In [None]:
import random
import os
import json
import torch
import optuna
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
from sentence_transformers import SentenceTransformer
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import Callback
from torchmetrics.classification import MulticlassAccuracy

# === SEED ===
SEED = 42
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    pl.seed_everything(seed, workers=True)

# Ensure it's called!
seed_everything(SEED)

# === CONFIG ===
DATASET_PATH = "data/us_gaap_multilabel_training_data.json"
MODEL_NAME = "BAAI/bge-large-en-v1.5"
OUTPUT_PATH = "data/fine_tuned_gaap_classifier"
os.makedirs(OUTPUT_PATH, exist_ok=True)
OPTUNA_DB_PATH = os.path.join(OUTPUT_PATH, "optuna_study.db")
EPOCHS = 200
PATIENCE = 5
VALIDATION_SPLIT = 0.2

device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

with open(DATASET_PATH, "r") as f:
    data = json.load(f)

def label_hash(entry):
    return f"{entry['labels']['statement_type']}_{entry['labels']['balance']}_{entry['labels']['period_type']}"

def get_split():
    label_hashes = [label_hash(d) for d in data]
    hash_counts = Counter(label_hashes)
    common_data, common_labels, rare_data = [], [], []
    for i, lh in enumerate(label_hashes):
        if hash_counts[lh] >= 2:
            common_data.append(data[i])
            common_labels.append(lh)
        else:
            rare_data.append(data[i])
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=VALIDATION_SPLIT, random_state=42)
    train_idx, val_idx = next(splitter.split(common_data, common_labels))
    train_data = [common_data[i] for i in train_idx]
    val_data = [common_data[i] for i in val_idx]
    train_data.extend(rare_data)
    return train_data, val_data

class MultiLabelDataset(Dataset):
    def __init__(self, data):
        self.samples = [
            (d["text"], torch.tensor([
                d["labels"]["statement_type"],
                d["labels"]["balance"],
                d["labels"]["period_type"]
            ], dtype=torch.long))
            for d in data
        ]
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    texts, labels = zip(*batch)
    return list(texts), torch.stack(labels)

