In [1]:
!python train_models.py -h

usage: train_models.py [-h] --loop LOOP [--epochs EPOCHS] [--lr LR]
                       [--batch_size BATCH_SIZE] [--dataset DATASET]
                       [--optimizer OPTIMIZER] [--model MODEL]
                       [--wandb_mode WANDB_MODE] [--wandb_log WANDB_LOG]
                       [--wandb_log_freq WANDB_LOG_FREQ]

Train a model

options:
  -h, --help            show this help message and exit
  --loop LOOP           Loop over all the combinations of the datasets,
                        optimizers and models. 0: Disabled, 1: Enabled
  --epochs EPOCHS       Number of epochs to train for
  --lr LR               Learning rate for training
  --batch_size BATCH_SIZE
                        Batch size for training
  --dataset DATASET     Name of the dataset to train on: mnist, tmnist,
                        fashion_mnist, cifar10
  --optimizer OPTIMIZER
                        Name of the optimizer to train: SGD, HessianFree,
                        PB_BFGS, K_BFGS, K_LBFGS
 

In [39]:
!python train_models.py --loop 0 --epochs 5 --batch_size 32 --dataset tmnist --optimizer HessianFree --model SmallCNN --wandb_mode 1 --wandb_log 3

^C


In [None]:
!nohup python train_models --loop 0 --epochs 50 --batch_size 32 --dataset mnist --optimizer SGD --model SmallCNN --wandb_mode 0 --wandb_log 3 &

In [40]:
!python train_models.py --loop 0 --epochs 1 --batch_size 32 --dataset tmnist --optimizer HessianFree --model SmallCNN --wandb_mode 0 --wandb_log 3

-------


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

 50%|█████     | 1/2 [09:22<09:22, 562.89s/it]
100%|██████████| 2/2 [10:41<00:00, 278.22s/it]
100%|██████████| 2/2 [10:41<00:00, 320.93s/it]



New experiment started at 2023-05-25 17:08:29

Config: {'epochs': 2, 'learning_rate': 0.001, 'batch_size': 32, 'dataset': 'tmnist', 'optimizer': 'HessianFree', 'model': 'SmallCNN', 'architecture': 'CNN', 'wandb_log': None, 'wandb_log_freq': 1}

-------
Epoch: 0
-------
Train_loss: 6.42623 | Train_acc: 0.11 | Total_train_time: 547.9349539999967 |               Test_loss: 4.51977 | Test_acc: 0.12 | Total_test_time: 14.946251000001212

Epoch: 1
-------
Train_loss: 5.16695 | Train_acc: 0.13 | Total_train_time: 65.96289369999431 |               Test_loss: 4.51977 | Test_acc: 0.12 | Total_test_time: 12.983146700004



# Senario: SmallCNN with HessianFree

In [4]:
%load_ext autoreload
%autoreload 2

In [1]:
import sys
sys.path.append("..")

import torch
from torch import nn
import torchmetrics
import numpy as np
import src.engine as engine
import src.experiments_maker as experiments_maker
import wandb
import time
import utils.config_manager as cm


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = cm.create_config(epochs=1, batch_size=128, optimizer='K_BFGS', model='SmallCNN', wandb_log_batch=1)
with wandb.init(project="baselines_cnn", config=config, mode=cm.wandb_modes[0]):
    model, train_dataloader, test_dataloader, optimizer, criterion = experiments_maker.make(config, cm.device)
    engine.train(model, train_dataloader, test_dataloader, cm.loss_fn, optimizer, criterion, cm.device, config)

batch:  1
-------
New experiment started at 2023_06_21_09_53_34

Config: {'loop': 0, 'epochs': 1, 'lr': 0.001, 'batch_size': 128, 'dataset': 'mnist', 'optimizer': 'K_BFGS', 'model': 'SmallCNN', 'wandb_mode': 0, 'wandb_log': 3, 'wandb_log_freq': 0, 'wandb_log_batch': 1, 'slice_size': 1.0, 'activation_fn': 'Tanh', 'dropout': 0.0, 'checkpoints': 0}

