In [1]:
def attention_head(n_tokens, dim, dim_head):
    macs = 0
    
    # token -> k, q, v
    macs += n_tokens * 3 * dim * dim_head
    
    # q * k'
    # (n_tokens, dim_head) * (dim_head, n_tokens) -> (n_tokens, n_tokens)
    macs += n_tokens * dim_head * n_tokens
    
    # Softmax e diviso sqrt(dim_head) 
    # ...
    
    # (q * k') * v
    # (n_tokens, n_tokens) * (n_tokens, dim_head) -> (n_tokens, dim_head)
    macs += n_tokens * n_tokens * dim_head
    
    return macs
    
def attention(n_tokens, dim, dim_head, n_heads):
    macs = 0
    
    macs += n_heads * attention_head(n_tokens, dim, dim_head)
    
    # Riporta gli z concatenati a dimensione dim
    macs += n_tokens * (dim_head * n_heads) * dim if not (n_heads == 1 and dim_head == dim) else 0
    
    return macs

def feed_forward(n_tokens, dim, mlp_dim):
    # 2 Linear: dim -> mlp_dim, mlp_dim -> dim
    return n_tokens * dim * mlp_dim * 2

def transformer(n_tokens, dim, dim_head, n_heads, mlp_dim, depth):
    return depth * (attention(n_tokens, dim, dim_head, n_heads) + feed_forward(n_tokens, dim, mlp_dim))

def vit(patch_size, dim, dim_head, n_heads, mlp_dim, depth):
    macs = 0
    
    n_tokens = 300 // patch_size
    
    # linear embedding
    macs += n_tokens * (14 * patch_size) * dim
    
    # +1 perché c'è cls_token
    macs += transformer(n_tokens + 1, dim, dim_head, n_heads, mlp_dim, depth)
    
    # output
    # Da mean o last token a class_scores
    macs += dim * 8
    
    return macs

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)
               
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x)
        qkv = qkv.chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_cls_token=True):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        #self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.use_cls_token = use_cls_token
        if self.use_cls_token:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches + 1, dim))
        else:
            self.pos_embedding = nn.Parameter(torch.empty(1, num_patches, dim))
        #nn.init.kaiming_uniform_(self.pos_embedding, a=5 ** .5)
        nn.init.normal_(self.pos_embedding, std=.02)

        #self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.cls_token = nn.Parameter(torch.empty(1, 1, dim))
        nn.init.zeros_(self.cls_token)
        
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # MACs: patch_size * n_patches * dim, es (30 * 14) * 10 * 300
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        
        if self.use_cls_token:
            cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
            x = torch.cat((cls_tokens, x), dim=1)
            x += self.pos_embedding[:, :(n + 1)]
        else :
            x += self.pos_embedding
        
        x = self.dropout(x)
        
        # FeedForward    Attention       project out
        # 300*300*10*2 + 300*(64*3)*10 + ((64)*300*10)
        # Attention -> manca softmax e attention vera e propria, c'è solo linear encoding a qkv
        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        
        x = self.to_latent(x)
        x = self.mlp_head(x)
        return x
    
# Ratio of params
def vit_aff_ratio(patch_size, dim, dim_head, n_heads, mlp_dim, depth): 
    n_tokens = 300 // patch_size + 1
    
    a = (((dim) * dim_head * 3 * n_heads) + ((dim_head * n_heads) * dim) + dim)
    ff = dim * mlp_dim * 2 + mlp_dim + dim
    
    return a / (a + ff)

