# Profiling your code

### Prerequisites

Make sure to read the following sections of the documentation before going through this example:

- [Pytorch setup](../../frameworks/pytorch_setup/index.rst)
- [Checkpointing](../checkpointing/index.rst)
- [Multi-gpu training](../../distributed/multi_gpu/index.rst)

Figuring out if or where your code may be performing slower than it needs to can be complicated.
In the present minimal example, we'll go through a basic profiling procedure that'll tackle the following:

- Diagnosing if training or dataloading is the bottleneck in your code
- Using the pytorch profiler to find additional bottlenecks
- WIP Potential avenues for further optimization with torch.compile, additional workers, multiple GPUs, etc.

### Dataloading

A simple way to tell if your bottleneck is coming from your dataloading procedure is to run the main script, ``main.py``, with and without training.  
Rationale being, if you run an epoch without training and the observed throughput is similar to the one you'd obtain while training, your dataloading is running at least at the speed of you training, making it comparatively slow. Take a minute to make sure this makes sense, then observe the two runs below.  

In [9]:
!python main.py --n-samples=20 --epochs=1 --skip-training

[08/05/24 13:25:45] INFO: Setting up ImageNet
Train epoch 0: 100%|████████████████████| 1.00/1.00 [00:01<00:00, 1.20s/Samples]
[08/05/24 13:25:52] INFO: epoch 0:
samples/s: 14.8144, 
updates/s: 0.0000, 
val_loss: 50.1568, 
val_accuracy: 0.00%


In [10]:
!python main.py --n-samples=20 --epochs=1 

[08/05/24 13:25:58] INFO: Setting up ImageNet
Train epoch 0: 100%|█| 1.00/1.00 [00:01<00:00, 1.39s/Samples, accuracy=0, loss=7
[08/05/24 13:26:05] INFO: epoch 0:
samples/s: 12.8945, 
updates/s: 0.7164, 
val_loss: 17.2102, 
val_accuracy: 0.00%


In [6]:
## Throughput with training
Take a look at https://docs.mila.quebec/examples/good_practices/launch_many_jobs/index.html

!srun --pty --gpus=1 --cpus-per-task=8 --mem=16G job.sh --epochs=1 --n-samples=20

SyntaxError: invalid syntax (3010376166.py, line 2)

Comparing the throughput of the former two cells, we can determine that dataloading was/wasn't the bottleneck.  
Did we leave any money on the table? Let's take a more in-depth look with the pytorch profiler.

In [None]:
## Basic profiler setup

In [None]:
## Profiler run

A-ha! [Component]'s utilization seems off. Let's introduce a quick fix.

In [None]:
## Fix to last bottleneck

#!python main.py --num-batches=20 --epochs=1 --skip-training  --num-workers=8

In [None]:
## New profiler run, with fixed bottleneck

See? we now have a pretty telling difference in profiler outputs. Can we do any better?

 Show how the output of the profiler changes once this last bottleneck is fixed. Give hints as to how to keep identifying the next bottleneck, and potential avenues for further optimization (for example using something like torch.compile, or more workers, multiple GPUs, etc.)


In [None]:
## More code changes, potential avenues for improvement.