# ML HW4 Sample Code
TODO:
 - Design your LSTM model
 - Use unlabelled data (train_nolabel.csv) for Word2Vec training
    - Combine labeled + unlabeled data to train better embeddings
 - Train with labelled data (train_label.csv)
    - Optional: Data augmentation
    - Optional: Custom loss function

## Download data

In [None]:
# !pip install -U gdown -q
# !gdown --folder https://drive.google.com/drive/folders/1786AXJRAtqFvWMBeh-bLm4MtU21IQpBg

## Import packages

In [None]:
import torch
import os
import csv
import random
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

# !pip install gensim
from gensim.models import Word2Vec
from sklearn.model_selection import train_test_split

## Set the Configurations

In [None]:
# Training Config
DEVICE_NUM = 2
BATCH_SIZE = 128
EPOCH_NUM = 20
MAX_POSITIONS_LEN = 500
SEED = 2025
MODEL_DIR = 'model.pth'
lr = 0.001

# Set Seed
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
random.seed(SEED)
np.random.seed(SEED)

# torch.cuda.set_device(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# RNN Config
w2v_config = {'path': 'w2v_model', 'dim': 512}
lstm_config = {'hidden_dim': 512, 'num_layers': 2, 'bidirectional': True, 'fix_embedding': True}
header_config = {'dropout': 0.3, 'hidden_dim': 1024}
assert header_config['hidden_dim'] == lstm_config['hidden_dim'] or header_config['hidden_dim'] == lstm_config['hidden_dim'] * 2

train_dataset_dropout = 0.2

## Utils for Datasets and Dataloaders

In [None]:
def parsing_text(text):
    # Explicitly handle None values
    if text is None:
        return ""
    # Handle NaN values (which often appear as floats in pandas)
    if pd.isna(text):
        return ""  # Convert NaN to an empty string
    # Optional: Add data preprocessing (e.g., lowercasing, removing special characters)
    text = text.lower()
    special_chars = ['!', '?', '(', ')', '[', ']', '{', '}', '"', "'", '_', '/', '\\', '@', '#', '$', '%', '^', '&', '*', '+', '=', '<', '>', '~', '`']
    for char in special_chars:
        text = text.replace(char, '')
    return str(text)

def load_train_label(path='train_label.csv'):
    tra_lb_pd = pd.read_csv(path)
    idx = tra_lb_pd['id'].tolist()
    text = [parsing_text(s).split() for s in tra_lb_pd['text'].tolist()]
    label = tra_lb_pd['label'].tolist()
    return idx, text, label

def load_train_nolabel(path='train_nolabel.csv'):
    tra_nlb_pd = pd.read_csv(path)
    text = [parsing_text(s).split() for s in tra_nlb_pd['text'].tolist()]
    return text

def load_test(path='test.csv'):
    test_pd = pd.read_csv(path)
    idx = test_pd['id'].tolist()
    text = [parsing_text(s).split() for s in test_pd['text'].tolist()]
    return idx, text

## Datasets and Dataloaders

In [None]:
class Preprocessor:
    def __init__(self, sentences, w2v_config):
        self.sentences = sentences
        self.idx2word = []
        self.word2idx = {}
        self.embedding_matrix = []
        self.build_word2vec(sentences, **w2v_config)

    def build_word2vec(self, x, path, dim):
        if os.path.isfile(path):
            print("loading word2vec model ...")
            w2v_model = Word2Vec.load(path)
        else:
            print("training word2vec model ...")
            w2v_model = Word2Vec(x, vector_size=dim, window=5, min_count=2, workers=12, epochs=10, sg=1)
            print("saving word2vec model ...")
            w2v_model.save(path)

        self.embedding_dim = w2v_model.vector_size
        for i, word in enumerate(w2v_model.wv.key_to_index):
            #e.g. self.word2index['he'] = 1
            #e.g. self.index2word[1] = 'he'
            #e.g. self.vectors[1] = 'he' vector

            self.word2idx[word] = len(self.word2idx)
            self.idx2word.append(word)
            self.embedding_matrix.append(w2v_model.wv[word])

        self.embedding_matrix = torch.tensor(self.embedding_matrix)
        self.add_embedding('<PAD>')
        self.add_embedding('<UNK>')
        print("total words: {}".format(len(self.embedding_matrix)))

    def add_embedding(self, word):
        vector = torch.empty(1, self.embedding_dim)
        torch.nn.init.uniform_(vector)
        self.word2idx[word] = len(self.word2idx)
        self.idx2word.append(word)
        self.embedding_matrix = torch.cat([self.embedding_matrix, vector], 0)

    def sentence2idx(self, sentence):
        sentence_idx = []
        for word in sentence:
            if word in self.word2idx.keys():
                sentence_idx.append(self.word2idx[word])
            else:
                sentence_idx.append(self.word2idx["<UNK>"])
        return torch.LongTensor(sentence_idx)

