In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
from dataset import get_dataloaders
from transformers import BertTokenizer, BertModel
from torchcrf import CRF
from tqdm import tqdm
import numpy as np

# Define the Enhanced Model
class EnhancedBertForIdiomDetection(nn.Module):
    def __init__(self, 
                 model_name="bert-base-multilingual-cased", 
                 num_labels=3,
                 lstm_hidden_size=384,
                 lstm_layers=2,
                 lstm_dropout=0.3,
                 hidden_dropout=0.3,
                 use_layer_norm=True,
                 freeze_bert_layers=0):
        super(EnhancedBertForIdiomDetection, self).__init__()
        
        # Pre-trained BERT model
        self.bert = BertModel.from_pretrained(model_name)
        
        # Freeze specified number of BERT layers if needed
        if freeze_bert_layers > 0:
            modules = [self.bert.embeddings]
            modules.extend(self.bert.encoder.layer[:freeze_bert_layers])
            for module in modules:
                for param in module.parameters():
                    param.requires_grad = False
        
        # Add a BiLSTM layer to capture context
        self.lstm = nn.LSTM(
            input_size=self.bert.config.hidden_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True,
            dropout=lstm_dropout if lstm_layers > 1 else 0
        )
        
        # Classification layers
        self.dropout = nn.Dropout(hidden_dropout)
        self.dense = nn.Linear(lstm_hidden_size*2, lstm_hidden_size)
        self.activation = nn.ReLU()
        self.use_layer_norm = use_layer_norm
        if use_layer_norm:
            self.norm = nn.LayerNorm(lstm_hidden_size)
        self.classifier = nn.Linear(lstm_hidden_size, num_labels)
        
        # CRF layer
        self.crf = CRF(num_labels, batch_first=True)
        
    def forward(self, input_ids, attention_mask, labels=None):
        # Get BERT outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Get token-level representations
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        
        # Apply BiLSTM
        lstm_output, _ = self.lstm(sequence_output)  # [batch_size, seq_len, 2*hidden_size]
        
        # Apply classification layers
        x = self.dropout(lstm_output)
        x = self.dense(x)
        x = self.activation(x)
        if self.use_layer_norm:
            x = self.norm(x)
        x = self.dropout(x)
        emissions = self.classifier(x)  # [batch_size, seq_len, num_labels]
        
        loss = None
        if labels is not None:
            # Create mask for CRF
            crf_mask = attention_mask.bool()
            
            # CRF loss (negative log-likelihood)
            loss = -self.crf(emissions, labels, mask=crf_mask, reduction='mean')
        
        # CRF decoding for predictions
        predictions = self.crf.decode(emissions, mask=attention_mask.bool())
        # Convert list of lists to tensor with padding
        max_len = emissions.size(1)
        pred_tensor = torch.zeros_like(input_ids)
        for i, pred_seq in enumerate(predictions):
            pred_tensor[i, :len(pred_seq)] = torch.tensor(pred_seq, device=pred_tensor.device)
        
        return {
            'loss': loss,
            'logits': emissions,
            'predictions': pred_tensor
        }

# Post-processing function
def post_process_bio_tags(tokens, tags, token_is_first_subword):
    """
    Apply linguistic rules to fix common errors in BIO tag sequences
    
    Parameters:
    - tokens: List of tokens (including subtokens)
    - tags: Predicted BIO tags (0=O, 1=B-IDIOM, 2=I-IDIOM)
    - token_is_first_subword: Boolean list indicating if a token is the first subword of a word
    
    Returns:
    - Corrected BIO tags
    """
    corrected_tags = tags.copy()
    
    # Rule 1: Fix I-IDIOM without preceding B-IDIOM
    for i in range(len(tags)):
        if i > 0 and token_is_first_subword[i] and tags[i] == 2 and tags[i-1] == 0:  # I-IDIOM after O
            # Either correct to B-IDIOM or O
            if i < len(tags)-1 and tags[i+1] == 2:  # If followed by I-IDIOM
                corrected_tags[i] = 1  # Convert to B-IDIOM
            else:
                corrected_tags[i] = 0  # Convert to O if isolated
    
    # Rule 2: Fix consecutive B-IDIOM tags (should usually be B-IDIOM followed by I-IDIOM)
    for i in range(len(tags)-1):
        if token_is_first_subword[i] and token_is_first_subword[i+1] and tags[i] == 1 and tags[i+1] == 1:
            corrected_tags[i+1] = 2  # Convert second to I-IDIOM
    
    # Rule 3: Fix B-IDIOM without any following I-IDIOM and not followed by a special token
    # This might be a model error for very short expressions
    for i in range(len(tags)-1):
        if token_is_first_subword[i] and token_is_first_subword[i+1] and tags[i] == 1 and tags[i+1] == 0:
            # Check context to determine if it's likely part of an idiom
            if i+2 < len(tags) and tags[i+2] == 2:  # Pattern O, B, O, I
                corrected_tags[i+1] = 2  # Make the O between B and I also I
    
    return corrected_tags

