In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#%matplotlib notebook
%matplotlib inline

In [3]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/training')

In [4]:
import sys
sys.path.append(ROOT)

In [5]:
import torch
import pickle
import pandas as pd
import matplotlib.pylab as pylab
from copy import deepcopy

from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from utils.tools import *

  from .autonotebook import tqdm as notebook_tqdm


## Set variables

In [6]:
L = 8
WIDTH = 1024  # 8000
BASE_LR = 1.0
N_STEPS = 2
BIAS = False
ACTIVATION = 'sigmoid'
CONFIG_FILE = 'fc_ipllr_mnist.yaml'

DIM = 20
OUTPUT_DIM = 1
LOSS = 'mse'
N_VAL = 100
SEED = 42

FONTSIZE = 12
FIGSIZE = (10, 6)

fig_dir = os.path.join(ROOT, FIGURES_DIR, 'linearization')

params = {'legend.fontsize': FONTSIZE,
         'axes.labelsize': FONTSIZE,
         'axes.titlesize': FONTSIZE,
         'xtick.labelsize': FONTSIZE,
         'ytick.labelsize': FONTSIZE}
pylab.rcParams.update(params)

In [7]:
set_random_seeds(SEED)

## Model and data

In [8]:
x_train = torch.randn(size=(N_STEPS+1, 1, DIM), requires_grad=False)
y_train = torch.ones(size=(N_STEPS+1, 1, OUTPUT_DIM), requires_grad=False) / 2

x_val = torch.randn(size=(N_STEPS, 1, DIM), requires_grad=False)
y_val = torch.ones(size=(N_STEPS, 1, OUTPUT_DIM), requires_grad=False) / 2

