# ResNet-18 SI + CIFAR10

## usual training

### Run training

In [None]:
USUAL_ELRS = [
    1e-6, 2e-6, 5e-6, 
    1e-5, 1.4e-5, 2e-5, 3e-5, 5e-5, 7e-5,
    1e-4, 1.4e-4, 2e-4, 3e-4, 5e-4, 7e-4, 
    1e-3, 1.4e-3, 2e-3, 3e-3, 5e-3, 7e-3,
    1e-2, 1.4e-2, 2e-2, 3e-2, 5e-2, 7e-2,
    1e-1, 2e-1, 5e-1,
    1e+0, 2e+0
]

USUAL_ESEEDS = [
    2000, 2001, 2002,
    2003, 2004, 2005, 2006, 2007, 2008,
    2009, 2010, 2011, 2012, 2013, 2014,
    2015, 2016, 2017, 2018, 2019, 2020,
    2021, 2022, 2023, 2024, 2025, 2026,
    2027, 2028, 2029,
    2030, 2031
]

EDLRS = [1e-5, 3e-5, 5e-5, 7e-5, 1e-4, 1.4e-4, 2e-4, 2.5e-4, 3e-4]

In [None]:
ELR2SEED = dict()
for k, v in zip(USUAL_ELRS, USUAL_ESEEDS):
    ELR2SEED[k] = v

In [None]:
split = 5
for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):
    gp = i // split
    txt = """python train_drop_resnet18si_cifar10_clean.py \\
    --gpu {} \\
    --init_elr {} --drop_elr {} \\
    --drop_epoch {} \\
    --seed {} && \\"""
    print(txt.format(
        gp,
        elr, elr, 
        1000,
        seed
    ))
    if i % split == split - 1:
        print()
        print()

### Сalculate gradients

In [None]:
split = 6
for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):
    gp = i // split + 1
    txt = """python calc_grad_norms_resnet18si_cifar_clean.py \\
    --gpu {} \\
    --directory_with_checkpoints ./Experiments/ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1000_wd_0.0_seed_{}_noaug_True/ \\
    --loader train \\
    --aug 0 \\
    --train_mode 1 && \\"""
    print(txt.format(
        gp,
        elr, elr, seed
    ))
    if i % split == split - 1:
        print()
        print()

## Drops

### Run training

In [None]:
split = 5
for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):
    gp = i // split
    for edlr in EDLRS:
        txt = """python train_drop_resnet18si_cifar10_clean_from_starting_point.py \\
        --gpu {} \\
        --init_checkpoint ./Experiments/ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1000_wd_0.0_seed_{}_noaug_True/checkpoint-200.pt \\
        --init_elr {} --drop_elr {} \\
        --drop_epoch {} \\
        --k_epoch {} \\
        --seed {} && \\"""
        print(txt.format(
            gp,
            elr, elr, seed,
            elr, edlr,
            200,
            250,
            seed
            
    ))
    if i % split == split - 1:
        print()
        print()

## SWA

In [None]:
for i, (blr, sd) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):
    gp = i // 5
    step = 200
    k_epoch = 100
    txt = "python custom_swa_resnet18si_cifar_starting_point_clean.py --gpu {} --elr {} \\\n".format(gp, blr) + \
          "   --k_epoch {} --seed {} \\\n".format(k_epoch, sd) + \
          "   --stride 1 --start_swa_epoch 200"
    if i % 6 == 5:
        print(txt)
        print()
    else:
        print(txt + " && \\")

# readings

In [None]:
import os
import pickle
from glob import glob

import torch
import numpy as np

from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib.pyplot import cm
from matplotlib.lines import Line2D

In [None]:
def check_key_name(key):
    return ('.running_var' not in key) and \
        ('.num_batches_tracked' not in key) and \
        ('.running_mean' not in key) and \
        ('linear.weight' not in key) and \
        ('n_averaged' not in key)


def make_flatten_vec(state_dict, layer=None):
    values = []
    if layer is None:
        for key, value in state_dict.items():
            if check_key_name(key):
                values.append(torch.flatten(value))
    else:
        values.append(torch.flatten(state_dict[layer]))
    vec = torch.cat(values, 0).to(torch.float64)
    return vec

