# imports

In [None]:
import os
import pickle

import numpy as np
from tqdm import tqdm

from matplotlib import pyplot as plt

import torch
import numpy as np
from glob import glob

from matplotlib.pyplot import cm

In [None]:
def check_key_name(key):
    return ('.running_var' not in key) and \
        ('.num_batches_tracked' not in key) and \
        ('.running_mean' not in key) and \
        ('linear.weight' not in key) and \
        ('n_averaged' not in key)

def make_flatten_vec(state_dict, layer=None):
    values = []
    if layer is None:
        for key, value in state_dict.items():
            if check_key_name(key):
                values.append(torch.flatten(value))
    else:
        values.append(torch.flatten(state_dict[layer]))
#             print('adding ', value.shape)
    vec = torch.cat(values, 0).to(torch.float64)
    return vec

In [None]:
def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector / np.linalg.norm(vector)

def angle_between(v1, v2):
    """ Returns the angle in radians between vectors 'v1' and 'v2'::

            >>> angle_between((1, 0, 0), (0, 1, 0))
            1.5707963267948966
            >>> angle_between((1, 0, 0), (1, 0, 0))
            0.0
            >>> angle_between((1, 0, 0), (-1, 0, 0))
            3.141592653589793
    """
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))

def get_init_angle_dist(point_a:str, point_b:str):
    sd1 = torch.load(point_a)['state_dict']
    sd2 = torch.load(point_b)['state_dict']
    
#     print(sd1.keys()) 
#     print('-'*10)
#     print(sd2.keys())
    
    vec1 = make_flatten_vec(sd1).detach().cpu()
    vec2 = make_flatten_vec(sd2).detach().cpu()
    cdist = angle_between(vec1, vec2)
    return cdist

# LI with barrier

In [None]:
def read_angle_dist_from_track(checkpoint_dir: str, n_steps: int=20):
    base = os.path.join(checkpoint_dir, 'interp_result_{:5.4f}-{}.pt')
    pt_path0 = base.format(0.0, int(0.0))
    pt_path1 = base.format(1.0, int(1.0))

    angle = get_init_angle_dist(point_a=pt_path0, point_b=pt_path1)   
       
    return angle

In [None]:
def read_li_track(checkpoint_dir: str, n_steps: int=20):
    track = []
    for alpha in np.linspace(0.0, 1.0, n_steps + 1):
        base = os.path.join(checkpoint_dir, 'interp_result_{:5.4f}-{}.pt')
        pt_path = base.format(alpha, int(alpha))
        data = torch.load(pt_path)

        data['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())

        del data['state_dict']

        data['alpha'] = alpha

        track.append(data)
    return track

In [None]:
def track_to_barrier(track_values: list, barrier_is_higher: bool=True):
    track_values = np.array(track_values)
    A = track_values[0]
    B = track_values[-1]
    
    alpha = np.linspace(0.0, 1.0, len(track_values))
    li = (1.0 - alpha) * A + alpha * B 
    if barrier_is_higher:
        return (track_values - li).clip(min=0.0).max()
    return (li - track_values).clip(min=0.0).max()

# general

In [None]:
USUAL_ELRS = [
    1e-6, 2e-6, 5e-6, 
    1e-5, 1.4e-5, 2e-5, 3e-5, 5e-5, 7e-5,
    1e-4, 1.4e-4, 2e-4, 3e-4, 5e-4, 7e-4, 
    1e-3, 1.4e-3, 2e-3, 3e-3, 5e-3, 7e-3,
    1e-2, 1.4e-2, 2e-2, 3e-2, 5e-2, 7e-2,
    1e-1, 2e-1, 5e-1,
    1e+0, 2e+0
]

USUAL_ESEEDS = [
    2000, 2001, 2002,
    2003, 2004, 2005, 2006, 2007, 2008,
    2009, 2010, 2011, 2012, 2013, 2014,
    2015, 2016, 2017, 2018, 2019, 2020,
    2021, 2022, 2023, 2024, 2025, 2026,
    2027, 2028, 2029,
    2030, 2031
]

EDLRS = [1e-5, 3e-5, 5e-5, 7e-5, 1e-4, 1.4e-4, 2e-4, 2.5e-4, 3e-4]

# DROP(HIGH) -> DROP(LOW)

In [None]:
for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):
    gp = (i // 6)
    
    print("""python ./../drops_with_clean_cifar10/linear_interpolation_resnet18si_cifar10_clean.py \\
    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\
    --point_a ./../drops_with_clean_cifar10/Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_0.0003_dropepochfrom_200_wd_0.0_seed_{}_noaug_True/checkpoint-400.pt \\
    --point_b ./../drops_with_clean_cifar10/Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_1e-05_dropepochfrom_200_wd_0.0_seed_{}_noaug_True/checkpoint-400.pt \\
    --save ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_drop_0.0003_to_drop_1e-05_seed_{}/ && \\""".\
         format(gp, elr, 
                elr, seed, 
                elr, seed, 
                elr, seed))
    print("""python ./../drops_with_clean_cifar10/calc_grad_norms_resnet18si_cifar_clean.py \\
    --gpu {} \\
    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_drop_0.0003_to_drop_1e-05_seed_{}/ \\
    --train_mode 1 && \\""".\
         format(gp, 
                elr, seed))
    if i % 6 == 5:
        print('\n\n')

In [None]:
interp_drophigh_droplow = dict()
for elr, seed in tqdm(zip(USUAL_ELRS, USUAL_ESEEDS)): 
    interp_drophigh_droplow [elr] = dict()
    base = './Experiments/CONNECTIVITY_RN18C10_lri_{}_from_drop_0.0003_to_drop_1e-05_seed_{}'
    pt_path = base.format(elr, seed)
    
    interp_drophigh_droplow[elr]['angle'] = read_angle_dist_from_track(pt_path)
    
    track = read_li_track(pt_path)
    interp_drophigh_droplow[elr]['track'] = track
    
    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)
    interp_drophigh_droplow[elr]['loss_barrier'] = loss_barrier
    
    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)
    interp_drophigh_droplow[elr]['lossts_barrier'] = lossts_barrier
    
    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)
    interp_drophigh_droplow[elr]['testacc_barrier'] = testacc_barrier
    
    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)
    interp_drophigh_droplow[elr]['trainacc_barrier'] = loss_barrier

