# Profiling with PyTorch
In this notebook we will go through profiling your training with PyTorch and TensorBoard.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

### Datapipe
First we construct a utility function to yield datapipes to later use in our DataLoader

In [None]:
import os
import torch
from torchvision.models import resnet18
from pytorch_dataset import TinyImageNetDataset 
from torch import nn, optim, profiler
from torch.utils.data import DataLoader
from PIL import Image


In [None]:
# Load TinyImageNet dataset using the custom dataset class
path_to_dataset = '/mimer/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip'

train_dataset = TinyImageNetDataset(path_to_dataset=path_to_dataset, split='train')
val_dataset = TinyImageNetDataset(path_to_dataset=path_to_dataset, split='val')

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=32)


In [None]:
# ResNet-18
pretrained = True
model = resnet18(weights=None, num_classes=200)
if pretrained:
    pretrained_state_dict = resnet18(
        pretrained=pretrained,
        num_classes=1000,
        progress=False,
    ).state_dict()
    for key in ["fc.weight", "fc.bias"]:
        del pretrained_state_dict[key]
    model.load_state_dict(pretrained_state_dict, strict=False)

# Optimizer
opt = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)

# Other
loss_func = nn.CrossEntropyLoss()
device = torch.device("cuda")



In [None]:
model.to(device)
def train_step(images, labels):
    images = images.to(device)
    labels = labels.to(device)
    opt.zero_grad()
    
    est = model(images)
    loss = loss_func(est, labels)
    loss.backward()
    opt.step()
    
    return loss.item()


Having taken care of these initialisations we are ready to take a look at profiling.

In [None]:
with profiler.profile(
        schedule=profiler.schedule(wait=10, warmup=5, active=10, repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/base.ptb'),
        record_shapes=False,
        profile_memory=False,
        with_stack=False,
) as prof:
    for images, labels in train_loader:
        loss = train_step(images, labels)
        
        # Step scheduler
        prof.step()
        
        print(f"\rStep: {prof.step_num}/50", end="")
        if prof.step_num >= 50:
            break


Note that you might get warnings for using step() during wait steps.

## Excercises
1. Look at the profiling results in tensorboard. To do this, follow the instructions in README.md
2. Try to follow the Performance Recomendation and try again with the code below

In [None]:
train_loader = DataLoader(trainpipe, shuffle=True)
model.to(device)

with profiler.profile(
        schedule=profiler.schedule(wait=10, warmup=5, active=10, repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/improved.ptb'),
        record_shapes=False,
        profile_memory=False,
        with_stack=False,
) as prof:

    for images, labels in train_loader:
        loss = train_step(images, labels)
        
        # This informs the profile scheduler
        prof.step()
        
        print(f"\rStep: {prof.step_num}/50", end="")
        if prof.step_num >= 50:
            # Part of an epoch may be enough information for us
            break