In [None]:
"""
SuperEmotion Dataset Tokenization Script
For DeBERTa v3 Base - Single-label Multi-class Classification
7 Emotions: fear, joy, sadness, anger, love, neutral, surprise
"""

import os
import pandas as pd
import numpy as np
import torch
from transformers import DebertaV2Tokenizer
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from tqdm.auto import tqdm
import warnings
import json
warnings.filterwarnings('ignore')

# ==================== CONFIGURATION ====================
DATASET_PATH = '/content/drive/MyDrive/SuperEmotion/superemotion.csv'
SAVE_DIR = '/content/drive/MyDrive/SuperEmotion/'
MODEL_NAME = 'microsoft/deberta-v3-base'
MAX_LENGTH = 128

# 7 Emotion Classes (alphabetically ordered for consistency)
EMOTION_CLASSES = ['anger', 'fear', 'joy', 'love', 'neutral', 'sadness', 'surprise']

print("="*80)
print(" "*20 + "🚀 SuperEmotion Dataset Tokenization 🚀")
print(" "*22 + "DeBERTa v3 Base - 7 Emotions")
print("="*80)
print(f"\n📁 Dataset: {DATASET_PATH}")
print(f"💾 Save Directory: {SAVE_DIR}")
print(f"🤖 Model: {MODEL_NAME}")
print(f"📏 Max Length: {MAX_LENGTH}")
print(f"🎭 Emotions: {len(EMOTION_CLASSES)}")
print(f"   {', '.join(EMOTION_CLASSES)}")

# ==================== MOUNT GOOGLE DRIVE ====================
try:
    from google.colab import drive
    if not os.path.exists('/content/drive'):
        print("\n🔗 Mounting Google Drive...")
        drive.mount('/content/drive')
        print("✅ Google Drive mounted successfully!")
    else:
        print("\n✅ Google Drive already mounted!")
except:
    print("\n⚠️  Not running in Colab or Drive already mounted")

# ==================== LOAD DATASET ====================
print("\n" + "="*80)
print("STEP 1: LOADING DATASET")
print("="*80)

try:
    # Check if file exists
    if not os.path.exists(DATASET_PATH):
        print(f"❌ ERROR: Dataset not found at {DATASET_PATH}")
        print("\n🔍 Searching for superemotion.csv...")
        import glob
        matches = glob.glob('/content/drive/MyDrive/**/superemotion.csv', recursive=True)
        if matches:
            print("Found at:")
            for match in matches:
                print(f"  ✅ {match}")
            print("\nPlease update DATASET_PATH with correct path.")
        else:
            print("❌ File not found in Drive")
        exit(1)

    # Load CSV
    print(f"📂 Loading CSV from: {DATASET_PATH}")
    df = pd.read_csv(DATASET_PATH, delimiter=',')
    print(f"✅ Dataset loaded successfully!")
    print(f"   Shape: {df.shape}")
    print(f"   Columns: {df.columns.tolist()}")

    # Verify columns
    if 'text' not in df.columns or 'emotions' not in df.columns:
        print(f"❌ ERROR: Expected columns 'text' and 'emotions'")
        print(f"   Found: {df.columns.tolist()}")
        exit(1)

    print(f"\n✅ Columns verified: text, emotions")

    # Check for missing values
    print(f"\n🔍 Checking data quality...")
    text_nulls = df['text'].isnull().sum()
    emotion_nulls = df['emotions'].isnull().sum()

    if text_nulls > 0:
        print(f"⚠️  {text_nulls} null texts found - removing")
        df = df.dropna(subset=['text'])

    if emotion_nulls > 0:
        print(f"⚠️  {emotion_nulls} null emotions found - removing")
        df = df.dropna(subset=['emotions'])

    # Clean text and emotion columns
    df['text'] = df['text'].astype(str).str.strip()
    df['emotions'] = df['emotions'].astype(str).str.strip().str.lower()

    # Remove empty texts
    df = df[df['text'].str.len() > 0]

    print(f"✅ Dataset after cleaning: {df.shape[0]:,} samples")

    # Show emotion distribution BEFORE filtering
    print(f"\n📊 Emotion Distribution (Original):")
    emotion_counts = df['emotions'].value_counts()
    for emotion, count in emotion_counts.items():
        percentage = (count / len(df)) * 100
        print(f"   {emotion:12s}: {count:6,} ({percentage:5.2f}%)")

    # Filter only valid emotions (in case there are typos)
    valid_emotions_mask = df['emotions'].isin(EMOTION_CLASSES)
    if not valid_emotions_mask.all():
        invalid_count = (~valid_emotions_mask).sum()
        print(f"\n⚠️  Warning: Found {invalid_count} samples with invalid emotions")
        invalid_emotions = df[~valid_emotions_mask]['emotions'].unique()
        print(f"   Invalid emotions: {invalid_emotions}")
        print(f"   Removing invalid samples...")
        df = df[valid_emotions_mask]

    print(f"\n✅ Final dataset: {df.shape[0]:,} samples")

    # Show final distribution
    print(f"\n📊 Final Emotion Distribution:")
    emotion_counts = df['emotions'].value_counts().sort_index()
    total_samples = len(df)

    for emotion in EMOTION_CLASSES:
        count = emotion_counts.get(emotion, 0)
        percentage = (count / total_samples) * 100
        bar_length = int(percentage / 2)
        bar = '█' * bar_length
        print(f"   {emotion:12s}: {count:6,} ({percentage:5.2f}%) {bar}")

    # Check for severe imbalance
    min_count = emotion_counts.min()
    max_count = emotion_counts.max()
    imbalance_ratio = max_count / min_count
    print(f"\n📈 Imbalance Ratio: {imbalance_ratio:.2f}x (max/min)")

    if imbalance_ratio > 5:
        print(f"⚠️  WARNING: Severe class imbalance detected!")
    elif imbalance_ratio > 2:
        print(f"⚠️  Moderate class imbalance - will use class weights")
    else:
        print(f"✅ Dataset is reasonably balanced")

except FileNotFoundError:
    print(f"❌ ERROR: File not found at {DATASET_PATH}")
    exit(1)
except Exception as e:
    print(f"❌ ERROR loading dataset: {e}")
    import traceback
    traceback.print_exc()
    exit(1)

# ==================== ENCODE LABELS ====================
print("\n" + "="*80)
print("STEP 2: ENCODING LABELS")
print("="*80)

try:
    # Create label encoding mapping
    label_to_id = {label: idx for idx, label in enumerate(EMOTION_CLASSES)}
    id_to_label = {idx: label for label, idx in label_to_id.items()}

    print(f"✅ Label Encoding:")
    for label, idx in label_to_id.items():
        count = (df['emotions'] == label).sum()
        print(f"   {label:12s} → {idx} ({count:,} samples)")

    # Encode labels
    df['label'] = df['emotions'].map(label_to_id)

    # Verify no missing labels
    if df['label'].isnull().any():
        missing = df[df['label'].isnull()]['emotions'].unique()
        print(f"❌ ERROR: Could not encode emotions: {missing}")
        exit(1)

    print(f"\n✅ All labels encoded successfully!")

except Exception as e:
    print(f"❌ ERROR encoding labels: {e}")
    exit(1)

# ==================== CALCULATE CLASS WEIGHTS ====================
print("\n" + "="*80)
print("STEP 3: CALCULATING CLASS WEIGHTS")
print("="*80)

try:
    # Compute class weights (inverse frequency)
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.array(range(len(EMOTION_CLASSES))),
        y=df['label'].values
    )

    print(f"✅ Class Weights Calculated:")
    for idx, (emotion, weight) in enumerate(zip(EMOTION_CLASSES, class_weights)):
        print(f"   {emotion:12s}: {weight:.4f}")

    # Save class weights
    weights_dict = {emotion: float(weight) for emotion, weight in zip(EMOTION_CLASSES, class_weights)}

except Exception as e:
    print(f"❌ ERROR calculating class weights: {e}")
    class_weights = np.ones(len(EMOTION_CLASSES))
    weights_dict = {emotion: 1.0 for emotion in EMOTION_CLASSES}

# ==================== STRATIFIED SPLIT ====================
print("\n" + "="*80)
print("STEP 4: STRATIFIED TRAIN/VAL/TEST SPLIT")
print("="*80)

try:
    # Extract texts and labels
    texts = df['text'].values
    labels = df['label'].values
    emotions = df['emotions'].values

    total_samples = len(texts)
    train_size = int(0.8 * total_samples)
    val_size = int(0.1 * total_samples)
    test_size = total_samples - train_size - val_size

    print(f"📊 Target Split (80/10/10):")
    print(f"   Total: {total_samples:,}")
    print(f"   Train: {train_size:,} (80%)")
    print(f"   Val:   {val_size:,} (10%)")
    print(f"   Test:  {test_size:,} (10%)")

    print(f"\n🔀 Performing stratified split with shuffle...")

    # First split: train vs (val + test)
    train_texts, temp_texts, train_labels, temp_labels = train_test_split(
        texts, labels,
        test_size=0.2,
        random_state=42,
        stratify=labels,
        shuffle=True
    )

    # Second split: val vs test
    val_texts, test_texts, val_labels, test_labels = train_test_split(
        temp_texts, temp_labels,
        test_size=0.5,
        random_state=42,
        stratify=temp_labels,
        shuffle=True
    )

    print(f"\n✅ Split complete!")
    print(f"   Train: {len(train_texts):,} samples")
    print(f"   Val:   {len(val_texts):,} samples")
    print(f"   Test:  {len(test_texts):,} samples")

    # Verify stratification
    print(f"\n📊 Verifying Stratification:")
    print(f"\n{'Emotion':<12} {'Train %':<10} {'Val %':<10} {'Test %':<10}")
    print("-" * 45)

    for emotion_idx, emotion in enumerate(EMOTION_CLASSES):
        train_pct = (train_labels == emotion_idx).sum() / len(train_labels) * 100
        val_pct = (val_labels == emotion_idx).sum() / len(val_labels) * 100
        test_pct = (test_labels == emotion_idx).sum() / len(test_labels) * 100
        print(f"{emotion:<12} {train_pct:>8.2f}%  {val_pct:>8.2f}%  {test_pct:>8.2f}%")

    print(f"\n✅ Stratification verified - distributions match!")

except Exception as e:
    print(f"❌ ERROR during split: {e}")
    import traceback
    traceback.print_exc()
    exit(1)

# ==================== INITIALIZE TOKENIZER ====================
print("\n" + "="*80)
print("STEP 5: INITIALIZING TOKENIZER")
print("="*80)

