## Profiling

- PyTorch tutorial: https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html

Also check Holistic Trace Analysis: https://hta.readthedocs.io/en/latest/index.html

- PyTorch tutorial: https://pytorch.org/tutorials/beginner/hta_intro_tutorial.html

In [None]:
import torch
import torch.nn
import torch.optim
import torch.profiler
import torch.utils.data
import torchvision.datasets
import torchvision.models
import torchvision.transforms as T
from hta.trace_analysis import TraceAnalysis

In [None]:
transform = T.Compose(
    [T.Resize(224),
     T.ToTensor(),
     T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)

In [None]:
# device = torch.device("cuda:0")
device = torch.device('cpu')
model = torchvision.models.resnet18(weights='IMAGENET1K_V1').to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
print("Resnet model")

In [None]:
def train(data):
    inputs, labels = data[0].to(device=device), data[1].to(device=device)
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
with torch.profiler.profile(
    activities = [torch.profiler.ProfilerActivity.CPU],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
    on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name='./log/resnet18', use_gzip=False), # Generates and saves tensorboard files
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    for step, batch_data in enumerate(train_loader):
        train(batch_data)
        prof.step()  # Need to call this at each step to notify profiler of steps' boundary.
        if step >= 1 + (1 + 3) * 2:
            break

In [None]:
from hta.trace_analysis import TraceAnalysis
trace_dir = "/Users/keeganjebb/Documents/Projects/dl_toolbox/log/resnet18/"
trace_files = {0 : 'MacBook-Pro.local_62168.1717193242817464000.pt.trace.json'}
analyzer = TraceAnalysis(trace_dir=trace_dir, trace_files=trace_files)