class TwitterDataset(torch.utils.data.Dataset):
    def __init__(self, id_list, sentences, labels, preprocessor, dropout=0.0):
        self.id_list = id_list
        self.sentences = sentences
        self.labels = labels
        self.preprocessor = preprocessor
        self.dropout = dropout
        
    def __getitem__(self, idx):
        if self.labels is None: return self.id_list[idx], self.preprocessor.sentence2idx(self.sentences[idx])
        if self.dropout > 0.0:
            sentence_idx = self.preprocessor.sentence2idx(self.sentences[idx])
            keep_mask = (torch.rand(len(sentence_idx)) > self.dropout).long()
            sentence_idx = sentence_idx * keep_mask + self.preprocessor.word2idx['<PAD>'] * (1 - keep_mask)
            return self.id_list[idx], sentence_idx, self.labels[idx]

        return self.id_list[idx], self.preprocessor.sentence2idx(self.sentences[idx]), self.labels[idx]

    def __len__(self):
        return len(self.sentences)

    def collate_fn(self, data):
        id_list = torch.LongTensor([d[0] for d in data])
        lengths = torch.LongTensor([len(d[1]) for d in data])
        texts = pad_sequence(
            [d[1] for d in data], batch_first=True).contiguous()

        if self.labels is None:
            return id_list, lengths, texts

        labels = torch.FloatTensor([d[2] for d in data])
        return id_list, lengths, texts, labels

## RNN Backbone

In [None]:
class LSTM_Backbone(torch.nn.Module):
    def __init__(self, embedding, hidden_dim, num_layers, bidirectional, fix_embedding=True):
        super(LSTM_Backbone, self).__init__()
        self.embedding = torch.nn.Embedding(embedding.size(0),embedding.size(1))
        self.embedding.weight = torch.nn.Parameter(embedding)
        self.embedding.weight.requires_grad = False if fix_embedding else True

        self.lstm = torch.nn.LSTM(embedding.size(1), hidden_dim, num_layers=num_layers, \
                                  bidirectional=bidirectional, batch_first=True)

    def forward(self, inputs):
        inputs = self.embedding(inputs)
        x, _ = self.lstm(inputs)
        return x

class Header(torch.nn.Module):
    def __init__(self, dropout, hidden_dim):
        super(Header, self).__init__()
        self.classifier = torch.nn.Sequential(
            torch.nn.Conv1d(hidden_dim, 128, kernel_size=3, padding=1),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Conv1d(128, 128, kernel_size=3, padding=1),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.AdaptiveMaxPool1d(1),
            torch.nn.Flatten(),
            torch.nn.Linear(128, 1),
            torch.nn.Sigmoid()
        )

    @torch.no_grad()
    def _get_length_masks(self, lengths):
        # lengths: (batch_size, ) in cuda
        ascending = torch.arange(MAX_POSITIONS_LEN)[:lengths.max().item()].unsqueeze(
            0).expand(len(lengths), -1).to(lengths.device)
        length_masks = (ascending < lengths.unsqueeze(-1)).unsqueeze(-1)
        return length_masks

    def forward(self, inputs, lengths):
        # the input shape should be (N, L, Dâˆ—H)
        pad_mask = self._get_length_masks(lengths)
        inputs = inputs * pad_mask
        
        # Permute to (N, C, L) for Conv1d
        inputs = inputs.permute(0, 2, 1)
        
        out = self.classifier(inputs).squeeze()
        return out

## Training & Validation

