In [110]:
import pandas as pd 
pd.set_option('display.max_colwidth', None)
import numpy as np
from tqdm import tqdm

In [43]:
import wandb
from IPython.display import display

In [247]:
api = wandb.Api()
entity, project = "gmum", "conditional_contrastive"  # set to your entity and project 
runs = list(api.runs(
    entity + "/" + project, 
    filters={
        # "config.dataset": "imagenet100",
        # "job_type": "pretrain"
    }
))

# runs += [
#     api.run("/gmum/conditional_contrastive/runs/7m53pp39"), 
#     api.run("/gmum/conditional_contrastive/runs/3aboubsn"),
#     # official weights
#     api.run("/gmum/conditional_contrastive/runs/zuu4ufh5"),
#     api.run("/gmum/conditional_contrastive/runs/x531ebgf"),
# ]

runs = [r for r in runs if r.job_type in ["pretrain", None]]

len(runs)

215

In [253]:
linear_sets = {
    k
    for r in runs
    for k in r.summary.keys() if k.startswith("test_")
    if "stl" not in k
}

linear_sets

{'test_few-shot/fc100',
 'test_few-shot_5-way_1-shot/cub200',
 'test_few-shot_5-way_1-shot/fc100',
 'test_few-shot_5-way_1-shot/plant_disease',
 'test_few-shot_5-way_1-way/fc100',
 'test_few-shot_5-way_5-shot/cub200',
 'test_few-shot_5-way_5-shot/fc100',
 'test_few-shot_5-way_5-shot/plant_disease',
 'test_few_shot/cub200',
 'test_few_shot/fc100',
 'test_few_shot/plant_disease',
 'test_linear/aircraft',
 'test_linear/caltech101',
 'test_linear/cars',
 'test_linear/cifar10',
 'test_linear/cifar100',
 'test_linear/dtd',
 'test_linear/flowers',
 'test_linear/food101',
 'test_linear/mit67',
 'test_linear/pets',
 'test_linear/sun397',
 'test_linear_looc_like_acc/cub200',
 'test_linear_looc_like_acc/imagenet100',
 'test_linear_looc_like_best_acc/cub200',
 'test_linear_looc_like_best_acc/imagenet100',
 'test_linear_looc_like_v2_acc/cub200',
 'test_linear_looc_like_v2_best_acc/cub200'}

In [254]:
runs[-1].summary["test_linear/food101"]

0.30099010467529297

