<a href="https://colab.research.google.com/github/ljppro/how-to-read-pytorch/blob/master/DTensor_Examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install PyTorch nightly to try out DTensor

In [None]:
!pip install expecttest hypothesis

Collecting expecttest
  Downloading expecttest-0.2.1-py3-none-any.whl (7.4 kB)
Collecting hypothesis
  Downloading hypothesis-6.100.0-py3-none-any.whl (458 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m458.0/458.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: hypothesis, expecttest
Successfully installed expecttest-0.2.1 hypothesis-6.100.0


In [None]:
import warnings

# Define a function that raises a warning
def example_function():
    warnings.warn("This is a warning messadsasfadfge")

# Use a context manager to catch and print warnings
with warnings.catch_warnings(record=True) as warning_list:
    example_function()

# Print the captured warnings
for warning in warning_list:
    print(f"Captured Warning: {warning.message}")



In [None]:
!pip install torch

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m50.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m54.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## import some testing utils to run DTensor in notebook

Since setup multiprocessing in a notebook is challenging (you can easily do it in `*.py` but with notebook there need to be some hacks, we developed some testing utils to spawn multiple threads and "mimic" the ProcessGroup communicator.

In [None]:
from torch.testing._internal.common_distributed import spawn_threads_and_init_comms

In [None]:
WORLD_SIZE=4

In [None]:
import torch

In [None]:
torch.tensor([[1,2,3],[4,5,6]]).sum(dim=0)

tensor([5, 7, 9])

# DTensor examples

DTensor will prototype release in PyTorch 2.0, let's try out some examples in this notebook to play around with DTensor.

First we need some necessary imports for DTensor

In [None]:
# some necessary imports
import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor, DeviceMesh, Shard, Replicate, distribute_tensor

How we could shard a big tensor across ranks?

In [None]:
@spawn_threads_and_init_comms
def shard_big_tensor(world_size):
  mesh = DeviceMesh("cpu", [0, 1, 2, 3])
  big_tensor = torch.randn((653, 10), device="meta")
  dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])
  print(f"on rank: {dist.get_rank()}, dtensor global shape: {dtensor.shape}, local shape: {dtensor.to_local().shape}\n")
  print(f"   global device: {dtensor.device}, local device: {dtensor.to_local().device}\n")

shard_big_tensor(WORLD_SIZE)

on rank: 0, dtensor global shape: torch.Size([653, 10]), local shape: torch.Size([164, 10])

   global device: meta, local device: meta

on rank: 2, dtensor global shape: torch.Size([653, 10]), local shape: torch.Size([164, 10])

on rank: 1, dtensor global shape: torch.Size([653, 10]), local shape: torch.Size([164, 10])

on rank: 3, dtensor global shape: torch.Size([653, 10]), local shape: torch.Size([161, 10])

   global device: meta, local device: meta
   global device: meta, local device: meta


   global device: meta, local device: meta



What if we want to replicate a big tensor across ranks?

In [None]:
@spawn_threads_and_init_comms
def replicate_big_tensor(world_size):
  mesh = DeviceMesh("cpu", [0, 1, 2, 3])
  big_tensor = torch.randn((888, 10))
  dtensor = distribute_tensor(big_tensor, mesh, [Replicate()])
  print(f"on rank: {dist.get_rank()}, dtensor global shape: {dtensor.shape}, local shape: {dtensor.to_local().shape}\n")

replicate_big_tensor(WORLD_SIZE)

on rank: 1, dtensor global shape: torch.Size([888, 10]), local shape: torch.Size([888, 10])

on rank: 2, dtensor global shape: torch.Size([888, 10]), local shape: torch.Size([888, 10])

on rank: 0, dtensor global shape: torch.Size([888, 10]), local shape: torch.Size([888, 10])

on rank: 3, dtensor global shape: torch.Size([888, 10]), local shape: torch.Size([888, 10])



What if we want to do some more complex sharding placements, say we want to shard this big tensor in a subset of devices, and replicate the shards in another shard of devices?

In [None]:
@spawn_threads_and_init_comms
def partially_shard_tensor(world_size):
  # if we want to distributed a tensor with both replication and sharding
  # create a 2-d mesh
  device_mesh = DeviceMesh("cpu", torch.arange(world_size).reshape(2, 2))
  print(str(device_mesh) + "\n")

  big_tensor = torch.randn((888, 10))
  # replicate across the first dimension of device mesh, then sharding (on tensor dim 0) on the second dimension of device mesh
  spec=[Replicate(), Shard(1)]
  partial_shard = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=spec)
  print(f"on rank: {dist.get_rank()} === {partial_shard.sum()}\n")
  print(f"on rank: {dist.get_rank()}, dtensor global shape: {partial_shard.shape}, local shape: {partial_shard.to_local().shape}\n")