def get_results(configs, results_, additional_columns, extract_model_hparams):
        
    acccs = []
    acccs_steady = []
    acccs_val0 = []
    acccs_steady_val0 = []
    acccs_val1 = []
    acccs_steady_val1 = []

    acccs_val_val0 = 0
    acccs_val_val1 = 0
    acccs_train_val0 = 0
    acccs_train_val1 = 0
    
    acccs_steady_persubject = np.array([0] * 10, dtype=float)
    preds_steady_bincounts_subject = np.zeros((10, 8), dtype=int)
    
    for config, r in zip(configs, results_):

        accs = np.array([0] * len(r['test_sessions']), dtype=float)
        for i, (y_pred, y_true) in enumerate(zip(r['val-fold_0']['y_preds'], r['val-fold_0']['y_trues'])):
            accs[i] += (y_pred == y_true).sum().float() / len(y_true)
        #for i, (y_pred, y_true) in enumerate(zip(r['val-fold_1']['y_preds'], r['val-fold_1']['y_trues'])):
        #    accs[i] += (y_pred == y_true).sum().float() / len(y_true)
        #accs /= 2

        accs_steady = np.array([0] * len(r['test_sessions']), dtype=float)
        for i, (y_pred, y_true) in enumerate(zip(r['val-fold_0']['y_preds_steady'], r['val-fold_0']['y_trues_steady'])):
            accs_steady[i] += (y_pred == y_true).sum().float() / len(y_true)
        for i, (y_pred, y_true) in enumerate(zip(r['val-fold_1']['y_preds_steady'], r['val-fold_1']['y_trues_steady'])):
            accs_steady[i] += (y_pred == y_true).sum().float() / len(y_true)
        accs_steady /= 2

        accs_val0 = np.array([0] * len(r['test_sessions']), dtype=float)
        for i, (y_pred, y_true) in enumerate(zip(r['val-fold_0']['y_preds'], r['val-fold_0']['y_trues'])):
            accs_val0[i] += (y_pred == y_true).sum().float() / len(y_true)

        accs_steady_val0 = np.array([0] * len(r['test_sessions']), dtype=float)
        for i, (y_pred, y_true) in enumerate(zip(r['val-fold_0']['y_preds_steady'], r['val-fold_0']['y_trues_steady'])):
            accs_steady_val0[i] += (y_pred == y_true).sum().float() / len(y_true)
            acccs_steady_persubject[r['subject'] - 1] += (y_pred == y_true).sum().float() / len(y_true)
            preds_steady_bincounts_subject[r['subject'] - 1] += y_pred.bincount(minlength=(y_true.max() + 1)).numpy()
            
        accs_val1 = np.array([0] * len(r['test_sessions']), dtype=float)
        for i, (y_pred, y_true) in enumerate(zip(r['val-fold_1']['y_preds'], r['val-fold_1']['y_trues'])):
            accs_val1[i] += (y_pred == y_true).sum().float() / len(y_true)

        accs_steady_val1 = np.array([0] * len(r['test_sessions']), dtype=float)
        for i, (y_pred, y_true) in enumerate(zip(r['val-fold_1']['y_preds_steady'], r['val-fold_1']['y_trues_steady'])):
            accs_steady_val1[i] += (y_pred == y_true).sum().float() / len(y_true)
            acccs_steady_persubject[r['subject'] - 1] += (y_pred == y_true).sum().float() / len(y_true)
            preds_steady_bincounts_subject[r['subject'] - 1] += y_pred.bincount(minlength=(y_true.max() + 1)).numpy()
            
        acccs_train_val0 += r['val-fold_0']['losses_accs'][-1]['train_acc']
        acccs_train_val1 += r['val-fold_1']['losses_accs'][-1]['train_acc']

        acccs_val_val0 += r['val-fold_0']['losses_accs'][-1]['val_acc']
        acccs_val_val1 += r['val-fold_1']['losses_accs'][-1]['val_acc']

        acccs.append(accs)
        acccs_steady.append(accs_steady)
        acccs_val0.append(accs_val0)
        acccs_steady_val0.append(accs_steady_val0)
        acccs_val1.append(accs_val1)
        acccs_steady_val1.append(accs_steady_val1)
        
    test_sessions = len(r['test_sessions'])

    acccs_steady_persubject /= test_sessions * 2 # 5 sessioni per due

    acccs = np.array(acccs).mean(axis=0)

    acccs_steady_ = np.array(acccs_steady).mean(axis=1)
    acccs_steady = np.array(acccs_steady).mean(axis=0)

    acccs_val0 = np.array(acccs_val0).mean(axis=0)
    acccs_val1 = np.array(acccs_val1).mean(axis=0)
    acccs_steady_val0 = np.array(acccs_steady_val0).mean(axis=0)
    acccs_steady_val1 = np.array(acccs_steady_val1).mean(axis=0)
    acccs_val_val0 /= 10
    acccs_val_val1 /= 10
    acccs_train_val0 /= 10
    acccs_train_val1 /= 10
    
    model_hparams = extract_model_hparams(config)

    return {        
        **model_hparams,

        **additional_columns,

        "train accuracy steady fold1":  acccs_train_val0,
        "train accuracy steady fold2":  acccs_train_val1,
        "train accuracy steady avg2folds": .5 * (acccs_train_val0 + acccs_train_val1),

        "validation accuracy steady fold1": acccs_val_val0,
        "validation accuracy steady fold2": acccs_val_val1,

        "test accuracy fold1": acccs_val0.mean(),
        "test accuracy fold2": acccs_val1.mean(),
        "test accuracy avg2folds": acccs.mean(), 
        "test accuracy steady fold1": acccs_steady_val0.mean(),
        "test accuracy steady fold2": acccs_steady_val1.mean(), 
        "test accuracy steady avg2folds": acccs_steady.mean(),

        "test accuracy steady avg2folds std across sessions": acccs_steady.std(),
        "test accuracy steady avg2folds std across subjects": acccs_steady_.std(),
        
        **{
          f"test accuracy steady session{s + 1 + test_sessions} avg2folds": acccs_steady[s] for s in range(test_sessions)
        },
        
        **{
            f"test accuracy steady subj{s} avg2folds": acccs_steady_persubject[s] for s in range(10)
        },
        
        **{
            f"test preds steady subj{s} avg2folds": preds_steady_bincounts_subject[s] for s in range(10)
        },
    }

