# Tensor Parallel row-wise vs column-wise

There are two ways a linear layer can be split up across GPUs: row-wise or column-wise

## Question: 

If we want to do a linear transformation on input `X` by matrix multiplying it (`@`) with parameters `W1`, then apply a non-linear activation function `gelu()`,then one more linear transformation by matrix multiplying with `W2` to get 

$$ y = gelu(X @ W1) @ W2 $$


Which of these methods of tensor parallelizing `X @ W1` and `X @ W2` will be fastest in terms of minimizing data tansfer across GPU while keeping the final result `y` the same as had we not used parallelism and kept the computation on one CPU or GPU.

- a) column-wise-parallel-matmul -> all-gather -> gelu -> column-wise-parallel-matmul -> all-gather

- b) 

In [5]:
import torch; torch.manual_seed(42)
from torch.nn.functional import gelu # f() 

<img src="pics/matmul.png" height=200 width=400>

In [6]:
X = torch.randn(2, 4, device="cpu", dtype=torch.float32)
W = torch.randn(4, 4, device="cpu", dtype=torch.float32)

y = gelu(X @ W)

print("Baseline")
y

Baseline


tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])

<img src="pics/column_wise_parallel.png" height=300 width=600>

W is split horizontally (dim=1) and the outputs of the separate matmuls is concatenated (all_gather) along that same dimension horizontally (dim=1)

In [7]:
# Column linear

W_0, W_1 = W.chunk(2, dim=1) # W is split horizontally (dim=1)

y_col_1 = torch.cat([gelu(X @ W_0), gelu(X @ W_1)], dim=1) # concatenate after the nonlinearity

y_col_2 = gelu(torch.cat([X @ W_0, X @ W_1], dim=1)) # concatenate before the nonlinearity

try:
    torch.testing.assert_close(y_col_1, y_col_2, rtol=1e-5, atol=1e-5)
    col_match = True
    print("Column linear match")
except AssertionError:
    col_match = False
    print("Column linear mismatch")

y_col_2

Column linear match


tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])

<img src="pics/row_wise_parallel.png" height=350 width=700>

X is split horizontally (dim=1) and W is split vertically (dim=0). There is no concatenation step but rather a elementwise addition (all_reduce) step at the end. 

Unlike Colum-wise linear you cannot apply the non-linearity before the synchronization (all-gather, all-reduce) steps, because the activation function is non-linear, `f(a) + f(b) != f(a+b)` but `f(concat(a, b)) = concat(f(a), f(b))`

Note however that after the initial row-wise parallel split, before the matmul, the distribution of input tensors (X_0, X_1) across GPUs in row-wise parallel is the same as the distribution of tensors (X @ W_0, X @ W_1) across GPUs in column-wise parallel after the matmul step.

In [8]:
# Row linear

X_0, X_1 = X.chunk(2, dim=1) # X is split horizontally (dim=1)
W_0, W_1 = W.chunk(2, dim=0) # W is split vertically (dim=0)

y_row_1 = gelu(X_0 @ W_0) + gelu(X_1 @ W_1) # element-wise addtion after the nonlinearity
y_row_2 = gelu(X_0 @ W_0 + X_1 @ W_1) # element-wise addtion before the nonlinearity

try:
    torch.testing.assert_close(y_row_1, y_row_2, rtol=1e-5, atol=1e-5)
    col_match = True
    print("Row linear match")
except AssertionError:
    col_match = False
    print("Row linear mismatch")

print(y_row_1, "\n\n", y_row_2)

Row linear mismatch
tensor([[0.4549, 0.2409, 0.0150, 0.5894],
        [2.8925, 2.8190, 1.6791, 2.2097]]) 

 tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])


<img src="pics/column_row_parallel.png" height=350 width=700>

In [10]:
X = torch.randn(2, 4, device="cpu", dtype=torch.float32)
W1 = torch.randn(4, 4, device="cpu", dtype=torch.float32)
W2 = torch.randn(4, 4, device="cpu", dtype=torch.float32)

y = gelu(X @ W1) @ W2

print("Baseline")
y

Baseline


tensor([[ 0.1601, -0.0207,  0.0361, -0.1454],
        [ 0.5498,  0.3071, -0.1691,  1.1577]])

In [11]:
W1_0, W1_1 = W1.chunk(2, dim=1) # W is split horizontally (dim=1)
W2_0, W2_1 = W2.chunk(2, dim=0) # W is split vertically (dim=0)

X2_0 = gelu(X @ W1_0)
X2_1 = gelu(X @ W1_1)

y_col_row = X2_0 @ W2_0 + X2_1 @ W2_1
y_col_row

tensor([[ 0.1601, -0.0207,  0.0361, -0.1454],
        [ 0.5498,  0.3071, -0.1691,  1.1577]])