In [None]:
usual_tracks = dict()
for elr, seed in zip(USUAL_ELRS, USUAL_ESEEDS):
    pth = './Experiments/ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1000_wd_0.0_seed_{}_noaug_True/checkpoint-{}.pt'
    usual_tracks[elr] = []
    
    for ckpt in tqdm(range(1001)):
        ckptpth = pth.format(elr, elr, seed, ckpt)
        data = torch.load(ckptpth)
        
        record = {}
        
        record['ep'] = ckpt
        record['train_loss'] = data['train_res']['loss']
        record['train_accuracy'] = data['train_res']['accuracy']
        
        record['test_loss'] = data['test_res']['loss']
        record['test_accuracy'] = data['test_res']['accuracy']
        
        record['elr'] = elr
        record['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())
        if 'gnorm_trainmode_m_train' in data:
            record['gnorm_trainmode'] = data['gnorm_trainmode_m_train']
        if 'loss_trainmode_train' in data:
            record['loss_trainmode_train'] = data['loss_trainmode_train']
        if 'acc_trainmode_train' in data:
            record['acc_trainmode_train']  = data['acc_trainmode_train']
        
        
        if 'gnorm_evalmode_m_train' in data:
            record['gnorm_evalmode'] = data['gnorm_evalmode_m_train']
        if 'loss_evalmode_train' in data:
            record['loss_evalmode_train'] = data['loss_evalmode_train']
        if 'acc_evalmode_train' in data:
            record['acc_evalmode_train']  = data['acc_evalmode_train']
        
        
        usual_tracks[elr].append(record)

In [None]:
drop_checkpoints = dict()
for elr, seed in zip(USUAL_ELRS, USUAL_ESEEDS):     
    print('-'*80)
    print(elr, seed)
    print('-'*80)
    
    pth = './Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepochfrom_{}_wd_0.0_seed_{}_noaug_True/checkpoint-{}.pt'
    pthmsk = './Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepochfrom_{}_wd_0.0_seed_{}_noaug_True/'
    if elr not in drop_checkpoints:
        drop_checkpoints[elr] = dict()
    
    for drop_start in [200]:
        if drop_start not in drop_checkpoints[elr]:
            drop_checkpoints[elr][drop_start] = dict()

        for edlr in EDLRS:
        
            drop_checkpoints[elr][drop_start][edlr] = []
        
            globmsk = glob(pthmsk.format(elr, edlr, drop_start, '*'))
            globmsk = list([x for x in globmsk if 'noaug_False' not in x])
            if globmsk:
                globmsk = globmsk[-1]
                print(glob(globmsk))

                for ckpt in tqdm(range(drop_start + 1, drop_start + 201)):
                    ckptpth = globmsk + 'checkpoint-{}.pt'.format(ckpt)
                    data = torch.load(ckptpth)

                    record = {}
                    
                    record['ep'] = ckpt
                    record['train_loss'] = data['train_res']['loss']
                    record['train_accuracy'] = data['train_res']['accuracy']

                    record['test_loss'] = data['test_res']['loss']
                    record['test_accuracy'] = data['test_res']['accuracy']

                    record['elr'] = elr
                    record['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())
                    
                    if 'gnorm_trainmode_m_train' in data:
                        record['gnorm_trainmode'] = data['gnorm_trainmode_m_train']
                    if 'loss_trainmode_train' in data:
                        record['loss_trainmode_train'] = data['loss_trainmode_train']
                    if 'acc_trainmode_train' in data:
                        record['acc_trainmode_train']  = data['acc_trainmode_train']


                    if 'gnorm_evalmode_m_train' in data:
                        record['gnorm_evalmode'] = data['gnorm_evalmode_m_train']
                    if 'loss_evalmode_train' in data:
                        record['loss_evalmode_train'] = data['loss_evalmode_train']
                    if 'acc_evalmode_train' in data:
                        record['acc_evalmode_train']  = data['acc_evalmode_train']
                    
                    drop_checkpoints[elr][drop_start][edlr].append(record)
#     break

In [None]:
swa_checkpoints = dict()

for elr in tqdm(USUAL_ELRS):
    swa_checkpoints[elr] = dict()
    for start_epoch in [200]:
        swa_checkpoints[elr][start_epoch] = dict()
        for k in [2, 5, 10, 50, 100]:
            
            base = './Experiments/SWA_K_100_stride_1_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1001_wd_0.0/swa_start_{:03d}_k_001/checkpoint-{}.pt'
            
            path = base.format(elr, elr, start_epoch, start_epoch + k - 1)
            
            data = torch.load(path)
    
            record = {}

            record['ep'] = start_epoch + k
            record['train_loss'] = data['train_res']['loss']
            record['train_accuracy'] = data['train_res']['accuracy']

            record['test_loss'] = data['test_res']['loss']
            record['test_accuracy'] = data['test_res']['accuracy']

            record['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())
            if 'gnorm_trainmode_m_train' in data:
                record['gnorm_trainmode'] = data['gnorm_trainmode_m_train']
            if 'loss_trainmode_train' in data:
                record['loss_trainmode_train'] = data['loss_trainmode_train']
            if 'acc_trainmode_train' in data:
                record['acc_trainmode_train']  = data['acc_trainmode_train']


            if 'gnorm_evalmode_m_train' in data:
                record['gnorm_evalmode'] = data['gnorm_evalmode_m_train']
            if 'loss_evalmode_train' in data:
                record['loss_evalmode_train'] = data['loss_evalmode_train']
            if 'acc_evalmode_train' in data:
                record['acc_evalmode_train']  = data['acc_evalmode_train']
            
            swa_checkpoints[elr][start_epoch][k] = record

In [None]:
with open('resnet18si_cifar10_swa_metrics.pkl', 'wb') as f:
    pickle.dump(swa_checkpoints, f)

# saving to pickle

In [None]:
with open('./resnet18si_cifar10_usual_metrics.pkl', 'wb') as f:
    pickle.dump(usual_tracks, f)
    
with open('./resnet18si_cifar10_drop_metrics.pkl', 'wb') as f:
    pickle.dump(drop_checkpoints, f)