# 📈 Checkpointing

This notebook explores some aspects of training a model, and optimising for performance.

## Setup 

In [70]:
import autorootcwd

In [71]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [72]:
from src.utils import get_device

device = get_device()
print(f"Using device: {device}")

Using device: cuda


## Single GPU

In [74]:
from src.world import World
from src.ckpt import Checkpoint

base_dir = "logs/20241115_174141/checkpoints"
ckpt = Checkpoint(base_dir)

In [75]:
from src.config import ModelConfig
from src.utils import get_tokenizer, get_model

tokenizer = get_tokenizer()
model = get_model(ModelConfig(n_layer=12, n_head=12, n_embd=768))

In [76]:
model.load_state_dict(ckpt.load(ckpt.get_latest_step()))
model.to(device)

GPT2(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x TransformerBlock(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (gelu): GELU(approximate='none')
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50304, bias=False)
)

In [78]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def sample(model, tokenizer, config, device):
    input_ids = tokenizer(config.prompt, return_tensors="pt")["input_ids"].to(device).repeat(config.num_samples, 1)
    for _ in range(config.max_new_tokens):
        logits = model(input_ids)
        logits = logits[:, -1, :] / config.temperature
        if config.top_k is not None:
            v, _ = torch.topk(logits, min(config.top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat((input_ids, idx_next), dim=1)
        if tokenizer.eos_token_id is not None and (idx_next == tokenizer.eos_token_id).all():
            break

    return [tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]

In [79]:
from src.config import SampleConfig

baseline_model = lambda input_ids: model.forward(input_ids=input_ids)
sample_config = SampleConfig(prompt="I am", max_new_tokens=20, num_samples=5, temperature=1.0, top_k=10)
sample(baseline_model, tokenizer, sample_config, device)

['I am a large language model and I can memorize this sentence.',
 'I am a large language model and I can memorize this sentence.',
 'I am a large language model and I can memorize this sentence.',
 'I am a large language model and I can memorize this sentence.',
 'I am a large language model and I can memorize this sentence.']

## Pipeline

In [80]:
base_dir = "logs/20241115_175526/checkpoints"

world0 = World(local_rank=0, world_size=2, device=device, debug=True)
ckpt0 = Checkpoint(base_dir)
ckpt0.setup(world0)

world1 = World(local_rank=1, world_size=2, device=device, debug=True)
ckpt1 = Checkpoint(base_dir)
ckpt1.setup(world1)

In [81]:
from copy import deepcopy
from src.utils import get_sharded_model

model0 = get_sharded_model(deepcopy(model), world0)
model1 = get_sharded_model(deepcopy(model), world1)

model0.load_state_dict(ckpt0.load(ckpt0.get_latest_step()))
model1.load_state_dict(ckpt1.load(ckpt1.get_latest_step()))

del model

print(model0.to(device))
print(model1.to(device))

ShardedGPT2(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-5): 6 x TransformerBlock(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (gelu): GELU(approximate='none')
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): Identity()
  )
  (lm_head): Identity()
)
ShardedGPT2(
  (transformer): ModuleDict(
    (wte): Identity()
    (wpe): Identity()
    (drop

In [82]:
pipeline_model = lambda input_ids: model1.forward(hidden_states=model0.forward(input_ids=input_ids))
sample(pipeline_model, tokenizer, sample_config, device)

['I am a large language model and I can memorize this sentence.',
 'I am a large language model and I can memorize this sentence.',
 'I am a large language model and I can memorize this sentence.',
 'I am a large language model and I can memorize this sentence.',
 'I am a large language model and I can memorize this sentence.']