In [None]:
#!/usr/bin/env python3
"""
🤖 DeiT (Data-efficient Image Transformers) Cross-Validation Training for Dog Emotion Recognition
======================================================================

Complete pipeline for training DeiT (Data-efficient Image Transformers) on dog emotion dataset with:
- Automatic dataset download and preparation
- 5-fold stratified cross-validation
- 50 epochs training per fold
- Comprehensive visualization and evaluation
- Model saving and download

Author: Dog Emotion Recognition Team
Date: 2024
"""

import os
import sys
import warnings
import time
from datetime import datetime
import zipfile
import shutil
from pathlib import Path

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

print("🚀 Starting DeiT (Data-efficient Image Transformers) Cross-Validation Training Pipeline")
print("=" * 60)

# =====================================
# 1. PACKAGE INSTALLATION
# =====================================

print("\n📦 Installing required packages...")
packages = [
    'torch>=1.9.0',
    'torchvision>=0.10.0', 
    'scikit-learn>=1.0.0',
    'matplotlib>=3.3.0',
    'seaborn>=0.11.0',
    'gdown>=4.0.0',
    'Pillow>=8.0.0',
    'numpy>=1.21.0',
    'pandas>=1.3.0',
    'tqdm>=4.60.0'
]

for package in packages:
    try:
        os.system(f'pip install {package} --quiet')
        print(f"✅ {package}")
    except Exception as e:
        print(f"❌ Failed to install {package}: {e}")

print("📦 Package installation completed!")

# =====================================
# 1.5. CLONE REPOSITORY & IMPORT MODULES
# =====================================

print("\n📥 Cloning repository and importing custom modules...")

# Clone repository từ GitHub
REPO_URL = "https://github.com/hoangh-e/dog-emotion-recognition-hybrid.git"
if not os.path.exists("dog-emotion-recognition-hybrid"):
    print("📥 Cloning repository from GitHub...")
    os.system(f"git clone {REPO_URL}")

# Change to repository directory và thêm vào Python path
os.chdir("dog-emotion-recognition-hybrid")
sys.path.insert(0, os.getcwd())

# Import modules từ custom package
print("📦 Importing custom modules...")
from dog_emotion_classification.deit import (
    load_deit_model,
    predict_emotion_deit,
    get_deit_transforms,
    create_deit_model
)

# Import utility functions for 3-class conversion
from dog_emotion_classification.utils import (
    convert_dataframe_4class_to_3class,
    get_3class_emotion_classes,
    EMOTION_CLASSES_3CLASS
)
from dog_emotion_classification import EMOTION_CLASSES as PACKAGE_EMOTION_CLASSES

print("✅ Imported 3-class utility functions")
print(f"📊 Target emotion classes: {EMOTION_CLASSES_3CLASS}")
print(f"📦 Package emotion classes: {PACKAGE_EMOTION_CLASSES}")

# =====================================
# 2. IMPORTS
# =====================================

print("\n📚 Importing libraries...")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from PIL import Image
import gdown
from tqdm import tqdm
import json
import random

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# =====================================
# 3. DATASET DOWNLOAD
# =====================================

print("\n💾 Downloading dataset...")

# Google Drive dataset ID
dataset_id = "1ZAgz5u64i3LDbwMFpBXjzsKt6FrhNGdW"
dataset_zip = "cropped_dataset_4k_face.zip"

if not os.path.exists("cropped_dataset_4k_face"):
    print("📥 Downloading dataset from Google Drive...")
    try:
        gdown.download(f'https://drive.google.com/uc?id={dataset_id}', dataset_zip, quiet=False)
        
        print("📂 Extracting dataset...")
        with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
            zip_ref.extractall('.')
        
        os.remove(dataset_zip)
        print("✅ Dataset downloaded and extracted successfully!")
        
    except Exception as e:
        print(f"❌ Error downloading dataset: {e}")
        print("Please check your internet connection and try again.")
        sys.exit(1)
else:
    print("✅ Dataset already exists!")

# =====================================
# 4. DATASET CLASS
# =====================================

class DogEmotionDataset(Dataset):
    """Dataset class for dog emotion recognition"""
    
    def __init__(self, root, labels_csv, transform=None):
        self.root = root
        df = pd.read_csv(labels_csv)
        self.items = df[['filename', 'label']].values
        unique_labels = sorted(df['label'].unique())
        self.label2index = {name: i for i, name in enumerate(unique_labels)}
        self.index2label = {i: name for name, i in self.label2index.items()}
        self.transform = transform
        print(f"📊 Dataset: {len(self.items)} samples")
        print(f"🏷️  Classes: {list(self.label2index.keys())}")

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        fn, label_str = self.items[idx]
        label_idx = self.label2index[label_str]
        img_path = os.path.join(self.root, label_str, fn)
        
        try:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, label_idx
        except Exception as e:
            # Fallback for corrupted images
            img = Image.new('RGB', (224, 224), (0, 0, 0))
            if self.transform:
                img = self.transform(img)
            return img, label_idx

