# Profiling with PyTorch
In this notebook we will go through profiling your training with PyTorch and Holistic Trace Analysis.

## Setting up model and dataset
For this example we will use [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet/overview) which is similar to ImageNet but lower resolution (64x64) and fewer images (100 k). For this dataset we will use a variant of the ResNet architecture wich is a type of Convolutional Neural Network with residual connections. For the sake of this tutorial you do not need to understand the details about the model or the dataset.

### Datapipe
First we construct a utility function to yield datapipes to later use in our DataLoader

In [None]:
# Check if the HTA model imported.
import hta
print("HTA module imported successfully!")

In [None]:
%matplotlib inline
import time
import json
import os
import torch
from torchvision.models import resnet18
from pytorch_dataset import TinyImageNetDataset 
from torch import nn, optim, profiler
from torch.utils.data import DataLoader
from PIL import Image
from hta.trace_analysis import TraceAnalysis  # Import HTA
import matplotlib.pyplot as plt
# incase we need a high-level API for creating interactive plots using Plotly:
import plotly.express as px
import warnings
warnings.filterwarnings("ignore")
# For performance set precision,
# see https://www.c3se.chalmers.se/documentation/applications/pytorch/#performance-and-precision
torch.set_float32_matmul_precision("high")

In [None]:
# Load TinyImageNet dataset using the custom dataset class
path_to_dataset = '/mimer/NOBACKUP/Datasets/tiny-imagenet-200/tiny-imagenet-200.zip'

train_dataset = TinyImageNetDataset(path_to_dataset=path_to_dataset, split='train')
val_dataset = TinyImageNetDataset(path_to_dataset=path_to_dataset, split='val')

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=32)


In [None]:
# ResNet-18
pretrained = True
model = resnet18(weights=None, num_classes=200)
if pretrained:
    pretrained_state_dict = resnet18(
        pretrained=pretrained,
        num_classes=1000,
        progress=False,
    ).state_dict()
    for key in ["fc.weight", "fc.bias"]:
        del pretrained_state_dict[key]
    model.load_state_dict(pretrained_state_dict, strict=False)

# Optimizer
opt = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)

# Other
loss_func = nn.CrossEntropyLoss()
# device = torch.device("cuda")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
model.to(device)
def train_step(images, labels):
    images = images.to(device)
    labels = labels.to(device)
    opt.zero_grad()
    
    est = model(images)
    loss = loss_func(est, labels)
    loss.backward()
    opt.step()
    
    return loss.item()


Having taken care of these initialisations we are ready to take a look at profiling.

In [None]:
trace_dir = './trace_hta'
os.makedirs(trace_dir, exist_ok=True)

with profiler.profile(
    schedule=profiler.schedule(wait=10, warmup=5, active=10, repeat=2),
    record_shapes=True,
    profile_memory=True,
    with_stack=True
) as prof:
    for step, (images, labels) in enumerate(train_loader):
        loss = train_step(images, labels)

        # Step scheduler
        prof.step()
        print(f"\rStep: {step + 1}/50", end="")
        if step >= 49:
            break

    # Save trace as JSON for HTA with a unique filename using timestamp
    timestamp = int(time.time())
    trace_file = os.path.join(trace_dir, f'trace_{timestamp}.json')
    prof.export_chrome_trace(trace_file)
    print(timestamp)
 

Note that you might get warnings for using step() during wait steps.

In [None]:
# Add "distributedInfo" key to the trace file
%matplotlib inline

with open(trace_file, 'r') as file:
    trace_data = json.load(file)

if "distributedInfo" not in trace_data:
    trace_data["distributedInfo"] = {"rank": 0}

with open(trace_file, 'w') as file:
    json.dump(trace_data, file, indent=4)


# Debug: Print trace file content
print("\nTrace File Content:")
#print(json.dumps(trace_data, indent=4))

# Analyze the trace using HTA
analyzer = TraceAnalysis(trace_dir=trace_dir)


# Get temporal breakdown
time_spent_df = analyzer.get_temporal_breakdown(visualize=False) # turn off the visualization to use the matplotlib manually

print("\nTemporal Breakdown DataFrame:")
print(time_spent_df.head(2))
#plt.savefig(os.path.join(trace_dir, 'temporal_breakdown.png'))
print(f"Visualizations and data saved to {trace_dir}")

In [None]:
print(time_spent_df.head())
print(time_spent_df.columns)


In [None]:
# Calculate average step time
total_time = time_spent_df['compute_time(us)'].sum() + time_spent_df['non_compute_time(us)'].sum()
average_step_time = total_time / len(time_spent_df)
print(f"\nAverage Step Time: {average_step_time} us")

In [None]:
import numpy as np

# Data to plot
categories = ['idle_time(us)', 'compute_time(us)', 'non_compute_time(us)', 'kernel_time(us)', 
              'idle_time_pctg', 'compute_time_pctg', 'non_compute_time_pctg']
# Extract first row (assuming only one rank)
values = time_spent_df.iloc[0, 1:].values  

# Plot
plt.figure(figsize=(8, 5))
plt.bar(categories, values, color=['blue', 'green', 'red', 'purple'])

plt.ylabel("Time (us)")
plt.title("Temporal Breakdown by Category")
plt.xticks(rotation=45)
plt.grid(axis="y", linestyle="--", alpha=0.7)

plt.show()

In [None]:
import numpy as np

# Data to plot
categories = ['idle_time(us)', 'compute_time(us)', 'non_compute_time(us)', 'kernel_time(us)']
values = time_spent_df.iloc[0, 1:5].values  

# Plot
plt.figure(figsize=(8, 5))
plt.bar(categories, values, color=['blue', 'green', 'red', 'purple'])

