In [1]:
import glob
import json
import itertools

from os import path

import pandas as pd
import numpy as np
from scipy.stats import sem

import torch

# Hyperparameter search

In [2]:
exp_fp = "/data/b2p-siteident/experiments/ssl_hyperparam_search/ssl-v*/*/v0"
hyperparams = [
    "batch_size",
    "lr", 
    "alpha",
    "lambda_u",
    "T",
    "ema_decay",
    "data_version",
    "no_use_pretrained"
]
metrics = [
    "best_acc",
    "max_train_acc",
    "best_epoch"
]

In [3]:
entries = []
for fp in glob.glob(exp_fp):
    if not (path.isfile(path.join(fp, "opts.json")) and path.isfile(path.join(fp, "logs.csv"))):
        print("{} not finished.".format(fp))
        continue
    
    with open(path.join(fp, "opts.json")) as f:
        opts = json.load(f)
    entry = [opts[h] for h in hyperparams]

    df = pd.read_csv(path.join(fp, "logs.csv"))
    best_acc = df.iloc[len(df) - 1]["Best Acc."]
    best_epoch = df.iloc[len(df) - 1]["Best Epoch"]
    max_train_acc = df["Train Acc."].max()
    entry += [best_acc, max_train_acc, best_epoch]
    entries.append(entry)

resdf = pd.DataFrame(entries, columns=hyperparams + metrics)
print("Number of experiments: {}".format(len(resdf)))

Number of experiments: 92


In [96]:
resdf[resdf.data_version == "v1"].sort_values(by="best_acc", ascending=False)[:3]

Unnamed: 0,batch_size,lr,alpha,lambda_u,T,ema_decay,data_version,no_use_pretrained,best_acc,max_train_acc,best_epoch
63,256,0.05,0.5,0.01,0.5,0.75,v1,False,85.7438,86.7839,82.0
69,256,0.05,0.5,0.01,0.5,0.9,v1,False,85.7438,86.6536,82.0
54,256,0.05,0.5,0.01,0.5,0.9,v1,False,85.7438,86.6536,82.0


In [97]:
# 49   256	0.05	0.5	0.01	0.5	0.90	v1	False	85.7438	86.6536
# 56   256	0.05	0.5	0.01	0.5	0.75	v1	False	85.7438	86.7839
# 74   256	0.05	0.5	0.01	0.5	0.75	v1	False	85.7438	86.7839

In [98]:
resdf[resdf.data_version == "v2"].sort_values(by="best_acc", ascending=False)[:3]

Unnamed: 0,batch_size,lr,alpha,lambda_u,T,ema_decay,data_version,no_use_pretrained,best_acc,max_train_acc,best_epoch
4,256,0.05,0.5,0.1,0.5,0.75,v2,False,88.3142,85.0446,100.0
10,256,0.05,0.5,0.1,0.5,0.5,v2,False,87.931,85.0446,100.0
11,256,0.05,0.5,0.1,0.5,0.9,v2,False,87.7395,84.933,100.0


In [73]:
# 2   256	0.05	0.5	0.1	0.5	0.75	v2	False	88.3142	85.0446
# 8   256	0.05	0.5	0.1	0.5	0.50	v2	False	87.9310	85.0446
# 9   256	0.05	0.5	0.1	0.5	0.90	v2	False	87.7395	84.9330

In [74]:
resdf[resdf.data_version == "v1"].sort_values(by="max_train_acc", ascending=False)[:3]

Unnamed: 0,batch_size,lr,alpha,lambda_u,T,ema_decay,data_version,no_use_pretrained,best_acc,max_train_acc,best_epoch
63,256,0.05,0.25,0.1,0.5,0.5,v1,False,84.5041,87.5,65.0
42,256,0.05,0.25,0.1,0.5,0.75,v1,False,84.7107,87.3047,125.0
75,256,0.05,0.25,0.1,0.5,0.75,v1,False,84.7107,87.3047,125.0


