<a href="https://colab.research.google.com/github/hugeclear/study_with_chatgpt/blob/master/implement_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1-head attentionを実装しよう


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# condition
seq_len = 3
d_model = 4
d_k = 2
d_v = 2

# input (random)
X = torch.randn(seq_len, d_model) # shape=(3,4)

# definition of weighted matrix(nn.Parameter, nn.Linear is also ok)
W_Q = torch.randn(d_model, d_k)
W_K = torch.randn(d_model, d_k)
W_V = torch.randn(d_model, d_v)

# Q,K,Vを計算
Q = X @ W_Q
K = X @ W_K
V = X @ W_V

# K,Qの行列積をとってソフトマックスを適用する
scores = Q @ K.T

attn = F.softmax(scores,dim=1)

output = attn @ V

print(f"Q: {Q}")
print(f"K: {K}")
print("scores:", scores)
print("attn:", attn)
print("output:", output)


Q: tensor([[ 0.0555,  0.0939],
        [ 1.4626,  2.3343],
        [-0.4570, -0.8459]])
K: tensor([[-0.1360, -0.2635],
        [ 1.1948, -1.5937],
        [-0.3831,  1.4915]])
scores: tensor([[-0.0323, -0.0833,  0.1188],
        [-0.8140, -1.9728,  2.9212],
        [ 0.2850,  0.8020, -1.0865]])
attn: tensor([[0.3212, 0.3052, 0.3736],
        [0.0231, 0.0073, 0.9696],
        [0.3412, 0.5722, 0.0866]])
output: tensor([[ 0.2101,  0.0382],
        [-0.3863,  0.3088],
        [ 0.6750, -0.0187]])


softmaxを使った重み付けattention出力ができるようになった

In [None]:
import torch
import torch.nn as nn
import torch.nn.Functional as F

seq_len = 3
d_model = 4
d_k = 2
d_v = 2

X = torch.Linear(seq_len, d_model)

W_Q = torch.randn(d_model, d_k)
W_K = torch.randn(d_model, d_k)
W_V = torch.randn(d_model, d_v)

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

scores = ( Q @ K.T) / sqrt(d_k)
# なぜsqrtで割る必要があるのか。
仮にq@k_tが大き過ぎた場合、scoreも大きくなり、softmaxで行方向に正規化した際にそれぞれの数値の差が小さくなる？


In [None]:
import torch
import numpy as np
def layer_norm(x, eps=1e-5):
    # x : shape(d_model, )

    mean = x.mean()
    var = x.var(unbiased=False)
    x_norm = (x - mean) / torch.sqrt(var + eps)

    gamma = torch.ones_like(x)
    beta = torch.zeros_like(x)

    out = x_norm * gamma + beta

    return out

# テスト
x = torch.tensor([5.0, 6.0, 7.0, 8.0])
print(layer_norm(x))

tensor([-1.3416, -0.4472,  0.4472,  1.3416])


Transformer blockの実装へ

Attention -> Residual -> LayerNorm -> FFN

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MiniTransformerBlock(nn.Module):
    def __init__(self, d_model, d_k, d_v, d_hidden):
        super().__init__()

        # Q,K,Vの線形層
        self.W_Q = nn.Linear(d_model,d_k)
        self.W_K = nn.Linear(d_model,d_k)
        self.W_V = nn.Linear(d_model,d_v)

        #出力 projection
        self.W_O = nn.Linear(d_v, d_model)

        # LayerNorm
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.GELU(), # ReLUより滑らかな活性化関数
            nn.Linear(d_hidden, d_model)
        )


    def forward(self, x):
        # 1. LayerNorm
        h = self.ln1(x)

        # 2. Q,K,V
        Q = self.W_Q(h)
        K = self.W_K(h)
        V = self.W_V(h)

        # 3. scaled-dot-product attention
        scores = Q @ K.transpose(-1,-2) / math.sqrt(Q.size(-1))
        attn = torch.softmax(scores, dim=-1)
        out = attn @ V

        out = self.W_O(out)

        #5. Residual
        x = x + out

        h2 = self.ln2(x)
        out2 = self.ffn(h2)
        x = x + h2

        return x

In [None]:
block = MiniTransformerBlock(d_model, d_k=2, d_v=2,d_hidden=8)

X=torch.randn(3,4)
out = block(X)