In [None]:
def train(train_loader, backbone, header, optimizer, criterion, device, epoch):

    total_loss = []
    total_acc = []

    for i, (idx_list, lengths, texts, labels) in enumerate(train_loader):
        lengths, inputs, labels = lengths.to(device), texts.to(device), labels.to(device)

        optimizer.zero_grad()
        if not backbone is None:
            inputs = backbone(inputs)
        soft_predicted = header(inputs, lengths)
        loss = criterion(soft_predicted, labels)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(backbone.parameters(), max_norm=1.0) if not backbone is None else None
        # torch.nn.utils.clip_grad_norm_(header.parameters(), max_norm=1.0)
        optimizer.step()

        with torch.no_grad():
            hard_predicted = (soft_predicted >= 0.5).int()
            correct = sum(hard_predicted == labels).item()
            batch_size = len(labels)
            total_loss.append(loss.item())
            total_acc.append(correct * 100 / batch_size)
            print('[ Epoch {}: {}/{} ] loss:{:.3f} acc:{:.3f} '.format(epoch+1, i+1, len(train_loader), np.mean(total_loss), np.mean(total_acc)), end='\r')
    
    return np.mean(total_loss), np.mean(total_acc)


def valid(valid_loader, backbone, header, criterion, device, epoch):
    backbone.eval()
    header.eval()
    with torch.no_grad():
        total_loss = []
        total_acc = []

        for i, (idx_list, lengths, texts, labels) in enumerate(valid_loader):
            lengths, inputs, labels = lengths.to(device), texts.to(device), labels.to(device)

            if not backbone is None:
                inputs = backbone(inputs)
            soft_predicted = header(inputs, lengths)
            loss = criterion(soft_predicted, labels)
            total_loss.append(loss.item())

            hard_predicted = (soft_predicted >= 0.5).int()
            correct = sum(hard_predicted == labels).item()
            acc = correct * 100 / len(labels)
            total_acc.append(acc)

            print('[Validation in epoch {:}] loss:{:.3f} acc:{:.3f}'.format(epoch+1, np.mean(total_loss), np.mean(total_acc)), end='\r')
    backbone.train()
    header.train()
    return np.mean(total_loss), np.mean(total_acc)


def run_training(train_loader, valid_loader, backbone, header, epoch_num, lr, device, model_dir):
    best_acc = 0.0
    patience = 5
    counter = 0

    def check_point(backbone, header, loss, acc, model_dir):
        nonlocal best_acc
        if acc > best_acc:
            # Save state_dict instead of full model objects
            torch.save({
                'backbone_state_dict': backbone.state_dict(),
                'header_state_dict': header.state_dict()
            }, model_dir)
            print(f'New best model saved with accuracy: {acc:.3f}')

    def is_stop(loss, acc):
        # TODO: Implement early stopping
        nonlocal best_acc, counter, patience
        if acc > best_acc:
            counter = 0
            best_acc = acc
            return False
        else:
            counter += 1
            if counter >= patience:
                return True
        return False

    if backbone is None:
        trainable_paras = header.parameters()
    else:
        trainable_paras = list(backbone.parameters()) + list(header.parameters())

    optimizer = torch.optim.AdamW(trainable_paras, lr=lr, weight_decay=1e-3)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

    backbone.train()
    header.train()
    backbone = backbone.to(device)
    header = header.to(device)
    criterion = torch.nn.BCELoss()

    for epoch in range(epoch_num):
        train_loss, train_acc = train(train_loader, backbone, header, optimizer, criterion, device, epoch)
        loss, acc = valid(valid_loader, backbone, header, criterion, device, epoch)
        print('[Epoch {:}] Train loss:{:.3f} acc:{:.3f} | Val loss:{:.3f} acc:{:.3f} | LR:{:.6f}'.format(
            epoch+1, train_loss, train_acc, loss, acc, optimizer.param_groups[0]['lr']))
        
        check_point(backbone, header, loss, acc, model_dir)
        scheduler.step(acc)
        
        if is_stop(loss, acc):
            print(f'Early stopping triggered after {epoch+1} epochs due to no improvement in accuracy for {patience} epochs.')
            break
    
    # Load best model for testing
    checkpoint = torch.load(model_dir, weights_only=True)
    backbone.load_state_dict(checkpoint['backbone_state_dict'])
    header.load_state_dict(checkpoint['header_state_dict'])
    print(f'Loaded best model with accuracy: {best_acc:.3f}')

