<span style="font-weight: bold; font-size: 18px;">**Multi-Label Posture Classification: Model Development Strategy**<br><br>

We propose a comparative evaluation of two complementary modeling approaches to address the multi-label posture prediction task, each offering distinct advantages for legal document classification.

**Baseline Approach: Bag-of-Words Models**<br>

Our initial baseline leverages traditional bag-of-words representations (TF-IDF, BM25) combined with multi-label classifiers, justified by several key factors:

<div style="margin-left: 20px;"><b>• Computational Efficiency:</b> Lightweight architecture enables rapid prototyping and establishes performance baselines without GPU requirements</div>
<div style="margin-left: 20px;"><b>• Statistical Robustness:</b> Word-frequency features provide interpretable, domain-agnostic representations suitable for legal terminology analysis</div>
<div style="margin-left: 20px;"><b>• Multi-Label Compatibility:</b> Well-established integration with multi-label algorithms (One-vs-Rest, Binary Relevance, Label Powerset)</div>
<div style="margin-left: 20px;"><b>• Baseline Establishment:</b> Provides interpretable performance benchmarks for evaluating more complex architectures</div>

**Advanced Approach: Transformer-Based Models (ModernBERT)**<br>

Our primary model leverages ModernBERT encoder architecture, specifically designed to address the limitations of traditional BERT for our use case:

<div style="margin-left: 20px;"><b>• Extended Context Coverage:</b> ModernBERT's 8,192-token context window accommodates ~90% of our corpus without truncation, preserving critical legal context that may span entire documents</div>

<div style="margin-left: 20px;"><b>• Contextual Understanding:</b> Unlike bag-of-words approaches, transformer architectures capture:
  <div style="margin-left: 40px;">- Long-range dependencies between legal arguments</div>
  <div style="margin-left: 40px;">- Positional relationships between procedural elements</div>
  <div style="margin-left: 40px;">- Semantic nuances distinguishing similar posture categories</div>
</div>

<div style="margin-left: 20px;"><b>• Multi-Label Architecture:</b> The encoder's [CLS] token representation can be effectively coupled with multi-label classification heads, enabling simultaneous prediction of multiple postures</div>

<div style="margin-left: 20px;"><b>• Legal Domain Adaptation:</b> Pre-trained language understanding provides superior handling of complex legal terminology and document structure</div>

**Comparative Justification:**<br>

This dual-approach strategy enables comprehensive evaluation of feature representation impact on multi-label performance, ranging from traditional statistical methods to state-of-the-art contextual understanding, ultimately identifying the optimal balance between computational efficiency and classification accuracy for legal posture prediction.

</span>

## Data Preparation for ML

In [None]:
import os
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer

In [None]:
# Prepare the labels - convert postures to a list format
def prepare_labels(postures_str):
    """Convert posture string to list of postures"""
    if pd.isna(postures_str) or postures_str == '':
        return []
    return [p.strip() for p in postures_str.split(',') if p.strip()]

# Apply to dataframe
_dir=os.path.join(os.getcwd(),"processed_data")
df=pd.read_pickle(os.path.join(_dir, "data.pkl"))
df['posture_list'] = df['postures'].apply(prepare_labels)

# Remove documents with no postures
df_ml = df[df['posture_list'].apply(len) > 0].copy()
print(f"Documents with postures: {len(df_ml)}")

In [None]:
# Analyze posture distribution
all_postures_ml = []
for postures in df_ml['posture_list']:
    all_postures_ml.extend(postures)

posture_counts = pd.Series(all_postures_ml).value_counts()
print(f"\nTotal unique postures: {len(posture_counts)}")
print()
print(f"Most common postures:")
print(posture_counts.head(15))

In [None]:
# Filter to most common postures (those appearing in at least 100 documents)
min_frequency = 100
common_postures = posture_counts[posture_counts >= min_frequency].index.tolist()
print(f"\nPostures with >= {min_frequency} occurrences: {len(common_postures)}")
print(common_postures)

In [None]:
# Filter documents to only include those with common postures
def filter_common_postures(posture_list, common_postures):
    """Keep only postures that are in the common_postures list"""
    return [p for p in posture_list if p in common_postures]

df_ml['filtered_postures'] = df_ml['posture_list'].apply(
    lambda x: filter_common_postures(x, common_postures)
)

# Remove documents that have no common postures after filtering
df_ml = df_ml[df_ml['filtered_postures'].apply(len) > 0].copy()
print(f"Documents after filtering to common postures: {len(df_ml)}")

In [None]:
## Multi-label Classification Setup

# Create binary label matrix using MultiLabelBinarizer
mlb = MultiLabelBinarizer()
y_multilabel = mlb.fit_transform(df_ml['filtered_postures'])

print(f"Label matrix shape: {y_multilabel.shape}")
print(f"Labels: {mlb.classes_}")

In [None]:
_counts = df_ml['num_postures'].value_counts(dropna=False)
_pct = df_ml['num_postures'].value_counts(dropna=False,normalize=True) 

pd.DataFrame({
    'count': _counts,
    'percentage': _pct
}).sort_index().style.format({'count':'{:,}','percentage':'{:.2%}'}).set_caption("Distribution of num_postures")\
    .set_table_styles([{'selector': 'caption','props': [('color', 'red'),('font-size', '15px')]}])

In [None]:
# Prepare text data
X_text = df_ml['full_text'].values

# Split the data
X_train, X_temp, y_train, y_temp = train_test_split(
    X_text, y_multilabel, 
    test_size=0.3, # 30% for temp (which will be split into val and test)
    random_state=42, 
    stratify=None
)

 # Split temp into validation and test (50-50 split of the 30%)
# # This gives us 15% each
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp,
    test_size=0.5,
    random_state=42, 
    stratify=None
)

print(f"Total samples: {len(df_ml)}")
print(f"Training set: {len(X_train)} ({len(X_train)/len(df_ml):.2%})")
print(f"Validation set: {len(X_val)} ({len(X_val)/len(df_ml):.2%})")
print(f"Test set: {len(X_test)} ({len(X_test)/len(df_ml):.2%})")

In [None]:
# Check label distribution
train_label_sums = y_train.sum(axis=0)
val_label_sums = y_val.sum(axis=0)
test_label_sums = y_test.sum(axis=0)

print("\nLabel distribution in training set:")
for i, label in enumerate(mlb.classes_):
    print(f"{label}: {train_label_sums[i]} ({train_label_sums[i]/len(y_train)*100:.1f}%)")

In [None]:
## save preprocess data
saved_data=os.path.join(os.getcwd(), 'processed_data')
os.makedirs(saved_data, exist_ok=True)
# Save using pickle
with open(os.path.join(saved_data,'train_arrays.pkl'), 'wb') as f:
    pickle.dump({'X_train': X_train, 'y_train': y_train, 'label_train': label_train}, f)

with open(os.path.join(saved_data,'val_arrays.pkl'), 'wb') as f:
    pickle.dump({'X_val': X_val, 'y_val': y_val, 'label_val': label_val}, f)

with open(os.path.join(saved_data,'test_arrays.pkl'), 'wb') as f:
    pickle.dump({'X_test': X_test, 'y_test': y_test, 'label_test': label_test}, f)

with open(os.path.join(saved_data,'class_name.pkl'), 'wb') as f:
    pickle.dump({'class_name': mlb.classes_}, f)

print("All arrays saved with pickle!")

# To load later:
# with open(os.path.join(saved_data,'train_arrays.pkl'), 'rb') as f:
#     train_data = pickle.load(f)
#     X_train = train_data['X_train']
#     y_train = train_data['y_train']
#     label_train = train_data['label_train']

## Bag-of-word (TFIDF): Benchmark

In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import xgboost as xgb
import lightgbm as lgb
from lightgbm import early_stopping, log_evaluation
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
# from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, hamming_loss
from sklearn.metrics import (
    precision_score, recall_score, f1_score, 
    roc_auc_score, average_precision_score,
    hamming_loss, jaccard_score
)
from sklearn.preprocessing import MultiLabelBinarizer

import warnings
warnings.filterwarnings('ignore')

In [None]:
# Create TF-IDF vectorizer
# Using parameters optimized for legal text
tfidf = TfidfVectorizer(
    max_features=10000,  # Limit features for computational efficiency
    stop_words='english',
    ngram_range=(1, 2),  # Include unigrams and bigrams
    min_df=5,           # Ignore terms that appear in fewer than 5 documents
    max_df=0.95,        # Ignore terms that appear in more than 95% of documents
    sublinear_tf=True   # Apply sublinear scaling
)

print("Fitting TF-IDF vectorizer...")
X_train_tfidf = tfidf.fit_transform(X_train)
X_val_tfidf = tfidf.transform(X_val)
X_test_tfidf = tfidf.transform(X_test)

print(f"TF-IDF matrix shape (train): {X_train_tfidf.shape}")
print(f"TF-IDF matrix shape (val): {X_val_tfidf.shape}")
print(f"TF-IDF matrix shape (test): {X_test_tfidf.shape}")
print(f"Vocabulary size: {len(tfidf.vocabulary_)}")

# Show some sample features
feature_names = tfidf.get_feature_names_out()
print(f"\nSample features: {feature_names[:20]}")
print(f"Last features: {feature_names[-20:]}")

def comprehensive_evaluation(y_true, y_pred_binary, y_pred_proba, threshold=0.5):
    """
    Comprehensive evaluation function for multi-label classification.
    
    Args:
        y_true: Ground truth binary labels (n_samples, n_labels)
        y_pred_binary: Predicted binary labels (n_samples, n_labels) 
        y_pred_proba: Predicted probabilities (n_samples, n_labels)
        threshold: Threshold for converting probabilities to binary (default: 0.5)
    
    Returns:
        dict: Comprehensive metrics including all averaging methods
    """
    import numpy as np
    from sklearn.metrics import (
        precision_score, recall_score, f1_score, accuracy_score,
        hamming_loss, jaccard_score, roc_auc_score, average_precision_score
    )
    
    # Ensure inputs are numpy arrays
    y_true = np.array(y_true, dtype=int)
    y_pred_binary = np.array(y_pred_binary, dtype=int)
    y_pred_proba = np.array(y_pred_proba, dtype=float)
    
    metrics = {}
    
    try:
        # SAMPLES AVERAGE (per-sample then average across samples)
        metrics['precision_samples'] = precision_score(y_true, y_pred_binary, average='samples', zero_division=0)
        metrics['recall_samples'] = recall_score(y_true, y_pred_binary, average='samples', zero_division=0)
        metrics['f1_samples'] = f1_score(y_true, y_pred_binary, average='samples', zero_division=0)
        
        # MICRO AVERAGE (global average)
        metrics['precision_micro'] = precision_score(y_true, y_pred_binary, average='micro', zero_division=0)
        metrics['recall_micro'] = recall_score(y_true, y_pred_binary, average='micro', zero_division=0)
        metrics['f1_micro'] = f1_score(y_true, y_pred_binary, average='micro', zero_division=0)
        
        # MACRO AVERAGE (unweighted average across labels)
        metrics['precision_macro'] = precision_score(y_true, y_pred_binary, average='macro', zero_division=0)
        metrics['recall_macro'] = recall_score(y_true, y_pred_binary, average='macro', zero_division=0)
        metrics['f1_macro'] = f1_score(y_true, y_pred_binary, average='macro', zero_division=0)
        
        # WEIGHTED AVERAGE (weighted by support)
        metrics['precision_weighted'] = precision_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        metrics['recall_weighted'] = recall_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        metrics['f1_weighted'] = f1_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        
        # ACCURACY METRICS
        metrics['accuracy'] = accuracy_score(y_true, y_pred_binary)
        metrics['hamming_loss'] = hamming_loss(y_true, y_pred_binary)
        
        # JACCARD (IoU) METRICS 
        metrics['jaccard_samples'] = jaccard_score(y_true, y_pred_binary, average='samples', zero_division=0)
        metrics['jaccard_macro'] = jaccard_score(y_true, y_pred_binary, average='macro', zero_division=0)
        metrics['jaccard_weighted'] = jaccard_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        
        # ROC-AUC METRICS (using probabilities)
        try:
            metrics['roc_auc_micro'] = roc_auc_score(y_true, y_pred_proba, average='micro')
            metrics['roc_auc_macro'] = roc_auc_score(y_true, y_pred_proba, average='macro')
            metrics['roc_auc_weighted'] = roc_auc_score(y_true, y_pred_proba, average='weighted')
            metrics['roc_auc_samples'] = roc_auc_score(y_true, y_pred_proba, average='samples')
        except ValueError as e:
            print(f"Warning: ROC-AUC calculation failed: {e}")
            metrics['roc_auc_micro'] = 0.0
            metrics['roc_auc_macro'] = 0.0
            metrics['roc_auc_weighted'] = 0.0
            metrics['roc_auc_samples'] = 0.0
        
        # PR-AUC METRICS (using probabilities)
        try:
            metrics['pr_auc_micro'] = average_precision_score(y_true, y_pred_proba, average='micro')
            metrics['pr_auc_macro'] = average_precision_score(y_true, y_pred_proba, average='macro')
            metrics['pr_auc_weighted'] = average_precision_score(y_true, y_pred_proba, average='weighted')
            metrics['pr_auc_samples'] = average_precision_score(y_true, y_pred_proba, average='samples')
        except ValueError as e:
            print(f"Warning: PR-AUC calculation failed: {e}")
            metrics['pr_auc_micro'] = 0.0
            metrics['pr_auc_macro'] = 0.0
            metrics['pr_auc_weighted'] = 0.0
            metrics['pr_auc_samples'] = 0.0
        
    except Exception as e:
        print(f"Error in comprehensive_evaluation: {e}")
        # Return minimal metrics if calculation fails
        metrics = {
            'precision_micro': 0.0, 'recall_micro': 0.0, 'f1_micro': 0.0,
            'precision_macro': 0.0, 'recall_macro': 0.0, 'f1_macro': 0.0,
            'accuracy': 0.0, 'hamming_loss': 1.0
        }
    
    return metrics

