# Summarize Online Results

This notebook is used to generate tables for the paper on OMMA algorithm.

## Main results

In [4]:
import json
import os
import numpy as np

In [None]:
seeds = [1, 13, 23, 2024, 7700]
results_dir = "../results_online_replicate_new_reg_2"
datasets = [
        ("youtube_deepwalk_plt", "\\datasettable{YouTube} ($m = 46, n = 7926$)"), 
        ("eurlex_lexglue_plt", "\\datasettable{Eurlex-LexGlue} ($m = 100, n = 5000$)"), 
        ("mediamill_plt", "\\datasettable{Mediamill} ($m = 101, n = 12914$)"), 
        ("flicker_deepwalk_plt", "\\datasettable{Flickr} ($m = 195, n = 24154$)"), 
        ("rcv1x_plt", "\\datasettable{RCV1X} ($m = 2456, n = 155962$)"), 
        ("amazoncat_plt", "\\datasettable{AmazonCat} ($m = 13330, n = 306784$)"), 
        
        #("eurlex_plt", "Eurlex"), 
        #("eurlex_lightxml", "Eurlex"), 
        #("eurlex_plt", "Eurlex ($m = 3993, n = 3809$)"), 
        #("eurlex_lightxml", "Eurlex ($m = 5600, n = 5000$)"), 
        #("amazoncat_lightxml", "AmazonCat"), 
        #("amazon_lightxml", "Amazon")
        # "wiki500_lightxml", "Wikipedia",
]
metrics = [
    ("micro_f1_k=0", "miF@0", "Micro F1"), 
    ("macro_f1_k=0", "mF@0", "Macro F1"),
    ("macro_f1_k=3", "mF@3", "Macro F1$@3$"),
    #("macro_f1_k=5", f"mF@5", "Macro F1$@5$"),
    ("macro_recall_k=3", "mR@3", "Macro Recall$@3$"), 
    ("macro_precision_k=3", "mP@3", "Macro Prec.$@3$"), 
    #("macro_recall_k=5", f"mR@5", "Macro Recall$@5$"),
    ("macro_gmean_k=0", "mG@0", "Macro G-Mean"),
    ("macro_hmean_k=0", "mH@0", "Macro H-Mean"),
    #("macro_min_tp_tn_k=0", "min_tp_tn@0", "Macro $\\min(\\text{tp},\\text{tn})$"),
    #("macro_precision", f"mP@{k}", "Macro Precision"), 
    #("macro_min_tp_tn", f"min_tp_tn@{k}", "Macro-Min(TP,TN)"),
]

methods = [
    #("default_prediction", "Top-K / $\hat \eta(x) > 0.5$"),
    ("online_default", "Top-$k$ / $\hat \eta \!>\!0.5$"),

    ("ofo", "OFO"),
    ("online_greedy", "Greedy"),
    ("online_frank_wolfe_exp=1.1", "Online-FW"),
    ("omma", "\\OMMA{}"),
    #("online_my", "OMMA"),
    #("frank_wolfe", "FW"),
    #("frank_wolfe_on_test", "FW-t"),

    #("ofo_etu", "OFO$(\\hat \\eta)$"),
    #("online_greedy_etu", "Greedy$(\\hat \\eta)$"),
    ("online_frank_wolfe_etu_exp=1.1", "Online-FW$(\\hat \\eta)$"),
    #("omma_etu", "OMMA$(\\hat \\eta)$"),
    ("omma_etu", "\\OMMAeta{}"),
    #("online_my_etu", "OMMA$(\\hat \\eta)$"),
    #("frank_wolfe_on_train_etu", "FW$(\\hat \\eta)$"),
    #("frank_wolfe_etu", "FW-t$(\\hat \\eta)$"),
    ("frank_wolfe", "Offline-FW"),
]

methods = [
    #("default_prediction", "Top-K / $\hat \eta(x) > 0.5$"),
    ("online_default", "Top-$k$ / $\hat \eta \!>\!0.5$"),

    ("ofo", "OFO"),
    ("online_greedy", "Greedy"),
    ("online_frank_wolfe_exp=1.1", "Online-FW"),
    ("online_frank_wolfe_etu_exp=1.1", "Online-FW$(\\hat \\eta)$"),
    #("omma", "OMMA"),
    #("omma_etu", "OMMA$(\\hat \\eta)$"),
    ("omma", "\\OMMA{}"),
    ("omma_etu", "\\OMMAeta{}"),
    #("frank_wolfe", "Offline-FW"),
]