def apply_post_processing(model, tokenizer, input_ids, attention_mask, device):
    """Apply model prediction and post-processing to input sequence"""
    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    preds = outputs['predictions'][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    masks = attention_mask[0]
    
    # Identify which tokens are first subwords
    token_is_first_subword = [True]  # For CLS token
    for token in tokens[1:]:  # Skip CLS token
        if token in ['[SEP]', '[PAD]']:
            token_is_first_subword.append(True)  # Special tokens
        else:
            token_is_first_subword.append(not token.startswith('##'))
    
    # Apply post-processing
    corrected_preds = post_process_bio_tags(tokens, preds.tolist(), token_is_first_subword)
    
    # Map back to tensor
    corrected_tensor = torch.tensor(corrected_preds, device=preds.device)
    
    return corrected_tensor, tokens, masks

def predict_idioms_with_postprocessing(model, tokenizer, sentence, device, language=None):
    model.eval()
    
    # Tokenize
    encoding = tokenizer.encode_plus(
        sentence,
        add_special_tokens=True,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt"
    )
    
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)
    
    # Apply model and post-processing
    corrected_preds, tokens, masks = apply_post_processing(
        model, tokenizer, input_ids, attention_mask, device
    )
    
    # Convert to word-level predictions
    words = []
    bio_tags = []
    word_idx = -1
    idiom_indices = []
    current_idiom = []
    current_word = ""
    previous_tag = 0  # O tag
    
    for token, mask, pred in zip(tokens, masks, corrected_preds):
        if mask == 0 or token in ['[CLS]', '[SEP]', '[PAD]']:
            continue
            
        if not token.startswith('##'):  # New word
            # Save previous word
            if current_word:
                words.append(current_word)
                bio_tags.append(previous_tag)
                word_idx += 1
                
                # Handle idiom tracking
                if previous_tag in [1, 2] and pred.item() not in [1, 2]:  # End of idiom
                    if current_idiom:
                        idiom_indices.extend(current_idiom)
                        current_idiom = []
            
            # Start new word
            current_word = token
            previous_tag = pred.item()
            
            # Track idioms
            if pred.item() == 1:  # B-IDIOM
                current_idiom = [word_idx + 1]  # +1 because we haven't incremented yet
            elif pred.item() == 2:  # I-IDIOM
                if previous_tag in [1, 2]:  # Continue idiom
                    current_idiom.append(word_idx + 1)
        else:
            # Continue current word
            current_word += token[2:]  # Remove ## prefix
    
    # Don't forget last word
    if current_word:
        words.append(current_word)
        bio_tags.append(previous_tag)
        
        # Handle last idiom
        if previous_tag in [1, 2] and current_idiom:
            idiom_indices.extend(current_idiom)
    
    # Apply language-specific post-processing if language is provided
    if language == 'italian':
        # Add Italian-specific rules here if needed
        pass
    elif language == 'turkish':
        # Add Turkish-specific rules here if needed
        pass
    
    return idiom_indices

