In [None]:
# Function for creating the plots

import json
import matplotlib.pyplot as plt
import numpy as np
from adjustText import adjust_text
from collections.abc import Iterable
import os

def load_json(filepath):
    with open(filepath) as file:
        return json.load(file)


def plot_results(experiment, methods, series_result, const_result, seeds, filename, x_axis_label=None, y_axis_label=None, title=None, legend=True):

    fig = plt.figure()
    fig.patch.set_alpha(0)

    plt.clf()
    ax = plt.gca()
    ax.set_facecolor('white')
    max_x = 0
    to_plot = []

    for method, label in methods.items():
        x_vals = []
        y_vals = []

        if not isinstance(seeds, Iterable):
            seeds = [seeds]
            
        for seed in seeds:
            try:
                path = f"{experiment}/{method}_s={seed}_results.json"
                if not os.path.exists(path):
                    print(f"Missing: {path}")
                    return

                method_results = load_json(path)
                if series_result in method_results:
                    x_vals.append(np.array([x for x, y in method_results[series_result]]))
                    y_vals.append(np.array([y for x, y in method_results[series_result]]))
                else:
                    y_vals.append(method_results[const_result])
                    x_vals.append(0)
            except Exception as e:
                print(f"Problem with file {path}, error: {e}")
                #return
        
        #print(method)
        #print("===")
        #print(x_vals, y_vals)

        x_vals = np.array(x_vals)
        y_vals = np.array(y_vals) * 100
        
        y_std = np.std(y_vals, axis=0) 

        x_vals = np.mean(x_vals, axis=0) 
        y_vals = np.mean(y_vals, axis=0)

        #print(x_vals, y_vals, y_std)

        if isinstance(x_vals, np.ndarray) and x_vals.size > 0:
            max_x = max(max_x, x_vals.max())
    
        if (isinstance(y_vals, np.ndarray) and y_vals.size > 0) or isinstance(y_vals, float):
            to_plot.append((x_vals, y_vals, y_std, label))
    

    for x_vals, y_vals, y_std, label in to_plot:
        if isinstance(y_vals, np.ndarray):
            plt.plot(x_vals, y_vals, linestyle="-", label=label)
            plt.fill_between(x=x_vals, y1 = y_vals - y_std, y2= y_vals + y_std, alpha=0.3)
        elif isinstance(y_vals, float):
            plt.plot([0, max_x], [y_vals, y_vals], linestyle="-", label=label)
            plt.fill_between(x=[0, max_x], y1 = [y_vals - y_std, y_vals - y_std], y2 = [y_vals + y_std, y_vals + y_std], alpha=0.3)
    
    if legend:
        plt.legend(loc='lower right', fontsize=8, ncol=2)

    if title is not None:
        plt.title(title)
    
    if x_axis_label is not None:
        plt.xlabel(x_axis_label)
    
    if y_axis_label is not None:
        plt.ylabel(y_axis_label)

    plt.xlim([0, max_x])
    plt.margins(0.15, 0.25)
    plt.savefig(f"{filename}.pdf", dpi=300, bbox_inches='tight')

In [None]:
import os
os.makedirs("online_plots", exist_ok=True)

plt.rcParams.update({
    #"figure.figsize": (4, 1.5), # for smaller plots
    "figure.figsize": (4, 3),
    #"figure.figsize": (15, 10),
    "figure.dpi": 200,
    "figure.autolayout": False,
    "text.usetex": True,
    'mathtext.fontset': 'stix',
    'text.latex.preamble': "\\usepackage{amsmath}",
    'font.family': 'STIXGeneral',
    'savefig.transparent': False,
})

