- 张量并行
    - 更细粒度的模型并行，细到权重矩阵（tensor）粒度
    - 数学上：矩阵分块

In [210]:
import math
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [211]:
X = np.random.randn(100,200)
A = np.random.randn(200, 300)
# XA = 100*300
B = np.random.randn(300, 400)

In [212]:
def split_columnwise(A, num_splits):
    return np.split(A, num_splits, axis=1)
 
def split_rowwise(A, num_splits):
    return np.split(A, num_splits, axis=0)

In [213]:
def normal_forward_pass(X, A, B, f):
    Y = f(np.dot(X, A)) #(100,200)*(200,300)=(100,300)
    Z = np.dot(Y, B) #(100,300)*(300,400)=(100,400)
    return Z

In [214]:
def tensor_parallel_forward_pass(X, A, B, f):
    A1, A2 = split_columnwise(A, 2) #(200,150)
    B1, B2 = split_rowwise(B, 2) #(150,400)
    Y1 = f(np.dot(X, A1)) #(100,150)
    Y2 = f(np.dot(X, A2))
    Z1 = np.dot(Y1, B1) #(100,400)
    Z2 = np.dot(Y2, B2)
#     Z = np.sum([Z1, Z2], axis=0)
    Z = Z1+Z2
    return Z

In [215]:
Z_normal = normal_forward_pass(X, A, B, np.tanh)
Z_tensor = tensor_parallel_forward_pass(X, A, B, np.tanh)

In [216]:
Z_tensor.shape

(100, 400)

In [217]:
np.allclose(Z_normal, Z_tensor)

True

FFN
- h -> 4h
- 4h -> h

In [218]:
from transformers import AutoModel
import os#环境代理设置
os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"

In [219]:
bert = AutoModel.from_pretrained('bert-base-uncased')

In [220]:
bert.encoder.layer[0].intermediate

BertIntermediate(
  (dense): Linear(in_features=768, out_features=3072, bias=True)
  (intermediate_act_fn): GELUActivation()
)

In [221]:
bert.encoder.layer[0].output

