In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/debug_ipllr_renorm')
CONFIG_PATH = os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', 'fc_ipllr_mnist.yaml')

In [3]:
import sys
sys.path.append(ROOT)

In [4]:
import os
from copy import deepcopy
import torch
import math
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, Subset, DataLoader
import torch.nn.functional as F

from utils.tools import read_yaml, set_random_seeds
from pytorch.configs.base import BaseConfig
from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from pytorch.models.abc_params.fully_connected.muP import FCmuP
from pytorch.models.abc_params.fully_connected.ntk import FCNTK
from pytorch.models.abc_params.fully_connected.standard_fc_ip import StandardFCIP
from utils.data.mnist import load_data
from utils.abc_params.debug_ipllr import *
from utils.plot.abc_parameterizations.debug_ipllr import *

### Load basic configuration and define variables 

In [5]:
N_TRIALS = 5
SEED = 30
L = 6
width = 1024
n_warmup_steps = 1
batch_size = 512
base_lr = 0.1
n_steps = 50

set_random_seeds(SEED)  # set random seed for reproducibility
config_dict = read_yaml(CONFIG_PATH)

In [6]:
config_dict = read_yaml(CONFIG_PATH)

input_size = config_dict['architecture']['input_size']

config_dict['architecture']['width'] = width
config_dict['architecture']['n_layers'] = L + 1
config_dict['optimizer']['params']['lr'] = base_lr
config_dict['scheduler'] = {'name': 'warmup_switch',
                            'params': {'n_warmup_steps': n_warmup_steps,
                                       'calibrate_base_lr': True,
                                       'default_calibration': False}}
        
base_model_config = ModelConfig(config_dict)

### Load data & define models

In [7]:
training_dataset, test_dataset = load_data(download=False, flatten=True)
train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size)
test_batches = list(DataLoader(test_dataset, shuffle=False, batch_size=batch_size))
batches = list(train_data_loader)
eval_batch = test_batches[0]

In [8]:
ipllrs = [FcIPLLR(base_model_config, n_warmup_steps=12, lr_calibration_batches=batches) for _ in range(N_TRIALS)]
base_model_config.scheduler = None
muPs = [FCmuP(base_model_config) for _ in range(N_TRIALS)]

initial base lr : [69.26097106933594, 36.901771545410156, 60.06058120727539, 61.465023040771484, 69.81842803955078, 80.47272491455078, 242.21490478515625]
initial base lr : [68.2101821899414, 38.062652587890625, 67.09164428710938, 72.38660430908203, 78.3447494506836, 93.293212890625, 289.7236633300781]
initial base lr : [69.08634948730469, 38.01998519897461, 66.5303726196289, 69.7338638305664, 72.82850646972656, 86.75313568115234, 247.78399658203125]
initial base lr : [69.98839569091797, 38.506282806396484, 67.24036407470703, 63.19353485107422, 76.40673065185547, 94.85008239746094, 269.35626220703125]
initial base lr : [68.72810363769531, 36.14960479736328, 60.415069580078125, 66.5676498413086, 64.38241577148438, 73.41007995605469, 205.81072998046875]


In [9]:
for ipllr in ipllrs:
    ipllr.scheduler.calibrate_base_lr(ipllr, batches=batches)

initial base lr : [69.26097106933594, 36.901771545410156, 60.06058120727539, 61.465023040771484, 69.81842803955078, 80.47272491455078, 242.21490478515625]
initial base lr : [68.2101821899414, 38.062652587890625, 67.09164428710938, 72.38660430908203, 78.3447494506836, 93.293212890625, 289.7236633300781]
initial base lr : [69.08634948730469, 38.01998519897461, 66.5303726196289, 69.7338638305664, 72.82850646972656, 86.75313568115234, 247.78399658203125]
initial base lr : [69.98839569091797, 38.506282806396484, 67.24036407470703, 63.19353485107422, 76.40673065185547, 94.85008239746094, 269.35626220703125]
initial base lr : [68.72810363769531, 36.14960479736328, 60.415069580078125, 66.5676498413086, 64.38241577148438, 73.41007995605469, 205.81072998046875]


