In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
cwd = os.getcwd()

NOTEBOOK_DIR = os.path.dirname(cwd)
ROOT = os.path.dirname(os.path.dirname(os.path.dirname(NOTEBOOK_DIR)))

FIGURES_DIR = os.path.join(ROOT, 'figures/abc_parameterizations/initialization')
CONFIG_PATH = os.path.join(ROOT, 'pytorch/configs/abc_parameterizations', 'fc_ipllr_mnist.yaml')

In [3]:
import sys
sys.path.append(ROOT)

In [4]:
import os
from copy import deepcopy
import torch
import math
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, Subset, DataLoader
import torch.nn.functional as F

from utils.tools import read_yaml, set_random_seeds
from pytorch.configs.base import BaseConfig
from pytorch.configs.model import ModelConfig
from pytorch.models.abc_params.fully_connected.ipllr import FcIPLLR
from pytorch.models.abc_params.fully_connected.muP import FCmuP
from pytorch.models.abc_params.fully_connected.ntk import FCNTK
from pytorch.models.abc_params.fully_connected.standard_fc_ip import StandardFCIP
from utils.data.mnist import load_data
from utils.abc_params.debug_ipllr import *

### Load basic configuration and define variables 

In [5]:
N_TRIALS = 1
SEED = 30
L = 6
width = 1024
n_warmup_steps = 1
batch_size = 512
base_lr = 0.1
n_steps = 50

set_random_seeds(SEED)  # set random seed for reproducibility
config_dict = read_yaml(CONFIG_PATH)

In [6]:
config_dict = read_yaml(CONFIG_PATH)

input_size = config_dict['architecture']['input_size']

config_dict['architecture']['width'] = width
config_dict['architecture']['n_layers'] = L + 1
config_dict['optimizer']['params']['lr'] = base_lr
        
base_model_config = ModelConfig(config_dict)
base_model_config.scheduler = None

### Load data & define model

In [7]:
training_dataset, test_dataset = load_data(download=False, flatten=True)
train_data_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size)
test_batches = list(DataLoader(test_dataset, shuffle=False, batch_size=batch_size))
batches = list(train_data_loader)
eval_batch = test_batches[0]

### Define model

In [8]:
muP = FCmuP(base_model_config)

In [9]:
for l, param_group in enumerate(muP.optimizer.param_groups):
    if l == 0:
        param_group['lr'] = param_group['lr'] * (muP.d + 1)

### Save initial model : t=0

In [10]:
muP_0 = deepcopy(muP)

### Train model one step : t=1

In [11]:
x, y = batches[0]
train_model_one_step(muP, x, y, normalize_first=True)
muP_1 = deepcopy(muP)

input abs mean in training:  0.6950533986091614
loss derivatives for model: tensor([[-0.8969,  0.1024,  0.1054,  ...,  0.0975,  0.0981,  0.0973],
        [ 0.1020, -0.8974,  0.1068,  ...,  0.0999,  0.1001,  0.0976],
        [ 0.1028,  0.1035,  0.1075,  ...,  0.0969,  0.1011, -0.9035],
        ...,
        [ 0.1007,  0.1015,  0.1078,  ...,  0.0990,  0.0991, -0.9043],
        [ 0.1040,  0.1026,  0.1052,  ...,  0.0975,  0.1001,  0.0976],
        [ 0.1053,  0.1027,  0.1094,  ...,  0.0963, -0.8984,  0.0981]])
average training loss for model1 : 2.303544521331787



### Train model for a second step : t=2

In [12]:
x, y = batches[1]
train_model_one_step(muP, x, y, normalize_first=True)
muP_2 = deepcopy(muP)

input abs mean in training:  0.6921874284744263
loss derivatives for model: tensor([[ 0.0947,  0.1087,  0.1084,  ...,  0.0948, -0.8999,  0.0992],
        [ 0.1012,  0.1073,  0.1296,  ...,  0.0830,  0.1024,  0.0911],
        [ 0.1007,  0.1040,  0.1072,  ...,  0.0943,  0.1014,  0.0981],
        ...,
        [ 0.1057,  0.0997,  0.1064,  ...,  0.0922,  0.1043,  0.0990],
        [ 0.0970, -0.8785,  0.1138,  ...,  0.0948,  0.0976,  0.0933],
        [ 0.1000,  0.1067,  0.1193,  ...,  0.0912, -0.8911,  0.0967]])
