In [285]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

In [286]:
train_ds = load_dataset("stanfordnlp/imdb", split="train[:5%]")
test_ds = load_dataset("stanfordnlp/imdb", split="test[:5%]")

In [287]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [288]:
def collate_fn(batch):
  max_len = 400
  texts, labels = [], []
  for row in batch:
    labels.append(row['label'])
    texts.append(row['text'])

  texts = torch.LongTensor(tokenizer(texts, padding=True, truncation=True, max_length=max_len).input_ids)
  labels = torch.LongTensor(labels)

  return texts, labels

train_loader = DataLoader(
    train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
    test_ds, batch_size=32, shuffle=False, collate_fn=collate_fn
)

In [289]:
# text, label 확인하기
for texts, labels in train_loader:
    print("texts shape:", texts.shape)
    print("labels shape:", labels.shape)
    print("texts (token IDs):", texts[:2])
    print("labels:", labels[:10])
    break

texts shape: torch.Size([32, 400])
labels shape: torch.Size([32])
texts (token IDs): tensor([[  101,  2043,  2097,  1996, 11878,  2644,  1029,  1045,  2196,  2215,
          2000,  2156,  2178,  2544,  1997,  1037,  4234,  8594,  2153,  1012,
          2027,  2562,  2006,  2437,  5691,  2007,  1996,  2168,  2466,  1010,
          4634,  2058,  2169,  2060,  1999,  2667,  2000,  2191,  1996,  3185,
          2488,  2059,  1996,  2717,  1010,  2021, 13718,  8246,  2000,  2079,
          2061,  1010,  2004,  2023,  2003,  2025,  1037,  2204,  2466,  1012,
          7191,  6553,  1010,  2214,  1011, 13405,  1010,  4603,  3407,  1011,
          3241,  1012,  2004,  2065,  2111,  4553,  1012,  1996,  3365,  2367,
          4617,  1997,  2023,  2143,  6011,  2008,  2057,  2123, 29658,  2102,
          1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,   

In [290]:
from torch import nn
from math import sqrt

In [291]:
# SelfAttention Module을 Multi-head attention으로 확장
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0 # model의 차원은 head의 수로 나누어 떨어져야함

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        
        self.dense = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask):
        # print("x shape", x.shape)
        batch_size = x.size(0)

        q, k, v = self.wq(x), self.wk(x), self.wv(x) # (B, S, D)
        # print("q shape", q.shape)
        # print("k shape", k.shape)
        # print("v shape", v.shape)

        # Q, K, V (B, S, D)를 (B, S, H, D')로 reshape 
        # D = H X D' => D' = D / H = d_head

        # [step 1] (B, S, D) -> (B, H, S, D')
        q = q.view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2)

        # print("q reshaped:", q.shape)
        # print("k reshaped:", k.shape)
        # print("v reshaped:", v.shape)

        # [step 2] Attention score : (B, H, S, D') X (B, H, D', S) = (B, H, S, S)
        score = torch.matmul(q, k.transpose(-1, -2)) / sqrt(self.d_head)

        if mask is not None:
            mask = mask.unsqueeze(1)
            score = score + (mask * -1e9)

        score = self.softmax(score)
        score = self.dropout(score)
        result = torch.matmul(score, v)

        # transpose(1, 2)하고 나면 (B, S, H, D')
        # 다시 (S, D)로 reshape
        # contiguous()는 transpose하고 나서 tensor의 연속성을 보장해주기위해 사용함 (안 하면 오류 발생)
        result = result.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)  # (B, S, D)
        return self.dense(result)

* contiguous를 안 넣어줬을 때  
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [293]:
class TransformerLayer(nn.Module):
    def __init__(self, d_model, n_heads, dff, dropout=0.1):
        super().__init__()

        self.mha = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        output = self.mha(x, mask)
        
        x1 = self.dropout1(output)
        x1 = self.norm1(x + x1)

        x2 = self.ffn(x1)
        x2 = self.dropout2(x2)
        x2 = self.norm2(x2 + x1)

        return x2

In [294]:
import numpy as np

def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, None], np.arange(d_model)[None, :], d_model)
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[None, ...]

    return torch.FloatTensor(pos_encoding)

max_len = 400
print(positional_encoding(max_len, 256).shape)

torch.Size([1, 400, 256])


In [295]:
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, dff, dropout=0.1):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.dff = dff
        self.dropout = nn.Dropout(dropout)

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.parameter.Parameter(positional_encoding(max_len, d_model), requires_grad=False)
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, n_heads=4, dff=dff, dropout=dropout) for _ in range(n_layers)
        ])
        self.classification = nn.Linear(d_model, 1)

    def forward(self, x):
        mask = (x == tokenizer.pad_token_id)
        mask = mask[:, None, :]

        seq_len = x.shape[1]
        x = self.embedding(x)
        x = x * sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len]
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, mask)

        x = x[:, 0]
        x = self.classification(x)
        
        return x

In [296]:
model = TextClassifier(len(tokenizer), d_model=64, n_layers=5, dff=128)

In [297]:
print(torch.__version__)
print(torch.backends.mps.is_built())
print(torch.backends.mps.is_available())

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

2.6.0
True
True


In [298]:
from torch.optim import Adam

lr = 0.001
model = model.to(device)
loss_fn = nn.BCEWithLogitsLoss()

optimizer = Adam(model.parameters(), lr=lr)

In [299]:
def accuracy(model, dataloader):
    correct = 0
    total = 0

    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            labels = labels.float().unsqueeze(1)  # (B, 1)

            outputs = model(inputs)
            preds = (outputs > 0).long()

            correct += (preds == labels.long()).sum().item()
            total += labels.size(0)

    return correct / total if total > 0 else 0

In [300]:
train_accs = []
test_accs = []

n_epochs = 50

for epoch in range(n_epochs):
    total_loss = 0.
    model.train()

    for data in train_loader:
        model.zero_grad()
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # print("labels", labels[:0])

        labels = labels.float().unsqueeze(1)
        
        preds = model(inputs)
        loss = loss_fn(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch {epoch + 1} | Train Loss: {total_loss}')

    with torch.no_grad():
        model.eval()
        train_acc = accuracy(model, train_loader)
        test_acc = accuracy(model, test_loader)
        train_accs.append(train_acc)
        test_accs.append(test_acc)
        print(f'=====> Train acc: {train_acc:.3f} | Test acc: {test_acc:.3f}')

Epoch 1 | Train Loss: 0.6040689274668694
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 2 | Train Loss: 0.1068037748336792
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 3 | Train Loss: 0.05715763568878174
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 4 | Train Loss: 0.03455835580825806
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 5 | Train Loss: 0.02263106405735016
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 6 | Train Loss: 0.01572476327419281
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 7 | Train Loss: 0.011381357908248901
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 8 | Train Loss: 0.008488595485687256
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 9 | Train Loss: 0.00651174783706665
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 10 | Train Loss: 0.005070596933364868
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 11 | Train Loss: 0.004045933485031128
=====> Train acc: 1.000 | Test acc: 1.000
Epoch 12 | Train Loss: 0.0032656490802764893
=====> Train acc: 1.000 | T