In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/initialization')
CONFIG_PATH = os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', 'fc_ipllr_mnist.yaml')

In [3]:
import sys
sys.path.append(ROOT)

In [4]:
import os
from copy import deepcopy
import torch
import math
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, Subset, DataLoader
import torch.nn.functional as F

from utils.tools import read_yaml, set_random_seeds
from pytorch.configs.base import BaseConfig
from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from pytorch.models.abc_params.fully_connected.muP import FCmuP
from pytorch.models.abc_params.fully_connected.ntk import FCNTK
from pytorch.models.abc_params.fully_connected.standard_fc_ip import StandardFCIP
from utils.data.mnist import load_data
from utils.abc_params.debug_ipllr import *

### Load basic configuration and define variables 

In [5]:
N_TRIALS = 1
SEED = 30
L = 6
width = 1024
n_warmup_steps = 1
batch_size = 512
base_lr = 0.1
n_steps = 50

set_random_seeds(SEED)  # set random seed for reproducibility
config_dict = read_yaml(CONFIG_PATH)

In [6]:
config_dict = read_yaml(CONFIG_PATH)

input_size = config_dict['architecture']['input_size']

config_dict['architecture']['width'] = width
config_dict['architecture']['n_layers'] = L + 1
config_dict['optimizer']['params']['lr'] = base_lr
config_dict['scheduler'] = {'name': 'warmup_switch',
                            'params': {'n_warmup_steps': n_warmup_steps,
                                       'calibrate_base_lr': True,
                                       'default_calibration': False}}
        
base_model_config = ModelConfig(config_dict)

### Load data & define model

In [7]:
training_dataset, test_dataset = load_data(download=False, flatten=True)
train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size)
test_batches = list(DataLoader(test_dataset, shuffle=False, batch_size=batch_size))
batches = list(train_data_loader)
eval_batch = test_batches[0]

### Define model

In [8]:
ipllr = FcIPLLR(base_model_config, n_warmup_steps=12, lr_calibration_batches=batches)

initial base lr : [78.5, 47.14387893676758, 62.35230255126953, 61.66318893432617, 69.83497619628906, 80.47408294677734, 24.221656799316406]


In [9]:
ipllr.scheduler.warm_lrs[0] = ipllr.scheduler.warm_lrs[0] * (ipllr.d + 1)

### Save initial model : t=0

In [10]:
ipllr_0 = deepcopy(ipllr)

### Train model one step : t=1

In [11]:
x, y = batches[0]
train_model_one_step(ipllr, x, y, normalize_first=True)
ipllr_1 = deepcopy(ipllr)

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.9000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000, -0.9000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        ...,
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000, -0.9000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000,  0.1000,  0.1000],
        [ 0.1000,  0.1000,  0.1000,  ...,  0.1000, -0.9000,  0.1000]])
average training loss for model1 : 2.3025991916656494



### Train model for a second step : t=2

In [12]:
x, y = batches[1]
train_model_one_step(ipllr, x, y, normalize_first=True)
ipllr_2 = deepcopy(ipllr)

input abs mean in training:  0.6921874284744263
loss derivatives for model: tensor([[ 0.0605,  0.1368,  0.0974,  ...,  0.0852, -0.8384,  0.1191],
        [ 0.0461,  0.1547,  0.0935,  ...,  0.0766,  0.1982,  0.1259],
        [ 0.0579,  0.1399,  0.0969,  ...,  0.0837,  0.1676,  0.1204],
        ...,
        [ 0.0572,  0.1407,  0.0967,  ...,  0.0834,  0.1691,  0.1207],
        [ 0.0523, -0.8534,  0.0955,  ...,  0.0806,  0.1809,  0.1230],
        [ 0.0412,  0.1616,  0.0915,  ...,  0.0730, -0.7863,  0.1281]])
average training loss for model1 : 2.35217547416687



In [13]:
ipllr.eval()
ipllr_0.eval()
ipllr_1.eval()
ipllr_2.eval()
print()




In [14]:
layer_scales = ipllr.layer_scales
intermediate_layer_keys = ["layer_{:,}_intermediate".format(l) for l in range(2, L + 1)]

### Define W0 and b0

In [15]:
with torch.no_grad():
    W0 = {1: layer_scales[0] * ipllr_0.input_layer.weight.data.detach() / math.sqrt(ipllr_0.d + 1)}
    for i, l in enumerate(range(2, L + 1)):
        layer = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i])
        W0[l] = layer_scales[l-1] * layer.weight.data.detach()

    W0[L+1] = layer_scales[L] * ipllr_0.output_layer.weight.data.detach()