average training loss for model1 : 2.2230353355407715



In [13]:
muP.eval()
muP_0.eval()
muP_1.eval()
muP_2.eval()
print()




In [14]:
layer_scales = muP.layer_scales
intermediate_layer_keys = ["layer_{:,}_intermediate".format(l) for l in range(2, L + 1)]

### Define W0 and b0

In [15]:
W0 = {1: layer_scales[0] * muP_0.input_layer.weight.data.detach() / math.sqrt(muP_0.d + 1)}
for i, l in enumerate(range(2, L + 1)):
    layer = getattr(muP_0.intermediate_layers, intermediate_layer_keys[i])
    W0[l] = layer_scales[l-1] * layer.weight.data.detach()

W0[L+1] = layer_scales[L] * muP_0.output_layer.weight.data.detach()

In [16]:
b0 = layer_scales[0] * muP_0.input_layer.bias.data.detach() / math.sqrt(muP_0.d + 1)

### Define Delta_W_1 and Delta_b_1

In [17]:
with torch.no_grad():
    Delta_W_1 = {1: layer_scales[0] * (muP_1.input_layer.weight.data.detach() -
                                       muP_0.input_layer.weight.data.detach()) / math.sqrt(muP_1.d + 1)}
    for i, l in enumerate(range(2, L + 1)):
        layer_1 = getattr(muP_1.intermediate_layers, intermediate_layer_keys[i])
        layer_0 = getattr(muP_0.intermediate_layers, intermediate_layer_keys[i])
        Delta_W_1[l] = layer_scales[l-1] * (layer_1.weight.data.detach() -
                                            layer_0.weight.data.detach())

    Delta_W_1[L+1] = layer_scales[L] * (muP_1.output_layer.weight.data.detach() -
                                        muP_0.output_layer.weight.data.detach())

In [18]:
with torch.no_grad():
    Delta_b_1 = layer_scales[0] * (muP_1.input_layer.bias.data.detach() -
                                   muP_0.input_layer.bias.data.detach()) / math.sqrt(muP_1.d + 1)

### Define Delta_W_2

In [19]:
with torch.no_grad():
    Delta_W_2 = {1: layer_scales[0] * (muP_2.input_layer.weight.data.detach() -
                                       muP_1.input_layer.weight.data.detach()) / math.sqrt(muP_2.d + 1)}
    for i, l in enumerate(range(2, L + 1)):
        layer_2 = getattr(muP_2.intermediate_layers, intermediate_layer_keys[i])
        layer_1 = getattr(muP_1.intermediate_layers, intermediate_layer_keys[i])
        Delta_W_2[l] = layer_scales[l-1] * (layer_2.weight.data.detach() -
                                            layer_1.weight.data.detach())

    Delta_W_2[L+1] = layer_scales[L] * (muP_2.output_layer.weight.data.detach() -
                                        muP_1.output_layer.weight.data.detach())

In [20]:
with torch.no_grad():
    Delta_b_2 = layer_scales[0] * (muP_2.input_layer.bias.data.detach() -
                                   muP_1.input_layer.bias.data.detach()) / math.sqrt(muP_1.d + 1)

## Explore at step 2

### On examples from the second batch

In [21]:
x, y = batches[1]

In [22]:
with torch.no_grad():
    x2 = {0: x}
    h0 = {1: F.linear(x, W0[1], b0)}
    delta_h_1 = {1: F.linear(x, Delta_W_1[1], Delta_b_1)}
    delta_h_2 = {1: F.linear(x, Delta_W_2[1], Delta_b_2)}
    h1 = {1: layer_scales[0] * muP_1.input_layer.forward(x) / math.sqrt(muP_1.d + 1)}
    h2 = {1: layer_scales[0] * muP_2.input_layer.forward(x) / math.sqrt(muP_2.d + 1)}
    x2[1] = muP_2.activation(h2[1])
    
    torch.testing.assert_allclose(h0[1] + delta_h_1[1], h1[1], rtol=1e-5, atol=1e-5)
    torch.testing.assert_allclose(h0[1] + delta_h_1[1] + delta_h_2[1], h2[1], rtol=1e-5, atol=1e-5)

