In [120]:
import glob
import json
import itertools

from os import path

import pandas as pd

import torch

In [94]:
exp_fp = "/data/b2p-siteident/experiments/ssl_hyperparam_search/ssl-v*/*/v0"
hyperparams = [
    "batch_size",
    "lr", 
    "alpha",
    "lambda_u",
    "T",
    "ema_decay",
    "data_version",
    "no_use_pretrained"
]
metrics = [
    "best_acc",
    "max_train_acc",
    "best_epoch"
]

In [95]:
entries = []
for fp in glob.glob(exp_fp):
    if not (path.isfile(path.join(fp, "opts.json")) and path.isfile(path.join(fp, "logs.csv"))):
        print("{} not finished.".format(fp))
        continue
    
    with open(path.join(fp, "opts.json")) as f:
        opts = json.load(f)
    entry = [opts[h] for h in hyperparams]

    df = pd.read_csv(path.join(fp, "logs.csv"))
    best_acc = df.iloc[len(df) - 1]["Best Acc."]
    best_epoch = df.iloc[len(df) - 1]["Best Epoch"]
    max_train_acc = df["Train Acc."].max()
    entry += [best_acc, max_train_acc, best_epoch]
    entries.append(entry)

resdf = pd.DataFrame(entries, columns=hyperparams + metrics)
print("Number of experiments: {}".format(len(resdf)))

/data/b2p-siteident/experiments/ssl_hyperparam_search/ssl-v2/resnet18_ema-0.75_lmdu-0_T-0.75_a-0.5_tile-300/v0 not finished.
Number of experiments: 91


In [96]:
resdf[resdf.data_version == "v1"].sort_values(by="best_acc", ascending=False)[:3]

Unnamed: 0,batch_size,lr,alpha,lambda_u,T,ema_decay,data_version,no_use_pretrained,best_acc,max_train_acc,best_epoch
63,256,0.05,0.5,0.01,0.5,0.75,v1,False,85.7438,86.7839,82.0
69,256,0.05,0.5,0.01,0.5,0.9,v1,False,85.7438,86.6536,82.0
54,256,0.05,0.5,0.01,0.5,0.9,v1,False,85.7438,86.6536,82.0


In [97]:
# 49   256	0.05	0.5	0.01	0.5	0.90	v1	False	85.7438	86.6536
# 56   256	0.05	0.5	0.01	0.5	0.75	v1	False	85.7438	86.7839
# 74   256	0.05	0.5	0.01	0.5	0.75	v1	False	85.7438	86.7839

In [98]:
resdf[resdf.data_version == "v2"].sort_values(by="best_acc", ascending=False)[:3]

Unnamed: 0,batch_size,lr,alpha,lambda_u,T,ema_decay,data_version,no_use_pretrained,best_acc,max_train_acc,best_epoch
4,256,0.05,0.5,0.1,0.5,0.75,v2,False,88.3142,85.0446,100.0
10,256,0.05,0.5,0.1,0.5,0.5,v2,False,87.931,85.0446,100.0
11,256,0.05,0.5,0.1,0.5,0.9,v2,False,87.7395,84.933,100.0


In [73]:
# 2   256	0.05	0.5	0.1	0.5	0.75	v2	False	88.3142	85.0446
# 8   256	0.05	0.5	0.1	0.5	0.50	v2	False	87.9310	85.0446
# 9   256	0.05	0.5	0.1	0.5	0.90	v2	False	87.7395	84.9330

In [74]:
resdf[resdf.data_version == "v1"].sort_values(by="max_train_acc", ascending=False)[:3]

Unnamed: 0,batch_size,lr,alpha,lambda_u,T,ema_decay,data_version,no_use_pretrained,best_acc,max_train_acc,best_epoch
63,256,0.05,0.25,0.1,0.5,0.5,v1,False,84.5041,87.5,65.0
42,256,0.05,0.25,0.1,0.5,0.75,v1,False,84.7107,87.3047,125.0
75,256,0.05,0.25,0.1,0.5,0.75,v1,False,84.7107,87.3047,125.0


In [75]:
resdf[resdf.data_version == "v2"].sort_values(by="best_acc", ascending=False)[:3]

Unnamed: 0,batch_size,lr,alpha,lambda_u,T,ema_decay,data_version,no_use_pretrained,best_acc,max_train_acc,best_epoch
2,256,0.05,0.5,0.1,0.5,0.75,v2,False,88.3142,85.0446,100.0
8,256,0.05,0.5,0.1,0.5,0.5,v2,False,87.931,85.0446,100.0
9,256,0.05,0.5,0.1,0.5,0.9,v2,False,87.7395,84.933,100.0


In [None]:
CMD = (
        "python train_ssl.py --out {out} --model {model} --tile_size {tile} "
        "--manualSeed {seed} "
        "--data_version {version} --lambda-u {lambdau} --ema-decay {ema} "
        "--alpha {alpha} --gpu {gpu}"
    )

In [91]:
hyperparams = {
    "lambda-u": [0.01],
    "ema-decay": [0.75],
    "T": [0.5],
    "alpha": [0.25],
    "tile_size": [1200],
    "model": ["resnet50", "wide_resnet50_2"],
    "learning-rate": [0.002],
    "batch-size": [64]
}