print("✅ Comprehensive evaluation function updated and ready to use!")

In [None]:
def comprehensive_evaluation(y_true, y_pred_proba, y_pred_binary=None, threshold=0.5):
    """
    Comprehensive evaluation for multi-label classification with all averaging methods
    """
    if y_pred_binary is None:
        y_pred_binary = (y_pred_proba >= threshold).astype(int)
    
    metrics = {}
    
    # SAMPLES AVERAGE (per-sample then average across samples)
    metrics['precision_samples'] = precision_score(y_true, y_pred_binary, average='samples', zero_division=0)
    metrics['recall_samples'] = recall_score(y_true, y_pred_binary, average='samples', zero_division=0)
    metrics['f1_samples'] = f1_score(y_true, y_pred_binary, average='samples', zero_division=0)
    
    # MICRO AVERAGE (global aggregation)
    metrics['precision_micro'] = precision_score(y_true, y_pred_binary, average='micro', zero_division=0)
    metrics['recall_micro'] = recall_score(y_true, y_pred_binary, average='micro', zero_division=0)
    metrics['f1_micro'] = f1_score(y_true, y_pred_binary, average='micro', zero_division=0)
    
    # MACRO AVERAGE (unweighted average across labels)
    metrics['precision_macro'] = precision_score(y_true, y_pred_binary, average='macro', zero_division=0)
    metrics['recall_macro'] = recall_score(y_true, y_pred_binary, average='macro', zero_division=0)
    metrics['f1_macro'] = f1_score(y_true, y_pred_binary, average='macro', zero_division=0)
    
    # WEIGHTED AVERAGE (weighted by support/frequency)
    metrics['precision_weighted'] = precision_score(y_true, y_pred_binary, average='weighted', zero_division=0)
    metrics['recall_weighted'] = recall_score(y_true, y_pred_binary, average='weighted', zero_division=0)
    metrics['f1_weighted'] = f1_score(y_true, y_pred_binary, average='weighted', zero_division=0)
    
    # ROC-AUC (multiple averaging methods)
    try:
        metrics['roc_auc_macro'] = roc_auc_score(y_true, y_pred_proba, average='macro')
        metrics['roc_auc_weighted'] = roc_auc_score(y_true, y_pred_proba, average='weighted')
        metrics['roc_auc_samples'] = roc_auc_score(y_true, y_pred_proba, average='samples')
    except ValueError as e:
        print(f"ROC-AUC calculation failed: {e}")
        metrics['roc_auc_macro'] = 0.0
        metrics['roc_auc_weighted'] = 0.0
        metrics['roc_auc_samples'] = 0.0
    
    # Precision-Recall AUC (multiple averaging methods)
    try:
        metrics['pr_auc_macro'] = average_precision_score(y_true, y_pred_proba, average='macro')
        metrics['pr_auc_weighted'] = average_precision_score(y_true, y_pred_proba, average='weighted')
        metrics['pr_auc_samples'] = average_precision_score(y_true, y_pred_proba, average='samples')
    except ValueError as e:
        print(f"PR-AUC calculation failed: {e}")
        metrics['pr_auc_macro'] = 0.0
        metrics['pr_auc_weighted'] = 0.0
        metrics['pr_auc_samples'] = 0.0
    
    # Hamming Loss (inherently micro-averaged)
    metrics['hamming_loss'] = hamming_loss(y_true, y_pred_binary)
    
    # Jaccard Score (multiple averaging methods)
    metrics['jaccard_samples'] = jaccard_score(y_true, y_pred_binary, average='samples', zero_division=0)
    metrics['jaccard_macro'] = jaccard_score(y_true, y_pred_binary, average='macro', zero_division=0)
    metrics['jaccard_weighted'] = jaccard_score(y_true, y_pred_binary, average='weighted', zero_division=0)
    
    # Note: micro average for Jaccard in multi-label is not directly supported in sklearn
    # but can be calculated manually if needed
    
    return metrics

In [None]:
# # Define models to test with optimized hyperparameters and validation-aware training
# models = {
#     'Logistic Regression': OneVsRestClassifier(
#         LogisticRegression(
#             random_state=42, 
#             max_iter=1000,
#             C=1.0,
#             solver='liblinear'
#         )
#     ),
#     'Random Forest': OneVsRestClassifier(
#         RandomForestClassifier(
#             n_estimators=100, 
#             random_state=42, 
#             n_jobs=-1,
#             max_depth=10,
#             min_samples_split=5,
#             min_samples_leaf=2,
#             # Additional overfitting control
#             min_impurity_decrease=0.0001,
#             max_features='sqrt'
#         )
#     ),
#     'XGBoost': OneVsRestClassifier(
#         xgb.XGBClassifier(
#             random_state=42,
#             n_estimators=100,
#             max_depth=6,
#             learning_rate=0.1,
#             subsample=0.8,
#             colsample_bytree=0.8,
#             eval_metric='logloss',
#             verbosity=0,
#             # Early stopping will be handled in training loop
#             early_stopping_rounds=10
#         )
#     ),
#     'LightGBM': OneVsRestClassifier(
#         lgb.LGBMClassifier(
#             random_state=42,
#             n_estimators=100,
#             max_depth=6,
#             learning_rate=0.1,
#             subsample=0.8,
#             colsample_bytree=0.8,
#             verbosity=-1,
#             # Early stopping will be handled in training loop
#             early_stopping_rounds=10
#         )
#     )
# }

# # Enhanced training function with validation monitoring
# def train_with_validation_control(model, X_train, y_train, X_val, y_val, model_name):
#     """
#     Train model with validation monitoring to control overfitting
#     """
#     print(f"\nTraining {model_name} with validation control...")
    
#     if model_name in ['XGBoost', 'LightGBM']:
#         # For tree-based models, we can use early stopping
#         if model_name == 'XGBoost':
#             # XGBoost with early stopping
#             for i, estimator in enumerate(model.estimators_):
#                 print(f"  Training label {i+1}/{len(model.estimators_)}")
                
#                 # Get single label
#                 y_train_single = y_train[:, i]
#                 y_val_single = y_val[:, i]
                
#                 # Only train if there are positive samples
#                 if y_train_single.sum() > 0:
#                     estimator.fit(
#                         X_train, y_train_single,
#                         eval_set=[(X_val, y_val_single)],
#                         verbose=False
#                     )
#                 else:
#                     # For labels with no positive samples, create a dummy classifier
#                     estimator.fit(X_train[:10], y_train_single[:10])
        
#         elif model_name == 'LightGBM':
#             # LightGBM with early stopping
#             for i, estimator in enumerate(model.estimators_):
#                 print(f"  Training label {i+1}/{len(model.estimators_)}")
                
#                 # Get single label
#                 y_train_single = y_train[:, i]
#                 y_val_single = y_val[:, i]
                
#                 # Only train if there are positive samples
#                 if y_train_single.sum() > 0:
#                     estimator.fit(
#                         X_train, y_train_single,
#                         eval_set=[(X_val, y_val_single)],
#                         callbacks=[
#                             early_stopping(10, verbose=False),
#                             log_evaluation(0)  # No logging
#                         ]
#                     )
#                 else:
#                     # For labels with no positive samples, create a dummy classifier
#                     estimator.fit(X_train[:10], y_train_single[:10])
#     else:
#         # For other models, use regular training
#         model.fit(X_train, y_train)
    
#     return model

# # Store results with validation tracking
# results = {}
# validation_scores = {}

# print("Training and evaluating models with validation control...")
# print("="*60)
# print("Models to evaluate:")
# for name in models.keys():
#     print(f"  • {name}")
# print()

# for name, model in models.items():
#     # Train with validation control
#     if name in ['XGBoost', 'LightGBM']:
#         # For tree-based models, we need to handle OneVsRestClassifier manually
#         # to implement early stopping properly
#         trained_model = OneVsRestClassifier(
#             model.estimator,
#             n_jobs=1  # Sequential to handle early stopping
#         )
#         trained_model.fit(X_train_tfidf, y_train)
#     else:
#         trained_model = model
#         trained_model.fit(X_train_tfidf, y_train)
    
#     # Make predictions on all sets
#     y_pred_train = trained_model.predict(X_train_tfidf)
#     y_pred_val = trained_model.predict(X_val_tfidf)
#     y_pred_test = trained_model.predict(X_test_tfidf)
    
#     # Calculate metrics for all sets
#     train_accuracy = accuracy_score(y_train, y_pred_train)
#     val_accuracy = accuracy_score(y_val, y_pred_val)
#     test_accuracy = accuracy_score(y_test, y_pred_test)
    
#     train_hamming = hamming_loss(y_train, y_pred_train)
#     val_hamming = hamming_loss(y_val, y_pred_val)
#     test_hamming = hamming_loss(y_test, y_pred_test)
    
#     # Calculate F1 scores
#     train_f1_micro = f1_score(y_train, y_pred_train, average='micro')
#     val_f1_micro = f1_score(y_val, y_pred_val, average='micro')
#     test_f1_micro = f1_score(y_test, y_pred_test, average='micro')
    
#     # Store results
#     results[name] = {
#         'model': trained_model,
#         'train_accuracy': train_accuracy,
#         'val_accuracy': val_accuracy,
#         'test_accuracy': test_accuracy,
#         'train_hamming_loss': train_hamming,
#         'val_hamming_loss': val_hamming,
#         'test_hamming_loss': test_hamming,
#         'train_f1_micro': train_f1_micro,
#         'val_f1_micro': val_f1_micro,
#         'test_f1_micro': test_f1_micro,
#         'y_pred_test': y_pred_test,
#         'y_pred_val': y_pred_val
#     }
    
#     # Check for overfitting
#     accuracy_gap = train_accuracy - val_accuracy
#     f1_gap = train_f1_micro - val_f1_micro
    
#     overfitting_status = "✅ Good" if accuracy_gap < 0.05 else "⚠️ Moderate" if accuracy_gap < 0.1 else "🚨 High"
    
#     print(f"\n{name} Results:")
#     print(f"  Train Accuracy: {train_accuracy:.4f}")
#     print(f"  Val Accuracy:   {val_accuracy:.4f}")
#     print(f"  Test Accuracy:  {test_accuracy:.4f}")
#     print(f"  Train-Val Gap:  {accuracy_gap:.4f} ({overfitting_status})")
#     print(f"  Train F1:       {train_f1_micro:.4f}")
#     print(f"  Val F1:         {val_f1_micro:.4f}")
#     print(f"  Test F1:        {test_f1_micro:.4f}")
#     print(f"  F1 Gap:         {f1_gap:.4f}")

# print("\n" + "="*80)
# print("Model Comparison with Overfitting Analysis:")
# print(f"{'Model':<15} | {'Test Acc':<8} | {'Val Acc':<8} | {'Gap':<6} | {'Status':<12} | {'Performance':<12}")
# print("-" * 85)

# # Sort results by validation accuracy (better indicator than test accuracy)
# sorted_results = sorted(results.items(), key=lambda x: x[1]['val_accuracy'], reverse=True)

# for name, result in sorted_results:
#     gap = result['train_accuracy'] - result['val_accuracy']
#     status = "Good" if gap < 0.05 else "Moderate" if gap < 0.1 else "High"
#     performance = "🥇 Best" if name == sorted_results[0][0] else "🥈 Good" if result['val_accuracy'] > 0.55 else "⚠️ Poor"
#     print(f"{name:<15} | {result['test_accuracy']:<8.4f} | {result['val_accuracy']:<8.4f} | {gap:<6.4f} | {status:<12} | {performance}")

# # Identify best model based on validation performance
# best_model_name = sorted_results[0][0]
# best_model = sorted_results[0][1]['model']
# print(f"\n🏆 Best performing model (based on validation): {best_model_name}")
# print(f"   Validation Accuracy: {sorted_results[0][1]['val_accuracy']:.4f}")
# print(f"   Test Accuracy: {sorted_results[0][1]['test_accuracy']:.4f}")
# print(f"   Overfitting Gap: {sorted_results[0][1]['train_accuracy'] - sorted_results[0][1]['val_accuracy']:.4f}")

In [None]:
import lightgbm as lgb
import xgboost as xgb
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import (
    precision_score, recall_score, f1_score, 
    roc_auc_score, average_precision_score,
    hamming_loss, jaccard_score, accuracy_score
)
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from tqdm import tqdm

class Train_XGBoost(BaseEstimator, ClassifierMixin):
    """XGBoost classifier with validation-based early stopping for multi-label"""
    
    def __init__(self, **xgb_params):
        self.xgb_params = xgb_params
        self.models_ = []
        self.n_classes_ = None
        
    def fit(self, X, y, X_val=None, y_val=None):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if X_val is not None and len(y_val.shape) == 1:
            y_val = y_val.reshape(-1, 1)
            
        self.n_classes_ = y.shape[1]
        self.models_ = []
        
        for i in tqdm(range(self.n_classes_), total=self.n_classes_, leave=True, position=0):
            
            y_single = y[:, i]
            
            # Skip if no positive samples
            if y_single.sum() == 0:
                self.models_.append(None)
                continue
            
            model = xgb.XGBClassifier(**self.xgb_params)
            
            if X_val is not None and y_val is not None:
                y_val_single = y_val[:, i]
                model.fit(
                    X, y_single,
                    eval_set=[(X_val, y_val_single)],
                    verbose=False
                )
            else:
                model.fit(X, y_single)
            
            self.models_.append(model)
        
        return self
    
    def predict(self, X):
        predictions = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                predictions[:, i] = model.predict(X)
        
        return predictions
    
    def predict_proba(self, X):
        probabilities = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                proba = model.predict_proba(X)
                # Handle case where only one class is present
                if proba.shape[1] == 1:
                    probabilities[:, i] = 0  # All negative class
                else:
                    probabilities[:, i] = proba[:, 1]  # Positive class probability
        
        return probabilities

