<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/examples/pytorch_dist/%5BDist%5D_PyTorch_FSDP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Follow the instruction:
https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html

In [1]:
# @title Install the requirement
# See https://github.com/pytorch/examples/blob/main/distributed/FSDP2/requirements.txt

# %%capture
# !pip install uv
# !uv pip install torch==2.7.0

!pip install torch==2.7.0



In [2]:
import torch
import torch.nn as nn

from torch.distributed.fsdp import fully_shard, FSDPModule

## Base Transfomer Model

https://github.com/pytorch/examples/blob/main/distributed/FSDP2/model.py


In [12]:

from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class ModelArgs:
    n_layers: int = 4
    vocab_size: int = 8
    max_seq_len: int = 16
    dim: int = 16
    n_heads: int = 4
    dropout_p: float = 0.1


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        assert args.dim % args.n_heads == 0
        self.head_dim = args.dim // args.n_heads
        self.n_heads = args.n_heads
        self.dropout_p = args.dropout_p
        self.resid_dropout = nn.Dropout(args.dropout_p)

        self.wq = nn.Linear(args.dim, args.dim, bias=False)
        self.wk = nn.Linear(args.dim, args.dim, bias=False)
        self.wv = nn.Linear(args.dim, args.dim, bias=False)
        self.wo = nn.Linear(args.dim, args.dim, bias=False)

    def forward(self, x):
        bsz, seq_len, _ = x.size()
        queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
        queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
        keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
        values = values.view(bsz, seq_len, self.n_heads, self.head_dim)

        queries = queries.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
        keys = keys.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
        values = values.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)

        output = F.scaled_dot_product_attention(
            queries,
            keys,
            values,
            None,
            self.dropout_p if self.training else 0,
        )
        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
        return self.resid_dropout(self.wo(output))

    def reset_parameters(self):
        self.wq.reset_parameters()
        self.wk.reset_parameters()
        self.wv.reset_parameters()
        self.wo.reset_parameters()


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout_p):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim)
        self.gelu = nn.GELU()
        self.w2 = nn.Linear(hidden_dim, dim)
        self.resid_dropout = nn.Dropout(dropout_p)

    def forward(self, x):
        return self.resid_dropout(self.w2(self.gelu(self.w1(x))))

    def reset_parameters(self):
        self.w1.reset_parameters()
        self.w2.reset_parameters()


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.attention_norm = nn.LayerNorm(args.dim)
        self.attention = Attention(args)
        self.ffn_norm = nn.LayerNorm(args.dim)
        self.feed_forward = FeedForward(
            args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
        )

    def forward(self, x):
        h = x + self.attention(self.attention_norm(x))
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

    def reset_parameters(self):
        self.attention_norm.reset_parameters()
        self.attention.reset_parameters()
        self.ffn_norm.reset_parameters()
        self.feed_forward.reset_parameters()


# A toy transformer model, partly inspired by the nanoGPT model:
# https://github.com/karpathy/nanoGPT.
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        assert args.vocab_size is not None
        assert args.max_seq_len is not None
        self.model_args = args
        self.max_seq_len = args.max_seq_len
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
        self.dropout = nn.Dropout(args.dropout_p)
        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(TransformerBlock(args))
        self.norm = nn.LayerNorm(args.dim)
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

    def forward(self, tokens):
        _bsz, seq_len = tokens.size()
        assert seq_len <= self.max_seq_len
        h = self.tok_embeddings(tokens)
        pos = torch.arange(0, seq_len, device=tokens.device)
        p = self.pos_embeddings(pos)  # positional embeddings of shape (seq_len, dim)
        h = h + p
        h = self.dropout(h)
        for layer in self.layers:
            h = layer(h)
        h = self.norm(h)
        output = self.output(h).float()
        return output

    def reset_parameters(self):
        self.tok_embeddings.reset_parameters()
        self.pos_embeddings.reset_parameters()
        self.norm.reset_parameters()
        self.output.reset_parameters()

## FSDP

In [13]:
from torch.distributed.tensor import DTensor, Shard

epochs = 2

