## Introduction
This notebook uses the `Interpreter` class from the `torch.fx` library to extract shape information from nodes after symbolic tracing has been performed



## Imports

In [1]:
import warnings
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torchvision
import torch.fx as fx
from typing import Any, Dict, List, Tuple
from tabulate import tabulate
import torchprofile
import torchinfo

## Models

### Hand-written CNN (for MNIST)
* Found [here](https://github.com/pytorch/examples/blob/main/mnist/main.py)
* Modified for experimental purposes, e.g. breaking the forward pass

In [2]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

def get_model_cnn_mnist():
    return MnistCnn()

def get_model_cnn_mnist_broken():
    model = MnistCnn()
    model.fc1 = nn.Linear(9216, 120)
    return model

### ResNet18 (from TorchVision)

In [3]:
def get_model_resnet18():
    return torchvision.models.resnet.resnet18()

### Model helpers

In [4]:
def get_model_and_input_size(name, device):
    match name:
        case 'cnn_mnist':
            return get_model_cnn_mnist().to(device), (16,1,28,28)
        case 'cnn_mnist_broken':
            return get_model_cnn_mnist_broken().to(device), (16,1,28,28)
        case 'resnet18':
            return get_model_resnet18().to(device), (16,3,160,160)
        case _:
            return None, None
        
def get_model_and_input_tensor(name, device):
    model, input_size = get_model_and_input_size(name, device)
    xb = torch.randn(input_size).to(device)
    return model, xb

In [5]:
device = 'cpu'
model_names = ['cnn_mnist', 'resnet18', 'cnn_mnist_broken']
model0, x0 = get_model_and_input_tensor(model_names[0], device)
model1, x1 = get_model_and_input_tensor(model_names[1], device)
model2, x2 = get_model_and_input_tensor(model_names[2], device)

## Shape interpreter

In [None]:
class ShapeInterpreter(fx.Interpreter):
    def __init__(self, mod : torch.nn.Module):
        gm = fx.symbolic_trace(mod)
        super().__init__(gm)

        self.cur_macs: int = None
        self.shapes : Dict[fx.Node, Tuple[int]] = {}
        self.macs : Dict[fx.Node, int] = {}
        self.param_counts : Dict[fx.Node, int] = {}
        self.type_names : Dict[fx.Node,str] = {}

    def rgetattr(self, m: nn.Module, attr: str) -> Tensor | None:
        # From torchinfo, used in `get_param_count()`:
        for attr_i in attr.split("."):
            if not hasattr(m, attr_i):
                return None
            m = getattr(m, attr_i)
        assert isinstance(m, Tensor)  # type: ignore[unreachable]
        return m  # type: ignore[unreachable]

    def get_num_trainable_params(self, m:nn.Module):
        num_params = 0
        for name, param in m.named_parameters():
            # We're only looking for trainable parameters here
            if not param.requires_grad: continue

            num_params_loop = param.nelement()

            # From torchinfo `get_param_count()`:
            # Masked models save parameters with the suffix "_orig" added.
            # They have a buffer ending with "_mask" which has only 0s and 1s.
            # If a mask exists, the sum of 1s in mask is number of params.
            if name.endswith("_orig"):
                without_suffix = name[:-5]
                pruned_weights = self.rgetattr(m, f"{without_suffix}_mask")
                if pruned_weights is not None:
                    num_params_loop = int(torch.sum(pruned_weights))
            
            num_params += num_params_loop
        return num_params

    def run_node(self, n:fx.Node) -> Any:
        # Run the node
        self.cur_macs = None
        result = super().run_node(n)

        # Retrieve the shape
        if isinstance(result, Tensor):
            shape = tuple(result.shape)
        else:
            shape = (0,0,0,0)
        self.shapes[n] = shape

        # Retrieve the module type and parameter count
        if n.op == 'call_module':
            submod = self.fetch_attr(n.target)
            self.type_names[n] = submod.__class__.__name__
            self.param_counts[n] = self.get_num_trainable_params(submod)
            if self.cur_macs is not None: self.macs[n] = self.cur_macs
        if n.op == 'call_function':
            self.type_names[n] = n.target.__name__

        # Return the result
        return result
    
    def call_module(self, target, args, kwargs):
        # Run the module
        result = super().call_module(target, args, kwargs)

        # Estimate the FLOPS
        try:
            submod = self.fetch_attr(target)
            macs = torchprofile.profile_macs(submod, args)
        except Exception as e:
            warnings.warn(f'FLOPS calculation failed for module {submod.__class__.__name__}: {e}')
            macs = 0  
        self.cur_macs = macs

        # Return the result
        return result
        
    def summary(self) -> str:
        node_summaries : List[List[Any]] = []

        for node, shape in self.shapes.items():
            type_name = self.type_names.get(node, '')
            num_params = self.param_counts.get(node, '')
            macs = self.macs.get(node, '')
            node_summaries.append(
                [node.op, node.name, type_name, node.all_input_nodes, list(node.users.keys()), shape, num_params, macs, node.target, node.args, node.kwargs])

        headers : List[str] = ['opcode', 'name', 'type', 'inputs', 'outputs', 'activations', '# params', 'MACs', 'target', 'args', 'kwargs']
        return tabulate(node_summaries, headers=headers)

### Testing model: working small model

In [7]:
model0

MnistCnn(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [8]:
interp0 = ShapeInterpreter(model0)
interp0.run(x0)
print(interp0.summary())

opcode         name         type         inputs         outputs        activations       # params    MACs       target                                                      args            kwargs
-------------  -----------  -----------  -------------  -------------  ----------------  ----------  ---------  ----------------------------------------------------------  --------------  ------------------------------------------------------------------------------------------
placeholder    x                         []             [conv1]        (16, 1, 28, 28)                          x                                                           ()              {}
call_module    conv1        Conv2d       [x]            [relu]         (16, 32, 26, 26)  320         3115008    conv1                                                       (x,)            {}
call_function  relu         relu         [conv1]        [conv2]        (16, 32, 26, 26)                         <function relu at 0x7f46a59f5ea0

In [9]:
torchinfo.summary(model0, input_size=x0.shape, verbose=1, depth=8, device=device,
                  col_names=["output_size","num_params","mult_adds","input_size","kernel_size"]);

Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds                 Input Shape               Kernel Shape
MnistCnn                                 [16, 10]                  --                        --                        [16, 1, 28, 28]           --
├─Conv2d: 1-1                            [16, 32, 26, 26]          320                       3,461,120                 [16, 1, 28, 28]           [3, 3]
├─Conv2d: 1-2                            [16, 64, 24, 24]          18,496                    170,459,136               [16, 32, 26, 26]          [3, 3]
├─Dropout: 1-3                           [16, 64, 12, 12]          --                        --                        [16, 64, 12, 12]          --
├─Linear: 1-4                            [16, 128]                 1,179,776                 18,876,416                [16, 9216]                --
├─Dropout: 1-5                           [16, 128]                 --                        -

### Testing: working ResNet18

In [10]:
model1

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [11]:
interp1 = ShapeInterpreter(model1)
interp1.run(x1)
print(interp1.summary())

opcode         name                   type               inputs                                 outputs                                  activations        # params    MACs       target                                                      args                                   kwargs
-------------  ---------------------  -----------------  -------------------------------------  ---------------------------------------  -----------------  ----------  ---------  ----------------------------------------------------------  -------------------------------------  --------
placeholder    x                                         []                                     [conv1]                                  (16, 3, 160, 160)                         x                                                           ()                                     {}
call_module    conv1                  Conv2d             [x]                                    [bn1]                                    (16, 64, 8

In [12]:
torchinfo.summary(model1, input_size=x1.shape, verbose=1, depth=8, device=device,
                  col_names=["output_size","num_params","mult_adds","input_size","kernel_size"]);

Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds                 Input Shape               Kernel Shape
ResNet                                   [16, 1000]                --                        --                        [16, 3, 160, 160]         --
├─Conv2d: 1-1                            [16, 64, 80, 80]          9,408                     963,379,200               [16, 3, 160, 160]         [7, 7]
├─BatchNorm2d: 1-2                       [16, 64, 80, 80]          128                       2,048                     [16, 64, 80, 80]          --
├─ReLU: 1-3                              [16, 64, 80, 80]          --                        --                        [16, 64, 80, 80]          --
├─MaxPool2d: 1-4                         [16, 64, 40, 40]          --                        --                        [16, 64, 80, 80]          3
├─Sequential: 1-5                        [16, 64, 40, 40]          --                        --    

In [13]:
# How to resolve an nn.Module from the `targets` string
# From `Interpreter.fetch_attr`, called from `Interpreter.call_module`
m = model1
target = 'layer1.1.conv2'

target_atoms = target.split('.')
attr_itr = m
for i, atom in enumerate(target_atoms):
    if not hasattr(attr_itr, atom):
        raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}")
    attr_itr = getattr(attr_itr, atom)
m = attr_itr
m

Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [14]:
model1.layer1[1].conv2

Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [15]:
model1.layer1[1].conv2.__class__

torch.nn.modules.conv.Conv2d

### Testing: Broken model

In [16]:
interp2 = ShapeInterpreter(model2)
interp2.run(x2)
print(interp2.summary())

RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x120 and 128x10)

While executing %fc2 : [num_users=1] = call_module[target=fc2](args = (%dropout2,), kwargs = {})
Original traceback:
None