def load_json(filepath):
    with open(filepath) as file:
        return json.load(file)


def check_cell(c, command):
    if c.strip() == "":
        return ""
    else:
        return f"\\{command}{{{c.strip()}}}"

def wrap_calls(text, command):
    text = text.split("&")
    return "&".join([check_cell(c, command) for c in text])


def wrap_calls_md(text, command):
    text = text.split("|")
    return "|".join([f"{command}{c}{command}" for c in text])


def get_result(results_dirs, dataset, method, metric, metric_key, seeds=seeds, lower_better=False):
    
    if not isinstance(results_dirs, list):
        results_dirs = [results_dirs]

    best_result = 0 
    best_result_std = 0

    if lower_better:
        best_result = 1e9

    for r_dir in results_dirs:
        results = []
        for s in seeds:
            path1 = f"{r_dir}/{dataset}/{method}_{metric}_s={s}_results.json"
            if os.path.exists(path1):
                method_results = load_json(path1)
            else:
                path2 = f"{r_dir}/{dataset}/{method}_{metric.split('_')[-1]}_s={s}_results.json"
                if os.path.exists(path2):
                    method_results = load_json(path2)
                else:
                    #print(f"File {path1} or {path2} does not exist")
                    continue
            try:
                if metric_key == "pred_utility_history":
                    results.append(method_results["pred_utility_history"][-1][-1] * 100)
                else:
                    results.append(method_results[metric_key])
            except Exception as e:
                print(e)
                continue
        if len(results):
            results = np.array(results)
            result = results.mean()
            result_std = results.std()

            if not lower_better and result > best_result:
                best_result = result
                best_result_std = result_std
            elif lower_better and result < best_result:
                best_result = result
                best_result_std = result_std

    return best_result, best_result_std

# Create data
def create_main_table(results_dirs, datasets, methods, metrics, seeds, header, rules, format="tex"):
    table = []
    for i, (dataset, dataset_name) in enumerate(datasets):
        if i % 2 == 0 or format != "tex":
            for method, method_name in methods:
                table.append([method_name])
        
        for metric, metric_key, metric_name in metrics:
            table_col = []
            for method, method_name in methods:
                table_col.append(get_result(results_dirs, dataset, method, metric, "pred_utility_history", seeds=seeds)[0])
            
            table_col = np.array(table_col).round(2)
            
            if format == "tex":
                table_col_str = [f"{x:.2f}" if x != 0 else "$\\times$" for x in table_col]
                max_row = np.argsort(table_col)
                max_i = -1
                table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textbf")
                max_i -= 1
                while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]] and max_i > -len(table_col):
                    table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textbf")
                    max_i -= 1

                if max_i > -len(table_col):
                    table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textit")
                    max_i -= 1
                while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]] and max_i > -len(table_col):
                    table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textit")
                    max_i -= 1
                #print(table_col_str)
            else:
                table_col_str = [f"{x:.2f}" if x != 0 else "x" for x in table_col]
                max_row = np.argsort(table_col)
                max_i = -1
                table_col_str[max_row[max_i]] = wrap_calls_md(table_col_str[max_row[max_i]], "**")
                max_i -= 1
                while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]] and max_i > -len(table_col):
                    table_col_str[max_row[max_i]] = wrap_calls_md(table_col_str[max_row[max_i]], "**")
                    max_i -= 1

                if max_i > -len(table_col):
                    table_col_str[max_row[max_i]] = wrap_calls_md(table_col_str[max_row[max_i]], "*")
                    max_i -= 1
                while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]] and max_i > -len(table_col):
                    table_col_str[max_row[max_i]] = wrap_calls_md(table_col_str[max_row[max_i]], "*")
                    max_i -= 1
            
            m = len(table_col)
            for j, c in enumerate(table_col_str):
                table[-(m - j)].append(c)

    if format == "tex":
        print(header)
    d = 0
    for i, table_row in enumerate(table):        
        if format == "tex":
            if i % len(methods) == 0:
                sec_header = f"\\midrule\n& \\multicolumn{{7}}{{c|}}{{{datasets[d][1]}}} & \\multicolumn{{7}}{{c}}{{{datasets[d + 1][1]}}} \\\\ \n \\midrule"
                d += 2
                print(sec_header)
            
            row = "    " + " & ".join(table_row) + " \\\\"
            for r in rules:
                if i % len(methods) == r:
                    row += "\n    \\midrule"
        else:
            table_row = [t + ' ' * ((30 - len(t)) if i == 0 else (10 - len(t))) for i, t in enumerate(table_row)]
            if i % len(methods) == 0:
                print(f"**{datasets[d][1]}**")
                d += 1
                print(header)
            row = "| " + " | ".join(table_row) + " |"
        print(row)

    if format == "tex":
        print("\\bottomrule\n\\end{tabular}")


