In [None]:
import json
import os
import math
import torch
from collections import Counter
from torch.utils.data import DataLoader, Dataset
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.tensorboard import SummaryWriter

# === CONFIG ===
DATASET_PATH = "data/us_gaap_multilabel_training_data.json"
MODEL_NAME = "BAAI/bge-large-en-v1.5"
OUTPUT_PATH = "fine_tuned_gaap_classifier"
BATCH_SIZE = 16
EPOCHS = 30
PATIENCE = 3
VALIDATION_SPLIT = 0.2
CHECKPOINT_DIR = os.path.join(OUTPUT_PATH, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# === Load and stratify (with fallback for rare categories) ===
with open(DATASET_PATH, "r") as f:
    data = json.load(f)

def label_hash(entry):
    return f"{entry['labels']['statement_type']}_{entry['labels']['balance']}_{entry['labels']['period_type']}"

# Step 1: Count label hashes
label_hashes = [label_hash(d) for d in data]
hash_counts = Counter(label_hashes)

# Step 2: Separate common and rare samples
common_data = []
common_labels = []
rare_data = []

for i, lh in enumerate(label_hashes):
    if hash_counts[lh] >= 2:
        common_data.append(data[i])
        common_labels.append(lh)
    else:
        rare_data.append(data[i])

# Step 3: Stratified split on common data
splitter = StratifiedShuffleSplit(n_splits=1, test_size=VALIDATION_SPLIT, random_state=42)
train_idx, val_idx = next(splitter.split(common_data, common_labels))
train_data = [common_data[i] for i in train_idx]
val_data = [common_data[i] for i in val_idx]

# Step 4: Add rare samples back to training set
train_data.extend(rare_data)

# === Build custom dataset ===
class MultiLabelDataset(Dataset):
    def __init__(self, data):
        self.samples = [
            (d["text"], torch.tensor([
                d["labels"]["statement_type"],
                d["labels"]["balance"],
                d["labels"]["period_type"]
            ], dtype=torch.long))
            for d in data
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def collate_batch(batch):
    texts, labels = zip(*batch)
    return list(texts), torch.stack(labels)

model = SentenceTransformer(MODEL_NAME, device=device)
train_dataset = MultiLabelDataset(train_data)
val_dataset = MultiLabelDataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_batch)

loss_fn = torch.nn.CrossEntropyLoss()
writer = SummaryWriter()

# === Classifier head ===
embedding_dim = model.get_sentence_embedding_dimension()
classifier = torch.nn.Linear(embedding_dim, 9)  # 3 labels * 3 classes each
classifier.to(device)
optimizer = torch.optim.AdamW(classifier.parameters(), lr=2e-5, weight_decay=0.01)

# === Training loop ===
best_val_loss = float("inf")
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    classifier.train()
    total_loss = 0

    for texts, labels in train_loader:
        embeddings = model.encode(texts, convert_to_tensor=True, device=device)
        labels = labels.to(device)
        outputs = classifier(embeddings)

        loss = (
            loss_fn(outputs[:, 0:3], labels[:, 0]) +
            loss_fn(outputs[:, 3:6], labels[:, 1]) +
            loss_fn(outputs[:, 6:9], labels[:, 2])
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    writer.add_scalar("Loss/Train", avg_train_loss, epoch)

    # === Validation ===
    model.eval()
    classifier.eval()
    val_loss = 0
    correct = [0, 0, 0]
    total = 0

    with torch.no_grad():
        for texts, labels in val_loader:
            embeddings = model.encode(texts, convert_to_tensor=True, device=device)
            labels = labels.to(device)
            outputs = classifier(embeddings)

            val_loss += (
                loss_fn(outputs[:, 0:3], labels[:, 0]) +
                loss_fn(outputs[:, 3:6], labels[:, 1]) +
                loss_fn(outputs[:, 6:9], labels[:, 2])
            ).item()

            preds = torch.stack([
                torch.argmax(outputs[:, 0:3], dim=1),
                torch.argmax(outputs[:, 3:6], dim=1),
                torch.argmax(outputs[:, 6:9], dim=1),
            ], dim=1)

            correct[0] += (preds[:, 0] == labels[:, 0]).sum().item()
            correct[1] += (preds[:, 1] == labels[:, 1]).sum().item()
            correct[2] += (preds[:, 2] == labels[:, 2]).sum().item()
            total += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    writer.add_scalar("Loss/Val", avg_val_loss, epoch)
    writer.add_scalar("Accuracy/Statement", correct[0] / total, epoch)
    writer.add_scalar("Accuracy/Balance", correct[1] / total, epoch)
    writer.add_scalar("Accuracy/Period", correct[2] / total, epoch)

    print(f"[Epoch {epoch}] Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    print(f"  Statement Acc: {correct[0]/total:.4f}, Balance Acc: {correct[1]/total:.4f}, Period Acc: {correct[2]/total:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(classifier.state_dict(), os.path.join(CHECKPOINT_DIR, f"best_classifier.pt"))
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

torch.save(classifier.state_dict(), os.path.join(OUTPUT_PATH, "final_classifier.pt"))
print("✅ Finished training.")