In [255]:
paper_rows = [
    {
        "name": "moco-resnet18-stl10_BASELINE_PAPER",
        'test_linear/aircraft': 0.2663,
        'test_linear/caltech101': 0.6415,
        'test_linear/cars': 0.1609,
        'test_linear/cifar10': 0.818,
        'test_linear/cifar100': 0.5375,
        'test_linear/dtd': 0.412,
        'test_linear/flowers': 0.6101,
        'test_linear/food101': 0.3369,
        'test_linear/mit67': 0.3901,
        'test_linear/pets': 0.4234,
        'test_linear/sun397': 0.285,
        "cfg/base_lr": 0.03
    },
    {
        "name": "moco-resnet18-stl10_AUGSELF_PAPER",
        'test_linear/aircraft': 0.2802,
        'test_linear/caltech101': 0.6602,
        'test_linear/cars': 0.1753,
        'test_linear/cifar10': 0.8245,
        'test_linear/cifar100': 0.5717,
        'test_linear/dtd': 0.4521,
        'test_linear/flowers': 0.6696,
        'test_linear/food101': 0.3691,
        'test_linear/mit67': 0.4167,
        'test_linear/pets': 0.4380,
        'test_linear/sun397': 0.3093,
        "cfg/base_lr": 0.03

    },
    {
        "name": "moco-resnet50-imagenet100_BASELINE_PAPER",
        'test_linear/aircraft': 0.4121,
        'test_linear/caltech101': 0.7725,
        'test_linear/cars': 0.3386,
        'test_linear/cifar10': 0.846,
        'test_linear/cifar100': 0.616,
        'test_linear/dtd': 0.6447,
        'test_linear/flowers': 0.8243,
        'test_linear/food101': 0.5967,
        'test_linear/mit67': 0.6164,
        'test_linear/pets': 0.7008,
        'test_linear/sun397': 0.4650,
        
        "test_few-shot_5-way_1-shot/fc100": 0.3167,
        "test_few-shot_5-way_5-shot/fc100": 0.4388,
        
        "test_few-shot_5-way_1-shot/cub200": 0.4167,
        "test_few-shot_5-way_5-shot/cub200": 0.5692,
        
        "test_few-shot_5-way_1-shot/plant_disease": 0.6573,
        "test_few-shot_5-way_5-shot/plant_disease": 0.8498,
        
        "test_linear_looc_like_v2_best_acc/cub200": 32.2,
        "cfg/base_lr": 0.03

    },
    {
        "name": "moco-resnet50-imagenet100_AUGSELF_PAPER",
        'test_linear/aircraft': 0.3947,
        'test_linear/caltech101': 0.7893,
        'test_linear/cars':0.3735,
        'test_linear/cifar10': 0.8526,
        'test_linear/cifar100': 0.639,
        'test_linear/dtd': 0.6622,
        'test_linear/flowers': 0.8570,
        'test_linear/food101': 0.6078,
        'test_linear/mit67': 0.6336,
        'test_linear/pets': 0.7346,
        'test_linear/sun397': 0.4852,
        
        "test_few-shot_5-way_1-shot/fc100": 0.3502,
        "test_few-shot_5-way_5-shot/fc100": 0.4877,
        
        "test_few-shot_5-way_1-shot/cub200": 0.4417,
        "test_few-shot_5-way_5-shot/cub200": 0.5735,
        
        "test_few-shot_5-way_1-shot/plant_disease": 0.718,
        "test_few-shot_5-way_5-shot/plant_disease": 0.8781,
        
        "test_linear_looc_like_v2_best_acc/cub200": 37.0,
                "cfg/base_lr": 0.03

    },
    
    {
        "name": "moco-resnet50-imagenet100_LOOC_PAPER",
        "test_linear_looc_like_v2_best_acc/cub200": 40.1,
                "cfg/base_lr": 0.03


    },
    {
        "name": "simsiam-resnet50-imagenet100_BASELINE_PAPER",
        'test_linear/aircraft': 0.4863,
        'test_linear/caltech101': 0.8413,
        'test_linear/cars': 0.4820,
        'test_linear/cifar10': 0.8689,
        'test_linear/cifar100': 0.6633,
        'test_linear/dtd': 0.6511,
        'test_linear/flowers': 0.8806,
        'test_linear/food101': 0.6148,
        'test_linear/mit67': 0.6575,
        'test_linear/pets': 0.7469,
        'test_linear/sun397': 0.5060,
        
        "test_few-shot_5-way_1-shot/fc100": 0.3619,
        "test_few-shot_5-way_5-shot/fc100": 0.5036,
        
        "test_few-shot_5-way_1-shot/cub200": 0.4556,
        "test_few-shot_5-way_5-shot/cub200": 0.6248,
        
        "test_few-shot_5-way_1-shot/plant_disease": 0.7572,
        "test_few-shot_5-way_5-shot/plant_disease": 0.8994,
        
        "test_linear_looc_like_v2_best_acc/cub200": 38.4,
                "cfg/base_lr": 0.05

    },
    {
        "name": "simsiam-resnet50-imagenet100_AUGSELF_PAPER",
        'test_linear/aircraft': 0.4976,
        'test_linear/caltech101': 0.8530,
        'test_linear/cars': 0.4752,
        'test_linear/cifar10': 0.8880,
        'test_linear/cifar100': 0.7027,
        'test_linear/dtd': 0.6729,
        'test_linear/flowers': 0.9070,
        'test_linear/food101': 0.6563,
        'test_linear/mit67': 0.6776,
        'test_linear/pets': 0.7634,
        'test_linear/sun397': 0.5228,
        
        "test_few-shot_5-way_1-shot/fc100": 0.3937,
        "test_few-shot_5-way_5-shot/fc100": 0.5527,
        
        "test_few-shot_5-way_1-shot/cub200": 0.4808,
        "test_few-shot_5-way_5-shot/cub200": 0.6627,
        
        "test_few-shot_5-way_1-shot/plant_disease": 0.7793,
        "test_few-shot_5-way_5-shot/plant_disease": 0.9152,
        
        "test_linear_looc_like_v2_best_acc/cub200": 45.3,
                "cfg/base_lr": 0.05
    }
    
]

In [257]:
summary_list, config_list, name_list = [], [], []

rows = [] + paper_rows

td_runs = tqdm(list(runs)) 
for run in td_runs:
    
    # if run.config is not None:
    #     print(f"\tOmitting {run.name} because {run.config.get('base_lr', 0.05)=}")
    #     if True or run.name not in ["moco-resnet50-imagenet100_augself", "moco-resnet50-imagenet100_mlp_2_16_proj-cat_crop_color"]:
    #         continue
            # pass
    
    eval_runs = api.runs(entity + "/" + project, filters={"group": run.name})
    
    td_runs.set_postfix({run.name: len(eval_runs)})
    # print()
    
    # if "augself" in run.name.lower():
        # for r in eval_runs:
            # print(f"\t{r.name}")
            # print("\t", r.summary.keys())
        
    metrics = {k:v for (k,v) in run.summary.items() if "stl" not in k}
    for k in linear_sets:
        for e in eval_runs:
            if k in e.summary:
                metrics[k] = e.summary[k]
        if k not in metrics:
            metrics[k] = -1
    
    rows.append({
        "id": run.id,
        "name": f"{run.name}_lr_0.05" if run.name in ["moco-resnet50-imagenet100_augself", "moco-resnet50-imagenet100_mlp_2_16_proj-cat_crop_color"] else run.name,
        "run": run,
        **metrics,
        # **{
        #     k: run.summary.get(k, 0)
        #     for k in linear_sets
        # }
        **{f"cfg/{k}": v for (k,v) in run.config.items()}
    
    })




