<span style="font-weight: bold; font-size: 18px;">**Multi-Label Posture Classification: Model Development Strategy**<br><br>

We propose a comparative evaluation of two complementary modeling approaches to address the multi-label posture prediction task, each offering distinct advantages for legal document classification.

**Baseline Approach: Bag-of-Words Models**<br>

Our initial baseline leverages traditional bag-of-words representations (TF-IDF, BM25) combined with multi-label classifiers, justified by several key factors:

<div style="margin-left: 20px;"><b>• Computational Efficiency:</b> Lightweight architecture enables rapid prototyping and establishes performance baselines without GPU requirements</div>
<div style="margin-left: 20px;"><b>• Statistical Robustness:</b> Word-frequency features provide interpretable, domain-agnostic representations suitable for legal terminology analysis</div>
<div style="margin-left: 20px;"><b>• Multi-Label Compatibility:</b> Well-established integration with multi-label algorithms (One-vs-Rest, Binary Relevance, Label Powerset)</div>
<div style="margin-left: 20px;"><b>• Baseline Establishment:</b> Provides interpretable performance benchmarks for evaluating more complex architectures</div>

**Advanced Approach: Transformer-Based Models (ModernBERT)**<br>

Our primary model leverages ModernBERT encoder architecture, specifically designed to address the limitations of traditional BERT for our use case:

<div style="margin-left: 20px;"><b>• Extended Context Coverage:</b> ModernBERT's 8,192-token context window accommodates ~90% of our corpus without truncation, preserving critical legal context that may span entire documents</div>

<div style="margin-left: 20px;"><b>• Contextual Understanding:</b> Unlike bag-of-words approaches, transformer architectures capture:
  <div style="margin-left: 40px;">- Long-range dependencies between legal arguments</div>
  <div style="margin-left: 40px;">- Positional relationships between procedural elements</div>
  <div style="margin-left: 40px;">- Semantic nuances distinguishing similar posture categories</div>
</div>

<div style="margin-left: 20px;"><b>• Multi-Label Architecture:</b> The encoder's [CLS] token representation can be effectively coupled with multi-label classification heads, enabling simultaneous prediction of multiple postures</div>

<div style="margin-left: 20px;"><b>• Legal Domain Adaptation:</b> Pre-trained language understanding provides superior handling of complex legal terminology and document structure</div>

**Comparative Justification:**<br>

This dual-approach strategy enables comprehensive evaluation of feature representation impact on multi-label performance, ranging from traditional statistical methods to state-of-the-art contextual understanding, ultimately identifying the optimal balance between computational efficiency and classification accuracy for legal posture prediction.

</span>

In [None]:
%%capture --no-stderr
%pip install -r /mnt/d/TR-Project/requirements.txt

## Bag-of-word (TFIDF): Benchmark

In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle
import xgboost as xgb
import lightgbm as lgb
from lightgbm import early_stopping, log_evaluation
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
# from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, hamming_loss
from sklearn.base import BaseEstimator, ClassifierMixin

from sklearn.metrics import (
    precision_score, recall_score, f1_score, 
    roc_auc_score, average_precision_score,
    hamming_loss, jaccard_score, accuracy_score
)
from sklearn.preprocessing import MultiLabelBinarizer

import warnings
warnings.filterwarnings('ignore')

In [2]:
### Load Dataset for model training and evaluation ###
data_path=os.path.join(os.getcwd(), 'processed_data')
with open(os.path.join(data_path,'train_arrays.pkl'), 'rb') as f:
    train_data = pickle.load(f)
    X_train = train_data['X_train']
    y_train = train_data['y_train']

with open(os.path.join(data_path,'val_arrays.pkl'), 'rb') as f:
    val_data = pickle.load(f)
    X_val = val_data['X_val']
    y_val = val_data['y_val']

with open(os.path.join(data_path,'test_arrays.pkl'), 'rb') as f:
    test_data = pickle.load(f)
    X_test = test_data['X_test']
    y_test = test_data['y_test']

with open(os.path.join(data_path,'class_name.pkl'), 'rb') as f:
    class_name_data = pickle.load(f)
    class_name = class_name_data['class_name']


In [3]:
# Create TF-IDF vectorizer
# Using parameters optimized for legal text
tfidf = TfidfVectorizer(
    max_features=10000,  # Limit features for computational efficiency
    stop_words='english',
    ngram_range=(1, 2),  # Include unigrams and bigrams
    min_df=5,           # Ignore terms that appear in fewer than 5 documents
    max_df=0.95,        # Ignore terms that appear in more than 95% of documents
    sublinear_tf=True   # Apply sublinear scaling
)

print("Fitting TF-IDF vectorizer...")
X_train_tfidf = tfidf.fit_transform(X_train)
X_val_tfidf = tfidf.transform(X_val)
X_test_tfidf = tfidf.transform(X_test)

print(f"TF-IDF matrix shape (train): {X_train_tfidf.shape}")
print(f"TF-IDF matrix shape (val): {X_val_tfidf.shape}")
print(f"TF-IDF matrix shape (test): {X_test_tfidf.shape}")
print(f"Vocabulary size: {len(tfidf.vocabulary_)}")

# Show some sample features
feature_names = tfidf.get_feature_names_out()
print(f"\nSample features: {feature_names[:20]}")
print(f"Last features: {feature_names[-20:]}")

Fitting TF-IDF vectorizer...
TF-IDF matrix shape (train): (11597, 10000)
TF-IDF matrix shape (val): (2485, 10000)
TF-IDF matrix shape (test): (2486, 10000)
Vocabulary size: 10000

Sample features: ['00' '000' '000 00' '000 000' '001' '01' '010' '02' '020' '03' '030' '04'
 '040' '05' '06' '07' '08' '09' '10' '10 000']
Last features: ['years prior' 'years prison' 'years supervised' 'yes' 'yes sir' 'yield'
 'york' 'york city' 'york county' 'york law' 'york state' 'young'
 'younger' 'youth' 'zba' 'zero' 'zone' 'zoning' 'zoning board'
 'zoning ordinance']


In [4]:
class Train_XGBoost(BaseEstimator, ClassifierMixin):
    """XGBoost classifier with validation-based early stopping for multi-label"""
    
    def __init__(self, **xgb_params):
        self.xgb_params = xgb_params
        self.models_ = []
        self.n_classes_ = None
        
    def fit(self, X, y, X_val=None, y_val=None):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if X_val is not None and len(y_val.shape) == 1:
            y_val = y_val.reshape(-1, 1)
            
        self.n_classes_ = y.shape[1]
        self.models_ = []
        
        for i in tqdm(range(self.n_classes_), total=self.n_classes_, leave=True, position=0):
            
            y_single = y[:, i]
            
            # Skip if no positive samples
            if y_single.sum() == 0:
                self.models_.append(None)
                continue
            
            model = xgb.XGBClassifier(**self.xgb_params)
            
            if X_val is not None and y_val is not None:
                y_val_single = y_val[:, i]
                model.fit(
                    X, y_single,
                    eval_set=[(X_val, y_val_single)],
                    verbose=False
                )
            else:
                model.fit(X, y_single)
            
            self.models_.append(model)
        
        return self
    
    def predict(self, X):
        predictions = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                predictions[:, i] = model.predict(X)
        
        return predictions
    
    def predict_proba(self, X):
        probabilities = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                proba = model.predict_proba(X)
                # Handle case where only one class is present
                if proba.shape[1] == 1:
                    probabilities[:, i] = 0  # All negative class
                else:
                    probabilities[:, i] = proba[:, 1]  # Positive class probability
        
        return probabilities

