In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/debug_ipllr_renorm')
CONFIG_PATH = os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', 'fc_ipllr_mnist.yaml')

In [3]:
import sys
sys.path.append(ROOT)

In [4]:
import os
from copy import deepcopy
import torch
import math
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, Subset, DataLoader
import torch.nn.functional as F

from utils.tools import read_yaml, set_random_seeds
from pytorch.configs.base import BaseConfig
from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from pytorch.models.abc_params.fully_connected.muP import FCmuP
from pytorch.models.abc_params.fully_connected.ntk import FCNTK
from pytorch.models.abc_params.fully_connected.standard_fc_ip import StandardFCIP
from utils.data.mnist import load_data
from utils.abc_params.debug_ipllr import *
from utils.plot.abc_parameterizations.debug_ipllr import *

### Load basic configuration and define variables 

In [5]:
N_TRIALS = 5
SEED = 30
L = 6
width = 1024
n_warmup_steps = 1
batch_size = 512
base_lr = 0.1
n_steps = 100
renorm_first = False
scale_first_lr = False

set_random_seeds(SEED)  # set random seed for reproducibility
config_dict = read_yaml(CONFIG_PATH)

In [6]:
config_dict = read_yaml(CONFIG_PATH)

input_size = config_dict['architecture']['input_size']

config_dict['architecture']['width'] = width
config_dict['architecture']['n_layers'] = L + 1
config_dict['optimizer']['params']['lr'] = base_lr
config_dict['scheduler'] = {'name': 'warmup_switch',
                            'params': {'n_warmup_steps': n_warmup_steps,
                                       'calibrate_base_lr': True,
                                       'default_calibration': False}}
        
base_model_config = ModelConfig(config_dict)

### Load data & define models

In [7]:
training_dataset, test_dataset = load_data(download=False, flatten=True)
train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size)
test_batches = list(DataLoader(test_dataset, shuffle=False, batch_size=batch_size))
batches = list(train_data_loader)
eval_batch = test_batches[0]

In [8]:
config_dict['scheduler']['params']['calibrate_base_lr'] = False
config = ModelConfig(config_dict)

ipllrs = [FcIPLLR(config) for _ in range(N_TRIALS)]
#ipllrs_renorm = [FcIPLLR(config) for _ in range(N_TRIALS)]
#ipllrs_renorm_scale_lr = [FcIPLLR(config) for _ in range(N_TRIALS)]

config_dict['scheduler']['params']['calibrate_base_lr'] = True
config = ModelConfig(config_dict)
ipllrs_calib = [FcIPLLR(config, lr_calibration_batches=batches) for _ in range(N_TRIALS)]
ipllrs_calib_renorm = [FcIPLLR(config, lr_calibration_batches=batches) for _ in range(N_TRIALS)]
ipllrs_calib_renorm_scale_lr = [FcIPLLR(config, lr_calibration_batches=batches) for _ in range(N_TRIALS)]