In [9]:
config_dict = read_yaml(os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', CONFIG_FILE))
config_dict['architecture']['width'] = WIDTH
config_dict['architecture']['n_layers'] = L + 1
config_dict['architecture']['input_size'] = DIM
config_dict['architecture']['output_size'] = OUTPUT_DIM
config_dict['optimizer']['params']['lr'] = BASE_LR
config_dict['activation']['name'] = ACTIVATION
config_dict['loss'] = {'name': 'mse', 'params': {'reduction': 'mean'}}
config_dict['scheduler']['params']['calibrate_base_lr'] = False

In [10]:
config = ModelConfig(config_dict=config_dict)

In [11]:
model = FcIPLLR(config)
pg = list(model.optimizer.param_groups)
pg[0]['lr'] = pg[0]['lr'] / DIM
#for l in range(2, L):
#    pg[l]['lr'] = pg[l]['lr'] * (WIDTH ** 0.5)
model_0 = deepcopy(model)

In [12]:
def forward_backward(x, y):
    h_grads = []
    x_grads = []

    hs = []
    xs = []
    
    model.optimizer.zero_grad()
    h = (model.width ** (-model.a[0])) * model.input_layer.forward(x)  # h_0 first layer pre-activations
    hs.append(h)

    x = model.activation(h)  # x_0, first layer activations
    xs.append(x)

    for l, layer in enumerate(model.intermediate_layers):  # L-1 intermediate layers
        h = (model.width ** (-model.a[l+1])) * layer.forward(x)  # h_l, layer l pre-activations
        hs.append(h)
        x = model.activation(h)  # x_l, l-th layer activations
        xs.append(x)
    
    for h_ in hs:
        h_.retain_grad()
    for x_ in xs:
        x_.retain_grad()
        
    y_hat = (model.width ** (-model.a[model.n_layers-1])) * model.output_layer.forward(x)  # f(x)
    y_hat.retain_grad()
    
    loss_ = model.loss(y_hat, y)
    loss_.backward()
    
    h_grads = [h_.grad for h_ in hs]
    x_grads = [x_.grad for x_ in xs]
    
    y_hat_grad = y_hat.grad
    
    return  hs, xs, y_hat, h_grads, x_grads, y_hat_grad, loss_

## 1st forward backward

In [13]:
hs0, xs0, y_hat0, h_grads0, x_grads0, y_hat_grad0, loss_0 = forward_backward(x_train[0, : ,:], y_train[0, :, :])

In [14]:
tilde_hs0 = []
tilde_xs0 = []
tilde_h_grads0 = []
tilde_x_grads0 = []

with torch.no_grad():
    for l in range(L):
        forward_scale = WIDTH ** (l/2)
        tilde_hs0.append(forward_scale * hs0[l])
        tilde_xs0.append(forward_scale * xs0[l])
        
        backward_scale = WIDTH * (WIDTH ** ((L-(l+1)) / 2))
        tilde_h_grads0.append(backward_scale * h_grads0[l])
        tilde_x_grads0.append(backward_scale * x_grads0[l])
        
    tilde_y_hat0 = (WIDTH ** (L/2)) * y_hat0

In [15]:
with torch.no_grad():
    for tilde_h in tilde_hs0:
        print(torch.sum((tilde_h)**2).detach().item() / WIDTH)
    print('')

    for tilde_x in tilde_xs0:
        print(torch.sum((tilde_x)**2).detach().item() / WIDTH)
        
    print((tilde_y_hat0.detach().item())**2)

4.698747634887695
0.7188931107521057
482.9482727050781
518127.65625
546759232.0
556967133184.0
540849393893376.0
5.763604562271273e+17

0.34303274750709534
256.25225830078125
262388.6875
268436192.0
274880626688.0
281318512394240.0
2.8809505032214938e+17
2.953098588440777e+20
4.394934928351072e+20


In [16]:
with torch.no_grad():
    for tilde_h_grad in tilde_h_grads0:
        print(torch.sum((tilde_h_grad)**2).detach().item() / WIDTH)
    print('')

    for tilde_x_grad in tilde_x_grads0:
        print(torch.sum((tilde_x_grad)**2).detach().item() / WIDTH)
        
    print((y_hat_grad0.detach().item())**2)

1.0266147043580531e-08
1.8961415548801597e-07
1.6430534515166073e-06
1.2515798516687937e-05
9.844550368143246e-05
0.0008299332694150507
0.006708770990371704
0.05289517343044281

3.5941593523602933e-07
3.0349037842825055e-06
2.6294979761587456e-05
0.00020029988081660122
0.001575519680045545
0.013282040134072304
0.10736584663391113
0.8465492725372314
0.9251871599622206


In [17]:
with torch.no_grad():
    for h in hs0:
        print(np.sqrt(torch.sum((h)**2).detach().item() / WIDTH))
    print('')

    for x in xs0:
        print(np.sqrt(torch.sum((x)**2).detach().item() / WIDTH))
        
    print(y_hat0.detach().item())

2.1676594831494396
0.026496113931458953
0.021461019636345207
0.021966883775315825
0.022299655258533377
0.022241541946912635
0.02165899513544738
0.022095164537040607

0.5856899755904103
0.5002462853403878
0.5002332977564525
0.5000006854529451
0.5000024735866395
0.499861012453168
0.4998826097728106
0.5001371612767347
0.019066737964749336


## 1st optimizer step : weight updates

In [18]:
model.optimizer.step()
model.scheduler.step()
model_1 = deepcopy(model)

## 2nd forward backward

In [19]:
hs1, xs1, y_hat1, h_grads1, x_grads1, y_hat_grad1, loss_1 = forward_backward(x_train[1, : ,:], y_train[1, :, :])

In [20]:
with torch.no_grad():
    for h in hs1:
        print(np.sqrt(torch.sum((h)**2).detach().item() / WIDTH))
    print('')

    for x in xs1:
        print(np.sqrt(torch.sum((x)**2).detach().item() / WIDTH))
        
    print(y_hat1.detach().item())

1.9267763735326202
0.025843651392611935
0.024114697395927604
0.9075732272005221
80.9285709360421
7664.12865236486
669404.6828160078
60855934.83142455

0.5739621737490722
0.500207828006174
0.5003017468699691
0.5346124210853432
0.7091570767323151
0.6980721219902711
0.7022564079451322
0.7173900177030623
8508037120.0


In [21]:
with torch.no_grad():
    for l in range(L):
        print(torch.sum(tilde_xs0[l] * xs1[l]).detach().item() / WIDTH)
    print(model_0.output_layer(xs1[2]).detach().item() / WIDTH)

0.2810658812522888
8.00703239440918
256.2706298828125
8155.666015625
266035.875
8172732.0
264603104.0
8845340672.0
0.019277792423963547


In [22]:
with torch.no_grad():
    for l in range(L):
        if l == 0:
            print(torch.sum(tilde_xs0[l] * xs1[l]).detach().item() / WIDTH)
        else:
            print(torch.sum(tilde_xs0[l] * xs1[l]).detach().item() / np.sqrt(WIDTH))
    print(model_0.output_layer(xs1[2]).detach().item() / np.sqrt(WIDTH))

0.2810658812522888
256.22503662109375
8200.66015625
260981.3125
8513148.0
261527424.0
8467299328.0
283050901504.0
0.6168893575668335


In [23]:
with torch.no_grad():
    print(torch.sum(tilde_xs0[0] * xs1[0]).detach().item() / WIDTH)
    print(torch.sum(tilde_xs0[1] * xs1[1]).detach().item() / np.sqrt(WIDTH))
    print(torch.sum(tilde_xs0[2] * xs1[2]).detach().item() / np.sqrt(WIDTH))
    
    print(model_0.output_layer(xs1[2]).detach().item() / np.sqrt(WIDTH))

0.2810658812522888
256.22503662109375
8200.66015625
0.6168893575668335


In [24]:
with torch.no_grad():
    for h_grad in h_grads1:
        print(np.sqrt(torch.sum((h_grad)**2).detach().item() / WIDTH))
    print('')

    for x_grad in x_grads1:
        print(np.sqrt(torch.sum((x_grad)**2).detach().item() / WIDTH))
        
    print(y_hat_grad1.detach().item())

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0

0.0
0.0
0.0
0.0
0.0
0.0
0.0
2.746712305533744e+17
17016074240.0


## 2nd optimizer step : weight updates

In [25]:
with torch.no_grad():
    print([g['lr'] for g in model.optimizer.param_groups])
    #for g in model.optimizer.param_groups:
    #    g['lr'] = g['lr'] * 0.0001
    #print([g['lr'] for g in model.optimizer.param_groups])

[1024.0, 1048576.0, 1048576.0, 1048576.0, 1048576.0, 1048576.0, 1048576.0, 1048576.0, 1024.0]


In [26]:
model.optimizer.step()
model_2 = deepcopy(model)

## 3rd forward pass

In [27]:
hs2, xs2, y_hat2, h_grads2, x_grads2, y_hat_grad2, loss_2 = forward_backward(x_train[2, : ,:], y_train[2, :, :])

In [28]:
with torch.no_grad():
    for h in hs2:
        print(np.sqrt(torch.sum((h)**2).detach().item() / WIDTH))
    print('')

    for x in xs2:
        print(np.sqrt(torch.sum((x)**2).detach().item() / WIDTH))
        
    print(y_hat2.detach().item())

1.9329723245817247
0.02653240171056875
0.024112040880760945
0.9075758213494385
80.92856791929998
7664.12865236486
669404.6828160078
60855934.83142455

0.58707340762283
0.5002197736294834
0.5003032063008787
0.5346125604494756
0.7091570767323151
0.6980721219902711
0.7022564079451322
0.7173900177030623
-249259392.0


In [29]:
with torch.no_grad():
    for l in range(L):
        print(torch.sum(tilde_xs0[l] * xs2[l]).detach().item() / WIDTH)

    
    print(model_0.output_layer(xs2[2]).detach().item() / WIDTH)

0.2872248888015747
8.007235527038574
256.27142333984375
8155.666015625
266035.875
8172732.0
264603104.0
8845340672.0
0.019279837608337402


In [30]:
with torch.no_grad():
    for l in range(L):
        print(torch.sum(xs1[l] * xs2[l]).detach().item() / WIDTH)
    
    print(model_1.output_layer(xs2[2]).detach().item() / WIDTH)

0.2874434292316437
0.2502078115940094
0.2503025531768799
0.28581053018569946
0.5029037594795227
0.4873046875
0.4931640625
0.5146484375
8268525568.0
