## Introduction
This notebook tests the model tracing algorithm implemented in [idlmav_tracing.py](../idlmav_tracing.py). For notebooks containing development notes and exploratory code for this algorithm, see [07_explore_tracing.ipynb](./07_explore_tracing.ipynb) and [09_explore_fx_interpreter](./09_explore_fx_interpreter.ipynb)

## Imports

In [1]:
import sys
import os
import importlib
workspace_path = os.path.abspath(os.path.join(os.path.abspath(''), '..'))
sys.path.append(workspace_path)

from idlmav_tracing import MavTracer

In [2]:
from typing import Dict, Tuple
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torchvision
import timm

import torchinfo

from miniai.init import clean_mem
from miniai.resnet import ResBlock, act_gr, conv

### To reload updated imports

In [3]:
def reload_imports():
    importlib.reload(sys.modules['idlmav_tracing'])
    global MavTracer
    from idlmav_tracing import MavTracer

## Models

### Hand-written CNN (for MNIST)
* Found [here](https://github.com/pytorch/examples/blob/main/mnist/main.py)
* Modified for experimental purposes, e.g. breaking the forward pass

In [4]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

def get_model_cnn_mnist():
    return MnistCnn()

def get_model_cnn_mnist_broken():
    model = MnistCnn()
    model.fc1 = nn.Linear(9216, 120)
    return model

### ResNet18 (from TorchVision)

In [5]:
def get_model_tv_resnet18():
    return torchvision.models.resnet.resnet18()

### ResNet18 (from timm)

In [6]:
def get_model_timm_resnet18():
    return timm.create_model('resnet18', in_chans=3, num_classes=10)

### ResNet (from miniai)

In [7]:
def get_model_miniai_resnet(act=nn.ReLU, nfs=(16,32,64,128,256), norm=nn.BatchNorm2d):
    class GlobalAvgPool(nn.Module):
        def forward(self, x:Tensor): return x.mean((-2,-1))
    layers = [conv(1, 16, ks=5, stride=1, act=act, norm=norm)]
    layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
    layers += [GlobalAvgPool(), nn.Linear(256, 10, bias=False), nn.BatchNorm1d(10)]
    return nn.Sequential(*layers)

### Autoencoder (from miniai)

In [8]:
def get_model_miniai_autoenc(act=act_gr, nfs=(32,64,128,256,512), norm=nn.BatchNorm2d, drop=0.1):
    def up_block(ni, nf, ks=3, act=act_gr, norm=None):
        return nn.Sequential(nn.UpsamplingNearest2d(scale_factor=2),
                            ResBlock(ni, nf, ks=ks, act=act, norm=norm))
    layers = [ResBlock(3, nfs[0], ks=5, stride=1, act=act, norm=norm)]
    layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]
    layers += [up_block(nfs[i], nfs[i-1], act=act, norm=norm) for i in range(len(nfs)-1,0,-1)]
    layers += [ResBlock(nfs[0], 3, act=nn.Identity, norm=norm)]
    return nn.Sequential(*layers)

### UNet (from miniai)

In [9]:
def up_block(ni, nf, ks=3, act=act_gr, norm=None):
    return nn.Sequential(nn.UpsamplingNearest2d(scale_factor=2),
                        ResBlock(ni, nf, ks=ks, act=act, norm=norm))

class TinyUnet(nn.Module):
    def __init__(self, act=act_gr, nfs=(32,64,128,256,512), norm=nn.BatchNorm2d):
        super().__init__()
        self.start = ResBlock(3, nfs[0], stride=1, act=act, norm=norm)
        self.dn = nn.ModuleList([ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2)
                                 for i in range(len(nfs)-1)])
        self.up = nn.ModuleList([up_block(nfs[i], nfs[i-1], act=act, norm=norm)
                                 for i in range(len(nfs)-1,0,-1)])
        self.up += [ResBlock(nfs[0], 3, act=act, norm=norm)]
        self.end = ResBlock(3, 3, act=nn.Identity, norm=norm)

    def forward(self, x):
        layers = []
        layers.append(x)
        x = self.start(x)
        for l in self.dn:
            layers.append(x)
            x = l(x)
        n = len(layers)
        for i,l in enumerate(self.up):
            if i!=0: x += layers[n-i]
            x = l(x)
        return self.end(x+layers[0])
    