partially_shard_tensor(WORLD_SIZE)

DeviceMesh([[0, 1], [2, 3]])
DeviceMesh([[0, 1], [2, 3]])


DeviceMesh([[0, 1], [2, 3]])

DeviceMesh([[0, 1], [2, 3]])

on rank: 3 === DTensor(local_tensor=-68.71055603027344, device_mesh=DeviceMesh([[0, 1], [2, 3]]), placements=(Replicate(), _Partial(reduce_op=RedOpType.SUM)))

on rank: 3, dtensor global shape: torch.Size([888, 10]), local shape: torch.Size([888, 5])

on rank: 1 === DTensor(local_tensor=-68.71055603027344, device_mesh=DeviceMesh([[0, 1], [2, 3]]), placements=(Replicate(), _Partial(reduce_op=RedOpType.SUM)))

on rank: 0 === DTensor(local_tensor=-19.24100112915039, device_mesh=DeviceMesh([[0, 1], [2, 3]]), placements=(Replicate(), _Partial(reduce_op=RedOpType.SUM)))

on rank: 0, dtensor global shape: torch.Size([888, 10]), local shape: torch.Size([888, 5])

on rank: 1, dtensor global shape: torch.Size([888, 10]), local shape: torch.Size([888, 5])
on rank: 2 === DTensor(local_tensor=-19.24100112915039, device_mesh=DeviceMesh([[0, 1], [2, 3]]), placements=(Replicate(), _P

How does DTensor intereacts with torch.Tensor?

We offer two APIs to convert from/to torch.Tensor:
- `from_local`, where it converts a torch.Tensor to a DTensor in SPMD fashion
- `to_local`, where we convert the DTensor to a torch.Tensor on each rank in SPMD fashion.

Note that both `from_local` and `to_local` are differentiable

In [None]:
@spawn_threads_and_init_comms
def dtensor_from_local_to_local(world_size):
  mesh = DeviceMesh("cpu", torch.arange(world_size))
  # create a DistributedTensor that shards on dim 0, from a local torch.Tensor
  local_tensor = torch.randn((8, 8), requires_grad=True)
  rowwise_placement = [Shard(0)]
  rowwise_tensor = DTensor.from_local(local_tensor, mesh, rowwise_placement)
  print(f"on rank: {dist.get_rank()}, dtensor global shape: {rowwise_tensor.shape}, local shape: {rowwise_tensor.to_local().shape}")

dtensor_from_local_to_local(WORLD_SIZE)

on rank: 0, dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([8, 8])
on rank: 1, dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([8, 8])
on rank: 3, dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([8, 8])
on rank: 2, dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([8, 8])


What if we want to change the layout of the DTensor (i.e. we want to convert a row-wise sharding DTensor to col-wise sharding, or we want to convert it back to a replicated DTensor? DTensor offers a `redistribute` API to automatically do the transformation:

- `dtensor.redistribute(mesh: DeviceMesh, placements: Sequence[Placement])`

Let's see an example:

In [None]:
@spawn_threads_and_init_comms
def dtensor_reshard(world_size):
  mesh = DeviceMesh("cpu", torch.arange(world_size))
  rowwise_placement = [Shard(0)]
  colwise_placement = [Shard(1)]
  # create a rowwise tensor
  local_tensor = torch.randn(8, 8)
  rowwise_tensor = DTensor.from_local(local_tensor, mesh, rowwise_placement)
  # reshard the current row-wise tensor to a colwise tensor or replicate tensor
  replica_placement = [Replicate()]
  colwise_tensor = rowwise_tensor.redistribute(mesh, colwise_placement)
  print(f"on rank: {dist.get_rank()}, col-wise dtensor global shape: {colwise_tensor.shape}, local shape: {colwise_tensor.to_local().shape}")
  replica_tensor = colwise_tensor.redistribute(mesh, replica_placement)
  print(f"on rank: {dist.get_rank()}, replicate dtensor global shape: {replica_tensor.shape}, local shape: {replica_tensor.to_local().shape}")

dtensor_reshard(WORLD_SIZE)

on rank: 0, col-wise dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([32, 2])
on rank: 1, col-wise dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([32, 2])
on rank: 3, col-wise dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([32, 2])on rank: 2, col-wise dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([32, 2])

on rank: 0, replicate dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([32, 8])on rank: 1, replicate dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([32, 8])on rank: 2, replicate dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([32, 8])


on rank: 3, replicate dtensor global shape: torch.Size([32, 8]), local shape: torch.Size([32, 8])


# Tensor Parallel Examples


Below we presented an example for Tensor Parallel (TP) and we first defined a dummy model which is essentially a two-layer multilayer perceptron (MLP).

In [None]:
import torch.nn as nn
from torch.distributed._tensor import (
    DeviceMesh,
)
from torch.distributed.tensor.parallel import (
    PairwiseParallel,
    parallelize_module,
)

ITER_TIME = 20

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 32)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(32, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


We then create an example to show an E2E working flow from forward,
backward and optimization.

More context about API designs can be found in the [design](https://github.com/pytorch/pytorch/issues/89884). And it is built on top of Distributed Tensor shown above.

We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
in between to show an example of Megatron-LM, which was proposed in [paper](https://arxiv.org/abs/1909.08053).

The basic idea is that we parallelize the first linear layer by column
and also parallelize the second linear layer by row so that we only need
one all reduce in the end of the second linear layer.

We can speed up the model training by avoiding communications between
two layers.

To parallelize a nn module, we need to specify what parallel style we want
to use and our `parallelize_module` API will parse and parallelize the modules
based on the given `ParallelStyle`. We are using this PyTorch native Tensor
Parallelism APIs in this example to show users how to use them.

In [None]:
def print0(msg, rank):
    if rank == 0:
        print(msg)

@spawn_threads_and_init_comms
def demo_tp(world_size):
    """
    Main body of the demo of a basic version of tensor parallel by using
    PyTorch native APIs.
    """
    rank = dist.get_rank()
    print0("Create a sharding plan based on the given world_size", rank)
    # create a sharding plan based on the given world_size.
    device_mesh = DeviceMesh(
        "cpu",
        torch.arange(world_size),
    )

    # create model and move it to GPU with id rank
    model = ToyModel()
    # Create a optimizer for the parallelized module.
    LR = 0.25
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    print0("Parallelize the module based on the given Parallel Style", rank)
    # Parallelize the module based on the given Parallel Style.
    model = parallelize_module(model, device_mesh, PairwiseParallel())

    # Perform a num of iterations of forward/backward
    # and optimizations for the sharded module.
    for i in range(ITER_TIME):
        inp = torch.rand(20, 10)
        output = model(inp)
        print0(f"FWD Step: iter {i}", rank)
        output.sum().backward()
        print0(f"BWD Step: iter {i}", rank)
        optimizer.step()
        print0(f"Optimization Step: iter {i}", rank)

    print0("Training finished", rank)

demo_tp(WORLD_SIZE)

Create a sharding plan based on the given world_size
Parallelize the module based on the given Parallel Style
FWD Step: iter 0
BWD Step: iter 0
Optimization Step: iter 0
FWD Step: iter 1
BWD Step: iter 1
Optimization Step: iter 1
FWD Step: iter 2
BWD Step: iter 2
Optimization Step: iter 2
FWD Step: iter 3
BWD Step: iter 3
Optimization Step: iter 3
FWD Step: iter 4
BWD Step: iter 4
Optimization Step: iter 4
FWD Step: iter 5
BWD Step: iter 5
Optimization Step: iter 5
FWD Step: iter 6
BWD Step: iter 6
Optimization Step: iter 6
FWD Step: iter 7
BWD Step: iter 7
Optimization Step: iter 7
FWD Step: iter 8
BWD Step: iter 8
Optimization Step: iter 8
FWD Step: iter 9
BWD Step: iter 9
Optimization Step: iter 9
FWD Step: iter 10
BWD Step: iter 10
Optimization Step: iter 10
FWD Step: iter 11
BWD Step: iter 11
Optimization Step: iter 11
FWD Step: iter 12
BWD Step: iter 12
Optimization Step: iter 12
FWD Step: iter 13
BWD Step: iter 13
Optimization Step: iter 13
FWD Step: iter 14
BWD Step: iter 14
Op

# 2D parallel and beyond

For 2D parallel with FullyShardedDataParallel(FSDP), since FSDP can only run on GPU now, we attached a [link](https://github.com/pytorch/pytorch/blob/master/test/distributed/tensor/parallel/test_2d_parallel.py) here as a reference.

And per community's ask for combining TP with PyTorch native pipeline parallel, aka, [PiPPy](https://github.com/pytorch/tau/tree/main). We also provided an [link](https://github.com/pytorch/tau/blob/main/examples/tp%2Bpp/pippy_tp.py) to the example showing how TP works with PiPPy.

# Call for Actions

Both DTensor and Tensor Parallel are in early stage of development (prototype release along with PyTorch 2.0). Feel free try it out now with the nightly, or with the upcoming 2.0 release.

If you meet some blockers, feel free file an github issue, or open a PR to contribute!