print("入力 X:", X)
print("出力 out:", out)
print("shape:", out.shape)

入力 X: tensor([[ 0.9023,  0.3733,  0.3107,  0.7211],
        [ 1.9552,  0.5846,  1.4028,  0.3647],
        [ 0.7259,  0.8475, -0.3782,  0.9767]])
出力 out: tensor([[ 0.1082, -1.8860,  2.0069,  1.5440],
        [ 1.7865, -1.2831,  3.5608, -0.2731],
        [-0.3764, -0.6505, -0.1868,  2.8405]], grad_fn=<AddBackward0>)
shape: torch.Size([3, 4])


In [None]:
class MultHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads
        # 線形変換（Q,K,V)で全部まとめて作る
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, D = x.shape # Batch, token長, d_model

        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        Q = Q.view(B,T,num_heads,d_k).transpose(1,2)
        K = K.view(B,T,num_heads, d_k),transpose(1,2)
        V = V.view(B,T,num_heads, d_v).transpose(1,2)
        K = K // d_k
        V = V // d_v

        scores = Q @ K.transpose(-1, -2) / math.sqrt(self.d_k)
        attn = torch.softmax(scores, dim=-1)
        out = attn @ V

        out = out.transpose(1,2).contiguous().view(B,T,D)

        out = self.W_O(out)
        return out

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)

        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, D = x.shape

        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        # head分割
        Q = Q.view(B,T,self.num_heads, self.d_k)
        K = K.view(B, T, self.num_heads, self.d_k)
        V = V.view(B, T, self.num_heads, self.d_v)

        scores = Q @ K.transpose(-1,-2) / math.sqrt(self.d_k)
        attn = torch.softmax(scores, dim=-1)
        out = attn @ V

        # ヘッドの結合
        out = out.transpose(1,2).contiguous().view(B,T,D)

        # 出力projection
        out = self.W_O(out)

        return out

In [None]:
mha = MultiHeadAttention(d_model=4, num_heads=2)
x = torch.randn(1, 3, 4)
out = mha(x)

print("入力:", x)
print("出力:", out)
print("shape:", out.shape)


入力: tensor([[[-1.3003, -1.8968,  1.7726,  0.7622],
         [ 0.8025,  0.5793, -0.3923,  0.2900],
         [ 0.2742, -0.1099, -0.2101,  2.7433]]])
出力: tensor([[[-0.4964,  0.6343, -0.0478,  0.3425],
         [-1.1541,  0.4896, -0.2930,  0.3507],
         [-0.2134,  0.2677, -0.4627, -0.2826]]], grad_fn=<ViewBackward0>)
shape: torch.Size([1, 3, 4])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MiniTransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_hidden):
        super().__init__()

        # Q,K,Vの線形層
        self.W_Q = nn.Linear(d_model,d_k)
        self.W_K = nn.Linear(d_model,d_k)
        self.W_V = nn.Linear(d_model,d_v)

        #出力 projection
        self.W_O = nn.Linear(d_v, d_model)

        # LayerNorm
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_model, num_heads)

        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.GELU(), # ReLUより滑らかな活性化関数
            nn.Linear(d_hidden, d_model)
        )


    def forward(self, x):
        # 1. LayerNorm
        h = self.ln1(x)
        attn_out = self.mha(h)

        # 2. Residual 1
        x = x + out

        # 3. LayerNorm -> FFN
        h2 = self.ln2(x)
        ffn_out = self.ffn(h2)

        # 4. residual 2
        x = x + h2

        return x

In [None]:
block = MiniTransformerBlock(d_model=4, num_heads=2, d_hidden=8)

X = torch.randn(1, 3, 4)
out = block(X)

print("入力:", X)
print("出力:", out)
print("shape:", out.shape)


入力: tensor([[[-0.8717,  0.7862,  0.0088,  1.2842],
         [-0.8768, -1.0473, -1.0807, -0.9000],
         [-0.9916,  0.1127,  1.7978, -1.5221]]])
出力: tensor([[[-2.8367,  2.2549, -0.4098,  2.6315],
         [-3.4882,  0.3624, -1.7703,  0.3845],
         [-1.9120,  0.9449,  2.6656, -2.9927]]], grad_fn=<AddBackward0>)
shape: torch.Size([1, 3, 4])