-------


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 0
-------
Batch: 0/469
Loss: 2.3065879344940186
Accuracy: 0.13636364042758942
-------
Batch: 0/469
Loss: 2.2877602577209473
Accuracy: 0.18000000715255737
-------
Batch: 1/469
Loss: 2.2841403484344482
Accuracy: 0.2133333384990692
-------
Batch: 2/469
Loss: 2.28005313873291
Accuracy: 0.24166667461395264
-------
Batch: 3/469
Loss: 2.259582757949829
Accuracy: 0.30642858147621155
-------
Batch: 4/469
Loss: 2.2263941764831543
Accuracy: 0.3904545307159424
-------
Batch: 5/469
Loss: 2.1780245304107666
Accuracy: 0.5178506970405579
-------
Batch: 6/469
Loss: 2.119352102279663
Accuracy: 0.5838174819946289
-------


  0%|          | 0/1 [02:25<?, ?it/s]


KeyboardInterrupt: 

In [3]:
config = cm.create_config(epochs=1, batch_size=128, optimizer='SGD', model='SmallCNN', wandb_log_batch=32)
model, train_dataloader, test_dataloader, optimizer, criterion = experiments_maker.make(config, cm.device)

In [4]:
for (X, y) in train_dataloader:
    break
X.to(cm.device)
y.to(cm.device)
# Forward pass
y_pred = model(X)
# Calculate loss
loss = cm.loss_fn(y_pred, y)
# Optimizer zero grad
optimizer.zero_grad()
# Loss backward
loss.backward()
#optimizer.step()

In [5]:
grouped = zip(*[iter(model.parameters())]*2)
for l, (param1, param2) in enumerate(grouped):
    print(param1.grad[0])
    break

tensor([[[-7.8005e-04,  1.5315e-04,  7.8725e-04],
         [-1.0772e-03, -1.2037e-04,  7.2590e-04],
         [-9.3098e-04, -9.1770e-05,  3.3776e-04]]])


In [16]:
lw = model.conv1.weight.reshape(param1.size()[0], -1)
lw.is_leaf = 

AttributeError: attribute 'is_leaf' of 'torch._C._TensorBase' objects is not writable

In [15]:
lw.grad

  lw.grad


In [12]:
model.conv1.weight.grad[0]

tensor([[[-7.8005e-04,  1.5315e-04,  7.8725e-04],
         [-1.0772e-03, -1.2037e-04,  7.2590e-04],
         [-9.3098e-04, -9.1770e-05,  3.3776e-04]]])

In [9]:
model.layers_weights[0]['W'].grad

  model.layers_weights[0]['W'].grad


In [26]:
model.layers_weights[0]['W']

