In [None]:
import random
import os
import json
import torch
import optuna
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from sentence_transformers import SentenceTransformer
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics import Precision, Recall


# === SEED ===
SEED = 42
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)

    # In PyTorch, torch.manual_seed() applies to the CPU by default, so no extra handling is needed for CPU.
    torch.manual_seed(seed)

    # If using MPS (Apple Silicon), set the seed for MPS device
    if torch.backends.mps.is_available():
        torch.manual_seed(seed)
    
    # If using CUDA (GPU), set the seed for CUDA device
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    pl.seed_everything(seed, workers=True)

# Ensure it's called!
seed_everything(SEED)

# === CONFIG ===
TRAIN_JSONL_PATH = "data/train.jsonl"  # Correct path for your dataset
VAL_JSONL_PATH = "data/val.jsonl"  # Correct path for your dataset
MODEL_NAME = "BAAI/bge-large-en-v1.5"
OUTPUT_PATH = "data/fine_tuned_gaap_classifier"
os.makedirs(OUTPUT_PATH, exist_ok=True)
OPTUNA_DB_PATH = os.path.join(OUTPUT_PATH, "optuna_study.db")
EPOCHS = 20
PATIENCE = 5

device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# === Load Data from JSONL files ===
def load_jsonl(filepath):
    with open(filepath, "r") as f:
        return [json.loads(line) for line in f]

train_data = load_jsonl(TRAIN_JSONL_PATH)
val_data = load_jsonl(VAL_JSONL_PATH)

# === Dynamically determine the number of possible categories ===
all_categories = set()
for entry in train_data + val_data:
    all_categories.update(entry["labels"])

num_labels = max(all_categories)  # Dynamically find the highest category label number
print(f"Number of categories: {num_labels}")

# Dynamically calculate num_balance_types and num_period_types
num_balance_types = len(set([d['balance_type_id'] for d in train_data + val_data]))  # Unique balance types
num_period_types = len(set([d['period_type_id'] for d in train_data + val_data]))  # Unique period types

print(f"number of balance types: {num_balance_types}")
print(f"number of period types: {num_period_types}")


# === Dataset Class ===
class MultiLabelDataset(Dataset):
    def __init__(self, data):
        self.samples = []
        for d in data:
            input_text = d["input_text"]
            labels = d["labels"]
            balance_type_id = d["balance_type_id"]
            period_type_id = d["period_type_id"]
            
            # Ensure labels are a list of integers, default to an empty list if not available
            labels = [int(label) for label in labels] if labels else []
            
            self.samples.append((input_text, labels, balance_type_id, period_type_id))

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text, labels, balance_type_id, period_type_id = self.samples[idx]
        
        # Convert labels to tensor
        labels_tensor = torch.zeros(num_labels)  # Dynamic number of categories
        for label in labels:
            labels_tensor[label - 1] = 1  # Set category positions to 1 (0-indexed for PyTorch)
        
        # Return text, labels, balance_type_id, and period_type_id
        return text, labels_tensor, balance_type_id, period_type_id


def collate_fn(batch):
    texts, labels, balance_type_ids, period_type_ids = zip(*batch)
    # Stack the labels as before
    labels_tensor = torch.stack(labels)
    # Stack the balance_type_ids and period_type_ids
    balance_type_ids_tensor = torch.tensor(balance_type_ids, dtype=torch.long)
    period_type_ids_tensor = torch.tensor(period_type_ids, dtype=torch.long)
    return list(texts), labels_tensor, balance_type_ids_tensor, period_type_ids_tensor