class Train_LGBM(BaseEstimator, ClassifierMixin):
    """LightGBM classifier with validation-based early stopping for multi-label"""
    
    def __init__(self, **lgb_params):
        self.lgb_params = lgb_params
        self.models_ = []
        self.n_classes_ = None
        
    def fit(self, X, y, X_val=None, y_val=None):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if X_val is not None and len(y_val.shape) == 1:
            y_val = y_val.reshape(-1, 1)
            
        self.n_classes_ = y.shape[1]
        self.models_ = []
        
        for i in tqdm(range(self.n_classes_), total=self.n_classes_, leave=True, position=0):
            
            y_single = y[:, i]
            
            # Skip if no positive samples
            if y_single.sum() == 0:
                self.models_.append(None)
                continue
            
            model = lgb.LGBMClassifier(**self.lgb_params)
            
            if X_val is not None and y_val is not None:
                y_val_single = y_val[:, i]
                model.fit(
                    X, y_single,
                    eval_set=[(X_val, y_val_single)],
                    callbacks=[
                        lgb.early_stopping(10, verbose=False),
                        lgb.log_evaluation(0)
                    ]
                )
            else:
                model.fit(X, y_single)
            
            self.models_.append(model)
        
        return self
    
    def predict(self, X):
        predictions = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                predictions[:, i] = model.predict(X)
        
        return predictions
    
    def predict_proba(self, X):
        probabilities = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                proba = model.predict_proba(X)
                # Handle case where only one class is present
                if proba.shape[1] == 1:
                    probabilities[:, i] = 0  # All negative class
                else:
                    probabilities[:, i] = proba[:, 1]  # Positive class probability
        
        return probabilities

class Train_logistic(BaseEstimator, ClassifierMixin):
    """Logistic Regression classifier with validation monitoring for multi-label"""
    
    def __init__(self, **lr_params):
        self.lr_params = lr_params
        self.models_ = []
        self.n_classes_ = None
        self.validation_scores_ = []
        
    def fit(self, X, y, X_val=None, y_val=None):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if X_val is not None and len(y_val.shape) == 1:
            y_val = y_val.reshape(-1, 1)
            
        self.n_classes_ = y.shape[1]
        self.models_ = []
        self.validation_scores_ = []
        
        for i in tqdm(range(self.n_classes_), total=self.n_classes_, leave=True, position=0):
            
            y_single = y[:, i]
            
            # Skip if no positive samples
            if y_single.sum() == 0:
                self.models_.append(None)
                self.validation_scores_.append(0.0)
                continue
            
            model = LogisticRegression(**self.lr_params)
            model.fit(X, y_single)
            
            # Calculate validation score if validation data provided
            if X_val is not None and y_val is not None:
                y_val_single = y_val[:, i]
                val_score = model.score(X_val, y_val_single)
                self.validation_scores_.append(val_score)
            else:
                self.validation_scores_.append(None)
            
            self.models_.append(model)
        
        return self
    
    def predict(self, X):
        predictions = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                predictions[:, i] = model.predict(X)
        
        return predictions
    
    def predict_proba(self, X):
        probabilities = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                proba = model.predict_proba(X)
                # Handle case where only one class is present
                if proba.shape[1] == 1:
                    probabilities[:, i] = 0  # All negative class
                else:
                    probabilities[:, i] = proba[:, 1]  # Positive class probability
        
        return probabilities
    
    def get_validation_scores(self):
        """Return validation scores for each label"""
        return self.validation_scores_

class Train_RandomForest(BaseEstimator, ClassifierMixin):
    """Random Forest classifier with validation monitoring for multi-label"""
    
    def __init__(self, **rf_params):
        self.rf_params = rf_params
        self.models_ = []
        self.n_classes_ = None
        self.validation_scores_ = []
        self.feature_importances_ = []
        
    def fit(self, X, y, X_val=None, y_val=None):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if X_val is not None and len(y_val.shape) == 1:
            y_val = y_val.reshape(-1, 1)
            
        self.n_classes_ = y.shape[1]
        self.models_ = []
        self.validation_scores_ = []
        self.feature_importances_ = []
        
        for i in tqdm(range(self.n_classes_), total=self.n_classes_, leave=True, position=0):
            
            y_single = y[:, i]
            
            # Skip if no positive samples
            if y_single.sum() == 0:
                self.models_.append(None)
                self.validation_scores_.append(0.0)
                self.feature_importances_.append(None)
                continue
            
            model = RandomForestClassifier(**self.rf_params)
            model.fit(X, y_single)
            
            # Store feature importances
            self.feature_importances_.append(model.feature_importances_)
            
            # Calculate validation score if validation data provided
            if X_val is not None and y_val is not None:
                y_val_single = y_val[:, i]
                val_score = model.score(X_val, y_val_single)
                self.validation_scores_.append(val_score)
            else:
                self.validation_scores_.append(None)
            
            self.models_.append(model)
        
        return self
    
    def predict(self, X):
        predictions = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                predictions[:, i] = model.predict(X)
        
        return predictions
    
    def predict_proba(self, X):
        probabilities = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                proba = model.predict_proba(X)
                # Handle case where only one class is present
                if proba.shape[1] == 1:
                    probabilities[:, i] = 0  # All negative class
                else:
                    probabilities[:, i] = proba[:, 1]  # Positive class probability
        
        return probabilities
    
    def get_validation_scores(self):
        """Return validation scores for each label"""
        return self.validation_scores_
    
    def get_feature_importances(self):
        """Return feature importances for each label"""
        return self.feature_importances_

