# Idiom Detection with BERT + BiLSTM + CRF
This notebook is adapted from your `main.py` for use in Google Colab.

**Instructions:**
- Upload your `model.py` and `dataset.py` files to the Colab environment.
- Place your data in the appropriate paths (e.g., `public_data/`, `starting_kit/`).
- Use the cells below to train, evaluate, or predict.


In [None]:
# Install dependencies
!pip install transformers tqdm scikit-learn pytorch-crf


In [None]:
# Import modules
import os
import pandas as pd
import torch
from dataset import get_dataloaders
from transformers import BertTokenizer

from biltsm_crf_model import (
    EnhancedBertForIdiomDetection,
    train_model,
    predict_idioms_with_postprocessing,
    evaluate
)


## Training Function

In [None]:
def run_train(epochs=10, lr=2e-5, batch_size=8, max_length=128, 
              lstm_hidden_size=384, lstm_layers=2, lstm_dropout=0.3,
              hidden_dropout=0.3, use_layer_norm=True, freeze_bert_layers=0):
    # ... existing dataloader code ...
    
    model = EnhancedBertForIdiomDetection(
        lstm_hidden_size=lstm_hidden_size,
        lstm_layers=lstm_layers,
        lstm_dropout=lstm_dropout,
        hidden_dropout=hidden_dropout,
        use_layer_norm=use_layer_norm,
        freeze_bert_layers=freeze_bert_layers
    )
    
    model = train_model(
        train_loader, val_loader, tokenizer,
        model=model,
        epochs=epochs,
        lr=lr
    )

## Evaluation Function

In [None]:
def run_eval(batch_size=8, max_length=128, 
            lstm_hidden_size=384, lstm_layers=2, lstm_dropout=0.3,
            hidden_dropout=0.3, use_layer_norm=True, freeze_bert_layers=0):
    # ... existing device and dataloader code ...
    
    model = EnhancedBertForIdiomDetection(
        lstm_hidden_size=lstm_hidden_size,
        lstm_layers=lstm_layers,
        lstm_dropout=lstm_dropout,
        hidden_dropout=hidden_dropout,
        use_layer_norm=use_layer_norm,
        freeze_bert_layers=freeze_bert_layers
    )

## Prediction Function

In [None]:
# ... existing code ...

def run_predict(output='predictions.csv', 
               lstm_hidden_size=384, lstm_layers=2, lstm_dropout=0.3,
               hidden_dropout=0.3, use_layer_norm=True, freeze_bert_layers=0):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize model with same parameters as training
    model = EnhancedBertForIdiomDetection(
        lstm_hidden_size=lstm_hidden_size,
        lstm_layers=lstm_layers,
        lstm_dropout=lstm_dropout,
        hidden_dropout=hidden_dropout,
        use_layer_norm=use_layer_norm,
        freeze_bert_layers=freeze_bert_layers
    )
    
    model.load_state_dict(torch.load('best_idiom_model.pt', map_location=device))
    model.to(device)
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

    # Read test data
    test_df = pd.read_csv('starting_kit/eval_w_o_labels.csv')
    ids = test_df['id'].tolist()
    sentences = test_df['sentence'].tolist()
    languages = test_df['language'].tolist()

    results = []
    for idx, sentence, lang in zip(ids, sentences, languages):
        # Use the new prediction function with post-processing
        idiom_indices = predict_idioms_with_postprocessing(model, tokenizer, sentence, device)
        
        # If no idiom is found, use [-1] as per the competition format
        if not idiom_indices:
            idiom_indices = [-1]
            
        results.append({
            'id': idx,
            'indices': str(idiom_indices),
            'language': lang
        })

    out_df = pd.DataFrame(results)
    out_df.to_csv(output, index=False)
    print(f'Predictions saved to {output}')

# ... existing code ...

## Example Usage
Uncomment and run the cell below for the desired operation.

In [None]:
# Train
# run_train(epochs=10, lr=2e-5, batch_size=8, max_length=128)

# Evaluate
# run_eval(batch_size=8, max_length=128)

# Predict
# run_predict(output='predictions.csv')