try:
    print(f"🤖 Loading {MODEL_NAME} tokenizer...")
    tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_NAME)
    print(f"✅ Tokenizer loaded successfully!")
    print(f"   Vocabulary size: {tokenizer.vocab_size:,}")
    print(f"   Special tokens: {tokenizer.all_special_tokens}")
    print(f"   Model max length: {tokenizer.model_max_length:,}")
    print(f"   Using max length: {MAX_LENGTH}")

except Exception as e:
    print(f"❌ ERROR loading tokenizer: {e}")
    print("   Make sure transformers library is installed:")
    print("   !pip install transformers")
    exit(1)

# ==================== TOKENIZATION FUNCTION ====================
def tokenize_dataset(texts, labels, split_name):
    """Tokenize dataset split"""
    print(f"\n{'='*80}")
    print(f"TOKENIZING {split_name.upper()} SET")
    print(f"{'='*80}")
    print(f"📝 Processing {len(texts):,} samples...")

    all_input_ids = []
    all_attention_masks = []
    all_labels = []
    failed_samples = 0

    for idx in tqdm(range(len(texts)), desc=f"Tokenizing {split_name}", unit="sample"):
        try:
            text = str(texts[idx])
            label = labels[idx]

            # Tokenize
            encoding = tokenizer(
                text,
                add_special_tokens=True,
                max_length=MAX_LENGTH,
                padding='max_length',
                truncation=True,
                return_attention_mask=True,
                return_tensors='pt'
            )

            all_input_ids.append(encoding['input_ids'].squeeze(0))
            all_attention_masks.append(encoding['attention_mask'].squeeze(0))
            all_labels.append(torch.LongTensor([label]))

        except Exception as e:
            failed_samples += 1
            if failed_samples <= 5:
                print(f"\n⚠️  Error at index {idx}: {e}")
            # Fallback
            all_input_ids.append(torch.zeros(MAX_LENGTH, dtype=torch.long))
            all_attention_masks.append(torch.zeros(MAX_LENGTH, dtype=torch.long))
            all_labels.append(torch.LongTensor([0]))

    if failed_samples > 0:
        print(f"\n⚠️  {failed_samples} samples failed (filled with zeros)")

    # Stack tensors
    print(f"\n📦 Stacking tensors...")
    input_ids = torch.stack(all_input_ids)
    attention_masks = torch.stack(all_attention_masks)
    labels = torch.cat(all_labels)

    print(f"✅ Tokenization complete!")
    print(f"   Input IDs shape: {input_ids.shape}")
    print(f"   Attention masks shape: {attention_masks.shape}")
    print(f"   Labels shape: {labels.shape}")

    # Memory usage
    memory_mb = (input_ids.element_size() * input_ids.numel() +
                 attention_masks.element_size() * attention_masks.numel() +
                 labels.element_size() * labels.numel()) / (1024**2)
    print(f"   Memory: ~{memory_mb:.2f} MB")

    return {
        'input_ids': input_ids,
        'attention_mask': attention_masks,
        'labels': labels
    }

# ==================== TOKENIZE ALL SPLITS ====================
print("\n" + "="*80)
print("STEP 6: TOKENIZING ALL SPLITS")
print("="*80)

try:
    train_data = tokenize_dataset(train_texts, train_labels, 'train')
    val_data = tokenize_dataset(val_texts, val_labels, 'validation')
    test_data = tokenize_dataset(test_texts, test_labels, 'test')

    print("\n✅ All splits tokenized successfully!")

except Exception as e:
    print(f"\n❌ ERROR during tokenization: {e}")
    import traceback
    traceback.print_exc()
    exit(1)

# ==================== SAVE TOKENIZED DATA ====================
print("\n" + "="*80)
print("STEP 7: SAVING TOKENIZED DATA")
print("="*80)

try:
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f"📁 Save directory: {SAVE_DIR}")

    # Define paths
    train_path = os.path.join(SAVE_DIR, 'tokenized_train.pt')
    val_path = os.path.join(SAVE_DIR, 'tokenized_val.pt')
    test_path = os.path.join(SAVE_DIR, 'tokenized_test.pt')
    metadata_path = os.path.join(SAVE_DIR, 'metadata.json')

    # Save train
    print(f"\n💾 Saving train data...")
    torch.save(train_data, train_path)
    print(f"✅ {train_path}")
    print(f"   Size: {os.path.getsize(train_path) / (1024**2):.2f} MB")

    # Save val
    print(f"\n💾 Saving validation data...")
    torch.save(val_data, val_path)
    print(f"✅ {val_path}")
    print(f"   Size: {os.path.getsize(val_path) / (1024**2):.2f} MB")

    # Save test
    print(f"\n💾 Saving test data...")
    torch.save(test_data, test_path)
    print(f"✅ {test_path}")
    print(f"   Size: {os.path.getsize(test_path) / (1024**2):.2f} MB")

    # Save metadata
    print(f"\n💾 Saving metadata...")
    metadata = {
        'emotion_classes': EMOTION_CLASSES,
        'label_to_id': label_to_id,
        'id_to_label': id_to_label,
        'class_weights': weights_dict,
        'num_classes': len(EMOTION_CLASSES),
        'max_length': MAX_LENGTH,
        'model_name': MODEL_NAME,
        'train_samples': len(train_texts),
        'val_samples': len(val_texts),
        'test_samples': len(test_texts),
        'total_samples': total_samples
    }

    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)

    print(f"✅ {metadata_path}")

    total_size = (os.path.getsize(train_path) +
                  os.path.getsize(val_path) +
                  os.path.getsize(test_path)) / (1024**2)
    print(f"\n📊 Total storage: {total_size:.2f} MB")

except Exception as e:
    print(f"\n❌ ERROR saving data: {e}")
    import traceback
    traceback.print_exc()
    exit(1)

# ==================== VERIFICATION ====================
print("\n" + "="*80)
print("STEP 8: VERIFICATION")
print("="*80)

try:
    print("🔍 Verifying saved files...")

    # Load and verify
    loaded_train = torch.load(train_path)
    loaded_val = torch.load(val_path)
    loaded_test = torch.load(test_path)

    assert loaded_train['input_ids'].shape[1] == MAX_LENGTH
    assert loaded_train['labels'].dim() == 1
    assert len(loaded_train['labels'].unique()) <= len(EMOTION_CLASSES)

    print(f"✅ Train: {loaded_train['input_ids'].shape[0]:,} samples")
    print(f"✅ Val: {loaded_val['input_ids'].shape[0]:,} samples")
    print(f"✅ Test: {loaded_test['input_ids'].shape[0]:,} samples")

    # Verify label distribution
    print(f"\n📊 Label Distribution Verification:")
    for split_name, data in [('Train', loaded_train), ('Val', loaded_val), ('Test', loaded_test)]:
        label_counts = torch.bincount(data['labels'], minlength=len(EMOTION_CLASSES))
        print(f"\n{split_name}:")
        for idx, count in enumerate(label_counts):
            print(f"  {EMOTION_CLASSES[idx]:12s}: {count:,}")

    print("\n✅ All verifications passed!")

except Exception as e:
    print(f"\n❌ ERROR during verification: {e}")
    import traceback
    traceback.print_exc()
    exit(1)

# ==================== SUMMARY ====================
print("\n" + "="*80)
print("🎉 TOKENIZATION COMPLETE!")
print("="*80)
print(f"\n📋 Summary:")
print(f"   ✅ Total samples: {total_samples:,}")
print(f"   ✅ Train: {len(train_texts):,} (80%)")
print(f"   ✅ Val: {len(val_texts):,} (10%)")
print(f"   ✅ Test: {len(test_texts):,} (10%)")
print(f"   ✅ Emotions: {len(EMOTION_CLASSES)}")
print(f"   ✅ Max length: {MAX_LENGTH}")
print(f"   ✅ Stratified: Yes")
print(f"   ✅ Shuffled: Yes")
print(f"   ✅ Class weights: Calculated")
print(f"\n📁 Saved Files:")
print(f"   📄 {train_path}")
print(f"   📄 {val_path}")
print(f"   📄 {test_path}")
print(f"   📄 {metadata_path}")
print(f"\n💡 Data Format:")
print(f"   - input_ids: [N, {MAX_LENGTH}]")
print(f"   - attention_mask: [N, {MAX_LENGTH}]")
print(f"   - labels: [N] (single-label integers)")
print(f"\n🚀 Ready for DeBERTa v3 training!")
print("="*80)

                    🚀 SuperEmotion Dataset Tokenization 🚀
                      DeBERTa v3 Base - 7 Emotions

📁 Dataset: /content/drive/MyDrive/SuperEmotion/superemotion.csv
💾 Save Directory: /content/drive/MyDrive/SuperEmotion/
🤖 Model: microsoft/deberta-v3-base
📏 Max Length: 128
🎭 Emotions: 7
   anger, fear, joy, love, neutral, sadness, surprise

🔗 Mounting Google Drive...
Mounted at /content/drive
✅ Google Drive mounted successfully!

STEP 1: LOADING DATASET
📂 Loading CSV from: /content/drive/MyDrive/SuperEmotion/superemotion.csv
✅ Dataset loaded successfully!
   Shape: (190716, 2)
   Columns: ['text', 'emotions']

✅ Columns verified: text, emotions

🔍 Checking data quality...
✅ Dataset after cleaning: 190,716 samples

📊 Emotion Distribution (Original):
   fear        : 30,000 (15.73%)
   joy         : 30,000 (15.73%)
   sadness     : 30,000 (15.73%)
   anger       : 30,000 (15.73%)
   love        : 30,000 (15.73%)
   neutral     : 24,443 (12.82%)
   surprise    : 16,273 ( 8.53%)

✅

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

✅ Tokenizer loaded successfully!
   Vocabulary size: 128,000
   Special tokens: ['[CLS]', '[SEP]', '[UNK]', '[PAD]', '[MASK]']
   Model max length: 1,000,000,000,000,000,019,884,624,838,656
   Using max length: 128

STEP 6: TOKENIZING ALL SPLITS

TOKENIZING TRAIN SET
📝 Processing 152,572 samples...


Tokenizing train:   0%|          | 0/152572 [00:00<?, ?sample/s]


📦 Stacking tensors...
✅ Tokenization complete!
   Input IDs shape: torch.Size([152572, 128])
   Attention masks shape: torch.Size([152572, 128])
   Labels shape: torch.Size([152572])
   Memory: ~299.16 MB

TOKENIZING VALIDATION SET
📝 Processing 19,072 samples...


Tokenizing validation:   0%|          | 0/19072 [00:00<?, ?sample/s]


📦 Stacking tensors...
✅ Tokenization complete!
   Input IDs shape: torch.Size([19072, 128])
   Attention masks shape: torch.Size([19072, 128])
   Labels shape: torch.Size([19072])
   Memory: ~37.40 MB

