Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# Usage example

Create a toy model to track:

In [1]:
import torch
from torch import nn, Tensor

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(10, 4)
        self.project = nn.Linear(4, 4)
        self.unembed = nn.Linear(4, 10)

    def forward(self, tokens: Tensor) -> Tensor:
        logits = self.unembed(self.project(self.embed(tokens)))
        return nn.functional.cross_entropy(logits, tokens)

torch.manual_seed(100)
module = Model()
inputs = torch.randint(0, 10, (3,))

Use `tensor_tracker` to capture forward pass activations and backward pass gradients from our toy model. By default, the tracker saves full tensors, as a list of `tensor_tracker.Stash` objects.

In [2]:
import tensor_tracker

with tensor_tracker.track(module) as tracker:
    module(inputs).backward()

print(tracker)

Tracker(stashes=8, tracking=0)


Note that calls are only tracked within the `with` context. Then, the tracker behaves like a list of `Stash` objects, with attached `name`, `value` etc.

In [3]:
display(list(tracker))
# => [Stash(name="embed", type=nn.Embedding, grad=False, value=tensor(...)),
#     ...]

[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4698,  1.2426,  0.5403, -1.1454],
         [-0.8425, -0.6475, -0.2189, -1.1326],
         [ 0.1268,  1.3564,  0.5632, -0.1039]])),
 Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.6237, -0.1652,  0.3782, -0.8841],
         [-0.9278, -0.2848, -0.8688, -0.4719],
         [-0.3449,  0.3643,  0.3935, -0.6302]])),
 Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2458,  1.0003, -0.8231, -0.1405, -0.2964,  0.5837,  0.2889,  0.2059,
          -0.6114, -0.5916],
         [-0.6345,  1.0882, -0.4304, -0.2196, -0.0426,  0.9428,  0.2051,  0.5897,
          -0.2217, -0.9132],
         [-0.0822,  0.9985, -0.7097, -0.3139, -0.4805,  0.6878,  0.2560,  0.3254,
          -0.4447, -0.3332]])),
 Stash(name='', type=<class '__main__.Model'>, grad=False, value=tensor(2.5663)),
 Stash(name='', type=<class '__m

As a higher-level API, `to_frame` computes summary statistics, defaulting to `torch.std`.

In [4]:
display(tracker.to_frame())

Unnamed: 0,name,type,grad,std
0,embed,torch.nn.modules.sparse.Embedding,False,0.853265
1,project,torch.nn.modules.linear.Linear,False,0.494231
2,unembed,torch.nn.modules.linear.Linear,False,0.581503
3,,__main__.Model,False,
4,,__main__.Model,True,
5,unembed,torch.nn.modules.linear.Linear,True,0.105266
6,project,torch.nn.modules.linear.Linear,True,0.112392
7,embed,torch.nn.modules.sparse.Embedding,True,0.068816