#seeds = [1, 13, 23, 2024, 7700]
seeds = [1, 13, 23, 2024, 7700]
#results_dir = "../results_online4"
results_dir = "../results_online6"
#for dataset in ["mediamill_plt", "youtube_deepwalk_plt", "flicker_deepwalk_plt", "eurlex_plt", "eurlex_lexglue_plt"]:
#for dataset in ["eurlex_lightxml", "wiki10_lightxml"]:#, "amazoncat_lightxml", "amazon_lightxml", "wiki500_lightxml"]:
#for dataset in ["flicker_deepwalk_plt", "eurlex_plt"]:
#for dataset in ["mediamill_plt"]:
for dataset, dataset_name in {
        "youtube_deepwalk_plt": "YouTube", 
        "mediamill_plt": "Mediamill", 
        "flicker_deepwalk_plt": "Flickr", 
        # "eurlex_plt": "Eurlex", 
        "eurlex_lexglue_plt": "Eurlex-LexGlue", 
        # "eurlex_lightxml": "Eurlex", 
        "rcv1x_plt": "RCV1-x", 
        "amazoncat_plt": "AmazonCat", 
        # "amazoncat_lightxml": "AmazonCat", 
        # "amazon_lightxml": "Amazon", 
        # "wiki500_lightxml": "Wikipedia",
}.items():
    for k in [0, 3]:
        for measure, short, measure_name in [
                ("micro_f1", f"mF@{k}", "Micro F1"), 
                ("macro_f1", f"mF@{k}", "Macro F1" if k == 0 else f"Macro F1@{k}"), 
                ("macro_recall", f"mR@{k}", "Macro Recall" if k == 0 else f"Macro Recall@{k}"), 
                ("macro_precision", f"mR@{k}", "Macro Precision" if k == 0 else f"Macro Precision@{k}"), 
                ("macro_g_mean", f"mG@{k}", "Macro G-mean" if k == 0 else f"Macro G-mean@{k}"),
                ("macro_h_mean", f"mH@{k}", "Macro H-Mean"),
                #("macro_min_tp_tn", f"min_tp_tn@{k}", "Macro-Min(TP,TN)"),
            ]:

            # methods = {
            #     f"online_default_{measure}_k={k}": "Top-$k$" if k > 0 else "$\hat \eta > 0.5$",
            #     #f"frank_wolfe_on_test_{measure}_k={k}": "FW$(\\boldsymbol{y})$",
            #     #f"frank_wolfe_etu_{measure}_k={k}": "FW$(\\hat \\eta)$",
            #     #f"block_coord_{measure}_k={k}": "ETU-Block-Coordinate",
            #     #f"greedy_{measure}_k={k}": "ETU-Greedy",
            # }

            # # if measure in ["macro_f1", "macro_min_tp_tn"] and k == 0:
            # #     methods[f"online_thresholds_{measure}_k={k}"] = "Online-Thresholds"
            # #     methods[f"find_thresholds_{measure}_k={k}"] = "Offline-Thresholds-PU"
            # #     methods[f"find_thresholds_{measure}_on_test_k={k}"] = "Offline-Thresholds"
            # if measure in ["micro_f1", "macro_f1"] and k == 0:
            #     methods[f"ofo_{measure}_k={k}"] = "OFO$(\\boldsymbol{y})$"
            #     methods[f"ofo_etu_{measure}_k={k}"] = "OFO$(\\hat \\eta)$"
            
            # if "micro" not in measure:
            #     methods.update({
            #         f"online_greedy_{measure}_k={k}": "Greedy$(\\boldsymbol{y}$)",
            #     })

            # methods.update({
            #     f"online_frank_wolfe_exp=1.2_{measure}_k={k}": "Online-FW$(\\boldsymbol{y})$",
            #     f"online_my_{measure}_k={k}": "OMMA$(\\boldsymbol{y})$",
            # })

            # if measure in ["micro_f1", "macro_f1"] and k == 0:
            #     methods[f"ofo_etu_{measure}_k={k}"] = "OFO$(\\hat \\eta)$"

            # if "micro" not in measure:
            #     methods.update({
            #         f"online_greedy_etu_{measure}_k={k}": "Greedy$(\\hat \\eta)$",
            #     })

            # methods.update({
            #     f"online_frank_wolfe_etu_exp=1.2_{measure}_k={k}": "Online-FW$(\\hat \\eta)$",
            #     f"online_my_etu_{measure}_k={k}": "OMMA$(\\hat \\eta)$",
            # })

            methods = {
                f"online_default_{measure}_k={k}": "Top-$k$" if k > 0 else "$\hat \eta > 0.5$",
                #f"frank_wolfe_on_test_{measure}_k={k}": "FW$(\\boldsymbol{y})$",
                #f"frank_wolfe_etu_{measure}_k={k}": "FW$(\\hat \\eta)$",
                #f"block_coord_{measure}_k={k}": "ETU-Block-Coordinate",
                #f"greedy_{measure}_k={k}": "ETU-Greedy",
            }

            # if measure in ["macro_f1", "macro_min_tp_tn"] and k == 0:
            #     methods[f"online_thresholds_{measure}_k={k}"] = "Online-Thresholds"
            #     methods[f"find_thresholds_{measure}_k={k}"] = "Offline-Thresholds-PU"
            #     methods[f"find_thresholds_{measure}_on_test_k={k}"] = "Offline-Thresholds"
            if measure in ["micro_f1", "macro_f1"] and k == 0:
                methods[f"ofo_{measure}_k={k}"] = "OFO"
            
            if "micro" not in measure:
                methods.update({
                    f"online_greedy_{measure}_k={k}": "Greedy",
                })

            methods.update({
                f"online_frank_wolfe_exp=1.2_{measure}_k={k}": "Online-FW",
                f"online_my_{measure}_k={k}": "OMMA",
            })

            # if measure in ["micro_f1", "macro_f1"] and k == 0:
            #     methods[f"ofo_etu_{measure}_k={k}"] = "OFO$_{\\hat \\eta}$"

            if "micro" not in measure:
                methods.update({
                    f"online_greedy_etu_{measure}_k={k}": "Greedy$(\\hat \\eta)$",
                })

            methods.update({
                f"online_frank_wolfe_etu_exp=1.2_{measure}_k={k}": "Online-FW$(\\hat \\eta)$",
                f"online_my_etu_{measure}_k={k}": "OMMA$(\\hat \\eta)$",
            })

            for series in ["pred_utility_history"]: #["pred_utility_history", "solution_utility_history"]:
                file = f"online_plots/{dataset}_{measure}_k={k}_{series}"
                print(file)
                title = f"{dataset_name} --- {measure_name}"
                if k > 0:
                    title += f"@{k}"
                plot_results(f"{results_dir}/{dataset}", methods, series, short, seeds, file, 
                             x_axis_label="$t$", #"Observed instances $(\\boldsymbol{x}, \\boldsymbol{y})$", 
                             y_axis_label=f"$\\text{{{measure_name}}}(\\boldsymbol{{C}}(\\boldsymbol{{y}}^t, \\boldsymbol{{\\widehat y}}^t))$ (\%)", 
                             title=None, legend=True)