100%|████████| 215/215 [00:57<00:00,  3.77it/s, moco-resnet18-stl10_hn_2_2048=0]


In [267]:
runs_df = pd.DataFrame(rows)

for c in runs_df.columns:
    if c.startswith("test_linear/"):
        runs_df = runs_df[runs_df[c].notnull()] 
        runs_df = runs_df[ runs_df[c]>0]

runs_df = runs_df[
    runs_df.apply(
        lambda r: (
            ("moco" in r["name"] and r.get("cfg/base_lr")==0.03) or
            ("simsiam" in r["name"] and r.get("cfg/base_lr")==0.05)
            or "official" in r["name"].lower()

        ), 
        axis=1)
]
        

runs_df = runs_df[runs_df.apply(lambda r: "moco" in r["name"], axis=1)]
runs_df = runs_df[runs_df.apply(lambda r: "imagenet100" in r["name"], axis=1)]
# runs_df = runs_df[runs_df.apply(lambda r: "stl10" in r["name"], axis=1)]

runs_df = runs_df[runs_df.apply(lambda r: "+augself" not in r["name"].lower(), axis=1)]

# runs_df = run



In [271]:


for m_subset in [
    "test_linear/",
    "test_few-shot_5-way",
    "test_linear_looc_like_v2_best_acc/"
]:
    col_subset = [c for c in runs_df.columns if c.startswith(m_subset) and "way/" not in c and not c.endswith("rank")]
    df = runs_df[["id", "name"] + col_subset].copy().dropna(subset=col_subset)
    
    df["name"] = df.name.map(
            lambda n: n.replace("moco-resnet50-imagenet100_", "").replace("OFFICIAL", "OFFICIAL_WEIGHTS").upper()
        )

    for n in [
        # "MLP_3_256_PROJ-CAT_CROP_COLOR",
        # "MLP_4_64_PROJ-CAT_CROP_COLOR_LR_0.03",
        "MLP_2_32_PROJ-CAT_CROP_COLOR"
    ]:
        df = df[df.name!=n]

    df["name"] = df.name.map(
        lambda n: n if any([(p in n) for p in ["AUGSELF", "LOOC", "BASELINE"] ]) else f"OUR_{n}"
    )
    
        
    df.style.highlight_max(color = 'blue', axis = 0)

    dn = len(df)
    print(f"{dn=}")
    c_subset = [c for c in df.columns if c not in ["name", "id"]]
    print("""
GREEN - 1
YELLOW - 2
PINK - 3
""")
    style = df.sort_values("name").style.highlight_max(
                props="font-weight: bold; background: lightgreen",
                subset=c_subset,
            ).highlight_quantile(
                    props="font-style: italic; background: yellow",
                    q_left=(dn-2)/dn, q_right=(dn-1)/dn,
                    subset=c_subset,
                )

    if len(df) > 4:
        style = style.highlight_quantile(
                    props="font-style: normal; background: pink",
                    q_left=(dn-3)/dn, q_right=(dn-2)/dn,
                    subset=c_subset,
                )
    display(style)
        # print(df[df.columns[1]])

dn=12

GREEN - 1
YELLOW - 2
PINK - 3