header1 = "\\begin{tabular}{l|c|cccccc|c|cccccc}"
header2 = "\\toprule\n    Method & Micro & \\multicolumn{6}{c|}{Macro} & Micro & \\multicolumn{6}{c}{Macro} \\\\"
header3 = "    & " + " & ".join([f"{metric[2].split(' ')[-1]}" for metric in metrics]) + " & " + " & ".join([f"{metric[2].split(' ')[-1]}" for metric in metrics]) + " \\\\"
header = f"{header1}\n{header2}\n{header3}"

header_md = """
| Method | Mi-F1 | Ma-F1 | Ma-F1@3 | Ma-R@3 | Ma-P@3 | Ma-G-Mean | Ma-H-Mean |
|:-------|------:|------:|--------:|-------:|-------:|----------:|----------:|
"""
header_md = header_md.strip()

rules = [0, 4]
results_dirs = [
    "../results_online_replicate_new_reg_0",
    "../results_online_replicate_new_reg_1e-6",
    "../results_online_replicate_new_reg_1e-3",
    "../results_online_replicate_new_reg_0.1",
    "../results_online_replicate_new_reg_0.1_v2",
    "../results_online_replicate_new_reg_1",
    "../results_online_replicate_old",
    "../results_online_replicate_new_reg_0.1_fw",
]

create_main_table(results_dirs, datasets, methods, metrics, seeds, header, rules, format="tex")
#create_main_table(results_dirs, datasets, methods, metrics, seeds, header_md, rules, format="md")

#################################
# Multi-class classification
#################################

results_dir = "../results_online_multiclass_2"
datasets = [
        #("sensorless_hsm", "Sensorless ($m = 11, n = 58509$)"),
        ("news20_hsm", "News20 ($m = 20, n = 7532$)"),
        ("fill", "News ($m = 20, n = 7532$)"), 
        ("ledgar_hsm", "Ledgar ($m = 100, n = 10000$)"),
        ("fill", "News ($m = 20, n = 7532$)"), 
        #("cal101_hsm", "Caltech 101 ($m = 101, n = 4339$)"), 
        #("cal101_lr", "Caltech 101 ($m = 101, n = 4339$)"),
        #("cal256_hsm", "Caltech 256 ($m = 256, n = 14890$)"), 
        ("cal256_lr", "Caltech 256 ($m = 256, n = 14890$)"),
        ("fill", "News ($m = 20, n = 7532$)"), 
        #("aloi.bin_hsm", "Protein ($m = 3548, n = 6621$)"), 
]
metrics = [
    #("micro_f1_k=1", "miF@1", "Micro F1"), 
    ("macro_f1_k=1", "mF@1", "Macro F1"),
    ("macro_f1_k=3", "mF@3", "Macro F1$@3$"),
    ("macro_recall_k=3", "mR@3", "Macro Recall$@3$"), 
    ("macro_precision_k=3", "mP@3", "Macro Prec.$@3$"), 
    ("multiclass_gmean_k=1", "G@1", "G-Mean"),
    ("multiclass_hmean_k=1", "H@1", "H-Mean"),
    ("multiclass_qmean_k=1", "Q@1", "Q-Mean"),
]
methods = [
    #("default_prediction", "Top-K / $\hat \eta(x) > 0.5$"),
    ("online_default", "Top-$1$ / Top-$k$"),

    ("online_greedy", "Greedy"),
    ("online_frank_wolfe_exp=1.1", "Online-FW"),
    ("omma", "OMMA"),
    #("frank_wolfe_on_test", "FW$(\\boldsymbol{y})$"),

    #("ofo_etu", "OFO$(\\hat \\eta)$"),
    #("online_greedy_etu", "Greedy$(\\hat \\eta)$"),
    ("online_frank_wolfe_etu_exp=1.1", "Online-FW$_{\\hat \\eta}$"),
    ("omma_etu", "OMMA$_{\\hat \\eta}$"),
    #("frank_wolfe_etu", "FW$(\\hat \\eta)$"),
    #("frank_wolfe", "Offline-FW"),
]