class Train_LGBM(BaseEstimator, ClassifierMixin):
    """LightGBM classifier with validation-based early stopping for multi-label"""
    
    def __init__(self, **lgb_params):
        self.lgb_params = lgb_params
        self.models_ = []
        self.n_classes_ = None
        
    def fit(self, X, y, X_val=None, y_val=None):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if X_val is not None and len(y_val.shape) == 1:
            y_val = y_val.reshape(-1, 1)
            
        self.n_classes_ = y.shape[1]
        self.models_ = []
        
        for i in tqdm(range(self.n_classes_), total=self.n_classes_, leave=True, position=0):
            
            y_single = y[:, i]
            
            # Skip if no positive samples
            if y_single.sum() == 0:
                self.models_.append(None)
                continue
            
            model = lgb.LGBMClassifier(**self.lgb_params)
            
            if X_val is not None and y_val is not None:
                y_val_single = y_val[:, i]
                model.fit(
                    X, y_single,
                    eval_set=[(X_val, y_val_single)],
                    callbacks=[
                        lgb.early_stopping(10, verbose=False),
                        lgb.log_evaluation(0)
                    ]
                )
            else:
                model.fit(X, y_single)
            
            self.models_.append(model)
        
        return self
    
    def predict(self, X):
        predictions = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                predictions[:, i] = model.predict(X)
        
        return predictions
    
    def predict_proba(self, X):
        probabilities = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                proba = model.predict_proba(X)
                # Handle case where only one class is present
                if proba.shape[1] == 1:
                    probabilities[:, i] = 0  # All negative class
                else:
                    probabilities[:, i] = proba[:, 1]  # Positive class probability
        
        return probabilities

class Train_logistic(BaseEstimator, ClassifierMixin):
    """Logistic Regression classifier with validation monitoring for multi-label"""
    
    def __init__(self, **lr_params):
        self.lr_params = lr_params
        self.models_ = []
        self.n_classes_ = None
        self.validation_scores_ = []
        
    def fit(self, X, y, X_val=None, y_val=None):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if X_val is not None and len(y_val.shape) == 1:
            y_val = y_val.reshape(-1, 1)
            
        self.n_classes_ = y.shape[1]
        self.models_ = []
        self.validation_scores_ = []
        
        for i in tqdm(range(self.n_classes_), total=self.n_classes_, leave=True, position=0):
            
            y_single = y[:, i]
            
            # Skip if no positive samples
            if y_single.sum() == 0:
                self.models_.append(None)
                self.validation_scores_.append(0.0)
                continue
            
            model = LogisticRegression(**self.lr_params)
            model.fit(X, y_single)
            
            # Calculate validation score if validation data provided
            if X_val is not None and y_val is not None:
                y_val_single = y_val[:, i]
                val_score = model.score(X_val, y_val_single)
                self.validation_scores_.append(val_score)
            else:
                self.validation_scores_.append(None)
            
            self.models_.append(model)
        
        return self
    
    def predict(self, X):
        predictions = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                predictions[:, i] = model.predict(X)
        
        return predictions
    
    def predict_proba(self, X):
        probabilities = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                proba = model.predict_proba(X)
                # Handle case where only one class is present
                if proba.shape[1] == 1:
                    probabilities[:, i] = 0  # All negative class
                else:
                    probabilities[:, i] = proba[:, 1]  # Positive class probability
        
        return probabilities
    
    def get_validation_scores(self):
        """Return validation scores for each label"""
        return self.validation_scores_

class Train_RandomForest(BaseEstimator, ClassifierMixin):
    """Random Forest classifier with validation monitoring for multi-label"""
    
    def __init__(self, **rf_params):
        self.rf_params = rf_params
        self.models_ = []
        self.n_classes_ = None
        self.validation_scores_ = []
        self.feature_importances_ = []
        
    def fit(self, X, y, X_val=None, y_val=None):
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        if X_val is not None and len(y_val.shape) == 1:
            y_val = y_val.reshape(-1, 1)
            
        self.n_classes_ = y.shape[1]
        self.models_ = []
        self.validation_scores_ = []
        self.feature_importances_ = []
        
        for i in tqdm(range(self.n_classes_), total=self.n_classes_, leave=True, position=0):
            
            y_single = y[:, i]
            
            # Skip if no positive samples
            if y_single.sum() == 0:
                self.models_.append(None)
                self.validation_scores_.append(0.0)
                self.feature_importances_.append(None)
                continue
            
            model = RandomForestClassifier(**self.rf_params)
            model.fit(X, y_single)
            
            # Store feature importances
            self.feature_importances_.append(model.feature_importances_)
            
            # Calculate validation score if validation data provided
            if X_val is not None and y_val is not None:
                y_val_single = y_val[:, i]
                val_score = model.score(X_val, y_val_single)
                self.validation_scores_.append(val_score)
            else:
                self.validation_scores_.append(None)
            
            self.models_.append(model)
        
        return self
    
    def predict(self, X):
        predictions = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                predictions[:, i] = model.predict(X)
        
        return predictions
    
    def predict_proba(self, X):
        probabilities = np.zeros((X.shape[0], self.n_classes_))
        
        for i, model in enumerate(self.models_):
            if model is not None:
                proba = model.predict_proba(X)
                # Handle case where only one class is present
                if proba.shape[1] == 1:
                    probabilities[:, i] = 0  # All negative class
                else:
                    probabilities[:, i] = proba[:, 1]  # Positive class probability
        
        return probabilities
    
    def get_validation_scores(self):
        """Return validation scores for each label"""
        return self.validation_scores_
    
    def get_feature_importances(self):
        """Return feature importances for each label"""
        return self.feature_importances_

def training_function_with_validation(X_train, y_train, X_val, y_val, model_type='lightgbm'):
    """
    Enhanced training function with proper validation control for multi-label classification
    """
    
    print(f"Training {model_type} with validation control...")
    print(f"X_train shape: {X_train.shape}")
    print(f"y_train shape: {y_train.shape}")
    print(f"X_val shape: {X_val.shape}")
    print(f"y_val shape: {y_val.shape}")
    
    if model_type == 'lightgbm':
        model = Train_LGBM(
            random_state=42,
            n_estimators=200,  # More estimators for early stopping
            max_depth=6,
            learning_rate=0.1,
            subsample=0.8,
            colsample_bytree=0.8,
            verbosity=-1,
            early_stopping_rounds=10
        )
    elif model_type == 'xgboost':
        model = Train_XGBoost(
            random_state=42,
            n_estimators=200,  # More estimators for early stopping
            max_depth=6,
            learning_rate=0.1,
            subsample=0.8,
            colsample_bytree=0.8,
            eval_metric='logloss',
            verbosity=0,
            early_stopping_rounds=10
        )
    elif model_type == 'logistic':
        model = Train_logistic(
            random_state=42,
            max_iter=1000,
            C=1.0,
            solver='liblinear',
            class_weight='balanced'  # Handle class imbalance
        )
    elif model_type == 'randomforest':
        model = Train_RandomForest(
            random_state=42,
            n_estimators=100,
            max_depth=10,
            min_samples_split=5,
            min_samples_leaf=2,
            max_features='sqrt',
            class_weight='balanced',  # Handle class imbalance
            n_jobs=-1
        )
    else:
        raise ValueError("Supported model types: 'lightgbm', 'xgboost', 'logistic', 'randomforest'")
    
    # Fit with validation data
    model.fit(X_train, y_train, X_val, y_val)
    
    # Make predictions
    y_pred_train = model.predict(X_train)
    y_pred_val = model.predict(X_val)
    
    # Calculate metrics
    train_acc = accuracy_score(y_train, y_pred_train)
    val_acc = accuracy_score(y_val, y_pred_val)
    train_f1 = f1_score(y_train, y_pred_train, average='micro')
    val_f1 = f1_score(y_val, y_pred_val, average='micro')
    
    # Calculate hamming loss (lower is better)
    train_hamming = hamming_loss(y_train, y_pred_train)
    val_hamming = hamming_loss(y_val, y_pred_val)
    
    # Calculate overfitting gaps for different metrics
    accuracy_gap = train_acc - val_acc
    f1_gap = train_f1 - val_f1
    hamming_gap = val_hamming - train_hamming  # Note: val - train because lower hamming is better
    
    print(f"Training completed!")
    print(f"Train Accuracy: {train_acc:.4f}")
    print(f"Val Accuracy: {val_acc:.4f}")
    print(f"Train F1: {train_f1:.4f}")
    print(f"Val F1: {val_f1:.4f}")
    print(f"Train Hamming Loss: {train_hamming:.4f}")
    print(f"Val Hamming Loss: {val_hamming:.4f}")
    print(f"Overfitting Gap (Accuracy): {accuracy_gap:.4f}")
    print(f"Overfitting Gap (F1): {f1_gap:.4f}")
    print(f"Overfitting Gap (Hamming): {hamming_gap:.4f}")
    
    return model, {
        'train_accuracy': train_acc,
        'val_accuracy': val_acc,
        'train_f1': train_f1,
        'val_f1': val_f1,
        'train_hamming_loss': train_hamming,
        'val_hamming_loss': val_hamming,
        'accuracy_gap': accuracy_gap,
        'f1_gap': f1_gap,
        'hamming_gap': hamming_gap,
        'overfitting_gap': hamming_gap  # Use hamming gap as primary overfitting indicator
    }



In [None]:
# Comprehensive Model Comparison with Validation Control

def compare_all_models(X_train, y_train, X_val, y_val, X_test, y_test):
    """
    Train and compare all models with validation control
    """
    
    print("🚀 COMPREHENSIVE MODEL COMPARISON WITH VALIDATION CONTROL")
    print("="*80)
    
    models_to_test = ['logistic', 'randomforest', 'lightgbm', 'xgboost']
    results = {}
    
    for model_type in models_to_test:
        print(f"\n{'='*60}")
        print(f"🔧 Training {model_type.upper()} Model")
        print(f"{'='*60}")
        
        try:
            # Train model with validation
            model, metrics = training_function_with_validation(
                X_train, y_train, X_val, y_val, model_type=model_type
            )
            
            # Test on unseen data
            y_pred_test = model.predict(X_test)
            test_acc = accuracy_score(y_test, y_pred_test)
            test_f1 = f1_score(y_test, y_pred_test, average='micro')
            test_hamming = hamming_loss(y_test, y_pred_test)
            
            # Store all results
            results[model_type] = {
                'model': model,
                'train_accuracy': metrics['train_accuracy'],
                'val_accuracy': metrics['val_accuracy'],
                'test_accuracy': test_acc,
                'train_f1': metrics['train_f1'],
                'val_f1': metrics['val_f1'],
                'test_f1': test_f1,
                'train_hamming_loss': metrics['train_hamming_loss'],
                'val_hamming_loss': metrics['val_hamming_loss'],
                'test_hamming_loss': test_hamming,
                'accuracy_gap': metrics['accuracy_gap'],
                'f1_gap': metrics['f1_gap'],
                'hamming_gap': metrics['hamming_gap'],
                'overfitting_gap': metrics['overfitting_gap']  # Based on hamming loss
            }
            
            print(f"✅ {model_type.upper()} completed successfully!")
            print(f"   Test Accuracy: {test_acc:.4f}")
            print(f"   Test F1: {test_f1:.4f}")
            print(f"   Test Hamming Loss: {test_hamming:.4f}")
            print(f"   Overfitting Gap (Hamming): {metrics['overfitting_gap']:.4f}")
            
        except Exception as e:
            print(f"❌ Error training {model_type}: {str(e)}")
            results[model_type] = None
    
    return results

