There are two ways a linear layer can be split up across GPUs: row-wise or column-wise

In [5]:
import torch; torch.manual_seed(42)
from torch.nn.functional import gelu

<img src="pics/matmul.png" height=200 width=400>

In [6]:
X = torch.randn(2, 4, device="cpu", dtype=torch.float32)
W = torch.randn(4, 4, device="cpu", dtype=torch.float32)

y = gelu(X @ W)

print("Baseline")
y

Baseline


tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])

<img src="pics/column_wise_parallel.png" height=300 width=600>

W is split horizontally (dim=1) and the outputs of the separate matmuls is concatenated (all_gather) along that same dimension horizontally (dim=1)

In [7]:
# Column linear

W_0, W_1 = W.chunk(2, dim=1) # W is split horizontally (dim=1)

y_col_1 = torch.cat([gelu(X @ W_0), gelu(X @ W_1)], dim=1) # concatenate after the nonlinearity

y_col_2 = gelu(torch.cat([X @ W_0, X @ W_1], dim=1)) # concatenate before the nonlinearity

try:
    torch.testing.assert_close(y_col_1, y_col_2, rtol=1e-5, atol=1e-5)
    col_match = True
    print("Column linear match")
except AssertionError:
    col_match = False
    print("Column linear mismatch")

y_col_2

Column linear match


tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])

<img src="pics/row_wise_parallel.png" height=350 width=700>

X is split horizontally (dim=1) and W is split vertically (dim=0). There is no concatenation step but rather a elementwise addition (all_reduce) step at the end. 

In [8]:
# Row linear

X_0, X_1 = X.chunk(2, dim=1) # X is split horizontally (dim=1)
W_0, W_1 = W.chunk(2, dim=0) # W is split vertically (dim=0)

y_row_1 = gelu(X_0 @ W_0) + gelu(X_1 @ W_1) # element-wise addtion after the nonlinearity
y_row_2 = gelu(X_0 @ W_0 + X_1 @ W_1) # element-wise addtion before the nonlinearity

try:
    torch.testing.assert_close(y_row_1, y_row_2, rtol=1e-5, atol=1e-5)
    col_match = True
    print("Row linear match")
except AssertionError:
    col_match = False
    print("Row linear mismatch")

print(y_row_1, "\n\n", y_row_2)

Row linear mismatch
tensor([[0.4549, 0.2409, 0.0150, 0.5894],
        [2.8925, 2.8190, 1.6791, 2.2097]]) 

 tensor([[0.5324, 0.2612, 0.0106, 0.7086],
        [1.1399, 2.8075, 1.8650, 1.5034]])


Because the the activation function is non-linear, `f(a) + f(b) != f(a+b)`