In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#%matplotlib notebook
%matplotlib inline

In [3]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/training')

In [4]:
import sys
sys.path.append(ROOT)

In [None]:
import torch
import pickle
import pandas as pd
import matplotlib.pylab as pylab
from copy import deepcopy

from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from utils.tools import *

## Set variables

In [None]:
L = 8
WIDTH = 1024  # 8000
BASE_LR = 1.0
N_STEPS = 2
BIAS = False
ACTIVATION = 'tanh'
CONFIG_FILE = 'fc_ipllr_mnist.yaml'

DIM = 20
OUTPUT_DIM = 1
LOSS = 'mse'
N_VAL = 100
SEED = 42

FONTSIZE = 12
FIGSIZE = (10, 6)

fig_dir = os.path.join(ROOT, FIGURES_DIR, 'linearization')

params = {'legend.fontsize': FONTSIZE,
         'axes.labelsize': FONTSIZE,
         'axes.titlesize': FONTSIZE,
         'xtick.labelsize': FONTSIZE,
         'ytick.labelsize': FONTSIZE}
pylab.rcParams.update(params)

In [None]:
set_random_seeds(SEED)

## Model and data

In [None]:
x_train = torch.randn(size=(N_STEPS+1, 1, DIM), requires_grad=False)
y_train = torch.ones(size=(N_STEPS+1, 1, OUTPUT_DIM), requires_grad=False) / 2

x_val = torch.randn(size=(N_STEPS, 1, DIM), requires_grad=False)
y_val = torch.ones(size=(N_STEPS, 1, OUTPUT_DIM), requires_grad=False) / 2

