데이터 정의 및 기존 코드

In [1]:
import torch
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

ds = load_dataset("stanfordnlp/imdb")
tokenizer = torch.hub.load(
    "huggingface/pytorch-transformers", "tokenizer", "bert-base-uncased"
)


# 마지막 단어 제외하고, 해당 단어를 라벨에 넣는 collate 함수
def collate_fn_2(batch):
    max_len = 400
    texts, labels = [], []
    for row in batch:
        input_ids = tokenizer(
            row["text"], truncation=True, max_length=max_len
        ).input_ids

        # 구두점은 텍스트에 남겨두고 라벨 후보에서만 제외
        if len(input_ids) > 2:
            candidate_label = input_ids[-2]
            if candidate_label not in {1012, 999, 1028, 102, 101}:
                labels.append(candidate_label)
            else:
                labels.append(
                    input_ids[-3]
                )  # 마지막 두 번째 단어가 의미 없으면 그 앞 단어 선택
        else:
            labels.append(tokenizer.pad_token_id)

        texts.append(torch.LongTensor(input_ids[:-2]))  # 마지막 두 단어 제외

    texts = pad_sequence(texts, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = torch.LongTensor(labels)

    return texts, labels


train_loader = DataLoader(
    ds["train"], batch_size=64, shuffle=True, collate_fn=collate_fn_2
)
test_loader = DataLoader(
    ds["test"], batch_size=64, shuffle=False, collate_fn=collate_fn_2
)

text, label = next(iter(train_loader))
print(text.shape, label.shape)

max_word_len = 398

Using cache found in /Users/yunhyeokchoi/.cache/torch/hub/huggingface_pytorch-transformers_main


torch.Size([64, 398]) torch.Size([64])


In [2]:
from torch import nn
from math import sqrt
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim import Adam
import numpy as np
import matplotlib.pyplot as plt


# 결과 출력을 위한 모니터
class AccuracyMonitor:
    def __init__(
        self, models, dataloaders, labels, title="Model Accuracies", accuracy_fn=None
    ):
        """
        models: 모델 리스트
        dataloaders: 데이터로더 리스트
        labels: 모델 및 데이터로더에 대한 레이블 리스트
        title: 그래프 제목
        accuracy_fn: 사용자 정의 정확도 함수 (default_accuracy를 기본값으로 사용)
        """
        if not (len(models) == len(dataloaders) == len(labels)):
            raise ValueError(
                "models, dataloaders, labels는 모두 같은 길이를 가져야 합니다."
            )

        self.models = models
        self.dataloaders = dataloaders
        self.labels = labels
        self.title = title
        self.acc_lists = [[] for _ in labels]
        self.accuracy_fn = accuracy_fn if accuracy_fn else self.default_accuracy

    def default_accuracy(self, model, dataloader, **kwargs):
        """
        기본 정확도 계산 함수. **kwargs를 통해 추가 옵션 지원.
        """
        cnt = 0
        acc = 0
        model.eval()
        with torch.no_grad():
            for data in dataloader:
                inputs, labels = data
                preds = model(inputs)
                preds = torch.argmax(preds, dim=-1)
                cnt += labels.shape[0]
                acc += (labels == preds).sum().item()
        model.train()
        return acc / cnt if cnt > 0 else 0.0

    def update_accuracies(self, verbose=False, **kwargs):
        """
        models와 dataloaders를 평가하고 acc_lists에 기록.
        verbose: True인 경우 정확도 업데이트 로그 출력.
        """
        for i, (model, dataloader) in enumerate(zip(self.models, self.dataloaders)):
            acc = self.accuracy_fn(model, dataloader, **kwargs)
            self.acc_lists[i].append(acc)
            if verbose:
                print(f"Updated Accuracy ({self.labels[i]}): {acc:.4f}")

    def plot(self, epoch=None, save_path=None):
        """
        정확도 그래프를 출력하고 저장하는 기능.
        epoch: 그래프 제목에 표시할 에폭 정보
        save_path: 그래프를 저장할 경로 (None이면 저장하지 않음)
        """
        if not self.acc_lists or len(self.acc_lists[0]) == 0:
            print("No accuracies to plot")
            return

        x = np.arange(len(self.acc_lists[0]))  # 에폭 수만큼 x축 생성

        plt.figure(figsize=(10, 6))  # 그래프 크기 조정
        for acc_list, label in zip(self.acc_lists, self.labels):
            plt.plot(x, acc_list, label=label, marker="o")

        title = f"{self.title} at Epoch {epoch}" if epoch is not None else self.title
        plt.title(title)
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.grid(True)
        plt.legend()
        plt.ylim(0.0, 1.0)

        if save_path:
            plt.savefig(save_path)
            print(f"Plot saved to {save_path}")
        else:
            plt.show()

# MPS 사용하기 위한 설정
device = torch.device("mps") if torch.backends.mps.is_built() else torch.device("cpu")

class SelfAttention(nn.Module):
    def __init__(self, input_dim, d_model):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model

        self.wq = nn.Linear(
            input_dim, d_model
        )  
        self.wk = nn.Linear(input_dim, d_model)
        self.wv = nn.Linear(input_dim, d_model)
        self.dense = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        score = torch.matmul(
            q, k.transpose(-1, -2)
        )  # (B, S, D) * (B, D, S) = (B, S, S)
        score = score / sqrt(self.d_model)

        if mask is not None:
            score = score + (mask * -1e9)

        score = self.softmax(score)
        result = torch.matmul(score, v)
        result = self.dense(result)

        return result


class SATransformerLayer(nn.Module):
    def __init__(self, input_dim, d_model, dff):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.dff = dff

        self.sa = SelfAttention(input_dim, d_model)
        self.ffn = nn.Sequential(  # 토큰 백터를 개별적으로 처리후 빈선형 변환을 통해 풍부한 표현력 제공
            nn.Linear(d_model, dff), nn.ReLU(), nn.Linear(dff, d_model)
        )

    def forward(self, x, mask):
        x = self.sa(x, mask)
        x = self.ffn(x)

        return x

def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(
        np.arange(position)[:, None], np.arange(d_model)[None, :], d_model
    )
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[None, ...]

    return torch.FloatTensor(pos_encoding)

class SALastWordPrediction(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, dff):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.dff = dff

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.parameter.Parameter(
            positional_encoding(max_word_len, d_model), requires_grad=False
        )
        self.layers = nn.ModuleList(
            [SATransformerLayer(d_model, d_model, dff) for _ in range(n_layers)]
        )
        self.classification = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        mask = (x == tokenizer.pad_token_id).unsqueeze(1)
        seq_len = x.shape[1]

        x = self.embedding(x)
        x = x * sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len]

        for layer in self.layers:
            x = layer(x, mask)

        x = self.classification(x)
        x = x[:, -1, :]
        return x