In [10]:
for i, ipllr in enumerate(ipllrs):
    ipllr.copy_initial_params_from_model(muPs[i])
    ipllr.initialize_params()

In [11]:
muPs_0 = [deepcopy(muP) for muP in muPs]

x, y = batches[0]
for muP in muPs:
    train_model_one_step(ipllr, x, y, batch_size)

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.9000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000, -0.9000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        ...,
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000, -0.9000,  0.1000]])
average training loss for model1 : 2.3025991916656494

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.9620,  0.2115,  0.0943,  ...,  0.0697,  0.1970,  0.1129],
        [ 0.0280, -0.7540,  0.0885,  ...,  0.0603,  0.2249,  0.1112],
        [ 0.0367,  0.2152,  0.0937,  ...,  0.0686,  0.2002, -0.8871],
        ...,
        [ 0.0421,  0.1994,  0.0960,  ...,  0.0730,  0.1871, -0.8870],
        [ 0.0423,  0.1990,  0.0960,  ...,  0.0731,  0.1868,  0.113

In [12]:
x, y = batches[0]
ipllrs_1 = []
for ipllr in ipllrs:
    train_model_one_step(ipllr, x, y, batch_size)
    ipllrs_1.append(deepcopy(ipllr))

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.9000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000, -0.9000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        ...,
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000, -0.9000,  0.1000]])
average training loss for model1 : 2.3025991916656494

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.9000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000, -0.9000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        ...,
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.100

In [13]:
results = {'muP': [], 'IPLLR': []}
results['muP'] = [collect_scales(muPs[i], muPs_0[i], batches[1:], eval_batch, n_steps) 
                  for i in range(N_TRIALS)]

In [None]:
results['IPLLR'] = [collect_scales(ipllrs[i], ipllrs_1[i], batches[1:], eval_batch, n_steps) 
                    for i in range(N_TRIALS)]

# Training

In [None]:
mode = 'training'

In [None]:
losses_muP = [r[0] for r in results['muP']]
losses_ip = [r[0] for r in results['IPLLR']]

chis_muP = [r[1] for r in results['muP']]
chis_ip = [r[1] for r in results['IPLLR']]

## Losses and derivatives

In [None]:
key = 'loss'
plt.figure(figsize=(12, 8))
plot_losses(losses_ip, losses_muP, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, mode=mode)
plt.savefig(os.path.join(FIGURES_DIR, '{}_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, L, width, base_lr, 
                                                                               batch_size)))
plt.show()

In [None]:
key = 'chi'
plt.figure(figsize=(12, 8))
plot_losses(chis_ip, chis_muP, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, mode=mode)
plt.savefig(os.path.join(FIGURES_DIR, '{}_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, L, width, base_lr, 
                                                                               batch_size)))
plt.show()

### Magnitude of the actvations of the network at different layers

In [None]:
dfs_muP = [r[2] for r in results['muP']]
dfs_ip = [r[2] for r in results['IPLLR']]

In [None]:
key = 'h'
for l in range(L-1, L+2):
    plt.figure(figsize=(12, 8))
    plot_output_scale(dfs_ip, dfs_muP, layer=l, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, 
                      mode=mode, y_scale='log')
    plt.savefig(os.path.join(FIGURES_DIR, 
                             '{}_{}_layer_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, l, L, width, base_lr, 
                                                                               batch_size)))
    plt.show()

In [None]:
key = 'h'
for l in range(L-1, L+2):
    plt.figure(figsize=(12, 8))
    plot_output_scale(dfs_ip, None, layer=l, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, 
                      mode=mode, y_scale='log')
    plt.savefig(os.path.join(FIGURES_DIR, 
                             '{}_{}_ip_layer_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, l, L, width, base_lr, 
                                                                                  batch_size)))
    plt.show()