## Testing

In [None]:
def run_testing(test_loader, backbone, header, device, output_path):
    with open(output_path, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['id', 'label'])

        with torch.no_grad():
            for i, (idx_list, lengths, texts) in enumerate(test_loader):
                lengths, inputs = lengths.to(device), texts.to(device)
                if not backbone is None:
                    inputs = backbone(inputs)
                soft_predicted = header(inputs, lengths)
                hard_predicted = (soft_predicted >= 0.5).int()
                for i, p in zip(idx_list, hard_predicted):
                    writer.writerow([str(i.item()), str(p.item())])

## Main

In [None]:
# Split Dataset
train_idx, train_label_text, label = load_train_label('dataset/train_label.csv')
train_nolabel_text = load_train_nolabel('dataset/train_nolabel.csv')

# For sanity check
print(train_label_text[:10])
print(train_nolabel_text[:10])

w2v_sentences = train_label_text + train_nolabel_text

# Use labeled data for Word2Vec embeddings (# TODO: Perform unsupervised Learning for w2v)
preprocessor = Preprocessor(w2v_sentences, w2v_config)

# K-Fold Cross Validation
n_splits = 5
all_indices = list(range(len(train_label_text)))
random.shuffle(all_indices)

fold_size = len(all_indices) // n_splits
folds = []
for i in range(n_splits):
    start = i * fold_size
    end = (i + 1) * fold_size if i < n_splits - 1 else len(all_indices)
    folds.append(all_indices[start:end])

test_idx, test_text = load_test('dataset/test.csv')
test_dataset = TwitterDataset(test_idx, test_text, None, preprocessor)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                            batch_size = BATCH_SIZE,
                                            shuffle = False,
                                            collate_fn = test_dataset.collate_fn,
                                            num_workers = 8)

test_pred_probs = []

for fold in range(n_splits):
    print(f'\nFold {fold+1}/{n_splits}')
    
    val_indices = folds[fold]
    train_indices = []
    for f in range(n_splits):
        if f != fold:
            train_indices.extend(folds[f])
    
    fold_train_idx = [train_idx[i] for i in train_indices]
    fold_train_text = [train_label_text[i] for i in train_indices]
    fold_train_label = [label[i] for i in train_indices]
    
    fold_val_idx = [train_idx[i] for i in val_indices]
    fold_val_text = [train_label_text[i] for i in val_indices]
    fold_val_label = [label[i] for i in val_indices]
    
    train_dataset = TwitterDataset(fold_train_idx, fold_train_text, fold_train_label, preprocessor, train_dataset_dropout)
    valid_dataset = TwitterDataset(fold_val_idx, fold_val_text, fold_val_label, preprocessor)
    
    train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                                batch_size = BATCH_SIZE,
                                                shuffle = True,
                                                collate_fn = train_dataset.collate_fn,
                                                num_workers = 8)
    valid_loader = torch.utils.data.DataLoader(dataset = valid_dataset,
                                                batch_size = BATCH_SIZE,
                                                shuffle = False,
                                                collate_fn = valid_dataset.collate_fn,
                                                num_workers = 8)
    
    backbone = LSTM_Backbone(preprocessor.embedding_matrix, **lstm_config)
    header = Header(**header_config)
    
    model_path = f'model_fold_{fold+1}.pth'
    run_training(train_loader, valid_loader, backbone, header, EPOCH_NUM, lr, device, model_path)
    
    backbone.eval()
    header.eval()
    
    fold_probs = []
    with torch.no_grad():
        for i, (idx_list, lengths, texts) in enumerate(test_loader):
            lengths, inputs = lengths.to(device), texts.to(device)
            if not backbone is None:
                inputs = backbone(inputs)
            soft_predicted = header(inputs, lengths)
            fold_probs.extend(soft_predicted.cpu().tolist())
    test_pred_probs.append(fold_probs)

avg_probs = np.mean(test_pred_probs, axis=0)
hard_preds = (avg_probs >= 0.5).astype(int)

with open('pred.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['id', 'label'])
    for i, p in zip(test_idx, hard_preds):
        writer.writerow([str(i), str(p)])