tensor([[ 0.1274,  0.1383, -0.0390,  0.1531, -0.0365,  0.0336, -0.0811,  0.0979,
          0.1469],
        [-0.1223,  0.1449,  0.0312,  0.1231,  0.0226,  0.0804, -0.0235,  0.1285,
          0.0246],
        [-0.0778,  0.0425, -0.0768, -0.0195, -0.0677,  0.1106, -0.1316, -0.0768,
         -0.0471],
        [-0.1002,  0.0157, -0.1646,  0.1505, -0.1416,  0.1287,  0.0277, -0.0541,
          0.1030],
        [ 0.0260,  0.1347,  0.0182, -0.0526,  0.0448, -0.0452,  0.0701,  0.1488,
          0.0963],
        [-0.0729,  0.0962,  0.0298,  0.0846, -0.1016, -0.1650, -0.0644, -0.1278,
          0.1368],
        [ 0.0480,  0.0690,  0.0527, -0.0029,  0.1304, -0.1184,  0.0105, -0.1138,
          0.0514],
        [-0.0574,  0.0511, -0.0347,  0.1382, -0.0988, -0.0994, -0.0994,  0.1499,
          0.0555],
        [ 0.1604, -0.1375, -0.1653, -0.1304, -0.1121,  0.0675,  0.0597,  0.1385,
         -0.0861],
        [-0.1136,  0.0884, -0.0674,  0.1012, -0.0396,  0.0953, -0.1295, -0.0841,
          0.0508],


In [16]:
x = torch.randn(128, 1, 28, 28)
print(x.shape)
x.unfold(3, 3, 1).shape

torch.Size([128, 1, 28, 28])


torch.Size([128, 1, 28, 26, 3])

In [8]:
def reshape_a_h(a, h, layers_params, kernel_size=3, stride=1):
    a_new = []
    h_new = []
    for l in range(len(a)):
        if layers_params[l] == 'conv':
            h_l = h[l].unfold(3, kernel_size, stride).unfold(2, kernel_size, stride).reshape(-1, kernel_size**2)
            a_l = h[l].unfold(3, kernel_size, stride).unfold(2, kernel_size, stride).reshape(-1, kernel_size**2)
        elif layers_params[l] == 'fc':
            h_l = h[l]
            a_l = a[l]
        a_new.append(a_l)
        h_new.append(h_l)
    return a_new, h_new


In [22]:
a = model.a
h = model.h
#a, h = reshape_a_h(a, h, model.layers_params)

In [10]:
for l in range(len(a)):
    print(a[l].reshape(-1, a[l].size()[1]).shape)

torch.Size([86528, 32])
torch.Size([67712, 32])
torch.Size([51200, 64])
torch.Size([128, 64])
torch.Size([128, 10])


In [23]:
h[0].shape

torch.Size([128, 1, 28, 28])

In [24]:
h[0].unfold(3, 3, 1).unfold(2, 3, 1).reshape(-1, 3**2).shape

torch.Size([86528, 9])

In [9]:
for l in range(len(a)):
    print(h[l].shape, a[l].shape)

torch.Size([86528, 9]) torch.Size([86528, 9])
torch.Size([2359296, 9]) torch.Size([2359296, 9])
torch.Size([1806336, 9]) torch.Size([1806336, 9])
torch.Size([128, 23104]) torch.Size([128, 64])
torch.Size([128, 64]) torch.Size([128, 10])


In [4]:
grouped = zip(*[iter(model.parameters())]*2)
for l, (param1, param2) in enumerate(grouped):
    print(param1.shape, param2.shape)
    print('In: ', param1.size()[1])
    print('Out: ', param1.size()[0])
    #print(l, param1.shape, param2.shape)

torch.Size([32, 1, 3, 3]) torch.Size([32])
In:  1
Out:  32
torch.Size([32, 32, 3, 3]) torch.Size([32])
In:  32
Out:  32
torch.Size([64, 32, 3, 3]) torch.Size([64])
In:  32
Out:  64
torch.Size([64, 23104]) torch.Size([64])
In:  23104
Out:  64
torch.Size([10, 64]) torch.Size([10])
In:  64
Out:  10


In [6]:
for l in range(model.numlayers):
    print(model.layers_weights[l]['W'].shape)
    print(model.layers_weights[l]['b'].shape)
    print(' ')

torch.Size([32, 9])
torch.Size([32])
 
torch.Size([32, 288])
torch.Size([32])
 
torch.Size([64, 288])
torch.Size([64])
 
torch.Size([64, 23104])
torch.Size([64])
 
torch.Size([10, 64])
torch.Size([10])
 


In [10]:
for h_l in model.h:
    print(h_l.shape)
print(' ')
for a_l in model.a:
    print(a_l.shape)

torch.Size([128, 1, 28, 28])
torch.Size([128, 32, 26, 26])
torch.Size([128, 32, 23, 23])
torch.Size([128, 23104])
torch.Size([128, 64])
 
torch.Size([128, 32, 26, 26])
torch.Size([128, 32, 23, 23])
torch.Size([128, 64, 20, 20])
torch.Size([128, 64])
torch.Size([128, 10])


In [52]:
# iterate through model.parameters() taking next 2 at a time
grouped = zip(*[iter(model.parameters())]*2)
for l, (param1, param2) in enumerate(grouped):
    if model._layers_params[l] == 'conv':
        #print(param2.reshape(param2.size()[0], -1).shape)
        homo_param = torch.cat((param1.reshape(param1.size()[0], -1), param2.unsqueeze(1)), dim=1)
        print(param1.grad.shape)
        print(param2.shape)
        print(homo_param.shape)
        print(' ')
    elif model._layers_params[l] == 'fc':
        homo_param = torch.cat((param1, param2.unsqueeze(1)), dim=1)
        print(param1.shape)
        print(param2.shape)
        print(homo_param.shape)
        print(' ')

torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([32, 10])
 
torch.Size([32, 32, 3, 3])
torch.Size([32])
torch.Size([32, 289])
 
torch.Size([64, 32, 3, 3])
torch.Size([64])
torch.Size([64, 289])
 
torch.Size([64, 23104])
torch.Size([64])
torch.Size([64, 23105])
 
torch.Size([10, 64])
torch.Size([10])
torch.Size([10, 65])
 


In [None]:
def update_params(self, update):
    '''
    Assumes that update is a flat 1D tensor all the deltas for all the parameters
    '''
    i = 0
    for layer in self.layers:
        W, b = tuple(layer.parameters())
        n_weights = W.numel() + b.numel()
        layer_update = update[i:i + n_weights].reshape(W.size()[0], -1)
        W_update = layer_update[:, :-1]
        b_update = layer_update[:, -1]
        if type(layer) is nn.Conv2d:
            W_update = W_update.reshape(*W.size())

        W.data.add_(W_update)
        b.data.add_(b_update)
        i += n_weights

In [4]:
model._layers_params

['conv', 'conv', 'conv', 'fc', 'fc']

In [10]:
for h_0 in model.a:
    print(h_0.shape)

torch.Size([128, 32, 26, 26])
torch.Size([128, 32, 23, 23])
torch.Size([128, 64, 20, 20])
torch.Size([128, 64])
torch.Size([128, 10])


In [12]:
model.h[3].size()[0]

128

In [5]:
h_upd = []
a_upd = []
# clear torch caches
torch.cuda.empty_cache()
for l in range(len(model.layers_params)):
    if model.layers_params[l]['name'] == 'conv':
        h_0 = model.h[l].unfold(3, 3, 1)
        h_0 = h_0.unfold(2, 3, 1)
        h_0 = h_0.reshape(-1, 9)
        ones = torch.ones(h_0.size()[0], 1)
        ones = ones.to(cm.device)
        h_0 = torch.cat([*[h_0 for _ in range(model.h[l].size(1))], ones], dim=1)
        A_temp = torch.matmul(torch.t(h_0), h_0) / h_0.size(0)
        print(A_temp.shape)
    elif model.layers_params[l] == 'fc':
        h_0 = model.h[l]
        ones = torch.ones(h_0.size()[0], 1)
        ones = ones.to(cm.device)
        h_0 = torch.cat([h_0.data, ones], dim=1)
        A_temp = torch.matmul(torch.t(h_0), h_0) / h_0.size(0)
        print(A_temp.shape)
print(' ')
# for l in range(len(model._layers)):
#     if model.a[l].dim() > 2:
#         a_upd.append(model.a[l].mean(dim=[2,3]))
#     else:
#         a_upd.append(model.a[l])
#     print(a_upd[l].shape)

 


In [10]:
h_0 = model.h[0]
print(h_0.shape)
h_0 = h_0.unfold(3, 3, 1)
print(h_0.shape)
h_0 = h_0.unfold(2, 3, 1)
print(h_0.shape)
h_0 = h_0.reshape(-1, 9)
print(h_0.shape)

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 26, 3])
torch.Size([128, 1, 26, 26, 3, 3])
torch.Size([86528, 9])