In [None]:

fig, axs = plt.subplots(4, 1)
fig.patch.set_alpha(0)
axs.set_facecolor('white')

def plot_combined_results(axis, experiment, methods, series_result, const_result, seeds, filename, x_axis_label=None, y_axis_label=None, title=None, legend=True):
    to_plot = []

    for method, label in methods.items():
        x_vals = []
        y_vals = []

        if not isinstance(seeds, Iterable):
            seeds = [seeds]
            
        for seed in seeds:
            try:
                path = f"{experiment}/{method}_s={seed}_results.json"
                if not os.path.exists(path):
                    print(f"Missing: {path}")
                    return

                method_results = load_json(path)
                if series_result in method_results:
                    x_vals.append(np.array([x for x, y in method_results[series_result]]))
                    y_vals.append(np.array([y for x, y in method_results[series_result]]))
                else:
                    y_vals.append(method_results[const_result])
                    x_vals.append(0)
            except Exception as e:
                print(f"Problem with file {path}, error: {e}")
                #return
        
        #print(method)
        #print("===")
        #print(x_vals, y_vals)

        x_vals = np.array(x_vals)
        y_vals = np.array(y_vals) * 100
        
        y_std = np.std(y_vals, axis=0) 

        x_vals = np.mean(x_vals, axis=0) 
        y_vals = np.mean(y_vals, axis=0)

        #print(x_vals, y_vals, y_std)

        if isinstance(x_vals, np.ndarray) and x_vals.size > 0:
            max_x = max(max_x, x_vals.max())
    
        if (isinstance(y_vals, np.ndarray) and y_vals.size > 0) or isinstance(y_vals, float):
            to_plot.append((x_vals, y_vals, y_std, label))
    

    for x_vals, y_vals, y_std, label in to_plot:
        if isinstance(y_vals, np.ndarray):
            axis.plot(x_vals, y_vals, linestyle="-", label=label)
            axis.fill_between(x=x_vals, y1 = y_vals - y_std, y2= y_vals + y_std, alpha=0.3)
        elif isinstance(y_vals, float):
            axis.plot([0, max_x], [y_vals, y_vals], linestyle="-", label=label)
            axis.fill_between(x=[0, max_x], y1 = [y_vals - y_std, y_vals - y_std], y2 = [y_vals + y_std, y_vals + y_std], alpha=0.3)
    
    if legend:
        axis.legend()

    if x_axis_label is not None:
        axis.xlabel(x_axis_label)
    
    if y_axis_label is not None:
        axis.ylabel(y_axis_label)

    axis.xlim([0, max_x])


plt.margins(0.15, 0.25)
plt.savefig(f"{filename}.pdf", dpi=300, bbox_inches='tight')