from pickle import load

import numpy as np
import pandas as pd

def group_configs(configs, group_exclude_columns):
    # https://stackoverflow.com/a/6027615
    import collections.abc

    def flatten(d, parent_key='', sep='_'):
        items = []
        for k, v in d.items():
            new_key = parent_key + sep + k if parent_key else k
            if isinstance(v, collections.abc.MutableMapping):
                items.extend(flatten(v, new_key, sep=sep).items())
            else:
                items.append((new_key, v))
        return dict(items)

    df = pd.DataFrame([flatten(config) for config in configs])
    df['row_id'] = [[a] for a in df.index]
    
    if 'training_config_lr_scheduler_hparams_milestones' in df:
        df['training_config_lr_scheduler_hparams_milestones'] = df['training_config_lr_scheduler_hparams_milestones'].apply(lambda x: ','.join(map(str, x)))
    df = df.fillna('null')
    
    grouped_configs = df.groupby([c for c in df.columns if c not in group_exclude_columns]).agg({'subjects': 'count', 'row_id': 'sum'})
    
    if (grouped_configs['subjects'] != 10).sum() != 0:
        display(grouped_configs)
        raise ValueError("For every config, it is assumed that you trained on 10 subjects")
    
    return list(grouped_configs["row_id"])


extract_model_hparams_generator = {
    'vit': lambda config: {
        "window_size": config["image_size"][1],
        "patch_size": config["patch_size"][1],
        "dim_projection": config["dim"],
        "dim_ff": config["mlp_dim"],
        "dim_head": config["dim_head"],
        "n_heads": config["heads"],
        "depth": config["depth"],
        "dropout": config["dropout"],
        "emb_dropout": config["emb_dropout"],
        
        "MACs": vit(patch_size=config["patch_size"][1], dim=config["dim"], dim_head=config["dim_head"], n_heads=config["heads"], mlp_dim=config["mlp_dim"], depth=config["depth"]),
        "params":  sum([param.nelement() for param in ViT(image_size=(1, 300), patch_size=config["patch_size"], dim=config["dim"], dim_head=config["dim_head"], heads=config["heads"], mlp_dim=config["mlp_dim"], depth=config["depth"], num_classes=8).parameters()]),
        "params_aff_ratio": vit_aff_ratio(patch_size=config["patch_size"][1], dim=config["dim"], dim_head=config["dim_head"], n_heads=config["heads"], mlp_dim=config["mlp_dim"], depth=config["depth"]),
        
    },
    'temponet': lambda _: {
        "MACs": 16028672,
        "params": 461512,
    },
    "convit": lambda config: {
        "depth": config["depth"],
    }
}

