# Tracking Numerics in Torch
Simple MNIST MLP example.

In [None]:
# Install torch dependecies if needed 
!pip install torch torchvision

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


from tandv.track.torch import track

In [None]:

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the image
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Prepare the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

In [None]:
# Instantiate the network, loss function, and optimizer
model = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 10

with track(module=model,optimizer=optimizer) as tracker: # wrap training loop in TorchTracker context manager
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            tracker.step() # call tracker.step() at the end of training loop

            # Print statistics
            running_loss += loss.item()
            if batch_idx % 100 == 99:  # Print every 100 batches
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0

In [11]:
from tandv.track.common import read_pickle

read_pickle('/Users/colmb/numerics-vis/example-notebooks/tandv/rank-rhea/final_logframe.pkl')

Unnamed: 0_level_0,metadata,metadata,metadata,metadata,metadata,scalar_stats,scalar_stats,scalar_stats,scalar_stats,scalar_stats,...,exponent_counts,exponent_counts,exponent_counts,exponent_counts,exponent_counts,exponent_counts,exponent_counts,exponent_counts,exponent_counts,exponent_counts
Unnamed: 0_level_1,step,name,type,tensor_type,dtype,mean,std,mean_abs,max_abs,min_abs,...,8,9,10,11,12,13,14,15,16,inf
0,0,fc1,,Weights,float32,4.014089e-05,0.020662,0.017892,0.035714,2.607703e-08,...,0,0,0,0,0,0,0,0,0,0
1,0,fc1.bias,,Weights,float32,-3.297047e-04,0.022365,0.019830,0.035681,1.496822e-05,...,0,0,0,0,0,0,0,0,0,0
2,0,fc2,,Weights,float32,3.476351e-04,0.051278,0.044444,0.088385,4.649162e-06,...,0,0,0,0,0,0,0,0,0,0
3,0,fc2.bias,,Weights,float32,-6.186668e-03,0.047025,0.039629,0.086022,1.875915e-03,...,0,0,0,0,0,0,0,0,0,0
4,0,fc3,,Weights,float32,-2.974902e-03,0.070225,0.060364,0.124378,1.822859e-04,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
300137,9379,fc2.bias,,Optimiser_State.exp_avg_sq,float32,6.081242e-06,0.000004,0.000006,0.000015,3.178590e-13,...,0,0,0,0,0,0,0,0,0,0
300138,9379,fc3,,Optimiser_State.exp_avg,float32,1.158514e-09,0.006127,0.003708,0.025309,5.605194e-45,...,0,0,0,0,0,0,0,0,0,0
300139,9379,fc3,,Optimiser_State.exp_avg_sq,float32,6.355176e-04,0.000928,0.000636,0.009392,9.010034e-17,...,0,0,0,0,0,0,0,0,0,0
300140,9379,fc3.bias,,Optimiser_State.exp_avg,float32,5.238689e-10,0.001920,0.001672,0.003188,8.360653e-04,...,0,0,0,0,0,0,0,0,0,0