# =====================================
# 5. DATA PREPARATION
# =====================================

print("\n🔍 Preparing dataset...")

# Dataset paths
data_root = os.path.join("cropped_dataset_4k_face", "Dog Emotion")
labels_csv = os.path.join(data_root, "labels.csv")

print(f"\n📂 Dataset structure:")
emotions = [d for d in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, d))]
print(f"   Emotion classes: {emotions}")

for emotion in emotions:
    emotion_path = os.path.join(data_root, emotion)
    if os.path.isdir(emotion_path):
        count = len([f for f in os.listdir(emotion_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"     {emotion}: {count} images")

print(f"   Labels CSV: {'✅' if os.path.exists(labels_csv) else '❌'} {labels_csv}")

# Create initial dataset to check original classes
print("\n📊 Loading original dataset...")
original_dataset = DogEmotionDataset(data_root, labels_csv, None)
original_classes = list(original_dataset.label2index.keys())
print(f"   Original classes: {original_classes}")
print(f"   Original samples: {len(original_dataset)}")

# Filter dataset for 3-class configuration
print("\n🔧 Converting to 3-class configuration...")

# Read labels CSV and filter out 'sad' class
labels_df = pd.read_csv(labels_csv)
print(f"   Original DataFrame: {len(labels_df)} samples")

# Convert to 3-class by removing 'sad' samples
filtered_df = convert_dataframe_4class_to_3class(labels_df, 'label')

# Save filtered labels CSV
filtered_labels_csv = os.path.join(data_root, "labels_3class.csv")
filtered_df.to_csv(filtered_labels_csv, index=False)
print(f"   Saved filtered labels to: {filtered_labels_csv}")

# Create 3-class dataset
dataset = DogEmotionDataset(data_root, filtered_labels_csv, None)  # Transform will be set later
NUM_CLASSES = len(dataset.label2index)
EMOTION_CLASSES = list(dataset.label2index.keys())

print(f"\n✅ Dataset ready:")
print(f"   Total samples: {len(dataset)}")
print(f"   Number of classes: {NUM_CLASSES}")
print(f"   Emotion classes: {EMOTION_CLASSES}")

# =====================================
# 6. TRAINING FUNCTIONS
# =====================================

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train model for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    """Evaluate model on validation set"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predicted = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Evaluating")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_predicted.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, all_predicted, all_labels

# =====================================
# 7. CROSS-VALIDATION TRAINING
# =====================================

print("\n🎯 Starting 5-Fold Cross-Validation Training...")

# Training parameters
n_folds = 5
epochs = 50
batch_size = 16
learning_rate = 1e-4
input_size = 224

# Data transforms
train_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Prepare labels for stratified split
labels = [dataset.label2index[item[1]] for item in dataset.items]
labels = np.array(labels)

# Initialize cross-validation
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

# Storage for results
fold_results = []
all_train_losses = []
all_val_losses = []
all_train_accs = []
all_val_accs = []

# Training loop
for fold, (train_idx, val_idx) in enumerate(skf.split(np.arange(len(dataset)), labels)):
    print(f"\n{'='*20} FOLD {fold+1}/{n_folds} {'='*20}")
    
    print(f"Train samples: {len(train_idx)}")
    print(f"Validation samples: {len(val_idx)}")
    
    # Create data samplers
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)
    
    # Set transforms
    dataset.transform = train_transform
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2)
    
    dataset.transform = val_transform  
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, num_workers=2)
    
    # Create model using custom function
    model = create_deit_model(num_classes=NUM_CLASSES)
    model = model.to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
    
    # Training tracking
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    best_val_acc = 0.0
    
    # Training epochs
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        print("-" * 30)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc, _, _ = evaluate_model(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step()
        
        # Save metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        # Print progress
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f'deit_fold_{fold+1}_best.pth')
            print(f"💾 New best model saved! Accuracy: {best_val_acc:.2f}%")
    
    # Store fold results
    fold_results.append({
        'fold': fold + 1,
        'best_val_acc': best_val_acc,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs
    })
    
    all_train_losses.append(train_losses)
    all_val_losses.append(val_losses)
    all_train_accs.append(train_accs)
    all_val_accs.append(val_accs)
    
    print(f"\n✅ Fold {fold+1} completed! Best validation accuracy: {best_val_acc:.2f}%")

# =====================================
# 8. RESULTS ANALYSIS
# =====================================

print("\n📊 Training Results Analysis")
print("=" * 50)

# Calculate statistics
fold_accuracies = [result['best_val_acc'] for result in fold_results]
mean_acc = np.mean(fold_accuracies)
std_acc = np.std(fold_accuracies)

print(f"Cross-Validation Results:")
print(f"Mean Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%")
print(f"Best Fold: {max(fold_accuracies):.2f}%")
print(f"Worst Fold: {min(fold_accuracies):.2f}%")

print("\nFold-by-fold results:")
for i, acc in enumerate(fold_accuracies):
    print(f"Fold {i+1}: {acc:.2f}%")

# =====================================
# 9. VISUALIZATION
# =====================================

print("\n📈 Creating visualizations...")

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('DeiT (Data-efficient Image Transformers) Cross-Validation Training Results', fontsize=16, fontweight='bold')

# 1. Training and Validation Loss
ax1 = axes[0, 0]
for fold in range(n_folds):
    epochs_range = range(1, epochs + 1)
    ax1.plot(epochs_range, all_train_losses[fold], alpha=0.7, label=f'Fold {fold+1} Train')
    ax1.plot(epochs_range, all_val_losses[fold], alpha=0.7, linestyle='--', label=f'Fold {fold+1} Val')

ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True, alpha=0.3)

# 2. Training and Validation Accuracy
ax2 = axes[0, 1]
for fold in range(n_folds):
    epochs_range = range(1, epochs + 1)
    ax2.plot(epochs_range, all_train_accs[fold], alpha=0.7, label=f'Fold {fold+1} Train')
    ax2.plot(epochs_range, all_val_accs[fold], alpha=0.7, linestyle='--', label=f'Fold {fold+1} Val')

ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True, alpha=0.3)

# 3. Cross-Validation Accuracy Distribution
ax3 = axes[1, 0]
ax3.bar(range(1, n_folds + 1), fold_accuracies, alpha=0.7, color='skyblue', edgecolor='navy')
ax3.axhline(y=mean_acc, color='red', linestyle='--', label=f'Mean: {mean_acc:.2f}%')
ax3.set_title('Cross-Validation Accuracy by Fold')
ax3.set_xlabel('Fold')
ax3.set_ylabel('Accuracy (%)')
ax3.set_xticks(range(1, n_folds + 1))
ax3.legend()
ax3.grid(True, alpha=0.3)

# Add accuracy values on bars
for i, acc in enumerate(fold_accuracies):
    ax3.text(i + 1, acc + 0.5, f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')

# 4. Model Performance Summary
ax4 = axes[1, 1]
ax4.axis('off')

# Create summary text
summary_text = f"""
DEIT (DATA-EFFICIENT IMAGE TRANSFORMERS) TRAINING SUMMARY
{'='*len(display_name)}

Dataset: Dog Emotion Recognition (3-Class)
Architecture: DeiT (Data-efficient Image Transformers)
Input Size: 224×224
Classes: {NUM_CLASSES}

Training Configuration:
• Folds: {n_folds}
• Epochs per fold: {epochs}
• Batch size: {batch_size}
• Learning rate: {learning_rate}
• Optimizer: Adam

Results:
• Mean CV Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%
• Best Fold Accuracy: {max(fold_accuracies):.2f}%
• Total Training Time: {datetime.now().strftime('%H:%M:%S')}

Classes: {', '.join(EMOTION_CLASSES)}
"""

ax4.text(0.05, 0.95, summary_text, transform=ax4.transAxes, fontsize=11,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

plt.tight_layout()
plt.show()

# =====================================
# 10. SAVE RESULTS
# =====================================

print("\n💾 Saving results...")

# Save training history
results_data = {
    'fold_results': fold_results,
    'mean_accuracy': mean_acc,
    'std_accuracy': std_acc,
    'emotion_classes': EMOTION_CLASSES,
    'training_config': {
        'n_folds': n_folds,
        'epochs': epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'input_size': input_size
    }
}

with open('deit_training_results.json', 'w') as f:
    json.dump(results_data, f, indent=2)

print("✅ Results saved to deit_training_results.json")

# =====================================
# 11. MODEL DOWNLOAD
# =====================================

print("\n📥 Preparing models for download...")

# Find best model
best_fold = fold_accuracies.index(max(fold_accuracies)) + 1
best_model_file = f'deit_fold_{best_fold}_best.pth'

print(f"\n🏆 Best model: {best_model_file} (Accuracy: {max(fold_accuracies):.2f}%)")

# Download best model (in Colab)
try:
    from google.colab import files
    print(f"\n📥 Downloading best model: {best_model_file}")
    files.download(best_model_file)
    files.download('deit_training_results.json')
    print("✅ Files downloaded successfully!")
except ImportError:
    print("⚠️ Not running in Colab - files saved locally")

print("\n🎉 DeiT (Data-efficient Image Transformers) Cross-Validation Training Completed!")
print("=" * 60)
print(f"✅ Final Results:")
print(f"   Mean Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%")
print(f"   Best Model: {best_model_file} ({max(fold_accuracies):.2f}%)")
print(f"   Classes: {EMOTION_CLASSES}")
print("=" * 60)