def read_results(filename, additional_columns=None, group_exclude_columns=None, model_name='vit'):
    additional_columns = {} if additional_columns is None else additional_columns
    
    group_exclude_columns = set() if group_exclude_columns is None else group_exclude_columns
    group_exclude_columns = group_exclude_columns.union({'subjects', 'row_id'})
    
    configs, results_ = load(open(filename, 'rb'))
    
    groups_indices = group_configs(configs, group_exclude_columns)
    
    df_l = []
    for idx in groups_indices:
        c = [configs[i] for i in idx]
        r = [results_[i] for i in idx]
        
        ac = additional_columns.copy()
        ac["training_config"] = c[0]["training_config"]
        df_l.append(get_results(c, r, ac, extract_model_hparams_generator[model_name]))   
    
    return pd.DataFrame(df_l) 

def get_rows(all_res_vit, group):
    m = None
    for k in group.keys():
        current_m = all_res_vit[k] == group[k]
        if m is None:
            m = current_m
        else:
            m &= current_m
    return all_res_vit[m].copy()

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [2]:
df = read_results("pretrain/train_conv/results_1631106425.pickle",additional_columns={"conv": "all", "pretrain": "no"}, model_name='convit')\
.append(read_results("train_h3d8_adamw/train_conv/results_1631096710.pickle", additional_columns={"conv": "single", "pretrain": "no"}, group_exclude_columns={'conv_layers'}, model_name='convit'), ignore_index=True)\
.append(read_results("pretrain/pretrain_conv/finetune25/results_1631115451.pickle", additional_columns={"conv": "single", "pretrain": "yes_nonfixed", "finetune": 25}, group_exclude_columns={'conv_layers'}, model_name='convit'), ignore_index=True)\
.append(read_results("pretrain/pretrain_conv/finetune50/results_1631130033.pickle", additional_columns={"conv": "single", "pretrain": "yes_nonfixed", "finetune": 50}, group_exclude_columns={'conv_layers'}, model_name='convit'), ignore_index=True)\
.append(read_results("pretrain/pretrain_conv/finetune75/results_1631124308.pickle", additional_columns={"conv": "single", "pretrain": "yes_nonfixed", "finetune": 75}, group_exclude_columns={'conv_layers'}, model_name='convit'), ignore_index=True)\
.append(read_results("pretrain/pretrain_conv/finetune25_/results_1631175251.pickle", additional_columns={"conv": "single", "pretrain": "yes", "finetune": 25}, group_exclude_columns={'conv_layers'}, model_name='convit'), ignore_index=True)\
.append(read_results("pretrain/pretrain_conv/finetune50_/results_1631177094.pickle", additional_columns={"conv": "single", "pretrain": "yes", "finetune": 50}, group_exclude_columns={'conv_layers'}, model_name='convit'), ignore_index=True)\
.append(read_results("pretrain/pretrain_conv/finetune75_/results_1631178989.pickle", additional_columns={"conv": "single", "pretrain": "yes", "finetune": 50}, group_exclude_columns={'conv_layers'}, model_name='convit'), ignore_index=True)\
.append(read_results("pretrain_2/pretrain_conv/finetune25/results_1631285553.pickle", additional_columns={"conv": "single", "pretrain": "yes", "finetune": 'n'}, group_exclude_columns={'conv_layers'}, model_name='convit'), ignore_index=True)\