In [5]:
def training_function_with_validation(X_train, y_train, X_val, y_val, model_type='lightgbm'):
    """
    Enhanced training function with proper validation control for multi-label classification
    """
    
    print(f"Training {model_type} with validation control...")
    print(f"X_train shape: {X_train.shape}")
    print(f"y_train shape: {y_train.shape}")
    print(f"X_val shape: {X_val.shape}")
    print(f"y_val shape: {y_val.shape}")
    
    if model_type == 'lightgbm':
        model = Train_LGBM(
            random_state=42,
            n_estimators=200,  # More estimators for early stopping
            max_depth=6,
            learning_rate=0.1,
            subsample=0.8,
            colsample_bytree=0.8,
            verbosity=-1,
            early_stopping_rounds=10
        )
    elif model_type == 'xgboost':
        model = Train_XGBoost(
            random_state=42,
            n_estimators=200,  # More estimators for early stopping
            max_depth=6,
            learning_rate=0.1,
            subsample=0.8,
            colsample_bytree=0.8,
            eval_metric='logloss',
            verbosity=0,
            early_stopping_rounds=10
        )
    elif model_type == 'logistic':
        model = Train_logistic(
            random_state=42,
            max_iter=1000,
            C=1.0,
            solver='liblinear',
            class_weight='balanced'  # Handle class imbalance
        )
    elif model_type == 'randomforest':
        model = Train_RandomForest(
            random_state=42,
            n_estimators=100,
            max_depth=10,
            min_samples_split=5,
            min_samples_leaf=2,
            max_features='sqrt',
            class_weight='balanced',  # Handle class imbalance
            n_jobs=-1
        )
    else:
        raise ValueError("Supported model types: 'lightgbm', 'xgboost', 'logistic', 'randomforest'")
    
    # Fit with validation data
    model.fit(X_train, y_train, X_val, y_val)
    
    # Make predictions
    y_pred_train = model.predict(X_train)
    y_pred_val = model.predict(X_val)

    # For ROC-AUC and PR-AUC, need probability estimates
    if hasattr(model, "predict_proba"):
        y_prob_val = model.predict_proba(X_val)
    else:
        y_prob_val = None
    
    # Calculate metrics
    # train_acc = accuracy_score(y_train, y_pred_train)
    val_acc = accuracy_score(y_val, y_pred_val)
    # train_f1 = f1_score(y_train, y_pred_train, average='micro')


    val_f1_samples = f1_score(y_val, y_pred_val, average='samples', zero_division=0)
    val_f1_micro = f1_score(y_val, y_pred_val, average='micro', zero_division=0)
    val_f1_macro = f1_score(y_val, y_pred_val, average='macro', zero_division=0)
    val_f1_weighted = f1_score(y_val, y_pred_val, average='weighted', zero_division=0)

    val_jaccard_samples = jaccard_score(y_val, y_pred_val, average='samples', zero_division=0)
    val_jaccard_macro = jaccard_score(y_val, y_pred_val, average='macro', zero_division=0)
    val_jaccard_weighted = jaccard_score(y_val, y_pred_val, average='weighted', zero_division=0)

    # Calculate hamming loss (lower is better)
    train_hamming = hamming_loss(y_train, y_pred_train)
    val_hamming = hamming_loss(y_val, y_pred_val)
    
    # Calculate overfitting gaps for different metrics
    hamming_gap = val_hamming - train_hamming  # Note: val - train because lower hamming is better

    if y_prob_val is not None:
        try:
            val_roc_auc_macro = roc_auc_score(y_val, y_prob_val, average="macro")
            val_roc_auc_micro = roc_auc_score(y_val, y_prob_val, average="micro")
            val_roc_auc_weighted = roc_auc_score(y_val, y_prob_val, average="weighted")

        except ValueError as e:
            print(f"Warning: ROC-AUC calculation failed: {e}")
            val_roc_auc_macro=0.0
            val_roc_auc_micro=0.0
            val_roc_auc_weighted=0.0
        try:
            val_pr_auc_macro = average_precision_score(y_val, y_prob_val, average="macro")
            val_pr_auc_micro = average_precision_score(y_val, y_prob_val, average="weighted")
            val_pr_auc_weighted = average_precision_score(y_val, y_prob_val, average="micro")
        except ValueError as e:
            print(f"Warning: PR-AUC calculation failed: {e}")
            val_pr_auc_macro=0.0
            val_pr_auc_micro=0.0
            val_pr_auc_weighted=0.0
    else:
        val_roc_auc_macro=0.0
        val_roc_auc_micro=0.0
        val_roc_auc_weighted=0.0
        val_pr_auc_macro=0.0
        val_pr_auc_micro=0.0
        val_pr_auc_weighted=0.0
    
    print(f"Training completed!")
    print(f"Val Accuracy: {val_acc:.4f}")
    print(f"Val F1 samples: {val_f1_samples:.4f}")
    print(f"Val F1 macro: {val_f1_macro:.4f}")
    print(f"Val F1 micro: {val_f1_micro:.4f}")
    print(f"Val F1 weighted: {val_f1_weighted:.4f}")
    print(f"Train Hamming Loss: {train_hamming:.4f}")
    print(f"Val Hamming Loss: {val_hamming:.4f}")
    print(f"Overfitting Gap (Hamming): {hamming_gap:.4f}")
    
    return model, {
        'val_accuracy': val_acc,
        'val_f1_samples': val_f1_samples,
        'val_f1_macro': val_f1_macro,
        'val_f1_micro': val_f1_micro,
        'val_f1_weighted': val_f1_weighted,
        'val_hamming_loss': val_hamming,
        "val_jaccard_samples":val_jaccard_samples,
        "val_jaccard_macro":val_jaccard_macro,
        "val_jaccard_weighted":val_jaccard_weighted,        
        'val_roc_auc_macro': val_roc_auc_macro,
        'val_roc_auc_micro': val_roc_auc_micro,
        'val_roc_auc_weighted': val_roc_auc_weighted,
        'val_pr_auc_macro': val_pr_auc_macro,
        'val_pr_auc_micro': val_pr_auc_micro,
        'val_pr_auc_weighted': val_pr_auc_weighted,
    }


In [6]:
# Comprehensive Model Comparison with Validation Control

def compare_all_models(X_train, y_train, X_val, y_val, X_test, y_test):
    """
    Train and compare all models with validation control
    """
    
    print("🚀 COMPREHENSIVE MODEL COMPARISON WITH VALIDATION CONTROL")
    print("="*80)
    
    models_to_test = ['logistic', 'randomforest', 'lightgbm', 'xgboost']
    results = {}
    
    for model_type in models_to_test:
        print(f"\n{'='*60}")
        print(f"🔧 Training {model_type.upper()} Model")
        print(f"{'='*60}")
        
        try:
            # Train model with validation
            model, metrics = training_function_with_validation(
                X_train, y_train, X_val, y_val, model_type=model_type
            )
            
            # Test on unseen data
            y_pred_test = model.predict(X_test)
            # For ROC-AUC and PR-AUC, need probability estimates
            if hasattr(model, "predict_proba"):
                y_prob_test = model.predict_proba(X_test)
            else:
                y_prob_test = None
        
            test_acc = accuracy_score(y_test, y_pred_test)
            test_f1_samples = f1_score(y_test, y_pred_test, average='samples', zero_division=0)
            test_f1_micro = f1_score(y_test, y_pred_test, average='micro', zero_division=0)
            test_f1_macro = f1_score(y_test, y_pred_test, average='macro', zero_division=0)
            test_f1_weighted = f1_score(y_test, y_pred_test, average='weighted', zero_division=0)
            test_hamming = hamming_loss(y_test, y_pred_test)

            test_jaccard_samples = jaccard_score(y_test, y_pred_test, average='samples', zero_division=0)
            test_jaccard_macro = jaccard_score(y_test, y_pred_test, average='macro', zero_division=0)
            test_jaccard_weighted = jaccard_score(y_test, y_pred_test, average='weighted', zero_division=0)

            if y_prob_test is not None:
                try:
                    test_roc_auc_macro = roc_auc_score(y_test, y_prob_test, average="macro")
                    test_roc_auc_micro = roc_auc_score(y_test, y_prob_test, average="micro")
                    test_roc_auc_weighted = roc_auc_score(y_test, y_prob_test, average="weighted")

                except ValueError as e:
                    print(f"Warning: ROC-AUC calculation failed: {e}")
                    test_roc_auc_macro=0.0
                    test_roc_auc_micro=0.0
                    test_roc_auc_weighted=0.0
                try:
                    test_pr_auc_macro = average_precision_score(y_test, y_prob_test, average="macro")
                    test_pr_auc_micro = average_precision_score(y_test, y_prob_test, average="weighted")
                    test_pr_auc_weighted = average_precision_score(y_test, y_prob_test, average="micro")
                except ValueError as e:
                    print(f"Warning: PR-AUC calculation failed: {e}")
                    test_pr_auc_macro=0.0
                    test_pr_auc_micro=0.0
                    test_pr_auc_weighted=0.0
            else:
                test_roc_auc_macro=0.0
                test_roc_auc_micro=0.0
                test_roc_auc_weighted=0.0
                test_pr_auc_macro=0.0
                test_pr_auc_micro=0.0
                test_pr_auc_weighted=0.0            
            
            # Store all results
            results[model_type] = {
                'model': model,
                'val_accuracy': metrics['val_accuracy'],
                'test_accuracy': test_acc,
                'val_f1_samples': metrics['val_f1_samples'],
                'val_f1_micro': metrics['val_f1_micro'],
                'val_f1_macro': metrics['val_f1_macro'],
                'val_f1_weighted': metrics['val_f1_weighted'],
                'test_f1_samples': test_f1_samples,
                'test_f1_micro': test_f1_micro,
                'test_f1_macro': test_f1_macro,
                'test_f1_weighted': test_f1_weighted,
                'val_hamming_loss': metrics['val_hamming_loss'],
                'test_hamming_loss': test_hamming,
                'val_jaccard_samples': metrics['val_jaccard_samples'],
                'val_jaccard_macro': metrics['val_jaccard_macro'],
                'val_jaccard_weighted': metrics['val_jaccard_weighted'],
                'test_jaccard_samples': test_jaccard_samples,
                'test_jaccard_macro': test_jaccard_macro,
                'test_jaccard_weighted': test_jaccard_weighted,
                'val_roc_auc_macro': metrics['val_roc_auc_macro'],
                'val_roc_auc_micro': metrics['val_roc_auc_micro'],
                'val_roc_auc_weighted': metrics['val_roc_auc_weighted'],
                'val_pr_auc_macro': metrics['val_pr_auc_macro'],
                'val_pr_auc_micro': metrics['val_pr_auc_micro'],
                'val_pr_auc_weighted': metrics['val_pr_auc_weighted'],
                'test_roc_auc_macro': test_roc_auc_macro,
                'test_roc_auc_micro': test_roc_auc_micro,
                'test_roc_auc_weighted': test_roc_auc_weighted,
                'test_pr_auc_macro': test_pr_auc_macro,
                'test_pr_auc_micro': test_pr_auc_micro,
                'test_pr_auc_weighted': test_pr_auc_weighted,

            }
            
            print(f"✅ {model_type.upper()} completed successfully!")
            print(f"   Test Accuracy: {test_acc:.4f}")
            print(f"   test_f1_samples: {test_f1_samples:.4f}")
            print(f"   test_f1_macro: {test_f1_macro:.4f}")
            print(f"   test_f1_micro: {test_f1_micro:.4f}")
            print(f"   test_f1_weighted: {test_f1_weighted:.4f}")
            print(f"   Test Hamming Loss: {test_hamming:.4f}")
            print(f"   test_jaccard_samples: {test_jaccard_samples:.4f}")
            print(f"   test_jaccard_macro: {test_jaccard_macro:.4f}")
            print(f"   test_jaccard_weighted: {test_jaccard_weighted:.4f}")
            print(f"   test_roc_auc_macro: {test_roc_auc_macro:.4f}")
            print(f"   test_roc_auc_micro: {test_roc_auc_micro:.4f}")
            print(f"   test_roc_auc_weighted: {test_roc_auc_weighted:.4f}")
            print(f"   test_pr_auc_macro: {test_pr_auc_macro:.4f}")
            print(f"   test_pr_auc_micro: {test_pr_auc_micro:.4f}")
            print(f"   test_pr_auc_weighted: {test_pr_auc_weighted:.4f}")
        except Exception as e:
            print(f"❌ Error training {model_type}: {str(e)}")
            results[model_type] = None
    
    return results


