# imports

In [1]:
!pip install pytorch_lightning -qq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m826.4/826.4 KB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m517.2/517.2 KB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
import re
import os
import itertools
from typing import Optional
from collections import OrderedDict
from functools import partial

#import git

import torch
from torch import nn, optim, linalg
from torch.nn import functional as F
from torchmetrics.functional import accuracy


import torchmetrics

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms

#from src.models.mlp import SkipMLP, SimpleMLP
#from src.data.mnist import MNISTDataModule


# geometric bounds

In [None]:
#from collections import OrderedDict
from functools import partial
#import torch
#from torch import nn, optim, linalg
from functorch import jacrev, vmap


def calc_bound(w, h, jh):
    return (
        (1 + linalg.vector_norm(h, dim=1)**2) /
        (linalg.matrix_norm(w, ord=2)**2 *
        linalg.matrix_norm(jh, ord='fro')**2)
    )

def calc_tight_bound(w, h, jh):
    return (
        (1 + linalg.vector_norm(h, dim=1)**2) /
        (linalg.matrix_norm(torch.bmm(w, jh), ord='fro')**2)
    )

def mlp_bound(weight_layer, partial_net, x, bound_f = calc_tight_bound):
    h_0 = x.view(x.shape[0],-1)
    if len(partial_net) > 0:
        h_n = partial_net(h_0)
        jh_n = vmap(jacrev(partial_net))(h_0)
    else:
        h_n = h_0
        jh_n = torch.eye(h_0.shape[1]).unsqueeze(0).repeat(x.shape[0],1,1)

    w = weight_layer.weight.repeat(x.shape[0],1,1)
    return bound_f(w, h_n, jh_n), (w,h_n,jh_n)

def skip_mlp_bound(weight_layer, partial_net, x, bound_f = calc_tight_bound):
    negative_slope = weight_layer.activation.negative_slope

    h_0 = x.view(x.shape[0],-1)
    h_n = partial_net(h_0)
    jh_n = vmap(jacrev(partial_net))(h_0)

    w_base = weight_layer.fc.weight.repeat(x.shape[0],1,1)
    w_skip =torch.eye(*w_base.shape[1:]).unsqueeze(0).repeat(x.shape[0],1,1)
    w_skip = w_skip * (1 + (torch.gt(-weight_layer.fc(h_n), 0) * (negative_slope**(-1) - 1 ))).unsqueeze(1)
    w = w_base + w_skip
    return bound_f(w, h_n, jh_n), (w,h_n,jh_n)

def get_bounds(model, x):
    modules_list = list(model.layers._modules.items())
    bound_types = {
        'fc': mlp_bound,
        'sk': skip_mlp_bound
    }
    layer_wise_bound = []
    w= []
    h= []
    h_prime = []
    for i in range(len(modules_list)):
        name, layer = modules_list[i]
        if name[:2] in bound_types.keys():
            bound, parts_tuple = bound_types[name[:2]](
                layer,
                nn.Sequential(OrderedDict(modules_list[:i])),
                x
            )
            layer_wise_bound.append(bound)
            w.append(parts_tuple[0])
            h.append(parts_tuple[1])
            h_prime.append(parts_tuple[2])
    return torch.stack(layer_wise_bound).T.detach()#, torch.stack(w).T.detach(), torch.stack(h).T.detach(), torch.stack(h_prime).T.detach()

def compose_funcs(f,g): return lambda x : g(f(x))

def norm_grad_params(model):
    return torch.tensor([(param.grad**2).sum() for param in model.parameters()])

def norm_grad_x(model, loss_fn, x, labels):
    grad_x = jacrev(compose_funcs(model, partial(loss_fn,target=labels)))(x)
    return (grad_x**2).sum().detach()

def norm_grads(model, loss_fn, optimizer, x, labels):
    #optimizer.zero_grad()
    inputs = x.clone()
    inputs.requires_grad_(True)
    inputs.retain_grad()

    outputs = model(inputs)
    loss = loss_fn(outputs, labels)
    loss.backward()

    norm_grad_wrt_input = (inputs.grad**2).sum()
    norm_grad_wrt_params = torch.tensor([(param.grad**2).sum() for param in model.parameters()])
    for param in model.parameters():
        param.grad = None

    return (
        norm_grad_wrt_input,
        norm_grad_wrt_params
    )

def norm_gradients(model, loss_fn, inputs, labels):
    norm_gradients_batch=[]
    for _img, _label in zip(inputs, labels):
        img = _img.unsqueeze(0)
        label= _label.unsqueeze(0)
        norm_gradients_batch.append(norm_grads(model, loss_fn,None, img, label))
    return [torch.stack(gradients) for gradients in zip(*norm_gradients_batch)]

# model definition

## block definitions

In [3]:
class SimpleMlpBlock(nn.Module):
    def __init__(self, in_features, out_features, negative_slope):
        super().__init__()
        self.fc = nn.Linear(in_features=in_features, out_features=out_features)
        self.activation = nn.LeakyReLU(negative_slope=negative_slope)
    
    def forward(self, x):
        return self.activation(self.fc(x))

class SkipMlpBlock(nn.Module):
    def __init__(self, in_features, out_features, negative_slope):
        super().__init__()
        self.fc = nn.Linear(in_features=in_features, out_features=out_features)
        self.activation = nn.LeakyReLU(negative_slope=negative_slope)
        
    def forward(self, x):
        return x + self.activation(self.fc(x))

