In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/initialization')
CONFIG_PATH = os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', 'fc_ipllr_mnist.yaml')

In [3]:
import sys
sys.path.append(ROOT)

In [4]:
import os
from copy import deepcopy
import torch
import math
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, Subset, DataLoader
import torch.nn.functional as F

from utils.tools import read_yaml, set_random_seeds
from pytorch.configs.base import BaseConfig
from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from pytorch.models.abc_params.fully_connected.muP import FCmuP
from pytorch.models.abc_params.fully_connected.ntk import FCNTK
from pytorch.models.abc_params.fully_connected.standard_fc_ip import StandardFCIP
from utils.data.mnist import load_data
from utils.abc_params.debug_ipllr import *

### Load basic configuration and define variables 

In [5]:
N_TRIALS = 1
SEED = 30
L = 6
width = 1024
n_warmup_steps = 1
batch_size = 512
base_lr = 0.1
n_steps = 50
tol = 1.0e-8

set_random_seeds(SEED)  # set random seed for reproducibility
config_dict = read_yaml(CONFIG_PATH)

In [6]:
config_dict = read_yaml(CONFIG_PATH)

input_size = config_dict['architecture']['input_size']

config_dict['architecture']['width'] = width
config_dict['architecture']['n_layers'] = L + 1
config_dict['optimizer']['params']['lr'] = base_lr
config_dict['scheduler'] = {'name': 'warmup_switch',
                            'params': {'n_warmup_steps': n_warmup_steps,
                                       'calibrate_base_lr': True,
                                       'default_calibration': False}}
        
base_model_config = ModelConfig(config_dict)

### Load data & define model

In [7]:
training_dataset, test_dataset = load_data(download=False, flatten=True)
train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size)
test_batches = list(DataLoader(test_dataset, shuffle=False, batch_size=batch_size))
batches = list(train_data_loader)
eval_batch = test_batches[0]

In [8]:
full_x = torch.cat([a for a,_ in batches], dim=0)
full_y = torch.cat([b for _,b in batches], dim=0)

### Define model

In [9]:
ipllr = FcIPLLR(base_model_config, n_warmup_steps=12, lr_calibration_batches=batches)

initial base lr : [78.5, 35.073184967041016, 59.817752838134766, 61.449581146240234, 69.81730651855469, 80.47264099121094, 24.22148323059082]


In [10]:
ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * (ipllr.d + 1)

### Save initial model : t=0

In [11]:
ipllr_0 = deepcopy(ipllr)

### Train model one step : t=1

In [12]:
x, y = batches[0]
train_model_one_step(ipllr, x, y, normalize_first=True)
ipllr_1 = deepcopy(ipllr)

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.9000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000, -0.9000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        ...,
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000, -0.9000,  0.1000]])
average training loss for model1 : 2.3025991916656494



### Train model for a second step : t=2

In [13]:
x, y = batches[1]
train_model_one_step(ipllr, x, y, normalize_first=True)
ipllr_2 = deepcopy(ipllr)

input abs mean in training:  0.6921874284744263
loss derivatives for model: tensor([[ 0.0702,  0.1263,  0.0989,  ...,  0.0898, -0.8576,  0.1143],
        [ 0.0582,  0.1394,  0.0970,  ...,  0.0840,  0.1667,  0.1202],
        [ 0.0681,  0.1285,  0.0987,  ...,  0.0888,  0.1463,  0.1153],
        ...,
        [ 0.0675,  0.1291,  0.0986,  ...,  0.0886,  0.1474,  0.1156],
        [ 0.0635, -0.8665,  0.0980,  ...,  0.0867,  0.1553,  0.1176],
        [ 0.0540,  0.1445,  0.0960,  ...,  0.0816, -0.8233,  0.1222]])
average training loss for model1 : 2.3237860202789307



In [14]:
ipllr.eval()
ipllr_0.eval()
ipllr_1.eval()
ipllr_2.eval()
print()




In [15]:
layer_scales = ipllr.layer_scales
intermediate_layer_keys = ["layer_{:,}_intermediate".format(l) for l in range(2, L + 1)]

### Define W0 and b0

In [16]:
with torch.no_grad():
    W0 = {1: layer_scales[0] * ipllr_0.input_layer.weight.data.detach() / math.sqrt(ipllr_0.d + 1)}
    for i, l in enumerate(range(2, L + 1)):
        layer = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i])
        W0[l] = layer_scales[l-1] * layer.weight.data.detach()

    W0[L+1] = layer_scales[L] * ipllr_0.output_layer.weight.data.detach()

In [17]:
with torch.no_grad():
    b0 = layer_scales[0] * ipllr_0.input_layer.bias.data.detach() / math.sqrt(ipllr_0.d + 1)

### Define Delta_W_1 and Delta_b_1