BertOutput(
  (dense): Linear(in_features=3072, out_features=768, bias=True)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

ffn case

In [222]:
input = torch.randn(size=(1, 5, 10), dtype=torch.float32)

In [223]:
embedding_dim = input.size(2)
embedding_dim

10

In [224]:
dense_h_to_4h = nn.Linear(embedding_dim, embedding_dim*4, bias=False)
# (1, 5, 10) -> (1, 5, 40)
output1 = dense_h_to_4h(input)
output1.shape

torch.Size([1, 5, 40])

In [225]:
dense_h_to_4h.weight.shape

torch.Size([40, 10])

In [226]:
dense_4h_to_h = nn.Linear(embedding_dim*4, embedding_dim, bias=False)
# (1, 5, 40) -> (1, 5, 10)
output2 = dense_4h_to_h(output1)
output2.shape

torch.Size([1, 5, 10])

In [227]:
n_devices = 2
parallel_4h = embedding_dim * 4 // n_devices
parallel_4h

20

In [228]:
dense_h_to_4h_parallel = nn.Linear(embedding_dim, parallel_4h, bias=False)
dense_h_to_4h_parallel.weight.data = dense_h_to_4h.weight.data[:parallel_4h,:]
first_h_to_4h = dense_h_to_4h_parallel(input)
first_h_to_4h.shape

torch.Size([1, 5, 20])

In [229]:
dense_h_to_4h_parallel = nn.Linear(embedding_dim, parallel_4h, bias=False)
dense_h_to_4h_parallel.weight.data = dense_h_to_4h.weight.data[parallel_4h:2*parallel_4h,:]
second_h_to_4h = dense_h_to_4h_parallel(input)
second_h_to_4h.shape

torch.Size([1, 5, 20])

In [230]:
dense_4h_to_h_parallel = nn.Linear(embedding_dim, parallel_4h, bias=False)
dense_4h_to_h_parallel.weight.data = dense_4h_to_h.weight.data[:,:parallel_4h]
first_4h_to_h = dense_4h_to_h_parallel(first_h_to_4h)
first_4h_to_h.shape

torch.Size([1, 5, 10])

In [231]:
dense_4h_to_h_parallel = nn.Linear(embedding_dim, parallel_4h, bias=False)
dense_4h_to_h_parallel.weight.data = dense_4h_to_h.weight.data[:,parallel_4h:2*parallel_4h]
second_4h_to_h = dense_4h_to_h_parallel(second_h_to_4h)
second_4h_to_h.shape

torch.Size([1, 5, 10])

In [232]:
first_4h_to_h+second_4h_to_h

tensor([[[-4.9719e-02, -7.4587e-02,  8.9481e-03,  5.1355e-01, -8.9134e-03,
          -4.1438e-01,  2.5453e-01, -3.6612e-01,  6.8409e-02,  3.3059e-01],
         [ 4.3118e-01,  4.5226e-01, -1.5077e-01,  1.5164e-01,  3.7224e-01,
          -4.8630e-01, -1.6320e-01,  2.3650e-02,  1.2778e-01, -2.4209e-01],
         [ 4.7697e-02, -7.8100e-04,  6.6952e-01,  4.9806e-02,  1.3544e-01,
           2.9979e-02, -7.5627e-02, -3.7002e-01, -4.3875e-02,  5.0383e-01],
         [-2.0860e-01, -3.3780e-01,  6.9807e-01,  2.5011e-01, -1.0857e+00,
           7.3868e-01, -7.2284e-02, -2.4534e-01,  3.2172e-02,  7.3528e-01],
         [-1.6256e-01, -5.2663e-01,  6.3661e-01,  2.1291e-01,  1.7782e-01,
          -2.5116e-01, -7.2236e-03, -6.2389e-01,  1.3003e-01,  5.0622e-01]]],
       grad_fn=<AddBackward0>)

In [233]:
output2

tensor([[[-4.9719e-02, -7.4587e-02,  8.9481e-03,  5.1355e-01, -8.9133e-03,
          -4.1438e-01,  2.5453e-01, -3.6612e-01,  6.8409e-02,  3.3059e-01],
         [ 4.3118e-01,  4.5226e-01, -1.5077e-01,  1.5164e-01,  3.7224e-01,
          -4.8630e-01, -1.6320e-01,  2.3650e-02,  1.2778e-01, -2.4209e-01],
         [ 4.7697e-02, -7.8097e-04,  6.6952e-01,  4.9806e-02,  1.3544e-01,
           2.9979e-02, -7.5627e-02, -3.7002e-01, -4.3875e-02,  5.0383e-01],
         [-2.0860e-01, -3.3780e-01,  6.9807e-01,  2.5011e-01, -1.0857e+00,
           7.3868e-01, -7.2284e-02, -2.4534e-01,  3.2172e-02,  7.3528e-01],
         [-1.6256e-01, -5.2663e-01,  6.3661e-01,  2.1291e-01,  1.7782e-01,
          -2.5116e-01, -7.2236e-03, -6.2389e-01,  1.3003e-01,  5.0622e-01]]],
       grad_fn=<UnsafeViewBackward0>)

Attention

In [234]:
input = torch.randn(size=(1, 5, 32), dtype=torch.float32)

In [235]:
batch_size, seq_len, hidden_size = input.size()
hidden_size 

32

In [236]:
num_heads = 4
head_dim = hidden_size // num_heads
head_dim

8

In [237]:
Wq = nn.Linear(hidden_size, hidden_size, bias=False)
Wk = nn.Linear(hidden_size, hidden_size, bias=False)
Wv = nn.Linear(hidden_size, hidden_size, bias=False)
Wo = nn.Linear(hidden_size, hidden_size, bias=False)

In [238]:
q = Wq(input)
k = Wk(input)
v = Wv(input)

In [239]:
# bsz, seq_len, num_heads, head_dim
# bsz, num_heads, seq_len, head_dim
q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
# bsz, num_heads, seq_len, head_dim
k = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
v.shape

torch.Size([1, 4, 5, 8])

In [240]:
attn_weight = torch.matmul(q, k.transpose(2,3))/math.sqrt(head_dim)

In [241]:
attn_weight.shape

torch.Size([1, 4, 5, 5])

In [242]:
attn_weight = F.softmax(attn_weight, dim=-1)
attn_output = torch.matmul(attn_weight, v)
attn_output.shape

torch.Size([1, 4, 5, 8])

In [243]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_len, num_heads*head_dim)
attn_output.shape

torch.Size([1, 5, 32])

``` python
x = torch.arange(12).view(3, 4)      # 连续
y = x.transpose(0, 1)                # 不连续
y.is_contiguous()                    # False
y.view(12)                           # ❌ RuntimeError
y.contiguous().view(12)              # ✅ 可以
```

In [244]:
attn_output_non_tp = Wo(attn_output)
attn_output_non_tp.shape

torch.Size([1, 5, 32])

In [245]:
attn_output_non_tp