methods = [
    #("default_prediction", "Top-K / $\hat \eta(x) > 0.5$"),
    ("online_default", "Top-$1$ / Top-$k$"),
    ("online_greedy", "Greedy"),
    ("online_frank_wolfe_exp=1.1", "Online-FW"),
    ("online_frank_wolfe_etu_exp=1.1", "Online-FW$_{\\hat \\eta}$"),
    ("omma", "OMMA"),
    ("omma_etu", "OMMA$_{\\hat \\eta}$"),
    ("frank_wolfe", "Offline-FW"),
]

header1 = "\\begin{tabular}{l|cccc|ccc|cccc|ccc}"
header2 = "\\toprule\n    Method & \\multicolumn{4}{c|}{Macro} & \\multicolumn{3}{c|}{Mulit-class}& \\multicolumn{4}{c|}{Macro} & \\multicolumn{3}{c|}{Mulit-class} \\\\"
header3 = "    & " + " & ".join([f"{metric[2].split(' ')[-1]}" for metric in metrics]) + " & " + " & ".join([f"{metric[2].split(' ')[-1]}" for metric in metrics]) + " \\\\"
header = f"{header1}\n{header2}\n{header3}"

header_md = """
| Method | Ma-F1 | Ma-F1@3 | Ma-R@3 | Ma-P@3 | MC-G-Mean | MC-H-Mean | MC-Q-Mean |
|:-------|------:|--------:|--------:|------:|----------:|----------:|----------:|
"""
header_md = header_md.strip()

results_dirs = [
    "../results_online_multiclass_reg_0",
    "../results_online_multiclass_reg_1e-6",
    "../results_online_multiclass_reg_1e-3",
    "../results_online_multiclass_reg_0.1",
    "../results_online_multiclass_reg_0.1_v2",
    "../results_online_multiclass_reg_1",
]


rules = [0, 3]
#create_main_table(results_dirs, datasets, methods, metrics, seeds, header, rules, format="tex")
#create_main_table(results_dirs, datasets, methods, metrics, seeds, header_md, rules, format="md")
#create_main_table("../results_online_multiclass_0", datasets, methods, metrics, seeds, header_md, rules, format="md")
#create_main_table("../results_online_multiclass_10", datasets, methods, metrics, seeds, header_md, rules, format="md")
#create_main_table("../results_online_multiclass_100", datasets, methods, metrics, seeds, header_md, rules, format="md")
#create_main_table("../results_online_multiclass_1000", datasets, methods, metrics, seeds, header_md, rules, format="md")
#print("\n\n========================================================================")
#print("\n\n------------------------------------------------------------------------")
#create_main_table("../results_online_multiclass_reg_0", datasets, methods, metrics, seeds, header_md, rules, format="md")
#print("\n\n------------------------------------------------------------------------")
#create_main_table("../results_online_multiclass_reg_1e-6", datasets, methods, metrics, seeds, header_md, rules, format="md")
#print("\n\n------------------------------------------------------------------------")
#create_main_table("../results_online_multiclass_reg_1e-3", datasets, methods, metrics, seeds, header_md, rules, format="md")
#print("\n\n------------------------------------------------------------------------")
#create_main_table("../results_online_multiclass_reg_0.1", datasets, methods, metrics, seeds, header_md, rules, format="md")
#print("\n\n------------------------------------------------------------------------")
#create_main_table("../results_online_multiclass_reg_1", datasets, methods, metrics, seeds, header_md, rules, format="md")

## Extended results