def evaluate(model, val_loader, tokenizer, device, apply_postprocessing=True):
    model.eval()
    val_loss = 0.0
    predictions = []
    ground_truth = []
    total_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            if batch['input_ids'].size(0) == 0:
                continue
                
            total_batches += 1
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass with loss calculation
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            if outputs['loss'] is not None:
                val_loss += outputs['loss'].item()
            
            preds = outputs['predictions']
            
            if apply_postprocessing:
                # Apply post-processing to each sequence
                processed_preds = []
                for i in range(preds.size(0)):
                    # Identify which tokens are first subwords for this sequence
                    tokens = tokenizer.convert_ids_to_tokens(input_ids[i])
                    token_is_first_subword = [True]  # For CLS token
                    for token in tokens[1:]:  # Skip CLS token
                        if token in ['[SEP]', '[PAD]']:
                            token_is_first_subword.append(True)  # Special tokens
                        else:
                            token_is_first_subword.append(not token.startswith('##'))
                    
                    # Apply post-processing
                    corrected_seq = post_process_bio_tags(tokens, preds[i].tolist(), token_is_first_subword)
                    processed_preds.append(torch.tensor(corrected_seq, device=preds.device))
                
                # Replace original predictions with processed ones
                preds = torch.stack(processed_preds)
            
            # Process each sequence in batch (same as original evaluate function)
            for seq_preds, seq_mask, seq_labels, seq_ids in zip(preds, attention_mask, labels, input_ids):
                tokens = tokenizer.convert_ids_to_tokens(seq_ids)
                
                # Extract idiom indices based on BIO tags
                # For ground truth
                word_idx = -1
                true_idiom_indices = []
                current_idiom_indices = []
                previous_tag = 0  # O tag
                
                for i, (token, mask, label) in enumerate(zip(tokens, seq_mask, seq_labels)):
                    if mask == 0 or token in ['[CLS]', '[SEP]', '[PAD]']:
                        continue
                        
                    if not token.startswith('##'):  # New word
                        word_idx += 1
                        
                        # Handle end of previous idiom
                        if previous_tag in [1, 2] and label.item() not in [1, 2]:
                            # End of idiom
                            if current_idiom_indices:
                                true_idiom_indices.extend(current_idiom_indices)
                                current_idiom_indices = []
                        
                        # Handle new idiom
                        if label.item() == 1:  # B-IDIOM
                            current_idiom_indices = [word_idx]
                        elif label.item() == 2:  # I-IDIOM
                            if previous_tag in [1, 2]:  # Continue idiom
                                current_idiom_indices.append(word_idx)
                        
                        previous_tag = label.item()
                
                # Don't forget last idiom
                if current_idiom_indices:
                    true_idiom_indices.extend(current_idiom_indices)
                
                # For predictions
                word_idx = -1
                pred_idiom_indices = []
                current_idiom_indices = []
                previous_tag = 0  # O tag
                
                for i, (token, mask, pred) in enumerate(zip(tokens, seq_mask, seq_preds)):
                    if mask == 0 or token in ['[CLS]', '[SEP]', '[PAD]']:
                        continue
                        
                    if not token.startswith('##'):  # New word
                        word_idx += 1
                        
                        # Handle end of previous idiom
                        if previous_tag in [1, 2] and pred.item() not in [1, 2]:
                            # End of idiom
                            if current_idiom_indices:
                                pred_idiom_indices.extend(current_idiom_indices)
                                current_idiom_indices = []
                        
                        # Handle new idiom
                        if pred.item() == 1:  # B-IDIOM
                            current_idiom_indices = [word_idx]
                        elif pred.item() == 2:  # I-IDIOM
                            if previous_tag in [1, 2]:  # Continue idiom
                                current_idiom_indices.append(word_idx)
                        
                        previous_tag = pred.item()
                
                # Don't forget last idiom
                if current_idiom_indices:
                    pred_idiom_indices.extend(current_idiom_indices)
                
                # Store results
                predictions.append(pred_idiom_indices)
                ground_truth.append(true_idiom_indices)
    
    # Calculate average loss
    avg_val_loss = val_loss / max(1, total_batches)
    
    # Calculate F1 scores using competition method
    f1_scores = []
    for pred, gold in zip(predictions, ground_truth):
        # Handle special case for no idiom
        if not gold:  # Empty gold = no idiom
            if not pred:  # Empty pred = correctly predicted no idiom
                f1_scores.append(1.0)
            else:
                f1_scores.append(0.0)
            continue
            
        # Normal case - set comparison
        pred_set = set(pred)
        gold_set = set(gold)
        
        intersection = len(pred_set & gold_set)
        precision = intersection / len(pred_set) if len(pred_set) > 0 else 0
        recall = intersection / len(gold_set) if len(gold_set) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        f1_scores.append(f1)
    
    mean_f1 = sum(f1_scores) / max(1, len(f1_scores))
    
    print(f"\nValidation Loss: {avg_val_loss:.4f}")
    print(f"Mean F1 Score: {mean_f1:.4f}")
    
    return {
        'loss': avg_val_loss,
        'f1': mean_f1,
        'predictions': predictions,
        'ground_truth': ground_truth
    }