In [None]:
config_dict = read_yaml(os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', CONFIG_FILE))
config_dict['architecture']['width'] = WIDTH
config_dict['architecture']['n_layers'] = L + 1
config_dict['architecture']['input_size'] = DIM
config_dict['architecture']['output_size'] = OUTPUT_DIM
config_dict['optimizer']['params']['lr'] = BASE_LR
config_dict['activation']['name'] = ACTIVATION
config_dict['loss'] = {'name': 'mse', 'params': {'reduction': 'mean'}}
config_dict['scheduler']['params']['calibrate_base_lr'] = False

In [None]:
config = ModelConfig(config_dict=config_dict)

In [None]:
model = FcIPLLR(config)
pg = list(model.optimizer.param_groups)
pg[0]['lr'] = pg[0]['lr'] / DIM
#for l in range(2, L):
#    pg[l]['lr'] = pg[l]['lr'] * (WIDTH ** 0.5)
model_0 = deepcopy(model)

In [None]:
def forward_backward(x, y):
    h_grads = []
    x_grads = []

    hs = []
    xs = []
    
    model.optimizer.zero_grad()
    h = (model.width ** (-model.a[0])) * model.input_layer.forward(x)  # h_0 first layer pre-activations
    hs.append(h)

    x = model.activation(h)  # x_0, first layer activations
    xs.append(x)

    for l, layer in enumerate(model.intermediate_layers):  # L-1 intermediate layers
        h = (model.width ** (-model.a[l+1])) * layer.forward(x)  # h_l, layer l pre-activations
        hs.append(h)
        x = model.activation(h)  # x_l, l-th layer activations
        xs.append(x)
    
    for h_ in hs:
        h_.retain_grad()
    for x_ in xs:
        x_.retain_grad()
        
    y_hat = (model.width ** (-model.a[model.n_layers-1])) * model.output_layer.forward(x)  # f(x)
    y_hat.retain_grad()
    
    loss_ = model.loss(y_hat, y)
    loss_.backward()
    
    h_grads = [h_.grad for h_ in hs]
    x_grads = [x_.grad for x_ in xs]
    
    y_hat_grad = y_hat.grad
    
    return  hs, xs, y_hat, h_grads, x_grads, y_hat_grad, loss_

## 1st forward backward

In [None]:
hs0, xs0, y_hat0, h_grads0, x_grads0, y_hat_grad0, loss_0 = forward_backward(x_train[0, : ,:], y_train[0, :, :])

In [None]:
tilde_hs0 = []
tilde_xs0 = []
tilde_h_grads0 = []
tilde_x_grads0 = []

with torch.no_grad():
    for l in range(L):
        forward_scale = WIDTH ** (l/2)
        tilde_hs0.append(forward_scale * hs0[l])
        tilde_xs0.append(forward_scale * xs0[l])
        
        backward_scale = WIDTH * (WIDTH ** ((L-(l+1)) / 2))
        tilde_h_grads0.append(backward_scale * h_grads0[l])
        tilde_x_grads0.append(backward_scale * x_grads0[l])
        
    tilde_y_hat0 = (WIDTH ** (L/2)) * y_hat0

In [None]:
with torch.no_grad():
    for tilde_h in tilde_hs0:
        print(torch.sum((tilde_h)**2).detach().item() / WIDTH)
    print('')

    for tilde_x in tilde_xs0:
        print(torch.sum((tilde_x)**2).detach().item() / WIDTH)
        
    print((tilde_y_hat0.detach().item())**2)

In [None]:
with torch.no_grad():
    for tilde_h_grad in tilde_h_grads0:
        print(torch.sum((tilde_h_grad)**2).detach().item() / WIDTH)
    print('')

    for tilde_x_grad in tilde_x_grads0:
        print(torch.sum((tilde_x_grad)**2).detach().item() / WIDTH)
        
    print((y_hat_grad0.detach().item())**2)

In [None]:
with torch.no_grad():
    for h in hs0:
        print(np.sqrt(torch.sum((h)**2).detach().item() / WIDTH))
    print('')

    for x in xs0:
        print(np.sqrt(torch.sum((x)**2).detach().item() / WIDTH))
        
    print(y_hat0.detach().item())

## 1st optimizer step : weight updates

In [None]:
model.optimizer.step()
model.scheduler.step()
model_1 = deepcopy(model)

## 2nd forward backward

In [None]:
hs1, xs1, y_hat1, h_grads1, x_grads1, y_hat_grad1, loss_1 = forward_backward(x_train[1, : ,:], y_train[1, :, :])

In [None]:
with torch.no_grad():
    for h in hs1:
        print(np.sqrt(torch.sum((h)**2).detach().item() / WIDTH))
    print('')

    for x in xs1:
        print(np.sqrt(torch.sum((x)**2).detach().item() / WIDTH))
        
    print(y_hat1.detach().item())

In [None]:
with torch.no_grad():
    for l in range(L):
        print(torch.sum(tilde_xs0[l] * xs1[l]).detach().item() / WIDTH)
    print(model_0.output_layer(xs1[2]).detach().item() / WIDTH)

In [None]:
with torch.no_grad():
    for l in range(L):
        if l == 0:
            print(torch.sum(tilde_xs0[l] * xs1[l]).detach().item() / WIDTH)
        else:
            print(torch.sum(tilde_xs0[l] * xs1[l]).detach().item() / np.sqrt(WIDTH))
    print(model_0.output_layer(xs1[2]).detach().item() / np.sqrt(WIDTH))

In [None]:
with torch.no_grad():
    print(torch.sum(tilde_xs0[0] * xs1[0]).detach().item() / WIDTH)
    print(torch.sum(tilde_xs0[1] * xs1[1]).detach().item() / np.sqrt(WIDTH))
    print(torch.sum(tilde_xs0[2] * xs1[2]).detach().item() / np.sqrt(WIDTH))
    
    print(model_0.output_layer(xs1[2]).detach().item() / np.sqrt(WIDTH))

In [None]:
with torch.no_grad():
    for h_grad in h_grads1:
        print(np.sqrt(torch.sum((h_grad)**2).detach().item() / WIDTH))
    print('')

    for x_grad in x_grads1:
        print(np.sqrt(torch.sum((x_grad)**2).detach().item() / WIDTH))
        
    print(y_hat_grad1.detach().item())

## 2nd optimizer step : weight updates

In [None]:
with torch.no_grad():
    print([g['lr'] for g in model.optimizer.param_groups])
    #for g in model.optimizer.param_groups:
    #    g['lr'] = g['lr'] * 0.0001
    #print([g['lr'] for g in model.optimizer.param_groups])

In [None]:
model.optimizer.step()
model_2 = deepcopy(model)

## 3rd forward pass

In [None]:
hs2, xs2, y_hat2, h_grads2, x_grads2, y_hat_grad2, loss_2 = forward_backward(x_train[2, : ,:], y_train[2, :, :])

In [None]:
with torch.no_grad():
    for h in hs2:
        print(np.sqrt(torch.sum((h)**2).detach().item() / WIDTH))
    print('')

    for x in xs2:
        print(np.sqrt(torch.sum((x)**2).detach().item() / WIDTH))
        
    print(y_hat2.detach().item())

In [None]:
with torch.no_grad():
    for l in range(L):
        print(torch.sum(tilde_xs0[l] * xs2[l]).detach().item() / WIDTH)

    
    print(model_0.output_layer(xs2[2]).detach().item() / WIDTH)

In [None]:
with torch.no_grad():
    for l in range(L):
        print(torch.sum(xs1[l] * xs2[l]).detach().item() / WIDTH)
    
    print(model_1.output_layer(xs2[2]).detach().item() / WIDTH)