def analyze_model_results(results):
    """
    Analyze and display comprehensive results
    """
    
    print(f"\n{'='*100}")
    print("📊 COMPREHENSIVE MODEL ANALYSIS")
    print(f"{'='*100}")
    
    # Filter successful results
    successful_results = {k: v for k, v in results.items() if v is not None}
    
    if not successful_results:
        print("❌ No models trained successfully!")
        return
    
    # Display detailed comparison table
    print(f"\n{'Model':<15} | {'Train Acc':<9} | {'Val Acc':<9} | {'Test Acc':<9} | {'Train Ham':<9} | {'Val Ham':<8} | {'Test Ham':<8} | {'Ham Gap':<8} | {'Status'}")
    print("-" * 105)
    
    # Sort by validation accuracy (best practice)
    sorted_results = sorted(successful_results.items(), 
                          key=lambda x: x[1]['val_accuracy'], reverse=True)
    
    for rank, (model_name, result) in enumerate(sorted_results, 1):
        hamming_gap = result['hamming_gap']
        
        # Determine overfitting status based on hamming gap
        # For hamming loss, positive gap means validation is worse (overfitting)
        if hamming_gap < 0.01:
            status = "✅ Excellent"
        elif hamming_gap < 0.02:
            status = "🟢 Good"
        elif hamming_gap < 0.04:
            status = "🟡 Moderate"
        else:
            status = "🔴 High"
        
        rank_emoji = "🥇" if rank == 1 else "🥈" if rank == 2 else "🥉" if rank == 3 else "4️⃣"
        
        print(f"{model_name.upper():<15} | {result['train_accuracy']:<9.4f} | {result['val_accuracy']:<9.4f} | "
              f"{result['test_accuracy']:<9.4f} | {result['train_hamming_loss']:<9.4f} | {result['val_hamming_loss']:<8.4f} | "
              f"{result['test_hamming_loss']:<8.4f} | {hamming_gap:<8.4f} | {status}")
    
    # Identify best models
    best_model = sorted_results[0]
    print(f"\n🏆 BEST MODEL (Based on Validation Performance): {best_model[0].upper()}")
    print(f"   📈 Validation Accuracy: {best_model[1]['val_accuracy']:.4f}")
    print(f"   🎯 Test Accuracy: {best_model[1]['test_accuracy']:.4f}")
    print(f"   📊 Test F1 Score: {best_model[1]['test_f1']:.4f}")
    print(f"   🔻 Test Hamming Loss: {best_model[1]['test_hamming_loss']:.4f}")
    print(f"   ⚖️ Overfitting Gap (Hamming): {best_model[1]['overfitting_gap']:.4f}")
    print(f"   📏 Accuracy Gap: {best_model[1]['accuracy_gap']:.4f}")
    print(f"   📈 F1 Gap: {best_model[1]['f1_gap']:.4f}")
    
    # Best test performance (might be different from best validation)
    best_test = max(successful_results.items(), key=lambda x: x[1]['test_accuracy'])
    if best_test[0] != best_model[0]:
        print(f"\n🎯 BEST TEST PERFORMANCE: {best_test[0].upper()}")
        print(f"   Test Accuracy: {best_test[1]['test_accuracy']:.4f}")
        print(f"   (Note: Choose model based on validation, not test performance)")
    
    # Best hamming loss performance
    best_hamming = min(successful_results.items(), key=lambda x: x[1]['test_hamming_loss'])
    if best_hamming[0] != best_model[0]:
        print(f"\n🔻 BEST HAMMING LOSS PERFORMANCE: {best_hamming[0].upper()}")
        print(f"   Test Hamming Loss: {best_hamming[1]['test_hamming_loss']:.4f}")
        print(f"   (Lower hamming loss = better multi-label performance)")
    
    # Model-specific insights
    print(f"\n{'='*80}")
    print("🔍 MODEL-SPECIFIC INSIGHTS:")
    print(f"{'='*80}")
    
    for model_name, result in successful_results.items():
        if hasattr(result['model'], 'get_validation_scores'):
            val_scores = result['model'].get_validation_scores()
            if val_scores and any(score for score in val_scores if score is not None):
                valid_scores = [s for s in val_scores if s is not None and s > 0]
                if valid_scores:
                    avg_label_score = np.mean(valid_scores)
                    print(f"{model_name.upper()}:")
                    print(f"   Average per-label validation score: {avg_label_score:.4f}")
                    print(f"   Labels with good performance (>0.8): {sum(1 for s in valid_scores if s > 0.8)}/{len(valid_scores)}")
    
    # Recommendations
    print(f"\n{'='*80}")
    print("💡 RECOMMENDATIONS:")
    print(f"{'='*80}")
    
    if best_model[1]['overfitting_gap'] < 0.02:
        print("✅ Your best model shows excellent generalization based on Hamming loss!")
    elif best_model[1]['overfitting_gap'] < 0.04:
        print("🟢 Your best model shows good generalization based on Hamming loss!")
    else:
        print("⚠️ Consider additional regularization for your best model:")
        print("   - Increase regularization parameters")
        print("   - Use more training data")
        print("   - Apply feature selection")
        print("   - Consider ensemble methods")
    
    hamming_gap_threshold = 0.02
    models_with_overfitting = [name for name, result in successful_results.items() 
                              if result['hamming_gap'] > hamming_gap_threshold]
    
    if models_with_overfitting:
        print(f"\n⚠️ Models showing overfitting based on Hamming loss (gap > {hamming_gap_threshold}):")
        for model in models_with_overfitting:
            result = successful_results[model]
            print(f"   - {model.upper()}:")
            print(f"     • Hamming Gap: {result['hamming_gap']:.4f}")
            print(f"     • Accuracy Gap: {result['accuracy_gap']:.4f}")
            print(f"     • F1 Gap: {result['f1_gap']:.4f}")
    
    print(f"\n🎯 Model Selection Priority (Updated with Hamming Loss):")
    print("   1. Choose model with best VALIDATION performance")
    print("   2. Prefer models with smaller Hamming loss gap (primary indicator)")
    print("   3. Consider accuracy and F1 gaps as secondary indicators")
    print("   4. Evaluate computational efficiency for deployment")
    print("   5. Lower Hamming loss = better multi-label classification performance")
    
    print(f"\n📊 Understanding Hamming Loss:")
    print("   • Hamming Loss measures label-wise classification errors")
    print("   • Perfect score = 0.0, higher values = more errors")
    print("   • Particularly important for multi-label problems")
    print("   • Gap = Val_Hamming - Train_Hamming (positive = overfitting)")
    
    return successful_results

# Example usage
print("Starting comprehensive model comparison...")
all_results = compare_all_models(X_train_tfidf, y_train, X_val_tfidf, y_val, X_test_tfidf, y_test)
final_analysis = analyze_model_results(all_results)

In [None]:
# # Example usage of the enhanced training function
# print("Testing enhanced validation-controlled training...")
# print("="*60)

# # Test with LightGBM
# lgbm_model, lgbm_metrics = training_function_with_validation(
#     X_train_tfidf, y_train, X_val_tfidf, y_val, model_type='lightgbm'
# )



In [None]:
# # Test with XGBoost
# print(f"\n{'-'*40}")
# xgb_model, xgb_metrics = training_function_with_validation(
#     X_train_tfidf, y_train, X_val_tfidf, y_val, model_type='xgboost'
# )

# print(f"\nXGBoost Results:")
# print(f"  Validation Accuracy: {xgb_metrics['val_accuracy']:.4f}")
# print(f"  Overfitting Gap: {xgb_metrics['overfitting_gap']:.4f}")

# # Final test predictions
# lgbm_test_pred = lgbm_model.predict(X_test_tfidf)
# xgb_test_pred = xgb_model.predict(X_test_tfidf)

# lgbm_test_acc = accuracy_score(y_test, lgbm_test_pred)
# xgb_test_acc = accuracy_score(y_test, xgb_test_pred)

# print(f"\nFinal Test Results:")
# print(f"  LightGBM Test Accuracy: {lgbm_test_acc:.4f}")
# print(f"  XGBoost Test Accuracy: {xgb_test_acc:.4f}")

# # Determine best model
# if lgbm_metrics['val_accuracy'] > xgb_metrics['val_accuracy']:
#     best_val_model = 'LightGBM'
#     best_model = lgbm_model
#     best_test_acc = lgbm_test_acc
# else:
#     best_val_model = 'XGBoost'
#     best_model = xgb_model
#     best_test_acc = xgb_test_acc

# print(f"\n🏆 Best validation-controlled model: {best_val_model}")
# print(f"   Test Accuracy: {best_test_acc:.4f}")

In [None]:
# # Additional Validation Techniques for Overfitting Control

# from sklearn.model_selection import cross_val_score, StratifiedKFold
# from sklearn.model_selection import validation_curve, learning_curve
# import matplotlib.pyplot as plt

# def plot_learning_curve(estimator, X, y, title, cv=5, n_jobs=-1, 
#                        train_sizes=np.linspace(0.1, 1.0, 10)):
#     """
#     Generate a plot showing the learning curve for a model
#     """
#     train_sizes, train_scores, val_scores = learning_curve(
#         estimator, X, y, cv=cv, n_jobs=n_jobs, 
#         train_sizes=train_sizes, scoring='accuracy'
#     )
    
#     train_scores_mean = np.mean(train_scores, axis=1)
#     train_scores_std = np.std(train_scores, axis=1)
#     val_scores_mean = np.mean(val_scores, axis=1)
#     val_scores_std = np.std(val_scores, axis=1)
    
#     plt.figure(figsize=(10, 6))
#     plt.plot(train_sizes, train_scores_mean, 'o-', color='blue', label='Training score')
#     plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
#                      train_scores_mean + train_scores_std, alpha=0.1, color='blue')
    
#     plt.plot(train_sizes, val_scores_mean, 'o-', color='red', label='Cross-validation score')
#     plt.fill_between(train_sizes, val_scores_mean - val_scores_std,
#                      val_scores_mean + val_scores_std, alpha=0.1, color='red')
    
#     plt.xlabel('Training Set Size')
#     plt.ylabel('Accuracy Score')
#     plt.title(f'Learning Curve - {title}')
#     plt.legend(loc='best')
#     plt.grid(True, alpha=0.3)
#     plt.tight_layout()
#     plt.show()
    
#     # Detect overfitting
#     final_gap = train_scores_mean[-1] - val_scores_mean[-1]
#     if final_gap > 0.1:
#         print(f"⚠️ WARNING: {title} shows signs of overfitting (gap: {final_gap:.4f})")
#     elif final_gap > 0.05:
#         print(f"🔶 MODERATE: {title} shows moderate overfitting (gap: {final_gap:.4f})")
#     else:
#         print(f"✅ GOOD: {title} shows good generalization (gap: {final_gap:.4f})")

# def cross_validate_with_overfitting_check(model, X, y, cv=5, model_name="Model"):
#     """
#     Perform cross-validation and check for overfitting signs
#     """
#     print(f"\nCross-validating {model_name}...")
    
#     # Perform cross-validation
#     cv_scores = cross_val_score(model, X, y, cv=cv, scoring='accuracy', n_jobs=-1)
    
#     # Train on full dataset to check training score
#     model.fit(X, y)
#     train_score = model.score(X, y)
    
#     cv_mean = cv_scores.mean()
#     cv_std = cv_scores.std()
    
#     print(f"  Cross-validation scores: {cv_scores}")
#     print(f"  CV Mean ± Std: {cv_mean:.4f} ± {cv_std:.4f}")
#     print(f"  Training score: {train_score:.4f}")
    
#     # Check for overfitting
#     overfitting_gap = train_score - cv_mean
#     print(f"  Overfitting gap: {overfitting_gap:.4f}")
    
#     if overfitting_gap > 0.1:
#         status = "🚨 HIGH OVERFITTING"
#     elif overfitting_gap > 0.05:
#         status = "⚠️ MODERATE OVERFITTING"
#     else:
#         status = "✅ GOOD GENERALIZATION"
    
#     print(f"  Status: {status}")
    
#     return {
#         'cv_scores': cv_scores,
#         'cv_mean': cv_mean,
#         'cv_std': cv_std,
#         'train_score': train_score,
#         'overfitting_gap': overfitting_gap,
#         'status': status
#     }

# def plot_validation_curve_param(estimator, X, y, param_name, param_range, title):
#     """
#     Plot validation curve for a specific parameter to find optimal value
#     """
#     train_scores, val_scores = validation_curve(
#         estimator, X, y, param_name=param_name, param_range=param_range,
#         cv=5, scoring='accuracy', n_jobs=-1
#     )
    
#     train_scores_mean = np.mean(train_scores, axis=1)
#     train_scores_std = np.std(train_scores, axis=1)
#     val_scores_mean = np.mean(val_scores, axis=1)
#     val_scores_std = np.std(val_scores, axis=1)
    
#     plt.figure(figsize=(10, 6))
#     plt.semilogx(param_range, train_scores_mean, 'o-', color='blue', label='Training score')
#     plt.fill_between(param_range, train_scores_mean - train_scores_std,
#                      train_scores_mean + train_scores_std, alpha=0.1, color='blue')
    
#     plt.semilogx(param_range, val_scores_mean, 'o-', color='red', label='Cross-validation score')
#     plt.fill_between(param_range, val_scores_mean - val_scores_std,
#                      val_scores_mean + val_scores_std, alpha=0.1, color='red')
    
#     plt.xlabel(param_name)
#     plt.ylabel('Accuracy Score')
#     plt.title(f'Validation Curve - {title}')
#     plt.legend(loc='best')
#     plt.grid(True, alpha=0.3)
#     plt.tight_layout()
#     plt.show()
    
#     # Find optimal parameter
#     optimal_idx = np.argmax(val_scores_mean)
#     optimal_param = param_range[optimal_idx]
#     optimal_score = val_scores_mean[optimal_idx]
    
#     print(f"Optimal {param_name}: {optimal_param}")
#     print(f"Optimal CV score: {optimal_score:.4f}")
    
#     return optimal_param, optimal_score

# # Example: Cross-validation analysis for overfitting detection
# print("COMPREHENSIVE VALIDATION ANALYSIS")
# print("="*60)

# # Sample a subset for faster computation in demo
# sample_size = min(1000, len(X_train_tfidf))
# X_sample = X_train_tfidf[:sample_size]
# y_sample = y_train[:sample_size]

# print(f"Using sample of {sample_size} examples for validation analysis...")

# # 1. Cross-validation for different models
# models_for_cv = {
#     'Logistic Regression': LogisticRegression(random_state=42, max_iter=1000),
#     'Random Forest': RandomForestClassifier(n_estimators=50, random_state=42, max_depth=10),
# }

# cv_results = {}
# for name, model in models_for_cv.items():
#     # Use OneVsRestClassifier for multi-label
#     multi_label_model = OneVsRestClassifier(model)
#     cv_results[name] = cross_validate_with_overfitting_check(
#         multi_label_model, X_sample, y_sample, cv=3, model_name=name
#     )

# # 2. Find models with best generalization
# print(f"\n{'='*60}")
# print("OVERFITTING SUMMARY:")
# print(f"{'Model':<20} | {'CV Score':<10} | {'Gap':<8} | {'Status'}")
# print(f"{'-'*65}")

# for name, results in cv_results.items():
#     print(f"{name:<20} | {results['cv_mean']:<10.4f} | {results['overfitting_gap']:<8.4f} | {results['status']}")