initial base lr : [78.5, 38.08141326904297, 68.91138458251953, 68.3167724609375, 74.03038024902344, 102.55865478515625, 30.983808517456055]
initial base lr : [78.5, 37.010562896728516, 66.8042984008789, 72.8503189086914, 78.61959075927734, 90.54930877685547, 29.0798397064209]
initial base lr : [78.5, 36.512081146240234, 58.382625579833984, 65.4113540649414, 75.37689971923828, 83.77328491210938, 22.76885414123535]
initial base lr : [78.5, 42.51643371582031, 72.11700439453125, 76.09538269042969, 83.21681213378906, 115.44151306152344, 35.31344223022461]
initial base lr : [78.5, 40.96342086791992, 71.92338562011719, 69.65662384033203, 77.32463836669922, 93.46656036376953, 26.484838485717773]
initial base lr : [78.5, 36.689453125, 71.21871948242188, 70.73763275146484, 75.81897735595703, 87.43914794921875, 28.48207664489746]
initial base lr : [78.5, 33.913116455078125, 59.712738037109375, 63.80599594116211, 63.03622055053711, 69.04557800292969, 19.315895080566406]
initial base lr : [78.5, 35

In [9]:
for i in range(N_TRIALS):
    # copy params
    #ipllrs_renorm[i].copy_initial_params_from_model(ipllrs[i])
    #ipllrs_renorm_scale_lr[i].copy_initial_params_from_model(ipllrs[i])
    
    ipllrs_calib[i].copy_initial_params_from_model(ipllrs[i])
    ipllrs_calib_renorm[i].copy_initial_params_from_model(ipllrs[i])
    ipllrs_calib_renorm_scale_lr[i].copy_initial_params_from_model(ipllrs[i])
    
    # re-initialize
    #ipllrs_renorm[i].initialize_params()
    #ipllrs_renorm_scale_lr[i].initialize_params()
    
    ipllrs_calib[i].initialize_params()
    ipllrs_calib_renorm[i].initialize_params()
    ipllrs_calib_renorm_scale_lr[i].initialize_params()

In [10]:
# Make sure calibration takes into account normalization

for ipllr in ipllrs_calib:    
    initial_base_lrs = ipllr.scheduler.calibrate_base_lr(ipllr, batches=batches, normalize_first=False)
    ipllr.scheduler._set_param_group_lrs(initial_base_lrs)
    
for ipllr in ipllrs_calib_renorm:        
    initial_base_lrs = ipllr.scheduler.calibrate_base_lr(ipllr, batches=batches, normalize_first=True)
    ipllr.scheduler._set_param_group_lrs(initial_base_lrs)
    
for ipllr in ipllrs_calib_renorm_scale_lr:            
    initial_base_lrs = ipllr.scheduler.calibrate_base_lr(ipllr, batches=batches, normalize_first=True)
    ipllr.scheduler._set_param_group_lrs(initial_base_lrs)

initial base lr : [0.1, 0.06003998592495918, 1.2169116735458374, 2.17669939994812, 2.4937546253204346, 2.8725359439849854, 0.8645324110984802]
initial base lr : [0.1, 0.0596611388027668, 1.2741634845733643, 2.544586658477783, 2.7952044010162354, 3.329960584640503, 1.0341224670410156]
initial base lr : [0.1, 0.059218455106019974, 1.2441296577453613, 2.4515788555145264, 2.6010076999664307, 3.0965986251831055, 0.8844329118728638]
initial base lr : [0.1, 0.060484614223241806, 1.2854281663894653, 2.2207565307617188, 2.7277324199676514, 3.3855533599853516, 0.9613930583000183]
initial base lr : [0.1, 0.06183946132659912, 1.1814137697219849, 2.329699754714966, 2.2961533069610596, 2.620152473449707, 0.7345650792121887]
initial base lr : [78.5, 35.073184967041016, 59.817752838134766, 61.449581146240234, 69.81730651855469, 80.47264099121094, 24.22148323059082]
initial base lr : [78.5, 36.027191162109375, 66.83576965332031, 72.37831115722656, 78.34403991699219, 93.29310607910156, 28.97234916687011

In [11]:
# scale lr of first layer if needed

#for ipllr in ipllrs_renorm_scale_lr:
#    for i, param_group in enumerate(ipllr.optimizer.param_groups):
#        if i == 0:
#            param_group['lr'] = param_group['lr'] * (ipllr.d + 1)
#    ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * (ipllr.d + 1)
    
for ipllr in ipllrs_calib_renorm_scale_lr:
    ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * (ipllr.d + 1)

In [None]:
results = dict()

# without calibration
#results['ipllr'] = [collect_training_losses(ipllrs[i], batches, n_steps, normalize_first=False) 
#                    for i in range(N_TRIALS)]

#results['ipllr_renorm'] = [collect_training_losses(ipllrs_renorm[i], batches, n_steps, normalize_first=True)
#                           for i in range(N_TRIALS)]

#results['ipllr_renorm_scale_lr'] = [collect_training_losses(ipllrs_renorm_scale_lr[i], batches, n_steps, 
#                                                            normalize_first=True) 
#                                    for i in range(N_TRIALS)]

# with calibration
results['ipllr_calib'] = [collect_training_losses(ipllrs_calib[i], batches, n_steps, normalize_first=False)
                                 for i in range(N_TRIALS)]

results['ipllr_calib_renorm'] = [collect_training_losses(ipllrs_calib_renorm[i], batches, n_steps, 
                                                         normalize_first=True)
                                 for i in range(N_TRIALS)]

results['ipllr_calib_renorm_scale_lr'] = \
    [collect_training_losses(ipllrs_calib_renorm_scale_lr[i], batches, n_steps, normalize_first=True) 
     for i in range(N_TRIALS)]

# Training

In [None]:
mode = 'training'

In [None]:
losses = dict()
for key, res in results.items():
    losses[key] = [r[0] for r in res]
    
chis = dict()
for key, res in results.items():
    chis[key] = [r[1] for r in res]

## Losses and derivatives

In [None]:
key = 'loss'
plt.figure(figsize=(12, 8))
plot_losses_models(losses, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, mode=mode, 
                   normalize_first=renorm_first, marker=None, name='IPLLR')
plt.savefig(
    os.path.join(FIGURES_DIR, 'IPLLRs_1_last_{}_{}_L={}_m={}_lr={}_bs={}.png'.\
                 format(mode, key, L, width, base_lr, batch_size, renorm_first, scale_first_lr)))
plt.show()

In [None]:
key = 'chi'
plt.figure(figsize=(12, 8))
plot_losses_models(chis, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, mode=mode, marker=None,
                   name='IPLLR')
plt.savefig(os.path.join(FIGURES_DIR, 'IPLLRs_1_last_{}_{}_L={}_m={}_lr={}_bs={}.png'.\
                         format(mode, key, L, width, base_lr, batch_size)))
plt.show()