In [16]:
with torch.no_grad():
    b0 = layer_scales[0] * ipllr_0.input_layer.bias.data.detach() / math.sqrt(ipllr_0.d + 1)

### Define Delta_W_1 and Delta_b_1

In [17]:
with torch.no_grad():
    Delta_W_1 = {1: layer_scales[0] * (ipllr_1.input_layer.weight.data.detach() -
                                       ipllr_0.input_layer.weight.data.detach()) / math.sqrt(ipllr_1.d + 1)}
    for i, l in enumerate(range(2, L + 1)):
        layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
        layer_0 = getattr(ipllr_0.intermediate_layers, intermediate_layer_keys[i])
        Delta_W_1[l] = layer_scales[l-1] * (layer_1.weight.data.detach() -
                                            layer_0.weight.data.detach())

    Delta_W_1[L+1] = layer_scales[L] * (ipllr_1.output_layer.weight.data.detach() -
                                        ipllr_0.output_layer.weight.data.detach())

In [18]:
with torch.no_grad():
    Delta_b_1 = layer_scales[0] * (ipllr_1.input_layer.bias.data.detach() -
                                   ipllr_0.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

### Define Delta_W_2

In [19]:
with torch.no_grad():
    Delta_W_2 = {1: layer_scales[0] * (ipllr_2.input_layer.weight.data.detach() -
                                       ipllr_1.input_layer.weight.data.detach()) / math.sqrt(ipllr_2.d + 1)}
    for i, l in enumerate(range(2, L + 1)):
        layer_2 = getattr(ipllr_2.intermediate_layers, intermediate_layer_keys[i])
        layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
        Delta_W_2[l] = layer_scales[l-1] * (layer_2.weight.data.detach() -
                                            layer_1.weight.data.detach())

    Delta_W_2[L+1] = layer_scales[L] * (ipllr_2.output_layer.weight.data.detach() -
                                        ipllr_1.output_layer.weight.data.detach())

In [20]:
with torch.no_grad():
    Delta_b_2 = layer_scales[0] * (ipllr_2.input_layer.bias.data.detach() -
                                   ipllr_1.input_layer.bias.data.detach()) / math.sqrt(ipllr_1.d + 1)

## Explore at step 2

### On examples from the second batch

In [21]:
x, y = batches[1]

In [22]:
with torch.no_grad():
    x2 = {0: x}
    h0 = {1: F.linear(x, W0[1], b0)}
    delta_h_1 = {1: F.linear(x, Delta_W_1[1], Delta_b_1)}
    delta_h_2 = {1: F.linear(x, Delta_W_2[1], Delta_b_2)}
    h1 = {1: layer_scales[0] * ipllr_1.input_layer.forward(x) / math.sqrt(ipllr_1.d + 1)}
    h2 = {1: layer_scales[0] * ipllr_2.input_layer.forward(x) / math.sqrt(ipllr_2.d + 1)}
    x2[1] = ipllr_2.activation(h2[1])

In [23]:
torch.testing.assert_allclose(h0[1] + delta_h_1[1], h1[1], rtol=1e-5, atol=1e-5)
torch.testing.assert_allclose(h0[1] + delta_h_1[1] + delta_h_2[1], h2[1], rtol=1e-5, atol=1e-5)

In [24]:
prod_1 = delta_h_1[1] * delta_h_2[1]

In [25]:
(prod_1 < 0).sum() / prod_1.numel()

tensor(0.5192)

In [26]:
(delta_h_1[1] < 0).sum() / delta_h_1[1].numel()

tensor(0.4914)

In [27]:
(delta_h_2[1] < 0).sum() / delta_h_2[1].numel()

tensor(0.9020)

In [28]:
h0[1][0, :].abs().mean()

tensor(0.9147)

In [29]:
h1[1][0, :].abs().mean()

tensor(1.2897)

In [30]:
h2[1][0, :].abs().mean()

tensor(1.3521)

In [31]:
with torch.no_grad():
    for i, l in enumerate(range(2, L + 1)):
        layer_1 = getattr(ipllr_1.intermediate_layers, intermediate_layer_keys[i])
        layer_2 = getattr(ipllr_2.intermediate_layers, intermediate_layer_keys[i])
        x = x2[l-1]

        h0[l] =  F.linear(x, W0[l])
        delta_h_1[l] = F.linear(x, Delta_W_1[l])
        delta_h_2[l] = F.linear(x, Delta_W_2[l])
        
        h1[l] = layer_scales[l-1] * layer_1.forward(x)
        h2[l] = layer_scales[l-1] * layer_2.forward(x)
        x2[l] = ipllr_2.activation(h2[l])
        
        torch.testing.assert_allclose(h0[l] + delta_h_1[l], h1[l], rtol=1e-5, atol=1e-5)
        torch.testing.assert_allclose(h0[l] + delta_h_1[l] + delta_h_2[l], h2[l], rtol=1e-5, atol=1e-5)

In [32]:
with torch.no_grad():
    x = x2[L] 
    h0[L+1] = F.linear(x, W0[L+1])
    delta_h_1[L+1] = F.linear(x, Delta_W_1[L+1])
    delta_h_2[L+1] = F.linear(x, Delta_W_2[L+1])
    h1[L+1] = layer_scales[L] * ipllr_1.output_layer.forward(x)
    h2[L+1] = layer_scales[L] * ipllr_2.output_layer.forward(x)
    x2[L+1] = ipllr_2.activation(h2[L+1])
                              
    torch.testing.assert_allclose(h0[L+1] + delta_h_1[L+1], h1[L+1], rtol=1e-5, atol=1e-5)
    torch.testing.assert_allclose(h0[L+1] + delta_h_1[L+1] + delta_h_2[L+1], h2[L+1], rtol=1e-5, atol=1e-5)

##### Signs

In [33]:
prod_1 = delta_h_1[2] * delta_h_2[2]

In [34]:
(prod_1 < 0).sum() / prod_1.numel()

tensor(0.4599)

In [35]:
(delta_h_2[1] < 0).sum() / delta_h_2[1].numel()

tensor(0.9020)

In [36]:
(delta_h_2[2] < 0).sum() / delta_h_2[2].numel()

tensor(0.5661)

In [37]:
(delta_h_2[3] < 0).sum() / delta_h_2[3].numel()

tensor(0.4919)

In [38]:
(delta_h_2[4] < 0).sum() / delta_h_2[4].numel()

tensor(0.4131)

In [39]:
(delta_h_2[5] < 0).sum() / delta_h_2[4].numel()

tensor(0.3740)

In [40]:
(delta_h_2[6] < 0).sum() / delta_h_2[6].numel()

tensor(0.3447)

In [41]:
(delta_h_2[7] < 0).sum() / delta_h_2[7].numel()

tensor(0.4000)

In [42]:
delta_h_2[7]

tensor([[ 0.0060, -0.0046,  0.0021,  ...,  0.0014, -0.0062, -0.0022],
        [ 0.0009, -0.0007,  0.0003,  ...,  0.0002, -0.0009, -0.0003],
        [ 0.0021, -0.0016,  0.0007,  ...,  0.0005, -0.0021, -0.0008],
        ...,
        [ 0.0009, -0.0007,  0.0003,  ...,  0.0002, -0.0009, -0.0003],
        [ 0.0144, -0.0110,  0.0051,  ...,  0.0033, -0.0148, -0.0054],
        [ 0.0019, -0.0014,  0.0007,  ...,  0.0004, -0.0019, -0.0007]])

##### Outputs

In [43]:
_, b = torch.max(h0[7][:, :], dim=1)
b

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,

In [44]:
_, b = torch.max(delta_h_1[7][:, :], dim=1)
b

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,

In [45]:
_, b = torch.max(delta_h_2[7][:, :], dim=1)
b

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [46]:
_, b = torch.max(h2[7][:, :], dim=1)
b

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,

In [47]:
_, b = torch.max(h0[1][:, :], dim=1)
b

tensor([ 441,  494,  958,   49,  223,  813,   92,  910,  456,  416,   12,   28,
         867,  603,  612,  291,  568,  541,   86, 1015,  630,  405,  834,  910,
         745,  291,  713,  166,  561,  543,  494,   86,  416,  828,  630,  568,
         494,  199,  910,  675,  117,  206,  855,  703,  910,  422,  187,  790,
          53,  910,  427,  206,  943,   86,  755,  416,  311,  568,  703,  723,
         731,  760,  661,  227,  836, 1015,  206,   11,  227,   42,  206,  813,
         910,  942,  745,  227,  510,  403,   86,    9,  117,  291,  209, 1018,
          86,  910,  227,  589,   19,  703,  144,  509,  227,  612,   86,  903,
         165,  206,   12,  958,  227,  706,  117, 1015,  703,  498,  958,  910,
         910,  790,  910,  154,  910,   86,  945,  713,  227,  706,  612,  731,
         813,  469,  346,  510,  985,  696,  782,  133,  755,  439,  910,  760,
          23,  910,  971,   86,  427,  723,  713,  612,  372,  976,  498,  416,
         910,  257,   49,  354,  235,  6

In [48]:
_, b = torch.max(delta_h_1[1][:, :], dim=1)
b

tensor([541, 524, 602, 541, 541, 541, 356,  68, 524,  68, 524, 524, 602, 541,
         25, 602, 790,  68, 524, 524, 541, 541, 541, 541, 156, 524,  68, 524,
        541, 541,  68,  68,  68, 156, 524, 790, 524, 541, 541, 602, 839, 156,
        541, 541, 541,  68, 524, 541,  68,  68, 541, 156, 541, 524, 541,  25,
        356, 524, 541,  68, 602, 602, 356,  68, 524, 261,  68, 524,  68, 541,
        199, 524, 541, 524, 839,  68, 602, 839, 541, 524, 790, 541, 602, 602,
        541, 541,  68,  68, 541, 541,  68, 541,  68, 541, 541, 541, 524,  68,
        524, 541,  68, 602, 156, 261, 541, 541, 541,  68,  68, 541,  68, 541,
        541, 524, 541,  68,  68, 602, 524, 541, 524, 709, 602, 524, 541, 524,
        302, 541, 541, 156,  68, 602, 541,  68, 156,  68, 337,  68, 541, 528,
        528, 541,  68, 156,  68, 261, 524, 839,  68, 524,  68, 790, 602, 156,
        541, 541,  68, 356, 790, 524, 541, 541, 524, 156, 156, 541, 120, 541,
        524, 611, 541,  68, 541, 541, 541, 602, 541, 602, 524, 5

In [49]:
_, b = torch.max(delta_h_2[1][:, :], dim=1)
b

tensor([  49,  104,   49,  336,  336,  336,  104,   49,  104,   49,  104,  336,
          49,  336,   49,   49,  104,   49,  104,  694,   49,  104,  336,  336,
         104,  104,   49,  104,  104,  104,   49,   49,   49,   49,  336,  104,
         336,   49,   49,   49, 1013,   49,  683,  336,   49,  336,  104,   49,
          49,   49,   49,   49,  336,  104,  336,   49,  104,  104,  336,  227,
          49,   49,  336,  227,   49,  336,   49,  104,  227,  336,   49,  336,
          49,  336,  104,   49,   49,  104,   49,  104,  104,   49,   49,   49,
         336,   49,   49,  227,   49,  336,   49,  104,  227,   49,  336,   49,
         104,   49,   49,   49,  227,   49,  104,  104,  336,   49,  336,   49,
          49,   49,  227,  104,  336,  104,  336,   49,   49,   49,  104,   49,
         336,   49,   49,  336,  336,  104,  104,   49,   49,  104,  227,   49,
         336,   49,  104,   49,   49,   49,   49,  104,  813,  104,   49,   49,
          49,   49,  336,   49,   49,  1

In [50]:
_, b = torch.max(h2[1][:, :], dim=1)
b

tensor([ 925,  156,  404,   49,  350,  220,  302,  120,  891,  120,  220,  417,
         962,  220,  528,  274,  932,   68,  648,  667,   39, 1000,  220,  515,
         547,  165,  801,  656,  482,  515,   68,  294,   68,  853,  350,  791,
         420,  220,  407,  404,  547,  853,  515,   68,  515,  660,  220,  220,
         120,  971,   39,  853,  567,  420,  220,  803,  220,  997,  231,   68,
         294,   76,  356,  227,  997,  779,  993,  853,   68,  220,   57,  460,
         220,  620,  482,   68,  962,  667,  515,  527,   70,  750,  183,  274,
         292,  220,  227,  971,  925,  482,  251,  220,  120,  896,  407,  709,
         165,  227, 1006,  515,   68,  361,   70,  361,  482,  783,  515,  120,
          68,  220,   68,  925,  515,  749,  667,   68,  227,  274,  220,  220,
         292,  220,   57,  228,  667,  165,  782,  274,  220,  482,   68,   76,
         331,  611,  932,  750,  611,  120,  361,  853,  372,  925,  120,  853,
         578,  801,   49,  979,  361,  7

In [51]:
_, b = torch.max(h0[6][:, :], dim=1)
b

tensor([582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582,
        582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 582, 5

In [52]:
_, b = torch.max(delta_h_1[6][:, :], dim=1)
b

tensor([545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 5

In [53]:
_, b = torch.max(delta_h_2[6][:, :], dim=1)
b

tensor([215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215,
        215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 2

In [54]:
_, b = torch.max(h2[6][:, :], dim=1)
b

tensor([545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545,
        545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 545, 5

In [55]:
_, b = torch.max(h0[2][:, :], dim=1)
b

tensor([ 517,  951,   25,  661,  531,   42,  928,  902,  924,  902,  798,  375,
         707,  611,  902,  960,  611,  778,  418,  974,  882,  843,  611,   42,
         306,  980,  902,  148,  611,  886,  491,  902,  902,  491,  377,  749,
         418,  632,  517,  742,  850,  491,  246,  843,  843,  902,  193,  913,
         902,  902,  902,  882,  778, 1016,  246,  902,  825,  924,  246,  446,
         517,  843,  611,  902,  989,  246,  902,  392,  797,  193,  902,  878,
         244,  902,  306,  797,  242,  596,  186,  519,  306,  902,  743,  843,
          42,  843,  778,  446,  895,  611,  797,  424,  446,  913,  895,  196,
         121,  491,  902,  797,  446,  350,  306,  950,  611,  997,  611,  446,
         882,   42,  838,  400,  334,  611,  611,  491,  882,  960,  193,   42,
         895,  679,  196,  517,  611,  974,  428,  902,  916,  306,  838,  740,
         531,  882,  227,  902,  902,  902,  997,  340,  381,  974,  601,  491,
         902,  491,  446,  381,  679,  1

In [56]:
b.unique().numel() / b.numel()

0.26953125

In [57]:
_, b = torch.max(delta_h_1[2][:, :], dim=1)
b

tensor([797, 814, 797, 797, 893, 895, 895, 797, 895, 797, 895, 797, 577, 895,
        797, 797, 340, 797, 797, 895, 797, 340, 893, 797, 340, 797, 797, 814,
        340, 895, 797, 797, 797, 797, 797, 340, 797, 797, 797, 577, 340, 797,
        797, 797, 797, 797, 340, 893, 797, 797, 797, 797, 797, 340, 893, 797,
        797, 340, 797, 797, 797, 797, 797, 797, 895, 797, 797, 340, 797, 797,
        797, 895, 797, 797, 340, 797, 797, 340, 797, 814, 340, 797, 797, 797,
        895, 797, 797, 797, 797, 893, 797, 340, 797, 340, 797, 797, 797, 797,
        797, 895, 797, 797, 340, 797, 797, 797, 797, 797, 797, 797, 797, 895,
        797, 340, 895, 797, 797, 797, 340, 797, 895, 797, 797, 797, 893, 814,
        895, 797, 895, 340, 797, 797, 895, 797, 340, 797, 797, 797, 797, 203,
        340, 895, 797, 797, 797, 797, 797, 340, 797, 895, 797, 340, 895, 814,
        797, 895, 797, 895, 340, 340, 895, 893, 797, 340, 340, 797, 797, 797,
        895, 797, 895, 797, 895, 797, 797, 797, 895, 797, 895, 7

In [58]:
_, b = torch.max(delta_h_2[2][:, :], dim=1)
b

tensor([523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523,
        523, 523, 523, 560, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 560, 523, 523, 523,
        560, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 560, 523, 523,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 560,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 560, 523, 523, 523,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523,
        523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523, 523,
        523, 560, 523, 523, 523, 523, 523, 523, 523, 523, 523, 5

In [59]:
_, b = torch.max(h2[2][:, :], dim=1)
b

tensor([797, 814, 797, 797, 895, 895, 895, 797, 895, 797, 895, 797, 577, 895,
        902, 797, 340, 797, 797, 895, 797, 895, 893, 797, 340, 797, 797, 797,
        340, 895, 797, 797, 797, 797, 797, 340, 797, 797, 797, 577, 340, 797,
        797, 797, 797, 797, 340, 893, 797, 797, 797, 797, 797, 340, 893, 797,
        797, 340, 797, 797, 797, 797, 797, 797, 895, 797, 797, 340, 797, 797,
        797, 895, 797, 797, 340, 797, 797, 340, 797, 203, 340, 797, 797, 797,
        895, 797, 797, 797, 797, 893, 797, 895, 797, 340, 797, 797, 797, 797,
        797, 895, 797, 797, 340, 797, 893, 797, 797, 797, 797, 797, 797, 895,
        797, 340, 895, 797, 797, 797, 340, 797, 895, 797, 797, 797, 893, 895,
        895, 797, 895, 340, 797, 797, 895, 797, 340, 797, 797, 797, 797, 203,
        340, 895, 797, 797, 797, 797, 797, 340, 797, 340, 797, 340, 895, 814,
        895, 895, 797, 895, 340, 340, 895, 893, 797, 340, 340, 797, 797, 797,
        895, 797, 895, 797, 895, 797, 797, 797, 895, 797, 895, 7

In [60]:
b.unique().numel()

9

In [61]:
columns = ['h0', 'delta_h_1', 'delta_h_2', 'h1', 'h2']
df = pd.DataFrame(columns=columns, index=range(1, L+2))
df.index.name = 'layer'

for l in df.index:
    maxes = dict()
    
    _, maxes['h0'] = torch.max(h0[l] , dim=1)
    _, maxes['delta_h_1'] = torch.max(delta_h_1[l] , dim=1)
    _, maxes['delta_h_2'] = torch.max(delta_h_2[l] , dim=1)
    _, maxes['h1'] = torch.max(h1[l] , dim=1)
    _, maxes['h2'] = torch.max(h2[l] , dim=1)

    df.loc[l, columns] = [maxes[key].unique().numel() for key in columns]
    
df.loc[:, 'batch_size'] = batch_size
df

Unnamed: 0_level_0,h0,delta_h_1,delta_h_2,h1,h2,batch_size
layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,171,25,9,68,105,512
2,138,7,2,9,9,512
3,15,1,2,1,1,512
4,2,1,1,1,1,512
5,1,1,1,1,1,512
6,1,1,1,1,1,512
7,1,1,1,1,1,512


In [62]:
0.00195312 * 512

0.99999744

In [63]:
_, b = torch.max(h0[7] , dim=1)
b.unique()

tensor([8])

In [64]:
b.unique().numel() / 512

0.001953125

##### Scales

In [65]:
columns = ['h0', 'delta_h_1', 'delta_h_2', 'h1', 'h2']
df = pd.DataFrame(columns=columns, index=range(1, L+2))
df.index.name = 'layer'
for l in df.index:
    df.loc[l, columns] = [h0[l][0, :].abs().mean().item(), delta_h_1[l][0, :].abs().mean().item(), 
                          delta_h_2[l][0, :].abs().mean().item(), h1[l][0, :].abs().mean().item(),  
                          h2[l][0, :].abs().mean().item()]
df

Unnamed: 0_level_0,h0,delta_h_1,delta_h_2,h1,h2
layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,0.91465,0.899165,0.539691,1.2897,1.35211
2,0.035248,0.642963,0.00274383,0.655749,0.654193
3,0.0247807,0.649087,0.00478905,0.660623,0.656214
4,0.0248684,0.636594,0.00525915,0.649044,0.644057
5,0.023723,0.618592,0.00547487,0.629891,0.62506
6,0.0236104,0.598841,0.00673273,0.608588,0.604272
7,0.0619903,0.0578427,0.00260525,0.117393,0.115223


In [66]:
prod = h1[2][0, :] * delta_h_2[2][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.4443)

In [67]:
prod = h1[3][0, :] * delta_h_2[3][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.4727)

In [68]:
prod = h1[4][0, :] * delta_h_2[4][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.4102)

In [69]:
prod = h1[5] * delta_h_2[5]
(prod < 0).sum() / prod.numel()

tensor(0.3735)

In [70]:
prod = h1[6] * delta_h_2[6]
(prod < 0).sum() / prod.numel()

tensor(0.3438)

In [71]:
prod = h1[7][132, :] * delta_h_2[7][132, :]
(prod < 0).sum() / prod.numel()

tensor(0.8000)

In [72]:
(h1[7][1, :] < 0).sum() / h1[7][1, :].numel()

tensor(0.6000)