# # 3. Recommendations for overfitting control
# print(f"\n{'='*60}")
# print("RECOMMENDATIONS FOR OVERFITTING CONTROL:")
# print()
# print("1. 📊 VALIDATION MONITORING:")
# print("   - Always split data into train/validation/test")
# print("   - Monitor validation metrics during training")
# print("   - Use early stopping when validation stops improving")
# print()
# print("2. 🔧 MODEL REGULARIZATION:")
# print("   - Logistic Regression: Adjust C parameter (lower = more regularization)")
# print("   - Random Forest: Limit max_depth, increase min_samples_split")
# print("   - XGBoost/LightGBM: Use early_stopping_rounds, adjust learning_rate")
# print()
# print("3. 📈 TECHNIQUES IMPLEMENTED:")
# print("   - Train/Validation/Test split (70/15/15)")
# print("   - Cross-validation for robust evaluation")
# print("   - Early stopping for tree-based models")
# print("   - Validation gap monitoring")
# print("   - Learning curve analysis")
# print()
# print("4. 🎯 SELECTION CRITERIA:")
# print("   - Choose model with best VALIDATION performance")
# print("   - Prefer models with smaller train-validation gap")
# print("   - Consider cross-validation consistency")

# # Example of how to use validation curve for parameter tuning
# print(f"\n{'='*60}")
# print("PARAMETER TUNING WITH VALIDATION CURVES:")
# print("(Use this approach to find optimal hyperparameters)")
# print()
# print("Example code for Random Forest max_depth tuning:")
# print("""
# # Find optimal max_depth for Random Forest
# param_range = [3, 5, 7, 10, 15, 20]
# optimal_depth, optimal_score = plot_validation_curve_param(
#     OneVsRestClassifier(RandomForestClassifier(random_state=42)),
#     X_train_tfidf, y_train,
#     param_name='estimator__max_depth',
#     param_range=param_range,
#     title='Random Forest max_depth'
# )
# """)

## Transformers Encoder Model (MordenBERT)

In [19]:
import os
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm
from datasets import Dataset, DatasetDict
from datasets import Sequence, Value
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
from transformers import EarlyStoppingCallback
import evaluate
import warnings

warnings.filterwarnings('ignore')

In [3]:
### Load Dataset for model training and evaluation ###
data_path=os.path.join(os.getcwd(), 'processed_data')
with open(os.path.join(data_path,'train_arrays.pkl'), 'rb') as f:
    train_data = pickle.load(f)
    X_train = train_data['X_train']
    y_train = train_data['y_train']

with open(os.path.join(data_path,'val_arrays.pkl'), 'rb') as f:
    val_data = pickle.load(f)
    X_val = val_data['X_val']
    y_val = val_data['y_val']

with open(os.path.join(data_path,'test_arrays.pkl'), 'rb') as f:
    test_data = pickle.load(f)
    X_test = test_data['X_test']
    y_test = test_data['y_test']

with open(os.path.join(data_path,'class_name.pkl'), 'rb') as f:
    class_name_data = pickle.load(f)
    class_name = class_name_data['class_name']

In [4]:
def create_datasets_from_arrays(X_train, y_train, X_val=None, y_val=None, X_test=None, y_test=None):
    """
    Convert arrays into HuggingFace datasets format with specified structure
    
    Returns:
        DatasetDict with features:
        - dataset["train"]["text"]: text data
        - dataset["train"]["labels"]: multi-label arrays
        - dataset["val"]["text"]: validation text data (if provided)
        - dataset["val"]["labels"]: validation labels (if provided)
        - dataset["test"]["text"]: test text data (if provided)
        - dataset["test"]["labels"]: test labels (if provided)
    """
    # Create training dataset
    train_dict = {
        "text": X_train.tolist() if hasattr(X_train, 'tolist') else list(X_train),
        "labels": y_train.tolist() if hasattr(y_train, 'tolist') else list(y_train)
    }
    
    datasets_dict = {
        "train": Dataset.from_dict(train_dict)
    }
    
    # Add validation dataset if provided
    if X_val is not None and y_val is not None:
        val_dict = {
            "text": X_val.tolist() if hasattr(X_val, 'tolist') else list(X_val),
            "labels": y_val.tolist() if hasattr(y_val, 'tolist') else list(y_val)
        }
        datasets_dict["val"] = Dataset.from_dict(val_dict)
    
    # Add test dataset if provided
    if X_test is not None and y_test is not None:
        test_dict = {
            "text": X_test.tolist() if hasattr(X_test, 'tolist') else list(X_test),
            "labels": y_test.tolist() if hasattr(y_test, 'tolist') else list(y_test)
        }
        datasets_dict["test"] = Dataset.from_dict(test_dict)

    # Create DatasetDict
    dataset = DatasetDict(datasets_dict)
    
    return dataset

In [5]:
# Create the datasets
dataset = create_datasets_from_arrays(X_train, y_train, X_val, y_val, X_test, y_test)

dataset


DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 11597
    })
    val: Dataset({
        features: ['text', 'labels'],
        num_rows: 2485
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 2486
    })
})

In [6]:
model_path = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_path)

print("{:<25}{:<15,}".format("Maximal context length:",tokenizer.model_max_length))
print("{:<25}{:<15,}".format("Vocabulary size :",tokenizer.vocab_size))

Maximal context length:  8,192          
Vocabulary size :        50,280         


In [7]:
# 🔧 FIXED TOKENIZATION AND DATA FORMAT
# This section addresses the data format issues that cause training failures

def preprocess_function(examples):
    """
    Proper tokenization function for multi-label classification.
    Ensures all outputs are compatible with HuggingFace Trainer.
    """
    # Handle batch vs single example
    if isinstance(examples['text'], str):
        texts = [examples['text']]
        labels = [examples['labels']]
    else:
        texts = examples['text']
        labels = examples['labels']
    
    # Tokenize the texts
    tokenized = tokenizer(
        texts,
        truncation=True,
        padding=True,  # Will be handled by data collator
        # max_length=tokenizer.model_max_length,  
        max_length=1024, # Adjust based on your model's limit
        return_tensors=None  # Don't return tensors yet, let data collator handle it
    )
    
    # Ensure labels are float32 for BCEWithLogitsLoss
    if isinstance(labels[0], (list, np.ndarray)):
        tokenized['labels'] = [np.array(label, dtype=np.float32).tolist() for label in labels]
    else:
        tokenized['labels'] = [np.array(labels, dtype=np.float32).tolist()]
    
    return tokenized

print("🔧 Re-tokenizing dataset with fixed function...")

# Apply the tokenization function
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=['text'],  # Remove the problematic text column
    desc="Tokenizing dataset"
)

# Define the proper feature type for multi-label classification
label_feature = Sequence(Value("float32"), length=len(class_name))

# Cast the labels column to float32 for all splits
print("🔄 Converting labels to float32...")
for split_name in tokenized_dataset.keys():
    tokenized_dataset[split_name] = tokenized_dataset[split_name].cast_column("labels", label_feature)

# Verify the tokenized dataset structure
print("\n✅ Tokenized dataset verification:")
print(f"Features: {list(tokenized_dataset['train'].features.keys())}")

# Check a sample
sample = tokenized_dataset["train"][0]
print(f"\nSample structure:")
for key, value in sample.items():
    if isinstance(value, (list, np.ndarray)):
        value_info = f"List/Array of length {len(value)}, dtype: {type(value[0]) if value else 'empty'}"
        if key == 'labels':
            value_info += f", shape: {np.array(value).shape}, sum: {np.sum(value)}"
    else:
        value_info = f"Type: {type(value)}, Value: {value}"
    print(f"  {key}: {value_info}")

# Verify labels are properly formatted
sample_labels = np.array(sample['labels'])
print(f"\n🎯 Labels verification:")
print(f"  Labels dtype: {sample_labels.dtype}")
print(f"  Labels shape: {sample_labels.shape}")
print(f"  Expected shape: ({len(class_name)},)")
print(f"  Labels range: [{sample_labels.min():.1f}, {sample_labels.max():.1f}]")

# Test tensor conversion
test_labels = torch.tensor(sample['labels'], dtype=torch.float32)
print(f"  PyTorch tensor dtype: {test_labels.dtype}")
print(f"  PyTorch tensor shape: {test_labels.shape}")

if test_labels.dtype == torch.float32:
    print("✅ SUCCESS: Labels are properly formatted as float32")
    print("🚀 Ready for training!")
else:
    print(f"❌ ISSUE: Labels are {test_labels.dtype}, expected float32")

print(f"\n📊 Dataset sizes after tokenization:")
for split_name, split_data in tokenized_dataset.items():
    print(f"  {split_name}: {len(split_data)} samples")

🔧 Re-tokenizing dataset with fixed function...


Tokenizing dataset:   0%|          | 0/11597 [00:00<?, ? examples/s]

Tokenizing dataset:   0%|          | 0/2485 [00:00<?, ? examples/s]

Tokenizing dataset:   0%|          | 0/2486 [00:00<?, ? examples/s]

🔄 Converting labels to float32...


Casting the dataset:   0%|          | 0/11597 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2485 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2486 [00:00<?, ? examples/s]


✅ Tokenized dataset verification:
Features: ['labels', 'input_ids', 'attention_mask']

Sample structure:
  labels: List/Array of length 27, dtype: <class 'float'>, shape: (27,), sum: 1.0
  input_ids: List/Array of length 1024, dtype: <class 'int'>
  attention_mask: List/Array of length 1024, dtype: <class 'int'>

🎯 Labels verification:
  Labels dtype: float64
  Labels shape: (27,)
  Expected shape: (27,)
  Labels range: [0.0, 1.0]
  PyTorch tensor dtype: torch.float32
  PyTorch tensor shape: torch.Size([27])
✅ SUCCESS: Labels are properly formatted as float32
🚀 Ready for training!

📊 Dataset sizes after tokenization:
  train: 11597 samples
  val: 2485 samples
  test: 2486 samples


In [8]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 11597
    })
    val: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 2485
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 2486
    })
})

In [9]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

class2id = {class_:id for id, class_ in enumerate(class_name)}
id2class = {id:class_ for class_, id in class2id.items()}


model = AutoModelForSequenceClassification.from_pretrained(model_path, 
                                                           num_labels=len(class_name),
                                                           id2label=id2class, 
                                                           label2id=class2id,
                                                           problem_type = "multi_label_classification"
)

print()

# Verify model is properly configured for multi-label classification
print("🤖 Model Configuration Verification:")
print(f"  Model type: {type(model).__name__}")
print(f"  Number of labels: {model.config.num_labels}")
print(f"  Problem type: {getattr(model.config, 'problem_type', 'Not set')}")
print(f"  Expected labels: {len(class_name)}")

# Check if model configuration matches our data
if model.config.num_labels != len(class_name):
    print(f"⚠️ WARNING: Model expects {model.config.num_labels} labels, but data has {len(class_name)}")
    print("  This might cause issues during training")
else:
    print(f"✅ Model configuration matches data: {len(class_name)} labels")

# Verify model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n📊 Model Parameters:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Non-trainable parameters: {total_params - trainable_params:,}")

print(f"\n✅ Data collator and model setup completed!")


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



🤖 Model Configuration Verification:
  Model type: ModernBertForSequenceClassification
  Number of labels: 27
  Problem type: multi_label_classification
  Expected labels: 27
✅ Model configuration matches data: 27 labels

📊 Model Parameters:
  Total parameters: 149,625,627
  Trainable parameters: 149,625,627
  Non-trainable parameters: 0

✅ Data collator and model setup completed!


In [10]:
# 📊 COMPREHENSIVE EVALUATION FUNCTION FOR MULTI-LABEL CLASSIFICATION
# This function calculates all relevant metrics for multi-label problems

import numpy as np
from sklearn.metrics import (
    precision_score, recall_score, f1_score, 
    roc_auc_score, average_precision_score,
    hamming_loss, jaccard_score, accuracy_score
)

def sigmoid(x):
    """Sigmoid activation function"""
    return 1/(1 + np.exp(-x))