TOKENIZING TEST SET
📝 Processing 19,072 samples...


Tokenizing test:   0%|          | 0/19072 [00:00<?, ?sample/s]


📦 Stacking tensors...
✅ Tokenization complete!
   Input IDs shape: torch.Size([19072, 128])
   Attention masks shape: torch.Size([19072, 128])
   Labels shape: torch.Size([19072])
   Memory: ~37.40 MB

✅ All splits tokenized successfully!

STEP 7: SAVING TOKENIZED DATA
📁 Save directory: /content/drive/MyDrive/SuperEmotion/

💾 Saving train data...
✅ /content/drive/MyDrive/SuperEmotion/tokenized_train.pt
   Size: 299.16 MB

💾 Saving validation data...
✅ /content/drive/MyDrive/SuperEmotion/tokenized_val.pt
   Size: 37.40 MB

💾 Saving test data...
✅ /content/drive/MyDrive/SuperEmotion/tokenized_test.pt
   Size: 37.40 MB

💾 Saving metadata...
✅ /content/drive/MyDrive/SuperEmotion/metadata.json

📊 Total storage: 373.95 MB

STEP 8: VERIFICATION
🔍 Verifying saved files...
✅ Train: 152,572 samples
✅ Val: 19,072 samples
✅ Test: 19,072 samples

📊 Label Distribution Verification:

Train:
  anger       : 24,000
  fear        : 24,000
  joy         : 24,000
  love        : 24,000
  neutral     : 19

In [None]:
"""
DeBERTa v3 Fine-tuning Script for Emotion Detection
SuperEmotion Dataset - 7 Emotions Single-label Classification
"""

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
from transformers import DebertaV2Model, get_cosine_schedule_with_warmup
from sklearn.metrics import (f1_score, precision_score, recall_score,
                            accuracy_score, confusion_matrix, classification_report)
from tqdm.auto import tqdm
import numpy as np
import json
import time
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
warnings.filterwarnings('ignore')

# ==================== CONFIGURATION ====================
DATA_DIR = '/content/drive/MyDrive/SuperEmotion/'
SAVE_DIR_EPOCHS = '/content/drive/MyDrive/SimpleSaves/'
SAVE_DIR_BEST = '/content/drive/MyDrive/BestModelSave/'
MODEL_NAME = 'microsoft/deberta-v3-base'
MAX_LENGTH = 128
TRAIN_BATCH_SIZE = 64
VAL_BATCH_SIZE = 128
EPOCHS = 20
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
GRADIENT_ACCUMULATION_STEPS = 2
GRADIENT_CLIP = 1.0
LOG_INTERVAL = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
USE_MIXED_PRECISION = True

print("="*90)
print(" "*25 + "🚀 DeBERTa v3 Emotion Detection Training 🚀")
print(" "*30 + "SuperEmotion - 7 Classes")
print("="*90)
print(f"\n🖥️  Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"   CUDA: {torch.version.cuda}")

print(f"\n📊 Configuration:")
print(f"   Model: {MODEL_NAME}")
print(f"   Train Batch: {TRAIN_BATCH_SIZE} (x{GRADIENT_ACCUMULATION_STEPS} accum = {TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS} effective)")
print(f"   Val Batch: {VAL_BATCH_SIZE}")
print(f"   Epochs: {EPOCHS}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Weight Decay: {WEIGHT_DECAY}")
print(f"   Max Length: {MAX_LENGTH}")
print(f"   Mixed Precision: {USE_MIXED_PRECISION}")
print(f"   Gradient Clipping: {GRADIENT_CLIP}")

# ==================== MOUNT DRIVE ====================
try:
    from google.colab import drive
    if not os.path.exists('/content/drive'):
        print("\n🔗 Mounting Google Drive...")
        drive.mount('/content/drive')
        print("✅ Drive mounted!")
    else:
        print("\n✅ Drive already mounted!")
except:
    print("\n⚠️  Not in Colab or Drive mounted")

# Setup directories
for directory in [SAVE_DIR_EPOCHS, SAVE_DIR_BEST]:
    os.makedirs(directory, exist_ok=True)
print(f"✅ Save directories ready")

# ==================== MODEL CLASS ====================
class DeBERTaEmotionClassifier(nn.Module):
    def __init__(self, num_labels):
        super(DeBERTaEmotionClassifier, self).__init__()
        self.deberta = DebertaV2Model.from_pretrained(MODEL_NAME)

        # Disable gradient checkpointing - causes issues with gradient accumulation
        # Memory usage will be higher but training will work correctly

        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.deberta.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        # DeBERTa uses last hidden state, not pooler output
        sequence_output = outputs.last_hidden_state
        # Use [CLS] token (first token)
        cls_output = sequence_output[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

# ==================== LOAD DATA ====================
def load_data():
    """Load pre-tokenized data and metadata"""
    print("\n" + "="*90)
    print("STEP 1: LOADING DATA")
    print("="*90)

    try:
        # Load metadata
        metadata_path = os.path.join(DATA_DIR, 'metadata.json')
        print(f"\n📂 Loading metadata from: {metadata_path}")
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        emotion_classes = metadata['emotion_classes']
        class_weights = metadata['class_weights']
        num_classes = metadata['num_classes']

        print(f"✅ Metadata loaded!")
        print(f"   Emotions: {', '.join(emotion_classes)}")
        print(f"   Classes: {num_classes}")

        # Load tokenized data
        train_path = os.path.join(DATA_DIR, 'tokenized_train.pt')
        val_path = os.path.join(DATA_DIR, 'tokenized_val.pt')
        test_path = os.path.join(DATA_DIR, 'tokenized_test.pt')

        print(f"\n📂 Loading train data...")
        train_data = torch.load(train_path)
        print(f"✅ Train: {train_data['input_ids'].shape[0]:,} samples")

        print(f"\n📂 Loading validation data...")
        val_data = torch.load(val_path)
        print(f"✅ Val: {val_data['input_ids'].shape[0]:,} samples")

        print(f"\n📂 Loading test data...")
        test_data = torch.load(test_path)
        print(f"✅ Test: {test_data['input_ids'].shape[0]:,} samples")

        # Create datasets
        train_dataset = TensorDataset(
            train_data['input_ids'],
            train_data['attention_mask'],
            train_data['labels']
        )

        val_dataset = TensorDataset(
            val_data['input_ids'],
            val_data['attention_mask'],
            val_data['labels']
        )

        test_dataset = TensorDataset(
            test_data['input_ids'],
            test_data['attention_mask'],
            test_data['labels']
        )

        # Convert class weights to tensor
        weights_tensor = torch.FloatTensor([class_weights[emotion] for emotion in emotion_classes])

        print(f"\n📊 Class Weights:")
        for emotion, weight in class_weights.items():
            print(f"   {emotion:12s}: {weight:.4f}")

        return train_dataset, val_dataset, test_dataset, emotion_classes, weights_tensor, metadata

    except Exception as e:
        print(f"\n❌ ERROR loading data: {e}")
        import traceback
        traceback.print_exc()
        raise

# ==================== CREATE DATALOADERS ====================
def create_dataloaders(train_dataset, val_dataset, test_dataset):
    """Create dataloaders"""
    print("\n" + "="*90)
    print("STEP 2: CREATING DATALOADERS")
    print("="*90)

    train_loader = DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=0
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=VAL_BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=0
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=VAL_BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=0
    )

    print(f"✅ DataLoaders created!")
    print(f"   Train: {len(train_loader):,} batches")
    print(f"   Val: {len(val_loader):,} batches")
    print(f"   Test: {len(test_loader):,} batches")
    print(f"   Effective batch size: {TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

    return train_loader, val_loader, test_loader

# ==================== INITIALIZE MODEL ====================
def initialize_model(num_classes):
    """Initialize DeBERTa model"""
    print("\n" + "="*90)
    print("STEP 3: INITIALIZING MODEL")
    print("="*90)

    try:
        print(f"🤖 Loading {MODEL_NAME}...")
        model = DeBERTaEmotionClassifier(num_labels=num_classes)
        model.to(DEVICE)

        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

        print(f"✅ Model initialized!")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,}")
        print(f"   Model size: ~{total_params * 4 / 1024**2:.2f} MB")
        print(f"   Gradient checkpointing: Disabled (for stability)")

        return model

    except Exception as e:
        print(f"❌ ERROR: {e}")
        raise

# ==================== OPTIMIZER & SCHEDULER ====================
def setup_optimizer_scheduler(model, train_loader):
    """Setup optimizer and scheduler"""
    print("\n" + "="*90)
    print("STEP 4: OPTIMIZER & SCHEDULER")
    print("="*90)

    # Differential learning rates
    optimizer_params = [
        {
            'params': [p for n, p in model.named_parameters() if 'deberta' in n],
            'lr': LEARNING_RATE,
            'weight_decay': WEIGHT_DECAY
        },
        {
            'params': [p for n, p in model.named_parameters() if 'classifier' in n or 'dropout' in n],
            'lr': LEARNING_RATE * 10,  # Higher LR for classifier
            'weight_decay': WEIGHT_DECAY
        }
    ]

    optimizer = AdamW(optimizer_params, eps=1e-8)

    total_steps = len(train_loader) * EPOCHS // GRADIENT_ACCUMULATION_STEPS
    warmup_steps = int(WARMUP_RATIO * total_steps)

    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    print(f"✅ Optimizer: AdamW (Differential LR)")
    print(f"   DeBERTa LR: {LEARNING_RATE}")
    print(f"   Classifier LR: {LEARNING_RATE * 10}")
    print(f"   Weight Decay: {WEIGHT_DECAY}")
    print(f"\n✅ Scheduler: Cosine with Warmup")
    print(f"   Total steps: {total_steps:,}")
    print(f"   Warmup steps: {warmup_steps:,} ({WARMUP_RATIO*100:.0f}%)")

    return optimizer, scheduler

# ==================== TRAINING FUNCTION ====================
def train_epoch(model, dataloader, optimizer, scheduler, criterion, scaler, epoch):
    """Train one epoch"""
    model.train()
    total_loss = 0
    batch_losses = []

    progress_bar = tqdm(
        dataloader,
        desc=f"Epoch {epoch+1}/{EPOCHS} [TRAIN]",
        unit="batch",
        colour="green"
    )

    for batch_idx, batch in enumerate(progress_bar):
        try:
            input_ids = batch[0].to(DEVICE)
            attention_mask = batch[1].to(DEVICE)
            labels = batch[2].to(DEVICE)

            # Forward pass with mixed precision
            with autocast(enabled=USE_MIXED_PRECISION):
                logits = model(input_ids, attention_mask)
                loss = criterion(logits, labels)
                loss = loss / GRADIENT_ACCUMULATION_STEPS

            # Backward pass
            scaler.scale(loss).backward()

            # Update weights after accumulation steps
            if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

            batch_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS
            total_loss += batch_loss
            batch_losses.append(batch_loss)

            avg_loss = total_loss / (batch_idx + 1)
            current_lr = scheduler.get_last_lr()[0]
            progress_bar.set_postfix({
                'loss': f'{batch_loss:.4f}',
                'avg': f'{avg_loss:.4f}',
                'lr': f'{current_lr:.2e}'
            })

            if (batch_idx + 1) % LOG_INTERVAL == 0:
                recent_avg = np.mean(batch_losses[-LOG_INTERVAL:])
                print(f"\n   Batch [{batch_idx+1}/{len(dataloader)}] "
                      f"Loss: {batch_loss:.4f} | "
                      f"Avg: {avg_loss:.4f} | "
                      f"Recent Avg: {recent_avg:.4f} | "
                      f"LR: {current_lr:.2e}")

        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"\n⚠️  OOM at batch {batch_idx}")
                torch.cuda.empty_cache()
                optimizer.zero_grad()
                continue
            else:
                print(f"\n⚠️  Runtime error at batch {batch_idx}: {e}")
                optimizer.zero_grad()
                continue
        except Exception as e:
            print(f"\n⚠️  Error at batch {batch_idx}: {e}")
            optimizer.zero_grad()
            continue

    # Final gradient update if there are remaining accumulated gradients
    if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRADIENT_CLIP)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()

    return total_loss / len(dataloader)

