In [1]:
from jupytertracerviz import init_multigpus_repl, multigpus
init_multigpus_repl()

Worker 0/3 initialized on GPU 0 on localhost:12355
Worker 1/3 initialized on GPU 1 on localhost:12355
Worker 2/3 initialized on GPU 2 on localhost:12355


In [2]:
%%multigpus

model = torch.nn.Linear(10, 10).cuda(rank)

x = torch.randn(10, 10).cuda(rank)
output = model(x)

print(f"Rank {rank}: Output sum {output.sum().item()}")

[GPU 0] Rank 0: Output sum 1.8705439567565918
[GPU 1] Rank 1: Output sum 1.8705439567565918
[GPU 2] Rank 2: Output sum 1.8705439567565918


In [4]:
%%multigpus

from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
from torch.distributed.device_mesh import init_device_mesh

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        w1 = self.w1(x)

model = FeedForward(8192, 8192)
device_mesh = init_device_mesh('cuda', (3,))
linears = {name: ColwiseParallel(use_local_output = False) \
           for name, submodule in model.named_modules() \
           if isinstance(submodule, nn.Linear)}
model = parallelize_module(model, device_mesh, linears)

In [7]:
%%multigpus

print(model.w1.weight)

[GPU 0] DTensor(local_tensor=tensor([[ 0.0093,  0.0076,  0.0007,  ...,  0.0004, -0.0050, -0.0019],
        [-0.0073, -0.0070, -0.0062,  ...,  0.0004,  0.0077,  0.0005],
        [-0.0094,  0.0043,  0.0092,  ..., -0.0058, -0.0054,  0.0059],
        ...,
        [-0.0102,  0.0093,  0.0051,  ..., -0.0065, -0.0071,  0.0065],
        [ 0.0029,  0.0036,  0.0107,  ...,  0.0089, -0.0007, -0.0023],
        [ 0.0064, -0.0008,  0.0090,  ..., -0.0075,  0.0102, -0.0078]],
       device='cuda:0'), device_mesh=DeviceMesh('cuda', [0, 1, 2]), placements=(Shard(dim=0),))
[GPU 1] DTensor(local_tensor=tensor([[ 0.0073,  0.0030, -0.0023,  ..., -0.0001, -0.0056, -0.0017],
        [ 0.0072, -0.0014, -0.0031,  ..., -0.0042,  0.0065, -0.0058],
        [ 0.0009, -0.0069, -0.0058,  ...,  0.0058, -0.0091,  0.0048],
        ...,
        [-0.0108,  0.0031, -0.0069,  ..., -0.0067,  0.0077,  0.0108],
        [ 0.0034,  0.0045,  0.0104,  ..., -0.0045,  0.0036,  0.0046],
        [ 0.0105,  0.0030, -0.0090,  ...,  0.0098