## simple model architecture

In [None]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        next_layer_input = config["input_size"]#784
        #layers = []
        hidden_layers = config['hl_depth'] * [config['hl_width']]
        layers = dict()
        BlockType = SkipMlpBlock if config["use_skip"] else SimpleMlpBlock
        for _i, hidden_layer in enumerate(hidden_layers):
            if next_layer_input == hidden_layer:
                layers.update({'skip_fc'+str(_i): BlockType(in_features=next_layer_input, out_features=hidden_layer, negative_slope=config['negative_slope'])})
            else:
                layers.update({'fc'+ str(_i): nn.Linear(in_features=next_layer_input, out_features=hidden_layer)})
                layers.update({'af'+ str(_i): nn.LeakyReLU(negative_slope=config['negative_slope'])})
            # Update input size
            next_layer_input = hidden_layer

        layers.update({'fc'+ str(len(hidden_layers)): nn.Linear(in_features=next_layer_input, out_features=config["out_features"])})
        #layers.update({'af'+ str(len(hidden_layers)): nn.LeakyReLU(negative_slope=config['negative_slope'])})
        self.layers = nn.Sequential(layers)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x_resized = x.view(batch_size, -1)
        return self.layers(x_resized)

## Training run script

In [None]:
#from collections import OrderedDict
#import pdb
#import torch
#import pytorch_lightning as pl
#import src.func_geometric_bounds as gb

class RunBase(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        data, target = batch

        # Bound statistics
        if batch_idx %16==0:
            #bounds, w, h, jh = gb.get_bounds(self, data)
            bounds = gb.get_bounds(self, data)
            gradients_data, gradients_params = gb.norm_gradients(self, nn.CrossEntropyLoss(label_smoothing=0.1), data, target)
            self.log('bounds/Gradients_x', gradients_data.sum())
            self.log('bounds/Bound', bounds.sum())
            self.log('bounds/Bound div Gradients_x', (bounds.sum(dim=1)/gradients_data).mean())
            self.log('bounds/Gradients_x div Bound', (gradients_data/bounds.sum(dim=1)).mean())
            self.log('bounds/Gradients_x times bound',(
                gradients_data* bounds.sum(dim=1)
            ).sum(), on_step=True)
            self.log('bounds/Gradients parameters',gradients_params.sum(), on_step=True)

            #self.log('bounds/w dot jh', linalg.matrix_norm(torch.bmm(w, jh), ord='fro').mean())
            #self.log('bounds/h',linalg.vector_norm(h, dim=1).mean())


        preds = self(data)
        loss = F.cross_entropy(preds, target, label_smoothing=0.1)
        # Logging to TensorBoard by default
        self.log('train/acc', accuracy(preds, target), on_step=True, on_epoch=True)
        self.log("train/loss", loss)
        return loss

    def validation_step(self, valid_batch, batch_idx):
        data, target = valid_batch
        preds = self(data)
        _, max_pred = torch.max(preds, 1)
        loss = F.cross_entropy(preds, target, label_smoothing=0.1)
        self.log("validation/loss", loss)
        self.log('validation/acc', accuracy(preds, target), on_step=True, on_epoch=True)
        return loss

# dataset definition

In [None]:
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./", batch_size = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict" or stage is None:
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)


# run

In [None]:
def main(config):
    mnist = MNISTDataModule(data_dir="./data", batch_size=config['batch_size'])
    model = config['model_type'](config)
    model_name = re.findall(r"[\w]+", str(type(model)))[-1]

    logger = TensorBoardLogger(
        'lightning_logs/',
        name=model_name,
        version="depth"+str(config['hl_depth'])
    )
    logger.log_hyperparams(config)

    trainer = pl.Trainer(max_epochs=config['max_epochs'],
                         num_processes=1,
                         #accelerator='gpu',
                         #devices=1,
                         logger=logger,
                         deterministic=True)
    trainer.fit(model, datamodule=mnist)

In [None]:
pl.seed_everything(1234)
# Git current git commit:
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
config = {
    # git revision
    'sha': sha,
    # dataset
    'batch_size': 32,
    # model config
    'model_type': SimpleMLP,
    #'num_parameters': (28**2 + 10) * w + (h-1)*w**2 # for single P 28**2 * 10
    #'hl_depth': 2,
    #'hl_width': 40,
    'negative_slope': 0.01,
    # training
    'max_epochs': 20,
}
# w * ((28*2 + 10) + (h-1)*w) aaprox h*w^2
# TODO: Find way to split depth versus wide.
# (h,w) = [(1,8), (4,4), (16,2), (64,1), () ]
# TODO: Change loop to set width and weight dependent on num parameters
# TODO: Perhaps around 10 layers of 40x40 size ish. And 1,2,4 to that.
# (2**4 * 10**4)
# OLD
num_params = lambda w, h : (h-1)*w**2 + (28**2 + 10) * w
base_width = 8 * 40 //2
base_depth = 1

for n in range(1):
    config['hl_width'] = base_width // (2**n)
    config['hl_depth'] = base_depth * 2**(2*n) + int(
        2**(-0.5) * 2**(2**0.75*(n-1)) * (n>0) * (28**2 + 10)/config['hl_width'] #* (1-2**(-n))
    )
    #print(config['hl_width'])
    #print(config['hl_depth'])
    #print(num_params(config['hl_width'], config['hl_depth']))
    print(config)
    import ipdb; ipdb.set_trace()
    main(config)