In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/initialization')
CONFIG_PATH = os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', 'fc_ipllr_mnist.yaml')

In [3]:
import sys
sys.path.append(ROOT)

In [4]:
import os
from copy import deepcopy
import torch
import math
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, Subset, DataLoader
import torch.nn.functional as F

from utils.tools import read_yaml, set_random_seeds
from pytorch.configs.base import BaseConfig
from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from pytorch.models.abc_params.fully_connected.muP import FCmuP
from pytorch.models.abc_params.fully_connected.ntk import FCNTK
from pytorch.models.abc_params.fully_connected.standard_fc_ip import StandardFCIP
from utils.data.mnist import load_data
from utils.abc_params.debug_ipllr import *

### Load basic configuration and define variables 

In [5]:
N_TRIALS = 1
SEED = 30
L = 6
width = 1024
n_warmup_steps = 1
batch_size = 512
base_lr = 0.001
n_steps = 50

set_random_seeds(SEED)  # set random seed for reproducibility
config_dict = read_yaml(CONFIG_PATH)

In [6]:
config_dict = read_yaml(CONFIG_PATH)

input_size = config_dict['architecture']['input_size']

config_dict['architecture']['width'] = width
config_dict['architecture']['n_layers'] = L + 1
config_dict['optimizer']['params']['lr'] = base_lr
config_dict['scheduler'] = {'name': 'warmup_switch',
                            'params': {'n_warmup_steps': n_warmup_steps,
                                       'calibrate_base_lr': True,
                                       'default_calibration': False}}
        
base_model_config = ModelConfig(config_dict)

### Load data & define model

In [7]:
training_dataset, test_dataset = load_data(download=False, flatten=True)
train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size)
test_batches = list(DataLoader(test_dataset, shuffle=False, batch_size=batch_size))
batches = list(train_data_loader)
eval_batch = test_batches[0]

### Define model

In [8]:
ipllr = FcIPLLR(base_model_config, n_warmup_steps=12, lr_calibration_batches=batches)

initial base lr : [69.26097106933594, 36.901771545410156, 60.06058120727539, 61.465023040771484, 69.81842803955078, 80.47272491455078, 242.21490478515625]


### Save initial model : t=0

In [9]:
ipllr_0 = deepcopy(ipllr)

### Train model one step : t=1

In [10]:
x, y = batches[0]
train_model_one_step(ipllr, x, y, batch_size)
ipllr_1 = deepcopy(ipllr)

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.9000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000, -0.9000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        ...,
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000, -0.9000,  0.1000]])
average training loss for model1 : 2.3025991916656494



### Train model for a second step : t=2

In [11]:
x, y = batches[1]
train_model_one_step(ipllr, x, y, batch_size)
ipllr_2 = deepcopy(ipllr)

input abs mean in training:  0.6921874284744263
loss derivatives for model: tensor([[ 0.0096,  0.2137,  0.0487,  ...,  0.0331, -0.5867,  0.1340],
        [ 0.0021,  0.2148,  0.0237,  ...,  0.0133,  0.5745,  0.1072],
        [ 0.0074,  0.2168,  0.0432,  ...,  0.0284,  0.4452,  0.1304],
        ...,
        [ 0.0071,  0.2172,  0.0424,  ...,  0.0277,  0.4501,  0.1298],
        [ 0.0045, -0.7809,  0.0341,  ...,  0.0210,  0.5019,  0.1219],
        [ 0.0012,  0.2074,  0.0176,  ...,  0.0093, -0.3769,  0.0953]])
average training loss for model1 : 3.235839366912842



In [12]:
ipllr.eval()
ipllr_0.eval()
ipllr_1.eval()
ipllr_2.eval()
print()




In [13]:
layer_scales = ipllr.layer_scales
intermediate_layer_keys = ["layer_{:,}_intermediate".format(l) for l in range(2, L + 1)]

### Define W0 and b0

In [14]:
W0 = {1: layer_scales[0] * ipllr_0.input_layer.weight.data.detach() / math.sqrt(ipllr_0.d + 1)}
for i, l in enumerate(range(2, L + 1)):
    layer = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i])
    W0[l] = layer_scales[l-1] * layer.weight.data.detach()

W0[L+1] = layer_scales[L] * ipllr_0.output_layer.weight.data.detach()

In [15]:
b0 = layer_scales[0] * ipllr_0.input_layer.bias.data.detach() / math.sqrt(ipllr_0.d + 1)

### Define Delta_W_1 and Delta_b_1

In [16]:
Delta_W_1 = {1: layer_scales[0] * (ipllr_1.input_layer.weight.data.detach() -
                                   ipllr_0.input_layer.weight.data.detach()) / math.sqrt(ipllr_1.d + 1)}
for i, l in enumerate(range(2, L + 1)):
    layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
    layer_0 = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i])
    Delta_W_1[l] = layer_scales[l-1] * (layer_1.weight.data.detach() -
                                        layer_0.weight.data.detach())

Delta_W_1[L+1] = layer_scales[L] * (ipllr_1.output_layer.weight.data.detach() -
                                    ipllr_0.output_layer.weight.data.detach())