In [4]:
pd.options.display.max_colwidth = 1000

In [5]:
df

Unnamed: 0,depth,conv,pretrain,training_config,train accuracy steady fold1,train accuracy steady fold2,train accuracy steady avg2folds,validation accuracy steady fold1,validation accuracy steady fold2,test accuracy fold1,test accuracy fold2,test accuracy avg2folds,test accuracy steady fold1,test accuracy steady fold2,test accuracy steady avg2folds,test accuracy steady avg2folds std across sessions,test accuracy steady avg2folds std across subjects,test accuracy steady session6 avg2folds,test accuracy steady session7 avg2folds,test accuracy steady session8 avg2folds,test accuracy steady session9 avg2folds,test accuracy steady session10 avg2folds,test accuracy steady subj0 avg2folds,test accuracy steady subj1 avg2folds,test accuracy steady subj2 avg2folds,test accuracy steady subj3 avg2folds,test accuracy steady subj4 avg2folds,test accuracy steady subj5 avg2folds,test accuracy steady subj6 avg2folds,test accuracy steady subj7 avg2folds,test accuracy steady subj8 avg2folds,test accuracy steady subj9 avg2folds,test preds steady subj0 avg2folds,test preds steady subj1 avg2folds,test preds steady subj2 avg2folds,test preds steady subj3 avg2folds,test preds steady subj4 avg2folds,test preds steady subj5 avg2folds,test preds steady subj6 avg2folds,test preds steady subj7 avg2folds,test preds steady subj8 avg2folds,test preds steady subj9 avg2folds,finetune
0,8,all,no,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.832824,0.891702,0.862263,0.628401,0.638954,0.441933,0.449447,0.441933,0.602031,0.60594,0.603985,0.012319,0.118277,0.616156,0.590528,0.62098,0.593185,0.599078,0.753866,0.684087,0.450906,0.670771,0.801496,0.651356,0.541126,0.47742,0.53656,0.472264,"[41807, 8911, 7252, 11555, 5250, 8265, 8493, 5719]","[37130, 10021, 4311, 7235, 14533, 10389, 6589, 6860]","[97402, 0, 0, 0, 0, 0, 0, 0]","[45596, 6682, 6356, 7039, 7494, 12349, 5225, 6273]","[51944, 7929, 7942, 7570, 4418, 6030, 6991, 4506]","[41677, 5808, 6131, 5142, 8978, 13131, 9451, 6554]","[36381, 7532, 12002, 7409, 12199, 6717, 10644, 4496]","[97942, 0, 0, 0, 0, 0, 0, 0]","[73835, 5280, 3776, 2604, 3837, 2701, 4316, 2719]","[68396, 2106, 2865, 3539, 6398, 6286, 4247, 3403]",
1,8,single,no,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.882529,0.859839,0.871184,0.647656,0.635783,0.45573,0.448515,0.45573,0.614778,0.605766,0.610272,0.016849,0.106803,0.622817,0.589404,0.636001,0.599015,0.604123,0.746648,0.673362,0.450906,0.668524,0.787117,0.654543,0.558581,0.47742,0.538243,0.547374,"[43159, 9577, 8003, 8520, 5472, 8625, 8858, 5038]","[35555, 9297, 6584, 7208, 13775, 12204, 6015, 6430]","[97402, 0, 0, 0, 0, 0, 0, 0]","[43968, 8248, 5617, 8590, 6828, 12203, 6886, 4674]","[52325, 8817, 7123, 7560, 3986, 5596, 8221, 3702]","[44828, 5115, 8246, 5673, 7816, 10145, 10383, 4666]","[35180, 7878, 8591, 6620, 12448, 9308, 12962, 4393]","[97942, 0, 0, 0, 0, 0, 0, 0]","[74923, 3967, 4307, 2553, 4213, 2639, 4008, 2458]","[46626, 4676, 7847, 6280, 10477, 9842, 6589, 4903]",
2,8,single,yes_nonfixed,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.959673,0.998564,0.979119,0.676759,0.709262,0.46177,0.472543,0.46177,0.617109,0.623918,0.620513,0.019044,0.059948,0.642582,0.613916,0.643545,0.598913,0.60361,0.743256,0.541221,0.602619,0.656861,0.698363,0.615981,0.615221,0.594285,0.549048,0.588277,"[41603, 10600, 6281, 9057, 5419, 10113, 8405, 5774]","[34901, 2659, 906, 1782, 5888, 11619, 12282, 27031]","[55468, 3750, 4904, 4073, 7213, 10869, 5904, 5221]","[47962, 9259, 4465, 3970, 6029, 11772, 8041, 5516]","[53005, 4404, 10682, 2972, 8192, 4930, 3816, 9329]","[44820, 4944, 5629, 6212, 6996, 11084, 12243, 4944]","[38982, 7702, 10332, 7889, 8685, 8117, 11674, 3999]","[54955, 8021, 7953, 4404, 5753, 3206, 7772, 5878]","[73546, 4095, 4041, 2086, 4118, 2839, 4980, 3363]","[45239, 4659, 6754, 5554, 11795, 11614, 5909, 5716]",25.0
3,8,single,yes_nonfixed,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.998241,0.998068,0.998154,0.708974,0.736615,0.492865,0.47686,0.492865,0.654383,0.629076,0.64173,0.018011,0.070781,0.663354,0.633657,0.663246,0.621813,0.626578,0.74109,0.608682,0.613583,0.661406,0.802737,0.616337,0.622353,0.597427,0.559118,0.594563,"[42732, 9380, 6097, 10432, 5282, 9455, 8193, 5681]","[37894, 3918, 1626, 2037, 8323, 13903, 13684, 15683]","[54454, 3561, 6678, 4240, 7351, 10790, 6140, 4188]","[45895, 8892, 4499, 4612, 6601, 9868, 9783, 6864]","[53051, 7970, 9514, 5260, 4665, 5714, 6734, 4422]","[44650, 5211, 6166, 7183, 6559, 12127, 9700, 5276]","[38431, 7660, 9622, 8804, 9962, 8500, 10432, 3969]","[57191, 8231, 5578, 4329, 5725, 2475, 8116, 6297]","[75288, 4047, 3520, 2274, 4141, 3326, 4238, 2234]","[44094, 4775, 5595, 7104, 12449, 13096, 4691, 5436]",50.0
4,8,single,yes_nonfixed,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.997923,0.999701,0.998812,0.707988,0.750072,0.488877,0.479657,0.488877,0.651552,0.633752,0.642652,0.017293,0.067983,0.661105,0.634181,0.66529,0.621955,0.630728,0.737263,0.616858,0.606898,0.671402,0.79105,0.622458,0.630461,0.602135,0.553396,0.5946,"[43115, 9285, 6600, 7995, 5412, 10557, 8257, 6031]","[39793, 4466, 1305, 3008, 8029, 11032, 14869, 14566]","[59113, 3787, 5261, 3256, 9532, 8669, 3745, 4039]","[46457, 7633, 5563, 5589, 6069, 10831, 8701, 6171]","[52701, 7551, 8620, 6385, 4664, 5717, 7183, 4509]","[43665, 5496, 6070, 7341, 7203, 11224, 11157, 4716]","[38720, 7628, 11747, 7897, 9376, 7944, 10219, 3849]","[56763, 8292, 5717, 4860, 5750, 2313, 8025, 6222]","[75015, 4012, 4625, 2322, 3900, 3177, 3601, 2416]","[45302, 5342, 4967, 6092, 12818, 12275, 5342, 5102]",75.0
5,8,single,yes,"{'epochs': 20, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0.0001, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'StepLR', 'lr_scheduler_hparams': {'step_size': 10, 'gamma': 0.1}}",1.0,1.0,1.0,0.701014,0.764761,0.483799,0.4743,0.483799,0.643965,0.626183,0.635074,0.018878,0.070415,0.651722,0.626734,0.662666,0.613291,0.620956,0.734223,0.594341,0.596575,0.658075,0.793397,0.613355,0.625503,0.589611,0.554981,0.590679,"[43426, 8087, 5520, 9837, 5046, 9789, 8088, 7459]","[39860, 4575, 1939, 3178, 7745, 7514, 14139, 18118]","[55611, 3600, 6652, 2922, 7491, 9971, 5388, 5767]","[46824, 7787, 5527, 6266, 5439, 10751, 8663, 5757]","[52223, 6711, 8547, 6653, 4597, 5466, 7289, 5844]","[44634, 5457, 6599, 6428, 7675, 10613, 10054, 5412]","[39161, 8029, 10414, 9881, 8867, 8154, 8687, 4187]","[54289, 8005, 5533, 5024, 7263, 3152, 7692, 6984]","[72966, 4448, 4262, 2559, 4544, 3116, 3788, 3385]","[43836, 4776, 6043, 7410, 13825, 10651, 5289, 5410]",25.0
6,8,single,yes,"{'epochs': 20, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0.0001, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'StepLR', 'lr_scheduler_hparams': {'step_size': 10, 'gamma': 0.1}}",1.0,0.99982,0.99991,0.702524,0.807746,0.478013,0.471284,0.478013,0.63678,0.624332,0.630556,0.020094,0.069056,0.647797,0.627941,0.658224,0.604172,0.614647,0.721906,0.60057,0.582704,0.657313,0.787451,0.604389,0.627124,0.582811,0.552383,0.588909,"[44290, 7628, 4814, 9366, 4816, 8933, 10043, 7362]","[41895, 3567, 1726, 3068, 7367, 8923, 15480, 15042]","[55801, 2912, 6968, 2807, 7047, 11267, 5424, 5176]","[46969, 7024, 5467, 6470, 6181, 11766, 7567, 5570]","[52604, 6767, 8361, 6425, 4677, 4804, 7843, 5849]","[46128, 5281, 6577, 6521, 8039, 10520, 8288, 5518]","[39324, 8032, 10180, 9062, 9033, 8306, 9025, 4418]","[55071, 6924, 4619, 5158, 7330, 2840, 8064, 7936]","[72851, 5088, 4719, 2535, 4165, 2882, 4261, 2567]","[44943, 4741, 6018, 7261, 13895, 10012, 5111, 5259]",50.0
7,8,single,yes,"{'epochs': 20, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0.0001, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'StepLR', 'lr_scheduler_hparams': {'step_size': 10, 'gamma': 0.1}}",1.0,0.999773,0.999886,0.702832,0.840944,0.472904,0.473583,0.472904,0.630852,0.62658,0.628716,0.021979,0.069588,0.646831,0.625515,0.6596,0.59972,0.611913,0.727337,0.597063,0.588965,0.6528,0.783393,0.590392,0.630161,0.581199,0.552252,0.583597,"[44030, 8163, 4897, 8939, 4745, 8144, 10883, 7451]","[41544, 4669, 1501, 3609, 6368, 5648, 19503, 14226]","[57643, 3724, 5541, 3079, 7479, 10554, 4307, 5075]","[46061, 6775, 5052, 6628, 6273, 11788, 8376, 6061]","[52326, 6627, 8173, 6890, 4463, 4476, 8202, 6173]","[45202, 5160, 7282, 6341, 8477, 10099, 9139, 5172]","[40511, 7815, 9473, 9209, 8422, 7682, 9360, 4908]","[56393, 6867, 4145, 4588, 6427, 3156, 8034, 8332]","[73831, 4962, 3703, 2337, 4295, 2823, 4136, 2981]","[42834, 5174, 6948, 7438, 13104, 10737, 5819, 5186]",50.0
8,2,single,yes,"{'epochs': 20, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'StepLR', 'lr_scheduler_hparams': {'step_size': 10, 'gamma': 0.1}}",0.856633,0.667566,0.762099,0.667566,0.856633,0.452885,0.452885,0.452885,0.599935,0.599935,0.599935,0.026289,0.069445,0.636965,0.598142,0.61604,0.558392,0.590136,0.706369,0.570263,0.550364,0.551571,0.743818,0.568019,0.642014,0.580132,0.521666,0.565136,"[43032, 8378, 7302, 7696, 5438, 10330, 8276, 6800]","[40844, 3362, 4898, 3110, 6048, 8036, 16512, 14258]","[67090, 2956, 1696, 1016, 6354, 12694, 2574, 3022]","[41908, 6354, 1982, 12918, 2768, 15794, 9088, 6202]","[51428, 5758, 8040, 8244, 4586, 5878, 8318, 5078]","[43684, 4420, 5670, 6176, 14072, 13966, 4870, 4014]","[38168, 9210, 8626, 7652, 9938, 9642, 8862, 5282]","[51120, 9808, 4566, 5468, 4404, 5906, 8770, 7900]","[65166, 10368, 2506, 4848, 4748, 6562, 3484, 1386]","[36440, 7978, 7538, 13962, 14572, 7888, 4650, 4212]",n
9,2,single,yes,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.996431,0.995991,0.996211,0.709019,0.716973,0.491747,0.473887,0.491747,0.652992,0.62623,0.639611,0.017223,0.073236,0.660144,0.62586,0.661206,0.626749,0.624097,0.742451,0.60914,0.614546,0.66274,0.801409,0.622619,0.621077,0.585088,0.545317,0.591726,"[41349, 9469, 5210, 9590, 5803, 10619, 9143, 6069]","[37060, 4211, 2475, 1805, 10566, 11821, 12975, 16155]","[56629, 4186, 5231, 4385, 8075, 6273, 6299, 6324]","[47119, 8846, 5968, 4902, 6268, 9855, 8812, 5244]","[53032, 7977, 7167, 8271, 3994, 6180, 6706, 4003]","[46397, 5147, 4691, 6809, 7651, 10073, 11881, 4223]","[39878, 8133, 8943, 6423, 10118, 7564, 11789, 4532]","[58284, 9226, 6306, 3983, 5022, 2858, 8093, 4170]","[74421, 5471, 3941, 2283, 4063, 3345, 4055, 1489]","[45078, 4215, 5158, 7160, 11889, 14052, 4732, 4956]",n


