In [1]:
from jupytertracerviz import init_multigpus_repl, multigpus
init_multigpus_repl()

In [2]:
%%multigpus

model = torch.nn.Linear(10, 10).cuda(rank)

x = torch.randn(10, 10).cuda(rank)
output = model(x)

print(f"Rank {rank}: Output sum {output.sum().item()}")

[GPU 1] Rank 1: Output sum 1.752981185913086
[GPU 2] Rank 2: Output sum 1.752981185913086
[GPU 0] Rank 0: Output sum 1.752981185913086


In [3]:
%%multigpus

from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
from torch.distributed.device_mesh import init_device_mesh

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        w1 = self.w1(x)

model = FeedForward(8192, 8192)
device_mesh = init_device_mesh('cuda', (3,))
linears = {name: ColwiseParallel(use_local_output = False) \
           for name, submodule in model.named_modules() \
           if isinstance(submodule, nn.Linear)}
model = parallelize_module(model, device_mesh, linears)

In [4]:
%%multigpus

print(model.w1.weight)

[GPU 0] DTensor(local_tensor=tensor([[-0.0091, -0.0020, -0.0044,  ...,  0.0051, -0.0003, -0.0052],
        [-0.0031, -0.0035,  0.0102,  ...,  0.0049,  0.0089, -0.0098],
        [-0.0023,  0.0038,  0.0077,  ...,  0.0042,  0.0058,  0.0008],
        ...,
        [ 0.0039, -0.0059,  0.0026,  ..., -0.0065,  0.0105,  0.0100],
        [-0.0059,  0.0055,  0.0086,  ..., -0.0009, -0.0025,  0.0101],
        [ 0.0110,  0.0024, -0.0085,  ...,  0.0073, -0.0097,  0.0062]],
       device='cuda:0'), device_mesh=DeviceMesh('cuda', [0, 1, 2]), placements=(Shard(dim=0),))
[GPU 1] DTensor(local_tensor=tensor([[ 0.0100, -0.0057, -0.0053,  ..., -0.0034, -0.0009,  0.0100],
        [-0.0042,  0.0054, -0.0007,  ..., -0.0038,  0.0090, -0.0033],
        [-0.0012, -0.0079, -0.0041,  ...,  0.0049,  0.0048,  0.0045],
        ...,
        [-0.0007, -0.0089, -0.0094,  ..., -0.0040,  0.0080, -0.0100],
        [ 0.0039,  0.0050,  0.0066,  ..., -0.0005,  0.0077,  0.0065],
        [ 0.0014, -0.0091,  0.0014,  ..., -0.0093

In [6]:
%%multigpus

print(model.w1.__dict__)

[GPU 0] {'training': True, '_parameters': {'weight': DTensor(local_tensor=tensor([[-0.0091, -0.0020, -0.0044,  ...,  0.0051, -0.0003, -0.0052],
        [-0.0031, -0.0035,  0.0102,  ...,  0.0049,  0.0089, -0.0098],
        [-0.0023,  0.0038,  0.0077,  ...,  0.0042,  0.0058,  0.0008],
        ...,
        [ 0.0039, -0.0059,  0.0026,  ..., -0.0065,  0.0105,  0.0100],
        [-0.0059,  0.0055,  0.0086,  ..., -0.0009, -0.0025,  0.0101],
        [ 0.0110,  0.0024, -0.0085,  ...,  0.0073, -0.0097,  0.0062]],
       device='cuda:0'), device_mesh=DeviceMesh('cuda', [0, 1, 2]), placements=(Shard(dim=0),)), 'bias': None}, '_buffers': {}, '_non_persistent_buffers_set': set(), '_backward_pre_hooks': OrderedDict(), '_backward_hooks': OrderedDict(), '_is_full_backward_hook': None, '_forward_hooks': OrderedDict([(1, <function distribute_module.<locals>.<lambda> at 0x7f577d652ef0>)]), '_forward_hooks_with_kwargs': OrderedDict(), '_forward_hooks_always_called': OrderedDict(), '_forward_pre_hooks': Orde