In [13]:
def analyze_model_results(results):
    """
    Analyze and display comprehensive results
    """

    # Filter successful results
    successful_results = {k: v for k, v in results.items() if v is not None}
    if not successful_results:
        print("❌ No models trained successfully!")
        return

    # Define column widths for perfect alignment
    col_widths = {
        "Model": 15,
        "Acc": 9,
        "Ham": 9,
        "f1_macro": 13,
        "f1_weighted": 16,
        "roc_auc_macro": 18,
        "roc_auc_weighted": 20,
        "pr_auc_macro": 17,
        "pr_auc_weighted": 19,
        "jaccard_macro": 18,
        "jaccard_weighted": 20
    }

    # Validation Set Table
    print("\n" + "="*120)
    print("📊 Model Evaluation in Validation Set")
    print("="*120 + "\n")

    val_header = (
        f"{'Model':<{col_widths['Model']}} | "
        f"{'Val Acc':<{col_widths['Acc']}} | "
        f"{'Val Ham':<{col_widths['Ham']}} | "
        f"{'val_f1_macro':<{col_widths['f1_macro']}} | "
        f"{'val_f1_weighted':<{col_widths['f1_weighted']}} | "
        f"{'val_roc_auc_macro':<{col_widths['roc_auc_macro']}} | "
        f"{'val_roc_auc_weighted':<{col_widths['roc_auc_weighted']}} | "
        f"{'val_pr_auc_macro':<{col_widths['pr_auc_macro']}} | "
        f"{'val_pr_auc_weighted':<{col_widths['pr_auc_weighted']}} | "
        # f"{'val_jaccard_macro':<{col_widths['jaccard_macro']}} | "
        # f"{'val_jaccard_weighted':<{col_widths['jaccard_weighted']}}"
    )
    print(val_header)
    print("-" * len(val_header))

    for model_name, result in successful_results.items():
        print(
            f"{model_name.upper():<{col_widths['Model']}} | "
            f"{result['val_accuracy']:<{col_widths['Acc']}.4f} | "
            f"{result['val_hamming_loss']:<{col_widths['Ham']}.4f} | "
            f"{result['val_f1_macro']:<{col_widths['f1_macro']}.4f} | "
            f"{result['val_f1_weighted']:<{col_widths['f1_weighted']}.4f} | "
            f"{result['val_roc_auc_macro']:<{col_widths['roc_auc_macro']}.4f} | "
            f"{result['val_roc_auc_weighted']:<{col_widths['roc_auc_weighted']}.4f} | "
            f"{result['val_pr_auc_macro']:<{col_widths['pr_auc_macro']}.4f} | "
            f"{result['val_pr_auc_weighted']:<{col_widths['pr_auc_weighted']}.4f} | "
            # f"{result['val_jaccard_macro']:<{col_widths['jaccard_macro']}.4f} | "
            # f"{result['val_jaccard_weighted']:<{col_widths['jaccard_weighted']}.4f}"
        )

    print("\n" + "="*120)
    print("📊 Model Evaluation in Test Set")
    print("="*120 + "\n")

    test_header = (
        f"{'Model':<{col_widths['Model']}} | "
        f"{'test Acc':<{col_widths['Acc']}} | "
        f"{'test Ham':<{col_widths['Ham']}} | "
        f"{'test_f1_macro':<{col_widths['f1_macro']}} | "
        f"{'test_f1_weighted':<{col_widths['f1_weighted']}} | "
        f"{'test_roc_auc_macro':<{col_widths['roc_auc_macro']}} | "
        f"{'test_roc_auc_weighted':<{col_widths['roc_auc_weighted']}} | "
        f"{'test_pr_auc_macro':<{col_widths['pr_auc_macro']}} | "
        f"{'test_pr_auc_weighted':<{col_widths['pr_auc_weighted']}} | "
        # f"{'test_jaccard_macro':<{col_widths['jaccard_macro']}} | "
        # f"{'test_jaccard_weighted':<{col_widths['jaccard_weighted']}}"
    )
    print(test_header)
    print("-" * len(test_header))

    for model_name, result in successful_results.items():
        print(
            f"{model_name.upper():<{col_widths['Model']}} | "
            f"{result['test_accuracy']:<{col_widths['Acc']}.4f} | "
            f"{result['test_hamming_loss']:<{col_widths['Ham']}.4f} | "
            f"{result['test_f1_macro']:<{col_widths['f1_macro']}.4f} | "
            f"{result['test_f1_weighted']:<{col_widths['f1_weighted']}.4f} | "
            f"{result['test_roc_auc_macro']:<{col_widths['roc_auc_macro']}.4f} | "
            f"{result['test_roc_auc_weighted']:<{col_widths['roc_auc_weighted']}.4f} | "
            f"{result['test_pr_auc_macro']:<{col_widths['pr_auc_macro']}.4f} | "
            f"{result['test_pr_auc_weighted']:<{col_widths['pr_auc_weighted']}.4f} | "
            # f"{result['test_jaccard_macro']:<{col_widths['jaccard_macro']}.4f} | "
            # f"{result['test_jaccard_weighted']:<{col_widths['jaccard_weighted']}.4f}"
        )
    return successful_results

In [9]:
# Example usage
print("Starting Model Training...")
all_results = compare_all_models(X_train_tfidf, y_train, X_val_tfidf, y_val, X_test_tfidf, y_test)