In [None]:
results_dir = "../results_online6"
datasets = [
        ("youtube_deepwalk_plt", "\\datasettable{YouTube} ($m = 46, n = 7926$)"), 
        ("eurlex_lexglue_plt", "\\datasettable{Eurlex-LexGlue} ($m = 100, n = 5000$)"), 
        ("mediamill_plt", "\\datasettable{Mediamill} ($m = 101, n = 12914$)"), 
        ("flicker_deepwalk_plt", "\\datasettable{Flickr} ($m = 195, n = 24154$)"), 
        #("eurlex_plt", "Eurlex"), 
        #("eurlex_lightxml", "Eurlex"), 
        ("rcv1x_plt", "\\datasettable{RCV1X} ($m = 2456, n = 155962$)"), 
        #("eurlex_plt", "Eurlex ($m = 3993, n = 3809$)"), 
        #("eurlex_lightxml", "Eurlex ($m = 5600, n = 5000$)"), 
        ("amazoncat_plt", "\\datasettable{AmazonCat} ($m = 13330, n = 306784$)"), 
        #("amazoncat_lightxml", "AmazonCat"), 
        #("amazon_lightxml", "Amazon")
        #("wiki500_lightxml", "Wikipedia"),
]
metrics = [
    ("micro_f1_k=0", "miF@0", "Micro F1"), 
    ("macro_f1_k=0", "mF@0", "Macro F1"),
    ("macro_f1_k=3", "mF@3", "Macro F1$@3$"),
    #("macro_f1_k=5", f"mF@5", "Macro F1$@5$"),
    ("macro_recall_k=3", "mR@3", "Macro Recall$@3$"), 
    ("macro_precision_k=3", "mP@3", "Macro Prec.$@3$"), 
    #("macro_recall_k=5", f"mR@5", "Macro Recall$@5$"),
    ("macro_gmean_k=0", "mG@0", "Macro G-Mean"),
    ("macro_hmean_k=0", "mH@0", "Macro H-Mean"),
    #("macro_min_tp_tn_k=0", "min_tp_tn@0", "Macro $\\min(\\text{tp},\\text{tn})$"),
    #("macro_precision", f"mP@{k}", "Macro Precision"), 
    #("macro_min_tp_tn", f"min_tp_tn@{k}", "Macro-Min(TP,TN)"),
]
methods = [
    #("default_prediction", "Top-K / $\hat \eta(x) > 0.5$"),
    ("online_default", "Top-$k$ / $\hat \eta \!>\!0.5$"),

    ("ofo", "OFO$(\\boldsymbol{y})$"),
    ("online_greedy", "Greedy$(\\boldsymbol{y}$)"),
    ("online_frank_wolfe_exp=1.1", "Online-FW$(\\boldsymbol{y})$"),
    #("frank_wolfe_on_test", "FW$(\\boldsymbol{y})$"),

    ("ofo_etu", "OFO$(\\hat \\eta)$"),
    ("online_greedy_etu", "Greedy$(\\hat \\eta)$"),
    ("online_frank_wolfe_etu_exp=1.1", "Online-FW$(\\hat \\eta)$"),
    ("omma", "OMMA$(\\boldsymbol{y})$"),
    ("omma_etu", "OMMA$(\\hat \\eta)$"),
    ("frank_wolfe", "Offline-FW"),
]

methods = [
    #("default_prediction", "Top-K / $\hat \eta(x) > 0.5$"),
    ("online_default", "Top-$k$ / $\hat \eta \!>\!0.5$"),
    ("ofo", "OFO"),
    ("online_greedy", "Greedy"),
    ("online_greedy_etu", "Greedy$(\\hat \\eta)$"),
    ("online_frank_wolfe_exp=1.1", "Online-FW"),
    ("online_frank_wolfe_etu_exp=1.1", "Online-FW$(\\hat \\eta)$"),
    #("omma", "OMMA"),
    #("omma_etu", "OMMA$(\\hat \\eta)$"),
    ("omma", "\\OMMA{}"),
    ("omma_etu", "\\OMMAeta{}"),
    ("frank_wolfe", "Offline-FW"),
]

seeds = [1, 13, 23, 2024, 7700]


def check_cell(c, command):
    if c.strip() == "":
        return ""
    else:
        return f"\\{command}{{{c.strip()}}}"

def wrap_calls(text, command):
    text = text.split("&")
    return "&".join([check_cell(c, command) for c in text])
    