tensor([[[ 0.1467, -0.1020,  0.0323, -0.0405,  0.1146,  0.0616,  0.1730,
          -0.2219,  0.0287,  0.2283,  0.1665,  0.1526,  0.0512, -0.1623,
          -0.0435,  0.0033, -0.1520, -0.0482, -0.0878,  0.0365,  0.1547,
          -0.1591, -0.0755,  0.0978,  0.1997,  0.1412,  0.0167, -0.0183,
           0.0036, -0.2445, -0.1841, -0.2814],
         [ 0.1176, -0.0888,  0.0617, -0.0142,  0.0818,  0.1042,  0.2191,
          -0.1904,  0.1060,  0.2371,  0.1683,  0.1092,  0.0542, -0.1224,
          -0.0410,  0.0388, -0.2000, -0.0624, -0.0393,  0.0318,  0.0554,
          -0.0615, -0.0725,  0.1469,  0.2720,  0.1751, -0.0576, -0.0669,
          -0.0062, -0.2414, -0.1992, -0.2961],
         [ 0.0731, -0.0706,  0.0947, -0.0608,  0.0897,  0.1043,  0.1654,
          -0.1146,  0.0608,  0.2052,  0.0981,  0.0356,  0.1239, -0.0489,
          -0.0732,  0.0068, -0.1020, -0.0595, -0.0776, -0.0383,  0.0699,
          -0.1072, -0.1308,  0.1756,  0.1683,  0.1028, -0.0533, -0.1249,
          -0.0881, -0.2516, -0

tensor parallel of Attention

In [246]:
n_devices = 2
num_heads = num_heads // n_devices
num_heads

2

In [247]:
Wq.weight.shape

torch.Size([32, 32])

In [248]:
Wq_blocks = Wq.weight.split(num_heads*head_dim, dim=0)
print(Wq_blocks[0].shape)
print(Wq_blocks[1].shape)

torch.Size([16, 32])
torch.Size([16, 32])


In [249]:
Wk_blocks = Wk.weight.split(num_heads * head_dim, dim=0)
Wv_blocks = Wv.weight.split(num_heads * head_dim, dim=0)

In [250]:
Wo_blocks = Wo.weight.split(num_heads * head_dim, dim=1)
print(Wo_blocks[0].shape)
print(Wo_blocks[1].shape)

torch.Size([32, 16])
torch.Size([32, 16])


In [251]:
# first device
q1 = F.linear(input, Wq_blocks[0])
k1 = F.linear(input, Wk_blocks[0])
v1 = F.linear(input, Wv_blocks[0])
q1.shape

torch.Size([1, 5, 16])

In [252]:
q1 = q1.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
k1 = k1.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
v1 = v1.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)

In [253]:
attn_weight = torch.matmul(q1, k1.transpose(2,3))/math.sqrt(head_dim)
attn_weight = F.softmax(attn_weight, dim=-1)
attn_output1 = torch.matmul(attn_weight, v1)
attn_output1 = attn_output1.transpose(1, 2).contiguous()
attn_output1.shape

torch.Size([1, 5, 2, 8])

In [254]:
attn_output1 = attn_output1.reshape(batch_size, seq_len, num_heads*head_dim)
attn_output1.shape

torch.Size([1, 5, 16])

In [255]:
Wo_blocks[0].shape

torch.Size([32, 16])

In [256]:
attn_output1 = F.linear(attn_output1, Wo_blocks[0])
attn_output1.shape

torch.Size([1, 5, 32])

In [257]:
# second device
q2 = F.linear(input, Wq_blocks[1])
k2 = F.linear(input, Wk_blocks[1])
v2 = F.linear(input, Wv_blocks[1])
q2 = q2.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
k2 = k2.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
v2 = v2.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
attn_weight = torch.matmul(q2, k2.transpose(2,3))/math.sqrt(head_dim)
attn_weight = F.softmax(attn_weight, dim=-1)
attn_output2 = torch.matmul(attn_weight, v2)
attn_output2 = attn_output2.transpose(1, 2).contiguous()
attn_output2 = attn_output2.reshape(batch_size, seq_len, num_heads*head_dim)
attn_output2 = F.linear(attn_output2, Wo_blocks[1])
attn_output2.shape

torch.Size([1, 5, 32])

In [258]:
attn_output1+attn_output2

tensor([[[ 0.1467, -0.1020,  0.0323, -0.0405,  0.1146,  0.0616,  0.1730,
          -0.2219,  0.0287,  0.2283,  0.1665,  0.1526,  0.0512, -0.1623,
          -0.0435,  0.0033, -0.1520, -0.0482, -0.0878,  0.0365,  0.1547,
          -0.1591, -0.0755,  0.0978,  0.1997,  0.1412,  0.0167, -0.0183,
           0.0036, -0.2445, -0.1841, -0.2814],
         [ 0.1176, -0.0888,  0.0617, -0.0142,  0.0818,  0.1042,  0.2191,
          -0.1904,  0.1060,  0.2371,  0.1683,  0.1092,  0.0542, -0.1224,
          -0.0410,  0.0388, -0.2000, -0.0624, -0.0393,  0.0318,  0.0554,
          -0.0615, -0.0725,  0.1469,  0.2720,  0.1751, -0.0576, -0.0669,
          -0.0062, -0.2414, -0.1992, -0.2961],
         [ 0.0731, -0.0706,  0.0947, -0.0608,  0.0897,  0.1043,  0.1654,
          -0.1146,  0.0608,  0.2052,  0.0981,  0.0356,  0.1239, -0.0489,
          -0.0732,  0.0068, -0.1020, -0.0595, -0.0776, -0.0383,  0.0699,
          -0.1072, -0.1308,  0.1756,  0.1683,  0.1028, -0.0533, -0.1249,
          -0.0881, -0.2516, -0

In [259]:
x = torch.randn(10,20)
ww = nn.Linear(20,30,bias=False)
my_w = ww.weight
o1 = ww(x).detach().numpy()
o2 = F.linear(x, my_w).detach().numpy()
np.allclose(o1,o2)

True