In [None]:
import os
import torch
import torch.nn as nn
from transformers import XLMRobertaTokenizer, XLMRobertaModel
from torch.utils.data import DataLoader
from dataset import IdiomDataset, preprocess_data
import pandas as pd
from tqdm import tqdm
import numpy as np

In [None]:
# Config
MAX_LEN = 128
BATCH_SIZE = 8
EPOCHS = 3
MODEL_NAME = "xlm-roberta-base"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer
tokenizer = XLMRobertaTokenizer.from_pretrained(MODEL_NAME)

# Load training data
train_df = pd.read_csv("dataset/train.csv")
val_df = pd.read_csv("dataset/eval.csv")

train_inputs, train_labels = preprocess_data(train_df, tokenizer)
val_inputs, val_labels = preprocess_data(val_df, tokenizer)

train_dataset = IdiomDataset(train_inputs, train_labels, tokenizer, MAX_LEN)
val_dataset = IdiomDataset(val_inputs, val_labels, tokenizer, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [None]:
# Model
class IdiomModel(nn.Module):
    def __init__(self):
        super(IdiomModel, self).__init__()
        self.roberta = XLMRobertaModel.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.roberta.config.hidden_size, 3)  # 3 classes: O, B, I

    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)
        return self.classifier(sequence_output)

model = IdiomModel().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
# Training loop
def train():
    model.train()
    for epoch in range(EPOCHS):
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs.view(-1, 3), labels.view(-1))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()
        print(f"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}")

In [None]:
# Inference function to generate predictions
def predict_and_save():
    model.eval()
    test_df = pd.read_csv("dataset/eval_w_o_labels.csv")

    all_preds = []
    ids = test_df["id"].tolist()
    langs = test_df["language"].tolist()

    inputs, _ = preprocess_data(test_df, tokenizer)
    test_dataset = IdiomDataset(inputs, [[0]*len(x) for x in inputs], tokenizer, MAX_LEN)
    test_loader = DataLoader(test_dataset, batch_size=1)

    for i, batch in enumerate(tqdm(test_loader, desc="Predicting")):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)

        with torch.no_grad():
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=-1).squeeze().cpu().tolist()
            mask = batch["attention_mask"].squeeze().cpu().tolist()

        # Skip [CLS], [SEP], and padding (only keep non-zero mask positions)
        valid_preds = [p for p, m in zip(preds, mask) if m != 0][1:-1]

        indices = [i for i, label in enumerate(valid_preds) if label in [1, 2]]
        if not indices:
            indices = [-1]

        all_preds.append(indices)

    # Save to prediction.csv
    output_df = pd.DataFrame({
        "id": ids,
        "indices": [str(x) for x in all_preds],
        "language": langs
    })

    os.makedirs("app/input/res", exist_ok=True)
    output_df.to_csv("app/input/res/prediction.csv", index=False)
    print("Predictions saved to app/input/res/prediction.csv")

In [None]:
if __name__ == "__main__":
    train()
    predict_and_save()