Unnamed: 0,id,name,test_linear/aircraft,test_linear/caltech101,test_linear/cars,test_linear/cifar10,test_linear/cifar100,test_linear/dtd,test_linear/flowers,test_linear/food101,test_linear/mit67,test_linear/pets,test_linear/sun397
46,2urk87ky,AUGSELF_LR_0.03,0.405954,0.794825,0.367492,0.85,0.6402,0.659043,0.860893,0.612713,0.627612,0.735757,0.491285
127,3aboubsn,AUGSELF_OFFICIAL_WEIGHTS,0.394688,0.789258,0.376446,0.8508,0.6363,0.664894,0.858014,0.610218,0.632836,0.734562,0.485239
3,,AUGSELF_PAPER,0.3947,0.7893,0.3735,0.8526,0.639,0.6622,0.857,0.6078,0.6336,0.7346,0.4852
126,7m53pp39,BASELINE_OFFICIAL_WEIGHTS,0.410036,0.772417,0.338142,0.8458,0.6159,0.644149,0.823412,0.595406,0.615672,0.699543,0.468212
2,,BASELINE_PAPER,0.4121,0.7725,0.3386,0.846,0.616,0.6447,0.8243,0.5967,0.6164,0.7008,0.465
39,g4p483qi,OUR_HN_4_64_PROJ-CAT_ALL_AUG_LR_0.03,0.363806,0.759078,0.29959,0.8347,0.6026,0.620213,0.78082,0.581822,0.608209,0.665031,0.4534
31,1azzl8ke,OUR_MLP_2_16_PROJ-CAT_ALL_AUG_LR_0.03,0.408414,0.801109,0.373834,0.861,0.6438,0.67234,0.84345,0.612594,0.629851,0.723683,0.484987
47,kov8l1xd,OUR_MLP_2_16_PROJ-CAT_CROP_COLOR_LR_0.03,0.416141,0.798187,0.37856,0.8544,0.6403,0.655851,0.847396,0.606218,0.624627,0.726159,0.490378
26,1pk05rfj,OUR_MLP_4_256_PROJ-CAT_ALL_AUG_LR_0.03,0.420258,0.794717,0.372964,0.8618,0.6416,0.643617,0.840418,0.607485,0.639552,0.719933,0.490428
40,tiwa3x21,OUR_MLP_4_64_PROJ-CAT_ALL_AUG_LR_0.03,0.42123,0.802872,0.389504,0.8627,0.651,0.667021,0.843547,0.610614,0.636567,0.721151,0.488967


dn=12

GREEN - 1
YELLOW - 2
PINK - 3



Unnamed: 0,id,name,test_few-shot_5-way_1-shot/fc100,test_few-shot_5-way_5-shot/fc100,test_few-shot_5-way_1-shot/cub200,test_few-shot_5-way_5-shot/cub200,test_few-shot_5-way_1-shot/plant_disease,test_few-shot_5-way_5-shot/plant_disease
46,2urk87ky,AUGSELF_LR_0.03,0.362006,0.484737,0.429556,0.563619,0.715644,0.894725
127,3aboubsn,AUGSELF_OFFICIAL_WEIGHTS,0.349381,0.489706,0.445425,0.571694,0.72435,0.883937
3,,AUGSELF_PAPER,0.3502,0.4877,0.4417,0.5735,0.718,0.8781
126,7m53pp39,BASELINE_OFFICIAL_WEIGHTS,0.321813,0.457837,0.421777,-1.0,0.659612,0.852275
2,,BASELINE_PAPER,0.3167,0.4388,0.4167,0.5692,0.6573,0.8498
39,g4p483qi,OUR_HN_4_64_PROJ-CAT_ALL_AUG_LR_0.03,0.261344,0.386606,0.3996,0.512056,0.618444,0.801019
31,1azzl8ke,OUR_MLP_2_16_PROJ-CAT_ALL_AUG_LR_0.03,0.333925,0.468894,0.429431,0.562719,0.687431,0.863719
47,kov8l1xd,OUR_MLP_2_16_PROJ-CAT_CROP_COLOR_LR_0.03,0.344031,0.478056,0.430063,0.553612,0.692187,0.860644
26,1pk05rfj,OUR_MLP_4_256_PROJ-CAT_ALL_AUG_LR_0.03,0.340931,0.480863,0.423225,0.552525,0.687631,0.862775
40,tiwa3x21,OUR_MLP_4_64_PROJ-CAT_ALL_AUG_LR_0.03,0.335919,0.477088,0.424387,0.550456,0.700187,0.870419


dn=12

GREEN - 1
YELLOW - 2
PINK - 3



Unnamed: 0,id,name,test_linear_looc_like_v2_best_acc/cub200
46,2urk87ky,AUGSELF_LR_0.03,31.239212
127,3aboubsn,AUGSELF_OFFICIAL_WEIGHTS,33.793579
3,,AUGSELF_PAPER,37.0
126,7m53pp39,BASELINE_OFFICIAL_WEIGHTS,32.171211
2,,BASELINE_PAPER,32.2
39,g4p483qi,OUR_HN_4_64_PROJ-CAT_ALL_AUG_LR_0.03,27.821884
31,1azzl8ke,OUR_MLP_2_16_PROJ-CAT_ALL_AUG_LR_0.03,31.877804
47,kov8l1xd,OUR_MLP_2_16_PROJ-CAT_CROP_COLOR_LR_0.03,31.9641
26,1pk05rfj,OUR_MLP_4_256_PROJ-CAT_ALL_AUG_LR_0.03,31.532619
40,tiwa3x21,OUR_MLP_4_64_PROJ-CAT_ALL_AUG_LR_0.03,32.430099
