## Hyperparameter Tuning for DeBERTa + BiLSTM using Optuna

This notebook performs hyperparameter tuning for a hybrid **DeBERTa + BiLSTM model** on a binary Natural Language Inference (NLI) task using **Optuna**.

The goal is to find the optimal values for:

- Learning rate
- Weight decay
- Dropout rate
- Hidden dimension of the LSTM
- Batch size
- Number of training epochs

We use validation accuracy as the optimization metric, and Optuna's `MedianPruner` to speed up the tuning process by pruning underperforming trials early.


In [None]:
import torch
import pandas as pd
import numpy as np
import optuna
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch.nn.functional import softmax
from sklearn.metrics import accuracy_score
from transformers import AutoModel, AutoTokenizer
from optuna.trial import TrialState

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Tokenizer and Dataset

We load the pretrained tokenizer (`microsoft/deberta-v3-base`) and read the `train.csv` and `dev.csv` datasets, which contain `premise`, `hypothesis`, and `label` columns.


In [None]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
df_train = pd.read_csv('/kaggle/input/training/train.csv')
df_val = pd.read_csv('/kaggle/input/training/dev.csv')

## Define Dataset Class and Tokenization Function

We define:
- A `NLIDataset` class that prepares batches of tokenized input.
- A helper function to tokenize `premise` and `hypothesis` pairs using a max length of 128.


In [None]:
# Dataset class
class NLIDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx]
        }

    def __len__(self):
        return len(self.labels)

# Tokenization function
def tokenize_premise_hypothesis(premises, hypotheses, max_length=128):
    return tokenizer(
        premises,
        hypotheses,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors='pt'
    )

## Define the DeBERTa + BiLSTM Model

The model combines:
- A DeBERTa-v3 transformer encoder
- A bidirectional LSTM (`hidden_dim`, 1 layer)
- Dropout
- A fully connected classification layer (`Linear(hidden_dim * 2, 2)`)


In [None]:
# BiLSTM model
class DeBERTaWithBiLSTM(nn.Module):
    def __init__(self, hidden_dim=256, dropout=0.3):
        super().__init__()
        self.base_model = AutoModel.from_pretrained("microsoft/deberta-v3-base")
        self.bilstm = nn.LSTM(768, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim * 2, 2)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        lstm_out, _ = self.bilstm(sequence_output)
        pooled_output = lstm_out[:, 0]
        out = self.dropout(pooled_output)
        logits = self.classifier(out)
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            return {'loss': loss, 'logits': logits}
        return {'logits': logits}

## Define Optuna Objective Function

This function:
- Samples hyperparameters from given ranges.
- Initializes the model and optimizer.
- Trains the model for `num_epochs`.
- Evaluates on the validation set each epoch.
- Uses `trial.report()` and `trial.should_prune()` for pruning.

Validation accuracy is returned as the optimization metric.

We use `MedianPruner` to prune underperforming trials early. The study is configured to maximize validation accuracy.

We run 20 trials to explore different hyperparameter combinations.

In [None]:


# Optuna objective
def objective(trial):
    learning_rate = trial.suggest_float("learning_rate", 5e-6, 3e-5, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 5e-4, log=True)
    dropout = trial.suggest_float("dropout", 0.2, 0.4)
    hidden_dim = trial.suggest_categorical("hidden_dim", [128, 256, 384])
    batch_size = trial.suggest_categorical("batch_size", [16, 32])
    num_epochs = trial.suggest_int("num_epochs", 5, 10)

    train_enc = tokenize_premise_hypothesis(df_train['premise'].tolist(), df_train['hypothesis'].tolist())
    val_enc = tokenize_premise_hypothesis(df_val['premise'].tolist(), df_val['hypothesis'].tolist())
    train_labels = torch.tensor(df_train['label'].values)
    val_labels = torch.tensor(df_val['label'].values)

    train_dataset = NLIDataset(train_enc, train_labels)
    val_dataset = NLIDataset(val_enc, val_labels)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    model = DeBERTaWithBiLSTM(hidden_dim=hidden_dim, dropout=dropout).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            outputs = model(
                input_ids=batch['input_ids'].to(device),
                attention_mask=batch['attention_mask'].to(device),
                labels=batch['labels'].to(device)
            )
            loss = outputs['loss']
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        # Evaluate after each epoch for pruning
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                outputs = model(
                    input_ids=batch['input_ids'].to(device),
                    attention_mask=batch['attention_mask'].to(device),
                    labels=batch['labels'].to(device)
                )
                preds = torch.argmax(outputs['logits'], dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(batch['labels'].cpu().numpy())

        val_acc = accuracy_score(all_labels, all_preds)
        trial.report(val_acc, epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return val_acc

# Run Optuna with pruning
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=1)
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=20)

# Print best trial
print("✅ Best Hyperparameters Found:")
for key, value in study.best_trial.params.items():
    print(f"{key}: {value}")


tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

[I 2025-04-07 01:58:54,505] A new study created in memory with name: no-name-e05d654e-c658-4754-91f9-2d9e1a70e844


pytorch_model.bin:   0%|          | 0.00/371M [00:00<?, ?B/s]

[I 2025-04-07 03:07:38,190] Trial 0 finished with value: 0.9073634204275535 and parameters: {'learning_rate': 1.1596292509631827e-05, 'weight_decay': 1.86556109418857e-05, 'dropout': 0.2629848916241302, 'hidden_dim': 256, 'batch_size': 32, 'num_epochs': 10}. Best is trial 0 with value: 0.9073634204275535.
[I 2025-04-07 04:24:25,228] Trial 1 finished with value: 0.9125593824228029 and parameters: {'learning_rate': 1.5191865828425241e-05, 'weight_decay': 0.00029798996534771856, 'dropout': 0.2127259462963618, 'hidden_dim': 256, 'batch_size': 16, 'num_epochs': 10}. Best is trial 1 with value: 0.9125593824228029.
[I 2025-04-07 05:10:32,887] Trial 2 finished with value: 0.9103325415676959 and parameters: {'learning_rate': 2.1124594298151137e-05, 'weight_decay': 7.028390337983515e-06, 'dropout': 0.3275037533574017, 'hidden_dim': 256, 'batch_size': 16, 'num_epochs': 6}. Best is trial 1 with value: 0.9125593824228029.
[I 2025-04-07 06:20:15,132] Trial 3 finished with value: 0.9171615201900237 a

✅ Best Hyperparameters Found:
learning_rate: 8.663394579529044e-06
weight_decay: 0.00043725068136265345
dropout: 0.38920213235632034
hidden_dim: 384
batch_size: 32
num_epochs: 5
