## [MY CODE] Last word prediction dataset 준비

In [1]:
import torch
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

ds = load_dataset("stanfordnlp/imdb")
tokenizer = torch.hub.load(
    "huggingface/pytorch-transformers", "tokenizer", "bert-base-uncased"
)


def collate_fn(batch):
    max_len = 50
    texts, labels = [], []
    for row in batch:
        labels.append(
            tokenizer(row["text"], truncation=True, max_length=max_len).input_ids[-2]
        )
        texts.append(
            torch.LongTensor(
                tokenizer(row["text"], truncation=True, max_length=max_len).input_ids[
                    :-2
                ]
            )
        )

    texts = pad_sequence(texts, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = torch.LongTensor(labels)

    return texts, labels


train_loader = DataLoader(
    ds["train"], batch_size=64, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
    ds["test"], batch_size=64, shuffle=False, collate_fn=collate_fn
)

text, label = next(iter(train_loader))
print(text.shape, label.shape)

max_word_len = 48


def count_1012_in_dataloader(dataloader):
    count_1012 = 0
    for texts, labels in dataloader:  # DataLoader에서 이미 처리된 데이터를 가져옴
        # 텐서를 리스트로 변환 후 1012 토큰 개수 카운트
        count_1012 += labels.tolist().count(1012)

    return count_1012


# 트레인 데이터 로더에서 1012 개수 찾기
count_1012 = count_1012_in_dataloader(train_loader)
print(f"Total number of 1012 tokens in test dataset labels: {count_1012}")

README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

Using cache found in /Users/yunhyeokchoi/.cache/torch/hub/huggingface_pytorch-transformers_main


torch.Size([64, 48]) torch.Size([64])


KeyboardInterrupt: 

## [MY CODE] Last word prediction dataset 준비 - 1012 토큰을 적게 포함하도록 collate 함수 재정의

In [2]:
# 마지막 단어 제외하고, 해당 단어를 라벨에 넣는 collate 함수
def collate_fn_2(batch):
    max_len = 50
    texts, labels = [], []
    for row in batch:
        input_ids = tokenizer(
            row["text"], truncation=True, max_length=max_len
        ).input_ids

        # 구두점은 텍스트에 남겨두고 라벨 후보에서만 제외
        if len(input_ids) > 2:
            candidate_label = input_ids[-2]
            if candidate_label not in {1012, 999, 1028, 102, 101}:
                labels.append(candidate_label)
            else:
                labels.append(
                    input_ids[-3]
                )  # 마지막 두 번째 단어가 의미 없으면 그 앞 단어 선택
        else:
            labels.append(tokenizer.pad_token_id)

        texts.append(torch.LongTensor(input_ids[:-2]))  # 마지막 두 단어 제외

    texts = pad_sequence(texts, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = torch.LongTensor(labels)

    return texts, labels


train_loader2 = DataLoader(
    ds["train"], batch_size=64, shuffle=True, collate_fn=collate_fn_2
)
test_loader2 = DataLoader(
    ds["test"], batch_size=64, shuffle=False, collate_fn=collate_fn_2
)

text, label = next(iter(train_loader2))
print(text.shape, label.shape)

max_word_len = 48

count_1012 = count_1012_in_dataloader(train_loader2)
print(f"Total number of 1012 tokens in test dataset labels: {count_1012}")

torch.Size([64, 48]) torch.Size([64])
Total number of 1012 tokens in test dataset labels: 94


## [LOG] 데이터 확인해보기 - 데이터 눈으로 확인해보기

In [154]:
text, label = next(iter(train_loader))
print(text.shape, label.shape)
print(text[0, 1].item(), label[1].item())
print(text[0, 2].item(), label[2].item())
print(text[0, 3].item(), label[3].item())
print(text[0, 4].item(), label[4].item())
print(text[0, 5].item(), label[5].item())
print(text[0, 6].item(), label[6].item())
print(text[0, 7].item(), label[7].item())
print(text[0, 8].item(), label[8].item())
print(text[0, 9].item(), label[9].item())
print(text[0, 10].item(), label[10].item())

text, label = next(iter(test_loader))
print(text.shape, label.shape)
print(text[0, 1].item(), label[1].item())
print(text[0, 2].item(), label[2].item())
print(text[0, 3].item(), label[3].item())
print(text[0, 4].item(), label[4].item())
print(text[0, 5].item(), label[5].item())
print(text[0, 6].item(), label[6].item())
print(text[0, 7].item(), label[7].item())
print(text[0, 8].item(), label[8].item())
print(text[0, 9].item(), label[9].item())
print(text[0, 10].item(), label[10].item())

text, label = next(iter(train_loader2))
print(text.shape, label.shape)
print(text[0, 1].item(), label[1].item())
print(text[0, 2].item(), label[2].item())
print(text[0, 3].item(), label[3].item())
print(text[0, 4].item(), label[4].item())
print(text[0, 5].item(), label[5].item())
print(text[0, 6].item(), label[6].item())
print(text[0, 7].item(), label[7].item())
print(text[0, 8].item(), label[8].item())
print(text[0, 9].item(), label[9].item())
print(text[0, 10].item(), label[10].item())


text, label = next(iter(test_loader2))
print(text.shape, label.shape)
print(text[0, 1].item(), label[1].item())
print(text[0, 2].item(), label[2].item())
print(text[0, 3].item(), label[3].item())
print(text[0, 4].item(), label[4].item())
print(text[0, 5].item(), label[5].item())
print(text[0, 6].item(), label[6].item())
print(text[0, 7].item(), label[7].item())
print(text[0, 8].item(), label[8].item())
print(text[0, 9].item(), label[9].item())
print(text[0, 10].item(), label[10].item())

torch.Size([64, 48]) torch.Size([64])
2004 2012
1996 2003
9338 2013
1997 2870
1000 2143
2023 2111
4768 7800
1010 1010
1000 2017
1045 12077
torch.Size([64, 48]) torch.Size([64])
1045 1012
2293 2000
16596 8149
1011 1997
10882 2049
1998 2005
2572 5595
5627 1045
2000 1010
2404 2002
torch.Size([64, 48]) torch.Size([64])
1045 5440
7021 1012
2023 1005
3185 2471
1999 2892
4966 1999
4289 2878
1012 1997
2026 5387
6100 2007
torch.Size([64, 48]) torch.Size([64])
1045 9767
2293 2000
16596 8149
1011 1997
10882 2049
1998 2005
2572 5595
5627 1045
2000 1010
2404 2002


## [LOG] collate_fn 에 따른 라벨 경우의수 확인하기

In [156]:
def extract_labels_from_dataloader(dataloader):
    """
    데이터로더에서 라벨에 사용된 모든 고유한 토큰을 추출.
    """
    label_tokens = set()
    for _, labels in dataloader:  # 데이터로더에서 라벨만 추출
        label_tokens.update(labels.tolist())  # 라벨을 set에 추가
    return list(label_tokens)  # 중복 제거된 라벨 토큰 반환


def merge_and_deduplicate_multiple_labels(*label_lists):
    """
    여러 개의 라벨 리스트를 합치고 중복을 제거합니다.
    """
    # 모든 리스트를 순회하며 중복 제거
    unique_labels = set()
    for labels in label_lists:
        unique_labels.update(labels)
    return list(unique_labels)


# Train 데이터로더에서 라벨 추출
train_labels = extract_labels_from_dataloader(train_loader)
# Test 데이터로더에서 라벨 추출
test_labels = extract_labels_from_dataloader(test_loader)
# 기존 col 함수의 라벨 경우의수
all_labels = merge_and_deduplicate_multiple_labels(train_labels, test_labels)
print(all_labels.__len__())

train_labels2 = extract_labels_from_dataloader(train_loader2)
test_labels2 = extract_labels_from_dataloader(test_loader2)
# 업데이트한 col 함수의 라벨 경우의수
all_labels2 = merge_and_deduplicate_multiple_labels(train_labels2, test_labels2)
print(all_labels2.__len__())

6906
7172


In [3]:
from torch import nn
from math import sqrt
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch


class SelfAttention(nn.Module):
    def __init__(self, input_dim, d_model):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model

        self.wq = nn.Linear(
            input_dim, d_model
        )  # selfAttention 은 이 가중치들을 학습하는 과정?
        self.wk = nn.Linear(input_dim, d_model)
        self.wv = nn.Linear(input_dim, d_model)
        self.dense = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        score = torch.matmul(
            q, k.transpose(-1, -2)
        )  # (B, S, D) * (B, D, S) = (B, S, S)
        score = score / sqrt(self.d_model)

        if mask is not None:
            score = score + (mask * -1e9)

        score = self.softmax(score)
        result = torch.matmul(score, v)
        result = self.dense(result)

        return result


class TransformerLayer(nn.Module):
    def __init__(self, input_dim, d_model, dff):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.dff = dff

        self.sa = SelfAttention(input_dim, d_model)
        self.ffn = nn.Sequential(  # 토큰 백터를 개별적으로 처리후 빈선형 변환을 통해 풍부한 표현력 제공
            nn.Linear(d_model, dff), nn.ReLU(), nn.Linear(dff, d_model)
        )

    def forward(self, x, mask):
        x = self.sa(x, mask)
        x = self.ffn(x)

        return x


def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates


def positional_encoding(position, d_model):
    angle_rads = get_angles(
        np.arange(position)[:, None], np.arange(d_model)[None, :], d_model
    )
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[None, ...]

    return torch.FloatTensor(pos_encoding)


# 결과 출력을 위한 모니터
class AccuracyMonitor:
    def __init__(
        self, models, dataloaders, labels, title="Model Accuracies", accuracy_fn=None
    ):
        """
        models: 모델 리스트
        dataloaders: 데이터로더 리스트
        labels: 모델 및 데이터로더에 대한 레이블 리스트
        title: 그래프 제목
        accuracy_fn: 사용자 정의 정확도 함수 (default_accuracy를 기본값으로 사용)
        """
        if not (len(models) == len(dataloaders) == len(labels)):
            raise ValueError(
                "models, dataloaders, labels는 모두 같은 길이를 가져야 합니다."
            )

        self.models = models
        self.dataloaders = dataloaders
        self.labels = labels
        self.title = title
        self.acc_lists = [[] for _ in labels]
        self.accuracy_fn = accuracy_fn if accuracy_fn else self.default_accuracy

    def default_accuracy(self, model, dataloader, **kwargs):
        """
        기본 정확도 계산 함수. **kwargs를 통해 추가 옵션 지원.
        """
        cnt = 0
        acc = 0
        model.eval()
        with torch.no_grad():
            for data in dataloader:
                inputs, labels = data
                preds = model(inputs)
                preds = torch.argmax(preds, dim=-1)
                cnt += labels.shape[0]
                acc += (labels == preds).sum().item()
        model.train()
        return acc / cnt if cnt > 0 else 0.0

    def update_accuracies(self, verbose=False, **kwargs):
        """
        models와 dataloaders를 평가하고 acc_lists에 기록.
        verbose: True인 경우 정확도 업데이트 로그 출력.
        """
        for i, (model, dataloader) in enumerate(zip(self.models, self.dataloaders)):
            acc = self.accuracy_fn(model, dataloader, **kwargs)
            self.acc_lists[i].append(acc)
            if verbose:
                print(f"Updated Accuracy ({self.labels[i]}): {acc:.4f}")

    def plot(self, epoch=None, save_path=None):
        """
        정확도 그래프를 출력하고 저장하는 기능.
        epoch: 그래프 제목에 표시할 에폭 정보
        save_path: 그래프를 저장할 경로 (None이면 저장하지 않음)
        """
        if not self.acc_lists or len(self.acc_lists[0]) == 0:
            print("No accuracies to plot")
            return

        x = np.arange(len(self.acc_lists[0]))  # 에폭 수만큼 x축 생성

        plt.figure(figsize=(10, 6))  # 그래프 크기 조정
        for acc_list, label in zip(self.acc_lists, self.labels):
            plt.plot(x, acc_list, label=label, marker="o")

        title = f"{self.title} at Epoch {epoch}" if epoch is not None else self.title
        plt.title(title)
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.grid(True)
        plt.legend()
        plt.ylim(0.0, 1.0)

        if save_path:
            plt.savefig(save_path)
            print(f"Plot saved to {save_path}")
        else:
            plt.show()


device = torch.device("mps") if torch.backends.mps.is_built() else torch.device("cpu")

## [MY CODE] Loss function 및 classifier output 변경

In [6]:
from torch.optim import Adam
import numpy as np
import matplotlib.pyplot as plt


class LastWordPrediction(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, dff):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.dff = dff

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.parameter.Parameter(
            positional_encoding(max_word_len, d_model), requires_grad=False
        )
        self.layers = nn.ModuleList(
            [TransformerLayer(d_model, d_model, dff) for _ in range(n_layers)]
        )
        self.classification = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        mask = (x == tokenizer.pad_token_id).unsqueeze(1)
        seq_len = x.shape[1]

        x = self.embedding(x)
        x = x * sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len]

        for layer in self.layers:
            x = layer(x, mask)

        x = self.classification(x)
        x = x[:, -1, :]
        return x


model = LastWordPrediction(len(tokenizer), 8, 2, 16)
lr = 0.001
loss_fn = nn.CrossEntropyLoss()

optimizer = Adam(model.parameters(), lr=lr)


model = model.to(device)


def custom_accuracy(model, dataloader):
    cnt = 0
    acc = 0

    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        preds = model(inputs)
        # preds = torch.argmax(preds, dim=-1)
        preds = torch.argmax(preds, dim=-1)  # 가장 높은 확률을 가진 클래스를 선택

        cnt += labels.shape[0]
        acc += (labels == preds).sum().item()

    return acc / cnt


n_epochs = 50

trans_monitor = AccuracyMonitor(
    models=[model, model],
    dataloaders=[train_loader2, test_loader2],
    labels=["Train", "Test"],
    accuracy_fn=custom_accuracy,
)

for epoch in range(n_epochs):

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}", unit="batch") as pbar:
        total_loss = 0.0
        model.train()
        for data in train_loader2:
            model.zero_grad()
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            labels = labels.long()

            preds = model(inputs)
            loss = loss_fn(preds, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        with torch.no_grad():
            model.eval()
            trans_monitor.update_accuracies()

trans_monitor.plot()

Epoch 1/50:   0%|          | 0/391 [00:00<?, ?batch/s]

KeyboardInterrupt: 