# ==================== EVALUATION FUNCTION ====================
def evaluate(model, dataloader, criterion, emotion_classes, split_name="VAL"):
    """Evaluate with comprehensive metrics"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    progress_bar = tqdm(
        dataloader,
        desc=f"{split_name}",
        unit="batch",
        colour="blue"
    )

    with torch.no_grad():
        for batch in progress_bar:
            try:
                input_ids = batch[0].to(DEVICE)
                attention_mask = batch[1].to(DEVICE)
                labels = batch[2].to(DEVICE)

                with autocast(enabled=USE_MIXED_PRECISION):
                    logits = model(input_ids, attention_mask)
                    loss = criterion(logits, labels)

                total_loss += loss.item()
                preds = torch.argmax(logits, dim=1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

                progress_bar.set_postfix({'loss': f'{total_loss / (len(all_preds) // VAL_BATCH_SIZE):.4f}'})

            except Exception as e:
                print(f"\n⚠️  Error: {e}")
                continue

    # Calculate metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    accuracy = accuracy_score(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    precision_weighted = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    recall_weighted = recall_score(all_labels, all_preds, average='weighted', zero_division=0)

    # Per-class metrics
    f1_per_class = f1_score(all_labels, all_preds, average=None, zero_division=0)
    precision_per_class = precision_score(all_labels, all_preds, average=None, zero_division=0)
    recall_per_class = recall_score(all_labels, all_preds, average=None, zero_division=0)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    avg_loss = total_loss / len(dataloader)

    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'precision_macro': precision_macro,
        'precision_weighted': precision_weighted,
        'recall_macro': recall_macro,
        'recall_weighted': recall_weighted,
        'f1_per_class': f1_per_class,
        'precision_per_class': precision_per_class,
        'recall_per_class': recall_per_class,
        'confusion_matrix': cm,
        'predictions': all_preds,
        'labels': all_labels
    }

# ==================== SAVE MODEL ====================
def save_model(model, path, name, metrics=None, metadata=None):
    """Save model with metadata"""
    try:
        full_path = os.path.join(path, name)
        os.makedirs(full_path, exist_ok=True)

        # Save DeBERTa
        model.deberta.save_pretrained(full_path)

        # Save classifier
        save_dict = {
            'classifier_state_dict': model.classifier.state_dict(),
            'dropout_state_dict': model.dropout.state_dict(),
            'metrics': metrics,
            'metadata': metadata
        }

        torch.save(save_dict, os.path.join(full_path, 'classifier.pt'))

        # Save metrics as JSON
        if metrics:
            metrics_to_save = {k: float(v) if not isinstance(v, (np.ndarray, list)) else v.tolist()
                             for k, v in metrics.items() if k not in ['predictions', 'labels', 'confusion_matrix']}
            with open(os.path.join(full_path, 'metrics.json'), 'w') as f:
                json.dump(metrics_to_save, f, indent=2)

        print(f"✅ Saved: {full_path}")
        return True

    except Exception as e:
        print(f"⚠️  Save error: {e}")
        return False

# ==================== PLOT CONFUSION MATRIX ====================
def plot_confusion_matrix(cm, emotion_classes, epoch):
    """Plot and save confusion matrix"""
    try:
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=emotion_classes,
                   yticklabels=emotion_classes)
        plt.title(f'Confusion Matrix - Epoch {epoch}')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()

        cm_path = os.path.join(SAVE_DIR_EPOCHS, f'confusion_matrix_epoch_{epoch}.png')
        plt.savefig(cm_path, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"   📊 Confusion matrix saved: {cm_path}")
    except Exception as e:
        print(f"   ⚠️  Could not save confusion matrix: {e}")

# ==================== MAIN ====================
def main():
    print("\n" + "="*90)
    print("🚀 STARTING TRAINING")
    print("="*90)

    start_time = time.time()

    # Load data
    train_dataset, val_dataset, test_dataset, emotion_classes, class_weights, metadata = load_data()

    # Create dataloaders
    train_loader, val_loader, test_loader = create_dataloaders(train_dataset, val_dataset, test_dataset)

    # Initialize model
    model = initialize_model(len(emotion_classes))

    # Setup optimizer
    optimizer, scheduler = setup_optimizer_scheduler(model, train_loader)

    # Loss function with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE))
    print(f"\n✅ Loss: CrossEntropyLoss with class weights")

    # Mixed precision scaler
    scaler = GradScaler(enabled=USE_MIXED_PRECISION)

    # Training tracking
    best_f1_macro = 0.0
    best_epoch = 0
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_f1_macro': [],
        'val_accuracy': []
    }

    print("\n" + "="*90)
    print(f"⏰ TRAINING START: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print("="*90)

    # Training loop
    for epoch in range(EPOCHS):
        epoch_start = time.time()

        print(f"\n{'='*90}")
        print(f"EPOCH {epoch + 1}/{EPOCHS}")
        print(f"{'='*90}")

        try:
            # Train
            train_loss = train_epoch(
                model, train_loader, optimizer, scheduler,
                criterion, scaler, epoch
            )

            # Validate
            print(f"\n🔍 Evaluating on validation set...")
            val_metrics = evaluate(model, val_loader, criterion, emotion_classes, "VALIDATION")

            epoch_time = time.time() - epoch_start

            # Update history
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_metrics['loss'])
            history['val_f1_macro'].append(val_metrics['f1_macro'])
            history['val_accuracy'].append(val_metrics['accuracy'])

            # Print results
            print(f"\n{'='*90}")
            print(f"EPOCH {epoch + 1} RESULTS")
            print(f"{'='*90}")
            print(f"⏱️  Time: {epoch_time:.2f}s ({epoch_time/60:.2f} min)")
            print(f"\n📉 Loss:")
            print(f"   Train: {train_loss:.4f}")
            print(f"   Val:   {val_metrics['loss']:.4f}")
            print(f"\n📊 Overall Metrics:")
            print(f"   Accuracy:          {val_metrics['accuracy']:.4f}")
            print(f"   F1 Macro:          {val_metrics['f1_macro']:.4f}")
            print(f"   F1 Weighted:       {val_metrics['f1_weighted']:.4f}")
            print(f"   Precision Macro:   {val_metrics['precision_macro']:.4f}")
            print(f"   Precision Weighted: {val_metrics['precision_weighted']:.4f}")
            print(f"   Recall Macro:      {val_metrics['recall_macro']:.4f}")
            print(f"   Recall Weighted:   {val_metrics['recall_weighted']:.4f}")

            # Per-class metrics
            print(f"\n📊 Per-Class Metrics:")
            print(f"{'Emotion':<12} {'F1':>6} {'Precision':>10} {'Recall':>8}")
            print("-" * 40)
            for i, emotion in enumerate(emotion_classes):
                print(f"{emotion:<12} {val_metrics['f1_per_class'][i]:>6.4f} "
                      f"{val_metrics['precision_per_class'][i]:>10.4f} "
                      f"{val_metrics['recall_per_class'][i]:>8.4f}")

            print(f"{'='*90}")

            # Plot confusion matrix
            plot_confusion_matrix(val_metrics['confusion_matrix'], emotion_classes, epoch + 1)

            # Save epoch model
            save_model(model, SAVE_DIR_EPOCHS, f'epoch_{epoch+1}', val_metrics, metadata)

            # Save best model
            if val_metrics['f1_macro'] > best_f1_macro:
                best_f1_macro = val_metrics['f1_macro']
                best_epoch = epoch + 1
                print(f"\n🌟 NEW BEST MODEL! F1 Macro: {best_f1_macro:.4f}")
                save_model(model, SAVE_DIR_BEST, 'best_model', val_metrics, metadata)
            else:
                print(f"\n   Best F1 Macro: {best_f1_macro:.4f} (Epoch {best_epoch})")

        except Exception as e:
            print(f"\n❌ Error in epoch {epoch+1}: {e}")
            import traceback
            traceback.print_exc()
            continue

    # Final test evaluation
    print("\n" + "="*90)
    print("🎯 FINAL TEST SET EVALUATION")
    print("="*90)

    test_metrics = evaluate(model, test_loader, criterion, emotion_classes, "TEST")

    total_time = time.time() - start_time

    print(f"\n{'='*90}")
    print("📊 TEST SET RESULTS")
    print(f"{'='*90}")
    print(f"   Accuracy:          {test_metrics['accuracy']:.4f}")
    print(f"   F1 Macro:          {test_metrics['f1_macro']:.4f}")
    print(f"   F1 Weighted:       {test_metrics['f1_weighted']:.4f}")
    print(f"   Precision Macro:   {test_metrics['precision_macro']:.4f}")
    print(f"   Recall Macro:      {test_metrics['recall_macro']:.4f}")
    print(f"{'='*90}")

    # Final summary
    print(f"\n{'='*90}")
    print("🎉 TRAINING COMPLETE!")
    print(f"{'='*90}")
    print(f"⏱️  Total Time: {total_time/3600:.2f} hours")
    print(f"🏆 Best Val F1 Macro: {best_f1_macro:.4f} (Epoch {best_epoch})")
    print(f"📈 Final Test F1 Macro: {test_metrics['f1_macro']:.4f}")
    print(f"⏰ End: {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"\n📁 Saved Models:")
    print(f"   • Best: {SAVE_DIR_BEST}best_model/")
    print(f"   • All: {SAVE_DIR_EPOCHS}epoch_X/")
    print(f"{'='*90}\n")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n⚠️  Training interrupted by user!")
        print("   Models saved up to last completed epoch.")
    except Exception as e:
        print(f"\n\n❌ CRITICAL ERROR: {e}")
        import traceback
        traceback.print_exc()
        print("\n💡 Troubleshooting:")
        print("   1. Check tokenized data exists in SuperEmotion folder")
        print("   2. Verify GPU memory available")
        print("   3. Ensure transformers library installed")
        print("   4. Check Drive is properly mounted")

                         🚀 DeBERTa v3 Emotion Detection Training 🚀
                              SuperEmotion - 7 Classes

🖥️  Device: cuda
   GPU: Tesla T4
   Memory: 14.74 GB
   CUDA: 12.6

📊 Configuration:
   Model: microsoft/deberta-v3-base
   Train Batch: 64 (x2 accum = 128 effective)
   Val Batch: 128
   Epochs: 20
   Learning Rate: 2e-05
   Weight Decay: 0.01
   Max Length: 128
   Mixed Precision: True
   Gradient Clipping: 1.0

✅ Drive already mounted!
✅ Save directories ready

🚀 STARTING TRAINING

STEP 1: LOADING DATA

📂 Loading metadata from: /content/drive/MyDrive/SuperEmotion/metadata.json
✅ Metadata loaded!
   Emotions: anger, fear, joy, love, neutral, sadness, surprise
   Classes: 7

📂 Loading train data...
✅ Train: 152,572 samples

📂 Loading validation data...
✅ Val: 19,072 samples

📂 Loading test data...
✅ Test: 19,072 samples

📊 Class Weights:
   anger       : 0.9082
   fear        : 0.9082
   joy         : 0.9082
   love        : 0.9082
   neutral     : 1.1146
   sadn

Epoch 1/20 [TRAIN]:   0%|          | 0/2384 [00:00<?, ?batch/s]


   Batch [100/2384] Loss: 2.0513 | Avg: 2.0105 | Recent Avg: 2.0105 | LR: 4.19e-07

   Batch [200/2384] Loss: 1.9606 | Avg: 1.9974 | Recent Avg: 1.9843 | LR: 8.39e-07

   Batch [300/2384] Loss: 1.9082 | Avg: 1.9842 | Recent Avg: 1.9577 | LR: 1.26e-06

   Batch [400/2384] Loss: 1.9024 | Avg: 1.9712 | Recent Avg: 1.9324 | LR: 1.68e-06

   Batch [500/2384] Loss: 1.7784 | Avg: 1.9491 | Recent Avg: 1.8608 | LR: 2.10e-06

   Batch [600/2384] Loss: 1.6884 | Avg: 1.9061 | Recent Avg: 1.6911 | LR: 2.52e-06

   Batch [700/2384] Loss: 1.5206 | Avg: 1.8506 | Recent Avg: 1.5173 | LR: 2.94e-06

   Batch [800/2384] Loss: 1.0039 | Avg: 1.7855 | Recent Avg: 1.3295 | LR: 3.36e-06

   Batch [900/2384] Loss: 1.0339 | Avg: 1.7089 | Recent Avg: 1.0962 | LR: 3.78e-06

   Batch [1000/2384] Loss: 0.7789 | Avg: 1.6279 | Recent Avg: 0.8988 | LR: 4.19e-06

   Batch [1100/2384] Loss: 0.8867 | Avg: 1.5502 | Recent Avg: 0.7739 | LR: 4.61e-06

   Batch [1200/2384] Loss: 0.6008 | Avg: 1.4766 | Recent Avg: 0.6671 | LR

VALIDATION:   0%|          | 0/149 [00:00<?, ?batch/s]


EPOCH 1 RESULTS
⏱️  Time: 1351.08s (22.52 min)

📉 Loss:
   Train: 0.9556
   Val:   0.3264

📊 Overall Metrics:
   Accuracy:          0.8903
   F1 Macro:          0.8845
   F1 Weighted:       0.8917
   Precision Macro:   0.8842
   Precision Weighted: 0.8978
   Recall Macro:      0.8903
   Recall Weighted:   0.8903

📊 Per-Class Metrics:
Emotion          F1  Precision   Recall
----------------------------------------
anger        0.9042     0.8959   0.9127
fear         0.8972     0.9562   0.8450
joy          0.9185     0.9904   0.8563
love         0.9117     0.8899   0.9347
neutral      0.8278     0.7992   0.8584
sadness      0.9254     0.9285   0.9223
surprise     0.8065     0.7290   0.9023
   📊 Confusion matrix saved: /content/drive/MyDrive/SimpleSaves/confusion_matrix_epoch_1.png
✅ Saved: /content/drive/MyDrive/SimpleSaves/epoch_1

🌟 NEW BEST MODEL! F1 Macro: 0.8845
✅ Saved: /content/drive/MyDrive/BestModelSave/best_model

EPOCH 2/20


Epoch 2/20 [TRAIN]:   0%|          | 0/2384 [00:00<?, ?batch/s]


   Batch [100/2384] Loss: 0.1785 | Avg: 0.3253 | Recent Avg: 0.3253 | LR: 1.04e-05

   Batch [200/2384] Loss: 0.2390 | Avg: 0.3270 | Recent Avg: 0.3287 | LR: 1.08e-05

   Batch [300/2384] Loss: 0.2211 | Avg: 0.3193 | Recent Avg: 0.3039 | LR: 1.13e-05

   Batch [400/2384] Loss: 0.3629 | Avg: 0.3202 | Recent Avg: 0.3229 | LR: 1.17e-05

   Batch [500/2384] Loss: 0.2968 | Avg: 0.3206 | Recent Avg: 0.3224 | LR: 1.21e-05

   Batch [600/2384] Loss: 0.2382 | Avg: 0.3187 | Recent Avg: 0.3088 | LR: 1.25e-05

   Batch [700/2384] Loss: 0.3045 | Avg: 0.3183 | Recent Avg: 0.3164 | LR: 1.29e-05

   Batch [800/2384] Loss: 0.3073 | Avg: 0.3166 | Recent Avg: 0.3043 | LR: 1.34e-05

   Batch [900/2384] Loss: 0.1848 | Avg: 0.3155 | Recent Avg: 0.3070 | LR: 1.38e-05

   Batch [1000/2384] Loss: 0.2728 | Avg: 0.3131 | Recent Avg: 0.2912 | LR: 1.42e-05

   Batch [1100/2384] Loss: 0.3596 | Avg: 0.3119 | Recent Avg: 0.3003 | LR: 1.46e-05

   Batch [1200/2384] Loss: 0.3521 | Avg: 0.3107 | Recent Avg: 0.2972 | LR

VALIDATION:   0%|          | 0/149 [00:00<?, ?batch/s]


EPOCH 2 RESULTS
⏱️  Time: 1349.39s (22.49 min)

📉 Loss:
   Train: 0.2953
   Val:   0.2526

📊 Overall Metrics:
   Accuracy:          0.9071
   F1 Macro:          0.9014
   F1 Weighted:       0.9091
   Precision Macro:   0.9007
   Precision Weighted: 0.9155
   Recall Macro:      0.9075
   Recall Weighted:   0.9071

📊 Per-Class Metrics:
Emotion          F1  Precision   Recall
----------------------------------------
anger        0.9278     0.9413   0.9147
fear         0.9118     0.9433   0.8823
joy          0.9365     0.9907   0.8880
love         0.9190     0.9026   0.9360
neutral      0.8503     0.7999   0.9075
sadness      0.9500     0.9920   0.9113
surprise     0.8142     0.7349   0.9128
   📊 Confusion matrix saved: /content/drive/MyDrive/SimpleSaves/confusion_matrix_epoch_2.png
✅ Saved: /content/drive/MyDrive/SimpleSaves/epoch_2

🌟 NEW BEST MODEL! F1 Macro: 0.9014
✅ Saved: /content/drive/MyDrive/BestModelSave/best_model

EPOCH 3/20


Epoch 3/20 [TRAIN]:   0%|          | 0/2384 [00:00<?, ?batch/s]


   Batch [100/2384] Loss: 0.2544 | Avg: 0.2546 | Recent Avg: 0.2546 | LR: 2.00e-05

   Batch [200/2384] Loss: 0.1801 | Avg: 0.2606 | Recent Avg: 0.2667 | LR: 2.00e-05

   Batch [300/2384] Loss: 0.3869 | Avg: 0.2616 | Recent Avg: 0.2635 | LR: 2.00e-05

   Batch [400/2384] Loss: 0.2847 | Avg: 0.2605 | Recent Avg: 0.2573 | LR: 2.00e-05

   Batch [500/2384] Loss: 0.2241 | Avg: 0.2607 | Recent Avg: 0.2615 | LR: 2.00e-05

   Batch [600/2384] Loss: 0.1354 | Avg: 0.2564 | Recent Avg: 0.2348 | LR: 2.00e-05

   Batch [700/2384] Loss: 0.2027 | Avg: 0.2549 | Recent Avg: 0.2458 | LR: 2.00e-05

   Batch [800/2384] Loss: 0.4426 | Avg: 0.2557 | Recent Avg: 0.2617 | LR: 2.00e-05

   Batch [900/2384] Loss: 0.2525 | Avg: 0.2545 | Recent Avg: 0.2444 | LR: 2.00e-05

   Batch [1000/2384] Loss: 0.1203 | Avg: 0.2531 | Recent Avg: 0.2411 | LR: 2.00e-05

   Batch [1100/2384] Loss: 0.2803 | Avg: 0.2516 | Recent Avg: 0.2363 | LR: 2.00e-05

   Batch [1200/2384] Loss: 0.1977 | Avg: 0.2512 | Recent Avg: 0.2467 | LR

VALIDATION:   0%|          | 0/149 [00:00<?, ?batch/s]


EPOCH 3 RESULTS
⏱️  Time: 1348.05s (22.47 min)

📉 Loss:
   Train: 0.2481
   Val:   0.2549

📊 Overall Metrics:
   Accuracy:          0.9073
   F1 Macro:          0.9014
   F1 Weighted:       0.9090
   Precision Macro:   0.9005
   Precision Weighted: 0.9145
   Recall Macro:      0.9071
   Recall Weighted:   0.9073

📊 Per-Class Metrics:
Emotion          F1  Precision   Recall
----------------------------------------
anger        0.9276     0.9561   0.9007
fear         0.9113     0.9254   0.8977
joy          0.9365     0.9882   0.8900
love         0.9189     0.8883   0.9517
neutral      0.8467     0.8186   0.8768
sadness      0.9512     0.9885   0.9167
surprise     0.8174     0.7381   0.9158
   📊 Confusion matrix saved: /content/drive/MyDrive/SimpleSaves/confusion_matrix_epoch_3.png
✅ Saved: /content/drive/MyDrive/SimpleSaves/epoch_3

   Best F1 Macro: 0.9014 (Epoch 2)

EPOCH 4/20


Epoch 4/20 [TRAIN]:   0%|          | 0/2384 [00:00<?, ?batch/s]


   Batch [100/2384] Loss: 0.2524 | Avg: 0.2143 | Recent Avg: 0.2143 | LR: 1.98e-05

   Batch [200/2384] Loss: 0.2467 | Avg: 0.2141 | Recent Avg: 0.2140 | LR: 1.98e-05

   Batch [300/2384] Loss: 0.2636 | Avg: 0.2174 | Recent Avg: 0.2240 | LR: 1.98e-05

   Batch [400/2384] Loss: 0.1610 | Avg: 0.2184 | Recent Avg: 0.2212 | LR: 1.98e-05

   Batch [500/2384] Loss: 0.1403 | Avg: 0.2180 | Recent Avg: 0.2164 | LR: 1.98e-05

   Batch [600/2384] Loss: 0.1166 | Avg: 0.2189 | Recent Avg: 0.2237 | LR: 1.98e-05

   Batch [700/2384] Loss: 0.2813 | Avg: 0.2181 | Recent Avg: 0.2134 | LR: 1.97e-05

   Batch [800/2384] Loss: 0.1957 | Avg: 0.2181 | Recent Avg: 0.2179 | LR: 1.97e-05

   Batch [900/2384] Loss: 0.3136 | Avg: 0.2191 | Recent Avg: 0.2270 | LR: 1.97e-05

   Batch [1000/2384] Loss: 0.2124 | Avg: 0.2181 | Recent Avg: 0.2091 | LR: 1.97e-05

   Batch [1100/2384] Loss: 0.2001 | Avg: 0.2179 | Recent Avg: 0.2155 | LR: 1.97e-05

   Batch [1200/2384] Loss: 0.1194 | Avg: 0.2191 | Recent Avg: 0.2323 | LR

VALIDATION:   0%|          | 0/149 [00:00<?, ?batch/s]


EPOCH 4 RESULTS
⏱️  Time: 1347.74s (22.46 min)

📉 Loss:
   Train: 0.2188
   Val:   0.2534

📊 Overall Metrics:
   Accuracy:          0.9105
   F1 Macro:          0.9050
   F1 Weighted:       0.9114
   Precision Macro:   0.9051
   Precision Weighted: 0.9163
   Recall Macro:      0.9090
   Recall Weighted:   0.9105

📊 Per-Class Metrics:
Emotion          F1  Precision   Recall
----------------------------------------
anger        0.9223     0.8810   0.9677
fear         0.9052     0.9663   0.8513
joy          0.9400     0.9948   0.8910
love         0.9232     0.9076   0.9393
neutral      0.8582     0.8305   0.8879
sadness      0.9517     0.9758   0.9287
surprise     0.8344     0.7796   0.8974
   📊 Confusion matrix saved: /content/drive/MyDrive/SimpleSaves/confusion_matrix_epoch_4.png
✅ Saved: /content/drive/MyDrive/SimpleSaves/epoch_4

🌟 NEW BEST MODEL! F1 Macro: 0.9050
✅ Saved: /content/drive/MyDrive/BestModelSave/best_model

EPOCH 5/20


Epoch 5/20 [TRAIN]:   0%|          | 0/2384 [00:00<?, ?batch/s]


   Batch [100/2384] Loss: 0.2472 | Avg: 0.1940 | Recent Avg: 0.1940 | LR: 1.94e-05

   Batch [200/2384] Loss: 0.2158 | Avg: 0.1883 | Recent Avg: 0.1827 | LR: 1.93e-05

   Batch [300/2384] Loss: 0.1820 | Avg: 0.1874 | Recent Avg: 0.1855 | LR: 1.93e-05

   Batch [400/2384] Loss: 0.1044 | Avg: 0.1867 | Recent Avg: 0.1845 | LR: 1.93e-05

   Batch [500/2384] Loss: 0.1466 | Avg: 0.1886 | Recent Avg: 0.1966 | LR: 1.93e-05

   Batch [600/2384] Loss: 0.1613 | Avg: 0.1893 | Recent Avg: 0.1929 | LR: 1.92e-05

   Batch [700/2384] Loss: 0.1523 | Avg: 0.1893 | Recent Avg: 0.1891 | LR: 1.92e-05

   Batch [800/2384] Loss: 0.1054 | Avg: 0.1907 | Recent Avg: 0.2001 | LR: 1.92e-05

   Batch [900/2384] Loss: 0.1443 | Avg: 0.1901 | Recent Avg: 0.1856 | LR: 1.92e-05

   Batch [1000/2384] Loss: 0.2121 | Avg: 0.1896 | Recent Avg: 0.1849 | LR: 1.91e-05

   Batch [1100/2384] Loss: 0.0598 | Avg: 0.1891 | Recent Avg: 0.1838 | LR: 1.91e-05

   Batch [1200/2384] Loss: 0.2458 | Avg: 0.1899 | Recent Avg: 0.1988 | LR

VALIDATION:   0%|          | 0/149 [00:00<?, ?batch/s]


EPOCH 5 RESULTS
⏱️  Time: 1348.20s (22.47 min)

📉 Loss:
   Train: 0.1923
   Val:   0.2561

📊 Overall Metrics:
   Accuracy:          0.9062
   F1 Macro:          0.8994
   F1 Weighted:       0.9071
   Precision Macro:   0.8980
   Precision Weighted: 0.9108
   Recall Macro:      0.9044
   Recall Weighted:   0.9062

📊 Per-Class Metrics:
Emotion          F1  Precision   Recall
----------------------------------------
anger        0.9302     0.9484   0.9127
fear         0.9128     0.9256   0.9003
joy          0.9371     0.9789   0.8987
love         0.9224     0.8957   0.9507
neutral      0.8369     0.8559   0.8187
sadness      0.9396     0.9438   0.9353
surprise     0.8168     0.7379   0.9146
   📊 Confusion matrix saved: /content/drive/MyDrive/SimpleSaves/confusion_matrix_epoch_5.png
✅ Saved: /content/drive/MyDrive/SimpleSaves/epoch_5

   Best F1 Macro: 0.9050 (Epoch 4)

EPOCH 6/20


Epoch 6/20 [TRAIN]:   0%|          | 0/2384 [00:00<?, ?batch/s]


   Batch [100/2384] Loss: 0.1129 | Avg: 0.1708 | Recent Avg: 0.1708 | LR: 1.86e-05

   Batch [200/2384] Loss: 0.1724 | Avg: 0.1618 | Recent Avg: 0.1527 | LR: 1.86e-05

   Batch [300/2384] Loss: 0.0825 | Avg: 0.1627 | Recent Avg: 0.1646 | LR: 1.85e-05

   Batch [400/2384] Loss: 0.1418 | Avg: 0.1610 | Recent Avg: 0.1560 | LR: 1.85e-05

   Batch [500/2384] Loss: 0.1340 | Avg: 0.1633 | Recent Avg: 0.1725 | LR: 1.85e-05

   Batch [600/2384] Loss: 0.1070 | Avg: 0.1630 | Recent Avg: 0.1615 | LR: 1.84e-05

   Batch [700/2384] Loss: 0.2059 | Avg: 0.1627 | Recent Avg: 0.1609 | LR: 1.84e-05

   Batch [800/2384] Loss: 0.2707 | Avg: 0.1635 | Recent Avg: 0.1687 | LR: 1.84e-05

   Batch [900/2384] Loss: 0.1730 | Avg: 0.1638 | Recent Avg: 0.1666 | LR: 1.83e-05

   Batch [1000/2384] Loss: 0.1249 | Avg: 0.1642 | Recent Avg: 0.1676 | LR: 1.83e-05

   Batch [1100/2384] Loss: 0.1665 | Avg: 0.1642 | Recent Avg: 0.1639 | LR: 1.82e-05

   Batch [1200/2384] Loss: 0.1634 | Avg: 0.1647 | Recent Avg: 0.1699 | LR

VALIDATION:   0%|          | 0/149 [00:00<?, ?batch/s]


EPOCH 6 RESULTS
⏱️  Time: 1347.82s (22.46 min)

📉 Loss:
   Train: 0.1658
   Val:   0.2727

📊 Overall Metrics:
   Accuracy:          0.9067
   F1 Macro:          0.9000
   F1 Weighted:       0.9081
   Precision Macro:   0.8997
   Precision Weighted: 0.9140
   Recall Macro:      0.9057
   Recall Weighted:   0.9067

📊 Per-Class Metrics:
Emotion          F1  Precision   Recall
----------------------------------------
anger        0.9270     0.8996   0.9560
fear         0.9038     0.9662   0.8490
joy          0.9380     0.9828   0.8970
love         0.9231     0.8997   0.9477
neutral      0.8472     0.8348   0.8601
sadness      0.9519     0.9868   0.9193
surprise     0.8093     0.7280   0.9109
   📊 Confusion matrix saved: /content/drive/MyDrive/SimpleSaves/confusion_matrix_epoch_6.png
✅ Saved: /content/drive/MyDrive/SimpleSaves/epoch_6

   Best F1 Macro: 0.9050 (Epoch 4)

EPOCH 7/20


Epoch 7/20 [TRAIN]:   0%|          | 0/2384 [00:00<?, ?batch/s]


   Batch [100/2384] Loss: 0.2214 | Avg: 0.1367 | Recent Avg: 0.1367 | LR: 1.76e-05

   Batch [200/2384] Loss: 0.1450 | Avg: 0.1415 | Recent Avg: 0.1463 | LR: 1.76e-05


⚠️  Training interrupted by user!
   Models saved up to last completed epoch.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
"""
DeBERTa v3 Model Evaluation Script for Emotion Detection
Evaluates all saved models on test set with comprehensive metrics
"""

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import DebertaV2Model
from sklearn.metrics import (f1_score, precision_score, recall_score,
                            accuracy_score, confusion_matrix, classification_report,
                            roc_auc_score, matthews_corrcoef, cohen_kappa_score)
from tqdm.auto import tqdm
import numpy as np
import json
import time
import warnings
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import glob

warnings.filterwarnings('ignore')

# ==================== CONFIGURATION ====================
DATA_DIR = '/content/drive/MyDrive/SuperEmotion/'
SAVE_DIR_EPOCHS = '/content/drive/MyDrive/SimpleSaves/'
SAVE_DIR_BEST = '/content/drive/MyDrive/BestModelSave/'
RESULTS_DIR = '/content/drive/MyDrive/ModelEvaluationResults/'
MODEL_NAME = 'microsoft/deberta-v3-base'
MAX_LENGTH = 128
VAL_BATCH_SIZE = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
USE_MIXED_PRECISION = True

# Create results directory
os.makedirs(RESULTS_DIR, exist_ok=True)

print("="*90)
print(" "*25 + "🔍 DeBERTa v3 Model Evaluation 🔍")
print(" "*30 + "Testing All Saved Models")
print("="*90)
print(f"\n🖥️  Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

# ==================== MOUNT DRIVE ====================
try:
    from google.colab import drive
    if not os.path.exists('/content/drive'):
        print("\n🔗 Mounting Google Drive...")
        drive.mount('/content/drive')
        print("✅ Drive mounted!")
    else:
        print("\n✅ Drive already mounted!")
except:
    print("\n⚠️  Not in Colab or Drive mounted")

# ==================== MODEL CLASS ====================
class DeBERTaEmotionClassifier(nn.Module):
    def __init__(self, num_labels):
        super(DeBERTaEmotionClassifier, self).__init__()
        self.deberta = DebertaV2Model.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.deberta.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        cls_output = sequence_output[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

# ==================== LOAD TEST DATA ====================
def load_test_data():
    """Load pre-tokenized test data and metadata"""
    print("\n" + "="*90)
    print("STEP 1: LOADING TEST DATA")
    print("="*90)

    try:
        # Load metadata
        metadata_path = os.path.join(DATA_DIR, 'metadata.json')
        print(f"\n📂 Loading metadata from: {metadata_path}")
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        emotion_classes = metadata['emotion_classes']
        num_classes = metadata['num_classes']

        print(f"✅ Metadata loaded!")
        print(f"   Emotions: {', '.join(emotion_classes)}")
        print(f"   Classes: {num_classes}")

        # Load test data
        test_path = os.path.join(DATA_DIR, 'tokenized_test.pt')
        print(f"\n📂 Loading test data...")
        test_data = torch.load(test_path)
        print(f"✅ Test: {test_data['input_ids'].shape[0]:,} samples")

        # Create dataset and dataloader
        test_dataset = TensorDataset(
            test_data['input_ids'],
            test_data['attention_mask'],
            test_data['labels']
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=VAL_BATCH_SIZE,
            shuffle=False,
            pin_memory=True,
            num_workers=0
        )

        print(f"✅ Test DataLoader created!")
        print(f"   Batches: {len(test_loader):,}")

        return test_loader, emotion_classes, metadata

    except Exception as e:
        print(f"\n❌ ERROR loading data: {e}")
        import traceback
        traceback.print_exc()
        raise

# ==================== FIND ALL MODELS ====================
def find_all_models():
    """Find all saved models in directories"""
    print("\n" + "="*90)
    print("STEP 2: SCANNING FOR MODELS")
    print("="*90)

    models = []

    # Scan epoch models
    epoch_dirs = glob.glob(os.path.join(SAVE_DIR_EPOCHS, 'epoch_*'))
    for epoch_dir in sorted(epoch_dirs):
        epoch_num = os.path.basename(epoch_dir).split('_')[1]
        classifier_path = os.path.join(epoch_dir, 'classifier.pt')
        if os.path.exists(classifier_path):
            models.append({
                'path': epoch_dir,
                'name': f'epoch_{epoch_num}',
                'type': 'epoch',
                'epoch': int(epoch_num)
            })
            print(f"✅ Found epoch model: {epoch_num}")

    # Scan best models
    best_dirs = glob.glob(os.path.join(SAVE_DIR_BEST, '*'))
    for best_dir in best_dirs:
        if os.path.isdir(best_dir):
            name = os.path.basename(best_dir)
            classifier_path = os.path.join(best_dir, 'classifier.pt')
            if os.path.exists(classifier_path):
                models.append({
                    'path': best_dir,
                    'name': name,
                    'type': 'best'
                })
                print(f"✅ Found best model: {name}")

    print(f"\n📊 Total models found: {len(models)}")
    return models

# ==================== LOAD MODEL ====================
def load_model(model_info, num_classes):
    """Load a specific model"""
    model_path = model_info['path']
    model_name = model_info['name']

    try:
        print(f"\n🤖 Loading model: {model_name}")

        # Initialize model
        model = DeBERTaEmotionClassifier(num_labels=num_classes)

        # Load classifier weights
        classifier_path = os.path.join(model_path, 'classifier.pt')
        checkpoint = torch.load(classifier_path, map_location=DEVICE, weights_only=False)

        model.classifier.load_state_dict(checkpoint['classifier_state_dict'])
        model.dropout.load_state_dict(checkpoint['dropout_state_dict'])

        # Load DeBERTa weights
        if os.path.exists(os.path.join(model_path, 'model.safetensors')):
            # Load from safetensors if available
            from transformers import AutoModel
            temp_model = AutoModel.from_pretrained(model_path)
            model.deberta.load_state_dict(temp_model.state_dict())
            del temp_model
        else:
            # Otherwise load from the directory
            model.deberta = DebertaV2Model.from_pretrained(model_path)

        model.to(DEVICE)
        model.eval()

        # Load existing metrics if available
        metrics_path = os.path.join(model_path, 'metrics.json')
        existing_metrics = None
        if os.path.exists(metrics_path):
            with open(metrics_path, 'r') as f:
                existing_metrics = json.load(f)

        print(f"✅ Model loaded: {model_name}")
        return model, existing_metrics

    except Exception as e:
        print(f"❌ ERROR loading model {model_name}: {e}")
        return None, None

# ==================== COMPREHENSIVE EVALUATION ====================
def comprehensive_evaluate(model, dataloader, emotion_classes, model_name):
    """Evaluate model with comprehensive metrics"""
    print(f"\n🔍 Evaluating model: {model_name}")

    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    total_loss = 0

    # Loss function (without class weights for fair comparison)
    criterion = nn.CrossEntropyLoss()

    progress_bar = tqdm(
        dataloader,
        desc=f"Evaluating {model_name}",
        unit="batch",
        colour="blue"
    )

    with torch.no_grad():
        for batch in progress_bar:
            try:
                input_ids = batch[0].to(DEVICE)
                attention_mask = batch[1].to(DEVICE)
                labels = batch[2].to(DEVICE)

                with torch.cuda.amp.autocast(enabled=USE_MIXED_PRECISION):
                    logits = model(input_ids, attention_mask)
                    loss = criterion(logits, labels)
                    probs = torch.softmax(logits, dim=-1)

                total_loss += loss.item()
                preds = torch.argmax(logits, dim=1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

                progress_bar.set_postfix({'loss': f'{total_loss / (len(all_preds) // VAL_BATCH_SIZE):.4f}'})

            except Exception as e:
                print(f"\n⚠️  Error during evaluation: {e}")
                continue

    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # Basic metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1_micro = f1_score(all_labels, all_preds, average='micro', zero_division=0)
    precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    precision_weighted = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    precision_micro = precision_score(all_labels, all_preds, average='micro', zero_division=0)
    recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    recall_weighted = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall_micro = recall_score(all_labels, all_preds, average='micro', zero_division=0)

    # Per-class metrics
    f1_per_class = f1_score(all_labels, all_preds, average=None, zero_division=0)
    precision_per_class = precision_score(all_labels, all_preds, average=None, zero_division=0)
    recall_per_class = recall_score(all_labels, all_preds, average=None, zero_division=0)

    # Additional metrics
    try:
        # Matthews Correlation Coefficient
        mcc = matthews_corrcoef(all_labels, all_preds)

        # Cohen's Kappa
        kappa = cohen_kappa_score(all_labels, all_preds)

        # AUC-ROC (one-vs-rest for multiclass)
        try:
            auc_roc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')
        except:
            auc_roc = 0.0
    except:
        mcc = 0.0
        kappa = 0.0
        auc_roc = 0.0

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Normalized confusion matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    avg_loss = total_loss / len(dataloader)

    return {
        'model_name': model_name,
        'loss': avg_loss,
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'f1_micro': f1_micro,
        'precision_macro': precision_macro,
        'precision_weighted': precision_weighted,
        'precision_micro': precision_micro,
        'recall_macro': recall_macro,
        'recall_weighted': recall_weighted,
        'recall_micro': recall_micro,
        'f1_per_class': f1_per_class.tolist(),
        'precision_per_class': precision_per_class.tolist(),
        'recall_per_class': recall_per_class.tolist(),
        'mcc': mcc,
        'kappa': kappa,
        'auc_roc': auc_roc,
        'confusion_matrix': cm.tolist(),
        'confusion_matrix_normalized': cm_normalized.tolist(),
        'predictions': all_preds.tolist(),
        'labels': all_labels.tolist(),
        'probabilities': all_probs.tolist()
    }

# ==================== SAVE RESULTS ====================
def save_results(results, emotion_classes):
    """Save evaluation results"""
    print("\n" + "="*90)
    print("STEP 5: SAVING RESULTS")
    print("="*90)

    # Create results dataframe
    df_results = pd.DataFrame()

    for result in results:
        # Create a row with basic metrics
        row = {
            'model_name': result['model_name'],
            'loss': result['loss'],
            'accuracy': result['accuracy'],
            'f1_macro': result['f1_macro'],
            'f1_weighted': result['f1_weighted'],
            'f1_micro': result['f1_micro'],
            'precision_macro': result['precision_macro'],
            'precision_weighted': result['precision_weighted'],
            'precision_micro': result['precision_micro'],
            'recall_macro': result['recall_macro'],
            'recall_weighted': result['recall_weighted'],
            'recall_micro': result['recall_micro'],
            'mcc': result['mcc'],
            'kappa': result['kappa'],
            'auc_roc': result['auc_roc']
        }

        # Add per-class metrics
        for i, emotion in enumerate(emotion_classes):
            row[f'f1_{emotion}'] = result['f1_per_class'][i]
            row[f'precision_{emotion}'] = result['precision_per_class'][i]
            row[f'recall_{emotion}'] = result['recall_per_class'][i]

        df_results = df_results.append(row, ignore_index=True)

    # Save the dataframe
    csv_path = os.path.join(RESULTS_DIR, 'model_comparison.csv')
    df_results.to_csv(csv_path, index=False)
    print(f"✅ Results saved to: {csv_path}")

    # Save detailed results for each model
    for result in results:
        model_dir = os.path.join(RESULTS_DIR, result['model_name'])
        os.makedirs(model_dir, exist_ok=True)

        # Save metrics
        metrics_path = os.path.join(model_dir, 'metrics.json')
        with open(metrics_path, 'w') as f:
            json.dump(result, f, indent=2)

        # Plot and save confusion matrix
        plt.figure(figsize=(12, 10))

        # Raw counts
        plt.subplot(2, 2, 1)
        sns.heatmap(result['confusion_matrix'], annot=True, fmt='d', cmap='Blues',
                   xticklabels=emotion_classes, yticklabels=emotion_classes)
        plt.title(f'Confusion Matrix - {result["model_name"]} (Counts)')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')

        # Normalized
        plt.subplot(2, 2, 2)
        sns.heatmap(result['confusion_matrix_normalized'], annot=True, fmt='.2f', cmap='Blues',
                   xticklabels=emotion_classes, yticklabels=emotion_classes)
        plt.title(f'Confusion Matrix - {result["model_name"]} (Normalized)')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')

        # Per-class metrics
        plt.subplot(2, 2, 3)
        x = np.arange(len(emotion_classes))
        width = 0.25

        plt.bar(x - width, result['f1_per_class'], width, label='F1')
        plt.bar(x, result['precision_per_class'], width, label='Precision')
        plt.bar(x + width, result['recall_per_class'], width, label='Recall')

        plt.xticks(x, emotion_classes, rotation=45)
        plt.title(f'Per-Class Metrics - {result["model_name"]}')
        plt.ylabel('Score')
        plt.legend()

        # Overall metrics
        plt.subplot(2, 2, 4)
        metrics = ['accuracy', 'f1_macro', 'precision_macro', 'recall_macro', 'mcc', 'kappa', 'auc_roc']
        values = [result[m] for m in metrics]

        bars = plt.bar(metrics, values)
        plt.title(f'Overall Metrics - {result["model_name"]}')
        plt.ylabel('Score')
        plt.xticks(rotation=45)

        # Add value labels on bars
        for bar, value in zip(bars, values):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.3f}', ha='center', va='bottom')

        plt.tight_layout()
        cm_path = os.path.join(model_dir, 'evaluation_plots.png')
        plt.savefig(cm_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"✅ Detailed results saved for: {result['model_name']}")

    # Create summary comparison plot
    plt.figure(figsize=(15, 10))

    # Sort models by F1 macro for better visualization
    df_sorted = df_results.sort_values('f1_macro', ascending=False)

    # Plot 1: Overall metrics comparison
    plt.subplot(2, 2, 1)
    metrics = ['accuracy', 'f1_macro', 'precision_macro', 'recall_macro']
    x = np.arange(len(df_sorted))
    width = 0.2

    for i, metric in enumerate(metrics):
        plt.bar(x + i*width, df_sorted[metric], width, label=metric)

    plt.xticks(x + width*1.5, df_sorted['model_name'], rotation=45, ha='right')
    plt.title('Overall Metrics Comparison')
    plt.ylabel('Score')
    plt.legend()

    # Plot 2: F1 scores per emotion
    plt.subplot(2, 2, 2)
    for emotion in emotion_classes:
        plt.plot(df_sorted['model_name'], df_sorted[f'f1_{emotion}'], marker='o', label=emotion)

    plt.xticks(rotation=45, ha='right')
    plt.title('F1 Scores per Emotion')
    plt.ylabel('F1 Score')
    plt.legend()

    # Plot 3: Additional metrics
    plt.subplot(2, 2, 3)
    additional_metrics = ['mcc', 'kappa', 'auc_roc']
    x = np.arange(len(df_sorted))
    width = 0.25

    for i, metric in enumerate(additional_metrics):
        plt.bar(x + i*width, df_sorted[metric], width, label=metric)

    plt.xticks(x + width, df_sorted['model_name'], rotation=45, ha='right')
    plt.title('Additional Metrics Comparison')
    plt.ylabel('Score')
    plt.legend()

    # Plot 4: Best model per emotion
    plt.subplot(2, 2, 4)
    best_models = []
    best_scores = []

    for emotion in emotion_classes:
        best_idx = df_sorted[f'f1_{emotion}'].idxmax()
        best_models.append(df_sorted.loc[best_idx, 'model_name'])
        best_scores.append(df_sorted.loc[best_idx, f'f1_{emotion}'])

    y_pos = np.arange(len(emotion_classes))
    bars = plt.barh(y_pos, best_scores)

    plt.yticks(y_pos, emotion_classes)
    plt.xlabel('F1 Score')
    plt.title('Best Model per Emotion')

    # Add model names as labels
    for i, (bar, model) in enumerate(zip(bars, best_models)):
        plt.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
                model, ha='left', va='center', fontsize=8)

    plt.tight_layout()
    summary_path = os.path.join(RESULTS_DIR, 'model_comparison.png')
    plt.savefig(summary_path, dpi=150, bbox_inches='tight')
    plt.close()

    print(f"✅ Summary comparison plot saved to: {summary_path}")

    return df_results

# ==================== MAIN ====================
def main():
    print("\n" + "="*90)
    print("🚀 STARTING MODEL EVALUATION")
    print("="*90)

    start_time = time.time()

    # Load test data
    test_loader, emotion_classes, metadata = load_test_data()
    num_classes = len(emotion_classes)

    # Find all models
    models = find_all_models()

    if not models:
        print("\n❌ No models found! Please check the directories.")
        return

    # Evaluate each model
    results = []

    print("\n" + "="*90)
    print("STEP 3: EVALUATING MODELS")
    print("="*90)

    for i, model_info in enumerate(models):
        print(f"\n[{i+1}/{len(models)}] Processing model: {model_info['name']}")

        # Load model
        model, existing_metrics = load_model(model_info, num_classes)
        if model is None:
            continue

        # Evaluate
        result = comprehensive_evaluate(model, test_loader, emotion_classes, model_info['name'])

        # Add existing metrics if available
        if existing_metrics:
            result['existing_metrics'] = existing_metrics

        results.append(result)

        # Clear GPU memory
        del model
        torch.cuda.empty_cache()

    # Save results
    if results:
        df_results = save_results(results, emotion_classes)

        # Print summary
        print("\n" + "="*90)
        print("📊 EVALUATION SUMMARY")
        print("="*90)

        # Sort by F1 macro
        df_sorted = df_results.sort_values('f1_macro', ascending=False)

        print("\nTop 5 Models by F1 Macro Score:")
        print(df_sorted[['model_name', 'f1_macro', 'accuracy', 'f1_weighted']].head(5).to_string(index=False))

        print("\nBest Model per Emotion:")
        for emotion in emotion_classes:
            best_idx = df_sorted[f'f1_{emotion}'].idxmax()
            best_model = df_sorted.loc[best_idx, 'model_name']
            best_score = df_sorted.loc[best_idx, f'f1_{emotion}']
            print(f"  {emotion:12s}: {best_model:20s} ({best_score:.4f})")

        total_time = time.time() - start_time
        print(f"\n⏱️  Total evaluation time: {total_time/60:.2f} minutes")
        print(f"📁 Results saved to: {RESULTS_DIR}")
    else:
        print("\n❌ No models were successfully evaluated!")

    print("\n" + "="*90)
    print("🎉 EVALUATION COMPLETE!")
    print("="*90)

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n⚠️  Evaluation interrupted by user!")
    except Exception as e:
        print(f"\n\n❌ CRITICAL ERROR: {e}")
        import traceback
        traceback.print_exc()

                         🔍 DeBERTa v3 Model Evaluation 🔍
                              Testing All Saved Models

🖥️  Device: cuda
   GPU: Tesla T4
   Memory: 14.74 GB

✅ Drive already mounted!

🚀 STARTING MODEL EVALUATION

STEP 1: LOADING TEST DATA

📂 Loading metadata from: /content/drive/MyDrive/SuperEmotion/metadata.json
✅ Metadata loaded!
   Emotions: anger, fear, joy, love, neutral, sadness, surprise
   Classes: 7

📂 Loading test data...
✅ Test: 19,072 samples
✅ Test DataLoader created!
   Batches: 149

STEP 2: SCANNING FOR MODELS
✅ Found epoch model: 1
✅ Found epoch model: 2
✅ Found epoch model: 3
✅ Found epoch model: 4
✅ Found epoch model: 5
✅ Found epoch model: 6
✅ Found best model: best_model

📊 Total models found: 7

STEP 3: EVALUATING MODELS

[1/7] Processing model: epoch_1

🤖 Loading model: epoch_1
✅ Model loaded: epoch_1

🔍 Evaluating model: epoch_1


Evaluating epoch_1:   0%|          | 0/149 [00:00<?, ?batch/s]


[2/7] Processing model: epoch_2

🤖 Loading model: epoch_2
✅ Model loaded: epoch_2

🔍 Evaluating model: epoch_2


Evaluating epoch_2:   0%|          | 0/149 [00:00<?, ?batch/s]


[3/7] Processing model: epoch_3

🤖 Loading model: epoch_3
✅ Model loaded: epoch_3

🔍 Evaluating model: epoch_3


Evaluating epoch_3:   0%|          | 0/149 [00:00<?, ?batch/s]


[4/7] Processing model: epoch_4

🤖 Loading model: epoch_4
✅ Model loaded: epoch_4

🔍 Evaluating model: epoch_4


Evaluating epoch_4:   0%|          | 0/149 [00:00<?, ?batch/s]


[5/7] Processing model: epoch_5

🤖 Loading model: epoch_5
✅ Model loaded: epoch_5

🔍 Evaluating model: epoch_5


Evaluating epoch_5:   0%|          | 0/149 [00:00<?, ?batch/s]


[6/7] Processing model: epoch_6

🤖 Loading model: epoch_6
✅ Model loaded: epoch_6

🔍 Evaluating model: epoch_6


Evaluating epoch_6:   0%|          | 0/149 [00:00<?, ?batch/s]


[7/7] Processing model: best_model

🤖 Loading model: best_model
✅ Model loaded: best_model

🔍 Evaluating model: best_model


Evaluating best_model:   0%|          | 0/149 [00:00<?, ?batch/s]


STEP 5: SAVING RESULTS


❌ CRITICAL ERROR: 'DataFrame' object has no attribute 'append'


Traceback (most recent call last):
  File "/tmp/ipython-input-2534668249.py", line 602, in <cell line: 0>
    main()
  File "/tmp/ipython-input-2534668249.py", line 570, in main
    df_results = save_results(results, emotion_classes)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-2534668249.py", line 374, in save_results
    df_results = df_results.append(row, ignore_index=True)
                 ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pandas/core/generic.py", line 6299, in __getattr__
    return object.__getattribute__(self, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'DataFrame' object has no attribute 'append'. Did you mean: '_append'?