In [75]:
resdf[resdf.data_version == "v2"].sort_values(by="best_acc", ascending=False)[:3]

Unnamed: 0,batch_size,lr,alpha,lambda_u,T,ema_decay,data_version,no_use_pretrained,best_acc,max_train_acc,best_epoch
2,256,0.05,0.5,0.1,0.5,0.75,v2,False,88.3142,85.0446,100.0
8,256,0.05,0.5,0.1,0.5,0.5,v2,False,87.931,85.0446,100.0
9,256,0.05,0.5,0.1,0.5,0.9,v2,False,87.7395,84.933,100.0


# Commands SSL training

In [4]:
CMD = (
        "python train_ssl.py --out {out} --model {model} --tile_size {tile} "
        "--manualSeed {seed} "
        "--data_version {version} --lambda-u {lambdau} --ema-decay {ema} "
        "--alpha {alpha} --gpu {gpu}"
    )

In [5]:
hyperparams = {
    "lambda-u": [0.01],
    "ema-decay": [0.75],
    "T": [0.5],
    "alpha": [0.25],
    "tile_size": [1200],
    "model": ["resnet50"],
    "learning-rate": [0.002],
    "batch-size": [64],
    "use_last_n_layers": [
        -1, 
        # 9
    ],
    "no_use_pretrained": [
        # True, 
        False
    ],
    "data_modalities": [
        "", "osm_img slope waterways admin_bounds_qgis"]
}

CMD = (
    "python train_ssl.py --out {out} --model {model} --tile_size {tile} "
    "--manualSeed {seed} "
    "--data_version {version} --lambda-u {lambdau} --ema-decay {ema} "
    "--alpha {alpha} --learning-rate {learning_rate} "
    "--batch-size {batch_size} --use_last_n_layers {use_last_n_layers}"
)
MODEL_NAME = "{model}_ema-{ema}_lmdu-{lmdu}_T-{T}_a-{a}_tile-{tile}_freeze-{freeze}"
sorted_names = sorted(list(hyperparams.keys()))

cmds = []

for data_version in ["v1", "v2"]:
    vfpath = "experiments/ssl_results/data-{}".format(data_version)
    combinations = sorted(list(itertools.product(
        *[hyperparams[k] for k in sorted_names])))
    # print("Number of combinations: {}".format(len(combinations)))
    for combination in combinations:
        # print(combination)
        model_name = MODEL_NAME.format(
            model=combination[sorted_names.index("model")],
            ema=combination[sorted_names.index("ema-decay")],
            lmdu=combination[sorted_names.index("lambda-u")],
            T=combination[sorted_names.index("T")],
            a=combination[sorted_names.index("alpha")],
            tile=combination[sorted_names.index("tile_size")],
            freeze=combination[sorted_names.index("use_last_n_layers")]
        )
        if combination[sorted_names.index("no_use_pretrained")]:
            model_name += "-nopretrained"
        if combination[sorted_names.index("data_modalities")] != "":
            model_name += "-small"
            
        for seed, version in [
                (42, 0), 
                (10, 1), (100, 2)
        ]:
            version_name = "v{version}".format(version=version)
            out = path.join(vfpath, model_name, version_name)
            cmd = CMD.format(
                out=out,
                model=combination[sorted_names.index("model")],
                tile=combination[sorted_names.index("tile_size")],
                seed=seed,
                version=data_version,
                lambdau=combination[sorted_names.index("lambda-u")],
                ema=combination[sorted_names.index("ema-decay")],
                alpha=combination[sorted_names.index("alpha")],
                log="{}_{}".format(model_name, version_name),
                learning_rate=combination[sorted_names.index("learning-rate")],
                batch_size=combination[sorted_names.index("batch-size")],
                use_last_n_layers=combination[sorted_names.index("use_last_n_layers")]
            )
            if combination[sorted_names.index("no_use_pretrained")]:
                cmd += " --no_use_pretrained"
            if combination[sorted_names.index("data_modalities")] != "":
                cmd += " --data_modalities {}".format(combination[
                    sorted_names.index("data_modalities")])
            cmds.append(cmd + " --gpu 2")

