In [13]:
import torch.nn as nn
import torch 
from typing import Iterable, List
import numpy as np

from torch.utils.data import Dataset, Sampler
from torch.utils.data.dataloader import _collate_fn_t, _worker_init_fn_t

from torchvision.transforms import PILToTensor

from typing import Any
import lightning as L
import torch.utils.data as tud
from torchvision.datasets import MNIST
import torch.optim as opt
import torch.nn.functional as F

import wandb


In [14]:

class Mlp(nn.Module):
    def __init__(self, dims, *args, **kwargs) -> None:
        super().__init__()

        self.linears = nn.ModuleList([
            nn.Linear(a, b) for a, b in zip(dims[:-1], dims[1:])
        ])

        self.act = nn.SELU()

    def forward(self, x):
        for layer in self.linears[:-1]:
            x = layer(x)
            x = self.act(x)

        final = self.linears[-1]
        return final(x)


class ResMlp(Mlp):
    def __init__(self, dims, *args, **kwargs) -> None:
        super().__init__(dims, *args, **kwargs)

    def forward(self, x):
        x = self.linears[0](x)
        x = self.act(x)
        for layer in self.linears[1:-1]:
            z = layer(x)
            x = self.act(z) + x

        final = self.linears[-1]
        return final(x)

In [15]:

from typing import Any


class Model(L.LightningModule):
    def __init__(self, n_layers, hidden_dim, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.mlp = ResMlp([28*28] + self.n_layers * [self.hidden_dim]+[10])

        self.save_hyperparameters()

    def training_step(self, batch, batch_index, *args: Any, **kwargs: Any):
        # return super().training_step(*args, **kwargs)
        x, y = batch
        x = x.view(-1, 784)

        logits = self.mlp(x)
        labels = F.one_hot(y, num_classes=10).type(torch.float32)

        loss = nn.functional.cross_entropy(logits, labels)

        return loss
    
    def test_step(self, batch, *args):
        x, y = batch
        x = x.view(-1, 784)

        logits = self.mlp(x)

        top1_error = (torch.argmax(logits, -1) != y).sum()/y.shape[0]
        self.log('test/top1_error', top1_error)

        labels = F.one_hot(y, num_classes=10).type(torch.float32)




        loss = nn.functional.cross_entropy(logits, labels)

        self.log('test/cce_loss', loss)

        return loss

    
    def on_train_start(self) -> None:
        # return super().on_train_start()
        self.init_params = {k:v.detach().clone() for k,v in self.state_dict().items() if 'weight' in k and v.shape[0]==v.shape[1] }
    def on_train_end(self) -> None:
        # return super().on_train_end()
        keys = list(self.init_params.keys())
        params = self.state_dict()
        diffs = [torch.abs(params[k] - self.init_params[k]).flatten().tolist() for k in keys]

        _, bins = np.histogram(diffs, bins=10)
        hists = [np.histogram(c, bins)[0] for c in diffs]

        # wandb.log({
        #     'changes/weight_delta': wandb.Table(keys, list(zip(*diffs)))
        # }
        # )

        self.store = diffs
        for k, h in enumerate(hists):
            wandb.log({
                'changes/weights': wandb.Histogram(np_histogram=(h, bins)), 'layer': k
            })


    def configure_optimizers(self):

        optim = opt.AdamW(self.mlp.parameters())
        return optim

In [16]:
dummy_model = ResMlp([28*28, 1, 10])

X = torch.zeros((1, 28*28))
dummy_model(X)

tensor([[ 0.4241, -0.8525,  0.2553,  0.0081,  0.2327,  0.4389,  0.3660,  0.4610,
          0.7931, -0.3177]], grad_fn=<AddmmBackward0>)

In [17]:


class MnistData(tud.Dataset):
    def __init__(self, train):
        self.mnist = MNIST('mnist', train=train)
        self.transform = PILToTensor()

    def __len__(self) -> int:
        return len(self.mnist)

    def __getitem__(self, index):
        im, label = self.mnist[index]
        return  (self.transform(im)/255).type(torch.float32), label

In [18]:
model = Model(20, 5)
model.compile()

data = tud.DataLoader(MnistData(True), batch_size=256)
test_data = tud.DataLoader(MnistData(False), batch_size=1024)

logger = L.pytorch.loggers.WandbLogger()

trainer = L.Trainer(max_epochs=2, logger=logger)
logger.watch(model)

trainer.fit(model, data, test_data)



GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/loggers/wandb.py:390: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory ./lightning_logs/z5r4omma/checkpoints exists and is not empty.

  | Name | Type   | P

Epoch 1:  86%|████████▌ | 201/235 [00:07<00:01, 27.90it/s, v_num=omma]

In [None]:
trainer.test(model, test_data)

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 30.69it/s]


[{'test/top1_error': 0.20430000126361847, 'test/cce_loss': 0.7030673027038574}]

In [None]:
diffs = model.store

In [None]:
_, bins = np.histogram(diffs, bins=10)
hists = [np.histogram(c, bins)[0] for c in diffs]

In [None]:
hists

[array([9, 9, 4, 2, 0, 0, 1, 0, 0, 0]),
 array([6, 6, 6, 3, 3, 1, 0, 0, 0, 0]),
 array([10,  7,  2,  2,  4,  0,  0,  0,  0,  0]),
 array([8, 9, 4, 3, 1, 0, 0, 0, 0, 0]),
 array([10, 11,  1,  2,  0,  0,  1,  0,  0,  0]),
 array([10, 10,  3,  1,  1,  0,  0,  0,  0,  0]),
 array([9, 9, 1, 2, 3, 1, 0, 0, 0, 0]),
 array([11,  7,  4,  1,  0,  2,  0,  0,  0,  0]),
 array([10,  9,  2,  1,  1,  0,  1,  0,  0,  1]),
 array([9, 8, 4, 2, 2, 0, 0, 0, 0, 0]),
 array([8, 6, 6, 2, 1, 2, 0, 0, 0, 0]),
 array([12,  8,  3,  0,  2,  0,  0,  0,  0,  0]),
 array([10,  4,  1,  2,  3,  1,  2,  1,  1,  0]),
 array([8, 7, 5, 2, 3, 0, 0, 0, 0, 0]),
 array([6, 9, 5, 2, 1, 1, 0, 0, 1, 0]),
 array([12,  2,  3,  4,  1,  1,  1,  0,  1,  0]),
 array([11,  7,  1,  3,  3,  0,  0,  0,  0,  0]),
 array([3, 4, 2, 5, 6, 3, 1, 1, 0, 0]),
 array([6, 5, 6, 2, 1, 3, 1, 1, 0, 0])]

In [None]:
len(model.store)

19