# Homework 2: Topic Classification
- Dataset: https://www.kaggle.com/datasets/amananandrai/ag-news-classification-dataset/

## Part 0: Download the dataset and upload to Colab or to your local machine
- You should register a Kaggle account to download it.
- You should create a folder called `data` and put the downloaded files inside it.

## Part I: Data pre-processing

In [None]:
from typing import Tuple
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [None]:
# Hyperparameters
use_agnews_title =
batch_size =
test_batch_size =
num_epoch =
embedding_dim =
hidden_size =
dropout_rate =
learning_rate =

In [None]:
def preprocess_agnews(
    data_type: str = "train",
    use_agnews_title: bool = False,
    train_size: float = 0.8,
    random_state: int = 42,
) -> Tuple[list, list] | Tuple[list, list, list, list]:
    # Read data
    df = pd.read_csv(f"data/{data_type}.csv")

    if data_type == "train":
        # TODO1-1: split the validation data from the training data
        # TODO1-2: do some data pre-processing for the train/valid set
        # Write your code here

        return train_text, train_label, val_text, val_label

    else: # this part should be for the test set
        # TODO1-3: do some data pre-processing for the test set
        # Write your code here

        return test_text, test_label

In [None]:
train_text, train_label, val_text, val_label = preprocess_agnews(
    data_type="train",
    use_agnews_title=use_agnews_title,
)
test_text, test_label = preprocess_agnews(
    data_type="test",
    use_agnews_title=use_agnews_title,
)
num_labels = len(set(train_label))

In [None]:
vocab = {'<pad>':0, '<unk>':1}
# TODO2: Build the vocabulary
# Write your code here


In [None]:
# TODO3-1: Write the torch Dataset

class AGNewsDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, vocab, tokenizer, lower=True):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.tokenizer = tokenizer  # TODO3-2. Write in the next block.
        self.lower = lower

    def __getitem__(self, idx):
        # Write your code here

        return # Two things (both can be tensor) should be returned

    def __len__(self):
        return len(self.labels)

In [None]:
tokenizer = # TODO3-2: Decide your tokenizer. You can use SpaCy, NLTK, and so on ...

train_dataset = AGNewsDataset(train_text, train_label, vocab, tokenizer, lower=True)
val_dataset = AGNewsDataset(val_text, val_label, vocab, tokenizer, lower=True)
test_dataset = AGNewsDataset(test_text, test_label, vocab, tokenizer, lower=True)

In [None]:
# TODO4: Write the collate function

def collate_batch(batch):
    # Write your code here

    return text, label

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_batch,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=test_batch_size,
    shuffle=False,
    collate_fn=collate_batch,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    shuffle=False,
    collate_fn=collate_batch,
)

## Part II: Build your model
- You are restricted to use LSTM only.

In [None]:
import torch
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

In [None]:
# TODO5: Write the class for your model

class LSTMTextClassifier(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, dropout, padding_idx):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        # Write your code here
        # You can adjust anything you want.

    def forward(self, x):
        # x: [batch_size, seq_len]
        # Write your code here

        return logits # model outputs before softmax

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LSTMTextClassifier(
    vocab_size=len(vocab),
    embedding_dim=embedding_dim,
    hidden_dim=hidden_size,
    output_dim=num_labels,
    dropout=dropout_rate,
    padding_idx=vocab['<pad>'],
).to(device)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss() # You should use CrossEntropyLoss for classification.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Part III: Training

In [None]:
def evaluate(dataloader, model, loss_fn):
    """定義驗證時的進行流程
    Arguments:
        - dataloader: 具備 mini-batches 的 dataset，由 PyTorch DataLoader 所建立
        - model: 要進行驗證的模型
        - loss_fn: loss function
    Returns:
        - loss: 模型在驗證/測試集的 loss
        - acc: 模型在驗證/測試集的正確率
    """
    # 設定模型的驗證模式
    # 此時 dropout 會自動關閉
    model.eval()

    # 設定現在不計算梯度
    with torch.no_grad():
        # 把每個 batch 的 label 儲存成一維 tensor
        y_true = torch.tensor([])
        y_pred = torch.tensor([])

        # 從 dataloader 一次一次抽
        for x, y in dataloader:
            # 把正確的 label concat 起來
            y_true = torch.cat([y_true, y])

            x = x.to(device)
            y = y.to(device)


            logits = model(x)
            # 預測的數值大於 0.5 則視為類別1，反之為類別0
            pred = torch.argmax(logits, dim=-1)
            # 把預測的 label concat 起來
            # 注意: 如果使用 gpu 計算的話，要先用 .cpu 把 tensor 轉回 cpu
            y_pred = torch.cat([y_pred, pred.cpu()])

    # 模型輸出的維度是 (B, 1)，使用.squeeze(-1)可以讓維度變 (B,)
    loss = loss_fn(y_pred.squeeze(-1), y_true)
    # 計算正確率
    acc = accuracy_score(y_true, y_pred.squeeze(-1))
    f1 = f1_score(y_true, y_pred.squeeze(-1))

    return loss, acc, f1

In [None]:
# TODO6: Write the training script

for epoch in range(num_epoch):
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch}")
    for x, y in progress_bar:
        # Write your code here

        progress_bar.set_postfix(loss=loss.item())

    # Write your code here for evaluating your model on the validation data

## Part IV: Evaluation

In [None]:
# 計算測試集的正確率
test_loss, test_acc, test_f1 = evaluate(test_loader, model, loss_fn)
print(f"Test Loss: {test_loss}, Test Acc: {test_acc}, Test F1: {test_f1}")