Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# Usage example (PopTorch)

Create a toy model to track:

In [1]:
import torch
from torch import nn, Tensor

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(10, 4)
        self.project = nn.Linear(4, 4)
        self.unembed = nn.Linear(4, 10)

    def forward(self, tokens: Tensor) -> Tensor:
        logits = self.unembed(self.project(self.embed(tokens)))
        return nn.functional.cross_entropy(logits, tokens)

torch.manual_seed(100)
module = Model()
inputs = torch.randint(0, 10, (3,))

**PopTorch:**

A few modifications to work with PopTorch:
 - Any tracking should be contained within `forward()`.
 - We shouldn't call `tensor.cpu()`, as this is implicit on returned tensors.
 - We don't have access to the backward pass.

In [8]:
from typing import Dict
import poptorch
import tensor_tracker

class TrackingModel(Model):
    def forward(self, inputs: Tensor) -> Dict[str, Tensor]:
        with tensor_tracker.track(self, stash_value=lambda t: t) as tracker:
            loss = super().forward(inputs)
        return loss, [t.__dict__ for t in tracker]

loss, tracked = poptorch.inferenceModel(TrackingModel())(inputs)
tracked = [tensor_tracker.Stash(**d) for d in tracked]
display(tracked)
# => [Stash(name="embed", type=nn.Embedding, grad=False, value=tensor(...)),
#     ...]

Graph compilation: 100%|██████████| 100/100 [00:04<00:00]


[Stash(name='embed', type=<class 'torch.nn.modules.sparse.Embedding'>, grad=False, value=tensor([[ 0.4520, -0.1066,  1.1028, -1.1578],
         [-0.4866, -0.1484, -1.6819,  0.7740],
         [-1.0324,  0.2063, -0.7983,  0.4695]])),
 Stash(name='project', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[ 1.2474,  0.4518,  0.2115, -0.6991],
         [-0.3698, -0.1035, -0.2358, -0.3482],
         [ 0.2165,  0.2673, -0.1278, -0.1348]])),
 Stash(name='unembed', type=<class 'torch.nn.modules.linear.Linear'>, grad=False, value=tensor([[-0.2676,  0.0945,  0.4727,  0.0716, -0.1146,  0.2311,  0.4380, -0.1172,
           0.6078, -0.0632],
         [ 0.2343, -0.0936,  0.1143, -0.0777,  0.0148, -0.0783,  0.2015,  0.1975,
           0.2441, -0.3956],
         [ 0.1521, -0.0814,  0.2678,  0.0481,  0.1128, -0.0149,  0.3953,  0.2135,
           0.3824, -0.2818]]))]