<a href="https://colab.research.google.com/github/clam-sdx/tensor_parallel/blob/main/Tensor_Parallel_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch

- Among data parallelism, pipeline parallelism and tensor parallel, which ones help you if your model parameters cannot fit on a single GPU?

- what is `torch.distributed.barrier()` for?

- For training LLMs about many Gigabytes of vRAM would you expect to save if you trained a 7B model in 16-bit vs. 32-bit?


### Problem 2: Tensor Parallelism
Combining Column-Wise and Row-wise Tensor Parallelism

*Suggested Time: 10 minutes*

#### Background

Two practical styles a linear transformation (linear layer) can be split up to distribute the work across GPUs are the row-wise way and column-wise way of tensor parallelism.

In column- wise parallel, W is split horizontally (dim=1), X is duplicated across GPUs, and the outputs of the separate matmuls is concatenated (all_gather) along that same dimension horizontally (dim=1)

<img src="https://raw.githubusercontent.com/clam-sdx/tensor_parallel/refs/heads/main/notebooks/pics/column_wise_parallel.png" height=250 width=500>

In row-wise parallel, X is split horizontally (dim=1) and W is split vertically (dim=0). The i-th shards of X and W go to the same GPU together. There is no concatenation step but rather a elementwise addition (all_reduce) step at the end.

Unlike Column-wise linear you cannot apply the non-linearity before the synchronization (all-gather, all-reduce) steps, because the activation function is non-linear, `f(a) + f(b) != f(a+b)` but `f(concat(a, b)) = concat(f(a), f(b))`

<img src="https://raw.githubusercontent.com/clam-sdx/tensor_parallel/refs/heads/main/notebooks/pics/row_wise_parallel.png" height=250 width=500>

For better or worse, we are in a jupyter notebook without access to real GPUs. So lets agree to represent which GPU we are using in our implementation by using intermediate variable names with `_GPU{index}` in the suffix like this:

In [2]:
import torch; torch.manual_seed(42)
from torch.nn.functional import gelu # f()

# Non-Distributed Reference

X = torch.randn(2, 4, device="cpu", dtype=torch.float32)
W = torch.randn(4, 4, device="cpu", dtype=torch.float32)

y = gelu(X @ W)

print(f"Non-Distributed Reference \n{y}\n")

# Column-Wise Distributed Linear

# column-wise-split
W_0, W_1 = W.chunk(2, dim=1) # W is split horizontally (dim=1)

# distributed-matmul
y_GPU0 = X @ W_0
y_GPU1 = X @ W_1

# non-linear activation function gelu()
y_GPU0 = gelu(y_GPU0)
y_GPU1 = gelu(y_GPU1)

# all-gather
y_col = torch.cat([y_GPU0, y_GPU1], dim=1)

print(f"Distributed Column-wise Linear \n{y_col}\n")

try:
    torch.testing.assert_close(y_col, y, rtol=1e-5, atol=1e-5)
    col_match = True
    print("Column linear match\n")
except AssertionError:
    col_match = False
    print("Column linear mismatch\n")

# Row-Wise Distributed Linear

# row-wise-split
X_0, X_1 = X.chunk(2, dim=1) # X is split horizontally (dim=1)
W_0, W_1 = W.chunk(2, dim=0) # W is split vertically (dim=0)

# distributed-matmul

y_GPU0 = X_0 @ W_0
y_GPU1 = X_1 @ W_1

# all-reduce
y_GPU0 = y_GPU0 + y_GPU1

# gelu
y_row = gelu(y_GPU0)

print(f"Distributed Row-wise Linear \n{y_row}\n")

try:
    torch.testing.assert_close(y_row, y, rtol=1e-5, atol=1e-5)
    row_match = True
    print("Row linear match\n")
except AssertionError:
    row_match = False
    print("Row linear mismatch\n")

Non-Distributed Reference 
tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])

Distributed Column-wise Linear 
tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])

Column linear match

Distributed Row-wise Linear 
tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])

Row linear match



#### Goal

We want to do a linear transformation on a 2x4 input matrix `X` by first matrix multiplying it (`@`) with 4x4 parameter matrix `W1`, then applying a non-linear activation function `gelu()` to `X @ W1`, then applying one more linear transformation by matrix multiplying with another 4x4 parameter matrix `W2` to get `y`. The sequence of operations we just described can be represent by this expression:

$$ y = gelu(X @ W1) @ W2 $$

The sequence of operations we just described can be be implemented this way in a non-distributed manner:

In [3]:
import torch; torch.manual_seed(42)
from torch.nn.functional import gelu # f()

X = torch.randn(2, 4, device="cpu", dtype=torch.float32)
W1 = torch.randn(4, 4, device="cpu", dtype=torch.float32)
W2 = torch.randn(4, 4, device="cpu", dtype=torch.float32)

y_nondist = gelu(X @ W1) @ W2

print(y_nondist)

tensor([[-1.5733,  0.9400, -0.9233,  0.3600],
        [-5.5034,  2.8632, -2.8932, -2.0718]])


#### Question

Which method of distributing and synchronizing the operations of `gelu(X @ W1) @ W2` on multiple GPUs will be fastest in terms of minimizing data tansfer between GPUs while keeping the final result `y_dist` the same as had we not used parallelism and kept the computation on one CPU or GPU `y_nondist` ?

1. column-wise-split -> distributed-matmul -> all-gather -> gelu -> column-wise-split -> distributed-matmul -> all-gather

2. row-wise-parallel-split -> distributed-matmul -> all-reduce -> gelu -> row-wise-parallel-split -> distributed-matmul -> all-reduce

3. column-wise-split -> distributed-matmul -> gelu -> distributed-matmul -> all-reduce

4. row-wise-split -> distributed-matmul -> gelu -> distributed-matmul -> all-gather


#### Implementation

Write your solution below assuming you have 2 GPUs `GPU0` and `GPU1` at your disposal. Use intermediate variable names with `_GPU{index}` in the suffix just like in the examples above to indicate what operations are occuring in parallel.

In [4]:
## Write your solution here ##

# Check if your y_dist matches the y_nondist

# print(f"Distributed Column-Row Linear \n{y_dist}\n")

# try:
#     torch.testing.assert_close(y_dist, y_nondist, rtol=1e-5, atol=1e-5)
#     col_row_match = True
#     print("Column Row linear match\n")
# except AssertionError:
#     col_row_match = False
#     print("Column Row linear mismatch\n")