### Contribution of the init to the activations at different layers.

In [None]:
key = 'h_init'
for l in range(1, L+2):
    plt.figure(figsize=(12, 8))
    plot_output_scale(dfs_ip, dfs_muP, layer=l, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size,
                      mode=mode, y_scale='log')
    plt.savefig(os.path.join(FIGURES_DIR, 
                             '{}_{}_layer_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, l, L, width, base_lr,
                                                                               batch_size)))
    plt.show()

### Magnitude of the update at different layers.

In [None]:
key = 'delta_h'
for l in range(1, L+2):
    plt.figure(figsize=(12, 8))
    plot_output_scale(dfs_ip, dfs_muP, layer=l, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size,
                      mode=mode, y_scale='log')
    plt.savefig(os.path.join(FIGURES_DIR, 
                             '{}_{}_layer_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, l, L, width, base_lr,
                                                                               batch_size)))
    plt.show()

# Validation

In [None]:
mode = 'val'

In [None]:
dfs_muP = [r[3] for r in results['muP']]
dfs_ip = [r[3] for r in results['IPLLR']]

## Losses and derivatives

In [None]:
key = 'loss'
plt.figure(figsize=(12, 8))
plot_losses(losses_ip, losses_muP, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, mode=mode)
plt.savefig(os.path.join(FIGURES_DIR, '{}_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, L, width, base_lr, 
                                                                               batch_size)))
plt.show()

In [None]:
key = 'chi'
plt.figure(figsize=(12, 8))
plot_losses(chis_ip, chis_muP, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, mode=mode)
plt.savefig(os.path.join(FIGURES_DIR, '{}_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, L, width, base_lr, 
                                                                               batch_size)))
plt.show()

### Magnitude of the actvations of the network at different layers

In [None]:
dfs_muP = [r[2] for r in results['muP']]
dfs_ip = [r[2] for r in results['IPLLR']]

In [None]:
key = 'h'
for l in range(L-1, L+2):
    plt.figure(figsize=(12, 8))
    plot_output_scale(dfs_ip, dfs_muP, layer=l, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, 
                      mode=mode, y_scale='log')
    plt.savefig(os.path.join(FIGURES_DIR, 
                             '{}_{}_layer_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, l, L, width, base_lr, 
                                                                               batch_size)))
    plt.show()

In [None]:
key = 'h'
for l in range(L-1, L+2):
    plt.figure(figsize=(12, 8))
    plot_output_scale(dfs_ip, None, layer=l, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size, 
                      mode=mode, y_scale='log')
    plt.savefig(os.path.join(FIGURES_DIR, 
                             '{}_{}_ip_layer_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, l, L, width, base_lr, 
                                                                                  batch_size)))
    plt.show()

### Contribution of the init to the activations at different layers.

In [None]:
key = 'h_init'
for l in range(1, L+2):
    plt.figure(figsize=(12, 8))
    plot_output_scale(dfs_ip, dfs_muP, layer=l, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size,
                      mode=mode, y_scale='log')
    plt.savefig(os.path.join(FIGURES_DIR, 
                             '{}_{}_layer_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, l, L, width, base_lr,
                                                                               batch_size)))
    plt.show()

### Magnitude of the update at different layers.

In [None]:
key = 'delta_h'
for l in range(1, L+2):
    plt.figure(figsize=(12, 8))
    plot_output_scale(dfs_ip, dfs_muP, layer=l, key=key, L=L, width=width, lr=base_lr, batch_size=batch_size,
                      mode=mode, y_scale='log')
    plt.savefig(os.path.join(FIGURES_DIR, 
                             '{}_{}_layer_{}_L={}_m={}_lr={}_bs={}.png'.format(mode, key, l, L, width, base_lr,
                                                                               batch_size)))
    plt.show()