# Create data
def create_table(results_dirs, datasets, methods, metrics, seeds, header, rules):
    table = []
    for i, (dataset, dataset_name) in enumerate(datasets):
        for method, method_name in methods:
            table.append([method_name])
        
        for metric, metric_key, metric_name in metrics:
            table_col = []
            table_col_std = []

            for method, method_name in methods[:-1]:
                mean, std = get_result(results_dirs, dataset, method, metric, "pred_utility_history", seeds=seeds)
                table_col.append(mean)
                table_col_std.append(std)
            #print(table_col)
                
            table_col = np.array(table_col).round(2)
            table_col_str = [f"{x:.2f} & \scriptsize $\pm$ {x_std:.2f}" if x != 0 else "& $\\times$" for x, x_std in zip(table_col, table_col_std)]

            max_row = np.argsort(table_col)
            max_i = -1
            table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textbf")
            max_i -= 1
            while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]] and max_i > -len(table_col):
                table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textbf")
                max_i -= 1

            if max_i > -len(table_col):
                table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textit")
                max_i -= 1
            while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]] and max_i > -len(table_col):
                table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textit")
                max_i -= 1
            #print(table_col_str)

            method, method_name = methods[-1]
            mean, std = get_result(results_dirs, dataset, method, metric, "pred_utility_history", seeds=seeds)
            table_col_str.append(f"{mean:.2f} & \scriptsize $\pm$ {std:.2f}" if mean != 0 else "& $\\times$")
                            
            m = len(table_col) + 1
            for j, c in enumerate(table_col_str):
                table[-(m - j)].append(c)

            table_col = []
            table_col_std = []
            
            for method, method_name in methods[:-1]:
                if method != "frank_wolfe":
                    mean, std = get_result(results_dirs, dataset, method, metric, "time", seeds=seeds, lower_better=True)
                else:
                    mean = 0
                    std = 0
                table_col.append(mean)
                table_col_std.append(std)
            #print(table_col)
            table_col_str = [f"{x:.2f} & \scriptsize $\pm$ {x_std:.2f}" if x != 0 and x != 1e9 else "& $\\times$" for x, x_std in zip(table_col, table_col_std)]
            table_col_str.append("& $-$")

            # max_row = np.argsort(-table_col)
            # max_i = -1
            # table_col_str[max_row[max_i]] = f"\\textbf{{{table_col_str[max_row[max_i]]}}}"
            # max_i -= 1
            # while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]]:
            #     table_col_str[max_row[max_i]] = f"\\textbf{{{table_col_str[max_row[max_i]]}}}"
            #     max_i -= 1

            # table_col_str[max_row[max_i]] = f"\\textit{{{table_col_str[max_row[max_i]]}}}"
            # max_i -= 1
            # while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]]:
            #     table_col_str[max_row[max_i]] = f"\\textit{{{table_col_str[max_row[max_i]]}}}"
            #     max_i -= 1
            # print(table_col_str)
            
            m = len(table_col) + 1
            for j, c in enumerate(table_col_str):
                table[-(m - j)].append(c)

    print(header)
    d = 0
    for i, table_row in enumerate(table):
        if i % len(methods) == 0:
            header = f"\\midrule\n& \multicolumn{{28}}{{c}}{{{datasets[d][1]}}}  \\\\ \n \\midrule"
            d += 1
            print(header)
        row = "    " + " & ".join(table_row) + " \\\\"
        for r in rules:
            if i % len(methods) == r:
                row += "\n    \\midrule"
        print(row)

    print("\\bottomrule\n\\end{tabular}")


header1 = "\\begin{tabular}{l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l}"
#header2 = "\\toprule\n    Method & \\multicolumn{4}{c|}{Micro} & \\multicolumn{24}{c}{Macro} \\\\"
header2 = "\\toprule\n    Method & " + " & ".join([f"\\multicolumn{{4}}{{c|}}{{{metric[2]}}}" for metric in metrics[:-1]]) + f"& \\multicolumn{{4}}{{c}}{{{metrics[-1][2]}}}" + " \\\\"
header3 = "    " + "& \\multicolumn{2}{c}{(\%)} & \\multicolumn{2}{c|}{time (s)} " * len(metrics[:-1]) + "& \\multicolumn{2}{c}{(\%)} & \\multicolumn{2}{c}{time (s)}" + " \\\\"
header = f"{header1}\n{header2}\n{header3}"


results_dirs = [
    "../results_online_replicate_new_reg_0",
    "../results_online_replicate_new_reg_1e-6",
    "../results_online_replicate_new_reg_1e-3",
    "../results_online_replicate_new_reg_0.1",
    "../results_online_replicate_new_reg_0.1_v2",
    "../results_online_replicate_new_reg_1",
    "../results_online_replicate_old",
]

rules = [0, 5, 7]
create_table(results_dirs, datasets, methods, metrics, seeds, header, rules)


results_dirs = [
    "../results_online_multiclass_reg_0",
    "../results_online_multiclass_reg_1e-6",
    "../results_online_multiclass_reg_1e-3",
    "../results_online_multiclass_reg_0.1",
    "../results_online_multiclass_reg_0.1_v2",
    "../results_online_multiclass_reg_1",
]

datasets = [
        ("news20_hsm", "News20 ($m = 20, n = 7532$)"),
        ("ledgar_hsm", "Ledgar ($m = 100, n = 10000$)"),
        ("cal256_lr", "Caltech 256 ($m = 256, n = 14890$)"),
]
metrics = [
    #("micro_f1_k=1", "miF@1", "Micro F1"), 
    ("macro_f1_k=1", "mF@1", "Macro F1"),
    ("macro_f1_k=3", "mF@3", "Macro F1$@3$"),
    ("macro_recall_k=3", "mR@3", "Macro Recall$@3$"), 
    ("macro_precision_k=3", "mP@3", "Macro Prec.$@3$"), 
    ("multiclass_gmean_k=1", "G@1", "Multi-class G-Mean"),
    ("multiclass_hmean_k=1", "H@1", "Multi-class H-Mean"),
    ("multiclass_qmean_k=1", "Q@1", "Multi-class Q-Mean"),
]