Starting Model Training...
🚀 COMPREHENSIVE MODEL COMPARISON WITH VALIDATION CONTROL

🔧 Training LOGISTIC Model
Training logistic with validation control...
X_train shape: (11597, 10000)
y_train shape: (11597, 27)
X_val shape: (2485, 10000)
y_val shape: (2485, 27)


100%|██████████| 27/27 [00:21<00:00,  1.25it/s]


Training completed!
Val Accuracy: 0.4978
Val F1 samples: 0.8097
Val F1 macro: 0.6277
Val F1 micro: 0.7936
Val F1 weighted: 0.8241
Train Hamming Loss: 0.0199
Val Hamming Loss: 0.0269
Overfitting Gap (Hamming): 0.0070
✅ LOGISTIC completed successfully!
   Test Accuracy: 0.4702
   test_f1_samples: 0.7981
   test_f1_macro: 0.6184
   test_f1_micro: 0.7859
   test_f1_weighted: 0.8188
   Test Hamming Loss: 0.0281
   test_jaccard_samples: 0.7203
   test_jaccard_macro: 0.4815
   test_jaccard_weighted: 0.7227
   test_roc_auc_macro: 0.9772
   test_roc_auc_micro: 0.9885
   test_roc_auc_weighted: 0.9746
   test_pr_auc_macro: 0.6688
   test_pr_auc_micro: 0.8618
   test_pr_auc_weighted: 0.8556

🔧 Training RANDOMFOREST Model
Training randomforest with validation control...
X_train shape: (11597, 10000)
y_train shape: (11597, 27)
X_val shape: (2485, 10000)
y_val shape: (2485, 27)


100%|██████████| 27/27 [00:15<00:00,  1.75it/s]


Training completed!
Val Accuracy: 0.5127
Val F1 samples: 0.7891
Val F1 macro: 0.3862
Val F1 micro: 0.7894
Val F1 weighted: 0.7666
Train Hamming Loss: 0.0114
Val Hamming Loss: 0.0245
Overfitting Gap (Hamming): 0.0130
✅ RANDOMFOREST completed successfully!
   Test Accuracy: 0.5270
   test_f1_samples: 0.7884
   test_f1_macro: 0.3861
   test_f1_micro: 0.7901
   test_f1_weighted: 0.7671
   Test Hamming Loss: 0.0242
   test_jaccard_samples: 0.7232
   test_jaccard_macro: 0.2972
   test_jaccard_weighted: 0.6723
   test_roc_auc_macro: 0.9641
   test_roc_auc_micro: 0.9840
   test_roc_auc_weighted: 0.9649
   test_pr_auc_macro: 0.6091
   test_pr_auc_micro: 0.8343
   test_pr_auc_weighted: 0.8438

🔧 Training LIGHTGBM Model
Training lightgbm with validation control...
X_train shape: (11597, 10000)
y_train shape: (11597, 27)
X_val shape: (2485, 10000)
y_val shape: (2485, 27)


100%|██████████| 27/27 [02:41<00:00,  5.97s/it]


Training completed!
Val Accuracy: 0.6149
Val F1 samples: 0.8339
Val F1 macro: 0.5954
Val F1 micro: 0.8340
Val F1 weighted: 0.8183
Train Hamming Loss: 0.0016
Val Hamming Loss: 0.0181
Overfitting Gap (Hamming): 0.0165
✅ LIGHTGBM completed successfully!
   Test Accuracy: 0.6070
   test_f1_samples: 0.8205
   test_f1_macro: 0.5729
   test_f1_micro: 0.8248
   test_f1_weighted: 0.8096
   Test Hamming Loss: 0.0190
   test_jaccard_samples: 0.7675
   test_jaccard_macro: 0.4471
   test_jaccard_weighted: 0.7212
   test_roc_auc_macro: 0.9621
   test_roc_auc_micro: 0.9874
   test_roc_auc_weighted: 0.9755
   test_pr_auc_macro: 0.6197
   test_pr_auc_micro: 0.8605
   test_pr_auc_weighted: 0.8836

🔧 Training XGBOOST Model
Training xgboost with validation control...
X_train shape: (11597, 10000)
y_train shape: (11597, 27)
X_val shape: (2485, 10000)
y_val shape: (2485, 27)


100%|██████████| 27/27 [08:49<00:00, 19.61s/it]


Training completed!
Val Accuracy: 0.6205
Val F1 samples: 0.8357
Val F1 macro: 0.5861
Val F1 micro: 0.8368
Val F1 weighted: 0.8162
Train Hamming Loss: 0.0027
Val Hamming Loss: 0.0176
Overfitting Gap (Hamming): 0.0149
✅ XGBOOST completed successfully!
   Test Accuracy: 0.6219
   test_f1_samples: 0.8273
   test_f1_macro: 0.5815
   test_f1_micro: 0.8331
   test_f1_weighted: 0.8133
   Test Hamming Loss: 0.0179
   test_jaccard_samples: 0.7767
   test_jaccard_macro: 0.4584
   test_jaccard_weighted: 0.7255
   test_roc_auc_macro: 0.9797
   test_roc_auc_micro: 0.9920
   test_roc_auc_weighted: 0.9780
   test_pr_auc_macro: 0.6835
   test_pr_auc_micro: 0.8706
   test_pr_auc_weighted: 0.9148


In [14]:
print("Starting comprehensive model comparison...")
print()
final_analysis = analyze_model_results(all_results)

Starting comprehensive model comparison...


📊 Model Evaluation in Validation Set

Model           | Val Acc   | Val Ham   | val_f1_macro  | val_f1_weighted  | val_roc_auc_macro  | val_roc_auc_weighted | val_pr_auc_macro  | val_pr_auc_weighted | 
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
LOGISTIC        | 0.4978    | 0.0269    | 0.6277        | 0.8241           | 0.9800             | 0.9765               | 0.6757            | 0.8560              | 
RANDOMFOREST    | 0.5127    | 0.0245    | 0.3862        | 0.7666           | 0.9655             | 0.9676               | 0.6136            | 0.8480              | 
LIGHTGBM        | 0.6149    | 0.0181    | 0.5954        | 0.8183           | 0.9522             | 0.9753               | 0.6355            | 0.8891              | 
XGBOOST         | 0.6205    | 0.0176    | 0.5861        | 0.8162           | 0.9778             |

In [15]:
import json

# Suppose your dictionary is named 'results'
# Remove non-serializable objects (like model instances) before saving
results_to_save = {}
for k, v in all_results.items():
    results_to_save[k] = {key: value for key, value in v.items() if key != 'model'}

# Save to JSON file
with open('TFIDF-model.json', 'w') as f:
    json.dump(results_to_save, f, indent=4)

In [None]:
# # Load from JSON file
# with open('TFIDF-model.json', 'r') as f:
#     loaded_results = json.load(f)

# print(loaded_results)

## Transformers Encoder Model (MordenBERT)

### you need GPU to run the following script

In [16]:
import os
import gc
import pandas as pd
import numpy as np
import pickle
from datasets import Dataset, DatasetDict
from datasets import Sequence, Value
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
from transformers import EarlyStoppingCallback
from transformers import TrainerCallback
from transformers import set_seed
import evaluate
import argparse
from functools import partial
from sklearn.metrics import (
    precision_score, recall_score, f1_score, 
    roc_auc_score, average_precision_score,
    hamming_loss, jaccard_score, accuracy_score
)

import warnings
warnings.filterwarnings('ignore')

In [17]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.random.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    set_seed(seed)