In [15]:
h_0 = model.h[0]
print(h_0.shape)
h_0 = h_0.unfold(3, 3, 1)
print(h_0.shape)
h_0 = h_0.unfold(2, 3, 1)
print(h_0.shape)
h_0 = h_0.mean(dim=[2,3])
print(h_0.shape)
h_0 = h_0.reshape(-1, 9)
print(h_0.shape)

torch.Size([128, 1, 28, 28])
torch.Size([128, 1, 28, 26, 3])
torch.Size([128, 1, 26, 26, 3, 3])
torch.Size([128, 1, 3, 3])
torch.Size([128, 9])


In [21]:
h_upd = []
a_upd = []
# clear torch caches
torch.cuda.empty_cache()
# time the forward pass
start = time.time()
for l in range(len(model.layers_params)):
    if model.layers_params[l]['name'] == 'conv':
        h_0 = model.h[l].unfold(3, 3, 1)
        h_0 = h_0.unfold(2, 3, 1)
        #h_0 = h_0.mean(dim=[2,3])
        h_0 = h_0.reshape(-1, 9)
        ones = torch.ones(h_0.size()[0], 1)
        ones = ones.to(cm.device)
        h_0 = torch.cat([*[h_0 for _ in range(model.h[l].size(1))], ones], dim=1)
        A_temp = torch.matmul(torch.t(h_0), h_0) / h_0.size(0)
        #print(A_temp.shape)
    elif model.layers_params[l]['name'] == 'fc':
        h_0 = model.h[l]
        ones = torch.ones(h_0.size()[0], 1)
        ones = ones.to(cm.device)
        h_0 = torch.cat([h_0.data, ones], dim=1)
        A_temp = torch.matmul(torch.t(h_0), h_0) / h_0.size(0)
        #print(A_temp.shape)