def run(rank, world_size):
  """ Distributed function to do the real work."""
  device = "cpu"
  max_seq_len = 16
  vocab_size = 8

  model_args = ModelArgs(max_seq_len=max_seq_len, vocab_size=vocab_size)
  model = Transformer(model_args)

  # Transform the model to FSDP
  for layer in model.layers:
    fully_shard(layer)
  fully_shard(model)

  # Verifications
  assert isinstance(model, Transformer)
  assert isinstance(model, FSDPModule)

  for param in model.parameters():
    assert isinstance(param, DTensor)
    assert param.placements == (Shard(0), )

  optim = torch.optim.Adam(model.parameters(), lr=1e-2)

  batch_size = 4
  seq_len = max_seq_len
  for epoch in range(epochs):
    x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)

    # This should be a loss example; no physical meaning.
    loss = model(x) # [B, N, H]
    loss = loss.sum() # [1]

    loss.backward()
    optim.step()
    optim.zero_grad()

    if rank == 0:
      print(f"Finished {epoch=}")


In [14]:
# @title Biolerplater to run the dist programs

import os
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    print(f"Initiated {backend=} {rank=} {world_size=}")
    fn(rank, size)


if __name__ == "__main__":
    world_size = 2
    processes = []
    if "google.colab" in sys.modules:
        print("Running in Google Colab")
        mp.get_context("spawn")
    else:
        mp.set_start_method("spawn")
    for rank in range(world_size):
        p = mp.Process(target=init_process, args=(rank, world_size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Running in Google Colab
Initiated backend='gloo' rank=0 world_size=2
Initiated backend='gloo' rank=1 world_size=2
Finished epoch=0
Finished epoch=1


# FSDP with Explicit Prefetching

In [19]:
from torch.distributed.tensor import DTensor, Shard

epochs = 4

def run_fsdp_explicit_prefetching(rank, world_size):
  """ Distributed function to do the real work."""
  device = "cpu"
  max_seq_len = 16
  vocab_size = 8

  model_args = ModelArgs(max_seq_len=max_seq_len, vocab_size=vocab_size)
  model = Transformer(model_args)

  # Transform the model to FSDP
  for layer in model.layers:
    fully_shard(layer)
  fully_shard(model)

  # Verifications
  assert isinstance(model, Transformer)
  assert isinstance(model, FSDPModule)

  for param in model.parameters():
    assert isinstance(param, DTensor)
    assert param.placements == (Shard(0), )

  # Explicit prefetching
  #
  # Users can specify forward ordering with set_modules_to_forward_prefetch, and
  # backward ordering with set_modules_to_backward_prefetch. As shown in the
  # code below, CPU thread issue all-gather i + 1 and i + 2 at layer i
  num_to_forward_prefetch = 2
  for i, layer in enumerate(model.layers):
    if i >= len(model.layers) - num_to_forward_prefetch:
      break
    layers_to_prefetch = [
        model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
    ]
    layer.set_modules_to_forward_prefetch(layers_to_prefetch)

  num_to_backward_prefetch = 2
  for i, layer in enumerate(model.layers):
    if i < num_to_backward_prefetch:
        continue
    layers_to_prefetch = [
        model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
    ]
    layer.set_modules_to_backward_prefetch(layers_to_prefetch)

  optim = torch.optim.Adam(model.parameters(), lr=1e-2)

  batch_size = 4
  seq_len = max_seq_len
  for epoch in range(epochs):
    # trigger 1st all-gather earlier
    # this overlaps all-gather with any computation before model(x)
    model.unshard()

    x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)

    # This should be a loss example; no physical meaning.
    loss = model(x) # [B, N, H]
    loss = loss.sum() # [1]

    loss.backward()
    optim.step()
    optim.zero_grad()

    if rank == 0:
      print(f"Finished {epoch=}")


In [20]:
# @title Biolerplater to run the dist programs

import os
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    print(f"Initiated {backend=} {rank=} {world_size=}")
    fn(rank, size)


if __name__ == "__main__":
    world_size = 2
    processes = []
    if "google.colab" in sys.modules:
        print("Running in Google Colab")
        mp.get_context("spawn")
    else:
        mp.set_start_method("spawn")
    for rank in range(world_size):
        p = mp.Process(target=init_process, args=(rank, world_size, run_fsdp_explicit_prefetching))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Running in Google Colab
Initiated backend='gloo' rank=0 world_size=2
Initiated backend='gloo' rank=1 world_size=2
Finished epoch=0
Finished epoch=1
Finished epoch=2
Finished epoch=3


In [21]:
print(torch.__version__)

2.7.0+cu126