methods = [
    #("default_prediction", "Top-K / $\hat \eta(x) > 0.5$"),
    ("online_default", "Top-$k$"),
    ("online_greedy", "Greedy"),
    ("online_greedy_etu", "Greedy$(\\hat \\eta)$"),
    ("online_frank_wolfe_exp=1.1", "Online-FW"),
    ("online_frank_wolfe_etu_exp=1.1", "Online-FW$(\\hat \\eta)$"),
    ("omma", "OMMA"),
    ("omma_etu", "OMMA$(\\hat \\eta)$"),
    ("frank_wolfe", "Offline-FW"),
]


header1 = "\\begin{tabular}{l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l}"
#header2 = "\\toprule\n    Method & \\multicolumn{4}{c|}{Micro} & \\multicolumn{24}{c}{Macro} \\\\"
header2 = "\\toprule\n    Method & " + " & ".join([f"\\multicolumn{{4}}{{c|}}{{{metric[2]}}}" for metric in metrics[:-1]]) + f"& \\multicolumn{{4}}{{c}}{{{metrics[-1][2]}}}" + " \\\\"
header3 = "    " + "& \\multicolumn{2}{c}{(\%)} & \\multicolumn{2}{c|}{time (s)} " * len(metrics[:-1]) + "& \\multicolumn{2}{c}{(\%)} & \\multicolumn{2}{c}{time (s)}" + " \\\\"
header = f"{header1}\n{header2}\n{header3}"

rules = [0, 4, 6]
create_table(results_dirs, datasets, methods, metrics, seeds, header, rules)


## Results for the online CPE

In [None]:
def create_table(results_dirs, datasets, methods, metrics, seeds, header, rules):
    table = []
    for i, (dataset, dataset_name) in enumerate(datasets):
        for method, method_name in methods:
            table.append([method_name])
        
        for metric, metric_key, metric_name in metrics:
            table_col = []
            table_col_std = []

            for method, method_name in methods:
                mean, std = get_result(results_dirs, dataset, method, metric, "pred_utility_history", seeds=seeds)
                table_col.append(mean)
                table_col_std.append(std)
                
            table_col = np.array(table_col).round(2)
            table_col_str = [f"{x:.2f} & \scriptsize $\pm$ {x_std:.2f}" if x != 0 else "& $\\times$" for x, x_std in zip(table_col, table_col_std)]

            max_row = np.argsort(table_col)
            max_i = -1
            table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textbf")
            max_i -= 1
            while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]] and max_i > -len(table_col):
                table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textbf")
                max_i -= 1

            if max_i > -len(table_col):
                table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textit")
                max_i -= 1
            while table_col[max_row[max_i]] == table_col[max_row[max_i + 1]] and max_i > -len(table_col):
                table_col_str[max_row[max_i]] = wrap_calls(table_col_str[max_row[max_i]], "textit")
                max_i -= 1
                            
            m = len(table_col) 
            for j, c in enumerate(table_col_str):
                table[-(m - j)].append(c)

            table_col = []
            table_col_std = []
            
            for method, method_name in methods:
                if method != "frank_wolfe":
                    mean, std = get_result(results_dirs, dataset, method, metric, "time", seeds=seeds, lower_better=True)
                else:
                    mean = 0
                    std = 0
                table_col.append(mean)
                table_col_std.append(std)

            table_col_str = [f"{x:.2f} & \scriptsize $\pm$ {x_std:.2f}" if x != 0 and x != 1e9 else "& $\\times$" for x, x_std in zip(table_col, table_col_std)]
            
            m = len(table_col)
            for j, c in enumerate(table_col_str):
                table[-(m - j)].append(c)

    print(header)
    d = 0
    for i, table_row in enumerate(table):
        if i % len(methods) == 0:
            header = f"\\midrule\n& \multicolumn{{28}}{{c}}{{{datasets[d][1]}}}  \\\\ \n \\midrule"
            d += 1
            print(header)
        row = "    " + " & ".join(table_row) + " \\\\"
        for r in rules:
            if i % len(methods) == r:
                row += "\n    \\midrule"
        print(row)

    print("\\bottomrule\n\\end{tabular}")