def get_model_miniai_unet():
    return TinyUnet()

### Model preparation helpers

In [10]:
def get_model_and_input_size(name, device):
    match name:
        case 'cnn_mnist':
            return get_model_cnn_mnist().to(device), (16,1,28,28)
        case 'cnn_mnist_broken':
            return get_model_cnn_mnist_broken().to(device), (16,1,28,28)
        case 'tv_resnet18':
            return get_model_tv_resnet18().to(device), (16,3,160,160)
        case 'timm_resnet18':
            return get_model_timm_resnet18().to(device), (16,3,160,160)
        case 'miniai_resnet':
            return get_model_miniai_resnet().to(device), (16,1,28,28)
        case 'miniai_autoenc':
            return get_model_miniai_autoenc().to(device), (16,3,160,160)
        case 'miniai_unet':
            return get_model_miniai_unet().to(device), (16,3,160,160)
        case _:
            return None
        
def get_models_and_input_tensors(model_names, device):
    """
    This function returns 3 dictionaries, all indexed by the model name:
    * A dictionary of `nn.Module` models
    * A dictionary of tuples: the input sizes for each model
    * A dictionary of tensors that can be passed to each model
    """
    clean_mem()
    models: Dict[str, nn.Module] = {}
    input_sizes: Dict[str, Tuple[int]] = {}
    inputs: Dict[str, Tensor] = {}
    for name in model_names:
        model, input_size = get_model_and_input_size(name, device)
        x = torch.randn(input_size).to(device)
        models[name] = model
        input_sizes[name] = input_size
        inputs[name] = x
    return models, input_sizes, inputs

def remove_all_hooks(model:nn.Module):
    # Iterate through all submodules
    for module in model.modules():
        # Clear forward hooks
        if hasattr(module, '_forward_hooks'):
            module._forward_hooks.clear()

        # Clear backward hooks
        if hasattr(module, '_backward_hooks'):
            module._backward_hooks.clear()

        # Clear pre-forward hooks if any (less common)
        if hasattr(module, '_forward_pre_hooks'):
            module._forward_pre_hooks.clear()


In [11]:
device = 'cpu'  # ['cpu','cuda']
model_names = ['cnn_mnist', 'cnn_mnist_broken', 'tv_resnet18', 'timm_resnet18', 'miniai_resnet', 'miniai_autoenc', 'miniai_unet']
models, input_sizes, inputs = get_models_and_input_tensors(model_names, device)

## Tests

In [12]:
def run_algorithm(model_name):
    tracer = MavTracer(models[model_name], inputs[model_name], device=device)
    print(tracer.summary())

def run_comparison(model_name):
    torchinfo.summary(models[model_name], input_size=input_sizes[model_name], verbose=1, depth=8, 
                      col_names=["output_size","num_params","mult_adds","input_size","kernel_size"]);

### Hand-written CNN (for MNIST)

In [13]:
run_algorithm('cnn_mnist')

name         operation      activations         params      flops
-----------  -------------  ----------------  --------  ---------
x            input          (16, 1, 28, 28)          0          0
conv1        nn.Conv2d      (16, 32, 26, 26)       320    6230016
relu         relu()         (16, 32, 26, 26)         0          0
conv2        nn.Conv2d      (16, 64, 24, 24)     18496  339738624
relu_1       relu()         (16, 64, 24, 24)         0          0
max_pool2d   max_pool2d()   (16, 64, 12, 12)         0          0
dropout1     nn.Dropout     (16, 64, 12, 12)         0          0
flatten      flatten()      (16, 9216)               0          0
fc1          nn.Linear      (16, 128)          1179776   37748736
relu_2       relu()         (16, 128)                0          0
dropout2     nn.Dropout     (16, 128)                0          0
fc2          nn.Linear      (16, 10)              1290      40960
log_softmax  log_softmax()  (16, 10)                 0          0
output    

