In [None]:
import torch
import pandas as pd
from dataset import get_dataloaders
from transformers import AutoTokenizer
from xlmr_model import XLMRForIdiomDetection, train_model, evaluate_model, predict_idiom_indices

In [None]:
def run_train(epochs=5, lr=3e-5, batch_size=8, max_length=128):
    train_loader, val_loader, tokenizer = get_dataloaders(
        train_path='dataset/train.csv',
        val_path='dataset/eval.csv',
        batch_size=batch_size,
        max_length=max_length
    )
    model = XLMRForIdiomDetection()
    model = train_model(model, train_loader, val_loader, tokenizer, epochs=epochs, lr=lr)
    os.makedirs('models/saved_pts', exist_ok=True)
    torch.save(model.state_dict(), 'models/saved_pts/xlmr_best_model.pt')
    print("Training complete. Model saved.")

In [None]:
def run_eval(batch_size=8, max_length=128):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _, val_loader, tokenizer = get_dataloaders(
        train_path='dataset/train.csv',
        val_path='dataset/eval.csv',
        batch_size=batch_size,
        max_length=max_length
    )
    model = XLMRForIdiomDetection()
    model.load_state_dict(torch.load('models/saved_pts/xlmr_best_model.pt', map_location=device))
    model.to(device)
    metrics = evaluate_model(model, val_loader, tokenizer, device)
    print("Evaluation complete.\n", metrics)

In [None]:
def run_predict(output_path='prediction.csv', max_length=128):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = XLMRForIdiomDetection()
    model.load_state_dict(torch.load('xlmr_best_model.pt', map_location=device))
    model.to(device)
    tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

    test_df = pd.read_csv("dataset/eval_w_o_labels.csv")
    results = []

    for _, row in test_df.iterrows():
        idx = row["id"]
        sentence = row["sentence"]
        lang = row["language"]
        _, pred_indices = predict_idiom_indices(model, tokenizer, sentence, device, max_length=max_length)
        if not pred_indices:
            pred_indices = [-1]
        results.append({
            "id": idx,
            "indices": str(pred_indices),
            "language": lang
        })

    pd.DataFrame(results).to_csv(output_path, index=False)
    print(f"Predictions saved to {output_path}")

In [None]:
# Uncomment to use:
# run_train()
# run_eval()
# run_predict()