def train_model(train_loader, val_loader, tokenizer, model=None, epochs=10, lr=2e-5, 
                weight_decay=0.01, lr_multiplier=10, patience=3):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create model if not provided
    if model is None:
        model = EnhancedBertForIdiomDetection().to(device)
    else:
        model = model.to(device)
    
    # Differential learning rates
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() 
                       if not any(nd in n for nd in no_decay) and 'bert' in n and p.requires_grad],
            'weight_decay': weight_decay,
            'lr': lr
        },
        {
            'params': [p for n, p in model.named_parameters() 
                       if any(nd in n for nd in no_decay) and 'bert' in n and p.requires_grad],
            'weight_decay': 0.0,
            'lr': lr
        },
        {
            'params': [p for n, p in model.named_parameters() 
                       if not any(nd in n for nd in no_decay) and 'bert' not in n],
            'weight_decay': weight_decay,
            'lr': lr * lr_multiplier
        },
        {
            'params': [p for n, p in model.named_parameters() 
                       if any(nd in n for nd in no_decay) and 'bert' not in n],
            'weight_decay': 0.0,
            'lr': lr * lr_multiplier
        }
    ]
    
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
    
        # Learning rate scheduler with warmup
    total_steps = len(train_loader) * epochs
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=[lr, lr, lr*lr_multiplier, lr*lr_multiplier],
        total_steps=total_steps,
        pct_start=0.1  # 10% warmup
    )
    
    best_f1 = 0
    no_improve_epochs = 0
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        # Training
        model.train()
        train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
        
        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            optimizer.zero_grad()
            outputs = model(**batch)
            loss = outputs['loss']
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        avg_train_loss = train_loss / len(train_loader)
        
        # Evaluation
        metrics = evaluate(model, val_loader, tokenizer, device, apply_postprocessing=True)
        
        print(f"Epoch {epoch+1}:")
        print(f"Training Loss: {avg_train_loss:.4f}")
        print(f"Validation Loss: {metrics['loss']:.4f}")
        print(f"F1 Score: {metrics['f1']:.4f}")
        
        # Save best model and check for early stopping
        if metrics['f1'] > best_f1:
            best_f1 = metrics['f1']
            torch.save(model.state_dict(), "best_idiom_model.pt")
            print("New best model saved!")
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            print(f"No improvement for {no_improve_epochs} epochs")
            
            if no_improve_epochs >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    # Load the best model weights
    model.load_state_dict(torch.load("best_idiom_model.pt"))
    return model

# Now update the main functions for training, evaluation, and prediction

def run_train(epochs=10, lr=2e-5, batch_size=8, max_length=128, 
              lstm_hidden_size=384, lstm_layers=2, lstm_dropout=0.3,
              hidden_dropout=0.3, use_layer_norm=True, freeze_bert_layers=0,
              weight_decay=0.01, lr_multiplier=10, patience=3):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Get data loaders
    train_loader, val_loader, test_loader = get_dataloaders(
        batch_size=batch_size, 
        max_length=max_length
    )
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    
    # Initialize model with optimal parameters
    model = EnhancedBertForIdiomDetection(
        lstm_hidden_size=lstm_hidden_size,
        lstm_layers=lstm_layers,
        lstm_dropout=lstm_dropout,
        hidden_dropout=hidden_dropout,
        use_layer_norm=use_layer_norm,
        freeze_bert_layers=freeze_bert_layers
    )
    
    # Train the model
    model = train_model(
        train_loader, 
        val_loader, 
        tokenizer,
        model=model,
        epochs=epochs,
        lr=lr,
        weight_decay=weight_decay,
        lr_multiplier=lr_multiplier,
        patience=patience
    )
    
    # Save final model
    torch.save(model.state_dict(), 'final_idiom_model.pt')
    print("Training complete!")
    
    return model

