## Multi-Head Attention 실습

In [None]:
from einops import rearrange
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, dim_head=64, dropout=0.):
        super().__init__()
        self.num_heads = num_heads

        self.d_h = embed_dim / num_heads

        self.qkv = nn.Linear(embed_dim, 3*embed_dim, bias=False)

        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = self.norm(x)
        qkv = self.qkv(x)

        qkv = qkv.chunk(3, dim=-1)
        q, k, v = map(lambda m: rearrange(m, 'b p (h d) -> b h p d', h=self.num_heads), qkv)
        # qkv에 shape 할당O, 각각 할당X

        attn_score = self.softmax((q.matmul(k.transpose(-1,-2))) / math.sqrt(self.d_h))
        attn_score = self.dropout(attn_score)

        out = torch.matmul(attn_score, v)
        out = rearrange(out, 'b h p d -> b p (h d)')
        out = self.linear(out)

        return out

In [None]:
multihead_att = MultiHeadAttention()
input_ = torch.randn((8,196,768))
output = multihead_att(input_)

print("출력 shape: ", output.shape)

출력 shape:  torch.Size([8, 196, 768])