In [7]:
df[['depth', 'pretrain','training_config', 'test accuracy steady avg2folds']]

Unnamed: 0,depth,pretrain,training_config,test accuracy steady avg2folds
0,8,no,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.603985
1,8,no,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.610272
2,8,yes_nonfixed,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.620513
3,8,yes_nonfixed,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.64173
4,8,yes_nonfixed,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.642652
5,8,yes,"{'epochs': 20, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0.0001, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'StepLR', 'lr_scheduler_hparams': {'step_size': 10, 'gamma': 0.1}}",0.635074
6,8,yes,"{'epochs': 20, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0.0001, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'StepLR', 'lr_scheduler_hparams': {'step_size': 10, 'gamma': 0.1}}",0.630556
7,8,yes,"{'epochs': 20, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0.0001, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'StepLR', 'lr_scheduler_hparams': {'step_size': 10, 'gamma': 0.1}}",0.628716
8,2,yes,"{'epochs': 20, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'StepLR', 'lr_scheduler_hparams': {'step_size': 10, 'gamma': 0.1}}",0.599935
9,2,yes,"{'epochs': 75, 'batch_size': 64, 'optim': 'AdamW', 'optim_hparams': {'lr': 0, 'betas': (0.9, 0.999), 'weight_decay': 0.01}, 'lr_scheduler': 'CyclicLR', 'lr_scheduler_hparams': {'base_lr': 1e-07, 'max_lr': 0.001, 'step_size_up': 50, 'step_size_down': None, 'mode': 'triangular', 'cycle_momentum': False}}",0.639611