#seeds = [1, 13, 23, 2024, 7700]
seeds = [1, 13, 23]
results_dir = "../results_online9"
datasets = [
        ("youtube_deepwalk_online", "\\datasettable{YouTube} ($m = 46, n = 31703$)"), 
        ("eurlex_lexglue_online", "\\datasettable{Eurlex-LexGlue} ($m = 100, n = 65000$)"), 
        ("mediamill_online", "\\datasettable{Mediamill} ($m = 101, n = 43907$)"), 
        ("flicker_deepwalk_online", "\\datasettable{Flickr} ($m = 195, n = 80513$)"), 
        #("rcv1x_plt", "\\datasettable{RCV1X} ($m = 2456, n = 155962$)"), 
        #("amazoncat_plt", "\\datasettable{AmazonCat} ($m = 13330, n = 306784$)"), 
        
        #("eurlex_plt", "Eurlex"), 
        #("eurlex_lightxml", "Eurlex"), 
        #("eurlex_plt", "Eurlex ($m = 3993, n = 3809$)"), 
        #("eurlex_lightxml", "Eurlex ($m = 5600, n = 5000$)"), 
        #("amazoncat_lightxml", "AmazonCat"), 
        #("amazon_lightxml", "Amazon")
        # "wiki500_lightxml", "Wikipedia",
]
metrics = [
    ("micro_f1_k=0", "miF@0", "Micro F1"), 
    ("macro_f1_k=0", "mF@0", "Macro F1"),
    ("macro_f1_k=3", "mF@3", "Macro F1$@3$"),
    #("macro_f1_k=5", f"mF@5", "Macro F1$@5$"),
    ("macro_recall_k=3", "mR@3", "Macro Recall$@3$"), 
    ("macro_precision_k=3", "mP@3", "Macro Prec.$@3$"), 
    #("macro_recall_k=5", f"mR@5", "Macro Recall$@5$"),
    ("macro_gmean_k=0", "mG@0", "Macro G-Mean"),
    ("macro_hmean_k=0", "mH@0", "Macro H-Mean"),
    #("macro_min_tp_tn_k=0", "min_tp_tn@0", "Macro $\\min(\\text{tp},\\text{tn})$"),
    #("macro_precision", f"mP@{k}", "Macro Precision"), 
    #("macro_min_tp_tn", f"min_tp_tn@{k}", "Macro-Min(TP,TN)"),
]

methods = [
    #("default_prediction", "Top-K / $\hat \eta(x) > 0.5$"),
    ("online_default", "Top-$k$ / $\hat \eta \!>\!0.5$"),

    ("ofo", "OFO"),
    ("online_greedy", "Greedy"),
    ("online_greedy_etu", "Greedy$(\\hat \\eta)$"),
    ("online_frank_wolfe_exp=1.1", "Online-FW"),
    ("online_frank_wolfe_etu_exp=1.1", "Online-FW$(\\hat \\eta)$"),
    #("omma", "OMMA"),
    #("omma_etu", "OMMA$(\\hat \\eta)$"),
    ("omma", "\\OMMA{}"),
    ("omma_etu", "\\OMMAeta{}"),
]

results_dirs = [
    "../results_online_replicate_new_reg_0",
    "../results_online_replicate_new_reg_1e-6",
    "../results_online_replicate_new_reg_1e-3",
    "../results_online_replicate_new_reg_0.1",
    "../results_online_replicate_new_reg_0.1_v2",
    "../results_online_replicate_new_reg_1",
    "../results_online_replicate_old",
]


header1 = "\\begin{tabular}{l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l|r@{}lr@{}l}"
#header2 = "\\toprule\n    Method & \\multicolumn{4}{c|}{Micro} & \\multicolumn{24}{c}{Macro} \\\\"
header2 = "\\toprule\n    Method & " + " & ".join([f"\\multicolumn{{4}}{{c|}}{{{metric[2]}}}" for metric in metrics[:-1]]) + f"& \\multicolumn{{4}}{{c}}{{{metrics[-1][2]}}}" + " \\\\"
header3 = "    " + "& \\multicolumn{2}{c}{(\%)} & \\multicolumn{2}{c|}{time (s)} " * len(metrics[:-1]) + "& \\multicolumn{2}{c}{(\%)} & \\multicolumn{2}{c}{time (s)}" + " \\\\"
header = f"{header1}\n{header2}\n{header3}"

rules = [0, 5]
create_table(results_dirs, datasets, methods, metrics, seeds, header, rules)