## [MY CODE] Multi Head Attention 구현해 마지막 단어 예측하는 모델 만들기

In [5]:
class Multi_Head_Attention(nn.Module):
    def __init__(self, input_dim, d_model, n_head):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.n_head = n_head
        self.d_devied = d_model // n_head

        # Linear layers for Q, K, V
        self.wq = nn.Linear(input_dim, d_model)  
        self.wk = nn.Linear(input_dim, d_model)
        self.wv = nn.Linear(input_dim, d_model)

        # Final linear layer
        self.dense = nn.Linear(d_model, d_model)

        # Softmax
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        batch_size = x.size(0)
        seq_len = x.size(1)

        # Linear transformations
        q = self.wq(x).view(batch_size, seq_len, self.n_head, self.d_devied).transpose(1, 2)  # (B, H, S, D')
        k = self.wk(x).view(batch_size, seq_len, self.n_head, self.d_devied).transpose(1, 2)
        v = self.wv(x).view(batch_size, seq_len, self.n_head, self.d_devied).transpose(1, 2)

        # Attention scores
        score = torch.matmul(q, k.transpose(-1,-2)) / sqrt(self.d_devied)  # (B, H, S, S)

        if mask is not None:
            mask = mask.unsqueeze(1).expand(-1, self.n_head, -1, -1)  # mask 도 차원 변환(B, H, 1, S)
            score = score + (mask * -1e9)

        # Attention weights
        score = self.softmax(score)

        # Weighted sum
        result = torch.matmul(score, v)  # (B, H, S, D')

        # Concatenate heads and project
        result = result.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)  # (B, S, D)
        output = self.dense(result)

        return output
    