CMD = (
    "python train_ssl.py --out {out} --model {model} --tile_size {tile} "
    "--manualSeed {seed} "
    "--data_version {version} --lambda-u {lambdau} --ema-decay {ema} "
    "--alpha {alpha} --learning-rate {learning_rate} "
    "--batch-size {batch_size} --no_early_stopping --no_use_pretrained"
)
MODEL_NAME = "{model}_ema-{ema}_lmdu-{lmdu}_T-{T}_a-{a}_tile-{tile}-nopretrained"
sorted_names = sorted(list(hyperparams.keys()))

cmds = []

for data_version in ["v1", "v2"]:
    vfpath = "experiments/ssl_results/data-{}".format(data_version)
    combinations = sorted(list(itertools.product(
        *[hyperparams[k] for k in sorted_names])))
    # print("Number of combinations: {}".format(len(combinations)))
    for combination in combinations:
        # print(combination)
        model_name = MODEL_NAME.format(
            model=combination[sorted_names.index("model")],
            ema=combination[sorted_names.index("ema-decay")],
            lmdu=combination[sorted_names.index("lambda-u")],
            T=combination[sorted_names.index("T")],
            a=combination[sorted_names.index("alpha")],
            tile=combination[sorted_names.index("tile_size")])
        for seed, version in [
                (10, 0), 
                # (42, 1), (100, 2)
        ]:
            version_name = "v{version}".format(version=version)
            out = path.join(vfpath, model_name, version_name)
            cmd = CMD.format(
                out=out,
                model=combination[sorted_names.index("model")],
                tile=combination[sorted_names.index("tile_size")],
                seed=42,
                version=data_version,
                lambdau=combination[sorted_names.index("lambda-u")],
                ema=combination[sorted_names.index("ema-decay")],
                alpha=combination[sorted_names.index("alpha")],
                log="{}_{}".format(model_name, version_name),
                learning_rate=combination[sorted_names.index("learning-rate")],
                batch_size=combination[sorted_names.index("batch-size")]
            )
            cmds.append(cmd)

In [92]:
cmds

['python train_ssl.py --out experiments/ssl_results/data-v1/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200-nopretrained/v0 --model resnet50 --tile_size 1200 --manualSeed 42 --data_version v1 --lambda-u 0.01 --ema-decay 0.75 --alpha 0.25 --learning-rate 0.002 --batch-size 64 --no_early_stopping --no_use_pretrained',
 'python train_ssl.py --out experiments/ssl_results/data-v1/wide_resnet50_2_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200-nopretrained/v0 --model wide_resnet50_2 --tile_size 1200 --manualSeed 42 --data_version v1 --lambda-u 0.01 --ema-decay 0.75 --alpha 0.25 --learning-rate 0.002 --batch-size 64 --no_early_stopping --no_use_pretrained',
 'python train_ssl.py --out experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200-nopretrained/v0 --model resnet50 --tile_size 1200 --manualSeed 42 --data_version v2 --lambda-u 0.01 --ema-decay 0.75 --alpha 0.25 --learning-rate 0.002 --batch-size 64 --no_early_stopping --no_use_pretrained',
 'python train_ssl.py -

In [126]:
CMD = "CUDA_VISIBLE_DEVICES=2 python train_ssl.py --evaluate --out {}"

all_fps = []
for fp in sorted(glob.glob("../experiments/ssl_results/*/*/*")):
    model_fp = path.join(fp, "model_best.pth.tar")
    if not path.isfile(model_fp):
        continue
    checkpoint = torch.load(model_fp, map_location="cpu")
    best_epoch = checkpoint["epoch"]
    all_fps.append((fp, best_epoch))

all_fps = sorted(all_fps, key=lambda x: -x[1])

all_cmds = []
for use_several in [False, True]:
    for fp, _ in all_fps:
        for use_last in [True, False]:
            cmd = CMD.format(fp.replace("../", ""))
            if use_last:
                cmd += " --use_last_checkpoint"
            if use_several:
                cmd += " --use_several_test_samples"
            all_cmds.append(cmd + ";")

In [127]:
" ".join(all_cmds)

'CUDA_VISIBLE_DEVICES=2 python train_ssl.py --evaluate --out experiments/ssl_results/data-v1/wide_resnet50_2_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200/v0 --use_last_checkpoint; CUDA_VISIBLE_DEVICES=2 python train_ssl.py --evaluate --out experiments/ssl_results/data-v1/wide_resnet50_2_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200/v0; CUDA_VISIBLE_DEVICES=2 python train_ssl.py --evaluate --out experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200/v0 --use_last_checkpoint; CUDA_VISIBLE_DEVICES=2 python train_ssl.py --evaluate --out experiments/ssl_results/data-v2/resnet50_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200/v0; CUDA_VISIBLE_DEVICES=2 python train_ssl.py --evaluate --out experiments/ssl_results/data-v2/wide_resnet50_2_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200/v0 --use_last_checkpoint; CUDA_VISIBLE_DEVICES=2 python train_ssl.py --evaluate --out experiments/ssl_results/data-v2/wide_resnet50_2_ema-0.75_lmdu-0.01_T-0.5_a-0.25_tile-1200/v0; CUDA_VISIBLE_DEVICES=2 