## Training

* Training Loop
* Automatic Mixed Precision (AMP)
* Distributed Data Parallelism (DDP)
* DDP with Gradient Accumulation
* Logging

### Setup

In [None]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(740)
torch.set_printoptions(sci_mode=False, linewidth=160)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"

data_size = 100; n_embd = 1000; vocab_size = 10
batch_size = 2; context_size = 8
train_tokens = torch.randint(low=0, high=vocab_size, size=(data_size,)).type(torch.int64)
valid_tokens = torch.randint(low=0, high=vocab_size, size=(data_size,)).type(torch.int64)

def get_batch(data, device = device, batch_size = batch_size, context_size = context_size):
    indices = torch.randint(low=0, high=data.shape[0] - context_size, size=(batch_size,))
    X = torch.stack([data[idx:idx+context_size] for idx in indices]).to(device)
    y = torch.stack([data[idx+1:idx+context_size+1] for idx in indices]).to(device)
    return X, y

X, y = get_batch(train_tokens)

class SimpleModel(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.linear1 = nn.Linear(n_embd, 4 * n_embd)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(4 * n_embd, n_embd)
        self.linear_out = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets = None):
        tok_emb = self.token_embedding_table(idx)
        x = self.linear1(tok_emb)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.activation(x)
        logits = self.linear_out(x)
        if targets is not None:
            loss = F.cross_entropy(
                    logits.view(batch_size * context_size, vocab_size), 
                    targets.view(batch_size * context_size)
                )
        else:
            loss = None
        return logits, loss

model = SimpleModel(n_embd=n_embd, vocab_size=vocab_size)
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.0001)

max_steps = 100
print(f"{X=}, {X.shape}")
print(f"{y=}, {y.shape}")
print(f"logits = {model(X, y)[0]}, {model(X)[0].shape}")
print(f"loss at init = {model(X, y)[1]:.4f}, expected loss at init = {-math.log(1/vocab_size):.4f}")

### Training Loop

In [None]:
for step in range(max_steps):
    X, y = get_batch(train_tokens)
    logits, loss = model(X, y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

### Training Loop with Automatic Mixed Precision  (AMP)

There are two precision formats we use here: one is brain float 16 (BF16), and one is floating point 32 called (full precision).

1. Starting out with the weights in FP32 you then copy the weights to BF16 format.
2. Forward pass: compute the outputs of the NN with the BF16 weights (`model(x)`)
3. Compute the gradients in BF16 (`loss.backward()`)
4. Copy the BF16 gradients back to FP32
5. Update the FP32 weights using the optimizer (`optimizer.step()`).

To summarize, you are doing the forward pass and the calculation of the gradients in BF16, and you are doing the weight update (and the loss calculation) in FP32.

This is all handled behind the scenes within a context manager provided by torch called `autocast`. They explain it [in the following way](https://pytorch.org/docs/stable/amp.html#autocasting):

> When entering an autocast-enabled region, Tensors may be any type. You should not call half() or bfloat16() on your model(s) or inputs when using autocasting. `autocast` should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.

In [None]:
for step in range(max_steps):
    X, y = get_batch(train_tokens)
    with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        logits, loss = model(X, y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

### Training with Distributed Data Parallelism (DDP)

DDP refers to the process of creating multiple processes running on separate machines, each machine has a copy of the weights, and these weights are trained using different batches and then the gradients are syncronized before being used to update the model weights.


This works when the model will git on a single GPU. (Larger models require different parallel techniques, such as [FDSP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html).)

When you run the training, instead of running with python you use `torchrun` command e.g.,: 

```
# Original
python train.py  --max_steps 600000

# DDP
torchrun  --standalone --nproc_per_node=8 train.py --max_steps 600000
```

To implement DDP, you need to modify your training code in 4 places.

In [None]:
# 0. Initialize the DDP process
import os
from torch.distributed import init_process_group, destroy_process_group

init_process_group(backend="nccl") # https://pytorch.org/docs/stable/distributed.html

# 1. Incorporate environment variables (set from torchrun)
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)
main_process = ddp_rank == 0

device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)

# 2. Wrap your model in a DDP container (access model using model.module)
model = DDP(model, device_ids=[ddp_local_rank])

# Same as before
for step in range(max_steps):
    X, y = get_batch(train_tokens, device)
    with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        logits, loss = model(X, y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

# 3. Cleanup at training end
destroy_process_group()

### Training with DDP and Gradient Accumulation

In [None]:
# 0. Initialize the DDP process
import os
from torch.distributed import init_process_group, destroy_process_group

init_process_group(backend="nccl") # https://pytorch.org/docs/stable/distributed.html

# 1. Incorporate environment variables (set from torchrun)
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)
main_process = ddp_rank == 0

device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)

# 2. Wrap your model in a DDP container (access model using model.module)
model = DDP(model, device_ids=[ddp_local_rank])

for step in range(max_steps):

    # Iterate through each substep to simulate the larger batch size
    # Example: batch_size_per_gpu = 16, gradient_accumulation_steps_per_gpu = 4, n_gpus = 8
    # So effective batch_size is 512, meaning parameter update is based off of 4 * 16 * 8 = 512 training examples.
    for sub_step in range(gradient_accumulation_steps_per_gpu):
        X, y = get_batch(train_tokens, device)

        # Sync gradients only on the last gradient accumulation step
        model.require_backward_grad_sync = (sub_step + 1) == gradient_accumulation_steps_per_gpu
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(X, y)
            loss = loss / gradient_accumulation_steps_per_gpu    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

# 3. Cleanup at training end
destroy_process_group()

## Training with Logging

In [None]:

from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import os
from torch.distributed import init_process_group, destroy_process_group

# 0. Before Training
writer = SummaryWriter(log_dir=f"/tmp/data/output/tensorboard/training-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}")

init_process_group(backend="nccl") # https://pytorch.org/docs/stable/distributed.html
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)
main_process = ddp_rank == 0

device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)

model = DDP(model, device_ids=[ddp_local_rank])

for step in range(max_steps):

    for sub_step in range(gradient_accumulation_steps_per_gpu):
        X, y = get_batch(train_tokens, device)
        
        # 1. Within Training
        losses = estimate_loss(model)
        writer.add_scalar("Loss/train", losses["train"], step)
        writer.add_scalar("Loss/eval", losses["eval"], step)
        writer.add_scalar("learning_rate", lr, step)

        model.require_backward_grad_sync = (sub_step + 1) == gradient_accumulation_steps_per_gpu
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(X, y)
            loss = loss / gradient_accumulation_steps_per_gpu    
        loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)


# After training
hparam_dict = {
    "max_learning_rate": max_learning_rate,
    "max_steps": max_steps,
    "weight_decay": weight_decay,
    "n_layer": n_layer,
    "n_embd": n_embd,
    "context_size": context_size,
    "total_training_tokens": total_training_tokens,
    "n_times_through_data": total_training_tokens / total_training_tokens_unique,
}
metric_dict = {"hparam/loss": best_eval_loss.item()}
writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict)
writer.flush()
destroy_process_group()