class MHATransformerLayer(nn.Module):
    def __init__(self, input_dim, d_model, n_head, dff, dropout_prob=0.1):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.n_head = n_head
        self.dff = dff

        self.MHA = Multi_Head_Attention(input_dim, d_model, n_head)
        self.ffn = nn.Sequential(  # 토큰 백터를 개별적으로 처리후 비선형 변환을 통해 풍부한 표현력 제공
            nn.Linear(d_model, dff), nn.ReLU(), nn.Linear(dff, d_model))
        self.dropout = nn.Dropout(dropout_prob)
        self.normal1 = nn.LayerNorm(d_model)
        self.normal2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):

        x1 = self.MHA(x, mask)
        x1 = self.dropout(x1)
        x1 = self.normal1(x1 + x)

        x2 = self.ffn(x1)
        x2 = self.dropout(x2)
        x2 = self.normal2(x2 + x1)

        return x2
    

class MHALastWordPrediction(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_head, dff):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.dff = dff

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.parameter.Parameter(
            positional_encoding(max_word_len, d_model), requires_grad=False
        )
        self.layers = nn.ModuleList(
            [MHATransformerLayer(d_model, d_model,n_head, dff) for _ in range(n_layers)]
        )
        self.classification = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        mask = (x == tokenizer.pad_token_id).unsqueeze(1)
        seq_len = x.shape[1]

        x = self.embedding(x)
        x = x * sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len]

        for layer in self.layers:
            x = layer(x, mask)

        x = self.classification(x)
        x = x[:, -1, :]
        return x


SA_model = SALastWordPrediction(len(tokenizer), 16, 2, 32)
MHA_model = MHALastWordPrediction(len(tokenizer), 16, 2, 4, 32)
lr = 0.01
loss_fn = nn.CrossEntropyLoss()

SA_optimizer = Adam(SA_model.parameters(), lr=lr)
MHA_optimizer = Adam(MHA_model.parameters(), lr=lr)

SA_model = SA_model.to(device)
MHA_model = MHA_model.to(device)

def custom_accuracy(model, dataloader):
    cnt = 0
    acc = 0

    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        preds = model(inputs)
        # preds = torch.argmax(preds, dim=-1)
        preds = torch.argmax(preds, dim=-1)  # 가장 높은 확률을 가진 클래스를 선택

        cnt += labels.shape[0]
        acc += (labels == preds).sum().item()

    return acc / cnt

## [MY CODE] 마지막 단어 예측 Self Attention, Multi Head Attention 비교

In [6]:
n_epochs = 50

trans_monitor = AccuracyMonitor(
    models=[SA_model, MHA_model],
    dataloaders=[train_loader, train_loader],
    labels=["Self", "Multi"],
    accuracy_fn=custom_accuracy,
)

for epoch in range(n_epochs):

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}", unit="batch") as pbar:
        SA_model.train()
        MHA_model.train()
        for data in train_loader:

            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            labels = labels.long()

            SA_model.zero_grad()
            preds = SA_model(inputs)
            loss = loss_fn(preds, labels)
            loss.backward()
            SA_optimizer.step()

            MHA_model.zero_grad()
            MHApreds = MHA_model(inputs)
            MHAloss = loss_fn(MHApreds, labels)
            MHAloss.backward()
            MHA_optimizer.step()

            pbar.set_postfix(SA_loss=loss.item(), MHA_loss=MHAloss.item())


        with torch.no_grad():
            SA_model.eval()
            MHA_model.eval()
            trans_monitor.update_accuracies()

trans_monitor.plot()

Epoch 1/50:   0%|          | 0/391 [00:00<?, ?batch/s]

KeyboardInterrupt: 