In [14]:
run_comparison('cnn_mnist')

Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds                 Input Shape               Kernel Shape
MnistCnn                                 [16, 10]                  --                        --                        [16, 1, 28, 28]           --
├─Conv2d: 1-1                            [16, 32, 26, 26]          320                       3,461,120                 [16, 1, 28, 28]           [3, 3]
├─Conv2d: 1-2                            [16, 64, 24, 24]          18,496                    170,459,136               [16, 32, 26, 26]          [3, 3]
├─Dropout: 1-3                           [16, 64, 12, 12]          --                        --                        [16, 64, 12, 12]          --
├─Linear: 1-4                            [16, 128]                 1,179,776                 18,876,416                [16, 9216]                --
├─Dropout: 1-5                           [16, 128]                 --                        -

### Broken model

In [15]:
run_algorithm('cnn_mnist_broken')

name         operation      activations         params      flops
-----------  -------------  ----------------  --------  ---------
x            input          (16, 1, 28, 28)          0          0
conv1        nn.Conv2d      (16, 32, 26, 26)       320    6230016
relu         relu()         (16, 32, 26, 26)         0          0
conv2        nn.Conv2d      (16, 64, 24, 24)     18496  339738624
relu_1       relu()         (16, 64, 24, 24)         0          0
max_pool2d   max_pool2d()   (16, 64, 12, 12)         0          0
dropout1     nn.Dropout     (16, 64, 12, 12)         0          0
flatten      flatten()      (16, 9216)               0          0
fc1          nn.Linear      (16, 120)          1106040   35389440
relu_2       relu()         (16, 120)                0          0
dropout2     nn.Dropout     (16, 120)                0          0
fc2          nn.Linear      (0,)                  1290          0
log_softmax  log_softmax()  (0,)                     0          0
output    


While executing %fc2 : [num_users=1] = call_module[target=fc2](args = (%dropout2,), kwargs = {})
Original traceback:
None


In [16]:
try:
    run_comparison('cnn_mnist_broken')
except Exception as e:
    print(f'Comparison failed: {e}')


Comparison failed: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Conv2d: 1, Conv2d: 1, Dropout: 1, Linear: 1, Dropout: 1]


### ResNet18 (from TorchVision)

In [17]:
run_algorithm('tv_resnet18')

name                   operation             activations          params       flops
---------------------  --------------------  -----------------  --------  ----------
x                      input                 (16, 3, 160, 160)         0           0
conv1                  nn.Conv2d             (16, 64, 80, 80)       9408  1926758400
bn1                    nn.BatchNorm2d        (16, 64, 80, 80)        128    13107200
relu                   nn.ReLU               (16, 64, 80, 80)          0           0
maxpool                nn.MaxPool2d          (16, 64, 40, 40)          0           0
layer1_0_conv1         nn.Conv2d             (16, 64, 40, 40)      36864  1887436800
layer1_0_bn1           nn.BatchNorm2d        (16, 64, 40, 40)        128     3276800
layer1_0_relu          nn.ReLU               (16, 64, 40, 40)          0           0
layer1_0_conv2         nn.Conv2d             (16, 64, 40, 40)      36864  1887436800
layer1_0_bn2           nn.BatchNorm2d        (16, 64, 40, 40)    

In [18]:
run_comparison('tv_resnet18')

Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds                 Input Shape               Kernel Shape
ResNet                                   [16, 1000]                --                        --                        [16, 3, 160, 160]         --
├─Conv2d: 1-1                            [16, 64, 80, 80]          9,408                     963,379,200               [16, 3, 160, 160]         [7, 7]
├─BatchNorm2d: 1-2                       [16, 64, 80, 80]          128                       2,048                     [16, 64, 80, 80]          --
├─ReLU: 1-3                              [16, 64, 80, 80]          --                        --                        [16, 64, 80, 80]          --
├─MaxPool2d: 1-4                         [16, 64, 40, 40]          --                        --                        [16, 64, 80, 80]          3
├─Sequential: 1-5                        [16, 64, 40, 40]          --                        --    