In [23]:
prod_1 = delta_h_1[1] * delta_h_2[1]

In [24]:
(prod_1 < 0).sum() / prod_1.numel()

tensor(0.5487)

In [25]:
(delta_h_1[1] < 0).sum() / delta_h_1[1].numel()

tensor(0.4941)

In [26]:
(delta_h_2[1] < 0).sum() / delta_h_2[1].numel()

tensor(0.3548)

In [27]:
h0[1][0, :].abs().mean()

tensor(0.9147)

In [28]:
h1[1][0, :].abs().mean()

tensor(1.2890)

In [29]:
h2[1][0, :].abs().mean()

tensor(1.7712)

In [30]:
with torch.no_grad():
    for i, l in enumerate(range(2, L + 1)):
        layer_1 = getattr(muP_1.intermediate_layers, intermediate_layer_keys[i])
        layer_2 = getattr(muP_2.intermediate_layers, intermediate_layer_keys[i])
        x = x2[l-1]

        h0[l] =  F.linear(x, W0[l])
        delta_h_1[l] = F.linear(x, Delta_W_1[l])
        delta_h_2[l] = F.linear(x, Delta_W_2[l])
        
        h1[l] = layer_scales[l-1] * layer_1.forward(x)
        h2[l] = layer_scales[l-1] * layer_2.forward(x)
        x2[l] = muP_2.activation(h2[l])
        
        torch.testing.assert_allclose(h0[l] + delta_h_1[l], h1[l], rtol=1e-5, atol=1e-5)
        torch.testing.assert_allclose(h0[l] + delta_h_1[l] + delta_h_2[l], h2[l], rtol=1e-5, atol=1e-5)

In [31]:
with torch.no_grad():
    x = x2[L] 
    h0[L+1] = F.linear(x, W0[L+1])
    delta_h_1[L+1] = F.linear(x, Delta_W_1[L+1])
    delta_h_2[L+1] = F.linear(x, Delta_W_2[L+1])
    h1[L+1] = layer_scales[L] * muP_1.output_layer.forward(x)
    h2[L+1] = layer_scales[L] * muP_2.output_layer.forward(x)
    x2[L+1] = muP_2.activation(h2[L+1])
                              
    torch.testing.assert_allclose(h0[L+1] + delta_h_1[L+1], h1[L+1], rtol=1e-5, atol=1e-5)
    torch.testing.assert_allclose(h0[L+1] + delta_h_1[L+1] + delta_h_2[L+1], h2[L+1], rtol=1e-5, atol=1e-5)

##### Signs

In [32]:
prod_1 = delta_h_1[2] * delta_h_2[2]

In [33]:
(prod_1 < 0).sum() / prod_1.numel()

tensor(0.5046)

In [34]:
(delta_h_2[1] < 0).sum() / delta_h_2[1].numel()

tensor(0.3548)

In [35]:
(delta_h_2[2] < 0).sum() / delta_h_2[2].numel()

tensor(0.3726)

In [36]:
(delta_h_2[3] < 0).sum() / delta_h_2[3].numel()

tensor(0.3658)

In [37]:
(delta_h_2[4] < 0).sum() / delta_h_2[4].numel()

tensor(0.3825)

In [38]:
(delta_h_2[5] < 0).sum() / delta_h_2[4].numel()

tensor(0.3916)

In [39]:
(delta_h_2[6] < 0).sum() / delta_h_2[6].numel()

tensor(0.3795)

In [40]:
(delta_h_2[7] < 0).sum() / delta_h_2[7].numel()

tensor(0.4650)

In [41]:
delta_h_2[7]

tensor([[ 6.1566e-03,  1.8271e-03,  6.0032e-04,  ...,  1.0044e-03,
          6.4865e-03,  1.3175e-03],
        [ 2.2221e-02, -3.1599e-03,  2.5931e-03,  ..., -2.8529e-03,
          8.0231e-03,  9.9353e-05],
        [ 1.1129e-02,  1.3531e-04, -1.2471e-03,  ...,  3.8138e-04,
          6.5811e-03,  1.1236e-03],
        ...,
        [ 1.6294e-02, -1.8284e-03, -6.6554e-04,  ..., -6.8285e-04,
          6.1427e-03,  3.1220e-04],
        [ 6.5086e-03,  7.7496e-03,  2.3772e-03,  ...,  3.6495e-04,
          8.8420e-03, -3.3902e-04],
        [ 1.8702e-02, -2.1560e-03,  1.8816e-03,  ...,  1.1074e-04,
          9.8820e-03,  9.3896e-04]])