class AttentionLayer(torch.nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.attn = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim, 1) # Scalar attention score
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply attention over a batch of embeddings.
        x: (batch_size, input_dim)
        returns: (batch_size, input_dim)
        """
        # Simulate sequence: [batch_size, 1, dim]
        x = x.unsqueeze(1)
        weights = torch.softmax(self.attn(x), dim=1)  # (batch_size, 1, 1)
        attended = torch.sum(weights * x, dim=1)      # (batch_size, dim)
        return attended


class GAAPClassifier(pl.LightningModule):
    def __init__(self, batch_size, lr, weight_decay, model_name, units_st, units_bal, units_pt, dropout_rate):
        super().__init__()
        self.batch_size = batch_size
        self.lr = lr
        self.weight_decay = weight_decay
        self.model_name = model_name
        self.encoder = SentenceTransformer(model_name, device=device)
        dim = self.encoder.get_sentence_embedding_dimension()

        self.attn = AttentionLayer(dim)
        self.norm = torch.nn.LayerNorm(dim)

        # self.head_st = torch.nn.Linear(dim, 3) # Dependent on number of classes
        # self.head_bal = torch.nn.Linear(dim, 3) # Dependent on number of classes
        # self.head_pt = torch.nn.Linear(dim, 3) # Dependent on number of classes

        self.head_st = torch.nn.Sequential(
            torch.nn.Linear(dim, units_st),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(units_st, 3)  # Dependent on number of classes
        )

        self.head_bal = torch.nn.Sequential(
            torch.nn.Linear(dim, units_bal),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(units_bal, 3)  # Dependent on number of classes
        )

        self.head_pt = torch.nn.Sequential(
            torch.nn.Linear(dim, units_pt),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(units_pt, 3)  # Dependent on number of classes
        )

        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.save_hyperparameters()

        self.acc_st = MulticlassAccuracy(num_classes=3, average='micro')
        self.acc_bal = MulticlassAccuracy(num_classes=3, average='micro')
        self.acc_pt = MulticlassAccuracy(num_classes=3, average='micro')


    def forward(self, texts):
        with torch.no_grad():
            embeddings = self.encoder.encode(
                texts, convert_to_tensor=True, device=device
            )

        attended = self.attn(embeddings)
        attended = self.norm(attended)

        out_st = self.head_st(attended)
        out_bal = self.head_bal(attended)
        out_pt = self.head_pt(attended)

        return out_st, out_bal, out_pt

    def compute_loss(self, outputs, labels):
        out_st, out_bal, out_pt = outputs
        loss_0 = self.loss_fn(out_st, labels[:, 0])
        loss_1 = self.loss_fn(out_bal, labels[:, 1])
        loss_2 = self.loss_fn(out_pt, labels[:, 2])
        return loss_0 + loss_1 + loss_2, (loss_0, loss_1, loss_2)

    def training_step(self, batch, batch_idx):
        texts, labels = batch
        outputs = self(texts)
        loss, (l0, l1, l2) = self.compute_loss(outputs, labels)
        self.log("train/loss", loss, prog_bar=True)
        self.log("train/loss_statement_type", l0)
        self.log("train/loss_balance", l1)
        self.log("train/loss_period_type", l2)

        pred_st = torch.argmax(outputs[0], dim=1)
        pred_bal = torch.argmax(outputs[1], dim=1)
        pred_pt = torch.argmax(outputs[2], dim=1)

        self.log("train/acc_statement_type", self.acc_st(pred_st, labels[:, 0]))
        self.log("train/acc_balance", self.acc_bal(pred_bal, labels[:, 1]))
        self.log("train/acc_period_type", self.acc_pt(pred_pt, labels[:, 2]))

        return loss

    def validation_step(self, batch, batch_idx):
        texts, labels = batch
        outputs = self(texts)
        loss, (l0, l1, l2) = self.compute_loss(outputs, labels)
        self.log("val/loss", loss, prog_bar=True)
        self.log("val/loss_statement_type", l0)
        self.log("val/loss_balance", l1)
        self.log("val/loss_period_type", l2)

        pred_st = torch.argmax(outputs[0], dim=1)
        pred_bal = torch.argmax(outputs[1], dim=1)
        pred_pt = torch.argmax(outputs[2], dim=1)

        self.log("val/acc_statement_type", self.acc_st(pred_st, labels[:, 0]))
        self.log("val/acc_balance", self.acc_bal(pred_bal, labels[:, 1]))
        self.log("val/acc_period_type", self.acc_pt(pred_pt, labels[:, 2]))

        return loss

    def configure_optimizers(self):
        params = list(self.head_st.parameters()) + \
                 list(self.head_bal.parameters()) + \
                 list(self.head_pt.parameters())
        return torch.optim.AdamW(
            params,
            lr=self.lr,
            weight_decay=self.weight_decay
        )

# class SafePruningCallback(Callback):
#     def __init__(self, trial, monitor: str = "val/loss", history_window: int = 10):
#         super().__init__()
#         self.trial = trial
#         self.monitor = monitor
#         self.history_window = history_window
#         self.recent_losses = []

#     def on_validation_end(self, trainer, pl_module):
#         current = trainer.callback_metrics.get(self.monitor)
#         if current is None:
#             return

#         self.recent_losses.append(current.item())
#         if len(self.recent_losses) > self.history_window:
#             self.recent_losses.pop(0)

#         if (
#             trainer.current_epoch >= self.history_window
#             and len(self.recent_losses) == self.history_window
#         ):
#             avg_loss = sum(self.recent_losses) / self.history_window
#             self.trial.report(avg_loss, step=trainer.current_epoch)
#             if self.trial.should_prune():
#                 print(f"Pruning trial at epoch {trainer.current_epoch} "
#                       f"with avg val/loss: {avg_loss:.4f}")
#                 raise optuna.TrialPruned()
#         else:
#             print(f"Skipping pruning at epoch {trainer.current_epoch}, "
#                   f"not enough history ({len(self.recent_losses)}/{self.history_window})")


def objective(trial):
    batch_size = trial.suggest_int("batch_size", 8, 64, step=8)
    lr = trial.suggest_float("lr", 1e-6, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 0.1)
    units_st = trial.suggest_int("units_st", 8, 128, step=8)
    units_bal = trial.suggest_int("units_bal", 8, 128, step=8)
    units_pt = trial.suggest_int("units_bal", 8, 128, step=8) 
    dropout_rate = trial.suggest_float("dropout_rate", 0, 0.5, step=0.1)

    train_data, val_data = get_split()
    train_loader = DataLoader(MultiLabelDataset(train_data),
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn=collate_fn)
    val_loader = DataLoader(MultiLabelDataset(val_data),
                            batch_size=batch_size,
                            shuffle=False,
                            collate_fn=collate_fn)

    model = GAAPClassifier(
        batch_size=batch_size,
        lr=lr,
        weight_decay=weight_decay,
        units_st=units_st,
        units_bal=units_bal,
        units_pt=units_pt,
        dropout_rate=dropout_rate,
        model_name=MODEL_NAME
    )

    log_dir = os.path.join(OUTPUT_PATH, f"trial_{trial.number}")
    logger = TensorBoardLogger(save_dir=log_dir, name="logs")

    early_stop = EarlyStopping(monitor="val/loss", patience=PATIENCE, mode="min")
    checkpoint_cb = ModelCheckpoint(
        dirpath=os.path.join(log_dir, "checkpoints"),
        filename="best",
        monitor="val/loss",
        mode="min",
        save_top_k=1,
        save_weights_only=True
    )

    trainer = pl.Trainer(
        logger=logger,
        max_epochs=EPOCHS,
        enable_checkpointing=True,
        callbacks=[
            early_stop,
            checkpoint_cb,
            # SafePruningCallback(trial, monitor="val/loss")
        ],
        log_every_n_steps=1,
        accelerator="auto",
        devices=1
    )

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    print("Best checkpoint:", checkpoint_cb.best_model_path)
    return trainer.callback_metrics["val/loss"].item()


Seed set to 42


Using device: mps


In [2]:
from optuna.pruners import HyperbandPruner
from optuna.samplers import TPESampler

study = optuna.create_study(
     direction="minimize",
     sampler=TPESampler(),
     study_name="gaap_tuning",
     storage=f"sqlite:///{OPTUNA_DB_PATH}",
     load_if_exists=True,
     pruner=HyperbandPruner(
        min_resource=5,    # usually epochs
        max_resource=EPOCHS,  # total epochs
        reduction_factor=3
    )
)
study.optimize(objective, n_trials=200)
print("Best params:", study.best_params)

print("Best trial:")
best_trial = study.best_trial
print("  Value: %.6f" % best_trial.value)
print("  Params:")
for k, v in best_trial.params.items():
    print("    %s: %s" % (k, v))



[I 2025-03-22 20:56:45,583] A new study created in RDB with name: gaap_tuning
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name     | Type                | Params | Mode 
---------------------------------------------------------
0 | encoder  | SentenceTransformer | 335 M  | train
1 | attn     | AttentionLayer      | 131 K  | train
2 | norm     | LayerNorm           | 2.0 K  | train
3 | head_st  | Sequential          | 131 K  | train
4 | head_bal | Sequential          | 131 K  | train
5 | head_pt  | Sequential          | 131 K  | train
6 | loss_fn  | CrossEntropyLoss    | 0      | train
7 | acc_st   | MulticlassAccuracy  | 0      | train
8 | acc_bal  | MulticlassAccuracy  | 0      | train
9 | acc_pt   | MulticlassAccuracy  | 0      | train
---------------------------------------------------------
335 M     Trainable params
0         Non-trainable params
335 M     Total params
1,342.680 Total estimated model param

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Volumes/2TB Storage Vault/rust-sec-fetcher/python/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Volumes/2TB Storage Vault/rust-sec-fetcher/python/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

[I 2025-03-22 21:00:34,595] Trial 0 finished with value: 0.847109854221344 and parameters: {'batch_size': 48, 'lr': 0.000722000755763429, 'weight_decay': 0.09144624512486152, 'units_st': 128, 'units_bal': 128, 'dropout_rate': 0.1}. Best is trial 0 with value: 0.847109854221344.


Best checkpoint: /Volumes/2TB Storage Vault/rust-sec-fetcher/python/data/fine_tuned_gaap_classifier/trial_0/checkpoints/best.ckpt


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name     | Type                | Params | Mode 
---------------------------------------------------------
0 | encoder  | SentenceTransformer | 335 M  | train
1 | attn     | AttentionLayer      | 131 K  | train
2 | norm     | LayerNorm           | 2.0 K  | train
3 | head_st  | Sequential          | 115 K  | train
4 | head_bal | Sequential          | 106 K  | train
5 | head_pt  | Sequential          | 106 K  | train
6 | loss_fn  | CrossEntropyLoss    | 0      | train
7 | acc_st   | MulticlassAccuracy  | 0      | train
8 | acc_bal  | MulticlassAccuracy  | 0      | train
9 | acc_pt   | MulticlassAccuracy  | 0      | train
---------------------------------------------------------
335 M     Trainable params
0         Non-trainable params
335 M     Total params
1,342.417 Total estimated model params size (MB)
29        Modules in train mode
444       Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Volumes/2TB Storage Vault/rust-sec-fetcher/python/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Volumes/2TB Storage Vault/rust-sec-fetcher/python/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...
[W 2025-03-22 21:10:02,060] Trial 1 failed with parameters: {'batch_size': 8, 'lr': 0.005796186821065959, 'weight_decay': 0.06349790279915665, 'units_st': 112, 'units_bal': 104, 'dropout_rate': 0.30000000000000004} because of the following error: NameError("name 'exit' is not defined").
Traceback (most recent call last):
  File "/Volumes/2TB Storage Vault/rust-sec-fetcher/python/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Volumes/2TB Storage Vault/rust-sec-fetcher/python/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Volumes/2TB Storage Vault/rust-sec-fetcher/python/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1012, in _run
    results = self._r

NameError: name 'exit' is not defined