In [None]:
# interp_drophigh_droplow[1e-6]

# SWA(5) -> DROP(HIGH)

In [None]:
for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):
    gp = (i // 6)
    
    print("""python ./../drops_with_clean_cifar10/linear_interpolation_resnet18si_cifar10_clean.py \\
    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\
    --point_a ./../drops_with_clean_cifar10/Experiments/SWA_K_100_stride_1_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1001_wd_0.0/swa_start_200_k_001/checkpoint-204.pt \\
    --point_b ./../drops_with_clean_cifar10/Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_0.0003_dropepochfrom_200_wd_0.0_seed_{}_noaug_True/checkpoint-400.pt \\
    --save ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_0.0003_seed_{}/ && \\""".\
         format(gp, elr, 
                elr, elr, 
                elr, seed, 
                elr, seed))
    print("""python ./../drops_with_clean_cifar10/calc_grad_norms_resnet18si_cifar_clean.py \\
    --gpu {} \\
    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_0.0003_seed_{}/ \\
    --train_mode 1 && \\""".\
         format(gp, 
                elr, seed))
    if i % 6 == 5:
        print('\n\n')

In [None]:
interp_swa5_drophigh = dict()
for elr, seed in tqdm(zip(USUAL_ELRS, USUAL_ESEEDS)): 
    interp_swa5_drophigh[elr] = dict()
    base = './Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_0.0003_seed_{}'
    pt_path = base.format(elr, seed)
    
    interp_swa5_drophigh[elr]['angle'] = read_angle_dist_from_track(pt_path)
    
    track = read_li_track(pt_path)
    interp_swa5_drophigh[elr]['track'] = track
        
    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)
    interp_swa5_drophigh[elr]['loss_barrier'] = loss_barrier
    
    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)
    interp_swa5_drophigh[elr]['lossts_barrier'] = lossts_barrier
    
    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)
    interp_swa5_drophigh[elr]['testacc_barrier'] = testacc_barrier
    
    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)
    interp_swa5_drophigh[elr]['trainacc_barrier'] = loss_barrier

# SWA(5) -> DROP(LOW)

In [None]:
for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):
    gp = (i // 6)
    
    print("""python ./../drops_with_clean_cifar10/linear_interpolation_resnet18si_cifar10_clean.py \\
    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\
    --point_a ./../drops_with_clean_cifar10/Experiments/SWA_K_100_stride_1_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1001_wd_0.0/swa_start_200_k_001/checkpoint-204.pt \\
    --point_b ./../drops_with_clean_cifar10/Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_1e-05_dropepochfrom_200_wd_0.0_seed_{}_noaug_True/checkpoint-400.pt \\
    --save ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_1e-05_seed_{}/ && \\""".\
         format(gp, elr, 
                elr, elr, 
                elr, seed, 
                elr, seed))
    print("""python ./../drops_with_clean_cifar10/calc_grad_norms_resnet18si_cifar_clean.py \\
    --gpu {} \\
    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_1e-05_seed_{}/ \\
    --train_mode 1 && \\""".\
         format(gp, 
                elr, seed))
    if i % 6 == 5:
        print('\n\n')

In [None]:
interp_swa5_droplow = dict()
for elr, seed in tqdm(zip(USUAL_ELRS, USUAL_ESEEDS)): 
    interp_swa5_droplow[elr] = dict()
    base = './Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_1e-05_seed_{}'
    pt_path = base.format(elr, seed)
    
    interp_swa5_droplow[elr]['angle'] = read_angle_dist_from_track(pt_path)
    
    track = read_li_track(pt_path)
    interp_swa5_droplow[elr]['track'] = track
    
    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)
    interp_swa5_droplow[elr]['loss_barrier'] = loss_barrier
    
    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)
    interp_swa5_droplow[elr]['lossts_barrier'] = lossts_barrier
    
    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)
    interp_swa5_droplow[elr]['testacc_barrier'] = testacc_barrier
    
    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)
    interp_swa5_droplow[elr]['trainacc_barrier'] = loss_barrier

# dump to the disk

In [None]:
barrier_setups = dict()
barrier_setups['interp_drophigh_droplow'] = interp_drophigh_droplow
barrier_setups['interp_swa5_drophigh'] = interp_swa5_drophigh
barrier_setups['interp_swa5_droplow'] = interp_swa5_droplow

In [None]:
with open('./resnet18si_cifar10_barrier_setups.pkl', 'wb') as f:
    pickle.dump(barrier_setups, f)