# This notebooks explores data distributed parallelization
Here the step to make data distributed parallel work:
1. initialize data distributed parallel -> 3. Initialize
2. prepare the data loader, so that the batches are loaded in the different processes (and cudas!) -> point 4
3. wrap the model -> point 5
4. if you want to save the model (or better it's params) every n steps, use model.module.state_dict() -> point 5
5. To run the model use <i>mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)</i>

# 1. Import Packages

In [10]:
# data distributed parallelization DDP
import torch
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

from torch.utils.data import Dataset, DataLoader

# 2. Params

In [11]:
world_size = torch.cuda.device_count() # also in 6. Run Training
print(f'number of cudas = {world_size}')

number of cudas = 1


# 3. Initialize Data Distributed Parallel

In [None]:
def ddp_setup(rank, world_size):
    '''
    Args:
            rank: unique identifier for each process -> rank is given automatically by mp.spawn()
            world_size: total number of processes
    '''
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    init_process_group(backend='nccl', rank = rank, world_size = world_size)
    torch.cuda.set_device(rank)
    

# 4. Prepare Dataloader

In [5]:
def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )

# 5. Wrap Model

In [None]:
model = DDP(model, device_ids=[gpu_id])

In [6]:
def _save_checkpoint(self, epoch):
# access to the model parameter
    ckp = self.model.module.state_dict()
    PATH = "checkpoint.pt"
    torch.save(ckp, PATH)
    print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

# 6. Run the training

In [8]:
def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    ddp_setup(rank, world_size)
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, rank, save_every)
    trainer.train(total_epochs)
    destroy_process_group()

In [7]:
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()
    
    world_size = torch.cuda.device_count()
    mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)

usage: ipykernel_launcher.py [-h] [--batch_size BATCH_SIZE]
                             total_epochs save_every
ipykernel_launcher.py: error: argument total_epochs: invalid int value: '/home/furio/.local/share/jupyter/runtime/kernel-bd8d7af0-4f4f-447d-b574-a21062084381.json'


AssertionError: 