def comprehensive_evaluation(y_true, y_pred_proba, y_pred_binary=None, threshold=0.5):
    """
    Comprehensive evaluation for multi-label classification with all averaging methods
    
    Args:
        y_true: Ground truth binary labels (n_samples, n_labels)
        y_pred_proba: Predicted probabilities (n_samples, n_labels)
        y_pred_binary: Predicted binary labels (n_samples, n_labels), optional
        threshold: Threshold for converting probabilities to binary (default: 0.5)
    
    Returns:
        dict: Comprehensive metrics including all averaging methods
    """
    if y_pred_binary is None:
        y_pred_binary = (y_pred_proba >= threshold).astype(int)
    
    metrics = {}
    
    try:
        # SAMPLES AVERAGE (per-sample then average across samples)
        metrics['precision_samples'] = precision_score(y_true, y_pred_binary, average='samples', zero_division=0)
        metrics['recall_samples'] = recall_score(y_true, y_pred_binary, average='samples', zero_division=0)
        metrics['f1_samples'] = f1_score(y_true, y_pred_binary, average='samples', zero_division=0)
        
        # MICRO AVERAGE (global aggregation)
        metrics['precision_micro'] = precision_score(y_true, y_pred_binary, average='micro', zero_division=0)
        metrics['recall_micro'] = recall_score(y_true, y_pred_binary, average='micro', zero_division=0)
        metrics['f1_micro'] = f1_score(y_true, y_pred_binary, average='micro', zero_division=0)
        
        # MACRO AVERAGE (unweighted average across labels)
        metrics['precision_macro'] = precision_score(y_true, y_pred_binary, average='macro', zero_division=0)
        metrics['recall_macro'] = recall_score(y_true, y_pred_binary, average='macro', zero_division=0)
        metrics['f1_macro'] = f1_score(y_true, y_pred_binary, average='macro', zero_division=0)
        
        # WEIGHTED AVERAGE (weighted by support)
        metrics['precision_weighted'] = precision_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        metrics['recall_weighted'] = recall_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        metrics['f1_weighted'] = f1_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        
        # ACCURACY METRICS
        metrics['accuracy'] = accuracy_score(y_true, y_pred_binary)
        metrics['hamming_loss'] = hamming_loss(y_true, y_pred_binary)
        
        # JACCARD (IoU) METRICS 
        metrics['jaccard_samples'] = jaccard_score(y_true, y_pred_binary, average='samples', zero_division=0)
        metrics['jaccard_macro'] = jaccard_score(y_true, y_pred_binary, average='macro', zero_division=0)
        metrics['jaccard_weighted'] = jaccard_score(y_true, y_pred_binary, average='weighted', zero_division=0)
        
        # ROC-AUC METRICS (using probabilities)
        try:
            metrics['roc_auc_micro'] = roc_auc_score(y_true, y_pred_proba, average='micro')
            metrics['roc_auc_macro'] = roc_auc_score(y_true, y_pred_proba, average='macro')
            metrics['roc_auc_weighted'] = roc_auc_score(y_true, y_pred_proba, average='weighted')
            metrics['roc_auc_samples'] = roc_auc_score(y_true, y_pred_proba, average='samples')
        except ValueError as e:
            print(f"Warning: ROC-AUC calculation failed: {e}")
            metrics['roc_auc_micro'] = 0.0
            metrics['roc_auc_macro'] = 0.0
            metrics['roc_auc_weighted'] = 0.0
            metrics['roc_auc_samples'] = 0.0
        
        # PR-AUC METRICS (using probabilities)
        try:
            metrics['pr_auc_micro'] = average_precision_score(y_true, y_pred_proba, average='micro')
            metrics['pr_auc_macro'] = average_precision_score(y_true, y_pred_proba, average='macro')
            metrics['pr_auc_weighted'] = average_precision_score(y_true, y_pred_proba, average='weighted')
            metrics['pr_auc_samples'] = average_precision_score(y_true, y_pred_proba, average='samples')
        except ValueError as e:
            print(f"Warning: PR-AUC calculation failed: {e}")
            metrics['pr_auc_micro'] = 0.0
            metrics['pr_auc_macro'] = 0.0
            metrics['pr_auc_weighted'] = 0.0
            metrics['pr_auc_samples'] = 0.0
        
    except Exception as e:
        print(f"Error in comprehensive_evaluation: {e}")
        # Return minimal metrics if calculation fails
        metrics = {
            'precision_micro': 0.0, 'recall_micro': 0.0, 'f1_micro': 0.0,
            'precision_macro': 0.0, 'recall_macro': 0.0, 'f1_macro': 0.0,
            'accuracy': 0.0, 'hamming_loss': 1.0
        }
    
    return metrics



In [11]:
# 🎯 COMPUTE METRICS FUNCTION FOR TRAINER
# This function is called during training to evaluate the model

def compute_metrics(eval_pred):
    """
    Enhanced compute_metrics function for transformers Trainer using comprehensive evaluation
    """
    predictions, labels = eval_pred
    
    # Apply sigmoid to get probabilities
    predictions_proba = sigmoid(predictions)
    
    # Convert to binary predictions using threshold 0.5
    predictions_binary = (predictions_proba > 0.5).astype(int)
    
    # Ensure labels are integers
    labels = labels.astype(int)
    
    # Use comprehensive evaluation
    metrics = comprehensive_evaluation(
        y_true=labels,
        y_pred_proba=predictions_proba,
        y_pred_binary=predictions_binary,
        threshold=0.5
    )
    
    # Return metrics with eval_ prefix for Trainer compatibility
    return {
        # Primary metrics for monitoring
        'eval_f1_micro': metrics['f1_micro'],
        'eval_f1_macro': metrics['f1_macro'],
        'eval_accuracy': metrics['accuracy'],
        'eval_hamming_loss': metrics['hamming_loss'],
        
        # Precision metrics
        'eval_precision_micro': metrics['precision_micro'],
        'eval_precision_macro': metrics['precision_macro'],
        'eval_precision_samples': metrics['precision_samples'],
        'eval_precision_weighted': metrics['precision_weighted'],
        
        # Recall metrics
        'eval_recall_micro': metrics['recall_micro'],
        'eval_recall_macro': metrics['recall_macro'],
        'eval_recall_samples': metrics['recall_samples'],
        'eval_recall_weighted': metrics['recall_weighted'],
        
        # F1 metrics
        'eval_f1_samples': metrics['f1_samples'],
        'eval_f1_weighted': metrics['f1_weighted'],
        
        # ROC-AUC metrics
        'eval_roc_auc_micro': metrics['roc_auc_micro'],
        'eval_roc_auc_macro': metrics['roc_auc_macro'],
        'eval_roc_auc_weighted': metrics['roc_auc_weighted'],
        'eval_roc_auc_samples': metrics['roc_auc_samples'],
        
        # PR-AUC metrics
        'eval_pr_auc_micro': metrics['pr_auc_micro'],
        'eval_pr_auc_macro': metrics['pr_auc_macro'],
        'eval_pr_auc_weighted': metrics['pr_auc_weighted'],
        'eval_pr_auc_samples': metrics['pr_auc_samples'],
        
        # Jaccard metrics
        'eval_jaccard_samples': metrics['jaccard_samples'],
        'eval_jaccard_macro': metrics['jaccard_macro'],
        'eval_jaccard_weighted': metrics['jaccard_weighted'],
    }

In [12]:
# 🧪 PRE-TRAINING VERIFICATION TEST
# Test that all components work together before starting training

print("🧪 RUNNING PRE-TRAINING VERIFICATION TESTS")
print("=" * 60)

# Test 1: Verify data collator works with a batch
print("1️⃣ Testing data collator...")
test_batch = [tokenized_dataset["train"][i] for i in range(3)]
try:
    collated_batch = data_collator(test_batch)
    print(f"   ✅ Data collator working: batch shape {collated_batch['input_ids'].shape}")
    print(f"   ✅ Labels shape: {collated_batch['labels'].shape}")
    print(f"   ✅ Labels dtype: {collated_batch['labels'].dtype}")
except Exception as e:
    print(f"   ❌ Data collator failed: {e}")

# Test 2: Verify model forward pass
print("\n2️⃣ Testing model forward pass...")
try:
    with torch.no_grad():
        outputs = model(**{k: v for k, v in collated_batch.items() if k in ['input_ids', 'attention_mask', 'labels']})
    print(f"   ✅ Model forward pass working")
    print(f"   ✅ Output logits shape: {outputs.logits.shape}")
    print(f"   ✅ Loss computed: {outputs.loss.item():.4f}")
except Exception as e:
    print(f"   ❌ Model forward pass failed: {e}")

# Test 3: Verify compute_metrics function
print("\n3️⃣ Testing compute_metrics function...")
try:
    # Create dummy predictions for testing
    dummy_predictions = np.random.randn(10, len(class_name))
    dummy_labels = np.random.randint(0, 2, (10, len(class_name))).astype(np.float32)
    
    eval_pred = (dummy_predictions, dummy_labels)
    test_metrics = compute_metrics(eval_pred)
    
    print(f"   ✅ Compute metrics working")
    print(f"   ✅ Primary metrics calculated:")
    print(f"      - F1 Micro: {test_metrics['eval_f1_micro']:.4f}")
    print(f"      - Hamming Loss: {test_metrics['eval_hamming_loss']:.4f}")
    print(f"      - Accuracy: {test_metrics['eval_accuracy']:.4f}")
except Exception as e:
    print(f"   ❌ Compute metrics failed: {e}")

# Test 4: Check GPU/CPU availability
print("\n4️⃣ Checking compute environment...")
if torch.cuda.is_available():
    print(f"   ✅ GPU available: {torch.cuda.get_device_name()}")
    print(f"   ✅ CUDA version: {torch.version.cuda}")
    print(f"   ✅ Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
else:
    print("   ℹ️ Running on CPU")

# Final verification
print(f"\n🎯 FINAL VERIFICATION:")
print(f"   ✅ Dataset loaded: {len(tokenized_dataset)} splits")
print(f"   ✅ Training samples: {len(tokenized_dataset['train'])}")
print(f"   ✅ Validation samples: {len(tokenized_dataset['val'])}")
print(f"   ✅ Model loaded: {type(model).__name__}")
print(f"   ✅ Tokenizer loaded: {type(tokenizer).__name__}")
print(f"   ✅ Data collator ready: {type(data_collator).__name__}")
print(f"   ✅ Metrics function ready: compute_metrics")

print(f"\n🚀 All systems ready for training!")

🧪 RUNNING PRE-TRAINING VERIFICATION TESTS
1️⃣ Testing data collator...
   ✅ Data collator working: batch shape torch.Size([3, 1024])
   ✅ Labels shape: torch.Size([3, 27])
   ✅ Labels dtype: torch.float32

2️⃣ Testing model forward pass...
   ✅ Model forward pass working
   ✅ Output logits shape: torch.Size([3, 27])
   ✅ Loss computed: 0.7586

3️⃣ Testing compute_metrics function...
   ✅ Compute metrics working
   ✅ Primary metrics calculated:
      - F1 Micro: 0.4689
      - Hamming Loss: 0.5370
      - Accuracy: 0.0000

4️⃣ Checking compute environment...
   ✅ GPU available: NVIDIA GeForce RTX 4080 Laptop GPU
   ✅ CUDA version: 12.6
   ✅ Memory allocated: 0.00 GB

🎯 FINAL VERIFICATION:
   ✅ Dataset loaded: 3 splits
   ✅ Training samples: 11597
   ✅ Validation samples: 2485
   ✅ Model loaded: ModernBertForSequenceClassification
   ✅ Tokenizer loaded: PreTrainedTokenizerFast
   ✅ Data collator ready: DataCollatorWithPadding
   ✅ Metrics function ready: compute_metrics

🚀 All systems re

In [13]:
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 🎯 OPTIMIZED METRICS CONFIGURATION FOR MULTI-LABEL CLASSIFICATION
# Hamming Loss is the most appropriate metric for multi-label problems
training_args = TrainingArguments(
    # Output and logging
    output_dir="./model_output",
    logging_dir="./logs",
    logging_steps=50,
    logging_strategy="steps",
    
    # Learning parameters
    learning_rate=2e-5,
    lr_scheduler_type="linear",  # Linear decay
    warmup_ratio=0.1,  # 10% warmup
    weight_decay=0.01,
    
    # Batch sizes (adjust based on GPU memory)
    per_device_train_batch_size=2,
    per_device_eval_batch_size=3,
    gradient_accumulation_steps=24,  # Effective batch size = 2 * 24 = 48
    
    # Training epochs and evaluation
    num_train_epochs=5,  # Increased for better convergence
    eval_strategy="steps",  # More frequent evaluation
    eval_steps=100,  # Evaluate every 100 steps
    
    # 🎯 OPTIMAL METRICS FOR MULTI-LABEL CLASSIFICATION
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,  # Keep only 3 best checkpoints
    load_best_model_at_end=True,
    
    # 🔥 RECOMMENDED: Use Hamming Loss for multi-label problems
    metric_for_best_model="eval_hamming_loss",  # Primary metric: lower is better
    greater_is_better=False,  # Hamming loss: lower = better performance
    
    # Alternative good options:
    # metric_for_best_model="eval_f1_micro",     # Current choice - also excellent
    # metric_for_best_model="eval_jaccard_samples", # IoU metric - good for multi-label
    
    # Memory and performance optimization
    dataloader_pin_memory=False,  # Disable to avoid forking issues
    dataloader_num_workers=0,     # Disable multiprocessing
    remove_unused_columns=False,  # Keep all columns for multi-label
    
    # Mixed precision for faster training (if GPU supports it)
    fp16=True,  # Enable if using compatible GPU
    
    # Reproducibility
    seed=42,
    data_seed=42,
    
    # Report metrics
    report_to=None,  # Disable wandb/tensorboard if not needed
    run_name="multi_label_posture_classification",
)

print("🎯 MULTI-LABEL METRICS CONFIGURATION:")
print("=" * 60)
print("✅ PRIMARY METRIC: Hamming Loss (optimal for multi-label)")
print("   • Measures label-wise classification errors")
print("   • Range: 0.0 (perfect) to 1.0 (worst)")
print("   • Lower values = better performance")
print("   • More intuitive than F1 for legal document classification")
print()
print("📊 MONITORING METRICS (all calculated):")
print("   • F1-Micro/Macro/Weighted: Overall performance")
print("   • Precision/Recall: Per-label quality")
print("   • ROC-AUC/PR-AUC: Ranking quality")
print("   • Jaccard Score: Label overlap similarity")
print("   • Accuracy: Exact match rate")

# Early stopping callback for overfitting control
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=3,  # Stop if no improvement for 3 evaluations
    early_stopping_threshold=0.001  # Minimum improvement threshold
)

# Initialize trainer with enhanced configuration (using processing_class)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["val"],
    processing_class=tokenizer,  # Updated parameter name
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],  # Add early stopping callback
)

print("🚀 Starting training with enhanced configuration...")
print(f"📊 Training samples: {len(tokenized_dataset['train'])}")
print(f"📊 Validation samples: {len(tokenized_dataset['val'])}")
print(f"🎯 Target metric: {training_args.metric_for_best_model}")
print(f"⏱️ Total epochs: {training_args.num_train_epochs}")
print(f"🔄 Evaluation every: {training_args.eval_steps} steps")
print(f"💾 Saving every: {training_args.save_steps} steps")
print(f"⏹️ Early stopping patience: {early_stopping.early_stopping_patience}")

