In [1]:
import pandas as pd
import torch
import torch.nn as nn
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import AutoModel, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
from transformers import  Trainer, TrainingArguments

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
def get_company_dataset(file_path, tokenizer, max_length=128, test_size=0.2):
    """
    Reads a CSV file, processes it for ModernBERT Multi-Head classification,
    and returns Train/Val datasets + dimensions for the model heads.
    """
    # 1. Load Data
    df = pd.read_csv(f"{company}.csv")

    # 2. Label Encoding (Local to this company/dataset)
    # If you need global consistency across companies, pass fitted encoders instead.
    type_encoder = LabelEncoder()
    code_encoder = LabelEncoder()

    df['cc_type_id'] = type_encoder.fit_transform(df['cc_type'])
    df['cc_code_id'] = code_encoder.fit_transform(df['cc_code'])

    # Calculate dimensions for the model heads
    num_type_labels = len(type_encoder.classes_)
    num_code_labels = len(code_encoder.classes_)

    print(f"Dataset Loaded: {len(df)} records")
    print(f"Found {num_type_labels} Transaction Types and {num_code_labels} GL Codes.")

    # 3. Stratified Train/Test Split
    # We create a temporary 'stratify_col' to ensure both Type and Code distributions are preserved
    df['stratify_col'] = df['cc_type'].astype(str) + "_" + df['cc_code'].astype(str)

    train_df, val_df = train_test_split(
        df,
        test_size=test_size,
        random_state=42,
        stratify=df['stratify_col']  # Critical for rare GL codes
    )

    # Cleanup auxiliary columns
    cols_to_keep = ['merchant_group', 'merchant_name', 'cc_type_id', 'cc_code_id']
    train_df = train_df[cols_to_keep]
    val_df = val_df[cols_to_keep]

    # 4. Convert to Hugging Face Datasets
    train_dataset = Dataset.from_pandas(train_df, preserve_index=False)
    val_dataset = Dataset.from_pandas(val_df, preserve_index=False)

    # 5. Tokenization Function
    def preprocess_function(examples):
        # Create input: "[CLS] Merchant Group [SEP] Merchant Name [SEP]"
        tokenized_inputs = tokenizer(
            text=examples["merchant_group"],
            text_pair=examples["merchant_name"],
            truncation=True,
            max_length=max_length,
            padding="max_length"
        )

        # Map to specific arguments expected by our custom ModernBertMultiHead
        tokenized_inputs["labels_type"] = examples["cc_type_id"]
        tokenized_inputs["labels_code"] = examples["cc_code_id"]

        return tokenized_inputs

    # 6. Apply Processing
    # We remove the text columns to leave only the tensors
    remove_cols = train_dataset.column_names

    train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=remove_cols)
    val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=remove_cols)

    # 7. Set Format for PyTorch
    target_columns = ["input_ids", "attention_mask", "labels_type", "labels_code"]
    train_dataset.set_format(type="torch", columns=target_columns)
    val_dataset.set_format(type="torch", columns=target_columns)

    return {
        "train": train_dataset,
        "val": val_dataset,
        "num_type_labels": num_type_labels,
        "num_code_labels": num_code_labels,
        "encoders": {"type": type_encoder, "code": code_encoder}
    }

In [3]:
class ModernBertMultiHead(nn.Module):
    def __init__(self, model_name, num_type_labels, num_code_labels):
        super().__init__()
        # Load base ModernBERT model
        self.bert = AutoModel.from_pretrained(model_name)
        
        self.config = self.bert.config

        hidden_size = self.bert.config.hidden_size
        print(hidden_size)
        # Define two separate heads
        self.type_head = nn.Linear(hidden_size, num_type_labels)
        self.code_head = nn.Linear(hidden_size, num_code_labels)

        # Dropout for regularization
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask, labels_type=None, labels_code=None, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # ModernBERT typically uses Mean Pooling or [CLS] (index 0)
        # We use index 0 (CLS-equivalent) for classification
        sequence_output = outputs.last_hidden_state
        pooled_output = sequence_output[:, 0, :]
        pooled_output = self.dropout(pooled_output)

        # Get logits from both heads
        logits_type = self.type_head(pooled_output)
        logits_code = self.code_head(pooled_output)

        loss = None
        if labels_type is not None and labels_code is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss_type = loss_fct(logits_type, labels_type)
            loss_code = loss_fct(logits_code, labels_code)

            loss = (2.0 * loss_type) + (1.0 * loss_code) 
           # loss = loss_type + loss_code  # Sum losses

        return {"loss": loss, "logits_type": logits_type, "logits_code": logits_code}


In [4]:
# 1. Setup
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
companies = ["company_A", "company_B", "company_C"]  # Your list of companies

#Define LoRA Config
# modules_to_save is CRITICAL here. It ensures your custom heads are trainable.
peft_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,  # Using custom model, so generic task
    inference_mode=False,
    r=64,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["Wqkv", "Wo", "Wi", "W2"],  # ModernBERT target modules (verify exact names)
    modules_to_save=["type_head", "code_head"]  # TRAIN THESE LAYERS
)


