In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/initialization')
CONFIG_PATH = os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', 'fc_ipllr_mnist.yaml')

In [3]:
import sys
sys.path.append(ROOT)

In [4]:
import os
from copy import deepcopy
import torch
import math
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, Subset, DataLoader
import torch.nn.functional as F

from utils.tools import read_yaml, set_random_seeds
from pytorch.configs.base import BaseConfig
from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from pytorch.models.abc_params.fully_connected.muP import FCmuP
from pytorch.models.abc_params.fully_connected.ntk import FCNTK
from pytorch.models.abc_params.fully_connected.standard_fc_ip import StandardFCIP
from utils.data.mnist import load_data
from utils.abc_params.debug_ipllr import *

### Load basic configuration and define variables 

In [5]:
N_TRIALS = 1
SEED = 30
L = 6
width = 1024
n_warmup_steps = 1
batch_size = 512
base_lr = 0.1
n_steps = 50

set_random_seeds(SEED)  # set random seed for reproducibility
config_dict = read_yaml(CONFIG_PATH)

In [6]:
config_dict = read_yaml(CONFIG_PATH)

input_size = config_dict['architecture']['input_size']

config_dict['architecture']['width'] = width
config_dict['architecture']['n_layers'] = L + 1
config_dict['optimizer']['params']['lr'] = base_lr
config_dict['scheduler'] = {'name': 'warmup_switch',
                            'params': {'n_warmup_steps': n_warmup_steps,
                                       'calibrate_base_lr': True,
                                       'default_calibration': False}}
        
base_model_config = ModelConfig(config_dict)

### Load data & define model

In [7]:
training_dataset, test_dataset = load_data(download=False, flatten=True)
train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size)
test_batches = list(DataLoader(test_dataset, shuffle=False, batch_size=batch_size))
batches = list(train_data_loader)
eval_batch = test_batches[0]

### Define model

In [8]:
ipllr = FcIPLLR(base_model_config, n_warmup_steps=12, lr_calibration_batches=batches)

initial base lr : [0.1, 47.14387893676758, 62.35230255126953, 61.66318893432617, 69.83497619628906, 80.47408294677734, 2.422165632247925]


In [9]:
ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * (ipllr.d + 1)

### Save initial model : t=0

In [10]:
ipllr_0 = deepcopy(ipllr)

### Train model one step : t=1

In [11]:
x, y = batches[0]
train_model_one_step(ipllr, x, y, normalize_first=True)
ipllr_1 = deepcopy(ipllr)

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.9000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000, -0.9000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        ...,
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000, -0.9000,  0.1000]])
average training loss for model1 : 2.3025991916656494



### Train model for a second step : t=2

In [12]:
x, y = batches[1]
train_model_one_step(ipllr, x, y, normalize_first=True)
ipllr_2 = deepcopy(ipllr)

input abs mean in training:  0.6921874284744263
loss derivatives for model: tensor([[ 0.0820,  0.1149,  0.1018,  ...,  0.0950, -0.8772,  0.1078],
        [ 0.0744,  0.1220,  0.1021,  ...,  0.0924,  0.1345,  0.1112],
        [ 0.0787,  0.1179,  0.1020,  ...,  0.0939,  0.1277,  0.1093],
        ...,
        [ 0.0799,  0.1168,  0.1019,  ...,  0.0944,  0.1258,  0.1087],
        [ 0.0795, -0.8829,  0.1019,  ...,  0.0942,  0.1264,  0.1089],
        [ 0.0726,  0.1238,  0.1022,  ...,  0.0917, -0.8625,  0.1120]])
average training loss for model1 : 2.3087775707244873



In [13]:
ipllr.eval()
ipllr_0.eval()
ipllr_1.eval()
ipllr_2.eval()
print()




In [14]:
layer_scales = ipllr.layer_scales
intermediate_layer_keys = ["layer_{:,}_intermediate".format(l) for l in range(2, L + 1)]

### Define W0 and b0

In [15]:
W0 = {1: layer_scales[0] * ipllr_0.input_layer.weight.data.detach() / math.sqrt(ipllr_0.d + 1)}
for i, l in enumerate(range(2, L + 1)):
    layer = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i])
    W0[l] = layer_scales[l-1] * layer.weight.data.detach()

W0[L+1] = layer_scales[L] * ipllr_0.output_layer.weight.data.detach()

In [16]:
b0 = layer_scales[0] * ipllr_0.input_layer.bias.data.detach() / math.sqrt(ipllr_0.d + 1)

### Define Delta_W_1 and Delta_b_1