def run_eval(batch_size=8, max_length=128, 
             lstm_hidden_size=384, lstm_layers=2, lstm_dropout=0.3,
             hidden_dropout=0.3, use_layer_norm=True, freeze_bert_layers=0,
             model_path='best_idiom_model.pt'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Get data loaders
    _, val_loader, _ = get_dataloaders(
        batch_size=batch_size, 
        max_length=max_length
    )
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    
    # Initialize model with the same parameters
    model = EnhancedBertForIdiomDetection(
        lstm_hidden_size=lstm_hidden_size,
        lstm_layers=lstm_layers,
        lstm_dropout=lstm_dropout,
        hidden_dropout=hidden_dropout,
        use_layer_norm=use_layer_norm,
        freeze_bert_layers=freeze_bert_layers
    )
    
    # Load model weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    
    # Evaluate the model with post-processing
    metrics = evaluate(model, val_loader, tokenizer, device, apply_postprocessing=True)
    
    print(f"Evaluation complete!")
    print(f"F1 Score: {metrics['f1']:.4f}")
    
    return metrics

def run_predict(output='predictions.csv', 
                lstm_hidden_size=384, lstm_layers=2, lstm_dropout=0.3,
                hidden_dropout=0.3, use_layer_norm=True, freeze_bert_layers=0,
                model_path='best_idiom_model.pt'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model with the same parameters
    model = EnhancedBertForIdiomDetection(
        lstm_hidden_size=lstm_hidden_size,
        lstm_layers=lstm_layers,
        lstm_dropout=lstm_dropout,
        hidden_dropout=hidden_dropout,
        use_layer_norm=use_layer_norm,
        freeze_bert_layers=freeze_bert_layers
    )
    
    # Load model weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

    # Read test data
    test_df = pd.read_csv('public_data/eval_w_o_labels.csv')
    ids = test_df['id'].tolist()
    sentences = test_df['sentence'].tolist()
    languages = test_df['language'].tolist()

    results = []
    for idx, sentence, lang in zip(ids, sentences, languages):
        # Use the new prediction function with post-processing
        idiom_indices = predict_idioms_with_postprocessing(
            model, tokenizer, sentence, device, language=lang
        )
        
        # If no idiom is found, use [-1] as per the competition format
        if not idiom_indices:
            idiom_indices = [-1]
            
        # Format the indices as a string for submission
        indices_str = str(idiom_indices).replace(' ', '')
        
        results.append({
            'id': idx,
            'indices': indices_str,
            'language': lang
        })

    # Save predictions to CSV
    out_df = pd.DataFrame(results)
    out_df.to_csv(output, index=False)
    print(f'Predictions saved to {output}')
    
    return out_df

def run_hyperparameter_search(n_trials=5, batch_size=8, max_length=128):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Get data loaders
    train_loader, val_loader, _ = get_dataloaders(
        batch_size=batch_size, 
        max_length=max_length
    )
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    
    # Define the hyperparameter search space
    param_grid = {
        'learning_rate': [1e-5, 2e-5, 3e-5, 5e-5],
        'lstm_hidden_size': [256, 384, 512],
        'lstm_layers': [1, 2, 3],
        'lstm_dropout': [0.2, 0.3, 0.4],
        'hidden_dropout': [0.1, 0.2, 0.3, 0.4],
        'use_layer_norm': [True, False],
        'freeze_bert_layers': [0, 3, 6],
        'weight_decay': [0.0, 0.01, 0.05],
        'lr_multiplier': [5, 10, 15]
    }
    
    best_f1 = 0
    best_config = None
    results = []
    
    for trial in range(n_trials):
        print(f"\n===== Trial {trial+1}/{n_trials} =====")
        
        # Sample random hyperparameters and convert to Python native types
        config = {
            'learning_rate': float(np.random.choice(param_grid['learning_rate'])),
            'lstm_hidden_size': int(np.random.choice(param_grid['lstm_hidden_size'])),
            'lstm_layers': int(np.random.choice(param_grid['lstm_layers'])),
            'lstm_dropout': float(np.random.choice(param_grid['lstm_dropout'])),
            'hidden_dropout': float(np.random.choice(param_grid['hidden_dropout'])),
            'use_layer_norm': bool(np.random.choice(param_grid['use_layer_norm'])),
            'freeze_bert_layers': int(np.random.choice(param_grid['freeze_bert_layers'])),
            'weight_decay': float(np.random.choice(param_grid['weight_decay'])),
            'lr_multiplier': int(np.random.choice(param_grid['lr_multiplier']))
        }
        
        print("Configuration:")
        for k, v in config.items():
            print(f"  {k}: {v}")
        
        # Initialize model with sampled hyperparameters
        model = EnhancedBertForIdiomDetection(
            lstm_hidden_size=config['lstm_hidden_size'],
            lstm_layers=config['lstm_layers'],
            lstm_dropout=config['lstm_dropout'],
            hidden_dropout=config['hidden_dropout'],
            use_layer_norm=config['use_layer_norm'],
            freeze_bert_layers=config['freeze_bert_layers']
        )
        
        # Train for a few epochs to evaluate the configuration
        trial_epochs = 3  # Use fewer epochs for quick evaluation
        model = train_model(
            train_loader, 
            val_loader, 
            tokenizer,
            model=model,
            epochs=trial_epochs,
            lr=config['learning_rate'],
            weight_decay=config['weight_decay'],
            lr_multiplier=config['lr_multiplier'],
            patience=2  # Use shorter patience for hyperparameter search
        )
        
        # Evaluate with post-processing
        metrics = evaluate(model, val_loader, tokenizer, device, apply_postprocessing=True)
        f1_score = metrics['f1']
        
        print(f"Trial {trial+1} F1 Score: {f1_score:.4f}")
        
        # Save results
        config['f1_score'] = f1_score
        results.append(config)
        
        # Update best configuration
        if f1_score > best_f1:
            best_f1 = f1_score
            best_config = config
            print(f"New best configuration found! F1: {best_f1:.4f}")
    
    # Print results summary
    print("\n===== Hyperparameter Search Results =====")
    print(f"Best F1 Score: {best_f1:.4f}")
    print("Best Configuration:")
    for k, v in best_config.items():
        print(f"  {k}: {v}")
    
    # Sort all results by F1 score
    results.sort(key=lambda x: x['f1_score'], reverse=True)
    print("\nTop 3 configurations:")
    for i, config in enumerate(results[:3]):
        print(f"Rank {i+1}: F1={config['f1_score']:.4f}")
        for k, v in config.items():
            if k != 'f1_score':
                print(f"  {k}: {v}")
    
    return best_config

# Function to run full pipeline with best hyperparameters
def run_full_pipeline(n_trials=5, final_epochs=10, output='predictions.csv'):
    # Step 1: Find optimal hyperparameters
    print("Starting hyperparameter search...")
    best_config = run_hyperparameter_search(n_trials=n_trials)
    
    # Step 2: Train the final model with the best configuration
    # Ensure all numeric values are Python native types
    print("\nTraining final model with the best configuration...")
    final_model = run_train(
        epochs=final_epochs,
        lr=float(best_config['learning_rate']),
        batch_size=8,
        max_length=128,
        lstm_hidden_size=int(best_config['lstm_hidden_size']),
        lstm_layers=int(best_config['lstm_layers']),
        lstm_dropout=float(best_config['lstm_dropout']),
        hidden_dropout=float(best_config['hidden_dropout']),
        use_layer_norm=bool(best_config['use_layer_norm']),
        freeze_bert_layers=int(best_config['freeze_bert_layers']),
        weight_decay=float(best_config['weight_decay']),
        lr_multiplier=int(best_config['lr_multiplier'])
    )
    
    # Step 3: Evaluate the final model
    print("\nEvaluating the final model...")
    metrics = run_eval(
        batch_size=8,
        max_length=128,
        lstm_hidden_size=best_config['lstm_hidden_size'],
        lstm_layers=best_config['lstm_layers'],
        lstm_dropout=best_config['lstm_dropout'],
        hidden_dropout=best_config['hidden_dropout'],
        use_layer_norm=best_config['use_layer_norm'],
        freeze_bert_layers=best_config['freeze_bert_layers']
    )
    
    # Step 4: Generate predictions
    print("\nGenerating predictions...")
    predictions = run_predict(
        output=output,
        lstm_hidden_size=best_config['lstm_hidden_size'],
        lstm_layers=best_config['lstm_layers'],
        lstm_dropout=best_config['lstm_dropout'],
        hidden_dropout=best_config['hidden_dropout'],
        use_layer_norm=best_config['use_layer_norm'],
        freeze_bert_layers=best_config['freeze_bert_layers']
    )
    
    print(f"\nFull pipeline completed successfully!")
    print(f"Final F1 Score: {metrics['f1']:.4f}")
    print(f"Predictions saved to: {output}")
    
    return {
        'best_config': best_config,
        'final_f1': metrics['f1'],
        'predictions': predictions
    }

In [None]:
# Evaluate the model
run_eval()

# Generate predictions
run_predict(output='predictions.csv')

# Or run the full pipeline with hyperparameter search
run_full_pipeline(n_trials=5, final_epochs=10)