In [None]:
for company in companies:
    print(f"Training adapter for: {company}")

    # A. Get Data for this Company
    # get_company_dataset returns a formatted HF Dataset
    # Format: "Merchant Group: {grp} [SEP] Merchant Name: {name}"
    # Get the data bundle
    data_bundle = get_company_dataset(company, tokenizer)


    # 2. Extract components
    train_dataset = data_bundle["train"]
    val_dataset = data_bundle["val"]
    num_type_labels = data_bundle["num_type_labels"]
    num_code_labels = data_bundle["num_code_labels"]
    base_model = ModernBertMultiHead(model_id, num_type_labels, num_code_labels)
    model = get_peft_model(base_model, peft_config)
    # B. Manage Adapters
    # If it's the first run, the 'default' adapter is active.
    # For subsequent runs, we add a new adapter.
    adapter_name = f"adapter_{company}"

    try:
        model.add_adapter(adapter_name, peft_config)
    except ValueError:
        pass  # Adapter might already exist if resuming

    model.set_adapter(adapter_name)

    # C. Train
    training_args = TrainingArguments(
        output_dir=f"./results/{company}",
        per_device_train_batch_size=64,
        num_train_epochs=5,
        save_strategy="no",  # We save manually to be safe
        learning_rate=1e-3,
        remove_unused_columns=False  # Important for custom models
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
    )

    trainer.train()

    metrics = trainer.evaluate(eval_dataset=val_dataset)
    print(metrics)
    
    # D. Save Adapter & Heads
    # This saves the LoRA weights AND the 'modules_to_save' (heads) to the folder
    model.save_pretrained(f"./final_adapters/{company}")

    # E. Cleanup to free VRAM for next company
    # delete_adapter removes the LoRA weights from memory
    # Note: 'modules_to_save' weights might persist in the base model state dict
    # if not carefully handled, but delete_adapter handles the PEFT part.
    model.delete_adapter(adapter_name)



Training adapter for: company_A
Dataset Loaded: 10000 records
Found 8 Transaction Types and 25 GL Codes.


Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:00<00:00, 30685.97 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 28603.70 examples/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████| 134/134 [00:00<00:00, 1020.18it/s, Materializing param=layers.21.mlp_norm.weight]
ModernBertModel LOAD REPORT from: answerdotai/ModernBERT-base
Key               | Status     |  | 
------------------+------------+--+-
decoder.bias      | UNEXPECTED |  | 
head.norm.weight  | UNEXPECTED |  | 
head.dense.weight | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


768


Step,Training Loss


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
from torchmetrics.classification import AUROC

# 1. Get Raw Predictions from Trainer
# predictions is a tuple: ( (logits_type, logits_code), label_ids, metrics )
# or a dict if your model output dict was unpacked differently. 
# For custom models, it usually returns the tuple of logits.
predictions_output = trainer.predict(val_dataset)

# Unpack logits (Trainer returns them as a tuple if model returns dict/tuple)
# Note: Check the order! It matches your forward() return or output dict order.
# Based on my previous code: return {"loss": loss, "logits_type": ..., "logits_code": ...}
# The Trainer often returns logits as a tuple: (logits_type, logits_code)
logits_type = predictions_output.predictions[0] 
logits_code = predictions_output.predictions[1]

# Unpack True Labels (Trainer stacks them in order of dataset columns)
# If your dataset has 'labels_type' and 'labels_code', Trainer might aggregate them.
# SAFER WAY: Extract directly from dataset to be 100% sure of order
true_types = val_dataset['labels_type']
true_codes = val_dataset['labels_code']

# 2. Convert Logits to Class IDs
pred_types = np.argmax(logits_type, axis=1)
pred_codes = np.argmax(logits_code, axis=1)

# 3. Generate Confusion Matrices
cm_type = confusion_matrix(true_types, pred_types)
cm_code = confusion_matrix(true_codes, pred_codes)

# --- Visualization Function ---
def plot_cm(cm, class_names, title):
    fig, ax = plt.subplots(figsize=(10, 8))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    
    # Filter for readability if you have too many classes
    if len(class_names) > 20:
        plt.title(f"{title} (Showing first 20 classes)", fontsize=14)
        disp.plot(ax=ax, cmap='Blues', xticks_rotation=45, include_values=False) # Hide numbers if dense
    else:
        plt.title(title, fontsize=14)
        disp.plot(ax=ax, cmap='Blues', xticks_rotation=45)
    
    plt.show()

# 4. Plot
# Get class names from your encoders
type_classes = data_bundle["encoders"]["type"].classes_
code_classes = data_bundle["encoders"]["code"].classes_

print("Generating Confusion Matrix for Transaction Types...")
plot_cm(cm_type, type_classes, "Confusion Matrix: Transaction Types")

print("Generating Confusion Matrix for GL Codes...")
plot_cm(cm_code, code_classes, "Confusion Matrix: GL Codes")

precision_types = precision_score(true_types, pred_types, average='weighted')
recall_types = recall_score(true_types, pred_types, average='weighted')
f1_types = f1_score(true_types, pred_types, average='weighted')
print("Precision, Recall F1 Types % % %", precision_types, recall_types, f1_types)
precision_codes = precision_score(true_codes, pred_codes, average='weighted')
recall_codes = recall_score(true_codes, pred_codes, average='weighted')
f1_codes = f1_score(true_codes, pred_codes, average='weighted')
print("Precision, Recall F1 Codes % % %", precision_codes, recall_codes, f1_codes)