In [17]:
Delta_W_1 = {1: layer_scales[0] * (ipllr_1.input_layer.weight.data.detach() -
                                   ipllr_0.input_layer.weight.data.detach()) / math.sqrt(ipllr_1.d + 1)}
for i, l in enumerate(range(2, L + 1)):
    layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
    layer_0 = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i])
    Delta_W_1[l] = layer_scales[l-1] * (layer_1.weight.data.detach() -
                                        layer_0.weight.data.detach())

Delta_W_1[L+1] = layer_scales[L] * (ipllr_1.output_layer.weight.data.detach() -
                                    ipllr_0.output_layer.weight.data.detach())

In [18]:
Delta_b_1 = layer_scales[0] * (ipllr_1.input_layer.bias.data.detach() -
                               ipllr_0.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

### Define Delta_W_2

In [19]:
Delta_W_2 = {1: layer_scales[0] * (ipllr_2.input_layer.weight.data.detach() -
                                   ipllr_1.input_layer.weight.data.detach()) / math.sqrt(ipllr_2.d + 1)}
for i, l in enumerate(range(2, L + 1)):
    layer_2 = getattr(ipllr_2.intermediate_layers, intermediate_layer_keys[i])
    layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
    Delta_W_2[l] = layer_scales[l-1] * (layer_2.weight.data.detach() -
                                        layer_1.weight.data.detach())

Delta_W_2[L+1] = layer_scales[L] * (ipllr_2.output_layer.weight.data.detach() -
                                    ipllr_1.output_layer.weight.data.detach())

In [20]:
Delta_b_2 = layer_scales[0] * (ipllr_2.input_layer.bias.data.detach() -
                               ipllr_1.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

## Explore at step 2

### On examples from the second batch

In [21]:
x, y = batches[1]

In [22]:
with torch.no_grad():
    x2 = {0: x}
    h0 = {1: F.linear(x, W0[1], b0)}
    delta_h_1 = {1: F.linear(x, Delta_W_1[1], Delta_b_1)}
    delta_h_2 = {1: F.linear(x, Delta_W_2[1], Delta_b_2)}
    h1 = {1: layer_scales[0] * ipllr_1.input_layer.forward(x) / math.sqrt(ipllr_1.d + 1)}
    h2 = {1: layer_scales[0] * ipllr_2.input_layer.forward(x) / math.sqrt(ipllr_2.d + 1)}
    x2[1] = ipllr_2.activation(h2[1])

In [23]:
torch.testing.assert_allclose(h0[1] + delta_h_1[1], h1[1], rtol=1e-5, atol=1e-5)
torch.testing.assert_allclose(h0[1] + delta_h_1[1] + delta_h_2[1], h2[1], rtol=1e-5, atol=1e-5)

In [24]:
prod_1 = delta_h_1[1] * delta_h_2[1]

In [25]:
(prod_1 < 0).sum() / prod_1.numel()

tensor(0.4852)

In [29]:
(delta_h_1[1] < 0).sum() / delta_h_1[1].numel()

tensor(0.4914)

In [30]:
(delta_h_2[1] < 0).sum() / delta_h_2[1].numel()

tensor(0.6582)

In [31]:
h0[1][0, :].abs().mean()

tensor(0.9147)

In [32]:
h1[1][0, :].abs().mean()

tensor(0.9146)

In [33]:
h2[1][0, :].abs().mean()

tensor(1.1195)

In [34]:
with torch.no_grad():
    for i, l in enumerate(range(2, L + 1)):
        layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
        layer_2 = getattr(ipllr_2.intermediate_layers, intermediate_layer_keys[i])
        x = x2[l-1]

        h0[l] =  F.linear(x, W0[l])
        delta_h_1[l] = F.linear(x, Delta_W_1[l])
        delta_h_2[l] = F.linear(x, Delta_W_2[l])
        
        h1[l] = layer_scales[l-1] * layer_1.forward(x)
        h2[l] = layer_scales[l-1] * layer_2.forward(x)
        x2[l] = ipllr_2.activation(h2[l])
        
        torch.testing.assert_allclose(h0[l] + delta_h_1[l], h1[l], rtol=1e-5, atol=1e-5)
        torch.testing.assert_allclose(h0[l] + delta_h_1[l] + delta_h_2[l], h2[l], rtol=1e-5, atol=1e-5)

In [35]:
with torch.no_grad():
    x = x2[L] 
    h0[L+1] = F.linear(x, W0[L+1])
    delta_h_1[L+1] = F.linear(x, Delta_W_1[L+1])
    delta_h_2[L+1] = F.linear(x, Delta_W_2[L+1])
    h1[L+1] = layer_scales[L] * ipllr_1.output_layer.forward(x)
    h2[L+1] = layer_scales[L] * ipllr_2.output_layer.forward(x)
    x2[L+1] = ipllr_2.activation(h2[L+1])
                              
    torch.testing.assert_allclose(h0[L+1] + delta_h_1[L+1], h1[L+1], rtol=1e-5, atol=1e-5)
    torch.testing.assert_allclose(h0[L+1] + delta_h_1[L+1] + delta_h_2[L+1], h2[L+1], rtol=1e-5, atol=1e-5)

##### Signs

In [36]:
prod_1 = delta_h_1[2] * delta_h_2[2]

In [37]:
(prod_1 < 0).sum() / prod_1.numel()

tensor(0.4418)

In [41]:
(delta_h_2[1] < 0).sum() / delta_h_2[1].numel()

tensor(0.6582)

In [42]:
(delta_h_2[2] < 0).sum() / delta_h_2[2].numel()

tensor(0.5192)

In [43]:
(delta_h_2[3] < 0).sum() / delta_h_2[3].numel()

tensor(0.4844)

In [44]:
(delta_h_2[4] < 0).sum() / delta_h_2[4].numel()

tensor(0.4092)

In [45]:
(delta_h_2[5] < 0).sum() / delta_h_2[4].numel()

tensor(0.3379)

In [46]:
(delta_h_2[6] < 0).sum() / delta_h_2[6].numel()

tensor(0.3008)

In [47]:
(delta_h_2[7] < 0).sum() / delta_h_2[7].numel()

tensor(0.6000)

In [48]:
delta_h_2[7]

tensor([[ 0.0057, -0.0036,  0.0013,  ...,  0.0008, -0.0018, -0.0012],
        [ 0.0018, -0.0011,  0.0004,  ...,  0.0003, -0.0006, -0.0004],
        [ 0.0033, -0.0021,  0.0008,  ...,  0.0005, -0.0011, -0.0007],
        ...,
        [ 0.0017, -0.0011,  0.0004,  ...,  0.0002, -0.0005, -0.0004],
        [ 0.0083, -0.0052,  0.0019,  ...,  0.0012, -0.0026, -0.0017],
        [ 0.0044, -0.0027,  0.0010,  ...,  0.0006, -0.0014, -0.0009]])

##### Outputs

In [111]:
_, b = torch.max(h0[7][:, :], dim=1)
b

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,

In [112]:
_, b = torch.max(delta_h_1[7][:, :], dim=1)
b

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,

In [110]:
_, b = torch.max(delta_h_2[7][:, :], dim=1)
b

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

##### Scales

In [138]:
columns = ['h0', 'delta_h_1', 'delta_h_2', 'h1', 'h2']
df = pd.DataFrame(columns=columns, index=range(1, L+2))
df.index.name = 'layer'
for l in df.index:
    df.loc[l, columns] = [h0[l][0, :].abs().mean().item(), delta_h_1[l][0, :].abs().mean().item(), 
                          delta_h_2[l][0, :].abs().mean().item(), h1[l][0, :].abs().mean().item(),  
                          h2[l][0, :].abs().mean().item()]
df

Unnamed: 0_level_0,h0,delta_h_1,delta_h_2,h1,h2
layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,0.91465,0.00114543,0.299482,0.914639,1.11946
2,0.0376886,1.24705,0.000515771,1.24928,1.24934
3,0.0471406,1.25427,0.00141231,1.27666,1.27563
4,0.048591,1.25596,0.00152903,1.28043,1.27898
5,0.0473902,1.25069,0.00182065,1.27336,1.27197
6,0.0482783,1.24478,0.00317893,1.26475,1.2637
7,0.132329,0.012407,0.00220172,0.144206,0.14315


In [135]:
prod = h1[2][0, :] * delta_h_2[2][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.4209)

In [136]:
prod = h1[3][0, :] * delta_h_2[3][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.4688)

In [137]:
prod = h1[4][0, :] * delta_h_2[4][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.4072)

In [132]:
prod = h1[5] * delta_h_2[5]
(prod < 0).sum() / prod.numel()

tensor(0.3381)

In [133]:
prod = h1[6] * delta_h_2[6]
(prod < 0).sum() / prod.numel()

tensor(0.3008)

In [141]:
prod = h1[7][132, :] * delta_h_2[7][132, :]
(prod < 0).sum() / prod.numel()

tensor(0.6000)

In [144]:
(h1[7][1, :] < 0).sum() / h1[7][1, :].numel()

tensor(0.6000)