In [6]:
"; ".join(cmds)

'python train_ssl.py --out experiments/ssl_results/data-v1/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v0 --model resnet50 --tile_size 1200 --manualSeed 42 --data_version v1 --lambda-u 0.01 --ema-decay 0.75 --alpha 0.25 --learning-rate 0.002 --batch-size 64 --use_last_n_layers -1 --gpu 2; python train_ssl.py --out experiments/ssl_results/data-v1/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v1 --model resnet50 --tile_size 1200 --manualSeed 10 --data_version v1 --lambda-u 0.01 --ema-decay 0.75 --alpha 0.25 --learning-rate 0.002 --batch-size 64 --use_last_n_layers -1 --gpu 2; python train_ssl.py --out experiments/ssl_results/data-v1/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v2 --model resnet50 --tile_size 1200 --manualSeed 100 --data_version v1 --lambda-u 0.01 --ema-decay 0.75 --alpha 0.25 --learning-rate 0.002 --batch-size 64 --use_last_n_layers -1 --gpu 2; python train_ssl.py --out experiments/ssl_results/data-v1/resnet50_ema-0.75_lmdu

# Commands SSL Eval

In [7]:
CMD = "python train_ssl.py --evaluate --out {}"
CMD_GPU = "CUDA_VISIBLE_DEVICES={} "
all_fps = []
for fp in sorted(glob.glob("../experiments/ssl_results/*/*/*")):
    model_fp = path.join(fp, "checkpoint.pth.tar")
    if not path.isfile(model_fp):
        continue
    checkpoint = torch.load(model_fp, map_location="cpu")
    best_epoch = checkpoint["epoch"]
    if best_epoch < 25:
        continue
    all_fps.append((fp, best_epoch))

all_fps = sorted(all_fps, key=lambda x: -x[1])

all_cmds = []

for fp, _ in all_fps:
    for use_last in [True, False]:
        for use_several in [False, True]:
            cmd = CMD.format(fp.replace("../", ""))
            if use_last:
                cmd += " --use_last_checkpoint"
            if use_several:
                cmd += " --use_several_test_samples"
            all_cmds.append(cmd + ";")

no_gpus = 3
cmds_per_gpu = len(all_cmds) // no_gpus
for i in range(no_gpus):
    print(" ".join([CMD_GPU.format(i) + cmd for cmd in all_cmds[i * cmds_per_gpu: (i + 1) * cmds_per_gpu]]))
    print()

CUDA_VISIBLE_DEVICES=0 python train_ssl.py --evaluate --out experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v0 --use_last_checkpoint; CUDA_VISIBLE_DEVICES=0 python train_ssl.py --evaluate --out experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v0 --use_last_checkpoint --use_several_test_samples; CUDA_VISIBLE_DEVICES=0 python train_ssl.py --evaluate --out experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v0; CUDA_VISIBLE_DEVICES=0 python train_ssl.py --evaluate --out experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v0 --use_several_test_samples; CUDA_VISIBLE_DEVICES=0 python train_ssl.py --evaluate --out experiments/ssl_results/data-v1/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v0 --use_last_checkpoint; CUDA_VISIBLE_DEVICES=0 python train_ssl.py --evaluate --out experiments/ssl_results/data-v1/res

# Eval

In [2]:
for fp in sorted(glob.glob("../experiments/ssl_results/*/*/*")):
    best_fp = path.join(fp, "model_best.pth.tar")
    last_fp = path.join(fp, "checkpoint.pth.tar")
    if not path.isfile(best_fp):
        continue
    best = torch.load(best_fp, map_location="cpu")
    last = torch.load(last_fp, map_location="cpu")
    best_epoch = best["epoch"]
    last_epoch = last["epoch"]
    with open(path.join(fp, "opts.json")) as f:
        opts = json.load(f)
    if opts["finished"]:
        continue
    print("/".join(fp.split("/")[3:]) + ": best {} last {} finished {}".format(
        best_epoch, last_epoch, opts["finished"]))