In [18]:
def create_datasets_from_arrays(X_train, y_train, X_val=None, y_val=None, X_test=None, y_test=None):
    """
    Convert arrays into HuggingFace datasets format with specified structure
    
    Returns:
        DatasetDict with features:
        - dataset["train"]["text"]: text data
        - dataset["train"]["labels"]: multi-label arrays
        - dataset["val"]["text"]: validation text data (if provided)
        - dataset["val"]["labels"]: validation labels (if provided)
        - dataset["test"]["text"]: test text data (if provided)
        - dataset["test"]["labels"]: test labels (if provided)
    """
    # Create training dataset
    train_dict = {
        "text": X_train.tolist() if hasattr(X_train, 'tolist') else list(X_train),
        "labels": y_train.tolist() if hasattr(y_train, 'tolist') else list(y_train)
    }
    
    datasets_dict = {
        "train": Dataset.from_dict(train_dict)
    }
    
    # Add validation dataset if provided
    if X_val is not None and y_val is not None:
        val_dict = {
            "text": X_val.tolist() if hasattr(X_val, 'tolist') else list(X_val),
            "labels": y_val.tolist() if hasattr(y_val, 'tolist') else list(y_val)
        }
        datasets_dict["val"] = Dataset.from_dict(val_dict)
    
    # Add test dataset if provided
    if X_test is not None and y_test is not None:
        test_dict = {
            "text": X_test.tolist() if hasattr(X_test, 'tolist') else list(X_test),
            "labels": y_test.tolist() if hasattr(y_test, 'tolist') else list(y_test)
        }
        datasets_dict["test"] = Dataset.from_dict(test_dict)

    # Create DatasetDict
    dataset = DatasetDict(datasets_dict)
    
    return dataset

def preprocess_function(examples,tokenizer,max_length):
    """
    Proper tokenization function for multi-label classification.
    Ensures all outputs are compatible with HuggingFace Trainer.
    """
    # Handle batch vs single example
    if isinstance(examples['text'], str):
        texts = [examples['text']]
        labels = [examples['labels']]
    else:
        texts = examples['text']
        labels = examples['labels']
    
    # Tokenize the texts
    tokenized = tokenizer(
        texts,
        truncation=True,
        padding=True,  # Will be handled by data collator
        max_length=max_length,  # Adjust based on your model's limit
        return_tensors=None  # Don't return tensors yet, let data collator handle it
    )
    
    # Ensure labels are float32 for BCEWithLogitsLoss
    if isinstance(labels[0], (list, np.ndarray)):
        tokenized['labels'] = [np.array(label, dtype=np.float32).tolist() for label in labels]
    else:
        tokenized['labels'] = [np.array(labels, dtype=np.float32).tolist()]
    
    return tokenized

def sigmoid(x):
    """Sigmoid activation function"""
    return 1/(1 + np.exp(-x))

In [19]:
def comprehensive_evaluation(y_true, y_pred_proba, y_pred_binary=None, threshold=0.5):
    """
    Comprehensive evaluation for multi-label classification with all averaging methods
    
    Args:
        y_true: Ground truth binary labels (n_samples, n_labels)
        y_pred_proba: Predicted probabilities (n_samples, n_labels)
        y_pred_binary: Predicted binary labels (n_samples, n_labels), optional
        threshold: Threshold for converting probabilities to binary (default: 0.5)
    
    Returns:
        dict: Comprehensive metrics including all averaging methods
    """
    if y_pred_binary is None:
        y_pred_binary = (y_pred_proba >= threshold).astype(int)
    
    metrics = {}
    
    try:
        # SAMPLES AVERAGE (per-sample then average across samples)
        # metrics['precision_samples'] = precision_score(y_true, y_pred_binary, average='samples', zero_division=0)
        # metrics['recall_samples'] = recall_score(y_true, y_pred_binary, average='samples', zero_division=0)
        metrics['f1_samples'] = f1_score(y_true, y_pred_binary, average='samples', zero_division=0)
        
        # MICRO AVERAGE (global aggregation)
        # metrics['precision_micro'] = precision_score(y_true, y_pred_binary, average='micro', zero_division=0)
        # metrics['recall_micro'] = recall_score(y_true, y_pred_binary, average='micro', zero_division=0)
        metrics['f1_micro'] = f1_score(y_true, y_pred_binary, average='micro', zero_division=0)
        
        # MACRO AVERAGE (unweighted average across labels)
        # metrics['precision_macro'] = precision_score(y_true, y_pred_binary, average='macro', zero_division=0)
        # metrics['recall_macro'] = recall_score(y_true, y_pred_binary, average='macro', zero_division=0)
        metrics['f1_macro'] = f1_score(y_true, y_pred_binary, average='macro', zero_division=0)
        
        # WEIGHTED AVERAGE (weighted by support)
        # metrics['precision_weighted'] = precision_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        # metrics['recall_weighted'] = recall_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        metrics['f1_weighted'] = f1_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        
        # ACCURACY METRICS
        metrics['accuracy'] = accuracy_score(y_true, y_pred_binary)
        metrics['hamming_loss'] = hamming_loss(y_true, y_pred_binary)
        
        # JACCARD (IoU) METRICS 
        metrics['jaccard_samples'] = jaccard_score(y_true, y_pred_binary, average='samples', zero_division=0)
        metrics['jaccard_macro'] = jaccard_score(y_true, y_pred_binary, average='macro', zero_division=0)
        metrics['jaccard_weighted'] = jaccard_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        
        # ROC-AUC METRICS (using probabilities)
        try:
            metrics['roc_auc_micro'] = roc_auc_score(y_true, y_pred_proba, average='micro')
            metrics['roc_auc_macro'] = roc_auc_score(y_true, y_pred_proba, average='macro')
            metrics['roc_auc_weighted'] = roc_auc_score(y_true, y_pred_proba, average='weighted')
            metrics['roc_auc_samples'] = roc_auc_score(y_true, y_pred_proba, average='samples')
        except ValueError as e:
            print(f"Warning: ROC-AUC calculation failed: {e}")
            metrics['roc_auc_micro'] = 0.0
            metrics['roc_auc_macro'] = 0.0
            metrics['roc_auc_weighted'] = 0.0
            metrics['roc_auc_samples'] = 0.0
        
        # PR-AUC METRICS (using probabilities)
        try:
            metrics['pr_auc_micro'] = average_precision_score(y_true, y_pred_proba, average='micro')
            metrics['pr_auc_macro'] = average_precision_score(y_true, y_pred_proba, average='macro')
            metrics['pr_auc_weighted'] = average_precision_score(y_true, y_pred_proba, average='weighted')
            metrics['pr_auc_samples'] = average_precision_score(y_true, y_pred_proba, average='samples')
        except ValueError as e:
            print(f"Warning: PR-AUC calculation failed: {e}")
            metrics['pr_auc_micro'] = 0.0
            metrics['pr_auc_macro'] = 0.0
            metrics['pr_auc_weighted'] = 0.0
            metrics['pr_auc_samples'] = 0.0
        
    except Exception as e:
        print(f"Error in comprehensive_evaluation: {e}")
        # Return minimal metrics if calculation fails
        metrics = {
            'precision_micro': 0.0, 'recall_micro': 0.0, 'f1_micro': 0.0,
            'precision_macro': 0.0, 'recall_macro': 0.0, 'f1_macro': 0.0,
            'accuracy': 0.0, 'hamming_loss': 1.0
        }
    
    return metrics