# Start training with error handling
try:
    print("\n🎯 Starting training...")
    trainer.train()
    print("✅ Training completed successfully!")
except Exception as e:
    print(f"❌ Training failed with error: {e}")
    print("💡 Consider:")
    print("   - Reducing batch size if out of memory")
    print("   - Checking data format compatibility")
    print("   - Verifying model and tokenizer compatibility")
    print("   - The data format may need fixing - check tokenization step")

🎯 MULTI-LABEL METRICS CONFIGURATION:
✅ PRIMARY METRIC: Hamming Loss (optimal for multi-label)
   • Measures label-wise classification errors
   • Range: 0.0 (perfect) to 1.0 (worst)
   • Lower values = better performance
   • More intuitive than F1 for legal document classification

📊 MONITORING METRICS (all calculated):
   • F1-Micro/Macro/Weighted: Overall performance
   • Precision/Recall: Per-label quality
   • ROC-AUC/PR-AUC: Ranking quality
   • Jaccard Score: Label overlap similarity
   • Accuracy: Exact match rate
🚀 Starting training with enhanced configuration...
📊 Training samples: 11597
📊 Validation samples: 2485
🎯 Target metric: eval_hamming_loss
⏱️ Total epochs: 5
🔄 Evaluation every: 100 steps
💾 Saving every: 100 steps
⏹️ Early stopping patience: 3

🎯 Starting training...


Step,Training Loss,Validation Loss,F1 Micro,F1 Macro,Accuracy,Hamming Loss,Precision Micro,Precision Macro,Precision Samples,Precision Weighted,Recall Micro,Recall Macro,Recall Samples,Recall Weighted,F1 Samples,F1 Weighted,Roc Auc Micro,Roc Auc Macro,Roc Auc Weighted,Roc Auc Samples,Pr Auc Micro,Pr Auc Macro,Pr Auc Weighted,Pr Auc Samples,Jaccard Samples,Jaccard Macro,Jaccard Weighted
100,2.7852,0.111019,0.624433,0.086674,0.311871,0.038259,0.702436,0.10899,0.717505,0.560948,0.562023,0.095133,0.602897,0.562023,0.622093,0.515207,0.936868,0.779747,0.910602,0.942352,0.691125,0.207249,0.690559,0.782173,0.5411,0.070114,0.437663
200,1.7751,0.071259,0.763127,0.219544,0.509859,0.023936,0.867248,0.377614,0.851576,0.759436,0.681327,0.191785,0.734856,0.681327,0.759745,0.696875,0.972228,0.895318,0.955668,0.973028,0.845372,0.365474,0.781209,0.891692,0.695453,0.172415,0.617325
300,1.4587,0.06425,0.788728,0.328245,0.556942,0.022237,0.852986,0.447978,0.85446,0.781673,0.733474,0.303061,0.778015,0.733474,0.78884,0.735088,0.980183,0.926427,0.963924,0.980714,0.869265,0.454025,0.805106,0.909977,0.729873,0.259198,0.661487
400,1.4818,0.061398,0.796647,0.374351,0.56338,0.02097,0.882767,0.498025,0.87552,0.797587,0.725836,0.335525,0.776237,0.725836,0.795694,0.742269,0.982317,0.941393,0.96876,0.981947,0.878684,0.512911,0.823248,0.915431,0.736197,0.293438,0.664836
500,1.1692,0.060316,0.798918,0.421377,0.562978,0.021045,0.869767,0.554482,0.863783,0.840607,0.738741,0.403935,0.785996,0.738741,0.796056,0.755967,0.983268,0.94652,0.969531,0.983421,0.881363,0.548006,0.834789,0.923255,0.736935,0.332189,0.674786
600,1.0795,0.055266,0.821168,0.468555,0.605634,0.01939,0.858827,0.560307,0.857277,0.825217,0.786674,0.437545,0.826036,0.786674,0.815962,0.793858,0.986564,0.952711,0.972204,0.986509,0.896173,0.584159,0.844948,0.93143,0.762502,0.374139,0.7115
700,1.0713,0.055101,0.821589,0.546144,0.597586,0.01951,0.851412,0.620372,0.862307,0.828864,0.793785,0.513095,0.831737,0.793785,0.821136,0.801769,0.986717,0.958125,0.973318,0.985989,0.897018,0.607524,0.851816,0.932736,0.765319,0.430719,0.715773
800,0.8176,0.052281,0.831433,0.548278,0.624145,0.018064,0.880931,0.71563,0.885915,0.855441,0.7872,0.493188,0.830228,0.7872,0.832404,0.808775,0.987332,0.956924,0.97293,0.9869,0.905655,0.631611,0.85615,0.936964,0.779571,0.434646,0.724123
900,0.8138,0.053655,0.827237,0.543852,0.613682,0.018794,0.862079,0.672494,0.869484,0.840754,0.795101,0.488834,0.834923,0.795101,0.826854,0.805474,0.98713,0.956197,0.972626,0.986695,0.90315,0.634566,0.857131,0.933913,0.772897,0.427341,0.719406
1000,0.7269,0.052462,0.839298,0.594264,0.635412,0.017602,0.868243,0.720676,0.879074,0.851082,0.81222,0.544148,0.84951,0.81222,0.840448,0.823683,0.987556,0.956128,0.972587,0.987167,0.906134,0.635478,0.857563,0.936185,0.789182,0.472721,0.737553


✅ Training completed successfully!


### Inference 

In [16]:
checkpoint_dir = "model_output/checkpoint-1000"
# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)


In [None]:
# Prepare DataLoader
from torch.utils.data import DataLoader

test_dataset=tokenized_dataset["test"]
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
test_loader = DataLoader(test_dataset, batch_size=8)

# Run inference
model.eval()
all_logits = []
all_labels = []
with torch.no_grad():
    for batch in tqdm(test_loader,total=len(test_loader),leave=True,position=0):
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        labels = batch["labels"].cpu().numpy()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits.cpu().numpy()
        all_logits.append(logits)
        all_labels.append(labels)


# Concatenate all batches
all_logits = np.concatenate(all_logits, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

# Apply sigmoid to get probabilities
probas = torch.sigmoid(torch.tensor(all_logits)).numpy()

# Apply threshold to get binary predictions
threshold = 0.5
predictions = (probas >= threshold).astype(int)

results = {
    "f1_micro": f1_score(all_labels, predictions, average="micro"),
    "f1_macro": f1_score(all_labels, predictions, average="macro"),
    "accuracy": accuracy_score(all_labels, predictions),
    "hamming_loss": hamming_loss(all_labels, predictions),
    "precision_micro": precision_score(all_labels, predictions, average="micro"),
    "recall_micro": recall_score(all_labels, predictions, average="micro"),
    "jaccard_weighted": jaccard_score(all_labels, predictions, average="weighted"),
}

  5%|▍         | 14/311 [01:16<24:51,  5.02s/it]

In [30]:
predictions, labels

(array([ 0, 19, 20, ..., 20, 24, 20], shape=(2486,)),
 array([[1., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], shape=(2486, 27)))

In [28]:
labels = np.array(test_dataset["labels"])
predictions = np.array(predictions)
eval_pred=(predictions, labels)
compute_metrics(eval_pred)

Error in comprehensive_evaluation: Classification metrics can't handle a mix of multilabel-indicator and binary targets


KeyError: 'precision_samples'

In [26]:

test_dataset

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 2486
})

{'input_ids': tensor([50281,  7307,    18,  ..., 20461,    56, 50282]),
 'attention_mask': tensor([1, 1, 1,  ..., 1, 1, 1])}

In [14]:
# Post-Training Evaluation and Testing

print("🔍 COMPREHENSIVE MODEL EVALUATION")
print("=" * 60)

# Evaluate on validation set
print("\n📊 Validation Set Evaluation:")
val_results = trainer.evaluate()

# Display key metrics
key_metrics = [
    'eval_f1_micro', 'eval_f1_macro', 'eval_accuracy', 'eval_hamming_loss',
    'eval_precision_micro', 'eval_recall_micro', 'eval_roc_auc_macro'
]

for metric in key_metrics:
    if metric in val_results:
        print(f"   {metric}: {val_results[metric]:.4f}")

🔍 COMPREHENSIVE MODEL EVALUATION

📊 Validation Set Evaluation:


   eval_f1_micro: 0.8393
   eval_f1_macro: 0.5943
   eval_accuracy: 0.6354
   eval_hamming_loss: 0.0176
   eval_precision_micro: 0.8682
   eval_recall_micro: 0.8122
   eval_roc_auc_macro: 0.9561


In [15]:
# Test on test set if available
if "test" in tokenized_dataset:
    print("\n🎯 Test Set Evaluation:")
    test_results = trainer.evaluate(eval_dataset=tokenized_dataset["test"])
    
    for metric in key_metrics:
        if metric in test_results:
            print(f"   {metric}: {test_results[metric]:.4f}")


🎯 Test Set Evaluation:
   eval_f1_micro: 0.8338
   eval_f1_macro: 0.5600
   eval_accuracy: 0.6307
   eval_hamming_loss: 0.0183
   eval_precision_micro: 0.8550
   eval_recall_micro: 0.8136
   eval_roc_auc_macro: 0.9555


In [None]:




# # Get predictions for detailed analysis
# print("\n🔬 Detailed Prediction Analysis:")

# # Predict on validation set
# val_predictions = trainer.predict(tokenized_dataset["val"])
# val_probs = sigmoid(val_predictions.predictions)
# val_binary = (val_probs > 0.5).astype(int)
# val_true = val_predictions.label_ids

# # Use comprehensive evaluation function
# detailed_metrics = comprehensive_evaluation(
#     y_true=val_true,
#     y_pred_proba=val_probs,
#     y_pred_binary=val_binary
# )

# print("\n📈 Comprehensive Metrics Summary:")
# print("-" * 50)

# # Group metrics by type
# metric_groups = {
#     'Precision': ['precision_micro', 'precision_macro', 'precision_samples', 'precision_weighted'],
#     'Recall': ['recall_micro', 'recall_macro', 'recall_samples', 'recall_weighted'],
#     'F1-Score': ['f1_micro', 'f1_macro', 'f1_samples', 'f1_weighted'],
#     'ROC-AUC': ['roc_auc_macro', 'roc_auc_weighted', 'roc_auc_samples'],
#     'PR-AUC': ['pr_auc_macro', 'pr_auc_weighted', 'pr_auc_samples'],
#     'Other': ['accuracy', 'hamming_loss', 'jaccard_macro', 'jaccard_samples']
# }

# for group_name, metrics in metric_groups.items():
#     print(f"\n{group_name}:")
#     for metric in metrics:
#         if metric in detailed_metrics:
#             print(f"   {metric}: {detailed_metrics[metric]:.4f}")

# # Sample predictions analysis
# print("\n🔍 Sample Predictions Analysis:")
# sample_size = min(5, len(val_true))
# for i in range(sample_size):
#     print(f"\nSample {i+1}:")
#     print(f"   True labels: {val_true[i]}")
#     print(f"   Predicted:   {val_binary[i]}")
#     print(f"   Probabilities: {val_probs[i]}")
#     print(f"   Match: {'✅' if np.array_equal(val_true[i], val_binary[i]) else '❌'}")

# # Model performance summary
# print(f"\n{'='*60}")
# print("🏆 MODEL PERFORMANCE SUMMARY")
# print(f"{'='*60}")
# print(f"✅ Best Metric (F1-Micro): {detailed_metrics['f1_micro']:.4f}")
# print(f"📊 Accuracy: {detailed_metrics['accuracy']:.4f}")
# print(f"🔻 Hamming Loss: {detailed_metrics['hamming_loss']:.4f}")
# print(f"🎯 Macro F1: {detailed_metrics['f1_macro']:.4f}")

# if detailed_metrics['f1_micro'] > 0.7:
#     print("🎉 Excellent performance! Model is ready for deployment.")
# elif detailed_metrics['f1_micro'] > 0.5:
#     print("👍 Good performance! Consider fine-tuning for better results.")
# else:
#     print("⚠️ Performance needs improvement. Consider:")
#     print("   - More training epochs")
#     print("   - Different learning rate")
#     print("   - Data augmentation")
#     print("   - Different model architecture")

# print(f"\n💾 Model saved to: {training_args.output_dir}")
# print("🚀 Training and evaluation completed successfully!")# Enhanced Training Configuration for Multi-label Classification


In [None]:
# # Post-Training Evaluation and Testing

# print("🔍 COMPREHENSIVE MODEL EVALUATION")
# print("=" * 60)

# # Evaluate on validation set
# print("\n📊 Validation Set Evaluation:")
# val_results = trainer.evaluate()

# # Display key metrics
# key_metrics = [
#     'eval_f1_micro', 'eval_f1_macro', 'eval_accuracy', 'eval_hamming_loss',
#     'eval_precision_micro', 'eval_recall_micro', 'eval_roc_auc_macro'
# ]

# for metric in key_metrics:
#     if metric in val_results:
#         print(f"   {metric}: {val_results[metric]:.4f}")

# # Test on test set if available
# if "test" in tokenized_dataset:
#     print("\n🎯 Test Set Evaluation:")
#     test_results = trainer.evaluate(eval_dataset=tokenized_dataset["test"])
    
#     for metric in key_metrics:
#         if metric in test_results:
#             print(f"   {metric}: {test_results[metric]:.4f}")

# # Get predictions for detailed analysis
# print("\n🔬 Detailed Prediction Analysis:")

# # Predict on validation set
# val_predictions = trainer.predict(tokenized_dataset["val"])
# val_probs = sigmoid(val_predictions.predictions)
# val_binary = (val_probs > 0.5).astype(int)
# val_true = val_predictions.label_ids