In [3]:
eval_fps = "/data/b2p-siteident/experiments/ssl_results/*/*/*"

In [16]:
feats = [
    # "batch_size",
    # "lr",
    "alpha",
    "lambda_u",
    "ema_decay",
    "T",
    # "tile_size",
    # "model",
    "data_version",
    "use_last",
    "pretrained",
    "data_modalities"
]

add_feats = [
    "best/last",
    "use_several",
    # "num_test_samples"
]

metrics = [
    # "epoch",
    "val_acc",
    # "val_weighted_f1",
    # "val_loss",
    "test_acc",
    "test_weighted_f1",
    # "test_loss",
    # "test_rw_acc",
    "test_rw_weighted_f1",
    # "test_rw_loss",
    # "test_ug_acc",
    "test_ug_weighted_f1",
    # "test_ug_loss",
    "ecount"
]

In [17]:
entries_dict = dict()
for fp in glob.glob(eval_fps):
    print(fp)
    with open(path.join(fp, "opts.json")) as f:
        opts = json.load(f)
    k = []
    for f in feats:
        if f == "pretrained":
            k.append(not opts["no_use_pretrained"])
        elif f == "use_last":
            k.append(opts["use_last_n_layers"])
        elif f == "data_modalities":
            if len(opts[f]) < 7:
                k.append("small")
            else:
                k.append("large")
        else:
            k.append(opts[f])
    stats_fps = glob.glob(path.join(fp, "stats_*.json"))
    for stats_fp in stats_fps:
        _, best_or_last, use_several_test_samples, num_test_samples = stats_fp.split(
            "/")[-1].split(".")[0].split("_")
        use_several_test_samples = use_several_test_samples == "True"
        num_test_samples = int(num_test_samples)
        k_stats = list(k)
        k_stats += [
            best_or_last, 
            use_several_test_samples, 
            # num_test_samples
        ]
        k_stats = tuple(k_stats)
        if k_stats not in entries_dict:
            entries_dict[k_stats] = {m: [] for m in metrics if m != "ecount"}
            entries_dict[k_stats]["ecount"] = 0
        with open(stats_fp) as f:
            stats = json.load(f)
        entries_dict[k_stats]["ecount"] += 1
        for m in metrics:
            if m == "ecount":
                continue
            entries_dict[k_stats][m].append(stats[m])

entries = []
for k, entry in entries_dict.items():
    avg_entry = list(k)
    for m in metrics:
        if m == "ecount" or m == "epoch":
            avg_entry.append(entry[m])
        else:
            val = np.array(entry[m])
            if "acc" in m:
                val *= 100
            avg_entry.append(round(np.mean(val), 2))
            avg_entry.append(round(sem(val), 2))
    entries.append(avg_entry)

metrics_header = []
for m in metrics:
    if m == "ecount" or m == "epoch":
        metrics_header.append(m)
    else:
        metrics_header.append(m + "_m")
        metrics_header.append(m + "_ste")
df = pd.DataFrame(entries, columns=feats + add_feats + metrics_header)

/data/b2p-siteident/experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze-9/v0
/data/b2p-siteident/experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1-small/v0
/data/b2p-siteident/experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1-small/v2
/data/b2p-siteident/experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1-small/v1
/data/b2p-siteident/experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v0
/data/b2p-siteident/experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v2
/data/b2p-siteident/experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze--1/v1
/data/b2p-siteident/experiments/ssl_results/data-v1/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200_freeze-9/v0
/data/b2p-siteident/experiments/ssl_results/data-v1/resn

  keepdims=keepdims, where=where)
  ret = ret.dtype.type(ret / rcount)


In [19]:
df[df.data_modalities == "large"].sort_values(by=["data_version", "val_acc_m"], ascending=False)

