# Multi-head attention(MHA) 구현


In [1]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)


ds = load_dataset("stanfordnlp/imdb")
tokenizer = torch.hub.load(
    "huggingface/pytorch-transformers", "tokenizer", "bert-base-uncased"
)


def collate_fn(batch):
    max_len = 400
    texts, labels = [], []
    for row in batch:
        labels.append(row["label"])
        texts.append(row["text"])

    texts = torch.LongTensor(
        tokenizer(texts, padding=True, truncation=True, max_length=max_len).input_ids
    )
    labels = torch.LongTensor(labels)

    return texts, labels


train_loader = DataLoader(
    ds["train"], batch_size=64, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
    ds["test"], batch_size=64, shuffle=False, collate_fn=collate_fn
)

Using cache found in /Users/obov/.cache/torch/hub/huggingface_pytorch-transformers_main


In [2]:
ds["train"][0]

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

## Self-attention


In [3]:
from torch import nn
from math import sqrt


class SelfAttention(nn.Module):
    def __init__(self, input_dim, d_model):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model

        self.wq = nn.Linear(input_dim, d_model)
        self.wk = nn.Linear(input_dim, d_model)
        self.wv = nn.Linear(input_dim, d_model)
        self.dense = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        score = torch.matmul(
            q, k.transpose(-1, -2)
        )  # (B, S, D) * (B, D, S) = (B, S, S)
        score = score / sqrt(self.d_model)

        if mask is not None:
            score = score + (mask * -1e9)

        score = self.softmax(score)
        result = torch.matmul(score, v)
        result = self.dense(result)

        return result

## Multi-Head Attention

ref : [4-2. Transformer(Multi-head Attention) [초등학생도 이해하는 자연어처리]](https://codingopera.tistory.com/44)


In [4]:
class MultiHeadAttention(nn.Module):

    def __init__(
        self,
        input_dim,
        d_model,
        n_heads,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.n_heads = n_heads

        assert (
            d_model % n_heads == 0
        ), f"d_model ({d_model}) must be a multiple of n_heads ({n_heads})"

        self.d_prime = self.d_model // self.n_heads

        self.wqkvs = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "wq": nn.Linear(input_dim, self.d_prime),
                        "wk": nn.Linear(input_dim, self.d_prime),
                        "wv": nn.Linear(input_dim, self.d_prime),
                    }
                )
                for _ in range(self.n_heads)
            ]
        )

        self.dense = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask):
        results = []
        for wqkv in self.wqkvs:
            q, k, v = wqkv["wq"](x), wqkv["wk"](x), wqkv["wv"](x)
            score = torch.matmul(
                q,
                k.transpose(-1, -2),
            )  # (B, S, D') * (B, D', S) = (B, S, S)
            score = score / sqrt(self.d_model)

            if mask is not None:
                score = score + (mask * -1e9)

            score = self.softmax(score)
            result = torch.matmul(
                score,  # (B, S, S)
                v,  # (B, S, D')
            )  # (B, S, D')
            results.append(result)
        results = torch.cat(results, dim=-1)  # (B, S, D' * n_heads)
        output = self.dense(results)
        return output

In [5]:
class TransformerLayer(nn.Module):

    def __init__(self, input_dim, d_model, dff, n_heads, dropout_prob=0.5):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.dff = dff

        self.mha = MultiHeadAttention(input_dim, d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff), nn.ReLU(), nn.Linear(dff, d_model)
        )
        self.dropout = nn.Dropout(dropout_prob)
        self.layer_norm_mha = nn.LayerNorm(normalized_shape=d_model)
        self.layer_norm_ffn = nn.LayerNorm(normalized_shape=d_model)

    def forward(self, x, mask):
        x1 = self.mha(x, mask)
        x1 = self.dropout(x)
        x1 = self.layer_norm_mha(x1 + x)

        x2 = self.ffn(x)
        x2 = self.dropout(x2)
        x2 = self.layer_norm_ffn(x2 + x1)

        return x2

In [6]:
import numpy as np


def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates


def positional_encoding(position, d_model):
    angle_rads = get_angles(
        np.arange(position)[:, None], np.arange(d_model)[None, :], d_model
    )
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[None, ...]

    return torch.FloatTensor(pos_encoding)


max_len = 400
print(positional_encoding(max_len, 256).shape)

torch.Size([1, 400, 256])


In [7]:
class TextClassifier(nn.Module):

    def __init__(self, vocab_size, d_model, n_layers, dff, n_heads):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.dff = dff

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.parameter.Parameter(
            positional_encoding(max_len, d_model), requires_grad=False
        )
        self.layers = nn.ModuleList(
            [TransformerLayer(d_model, d_model, dff, n_heads) for _ in range(n_layers)]
        )
        self.classification = nn.Linear(d_model, 1)

    def forward(self, x):
        mask = x == tokenizer.pad_token_id
        mask = mask[:, None, :]
        seq_len = x.shape[1]

        x = self.embedding(x)
        x = x * sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len]

        for layer in self.layers:
            x = layer(x, mask)

        x = x[:, 0]
        x = self.classification(x)

        return x


model = TextClassifier(
    len(tokenizer),
    n_layers=5,
    dff=32,
    d_model=32,
    n_heads=4,
)

In [8]:
from torch.optim import Adam

lr = 0.001
model = model.to("mps")
loss_fn = nn.BCEWithLogitsLoss()

optimizer = Adam(model.parameters(), lr=lr)

In [9]:
import numpy as np
import matplotlib.pyplot as plt


def accuracy(model, dataloader):
    cnt = 0
    acc = 0

    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to("mps"), labels.to("mps")

        preds = model(inputs)
        # preds = torch.argmax(preds, dim=-1)
        preds = (preds > 0).long()[..., 0]

        cnt += labels.shape[0]
        acc += (labels == preds).sum().item()

    return acc / cnt

In [10]:
n_epochs = 50

for epoch in range(n_epochs):
    total_loss = 0.0
    model.train()
    for data in train_loader:
        model.zero_grad()
        inputs, labels = data
        inputs, labels = inputs.to("mps"), labels.to("mps").float()

        preds = model(inputs)[..., 0]
        loss = loss_fn(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch:3d} | Train Loss: {total_loss}")

    with torch.no_grad():
        model.eval()
        train_acc = accuracy(model, train_loader)
        test_acc = accuracy(model, test_loader)
        print(f"=========> Train acc: {train_acc:.3f} | Test acc: {test_acc:.3f}")

Epoch   0 | Train Loss: 273.24636590480804


KeyboardInterrupt: 