In [None]:
import random
import os
import json
import torch
import optuna
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from sentence_transformers import SentenceTransformer
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics import Precision, Recall


# === SEED ===
SEED = 42
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    pl.seed_everything(seed, workers=True)

# Ensure it's called!
seed_everything(SEED)

# === CONFIG ===
TRAIN_JSONL_PATH = "data/train.jsonl"  # Correct path for your dataset
VAL_JSONL_PATH = "data/val.jsonl"  # Correct path for your dataset
MODEL_NAME = "BAAI/bge-large-en-v1.5"
OUTPUT_PATH = "data/fine_tuned_gaap_classifier"
os.makedirs(OUTPUT_PATH, exist_ok=True)
OPTUNA_DB_PATH = os.path.join(OUTPUT_PATH, "optuna_study.db")
EPOCHS = 200
PATIENCE = 5

device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# === Load Data from JSONL files ===
def load_jsonl(filepath):
    with open(filepath, "r") as f:
        return [json.loads(line) for line in f]

train_data = load_jsonl(TRAIN_JSONL_PATH)
val_data = load_jsonl(VAL_JSONL_PATH)

# === Dynamically determine the number of possible categories ===
all_categories = set()
for entry in train_data + val_data:
    all_categories.update(entry["labels"])

num_labels = max(all_categories)  # Dynamically find the highest category label number
print(f"Number of categories: {num_labels}")


# === Dataset Class ===
class MultiLabelDataset(Dataset):
    def __init__(self, data):
        self.samples = []
        for d in data:
            input_text = d["input_text"]
            labels = d["labels"]
            # Ensure labels are a list of integers, default to an empty list if not available
            labels = [int(label) for label in labels] if labels else []
            self.samples.append((input_text, labels))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text, labels = self.samples[idx]
        labels_tensor = torch.zeros(num_labels)  # Dynamic number of categories
        for label in labels:
            labels_tensor[label - 1] = 1  # Set category positions to 1 (0-indexed for PyTorch)
        return text, labels_tensor

def collate_fn(batch):
    texts, labels = zip(*batch)
    return list(texts), torch.stack(labels)

# === Model Definition ===
class GAAPClassifier(pl.LightningModule):
    def __init__(self, model_name, dropout_rate, num_labels, lr):
        super().__init__()
        self.encoder = SentenceTransformer(model_name, device=device)
        self.dim = self.encoder.get_sentence_embedding_dimension()
        self.attn = torch.nn.Sequential(
            torch.nn.Linear(self.dim, self.dim),
            torch.nn.GELU(),
            torch.nn.Linear(self.dim, self.dim)  # Keep the same dimension
        )

        self.norm = torch.nn.LayerNorm(self.dim)

        self.num_labels = num_labels  # Ensure num_labels is passed

        # Adjust output layer for multi-label classification
        self.head = torch.nn.Sequential(
            torch.nn.Linear(self.dim, 128),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(128, self.num_labels)  # Use num_labels for output layer
        )
        
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.save_hyperparameters()

        # Metrics (removed compute_on_step argument)
        self.precision = Precision(num_labels=self.num_labels, average="macro", task="multilabel")
        self.recall = Recall(num_labels=self.num_labels, average="macro", task="multilabel")

    def forward(self, texts):
        with torch.no_grad():
            embeddings = self.encoder.encode(texts, convert_to_tensor=True, device=device)        
        attended = self.attn(embeddings)
        attended = self.norm(attended)
        return self.head(attended)


    def compute_loss(self, outputs, labels):
        return self.loss_fn(outputs, labels)

    def training_step(self, batch, batch_idx):
        texts, labels = batch
        outputs = self(texts)
        loss = self.compute_loss(outputs, labels)
        self.log("train/loss", loss, prog_bar=True)

        pred = torch.sigmoid(outputs) > 0.5
        self.log("train/precision", self.precision(pred, labels))
        self.log("train/recall", self.recall(pred, labels))
        return loss

    def validation_step(self, batch, batch_idx):
        texts, labels = batch
        outputs = self(texts)
        loss = self.compute_loss(outputs, labels)
        self.log("val/loss", loss, prog_bar=True)

        pred = torch.sigmoid(outputs) > 0.5
        self.log("val/precision", self.precision(pred, labels))
        self.log("val/recall", self.recall(pred, labels))
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

# === Objective Function for Optuna ===
def objective(trial):
    batch_size = trial.suggest_int("batch_size", 8, 64, step=8)
    lr = trial.suggest_float("lr", 1e-6, 1e-2, log=True)
    dropout_rate = trial.suggest_float("dropout_rate", 0, 0.5, step=0.1)

    # Load dataset
    train_loader = DataLoader(MultiLabelDataset(train_data),
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn=collate_fn)
    val_loader = DataLoader(MultiLabelDataset(val_data),
                            batch_size=batch_size,
                            shuffle=False,
                            collate_fn=collate_fn)

    model = GAAPClassifier(MODEL_NAME, dropout_rate, num_labels, lr)
    trainer = pl.Trainer(
        max_epochs=EPOCHS,
        callbacks=[EarlyStopping(monitor="val/loss", patience=PATIENCE)],
        logger=TensorBoardLogger(OUTPUT_PATH),
        accelerator="auto",
        devices=1
    )

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    return trainer.callback_metrics["val/loss"].item()

# === Optuna Optimization ===
study = optuna.create_study(direction="minimize", storage=f"sqlite:///{OPTUNA_DB_PATH}", load_if_exists=True)
study.optimize(objective, n_trials=200)

# Best Params
print("Best params:", study.best_params)
best_trial = study.best_trial
print(f"Best trial value: {best_trial.value}")
for k, v in best_trial.params.items():
    print(f"    {k}: {v}")