# === Model Definition ===
class GAAPClassifier(pl.LightningModule):
    def __init__(self, model_name, dropout_rate, num_labels, batch_size, lr, num_balance_types, num_period_types):
        super().__init__()

        self.batch_size = batch_size
        self.lr = lr

        self.encoder = SentenceTransformer(model_name, device=device)
        self.dim = self.encoder.get_sentence_embedding_dimension()

        # Embedding layers for balance_type_id and period_type_id
        # self.balance_type_embedding = torch.nn.Embedding(num_balance_types, 8)  # Example embedding size
        # self.period_type_embedding = torch.nn.Embedding(num_period_types, 8)  # Example embedding size

        # Attention mechanism without changing dimensionality
        # self.attn = torch.nn.MultiheadAttention(embed_dim=self.dim, num_heads=1)
        # self.norm = torch.nn.LayerNorm(self.dim)

        self.num_labels = num_labels  # Ensure num_labels is passed

        # Adjust output layer for multi-label classification
        self.head = torch.nn.Sequential(
            # torch.nn.Linear(self.dim + 8 + 8, 128),  # 8 for each embedding size of balance_type_id and period_type_id
            torch.nn.Linear(self.dim, 128),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(128, self.num_labels)  # Use num_labels for output layer
        )
        
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.save_hyperparameters()

        # Metrics
        self.precision = Precision(num_labels=self.num_labels, average="macro", task="multilabel")
        self.recall = Recall(num_labels=self.num_labels, average="macro", task="multilabel")

        # Apply weight initialization
        self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def forward(self, texts, balance_type_ids, period_type_ids):
        # Ensure embeddings are correctly encoded (check the shape here)
        with torch.no_grad():
            embeddings = self.encoder.encode(texts, convert_to_tensor=True, device=device)

        # Reshape embeddings to ensure proper dimensionality for attention
        # embeddings = embeddings.unsqueeze(0)  # Add batch dimension if necessary
        # attended, _ = self.attn(embeddings, embeddings, embeddings)  # Apply attention
        # or
        attended = embeddings
        
        # After attention, reshape back to [batch_size, 1024]
        # attended = attended.squeeze(0)  # Remove the extra batch dimension added by attention
        
        # attended = self.norm(attended)  # Apply normalization

        # Embedding lookups for balance_type_id and period_type_id
        # balance_embedding = self.balance_type_embedding(balance_type_ids)
        # period_embedding = self.period_type_embedding(period_type_ids)

        # Concatenate the embeddings
        # concatenated = torch.cat((attended, balance_embedding, period_embedding), dim=-1)
        # or
        concatenated = attended

        # Final output layer
        return self.head(concatenated)  # Final output layer

    def compute_loss(self, outputs, labels):
        return self.loss_fn(outputs, labels)

    def training_step(self, batch, batch_idx):
        texts, labels, balance_type_id, period_type_id = batch  # Unpack batch
        outputs = self(texts, balance_type_id, period_type_id)  # Get model outputs

        if batch_idx % 100 == 0:
            print(f"Train logits range: min={outputs.min()}, max={outputs.max()}")

        # Compute loss using raw logits (no need for sigmoid here)
        loss = self.compute_loss(outputs, labels)  # BCEWithLogitsLoss already applies sigmoid
        
        self.log("train/loss", loss, prog_bar=True)

        # Log precision and recall for binary predictions (outputs are logits)
        pred = outputs > 0.5  # Directly threshold the logits to binary values (no sigmoid)
        self.log("train/precision", self.precision(pred, labels))
        self.log("train/recall", self.recall(pred, labels))
        return loss

    def validation_step(self, batch, batch_idx):
        texts, labels, balance_type_id, period_type_id = batch  # Unpack batch
        outputs = self(texts, balance_type_id, period_type_id)  # Get model outputs

        if batch_idx % 100 == 0:
            print(f"Val logits range: min={outputs.min()}, max={outputs.max()}")

        # Compute loss using raw logits (no need for sigmoid here)
        loss = self.compute_loss(outputs, labels)  # BCEWithLogitsLoss already applies sigmoid
        
        self.log("val/loss", loss, prog_bar=True)

        # Log precision and recall for validation using raw logits
        pred = outputs > 0.5  # Directly threshold the logits to binary values (no sigmoid)
        self.log("val/precision", self.precision(pred, labels))
        self.log("val/recall", self.recall(pred, labels))
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=1e-5)

# === Objective Function for Optuna ===
def objective(trial):
    batch_size = trial.suggest_int("batch_size", 8, 64, step=8)
    lr = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
    dropout_rate = trial.suggest_float("dropout_rate", 0, 0.5, step=0.1)

    # Load dataset
    train_loader = DataLoader(MultiLabelDataset(train_data),
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn=collate_fn)
    val_loader = DataLoader(MultiLabelDataset(val_data),
                            batch_size=batch_size,
                            shuffle=False,
                            collate_fn=collate_fn)

    # Pass num_balance_types and num_period_types to GAAPClassifier
    model = GAAPClassifier(MODEL_NAME, dropout_rate, num_labels, batch_size, lr, num_balance_types, num_period_types)

    trainer = pl.Trainer(
        max_epochs=EPOCHS,
        callbacks=[EarlyStopping(monitor="val/loss", patience=PATIENCE)],
        logger=TensorBoardLogger(OUTPUT_PATH),
        accelerator="auto",
        devices=1,
        gradient_clip_val=0.75
    )

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    return trainer.callback_metrics["val/loss"].item()

# === Optuna Optimization ===
study = optuna.create_study(direction="minimize", storage=f"sqlite:///{OPTUNA_DB_PATH}", load_if_exists=True)
study.optimize(objective, n_trials=50)

# Best Params
print("Best params:", study.best_params)
best_trial = study.best_trial
print(f"Best trial value: {best_trial.value}")
for k, v in best_trial.params.items():
    print(f"    {k}: {v}")