def compute_metrics(eval_pred):
    """
    Enhanced compute_metrics function for transformers Trainer using comprehensive evaluation
    """
    predictions, labels = eval_pred
    
    # Apply sigmoid to get probabilities
    predictions_proba = sigmoid(predictions)
    
    # Convert to binary predictions using threshold 0.5
    predictions_binary = (predictions_proba > 0.5).astype(int)
    
    # Ensure labels are integers
    labels = labels.astype(int)
    
    # Use comprehensive evaluation
    metrics = comprehensive_evaluation(
        y_true=labels,
        y_pred_proba=predictions_proba,
        y_pred_binary=predictions_binary,
        threshold=0.5
    )
    
    # Return metrics with eval_ prefix for Trainer compatibility
    return {
        # Primary metrics for monitoring
        'eval_accuracy': metrics['accuracy'],
        'eval_hamming_loss': metrics['hamming_loss'],
        'eval_f1_macro': metrics['f1_macro'],
        # 'eval_f1_samples': metrics['f1_samples'],
        'eval_f1_weighted': metrics['f1_weighted'],

        # # Precision metrics
        # 'eval_precision_micro': metrics['precision_micro'],
        # 'eval_precision_macro': metrics['precision_macro'],
        # 'eval_precision_samples': metrics['precision_samples'],
        # 'eval_precision_weighted': metrics['precision_weighted'],
        
        # # Recall metrics
        # 'eval_recall_micro': metrics['recall_micro'],
        # 'eval_recall_macro': metrics['recall_macro'],
        # 'eval_recall_samples': metrics['recall_samples'],
        # 'eval_recall_weighted': metrics['recall_weighted'],
        
        # ROC-AUC metrics
        # 'eval_roc_auc_micro': metrics['roc_auc_micro'],
        'eval_roc_auc_macro': metrics['roc_auc_macro'],
        'eval_roc_auc_weighted': metrics['roc_auc_weighted'],
        # 'eval_roc_auc_samples': metrics['roc_auc_samples'],
        
        # PR-AUC metrics
        # 'eval_pr_auc_micro': metrics['pr_auc_micro'],
        'eval_pr_auc_macro': metrics['pr_auc_macro'],
        'eval_pr_auc_weighted': metrics['pr_auc_weighted'],
        # 'eval_pr_auc_samples': metrics['pr_auc_samples'],
        
        # Jaccard metrics
        # 'eval_jaccard_samples': metrics['jaccard_samples'],
        'eval_jaccard_macro': metrics['jaccard_macro'],
        'eval_jaccard_weighted': metrics['jaccard_weighted'],
    }

In [20]:
class ClearCUDACacheCallback(TrainerCallback):
    def on_step_end(self,args,state,control,**kwargs):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    def on_evaluate(self,args,state,control,**kwargs):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
class EarlyStoppingCallback(TrainerCallback):
    def __init__(self,patience=1):
        super().__init__()
        self.patience=patience
        self.best_loss = float("inf")
        self.early_stop_count = 0

    def on_evaluate(self, args, state, control, **kwargs):
        # Access the evaluation loss
        eval_loss = kwargs["metrics"]["eval_loss"]
        if eval_loss < self.best_loss:
            self.best_loss = eval_loss
            self.early_stop_count = 0
        else:
            self.early_stop_count += 1
        if self.early_stop_count >= self.patience:
            print("Early stopping triggered")
            control.should_training_stop = True
            
class LoggingCallback(TrainerCallback):
    def __init__(self,log_file):
        super().__init__()
        self.log_file=log_file
        self.last_train_loss = None
    def on_log(self, args, state, control, logs=None, **kwargs):
        logs = logs or {}

        # Extract metrics
        epoch = round(float(state.epoch), 2) if state.epoch is not None else None
        step=int(state.global_step)
        # train_loss = logs.get("loss")
        # if train_loss is not None:
        #     train_loss = round(float(train_loss), 4)
        #     self.last_train_loss = train_loss
        # elif self.last_train_loss is not None:
        #     train_loss = self.last_train_loss
        # else:
        #     train_loss = "N/A"  # or skip logging this time
        eval_loss = logs.get("eval_loss")
        eval_accuracy = logs.get("eval_accuracy")
        eval_hamming_loss = logs.get("eval_hamming_loss")
        eval_jaccard_weighted = logs.get("eval_jaccard_weighted")

        # Round losses to 4 decimal places if present
        # train_loss = round(float(train_loss), 4) if train_loss is not None else None
        eval_loss = round(float(eval_loss), 4) if eval_loss is not None else None
        eval_accuracy = round(float(eval_accuracy), 4) if eval_accuracy is not None else None
        eval_hamming_loss = round(float(eval_hamming_loss), 4) if eval_hamming_loss is not None else None
        eval_jaccard_weighted = round(float(eval_jaccard_weighted), 4) if eval_jaccard_weighted is not None else None

        # Prepare log line: epoch, train_loss, eval_loss, eval_accuracy, eval_hamming_loss, eval_jaccard_weighted
        log_line = (
            f"Epoch: {epoch} | "
            f"Step: {step} | "
            f"Val Loss: {eval_loss} | "
            f"Val Acc: {eval_accuracy} | "
            f"Val Hamming Loss: {eval_hamming_loss} | "
            f"Val Jaccard Weighted: {eval_jaccard_weighted}\n"
        )

        # Write to log file
        with open(self.log_file, "a") as f:
            f.write(log_line)

In [21]:
def main(args):
    dataset = create_datasets_from_arrays(X_train, y_train, X_val, y_val, X_test, y_test)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    print("{:<25}{:<15,}".format("Maximal context length:",tokenizer.model_max_length))
    print("{:<25}{:<15,}".format("Vocabulary size :",tokenizer.vocab_size))

    # Apply the tokenization function

    encode_function = partial(preprocess_function,tokenizer=tokenizer, max_length=args.max_context_length)
    
    tokenized_dataset = dataset.map(
        encode_function,
        batched=True,
        remove_columns=['text'],  # Remove the problematic text column
        desc="Tokenizing dataset"
    )
    
    # Define the proper feature type for multi-label classification
    label_feature = Sequence(Value("float32"), length=len(class_name))
    
    # Cast the labels column to float32 for all splits
    for split_name in tokenized_dataset.keys():
        tokenized_dataset[split_name] = tokenized_dataset[split_name].cast_column("labels", 
                                                                                  label_feature)
    print(f"Features: {list(tokenized_dataset['train'].features.keys())}")


    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    
    class2id = {class_:id for id, class_ in enumerate(class_name)}
    id2class = {id:class_ for class_, id in class2id.items()}
    
    
    model = AutoModelForSequenceClassification.from_pretrained(args.model_path, 
                                                               num_labels=len(class_name),
                                                               id2label=id2class, 
                                                               label2id=class2id,
                                                               problem_type = "multi_label_classification"
                                                              )
    
    
    print()
    
    # Verify model is properly configured for multi-label classification
    print("🤖 Model Configuration Verification:")
    print(f"  Model type: {type(model).__name__}")
    print(f"  Number of labels: {model.config.num_labels}")
    print(f"  Expected labels: {len(class_name)}")
    
    # Check if model configuration matches our data
    if model.config.num_labels != len(class_name):
        print(f"⚠️ WARNING: Model expects {model.config.num_labels} labels, but data has {len(class_name)}")
        print("  This might cause issues during training")
    else:
        print(f"✅ Model configuration matches data: {len(class_name)} labels")
    
    # Verify model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\n📊 Model Parameters:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")

    # Fix tokenizer parallelism warning
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    training_args = TrainingArguments(
        # Output and logging
        output_dir=args.output_dir,
        # logging_dir="./logs",
        # logging_steps=100,
        logging_strategy="no", ## we customize logging intead of using the built-in logging
        
        # Learning parameters
        learning_rate=2e-5,
        lr_scheduler_type="linear",  # Linear decay
        warmup_ratio=0.1,  # 10% warmup
        weight_decay=0.01,
        
        # Batch sizes (adjust based on GPU memory)
        per_device_train_batch_size=args.train_batch,
        per_device_eval_batch_size=args.eval_batch,
        gradient_accumulation_steps=args.gradient_accumulation_step,  # Effective batch size = 4 * 12 = 48

        # Training epochs and evaluation
        num_train_epochs=args.num_epochs,  # Increased for better convergence
        eval_strategy="steps",  # More frequent evaluation
        eval_steps=100,  # Evaluate every 100 steps
        
        # 🎯 OPTIMAL METRICS FOR MULTI-LABEL CLASSIFICATION
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,  # Keep only 3 best checkpoints
        load_best_model_at_end=True,
        
        # 🔥 RECOMMENDED: Use Hamming Loss for multi-label problems
        metric_for_best_model="eval_hamming_loss",  # Primary metric: lower is better
        greater_is_better=False,  # Hamming loss: lower = better performance
        
        # Alternative good options:
        # metric_for_best_model="eval_f1_micro",     # Current choice - also excellent
        # metric_for_best_model="eval_jaccard_samples", # IoU metric - good for multi-label
        
        # Memory and performance optimization
        dataloader_pin_memory=False,  # Disable to avoid forking issues
        dataloader_num_workers=0,     # Disable multiprocessing
        remove_unused_columns=False,  # Keep all columns for multi-label
        
        # Mixed precision for faster training (if GPU supports it)
        fp16=True,  # Enable if using compatible GPU
        
        # Reproducibility
        seed=args.seed,
        data_seed=args.seed,
        
        # Report metrics
        report_to=None,  # Disable wandb/tensorboard if not needed
        run_name="multi_label_posture_classification",
    )
    # # Early stopping callback for overfitting control
    # early_stopping = EarlyStoppingCallback(
    #     early_stopping_patience=3,  # Stop if no improvement for 3 evaluations
    #     early_stopping_threshold=0.001  # Minimum improvement threshold
    # )
    
    # Initialize trainer with enhanced configuration (using processing_class)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["val"],
        processing_class=tokenizer,  # Updated parameter name
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(patience=3),
                   ClearCUDACacheCallback(),
                   LoggingCallback(log_file=os.path.join(args.output_dir,"training_logs.txt"))],  # Add early stopping callback
    )    

    print("\n🎯 Starting training...")
    trainer.train()
    print("✅ Training completed successfully!")

    val_results = trainer.evaluate()

    # Display key metrics
    key_metrics = [
        'eval_accuracy', 'eval_hamming_loss', 'eval_f1_micro', 'eval_f1_macro', 
        'eval_f1_weighted','eval_f1_samples', 'eval_jaccard_macro', 
        'eval_jaccard_weighted'
    ]

    with open(os.path.join(args.output_dir,"val_logs.txt"), "w") as f:
        for metric in key_metrics:
            if metric in val_results:
                line = f"{metric}: {val_results[metric]:.4f}"
                print(f"   {line}")
                f.write(line + "\n")

    if "test" in tokenized_dataset:
        print("\n🎯 Test Set Evaluation:")
        test_results = trainer.evaluate(eval_dataset=tokenized_dataset["test"])
        with open(os.path.join(args.output_dir,"test_logs.txt"), "w") as f:
            for metric in key_metrics:
                if metric in test_results:
                    line = f"{metric}: {test_results[metric]:.4f}"
                    print(f"   {line}")
                    f.write(line + "\n")


