# Idiom Detection with BERT + BiLSTM + CRF
This notebook is adapted from your `main.py` for use in Google Colab.

**Instructions:**
- Upload your `model.py` and `dataset.py` files to the Colab environment.
- Place your data in the appropriate paths (e.g., `public_data/`, `starting_kit/`).
- Use the cells below to train, evaluate, or predict.


In [None]:
# Install dependencies
!pip install transformers tqdm scikit-learn pytorch-crf


In [None]:
# Import modules
import os
import pandas as pd
import torch
from dataset import get_dataloaders
from model import train_model, BertForIdiomDetection, predict_idioms
from transformers import BertTokenizer


## Training Function

In [None]:
def run_train(epochs=10, lr=2e-5, batch_size=8, max_length=128):
    train_loader, val_loader, tokenizer = get_dataloaders(
        train_path='public_data/train.csv',
        val_path='public_data/eval.csv',
        batch_size=batch_size,
        max_length=max_length
    )
    model = train_model(
        train_loader, val_loader, tokenizer,
        epochs=epochs, lr=lr
    )
    print('Training complete. Best model saved as best_idiom_model.pt')


## Evaluation Function

In [None]:
def run_eval(batch_size=8, max_length=128):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, val_loader, tokenizer = get_dataloaders(
        train_path='public_data/train.csv',
        val_path='public_data/eval.csv',
        batch_size=batch_size,
        max_length=max_length
    )
    model = BertForIdiomDetection()
    model.load_state_dict(torch.load('best_idiom_model.pt', map_location=device))
    model.to(device)
    from model import evaluate
    metrics = evaluate(model, val_loader, tokenizer, device)
    print('Evaluation complete.')
    print(metrics)


## Prediction Function

In [None]:
def run_predict(output='predictions.csv'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = BertForIdiomDetection()
    model.load_state_dict(torch.load('best_idiom_model.pt', map_location=device))
    model.to(device)
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

    # Read test data
    test_df = pd.read_csv('starting_kit/eval_w_o_labels.csv')
    ids = test_df['id'].tolist()
    sentences = test_df['sentence'].tolist()
    languages = test_df['language'].tolist()  # Get language information

    results = []
    for idx, sentence, lang in zip(ids, sentences, languages):
        _, idiom_indices = predict_idioms(model, tokenizer, sentence, device)
        # If no idiom, output [-1] as in training
        if not idiom_indices:
            idiom_indices = [-1]
        results.append({
            'id': idx,
            'indices': str(idiom_indices),  # Convert list to string representation
            'language': lang
        })

    out_df = pd.DataFrame(results)
    out_df.to_csv(output, index=False)
    print(f'Predictions saved to {output}')

## Example Usage
Uncomment and run the cell below for the desired operation.

In [None]:
# Train
# run_train(epochs=10, lr=2e-5, batch_size=8, max_length=128)

# Evaluate
# run_eval(batch_size=8, max_length=128)

# Predict
# run_predict(output='predictions.csv')