### ResNet18 (from timm)

In [19]:
run_algorithm('timm_resnet18')

name                   operation             activations          params       flops
---------------------  --------------------  -----------------  --------  ----------
x                      input                 (16, 3, 160, 160)         0           0
conv1                  nn.Conv2d             (16, 64, 80, 80)       9408  1926758400
bn1                    nn.BatchNorm2d        (16, 64, 80, 80)        128    13107200
act1                   nn.ReLU               (16, 64, 80, 80)          0           0
maxpool                nn.MaxPool2d          (16, 64, 40, 40)          0           0
layer1_0_conv1         nn.Conv2d             (16, 64, 40, 40)      36864  1887436800
layer1_0_bn1           nn.BatchNorm2d        (16, 64, 40, 40)        128     3276800
layer1_0_drop_block    nn.Identity           (16, 64, 40, 40)          0           0
layer1_0_act1          nn.ReLU               (16, 64, 40, 40)          0           0
layer1_0_aa            nn.Identity           (16, 64, 40, 40)    

In [20]:
run_comparison('timm_resnet18')

Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds                 Input Shape               Kernel Shape
ResNet                                   [16, 10]                  --                        --                        [16, 3, 160, 160]         --
├─Conv2d: 1-1                            [16, 64, 80, 80]          9,408                     963,379,200               [16, 3, 160, 160]         [7, 7]
├─BatchNorm2d: 1-2                       [16, 64, 80, 80]          128                       2,048                     [16, 64, 80, 80]          --
├─ReLU: 1-3                              [16, 64, 80, 80]          --                        --                        [16, 64, 80, 80]          --
├─MaxPool2d: 1-4                         [16, 64, 40, 40]          --                        --                        [16, 64, 80, 80]          3
├─Sequential: 1-5                        [16, 64, 40, 40]          --                        --    

### ResNet (from miniai)

In [21]:
run_algorithm('miniai_resnet')

name          operation       activations         params      flops
------------  --------------  ----------------  --------  ---------
input_1       input           (16, 1, 28, 28)          0          0
_0_0          nn.Conv2d       (16, 16, 28, 28)       416   10035200
_0_1          nn.BatchNorm2d  (16, 16, 28, 28)        32     401408
_0_2          nn.ReLU         (16, 16, 28, 28)         0          0
_1_convs_0_0  nn.Conv2d       (16, 32, 28, 28)      4640  115605504
_1_convs_0_1  nn.BatchNorm2d  (16, 32, 28, 28)        64     802816
_1_convs_0_2  nn.ReLU         (16, 32, 28, 28)         0          0
_1_convs_1_0  nn.Conv2d       (16, 32, 14, 14)      9248   57802752
_1_convs_1_1  nn.BatchNorm2d  (16, 32, 14, 14)        64     200704
_1_pool       nn.AvgPool2d    (16, 16, 14, 14)         0     100352
_1_idconv_0   nn.Conv2d       (16, 32, 14, 14)       544    3211264
add           add()           (16, 32, 14, 14)         0          0
_1_act        nn.ReLU         (16, 32, 14, 14)  

In [22]:
run_comparison('miniai_resnet')

Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds                 Input Shape               Kernel Shape
Sequential                               [16, 10]                  --                        --                        [16, 1, 28, 28]           --
├─Sequential: 1-1                        [16, 16, 28, 28]          --                        --                        [16, 1, 28, 28]           --
│    └─Conv2d: 2-1                       [16, 16, 28, 28]          416                       5,218,304                 [16, 1, 28, 28]           [5, 5]
│    └─BatchNorm2d: 2-2                  [16, 16, 28, 28]          32                        512                       [16, 16, 28, 28]          --
│    └─ReLU: 2-3                         [16, 16, 28, 28]          --                        --                        [16, 16, 28, 28]          --
├─ResBlock: 1-2                          [16, 32, 14, 14]          --                        --   