Unnamed: 0,alpha,lambda_u,ema_decay,T,data_version,use_last,pretrained,data_modalities,best/last,use_several,...,val_acc_ste,test_acc_m,test_acc_ste,test_weighted_f1_m,test_weighted_f1_ste,test_rw_weighted_f1_m,test_rw_weighted_f1_ste,test_ug_weighted_f1_m,test_ug_weighted_f1_ste,ecount
10,0.25,0.01,0.75,0.5,v2,-1,True,large,last,True,...,0.68,86.79,1.4,0.87,0.01,0.83,0.01,0.82,0.03,3
8,0.25,0.01,0.75,0.5,v2,-1,True,large,last,False,...,0.57,86.14,1.03,0.86,0.01,0.83,0.01,0.8,0.03,3
9,0.25,0.01,0.75,0.5,v2,-1,True,large,best,True,...,0.17,87.1,0.53,0.87,0.01,0.82,0.01,0.87,0.0,3
11,0.25,0.01,0.75,0.5,v2,-1,True,large,best,False,...,0.88,85.87,1.96,0.86,0.02,0.81,0.01,0.84,0.02,3
2,0.25,0.01,0.75,0.5,v2,9,True,large,last,True,...,,81.11,,0.81,,0.69,,0.87,,1
3,0.25,0.01,0.75,0.5,v2,9,True,large,best,False,...,,80.07,,0.8,,0.68,,0.87,,1
1,0.25,0.01,0.75,0.5,v2,9,True,large,best,True,...,,82.14,,0.82,,0.67,,0.87,,1
0,0.25,0.01,0.75,0.5,v2,9,True,large,last,False,...,,81.34,,0.81,,0.67,,0.91,,1
25,0.25,0.01,0.75,0.5,v1,-1,True,large,best,True,...,0.84,78.17,1.11,0.78,0.01,0.82,0.0,0.68,0.03,3
24,0.25,0.01,0.75,0.5,v1,-1,True,large,last,False,...,0.73,76.69,0.47,0.76,0.01,0.83,0.01,0.64,0.04,3


In [None]:
\res{0.82}{<0.01} & \res{0.68}{0.03} & {0.78}{0.01} & & \res{0.83}{0.01} & \res{0.82}{0.03} & \res{0.87}{0.01}

In [None]:
\res{0.82}{0.01}& \res{0.55}{0.03} & \res{0.75}{0.01} & & \res{0.82}{0.01} & \res{0.76}{0.01} & \res{0.85}{0.01}

In [40]:
df.sort_values(by=feats + add_feats)

Unnamed: 0,alpha,lambda_u,ema_decay,T,model,data_version,use_last,pretrained,best/last,use_several,epoch,val_acc_m,val_acc_ste,test_acc_m,test_acc_ste,ecount
11,0.25,0.01,0.75,0.5,resnet50,v1,-1,False,best,False,[53],84.09,,79.84,,1
10,0.25,0.01,0.75,0.5,resnet50,v1,-1,False,last,False,[73],80.37,,76.29,,1
14,0.25,0.01,0.75,0.5,resnet50,v1,-1,True,best,False,[105],84.71,,76.61,,1
12,0.25,0.01,0.75,0.5,resnet50,v1,-1,True,last,False,[141],79.13,,75.32,,1
13,0.25,0.01,0.75,0.5,resnet50,v1,-1,True,last,True,[134],84.3,,77.34,,1
17,0.25,0.01,0.75,0.5,resnet50,v1,9,False,best,False,[3],71.49,,64.68,,1
15,0.25,0.01,0.75,0.5,resnet50,v1,9,False,last,False,[2],64.05,,61.21,,1
16,0.25,0.01,0.75,0.5,resnet50,v1,9,False,last,True,[4],66.12,,64.84,,1
2,0.25,0.01,0.75,0.5,resnet50,v2,-1,False,best,False,[39],84.67,,81.8,,1
1,0.25,0.01,0.75,0.5,resnet50,v2,-1,False,best,True,[25],85.82,,84.1,,1