plt.ylabel("Time (us)")
plt.title("Temporal Breakdown by Category")
plt.xticks(rotation=45)
plt.grid(axis="y", linestyle="--", alpha=0.7)

plt.show()


In [None]:
plt.figure(figsize=(10, 6))
plt.bar(time_spent_df['rank'], time_spent_df['compute_time(us)'], label='Compute Time')
plt.bar(time_spent_df['rank'], time_spent_df['idle_time(us)'], bottom=time_spent_df['compute_time(us)'], label='Idle Time')
plt.xlabel('Rank')
plt.ylabel('Time (us)')
plt.title('Temporal Breakdown')
plt.legend()
plt.savefig(os.path.join(trace_dir, 'temporal_breakdown.png'))
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.pie(time_spent_df['compute_time_pctg'], labels=time_spent_df['rank'], autopct='%1.1f%%')
plt.title('Kernel Type Breakdown')
plt.show()

In [None]:
# Get idle time breakdown
idle_time_df = analyzer.get_idle_time_breakdown(visualize=False)
print("\nIdle Time Breakdown DataFrame:")
print(idle_time_df)


In [None]:
print(type(idle_time_df))  # Check if it's really a DataFrame
print(idle_time_df)

idle_time_df = idle_time_df[0]

print(idle_time_df.head())
print(idle_time_df.columns)


In [None]:
# Data to plot
categories = ['idle_time', 'idle_time_ratio']  
values = idle_time_df.iloc[0, 3:].values  

# Plot
plt.figure(figsize=(8, 5))
plt.bar(categories, values, color=['blue', 'green'])

# Labels and title
plt.ylabel("Time (us) / Ratio")
plt.title("Idle Time and Ratio Breakdown")
plt.xticks(rotation=45)
plt.grid(axis="y", linestyle="--", alpha=0.7)

plt.show()

In [None]:
# Access the DataFrame inside the tuple
#idle_time_df = idle_time_df[0]

plt.figure(figsize=(10, 6))
plt.bar(idle_time_df['rank'], idle_time_df['idle_time'], label='idle_time')
plt.bar(idle_time_df['rank'], idle_time_df['idle_time_ratio'], bottom=idle_time_df['idle_time_ratio'], label='idle_time_ratio')
plt.xlabel('Rank')
plt.ylabel('Time (us)')
plt.title('Idle Time')
plt.legend()
#plt.savefig(os.path.join(trace_dir, 'idle_time.png'))
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.pie(idle_time_df['idle_time'], labels=idle_time_df['rank'], autopct='%1.1f%%')
plt.title('Idle Time Distribution')
plt.show()

In [None]:
# Get communication computation overlap
overlap_df = analyzer.get_comm_comp_overlap(visualize=False)
print("\nCommunication Computation Overlap DataFrame:")
print(overlap_df)
# manual Visualizations
operations = ['idle_time(us)', 'compute_time(us)', 'non_compute_time(us)', 'kernel_time(us)']
time_spent = [time_spent_df[col].sum() for col in operations]  # Sum of each category

plt.figure(figsize=(10, 6))
plt.bar(operations, time_spent, color=['blue', 'green', 'red', 'purple'])
plt.title('Temporal Breakdown of Operations')
plt.ylabel('Time Spent (us)')
plt.xlabel('Operation')
plt.xticks(rotation=45)

plt.show()


In [None]:
kernel_type_metrics_df, kernel_metrics_df = analyzer.get_gpu_kernel_breakdown(num_kernels=5,include_memory_kernels=True,image_renderer="notebook",visualize=False)

print("\nKernel Type Metrics DataFrame:")
print(kernel_type_metrics_df)
print("\nKernel Metrics DataFrame:")
print(kernel_metrics_df)

In [None]:
# incase to check the available renders:
import plotly
print(plotly.io.renderers)

In [None]:
# to make sure of the data type
print(type(kernel_type_metrics_df))  
print(kernel_type_metrics_df.head())
print(kernel_type_metrics_df.columns)

print('*************************')
print(type(kernel_metrics_df))  
print(kernel_metrics_df.head())
print(kernel_metrics_df.columns)

In [None]:
# manual Visualizations
operations = ['sum', 'percentage']
kernel_type_metrics = [kernel_type_metrics_df[col].sum() for col in operations]  # Sum of each category

plt.figure(figsize=(10, 6))
plt.bar(operations, kernel_type_metrics, color=['blue', 'red'])
plt.title('Kernel Type Metrics - Sum and Percentage')
plt.ylabel('Value')
plt.xlabel('Operations')
plt.xticks(rotation=45)
plt.show()

In [None]:
metrics = ['sum (us)', 'mean (us)', 'stddev']
kernel_metrics = [kernel_metrics_df[metric].sum() for metric in metrics]  

plt.figure(figsize=(10, 6))
plt.bar(metrics, kernel_metrics, color=['blue', 'green', 'red'])
plt.title('Kernel Metrics - Sum, Mean, and Stddev')
plt.ylabel('Value (us)')
plt.xlabel('Metrics')
plt.xticks(rotation=45)
plt.show()

In [None]:
# Directly explore the JSON trace data 
trace_dir = './trace_hta'
trace_files = [f for f in os.listdir(trace_dir) if f.endswith('.json')]

all_trace_data = []

for trace_file in trace_files:
    with open(os.path.join(trace_dir, trace_file), 'r') as f:
        trace_data = json.load(f)
        all_trace_data.append(trace_data)

# Print keys or inspect the data structure of the first trace file
if all_trace_data:
    print("\nKeys in the first trace data:", all_trace_data[0].keys())

print("\nProfiling and analysis completed.")

## Excercises
1. Use HTA to analyze how the execution time, compute time, and idle time are affected by changing the batch size in your training module
2. Run the model with different batch sizes (e.g., 64) and use HTA tool to visualize the impact on the idle time.