In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import seaborn as sns
import wandb

%config Completer.use_jedi = False
%matplotlib inline

In [None]:
pgf_with_rc_fonts = {
    "font.serif": [],                   # use latex default serif font
    "font.sans-serif": ["DejaVu Sans"], # use a specific sans-serif font
    "font.size": 12,
    "ps.useafm": True,
    "pdf.use14corefonts": True,
    "text.usetex": True,
}
matplotlib.rcParams.update(pgf_with_rc_fonts)


In [None]:
api = wandb.Api()
entity, project = "INPUT_YOUR_ENTITY", "curl"
runs = api.runs(entity + "/" + project) 

In [None]:
def extract_best_checkpoint_and_epoch(output_model_name):
    elements = output_model_name.split("_")
    if elements[-1] == "model.pt":
        return output_model_name, "best"
    else:
        checkpoint = elements[-1].split(".")[0]
        
        elements[3] += ".pt"
        
        return "_".join(elements[:4]), checkpoint


In [None]:
datasets = ["wiki3029", "cifar10", "cifar100"]
eval_records = []
contrastive_records = []
classifiers = ("linear", "mean")


for run in runs: 
    if "hydra_path" not in run.config:
        continue

    dataset = run.config["dataset.name"]
    if dataset not in datasets:
        continue
        
    if run.config["name"] in classifiers:
        if run.config["normalize"]:
            continue

        extracted_keys = ("supervised_test_acc", "supervised_val_acc")
        for k in extracted_keys:
            run.config[k] = run.summary[k]        

        run.config["target_weight_file"], run.config["checkpoint"] = extract_best_checkpoint_and_epoch(run.config["target_weight_file"])
        eval_records.append(run.config)
        
    elif run.config["name"] == "contrastive":

        v = run.config["hydra_path"] + "/" + run.config["output_model_name"]

        run.config["target_weight_file"] = v
        contrastive_records.append(run.config)


In [None]:
contrastive_df = pd.DataFrame.from_records(contrastive_records)
eval_df = pd.DataFrame.from_records(eval_records)
results_df = contrastive_df.merge(eval_df, how="inner", on="target_weight_file", suffixes=("", "_y"))

used_columns = ["seed", "dataset.name", "supervised_test_acc", "supervised_val_acc", "dataset.num_used_classes", "optimizer.lr", "loss.neg_size", "epochs", "name_y", "checkpoint"]
results_df.drop(labels=[k for k in results_df.keys() if k not in used_columns ], inplace=True, axis=1)


In [None]:
removed_prefix_columns = []
for c in results_df.columns:
    removed_prefix_columns.append(c.split(".")[-1].replace("_", "-"))
results_df.columns = removed_prefix_columns

rename = {"wiki3029": "Wiki-3029", "cifar10": "CIFAR-10", "cifar100": "CIFAR-100"}

## Wiki3029

In [None]:
dataset = "wiki3029"

for classifier in ["mean", "linear"]:
    _results_df = results_df[results_df["name"] == dataset]
    _results_df = _results_df[_results_df["checkpoint"] == "best"]        
    df_per_classifer = _results_df[_results_df["name-y"] == classifier]
    idx = df_per_classifer.groupby(["seed", "num-used-classes", "neg-size"])["supervised-val-acc"].idxmax()
    df_per_classifer = df_per_classifer.loc[idx,]
    mean = df_per_classifer.groupby(["num-used-classes", "neg-size"]).mean().reset_index()
    std = df_per_classifer.groupby(["num-used-classes", "neg-size"]).std().reset_index()    

    mean = mean.pivot(index='num-used-classes', columns='neg-size', values='supervised-test-acc', )
    std = std.pivot(index='num-used-classes', columns='neg-size', values='supervised-test-acc', )    

    mean = mean.sort_values(by="num-used-classes", ascending=False)
    std = std.sort_values(by="num-used-classes", ascending=False)    

    Ks = mean.columns
    Cs = mean.index
    data = mean.to_numpy()

    plt.imshow(data)

    plt.yticks(np.arange(len(Cs)), [r"${}$".format(c) for c in sorted(Cs, reverse=True)])
    plt.xticks(np.arange(len(Ks)), [r"${}$".format(k) for k in sorted(list(Ks))])

    for c in range(len(Cs)):
        for k in range(len(Ks)):
            if c <= 1:
                color = "white"
            else:
                color = "black"
            plt.text(k, c, "${:.2f}$".format(data[c, k]),
                     ha="center", va="bottom", color=color)
            plt.text(k, c, "$({:.2f})$".format(std.to_numpy()[c, k]),
                     ha="center", va="top", color=color)            
    plt.title("{} {} Classifier".format(rename[dataset], classifier.capitalize()))
    plt.xlabel(r"$K$")
    plt.ylabel(r"$C$")    
    plt.savefig("../../papers/figures/heatmap-{}-{}.pdf".format(dataset, classifier))
    plt.show()            