In [18]:
with torch.no_grad():
    Delta_W_1 = {1: layer_scales[0] * (ipllr_1.input_layer.weight.data.detach() -
                                       ipllr_0.input_layer.weight.data.detach()) / math.sqrt(ipllr_1.d + 1)}
    for i, l in enumerate(range(2, L + 1)):
        layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
        layer_0 = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i])
        Delta_W_1[l] = layer_scales[l-1] * (layer_1.weight.data.detach() -
                                            layer_0.weight.data.detach())

    Delta_W_1[L+1] = layer_scales[L] * (ipllr_1.output_layer.weight.data.detach() -
                                        ipllr_0.output_layer.weight.data.detach())

In [19]:
with torch.no_grad():
    Delta_b_1 = layer_scales[0] * (ipllr_1.input_layer.bias.data.detach() -
                                   ipllr_0.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

### Define Delta_W_2

In [20]:
with torch.no_grad():
    Delta_W_2 = {1: layer_scales[0] * (ipllr_2.input_layer.weight.data.detach() -
                                       ipllr_1.input_layer.weight.data.detach()) / math.sqrt(ipllr_2.d + 1)}
    for i, l in enumerate(range(2, L + 1)):
        layer_2 = getattr(ipllr_2.intermediate_layers, intermediate_layer_keys[i])
        layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
        Delta_W_2[l] = layer_scales[l-1] * (layer_2.weight.data.detach() -
                                            layer_1.weight.data.detach())

    Delta_W_2[L+1] = layer_scales[L] * (ipllr_2.output_layer.weight.data.detach() -
                                        ipllr_1.output_layer.weight.data.detach())

In [21]:
with torch.no_grad():
    Delta_b_2 = layer_scales[0] * (ipllr_2.input_layer.bias.data.detach() -
                                   ipllr_1.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

## Explore at step 1

### On all training examples

In [22]:
x, y = full_x, full_y

In [23]:
with torch.no_grad():
    x1 = {0: x}
    h0 = {1: F.linear(x, W0[1], b0)}
    delta_h_1 = {1: F.linear(x, Delta_W_1[1], Delta_b_1)}
    h1 = {1: layer_scales[0] * ipllr_1.input_layer.forward(x) / math.sqrt(ipllr_1.d + 1)}
    x1[1] = ipllr_1.activation(h1[1])

In [24]:
torch.testing.assert_allclose(h0[1] + delta_h_1[1], h1[1], rtol=1e-5, atol=1e-5)

In [25]:
with torch.no_grad():
    for i, l in enumerate(range(2, L + 1)):
        layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
        x = x1[l-1]

        h0[l] =  F.linear(x, W0[l])
        delta_h_1[l] = F.linear(x, Delta_W_1[l])
        
        h1[l] = layer_scales[l-1] * layer_1.forward(x)
        x1[l] = ipllr_1.activation(h1[l])
        
        torch.testing.assert_allclose(h0[l] + delta_h_1[l], h1[l], rtol=1e-5, atol=1e-5)

In [26]:
with torch.no_grad():
    x = x1[L] 
    h0[L+1] = F.linear(x, W0[L+1])
    delta_h_1[L+1] = F.linear(x, Delta_W_1[L+1])
    h1[L+1] = layer_scales[L] * ipllr_1.output_layer.forward(x)
    x1[L+1] = ipllr_1.activation(h1[L+1])
                              
    torch.testing.assert_allclose(h0[L+1] + delta_h_1[L+1], h1[L+1], rtol=1e-5, atol=1e-5)

## Ranks after one step

In [27]:
columns = ['W0', 'Delta_W_1', 'Delta_W_2', 'max']
df = pd.DataFrame(columns=columns, index=range(1, L+2))
df.index.name = 'layer'

for l in df.index:
    df.loc[l, columns] = [torch.matrix_rank(W0[l], tol=tol).item(), 
                          torch.matrix_rank(Delta_W_1[l], tol=tol).item(), 
                          torch.matrix_rank(Delta_W_2[l], tol=tol).item(),
                          min(W0[l].shape[0], W0[l].shape[1])]
    
df.loc[:, 'batch_size'] = batch_size
df

Unnamed: 0_level_0,W0,Delta_W_1,Delta_W_2,max,batch_size
layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,784,783,770,784,512
2,1024,643,350,1024,512
3,1024,708,57,1024,512
4,1024,689,16,1024,512
5,1024,681,12,1024,512
6,1024,683,11,1024,512
7,10,9,3,10,512


In [29]:
print(Delta_W_1[1].shape)
print(torch.matrix_rank(Delta_W_1[1], tol=tol))

torch.Size([1024, 784])
tensor(783)


In [32]:
print(Delta_W_1[1].shape)
print(torch.matrix_rank(Delta_W_1[1], tol=1e-4))

torch.Size([1024, 784])
tensor(509)


In [34]:
print(Delta_W_1[1].shape)
print(torch.matrix_rank(Delta_W_1[1]))

torch.Size([1024, 784])
tensor(492)


In [36]:
import sympy

In [39]:
m = sympy.Matrix(Delta_W_1[1].numpy().tolist()).rref()

KeyboardInterrupt: 

In [None]:
len(m[1])

In [40]:
import time

In [None]:
time.