### Autoencoder (from miniai)

In [23]:
run_algorithm('miniai_autoenc')

name            operation               activations           params        flops
--------------  ----------------------  ------------------  --------  -----------
input_1         input                   (16, 3, 160, 160)          0            0
_0_convs_0_0    nn.Conv2d               (16, 32, 160, 160)      2432   1966080000
_0_convs_0_1    nn.BatchNorm2d          (16, 32, 160, 160)        64     26214400
leaky_relu      leaky_relu()            (16, 32, 160, 160)         0            0
sub             sub()                   (16, 32, 160, 160)         0            0
_0_convs_1_0    nn.Conv2d               (16, 32, 160, 160)     25632  20971520000
_0_convs_1_1    nn.BatchNorm2d          (16, 32, 160, 160)        64     26214400
_0_idconv_0     nn.Conv2d               (16, 32, 160, 160)       128     78643200
add             add()                   (16, 32, 160, 160)         0            0
leaky_relu_1    leaky_relu()            (16, 32, 160, 160)         0            0
sub_1           

In [24]:
run_comparison('miniai_autoenc')

Layer (type:depth-idx)                        Output Shape              Param #                   Mult-Adds                 Input Shape               Kernel Shape
Sequential                                    [16, 3, 160, 160]         --                        --                        [16, 3, 160, 160]         --
├─ResBlock: 1-1                               [16, 32, 160, 160]        --                        --                        [16, 3, 160, 160]         --
│    └─Sequential: 2-1                        [16, 32, 160, 160]        --                        --                        [16, 3, 160, 160]         --
│    │    └─Sequential: 3-1                   [16, 32, 160, 160]        --                        --                        [16, 3, 160, 160]         --
│    │    │    └─Conv2d: 4-1                  [16, 32, 160, 160]        2,432                     996,147,200               [16, 3, 160, 160]         [5, 5]
│    │    │    └─BatchNorm2d: 4-2             [16, 32, 160, 160]    

### UNet (from miniai)

In [27]:
run_algorithm('miniai_unet')

name              operation               activations           params        flops
----------------  ----------------------  ------------------  --------  -----------
x                 input                   (16, 3, 160, 160)          0            0
start_convs_0_0   nn.Conv2d               (16, 32, 160, 160)       896    707788800
start_convs_0_1   nn.BatchNorm2d          (16, 32, 160, 160)        64     26214400
leaky_relu        leaky_relu()            (16, 32, 160, 160)         0            0
sub               sub()                   (16, 32, 160, 160)         0            0
start_convs_1_0   nn.Conv2d               (16, 32, 160, 160)      9248   7549747200
start_convs_1_1   nn.BatchNorm2d          (16, 32, 160, 160)        64     26214400
start_idconv_0    nn.Conv2d               (16, 32, 160, 160)       128     78643200
add               add()                   (16, 32, 160, 160)         0            0
leaky_relu_1      leaky_relu()            (16, 32, 160, 160)         0      

In [26]:
run_comparison('miniai_unet')

Layer (type:depth-idx)                             Output Shape              Param #                   Mult-Adds                 Input Shape               Kernel Shape
TinyUnet                                           [16, 3, 160, 160]         --                        --                        [16, 3, 160, 160]         --
├─ResBlock: 1-1                                    [16, 32, 160, 160]        --                        --                        [16, 3, 160, 160]         --
│    └─Sequential: 2-1                             [16, 32, 160, 160]        --                        --                        [16, 3, 160, 160]         --
│    │    └─Sequential: 3-1                        [16, 32, 160, 160]        --                        --                        [16, 3, 160, 160]         --
│    │    │    └─Conv2d: 4-1                       [16, 32, 160, 160]        896                       367,001,600               [16, 3, 160, 160]         [3, 3]
│    │    │    └─BatchNorm2d: 4-2     