## CIFAR-10/100


In [None]:
for classifier in ["mean", "linear"]:
    mean_dfs = []
    std_dfs = []
    for dataset in set(results_df.name.values) - {"wiki3029"}:
        _results_df = results_df[results_df["name"] == dataset]
        _results_df = _results_df[_results_df["checkpoint"] == "best"]        
        df_per_classifer = _results_df[_results_df["name-y"] == classifier]
        idx = df_per_classifer.groupby(["seed", "num-used-classes", "neg-size"])["supervised-val-acc"].idxmax()
        df_per_classifer = df_per_classifer.loc[idx,]
        mean = df_per_classifer.groupby(["num-used-classes", "neg-size"]).mean().reset_index()
        std = df_per_classifer.groupby(["num-used-classes", "neg-size"]).std().reset_index()    

        mean = mean.pivot(index='num-used-classes', columns='neg-size', values='supervised-test-acc', )
        std = std.pivot(index='num-used-classes', columns='neg-size', values='supervised-test-acc', )


        
        mean_dfs.append(mean)
        std_dfs.append(std)
    mean = pd.concat(mean_dfs)
    std = pd.concat(std_dfs)
    
    mean = mean.sort_values(by="num-used-classes", ascending=False)
    std = std.sort_values(by="num-used-classes", ascending=False)        
    
    Ks = mean.columns
    Cs = mean.index
    data = mean.to_numpy()

    plt.imshow(data)

    plt.yticks(np.arange(len(Cs)), [r"${}$".format(c) for c in sorted(Cs, reverse=True)])
    plt.xticks(np.arange(len(Ks)), [r"${}$".format(k) for k in sorted(list(Ks))])

    for c in range(len(Cs)):
        for k in range(len(Ks)):
            if c == 0:
                color = "white"
            else:
                color = "black"
            plt.text(k, c, "${:.2f}$".format(data[c, k]),
                     ha="center", va="bottom", color=color)
            plt.text(k, c, "$({:.2f})$".format(std.to_numpy()[c, k]),
                     ha="center", va="top", color=color)            
    plt.title("CIFAR {} Classifier".format(classifier.capitalize()))
    plt.xlabel(r"$K$")
    plt.ylabel(r"$C$")    
    plt.savefig("../../papers/figures/heatmap-cifar-{}.pdf".format(classifier))
    plt.show()            

### Mean classifier's test accuracy by checkpoints

In [None]:
import sys
sys.path.append("../../scripts/")
import style


for dataset in set(results_df.name.values):
    if dataset == "wiki3029":
        continue
        
    _results_df = results_df[results_df["name"] == dataset]
    _results_df = _results_df[_results_df["checkpoint"] != "best"]
    _results_df["checkpoint"] = pd.to_numeric(_results_df["checkpoint"])    
    
    df_per_classifer = _results_df[_results_df["name-y"] == "mean"]
    idx = df_per_classifer.groupby(["seed", "checkpoint", "neg-size"])["supervised-val-acc"].idxmax()
    df_per_classifer = df_per_classifer.loc[idx,]
    mean = df_per_classifer.groupby(["checkpoint", "neg-size"]).mean().reset_index()
    std = df_per_classifer.groupby(["checkpoint", "neg-size"]).std().reset_index()    

    mean = mean.sort_values(by="checkpoint", ascending=False)
    std = std.sort_values(by="checkpoint", ascending=False)    

    plot = sns.barplot(x="checkpoint", y="supervised-test-acc", hue="neg-size", data=mean)
    plt.ylabel("Mean classifier's test accuracy")
    plt.xlabel("Epochs")
    plt.legend(loc="lower left", title=r"$K$")
    plt.savefig("../../papers/figures/{}-mean-performance-by-epochs.pdf".format(dataset))
    plt.title(rename[dataset])
    plt.show()
