# Example: Single Node Multi-GPU training with `quix` in Jupyter

In this example, we're going to check out how `quix` can be used for single node training in a Jupyter Notebook.
Note that this method is not the most optimal way of training, but could be useful for some light training on a 
node where multiple users have direct access to the GPUs, such as `samsida.hpc.uio.no`. As such, this is an *illustrative
example*, but in general *not the most efficient methodology for training with `quix`*.

To start off, we will set up some standard Jupyter cell magic commands that are often useful. The value of these in 
the current notebook is a little dubious, but it is good practice to have these cell magic commands in the first run
cell. We use the `autoreload` extension, and set `matplotlib inline` to allow simple plotting in the notebook.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Step 1: Imports

Our first actual step will involve setting the `CUDA_VISIBLE_DEVICES` environment variable. 
We do this to tell PyTorch to use only a subset of the available devices on the node, since we are sharing it with others. 
In this case, it seems that GPU 1,2,3 are available for training. Let's set the environment variable to let PyTorch use these two nodes.

In [2]:
import os

# Set visible devices for the system
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29501'

## Step 2: Writing a main_worker using `quix.run`

The next step involves defining a `main_worker` function that will carry out the 
training. We actually need to write a seperate script for this purpose, since 
the `main_worker` function needs to be defined as a top-level function in a 
module for `torch.multiprocessing.spawn` to work. No problem, we simply write
our output to a file using the cell magic `%%writefile` which constructs a file
in the `cwd` of the notebook. 

Using the `quix.run.Runner`'s `from_dict` method allows us to parse our config
arguments to the run as a dictionary. We will train a ResNet18 on the Caltech256
classification dataset over 15 epochs as a small test. The main worker will
mostly rely on the defaults in `quix` with a few small variations.
We also need some trick to import the `quix` module, since this notebook is in 
a subfolder of the repo. We ignore this if `quix` has been installed using `pip`.

When the module has been created, the `main_worker` will be defined as a top-level
function in a module, so we can simply import it into the notebook and run it
using `torch.multiprocessing.spawn`.

In [3]:
%%writefile train_test.py

import torch

# Check if quix is a environment module
try:
    from quix.run import Runner
    
# If not, assume we are running from the notebook in the repo, and use a hack
except ModuleNotFoundError:    
    import sys
    sys.path.append('../')
    from quix.run import Runner

# Define a small main_worker function
# Feel free to change these if your want to test for yourself
def main_worker(rank):
    world_size = torch.cuda.device_count()
    runner = Runner.from_dict(
        model='resnet18',
        custom_runid='mytestrun',
        project='testproject',
        dataset='Caltech256',
        num_classes=257,
        epochs=50,
        aug3=True,
        input_ext='jpg',
        target_ext='cls',
        data_path='/work2/litdata/',
        batch_size=512,
        lr_init=3e-5,
        model_ema=True,
        world_size = world_size,
        rank = rank,
        local_world_size = world_size,
        local_rank = rank,
    )
    runner.run()

Overwriting train_test.py


## Step 3: Running the training using the `main_worker`

We've arrived at the meat-and-potatoes of the example; the actual run. We wrote
a small `main_worker` function, and we now want to use it to run our script. 
To carry out the distributed learning, we launch this function with 
`torch.multiprocessing.spawn` using `torch.cuda.device_count` as the effective
world size for the DDP process.

In [4]:
import torch
from torch.multiprocessing import spawn
from train_test import main_worker

spawn(main_worker, nprocs=torch.cuda.device_count(), join=True)

Parsing augmentations...
Parsing data...
Parsing model...
Parsing loss...
Parsing parameter groups...
Parsing optimizer...
Parsing scaler...
Parsing scheduler...
Parsing DDP...
Parsing EMA...
Parsing checkpoint...
Parsing logger...
Finished parsing!
{'time': 1703173166.673809, 'epoch': 0, 'iteration': 0, 'timedelta': 17.3375506401062, 'loss': 5.662629127502441, 'Acc1': 0.001953125, 'Acc5': 0.01171875, 'last_lr': 3.013852679268472e-05}
{'time': 1703173166.993087, 'epoch': 0, 'iteration': 1, 'timedelta': 0.3192780017852783, 'loss': 5.6316657066345215, 'Acc1': 0.00390625, 'Acc5': 0.013671875, 'last_lr': 3.055408132606264e-05}
{'time': 1703173168.0989506, 'epoch': 0, 'iteration': 2, 'timedelta': 1.1058635711669922, 'loss': 5.651215553283691, 'Acc1': 0.001953125, 'Acc5': 0.01171875, 'last_lr': 3.124658607092611e-05}
{'time': 1703173169.2100976, 'epoch': 0, 'iteration': 3, 'timedelta': 1.1111469268798828, 'loss': 5.627786636352539, 'Acc1': 0.00390625, 'Acc5': 0.025390625, 'last_lr': 3.221591

## Step 4: Clean up test script and do some more inference

Since we no longer need the module for the main worker, we can simply delete
the reference, the file, and do some manual clean up. This will free 
the memory on the node, and allows applying further inference or other tasks.

In [None]:
import gc

# Delete reference and file
del main_worker; os.remove('train_test.py')

# Do garbage collection and clear the node's GPU cache
gc.collect()
torch.cuda.empty_cache()

# Now we can load the model and perform inference as normal if we want.

## Conclusion

We've seen a method for training a model on a single node with multiple GPUs 
using the Jupyter Notebook format. 

Even though we need to jump through a few 
hoops (defining the main function in a seperate script) we see that it can be 
done `quixly`$^\mathrm{TM}$ with few lines of code.