In [22]:
if __name__=="__main__":
    parser = argparse.ArgumentParser(description='Fine-tune MordenBERT')

    parser.add_argument("--seed",  type=int,default=42)
    parser.add_argument("--data_path", type=str, default='processed_data')
    parser.add_argument("--output_dir", type=str, default='model_output')
    parser.add_argument('--model_path', type=str, default="answerdotai/ModernBERT-base")
    parser.add_argument('--train_batch', type=int, default=4)
    parser.add_argument('--eval_batch', type=int, default=8)
    parser.add_argument('--gradient_accumulation_step', type=int, default=12)
    parser.add_argument('--num_epochs', type=int, default=5)
    parser.add_argument('--max_context_length', type=int, default=512)
    
    args, _= parser.parse_known_args()

    seed_everything(args.seed)

    ### Load Dataset for model training and evaluation ###
    data_path=os.path.join(os.getcwd(), args.data_path)
    with open(os.path.join(data_path,'train_arrays.pkl'), 'rb') as f:
        train_data = pickle.load(f)
        X_train = train_data['X_train']
        y_train = train_data['y_train']
    
    with open(os.path.join(data_path,'val_arrays.pkl'), 'rb') as f:
        val_data = pickle.load(f)
        X_val = val_data['X_val']
        y_val = val_data['y_val']
    
    with open(os.path.join(data_path,'test_arrays.pkl'), 'rb') as f:
        test_data = pickle.load(f)
        X_test = test_data['X_test']
        y_test = test_data['y_test']
    
    with open(os.path.join(data_path,'class_name.pkl'), 'rb') as f:
        class_name_data = pickle.load(f)
        class_name = class_name_data['class_name']

    main(args)
    

Maximal context length:  8,192          
Vocabulary size :        50,280         


Tokenizing dataset:   0%|          | 0/11597 [00:00<?, ? examples/s]

Tokenizing dataset:   0%|          | 0/2485 [00:00<?, ? examples/s]

Tokenizing dataset:   0%|          | 0/2486 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/11597 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2485 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2486 [00:00<?, ? examples/s]

Features: ['labels', 'input_ids', 'attention_mask']


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🤖 Model Configuration Verification:
  Model type: ModernBertForSequenceClassification
  Number of labels: 27
  Expected labels: 27
✅ Model configuration matches data: 27 labels

📊 Model Parameters:
  Total parameters: 149,625,627
  Trainable parameters: 149,625,627

🎯 Starting training...


Step,Training Loss,Validation Loss,Accuracy,Hamming Loss,F1 Macro,F1 Weighted,Roc Auc Macro,Roc Auc Weighted,Pr Auc Macro,Pr Auc Weighted,Jaccard Macro,Jaccard Weighted
100,No log,0.108869,0.324346,0.036277,0.098176,0.544194,0.757261,0.889413,0.177906,0.664161,0.077208,0.45351
200,No log,0.075353,0.500201,0.024547,0.191859,0.679791,0.874292,0.951123,0.344305,0.774936,0.15033,0.598609
300,No log,0.068367,0.54326,0.022938,0.347228,0.738232,0.925669,0.962463,0.446651,0.801689,0.268993,0.65186
400,No log,0.061599,0.55493,0.021149,0.355828,0.737189,0.940895,0.966702,0.521841,0.81905,0.280129,0.660274
500,No log,0.059241,0.574648,0.020762,0.451176,0.761073,0.94689,0.968905,0.581514,0.837406,0.352403,0.675321
600,No log,0.056793,0.591147,0.019674,0.480668,0.784267,0.951841,0.970376,0.59893,0.840041,0.381261,0.701362
700,No log,0.055119,0.592354,0.01951,0.527601,0.786582,0.958111,0.971808,0.616069,0.848114,0.406824,0.699286
800,No log,0.054728,0.610865,0.018735,0.536085,0.799105,0.956394,0.970696,0.629109,0.848413,0.420717,0.710891
900,No log,0.0542,0.615694,0.018451,0.541879,0.806194,0.95473,0.970917,0.642234,0.852992,0.424259,0.720777
1000,No log,0.053624,0.621328,0.018213,0.587482,0.818676,0.956112,0.971402,0.63745,0.852427,0.461486,0.730777


✅ Training completed successfully!


Early stopping triggered
   eval_accuracy: 0.6213
   eval_hamming_loss: 0.0182
   eval_f1_macro: 0.5875
   eval_f1_weighted: 0.8187
   eval_jaccard_macro: 0.4615
   eval_jaccard_weighted: 0.7308

🎯 Test Set Evaluation:
Early stopping triggered
   eval_accuracy: 0.6102
   eval_hamming_loss: 0.0190
   eval_f1_macro: 0.5631
   eval_f1_weighted: 0.8102
   eval_jaccard_macro: 0.4437
   eval_jaccard_weighted: 0.7220