end = time.time() - start
print('Time taken: ', end)

Time taken:  2.2812726497650146


In [12]:
a_l = model.a[0]
print(a_l.shape)
a_l = a_l.mean(dim=[2,3])
print(a_l.shape)

torch.Size([96, 32, 26, 26])
torch.Size([96, 32])


In [13]:
a_l = model.a[0]
a_l = a_l.reshape(-1, model.a[0].size()[1])
print(a_l.shape)

torch.Size([64896, 32])


In [None]:
0.9164321422576904

In [23]:
def get_Al_size(l):
    if model._layers_params[l] == 'conv':
        return model.h[l].size(1) * 9 + 1
    elif model._layers_params[l] == 'fc':
        return model.h[l].size(1) + 1
print(get_Al_size(4))

65


In [None]:
def get_hl_size(l):
    if model._layers_params[l] == 'conv':
        return model.h[l].size(1) * 9 + 1
    elif model._layers_params[l] == 'fc':
        return model.h[l].size(1) + 1

In [24]:
h_upd = []
a_upd = []
# clear torch caches
torch.cuda.empty_cache()
for l in range(len(model._layers_params)):
    print(get_Al_size(l))

10
289
289
23105
65


In [38]:
p_mean = torch.mean(model.a[0], dim=0)

In [40]:
p_mean.shape
    

torch.Size([32, 26, 26])

In [69]:
params = []
for p in model.parameters():
    print(p.shape)
    params.append(p)

torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([32, 32, 3, 3])
torch.Size([32])
torch.Size([64, 32, 3, 3])
torch.Size([64])
torch.Size([64, 23104])
torch.Size([64])
torch.Size([10, 64])
torch.Size([10])


In [82]:
params[0].shape

torch.Size([32, 1, 3, 3])

In [91]:
params[0][0].sum()/9

tensor(0.1201, grad_fn=<DivBackward0>)

In [88]:
tau_params = torch.mean(params[0], dim=2)
tau_params = torch.mean(tau_params, dim=2)
print(tau_params.shape)
print(tau_params)

torch.Size([32, 1])
tensor([[ 0.1201],
        [ 0.0910],
        [-0.0765],
        [-0.0077],
        [ 0.0980],
        [-0.0409],
        [ 0.0282],
        [ 0.0011],
        [-0.0456],
        [-0.0219],
        [ 0.0073],
        [-0.1558],
        [-0.0660],
        [-0.1591],
        [ 0.0093],
        [-0.0949],
        [ 0.0441],
        [ 0.0455],
        [-0.0340],
        [-0.0450],
        [ 0.0576],
        [ 0.0283],
        [-0.0292],
        [-0.0171],
        [-0.0611],
        [ 0.0505],
        [-0.0766],
        [-0.0932],
        [-0.0143],
        [-0.0157],
        [ 0.0564],
        [-0.1206]], grad_fn=<MeanBackward1>)