##### Outputs

In [42]:
_, b = torch.max(h0[7][:, :], dim=1)
b

tensor([2, 6, 3, 2, 7, 8, 2, 1, 6, 1, 2, 2, 3, 0, 2, 3, 0, 1, 2, 2, 2, 2, 7, 2,
        0, 6, 2, 6, 0, 2, 2, 8, 1, 2, 2, 0, 2, 2, 2, 3, 0, 2, 7, 2, 7, 2, 0, 0,
        8, 2, 8, 2, 8, 0, 7, 2, 4, 0, 7, 1, 3, 3, 7, 1, 6, 7, 2, 0, 1, 7, 2, 4,
        8, 2, 0, 1, 2, 0, 8, 2, 0, 8, 3, 3, 2, 2, 1, 1, 2, 7, 2, 7, 1, 0, 8, 8,
        2, 2, 2, 8, 1, 8, 0, 2, 7, 8, 8, 1, 1, 8, 1, 2, 7, 0, 7, 1, 1, 3, 0, 8,
        6, 3, 8, 2, 7, 2, 2, 8, 8, 0, 1, 3, 0, 2, 0, 2, 2, 1, 8, 2, 0, 2, 2, 2,
        2, 2, 6, 0, 3, 2, 1, 0, 0, 2, 2, 9, 2, 4, 0, 0, 9, 7, 2, 0, 0, 2, 2, 9,
        2, 2, 7, 1, 9, 2, 7, 3, 9, 3, 6, 2, 0, 1, 2, 2, 7, 8, 3, 2, 2, 2, 0, 3,
        0, 8, 0, 2, 8, 2, 1, 3, 8, 0, 2, 2, 4, 7, 0, 7, 2, 6, 8, 7, 0, 7, 7, 2,
        2, 2, 7, 7, 8, 3, 2, 0, 2, 0, 2, 7, 0, 1, 6, 1, 1, 8, 8, 8, 2, 2, 8, 2,
        8, 2, 2, 2, 2, 2, 8, 9, 2, 0, 2, 8, 8, 2, 2, 6, 7, 0, 8, 3, 3, 8, 2, 1,
        2, 2, 0, 2, 7, 2, 2, 3, 0, 8, 0, 8, 8, 2, 0, 2, 0, 9, 3, 2, 2, 8, 8, 7,
        8, 7, 8, 2, 2, 0, 1, 2, 2, 0, 8,

In [43]:
_, b = torch.max(delta_h_1[7][:, :], dim=1)
b

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,

In [44]:
_, b = torch.max(delta_h_2[7][:, :], dim=1)
b

tensor([3, 0, 3, 0, 0, 0, 0, 3, 0, 3, 0, 0, 3, 0, 0, 3, 0, 3, 0, 0, 0, 0, 0, 0,
        0, 0, 3, 0, 0, 0, 3, 3, 3, 0, 0, 0, 0, 3, 3, 3, 0, 0, 0, 3, 3, 3, 0, 0,
        3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 0, 3, 0, 3, 3, 0, 3, 3, 3, 0,
        0, 0, 0, 3, 3, 0, 0, 0, 0, 0, 3, 3, 0, 3, 3, 3, 0, 0, 3, 0, 3, 0, 0, 0,
        0, 3, 0, 3, 3, 3, 0, 0, 0, 3, 0, 3, 3, 3, 3, 0, 3, 0, 0, 3, 3, 3, 0, 3,
        0, 3, 3, 0, 0, 0, 0, 0, 3, 0, 3, 3, 0, 3, 0, 3, 3, 3, 3, 0, 0, 0, 3, 0,
        3, 0, 0, 0, 3, 0, 3, 0, 0, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 3,
        0, 3, 0, 3, 0, 3, 3, 3, 0, 3, 0, 0, 0, 3, 0, 3, 3, 0, 3, 0, 0, 0, 0, 3,
        0, 3, 0, 0, 0, 0, 3, 3, 0, 0, 3, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0,
        0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 3, 3, 0, 0, 0, 0, 3, 0, 0,
        3, 0, 3, 0, 0, 0, 0, 3, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 3, 3, 3, 0, 0, 3,
        0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
        3, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0,

In [45]:
_, b = torch.max(h2[7][:, :], dim=1)
b

tensor([2, 6, 3, 2, 7, 8, 2, 1, 6, 1, 2, 2, 3, 0, 2, 3, 0, 1, 2, 2, 2, 2, 7, 2,
        0, 6, 2, 6, 0, 2, 2, 8, 1, 2, 2, 0, 2, 2, 2, 3, 0, 2, 7, 2, 7, 2, 0, 0,
        8, 2, 8, 2, 8, 0, 7, 2, 4, 0, 7, 1, 3, 3, 7, 1, 6, 7, 2, 0, 1, 7, 2, 4,
        8, 2, 0, 1, 8, 0, 8, 2, 0, 8, 3, 3, 2, 2, 1, 1, 2, 7, 8, 7, 1, 0, 8, 8,
        2, 2, 2, 8, 1, 8, 0, 2, 7, 8, 8, 1, 1, 8, 1, 2, 7, 0, 7, 1, 1, 3, 0, 8,
        6, 3, 8, 2, 7, 2, 2, 8, 8, 0, 1, 3, 0, 8, 0, 2, 2, 1, 8, 2, 0, 2, 2, 2,
        2, 2, 6, 0, 3, 2, 1, 0, 0, 2, 2, 9, 2, 4, 0, 0, 9, 7, 2, 0, 0, 8, 2, 9,
        2, 2, 7, 1, 9, 2, 7, 3, 9, 3, 6, 2, 0, 1, 2, 2, 7, 8, 3, 2, 2, 2, 0, 3,
        0, 8, 0, 2, 8, 2, 1, 3, 8, 0, 2, 2, 4, 7, 0, 7, 2, 6, 8, 7, 0, 7, 7, 2,
        2, 2, 7, 9, 8, 3, 2, 0, 2, 0, 2, 7, 0, 1, 6, 1, 1, 8, 8, 8, 2, 2, 8, 2,
        8, 2, 8, 2, 2, 2, 8, 9, 2, 0, 2, 8, 8, 2, 2, 2, 7, 0, 8, 3, 3, 8, 2, 1,
        2, 2, 0, 2, 7, 2, 2, 3, 0, 8, 0, 8, 8, 2, 0, 2, 0, 9, 3, 2, 2, 8, 8, 7,
        8, 7, 8, 0, 2, 0, 1, 2, 9, 0, 8,

In [46]:
_, b = torch.max(h0[1][:, :], dim=1)
b

tensor([ 441,  494,  958,   49,  223,  813,   92,  910,  456,  416,   12,   28,
         867,  603,  612,  291,  568,  541,   86, 1015,  630,  405,  834,  910,
         745,  291,  713,  166,  561,  543,  494,   86,  416,  828,  630,  568,
         494,  199,  910,  675,  117,  206,  855,  703,  910,  422,  187,  790,
          53,  910,  427,  206,  943,   86,  755,  416,  311,  568,  703,  723,
         731,  760,  661,  227,  836, 1015,  206,   11,  227,   42,  206,  813,
         910,  942,  745,  227,  510,  403,   86,    9,  117,  291,  209, 1018,
          86,  910,  227,  589,   19,  703,  144,  509,  227,  612,   86,  903,
         165,  206,   12,  958,  227,  706,  117, 1015,  703,  498,  958,  910,
         910,  790,  910,  154,  910,   86,  945,  713,  227,  706,  612,  731,
         813,  469,  346,  510,  985,  696,  782,  133,  755,  439,  910,  760,
          23,  910,  971,   86,  427,  723,  713,  612,  372,  976,  498,  416,
         910,  257,   49,  354,  235,  6

In [47]:
_, b = torch.max(delta_h_1[1][:, :], dim=1)
b

tensor([ 541,  524,  602,  541,  541,  541,  356,   68,  524,   68,  524,  524,
         602,  541,   25,  602,  790,   68,  524,  524,  541,  541,  541,  541,
         156,  524,   68,  524,  541,  541,   68,   68,   68,  156,  524,  790,
         524,  541,  541,  602,  839,  853,  541,  541,  541,  120,  524,  541,
          68,   68,  541,  156,  541,  524,  541,  524,  524,  524,  541,   68,
         602,  602,  356,   68,  524,  356,   68,  524,   68,  541,  199,  524,
         541,  524,  839,   68,  602,  790,  541,  524,  790,  541,  602,  602,
         541,  541,   68,   68,  541,  541,   68,  541,   68,  541,  541,  541,
         524,   68,  524,  541,   68,  602,  790,  524,  541,  541,  541,   68,
          68,  541,   68,  541,  541,  524,  541,   68,   68,  602,  524,  541,
         524,  602,  602,  524,  541,  524,  302,  541,  541,  156,   68,  602,
         541,   68,  602,   68,  337,   68,  541,  524,  528,  541,   68,  156,
          68,  261,  524,  790,   68,  5

In [48]:
_, b = torch.max(delta_h_2[1][:, :], dim=1)
b

tensor([ 908, 1022,  728,  908,  908,  908,  908,  824, 1022,  824, 1022, 1022,
         728,  167,  824,  728,  760,  824, 1022, 1022, 1022, 1022,  167,  908,
         760, 1022,  824, 1022,  187,  908,  824,  187,  824, 1022,  908,   18,
        1022,  187,  908,  728,  760, 1022,  908,  908,  908, 1022,  187,  187,
         187,  824, 1022, 1022,  167,  187,  167, 1022,  908, 1022,  908,  588,
         728,  728,  908,  588,  728,  908, 1022, 1022,  588,  908,  728,  908,
         728, 1022,  760,  824,  187,  760,  187, 1022,  230, 1022,  728,  728,
         908,  908,  824,  824,  908,  167,  187,  908,  824,  728,  908,  187,
        1022,  824, 1022,  187,  824,  187,  760,  824,  908, 1022,  908,  824,
         824,  187,  824,  908,  908,  230,  908,  824,  824,  728,  187,  187,
        1022,  187,  728, 1022,  167, 1022,  908, 1022,  187,   18,  824,  728,
        1022, 1022,  760,  187, 1022,  824, 1022, 1022,  760, 1022,  824,  824,
         824,  824, 1022,   18,  728, 10

In [49]:
_, b = torch.max(h2[1][:, :], dim=1)
b

tensor([ 228, 1022,  187,  602,  187,  908, 1022,  228,  602,  681, 1022, 1022,
         728,  602, 1022,  187,  230,  801, 1022, 1022, 1022, 1022,  187,  657,
         230,  602,  801,  656,  230,  908,  494,   57,  801, 1022,  293,  230,
         494,  657,  228,  728,  760, 1022,   57,  154,  228,  494,  187,  790,
          57,  494,  908,  206,  908,  230,  187, 1022,  602,  230,   57,  681,
          57,  760,  293,  588,  188,  494,  206,  230,  894,   57,  206,  908,
         657,  765,  230,  228,  602,  230,   57, 1022,  230,  790,  507,  728,
         908,  908,  588,  588,  228,  187,   17,  177,  681,  790,  228,  507,
        1022,  878, 1022,  602,  681,  602,  230,  211,  167,  657,  657,  824,
          57,  808,  681,  908,  228,  230,  167,  801,  588,  602,  230,  602,
        1022,   57,  602,  228,  167, 1022, 1022, 1022,  602,  230,  588,  760,
        1022,  790,  187,   57,  529,   57,  228, 1022,  230, 1022,  494, 1022,
          57,  494,  524,  230,  187, 10

In [50]:
_, b = torch.max(h0[2][:, :], dim=1)
b

tensor([ 156,  726,  726,  684,  769,  428,  684,  902,  271,  778,  726,  531,
         726,  769,  726,  726,  726,  778,  531,  883,  156, 1021,  769,  684,
         726,  428,  902, 1021,  726,  684,  156,  726,  778,  726,  883,  726,
         726,  769,  902,  726,  726,  883,  769,  902,  769,  531,  726,  726,
         902,  902,  769,  883,  531,  726,  769,  726,  684,  726,  246,  778,
         221, 1021,  246,  902,  726,  246,  902,  531,  778,  769,  726,  428,
         726,  531,  726,  902,  156,  726,  726,  726,  726,  726,  726,  726,
         428,  661,  778,  246,  246,  769,  726,  769,  778,  726,  531,  726,
         726,  726,  883,  769,  778,  726,  726,  156,  961,  726,  769,  902,
         778,  769,  778,  684,  769,  726,  684,  778,  902,  726,  726,  726,
         428,  221,  726,  156,  684,  726,  401,  726,  769,  726,  778,  726,
         531,  902,  726,  902,  902,  902,  156,  726,  726,  726,  778,  726,
         902, 1021,  531,  726,  726,  7

In [51]:
_, b = torch.max(delta_h_1[2][:, :], dim=1)
b

tensor([797, 814, 895, 797, 895, 895, 895, 797, 814, 797, 814, 797, 340, 895,
        814, 895, 340, 797, 895, 895, 797, 895, 895, 895, 340, 814, 797, 814,
        340, 895, 797, 797, 797, 814, 797, 340, 814, 797, 797, 340, 340, 814,
        895, 797, 797, 797, 340, 340, 797, 797, 797, 814, 797, 340, 895, 814,
        797, 340, 797, 797, 797, 340, 797, 797, 814, 797, 797, 340, 797, 797,
        797, 895, 895, 797, 340, 797, 797, 340, 895, 814, 340, 797, 340, 895,
        895, 797, 797, 797, 797, 895, 797, 895, 797, 340, 895, 895, 814, 797,
        814, 895, 797, 340, 340, 797, 797, 797, 895, 797, 797, 895, 797, 895,
        797, 340, 895, 797, 797, 895, 340, 797, 895, 797, 814, 797, 895, 814,
        895, 895, 895, 340, 797, 340, 895, 797, 340, 797, 797, 797, 797, 814,
        340, 895, 797, 814, 797, 797, 797, 340, 814, 814, 797, 340, 340, 814,
        895, 895, 797, 895, 340, 340, 895, 895, 814, 340, 340, 797, 797, 895,
        895, 797, 895, 797, 895, 797, 895, 340, 895, 340, 814, 8

In [52]:
_, b = torch.max(delta_h_2[2][:, :], dim=1)
b

tensor([469, 469, 469, 469, 469, 469, 469, 975, 469, 975, 469, 469, 469, 469,
        469, 469, 469, 975, 469, 469, 469, 469, 469, 469, 469, 469, 975, 469,
        469, 469, 975, 469, 975, 469, 469, 469, 469, 469, 469, 469, 469, 490,
        469, 469, 469, 469, 469, 469, 469, 975, 469, 469, 469, 469, 469, 469,
        469, 469, 469, 975, 469, 469, 469, 975, 469, 469, 469, 469, 975, 469,
        469, 469, 469, 469, 469, 975, 469, 469, 469, 469, 469, 469, 469, 469,
        469, 469, 975, 975, 469, 469, 469, 469, 975, 469, 469, 469, 469, 469,
        469, 469, 975, 469, 469, 469, 469, 469, 469, 975, 975, 469, 975, 469,
        469, 469, 469, 975, 975, 469, 469, 469, 469, 469, 469, 469, 469, 469,
        469, 469, 469, 469, 975, 469, 469, 469, 469, 469, 469, 975, 469, 469,
        469, 469, 975, 490, 469, 469, 469, 469, 469, 469, 975, 469, 469, 469,
        469, 469, 469, 469, 469, 469, 469, 469, 469, 469, 469, 469, 469, 469,
        469, 401, 469, 975, 469, 469, 469, 469, 469, 469, 469, 4

In [53]:
_, b = torch.max(h2[2][:, :], dim=1)
b

tensor([ 156,  726,  726,  684,  769,  428,  684,  902,  271,  778,  726,  531,
         726,  769,  726,  726,  726,  778,  531,  883,  156, 1021,  769,  684,
         726,  428,  902, 1021,  726,  684,  156,  726,  778,  726,  883,  726,
         726,  769,  902,  726,  726,  883,  769,  902,  769,  531,  726,  726,
         902,  902,  769,  883,  531,  726,  769,  726,  684,  726,  246,  778,
         221, 1021,  246,  902,  726,  246,  902,  531,  778,  769,  726,  428,
         726,  531,  726,  902,  156,  726,  726,  726,  726,  726,  726,  726,
         428,  661,  778,  246,  246,  769,  726,  769,  778,  726,  531,  726,
         726,  726,  883,  769,  778,  726,  726,  156,  961,  726,  769,  902,
         778,  769,  778,  684,  769,  726,  684,  778,  902,  726,  726,  726,
         428,  221,  726,  156,  684,  726,  401,  726,  769,  726,  778,  726,
         531,  902,  726,  902,  902,  902,  156,  726,  726,  726,  778,  726,
         902, 1021,  531,  726,  726,  7

In [54]:
_, max_h2 = torch.max(h2[7], dim=1)
_, max_h0 = torch.max(h0[7], dim=1)
(max_h2 == max_h0).sum() / max_h2.numel()

tensor(0.9648)

In [55]:
columns = ['h0', 'delta_h_1', 'delta_h_2', 'h1', 'h2']
df = pd.DataFrame(columns=columns, index=range(1, L+2))
df.index.name = 'layer'

for l in df.index:
    maxes = dict()
    
    _, maxes['h0'] = torch.max(h0[l] , dim=1)
    _, maxes['delta_h_1'] = torch.max(delta_h_1[l] , dim=1)
    _, maxes['delta_h_2'] = torch.max(delta_h_2[l] , dim=1)
    _, maxes['h1'] = torch.max(h1[l] , dim=1)
    _, maxes['h2'] = torch.max(h2[l] , dim=1)

    df.loc[l, columns] = [maxes[key].unique().numel() for key in columns]
    
df.loc[:, 'batch_size'] = batch_size
df

Unnamed: 0_level_0,h0,delta_h_1,delta_h_2,h1,h2,batch_size
layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,171,25,15,69,42,512
2,21,4,4,21,21,512
3,47,1,4,46,46,512
4,17,4,2,17,16,512
5,18,3,3,18,18,512
6,10,3,3,10,10,512
7,9,1,2,9,9,512


##### Scales

In [56]:
columns = ['h0', 'delta_h_1', 'delta_h_2', 'h1', 'h2']
df = pd.DataFrame(columns=columns, index=range(1, L+2))
df.index.name = 'layer'
for l in df.index:
    df.loc[l, columns] = [h0[l][0, :].abs().mean().item(), delta_h_1[l][0, :].abs().mean().item(), 
                          delta_h_2[l][0, :].abs().mean().item(), h1[l][0, :].abs().mean().item(),  
                          h2[l][0, :].abs().mean().item()]
df

Unnamed: 0_level_0,h0,delta_h_1,delta_h_2,h1,h2
layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,0.91465,0.902734,1.15864,1.28904,1.77124
2,2.4183,0.00362996,0.0104477,2.41916,2.42351
3,2.70222,0.00540379,0.0150112,2.70288,2.70762
4,2.86657,0.0064816,0.0172351,2.86723,2.87247
5,3.14571,0.00661806,0.0177092,3.14623,3.15094
6,3.22414,0.00611874,0.0160097,3.2243,3.22706
7,0.0846157,0.00222558,0.00523002,0.0862315,0.0902302


In [57]:
prod = h1[2][0, :] * delta_h_2[2][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.3994)

In [58]:
prod = h1[3][0, :] * delta_h_2[3][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.4199)

In [59]:
prod = h1[4][0, :] * delta_h_2[4][0, :]
(prod < 0).sum() / prod.numel()

tensor(0.4131)

In [60]:
prod = h1[5] * delta_h_2[5]
(prod < 0).sum() / prod.numel()

tensor(0.4055)

In [61]:
prod = h1[6] * delta_h_2[6]
(prod < 0).sum() / prod.numel()

tensor(0.4199)

In [62]:
prod = h1[7][132, :] * delta_h_2[7][132, :]
(prod < 0).sum() / prod.numel()

tensor(0.3000)

In [63]:
(h1[7][1, :] < 0).sum() / h1[7][1, :].numel()

tensor(0.6000)