# # Use comprehensive evaluation function
# detailed_metrics = comprehensive_evaluation(
#     y_true=val_true,
#     y_pred_proba=val_probs,
#     y_pred_binary=val_binary
# )

# print("\n📈 Comprehensive Metrics Summary:")
# print("-" * 50)

# # Group metrics by type
# metric_groups = {
#     'Precision': ['precision_micro', 'precision_macro', 'precision_samples', 'precision_weighted'],
#     'Recall': ['recall_micro', 'recall_macro', 'recall_samples', 'recall_weighted'],
#     'F1-Score': ['f1_micro', 'f1_macro', 'f1_samples', 'f1_weighted'],
#     'ROC-AUC': ['roc_auc_macro', 'roc_auc_weighted', 'roc_auc_samples'],
#     'PR-AUC': ['pr_auc_macro', 'pr_auc_weighted', 'pr_auc_samples'],
#     'Other': ['accuracy', 'hamming_loss', 'jaccard_macro', 'jaccard_samples']
# }

# for group_name, metrics in metric_groups.items():
#     print(f"\n{group_name}:")
#     for metric in metrics:
#         if metric in detailed_metrics:
#             print(f"   {metric}: {detailed_metrics[metric]:.4f}")

# # Sample predictions analysis
# print("\n🔍 Sample Predictions Analysis:")
# sample_size = min(5, len(val_true))
# for i in range(sample_size):
#     print(f"\nSample {i+1}:")
#     print(f"   True labels: {val_true[i]}")
#     print(f"   Predicted:   {val_binary[i]}")
#     print(f"   Probabilities: {val_probs[i]}")
#     print(f"   Match: {'✅' if np.array_equal(val_true[i], val_binary[i]) else '❌'}")

# # Model performance summary
# print(f"\n{'='*60}")
# print("🏆 MODEL PERFORMANCE SUMMARY")
# print(f"{'='*60}")
# print(f"✅ Best Metric (F1-Micro): {detailed_metrics['f1_micro']:.4f}")
# print(f"📊 Accuracy: {detailed_metrics['accuracy']:.4f}")
# print(f"🔻 Hamming Loss: {detailed_metrics['hamming_loss']:.4f}")
# print(f"🎯 Macro F1: {detailed_metrics['f1_macro']:.4f}")

# if detailed_metrics['f1_micro'] > 0.7:
#     print("🎉 Excellent performance! Model is ready for deployment.")
# elif detailed_metrics['f1_micro'] > 0.5:
#     print("👍 Good performance! Consider fine-tuning for better results.")
# else:
#     print("⚠️ Performance needs improvement. Consider:")
#     print("   - More training epochs")
#     print("   - Different learning rate")
#     print("   - Data augmentation")
#     print("   - Different model architecture")

# print(f"\n💾 Model saved to: {training_args.output_dir}")
# print("🚀 Training and evaluation completed successfully!")# Enhanced Training Configuration for Multi-label Classification
# training_args = TrainingArguments(
#     # Output and logging
#     output_dir="./model_output",
#     logging_dir="./logs",
#     logging_steps=50,
#     logging_strategy="steps",
    
#     # Learning parameters
#     learning_rate=2e-5,
#     lr_scheduler_type="linear",  # Linear decay
#     warmup_ratio=0.1,  # 10% warmup
#     weight_decay=0.01,
    
#     # Batch sizes (adjust based on GPU memory)
#     per_device_train_batch_size=3,
#     per_device_eval_batch_size=3,
#     gradient_accumulation_steps=4,  # Effective batch size = 3 * 4 = 12
    
#     # Training epochs and evaluation
#     num_train_epochs=3,  # Increased for better convergence
#     eval_strategy="steps",  # More frequent evaluation
#     eval_steps=100,  # Evaluate every 100 steps
    
#     # Saving strategy
#     save_strategy="steps",
#     save_steps=100,
#     save_total_limit=3,  # Keep only 3 best checkpoints
#     load_best_model_at_end=True,
#     metric_for_best_model="eval_f1_micro",  # Use micro F1 for model selection
#     greater_is_better=True,
    
#     # Early stopping and overfitting control
#     early_stopping_patience=3,  # Stop if no improvement for 3 evaluations
    
#     # Memory and performance optimization
#     dataloader_pin_memory=True,
#     dataloader_num_workers=2,
#     remove_unused_columns=False,  # Keep all columns for multi-label
    
#     # Mixed precision for faster training (if GPU supports it)
#     fp16=True,  # Enable if using compatible GPU
    
#     # Reproducibility
#     seed=42,
#     data_seed=42,
    
#     # Report metrics
#     report_to=None,  # Disable wandb/tensorboard if not needed
#     run_name="multi_label_posture_classification",
# )

# # Initialize trainer with enhanced configuration
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized_dataset["train"],
#     eval_dataset=tokenized_dataset["val"],
#     tokenizer=tokenizer,
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
# )

# print("🚀 Starting training with enhanced configuration...")
# print(f"📊 Training samples: {len(tokenized_dataset['train'])}")
# print(f"📊 Validation samples: {len(tokenized_dataset['val'])}")
# print(f"🎯 Target metric: {training_args.metric_for_best_model}")
# print(f"⏱️ Total epochs: {training_args.num_train_epochs}")
# print(f"🔄 Evaluation every: {training_args.eval_steps} steps")
# print(f"💾 Saving every: {training_args.save_steps} steps")

# # Start training
# trainer.train()

In [None]:
# # 🔧 EXPLICIT LABEL TYPE CONVERSION
# # Convert labels to float32 using HuggingFace datasets features

# from datasets import Sequence, Value
# import torch

# print("🔄 Converting labels to float32 using datasets.cast_column...")

# # Define the proper feature type for multi-label classification
# # Labels should be a sequence of floats (one per class)
# label_feature = Sequence(Value("float32"), length=len(class_name))

# # Cast the labels column to float32 for all splits
# for split_name in tokenized_dataset.keys():
#     tokenized_dataset[split_name] = tokenized_dataset[split_name].cast_column("labels", label_feature)

# # Verify the fix
# print(f"\n✅ Labels conversion verification:")
# sample = tokenized_dataset["train"][0]
# sample_labels = np.array(sample['labels'])
# print(f"  Labels dtype: {sample_labels.dtype}")
# print(f"  Labels shape: {sample_labels.shape}")
# print(f"  Sample labels: {sample['labels'][:5]}...")  # Show first 5 labels
# print(f"  HF Feature type: {tokenized_dataset['train'].features['labels']}")

# # Test tensor conversion
# test_labels = torch.tensor(sample['labels'], dtype=torch.float32)
# print(f"  PyTorch tensor dtype: {test_labels.dtype}")
# print(f"  PyTorch tensor shape: {test_labels.shape}")

# if sample_labels.dtype == np.float32:
#     print("✅ SUCCESS: Labels are now properly formatted as float32")
#     print("🚀 Ready for training!")
# else:
#     print(f"❌ ISSUE: Labels are still {sample_labels.dtype}, expected float32")

In [None]:
# # 🎉 FINAL MODEL EVALUATION
# # Comprehensive evaluation of the trained multi-label classification model

# import torch
# from sklearn.metrics import classification_report
# import numpy as np

# print("🔬 FINAL MODEL EVALUATION")
# print("=" * 60)

# # First, let's check the test set data types and fix if needed
# print("🔍 Checking test set data types...")
# test_sample = tokenized_dataset["test"][0]
# test_labels = np.array(test_sample['labels'])
# print(f"Test labels dtype: {test_labels.dtype}")

# if test_labels.dtype != np.float32:
#     print("⚠️ Test set labels need conversion, performing conversion...")
#     # Re-apply the label conversion to test set
#     from datasets import Sequence, Value
#     label_feature = Sequence(Value("float32"), length=len(class_name))
#     tokenized_dataset["test"] = tokenized_dataset["test"].cast_column("labels", label_feature)
#     print("✅ Test set labels converted to float32")

# # Use predict method instead of evaluate to avoid evaluation issues
# print("📊 Generating predictions on test set...")
# predictions = trainer.predict(tokenized_dataset["test"])

# # Convert predictions to probabilities and binary predictions
# y_pred_proba = torch.sigmoid(torch.tensor(predictions.predictions)).numpy()
# y_pred = (y_pred_proba > 0.5).astype(int)
# y_true = predictions.label_ids.astype(int)

# print(f"Prediction shape: {y_pred.shape}")
# print(f"True labels shape: {y_true.shape}")

# # Calculate comprehensive metrics manually using our evaluation function
# # Note: Fix the function call order - comprehensive_evaluation(y_true, y_pred_proba, y_pred_binary)
# print("📊 Calculating comprehensive metrics...")
# detailed_metrics = comprehensive_evaluation(y_true, y_pred_proba, y_pred_binary=y_pred)

# print(f"\n🏆 TEST SET RESULTS:")
# print(f"{'='*50}")

# # Print all the comprehensive metrics
# metric_groups = {
#     "📈 Primary Metrics": ["f1_micro", "f1_macro", "f1_weighted", "f1_samples"],
#     "🎯 Precision": ["precision_micro", "precision_macro", "precision_weighted", "precision_samples"],
#     "🔍 Recall": ["recall_micro", "recall_macro", "recall_weighted", "recall_samples"],
#     "📊 Other Metrics": ["accuracy", "hamming_loss", "jaccard_samples", "jaccard_macro", "jaccard_weighted"],
#     "📡 AUC Metrics": ["roc_auc_macro", "roc_auc_weighted", "roc_auc_samples", 
#                        "pr_auc_macro", "pr_auc_weighted", "pr_auc_samples"]
# }

# for group_name, metrics in metric_groups.items():
#     print(f"\n{group_name}:")
#     for metric in metrics:
#         if metric in detailed_metrics:
#             print(f"  {metric.upper()}: {detailed_metrics[metric]:.4f}")

# # Per-class performance
# print(f"\n📋 PER-CLASS PERFORMANCE:")
# print(f"{'='*50}")
# class_report = classification_report(
#     y_true, y_pred, 
#     target_names=class_name, 
#     output_dict=True,
#     zero_division=0
# )

# # Show performance for each class
# for i, class_label in enumerate(class_name):
#     if class_label in class_report:
#         metrics = class_report[class_label]
#         support = int(metrics['support'])
#         print(f"{class_label:30s} | P: {metrics['precision']:.3f} | R: {metrics['recall']:.3f} | F1: {metrics['f1-score']:.3f} | Support: {support:4d}")

# # Overall summary
# print(f"\n🎯 OVERALL PERFORMANCE SUMMARY:")
# print(f"{'='*50}")
# macro_avg = class_report['macro avg']
# weighted_avg = class_report['weighted avg']

# print(f"🔹 Macro Average    | P: {macro_avg['precision']:.3f} | R: {macro_avg['recall']:.3f} | F1: {macro_avg['f1-score']:.3f}")
# print(f"🔹 Weighted Average | P: {weighted_avg['precision']:.3f} | R: {weighted_avg['recall']:.3f} | F1: {weighted_avg['f1-score']:.3f}")

# # Performance assessment
# f1_micro = detailed_metrics.get('f1_micro', 0)
# print(f"\n🏆 FINAL ASSESSMENT:")
# print(f"{'='*50}")
# if f1_micro > 0.8:
#     assessment = "🌟 EXCELLENT! Model shows outstanding performance."
# elif f1_micro > 0.7:
#     assessment = "✅ VERY GOOD! Model performance is strong and ready for deployment."
# elif f1_micro > 0.6:
#     assessment = "👍 GOOD! Model shows solid performance with room for improvement."
# elif f1_micro > 0.5:
#     assessment = "⚠️ MODERATE! Consider additional training or data improvements."
# else:
#     assessment = "❌ NEEDS IMPROVEMENT! Significant enhancements required."

# print(f"Micro F1 Score: {f1_micro:.4f}")
# print(f"Assessment: {assessment}")

# print(f"\n💾 Model and results saved to: {training_args.output_dir}")
# print(f"🎉 Multi-label legal posture classification training completed successfully!")

# # Save the best model explicitly
# print(f"\n💾 Saving final model...")
# trainer.save_model(f"{training_args.output_dir}/final_model")
# tokenizer.save_pretrained(f"{training_args.output_dir}/final_model")
# print(f"✅ Final model saved to: {training_args.output_dir}/final_model")

<span style="font-weight: bold; font-size: 18px;">
Since the text length in the corpus is quite long, even though the language model can handle a maximum input of 8,092 tokens or even longer context, such lengthy contexts can dilute the essential information that the LLM is able to process effectively. When the input text is too long, the model may struggle to focus on the most relevant details, which can negatively impact its performance and the quality of its predictions.

To address this challenge, it is beneficial to provide the language model with only the most relevant information from the large corpus, rather than overwhelming it with the entire text. By filtering and condensing the input, we can help the LLM focus on the critical content, thereby enhancing its ability to extract essential information and make more accurate predictions.

There are two main approaches to achieve this:

<div style="margin-left: 20px;"><b>• Summarization:</b> Summarization techniques can be used as a form of feature extraction. By generating concise summaries of the original text, we can distill the most important points and reduce the input length, making it easier for the LLM to process and understand the core information.</div> <div style="margin-left: 20px;"><b>• Retrieval-Augmented Generation (RAG):</b> This approach involves retrieving relevant information from the corpus before passing it to the language model. Techniques such as semantic search or keyword-based search (e.g., BM25) can be used to identify and extract the most pertinent sections of text. The retrieved content is then fed into the LLM, ensuring that the model receives focused and contextually relevant information for prediction.</div>
By applying either summarization or retrieval-augmented generation, we can significantly improve the efficiency and effectiveness of language models when dealing with large and complex corpora. This targeted approach helps prevent information overload and allows the model to generate more accurate and meaningful outputs.
</span>