In [17]:
Delta_b_1 = layer_scales[0] * (ipllr_1.input_layer.bias.data.detach() -
                               ipllr_0.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

### Define Delta_W_2

In [18]:
Delta_W_2 = {1: layer_scales[0] * (ipllr_2.input_layer.weight.data.detach() -
                                   ipllr_1.input_layer.weight.data.detach()) / math.sqrt(ipllr_2.d + 1)}
for i, l in enumerate(range(2, L + 1)):
    layer_2 = getattr(ipllr_2.intermediate_layers, intermediate_layer_keys[i])
    layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
    Delta_W_2[l] = layer_scales[l-1] * (layer_2.weight.data.detach() -
                                        layer_1.weight.data.detach())

Delta_W_2[L+1] = layer_scales[L] * (ipllr_2.output_layer.weight.data.detach() -
                                    ipllr_1.output_layer.weight.data.detach())

In [19]:
Delta_b_2 = layer_scales[0] * (ipllr_2.input_layer.bias.data.detach() -
                               ipllr_1.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

## Explore at step 2

### On examples from the second batch

In [20]:
x, y = batches[1]

In [23]:
with torch.no_grad():
    x2 = {0: x}
    h0 = {1: F.linear(x, W0[1], b0)}
    delta_h_1 = {1: F.linear(x, Delta_W_1[1], Delta_b_1)}
    delta_h_2 = {1: F.linear(x, Delta_W_2[1], Delta_b_2)}
    h1 = {1: layer_scales[0] * ipllr_1.input_layer.forward(x) / math.sqrt(ipllr_1.d + 1)}
    h2 = {1: layer_scales[0] * ipllr_2.input_layer.forward(x) / math.sqrt(ipllr_2.d + 1)}
    x2[1] = ipllr_2.activation(h2[1])

In [24]:
torch.testing.assert_allclose(h0[1] + delta_h_1[1], h1[1], rtol=1e-5, atol=1e-5)
torch.testing.assert_allclose(h0[1] + delta_h_1[1] + delta_h_2[1], h2[1], rtol=1e-5, atol=1e-5)

In [25]:
prod_1 = delta_h_1[1] * delta_h_2[1]

In [26]:
(prod_1 < 0).sum() / prod_1.numel()

tensor(0.5108)

In [27]:
h0[1][0, :5]

tensor([ 0.4677,  0.8577,  0.4337, -1.1502, -1.4598])

In [28]:
delta_h_1[1][0, :5]

tensor([ 0.1314, -0.9324,  0.5104,  0.6791,  0.8944])

In [29]:
delta_h_2[1][0, :5]

tensor([-1.8309e-04, -1.4756e-04, -1.4610e-04, -1.1224e-04, -4.4122e-05])

In [67]:
(delta_h_1[1] < 0).sum() / delta_h_1[1].numel()

tensor(0.4914)

In [30]:
(delta_h_2[1] < 0).sum() / delta_h_2[1].numel()

tensor(0.9891)

In [49]:
h0[1][0, :].abs().mean()

tensor(0.9147)

In [50]:
h1[1][0, :].abs().mean()

tensor(1.2164)

In [51]:
h2[1][0, :].abs().mean()

tensor(1.2163)

In [37]:
with torch.no_grad():
    for i, l in enumerate(range(2, L + 1)):
        layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
        layer_2 = getattr(ipllr_2.intermediate_layers, intermediate_layer_keys[i])
        x = x2[l-1]

        h0[l] =  F.linear(x, W0[l])
        delta_h_1[l] = F.linear(x, Delta_W_1[l])
        delta_h_2[l] = F.linear(x, Delta_W_2[l])
        
        h1[l] = layer_scales[l-1] * layer_1.forward(x)
        h2[l] = layer_scales[l-1] * layer_2.forward(x)
        x2[l] = ipllr_2.activation(h2[l])
        
        torch.testing.assert_allclose(h0[l] + delta_h_1[l], h1[l], rtol=1e-5, atol=1e-5)
        torch.testing.assert_allclose(h0[l] + delta_h_1[l] + delta_h_2[l], h2[l], rtol=1e-5, atol=1e-5)

In [38]:
prod_1 = delta_h_1[2] * delta_h_2[2]

In [39]:
(prod_1 < 0).sum() / prod_1.numel()

tensor(0.4838)

In [40]:
h0[2][0, :10]

tensor([-0.0103,  0.0352,  0.0609, -0.0158, -0.0485,  0.0913, -0.0283,  0.0037,
        -0.0649, -0.0334])

In [41]:
delta_h_1[2][0, :10]

tensor([-0.2592,  1.2580,  0.7624, -0.5873, -0.3028,  1.3465,  0.0032,  0.0545,
        -0.0993,  0.3867])

In [42]:
delta_h_2[1][0, :10]

tensor([-1.8309e-04, -1.4756e-04, -1.4610e-04, -1.1224e-04, -4.4122e-05,
        -5.0859e-04, -1.7721e-04, -6.4048e-06, -4.1427e-04, -5.5448e-04])

In [68]:
(delta_h_2[1] < 0).sum() / delta_h_2[1].numel()

tensor(0.9891)

In [55]:
(delta_h_2[2] < 0).sum() / delta_h_2[2].numel()

tensor(0.6173)

In [56]:
(delta_h_2[3] < 0).sum() / delta_h_2[3].numel()

tensor(0.5020)

In [59]:
(delta_h_2[4] < 0).sum() / delta_h_2[4].numel()

tensor(0.4131)

In [60]:
(delta_h_2[5] < 0).sum() / delta_h_2[4].numel()

tensor(0.3789)

In [61]:
(delta_h_2[6] < 0).sum() / delta_h_2[6].numel()

tensor(0.3496)

In [62]:
(delta_h_2[2][0, :] < 0).sum() / delta_h_2[2][0, :].numel()

tensor(0.6162)

In [63]:
(delta_h_2[3][0, :] < 0).sum() / delta_h_2[3][0, :].numel()

tensor(0.5020)

In [64]:
(delta_h_2[4][0, :] < 0).sum() / delta_h_2[4][0, :].numel()

tensor(0.4131)

In [65]:
(delta_h_2[5][0, :] < 0).sum() / delta_h_2[4][0, :].numel()

tensor(0.3789)

In [66]:
(delta_h_2[6][0, :] < 0).sum() / delta_h_2[6][0, :].numel()

tensor(0.3496)