In [None]:
#!/usr/bin/env python3
"""
DeiT Cross-Validation Training Pipeline
Huấn luyện DeiT cho nhận diện cảm xúc chó với 5-fold cross-validation
"""

import os
import sys
import time
import warnings
warnings.filterwarnings('ignore')

print("🚀 Bắt đầu pipeline huấn luyện DeiT Cross-Validation...")
print("=" * 60)

# ==========================================
# BƯỚC 1: CLONE REPOSITORY
# ==========================================
print("🔄 Clone repository từ GitHub...")
REPO_URL = "https://github.com/hoangh-e/dog-emotion-recognition-hybrid.git"
REPO_NAME = "dog-emotion-recognition-hybrid"

if not os.path.exists(REPO_NAME):
    os.system(f"git clone {REPO_URL}")
    print("✅ Repository cloned successfully")
else:
    print("✅ Repository already exists")

os.chdir(REPO_NAME)
sys.path.insert(0, os.getcwd())

# ==========================================
# BƯỚC 2: CÀI ĐẶT PACKAGES
# ==========================================
print("📦 Cài đặt các packages cần thiết...")
packages_to_install = [
    "torch>=1.9.0",
    "torchvision>=0.10.0", 
    "timm>=0.6.0",
    "gdown",
    "scikit-learn",
    "matplotlib",
    "seaborn",
    "Pillow",
    "numpy",
    "pandas",
    "tqdm"
]

for package in packages_to_install:
    try:
        if package.startswith("torch"):
            import torch
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("torchvision"):
            import torchvision
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("timm"):
            import timm
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("gdown"):
            import gdown
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("scikit-learn"):
            import sklearn
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("matplotlib"):
            import matplotlib
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("seaborn"):
            import seaborn
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("Pillow"):
            import PIL
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("numpy"):
            import numpy
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("pandas"):
            import pandas
            print(f"✅ {package} đã có sẵn")
        elif package.startswith("tqdm"):
            import tqdm
            print(f"✅ {package} đã có sẵn")
    except ImportError:
        print(f"⬇️ Đang cài đặt {package}...")
        os.system(f"pip install {package}")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 3: IMPORT LIBRARIES
# ==========================================
print("📚 Import thư viện...")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from PIL import Image
import gdown
import zipfile
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import DeiT from our custom module
from dog_emotion_classification.deit import (
    load_deit_model,
    predict_emotion_deit,
    get_deit_transforms,
    create_simple_deit_model
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Sử dụng device: {device}")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 4: TẢI DATASET
# ==========================================
print("📥 Tải dataset từ Google Drive...")

# Dataset ID và thông tin
dataset_id = "1ZAgz5u64i3LDbwMFpBXjzsKt6FrhNGdW"
dataset_filename = "dog_emotion_dataset.zip"

# Tải dataset
if not os.path.exists(dataset_filename):
    print(f"⬇️ Đang tải {dataset_filename}...")
    gdown.download(f"https://drive.google.com/uc?id={dataset_id}", dataset_filename, quiet=False)
    print("✅ Tải dataset thành công!")
else:
    print("✅ Dataset đã tồn tại!")

# Giải nén dataset
print("📂 Giải nén dataset...")
if not os.path.exists("dataset"):
    with zipfile.ZipFile(dataset_filename, 'r') as zip_ref:
        zip_ref.extractall(".")
    print("✅ Giải nén thành công!")
else:
    print("✅ Dataset đã được giải nén!")

# Kiểm tra cấu trúc dataset
dataset_path = "dataset"
if os.path.exists(dataset_path):
    classes = sorted(os.listdir(dataset_path))
    print(f"📊 Tìm thấy {len(classes)} classes: {classes}")
    
    total_images = 0
    for class_name in classes:
        class_path = os.path.join(dataset_path, class_name)
        if os.path.isdir(class_path):
            class_images = len([f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
            total_images += class_images
            print(f"   {class_name}: {class_images} ảnh")
    
    print(f"📈 Tổng số ảnh: {total_images}")
else:
    print("❌ Không tìm thấy thư mục dataset!")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 5: CHUẨN BỊ DỮ LIỆU
# ==========================================
print("🔧 Chuẩn bị dữ liệu cho cross-validation...")

class DogEmotionDataset(Dataset):
    """Dataset tùy chỉnh cho dog emotion recognition"""
    
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            image_path = self.image_paths[idx]
            image = Image.open(image_path).convert('RGB')
            label = self.labels[idx]
            
            if self.transform:
                image = self.transform(image)
            
            return image, label
        except Exception as e:
            print(f"⚠️ Lỗi khi tải ảnh {self.image_paths[idx]}: {e}")
            # Trả về ảnh trắng thay thế
            if self.transform:
                dummy_image = self.transform(Image.new('RGB', (224, 224), color='white'))
            else:
                dummy_image = Image.new('RGB', (224, 224), color='white')
            return dummy_image, self.labels[idx]

# Thu thập tất cả đường dẫn ảnh và nhãn
all_image_paths = []
all_labels = []
class_to_idx = {}

for idx, class_name in enumerate(sorted(classes)):
    class_to_idx[class_name] = idx
    class_path = os.path.join(dataset_path, class_name)
    
    if os.path.isdir(class_path):
        for image_name in os.listdir(class_path):
            if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_path = os.path.join(class_path, image_name)
                all_image_paths.append(image_path)
                all_labels.append(idx)

print(f"📊 Tổng số mẫu: {len(all_image_paths)}")
print(f"🏷️ Mapping classes: {class_to_idx}")

# Chuyển đổi sang numpy arrays
all_image_paths = np.array(all_image_paths)
all_labels = np.array(all_labels)

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 5: THIẾT LẬP CROSS-VALIDATION
# ==========================================
print("🔄 Thiết lập 5-fold Stratified Cross-Validation...")

# Stratified K-Fold để đảm bảo phân bố class đều
n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

# Data transforms using our custom DeiT module
train_transform = get_deit_transforms(input_size=224, is_training=True)
val_transform = get_deit_transforms(input_size=224, is_training=False)

print("✅ Thiết lập transforms thành công!")
print(f"📊 Số folds: {n_folds}")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 6: ĐỊNH NGHĨA MODEL VÀ TRAINING FUNCTIONS
# ==========================================
print("🏗️ Định nghĩa DeiT model và training functions...")

def create_deit_model(num_classes=4, pretrained=True):
    """Tạo DeiT model với pretrained weights"""
    try:
        # Thử tải DeiT từ timm
        model = timm.create_model('deit_small_patch16_224', pretrained=pretrained, num_classes=num_classes)
        print(f"✅ Tạo DeiT model thành công với {num_classes} classes")
        return model
    except Exception as e:
        print(f"⚠️ Không thể tải DeiT từ timm: {e}")
        print("🔄 Sử dụng Vision Transformer thay thế...")
        
        # Fallback to Vision Transformer
        model = timm.create_model('vit_small_patch16_224', pretrained=pretrained, num_classes=num_classes)
        print(f"✅ Tạo ViT model thay thế thành công")
        return model

def train_epoch(model, train_loader, criterion, optimizer, device, epoch, total_epochs):
    """Huấn luyện một epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{total_epochs} [Train]')
    
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # Cập nhật progress bar
        pbar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def validate_epoch(model, val_loader, criterion, device):
    """Đánh giá một epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            pbar.set_postfix({
                'Loss': f'{running_loss/(pbar.n+1):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

print("✅ Định nghĩa functions thành công!")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 7: CROSS-VALIDATION TRAINING
# ==========================================
print("🚀 Bắt đầu 5-fold Cross-Validation Training...")

# Hyperparameters
num_epochs = 50
batch_size = 16
learning_rate = 1e-4
num_classes = len(classes)

# Lưu trữ kết quả
fold_results = []
all_train_losses = []
all_val_losses = []
all_train_accs = []
all_val_accs = []

# Bắt đầu cross-validation
for fold, (train_idx, val_idx) in enumerate(skf.split(all_image_paths, all_labels)):
    print(f"\n{'='*20} FOLD {fold + 1}/{n_folds} {'='*20}")
    
    # Chia dữ liệu
    train_paths = all_image_paths[train_idx]
    train_labels = all_labels[train_idx]
    val_paths = all_image_paths[val_idx]
    val_labels = all_labels[val_idx]
    
    print(f"📊 Train: {len(train_paths)} samples, Val: {len(val_paths)} samples")
    
    # Tạo datasets
    train_dataset = DogEmotionDataset(train_paths, train_labels, train_transform)
    val_dataset = DogEmotionDataset(val_paths, val_labels, val_transform)
    
    # Tạo data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Tạo DeiT model using our custom module
    model = create_simple_deit_model(num_classes=num_classes, device=device)
    model = model.to(device)
    
    # Loss function và optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
    
    # Training loop
    fold_train_losses = []
    fold_val_losses = []
    fold_train_accs = []
    fold_val_accs = []
    
    best_val_acc = 0.0
    
    print(f"🏋️ Bắt đầu training fold {fold + 1}...")
    
    for epoch in range(num_epochs):
        # Training
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, epoch, num_epochs)
        
        # Validation
        val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
        
        # Scheduler step
        scheduler.step()
        
        # Lưu metrics
        fold_train_losses.append(train_loss)
        fold_val_losses.append(val_loss)
        fold_train_accs.append(train_acc)
        fold_val_accs.append(val_acc)
        
        # In kết quả
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"  LR: {scheduler.get_last_lr()[0]:.6f}")
        
        # Lưu best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f'best_deit_fold_{fold+1}.pth')
    
    # Lưu kết quả fold
    fold_results.append({
        'fold': fold + 1,
        'best_val_acc': best_val_acc,
        'final_train_acc': fold_train_accs[-1],
        'final_val_acc': fold_val_accs[-1],
        'train_losses': fold_train_losses,
        'val_losses': fold_val_losses,
        'train_accs': fold_train_accs,
        'val_accs': fold_val_accs
    })
    
    all_train_losses.append(fold_train_losses)
    all_val_losses.append(fold_val_losses)
    all_train_accs.append(fold_train_accs)
    all_val_accs.append(fold_val_accs)
    
    print(f"✅ Fold {fold + 1} hoàn thành! Best Val Acc: {best_val_acc:.2f}%")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 8: PHÂN TÍCH KẾT QUẢ
# ==========================================
print("📊 Phân tích kết quả Cross-Validation...")

# Tính toán statistics
val_accs = [result['best_val_acc'] for result in fold_results]
mean_val_acc = np.mean(val_accs)
std_val_acc = np.std(val_accs)

print(f"\n🎯 KẾT QUẢ CROSS-VALIDATION:")
print(f"Mean Validation Accuracy: {mean_val_acc:.2f}% ± {std_val_acc:.2f}%")
print(f"Min Validation Accuracy: {min(val_accs):.2f}%")
print(f"Max Validation Accuracy: {max(val_accs):.2f}%")

print(f"\n📈 Chi tiết từng fold:")
for i, result in enumerate(fold_results):
    print(f"Fold {i+1}: {result['best_val_acc']:.2f}%")

# Tạo DataFrame kết quả
results_df = pd.DataFrame({
    'Fold': [f"Fold {i+1}" for i in range(n_folds)],
    'Best_Val_Acc': val_accs,
    'Final_Train_Acc': [result['final_train_acc'] for result in fold_results],
    'Final_Val_Acc': [result['final_val_acc'] for result in fold_results]
})

print(f"\n📋 Bảng kết quả:")
print(results_df.to_string(index=False))

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 9: VISUALIZATION
# ==========================================
print("📈 Tạo biểu đồ kết quả...")

# Thiết lập style
plt.style.use('default')
sns.set_palette("husl")

# Tạo figure với nhiều subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('DeiT Cross-Validation Training Results', fontsize=16, fontweight='bold')

# 1. Training curves cho tất cả folds
ax1 = axes[0, 0]
for i, fold_result in enumerate(fold_results):
    epochs = range(1, len(fold_result['train_losses']) + 1)
    ax1.plot(epochs, fold_result['train_losses'], alpha=0.7, label=f'Fold {i+1} Train')
    ax1.plot(epochs, fold_result['val_losses'], alpha=0.7, linestyle='--', label=f'Fold {i+1} Val')

ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True, alpha=0.3)

# 2. Accuracy curves
ax2 = axes[0, 1]
for i, fold_result in enumerate(fold_results):
    epochs = range(1, len(fold_result['train_accs']) + 1)
    ax2.plot(epochs, fold_result['train_accs'], alpha=0.7, label=f'Fold {i+1} Train')
    ax2.plot(epochs, fold_result['val_accs'], alpha=0.7, linestyle='--', label=f'Fold {i+1} Val')

ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True, alpha=0.3)

# 3. Box plot của validation accuracies
ax3 = axes[1, 0]
ax3.boxplot(val_accs, labels=['DeiT'])
ax3.set_title('Validation Accuracy Distribution')
ax3.set_ylabel('Accuracy (%)')
ax3.grid(True, alpha=0.3)

# Thêm statistics
ax3.text(0.02, 0.98, f'Mean: {mean_val_acc:.2f}%\nStd: {std_val_acc:.2f}%', 
         transform=ax3.transAxes, verticalalignment='top',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

# 4. Bar plot comparison
ax4 = axes[1, 1]
x_pos = np.arange(len(fold_results))
bars = ax4.bar(x_pos, val_accs, alpha=0.7, color='skyblue', edgecolor='navy')
ax4.set_title('Best Validation Accuracy by Fold')
ax4.set_xlabel('Fold')
ax4.set_ylabel('Accuracy (%)')
ax4.set_xticks(x_pos)
ax4.set_xticklabels([f'Fold {i+1}' for i in range(n_folds)])
ax4.grid(True, alpha=0.3)

# Thêm giá trị trên bars
for bar, acc in zip(bars, val_accs):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height + 0.5,
             f'{acc:.1f}%', ha='center', va='bottom')

plt.tight_layout()
plt.show()

print("✅ Biểu đồ đã được tạo!")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 10: LƯU KẾT QUẢ
# ==========================================
print("💾 Lưu kết quả và models...")

# Lưu kết quả vào file
results_summary = {
    'model_name': 'DeiT',
    'cross_validation_folds': n_folds,
    'epochs_per_fold': num_epochs,
    'batch_size': batch_size,
    'learning_rate': learning_rate,
    'mean_val_accuracy': float(mean_val_acc),
    'std_val_accuracy': float(std_val_acc),
    'min_val_accuracy': float(min(val_accs)),
    'max_val_accuracy': float(max(val_accs)),
    'fold_results': fold_results,
    'class_mapping': class_to_idx
}

# Lưu JSON
with open('deit_cv_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

# Lưu CSV
results_df.to_csv('deit_cv_results.csv', index=False)

print("✅ Đã lưu kết quả vào:")
print("   - deit_cv_results.json")
print("   - deit_cv_results.csv")
print("   - best_deit_fold_*.pth (model weights)")

# Tạo file summary
summary_text = f"""
DeiT Cross-Validation Training Summary
=====================================

Model: DeiT (Data-efficient Image Transformer)
Dataset: Dog Emotion Recognition
Classes: {list(class_to_idx.keys())}
Total Images: {len(all_image_paths)}

Training Configuration:
- Cross-validation: {n_folds}-fold Stratified
- Epochs per fold: {num_epochs}
- Batch size: {batch_size}
- Learning rate: {learning_rate}
- Optimizer: Adam
- Scheduler: StepLR (step_size=15, gamma=0.1)

Results:
- Mean Validation Accuracy: {mean_val_acc:.2f}% ± {std_val_acc:.2f}%
- Best Fold Accuracy: {max(val_accs):.2f}%
- Worst Fold Accuracy: {min(val_accs):.2f}%

Fold Details:
{chr(10).join([f"Fold {i+1}: {acc:.2f}%" for i, acc in enumerate(val_accs)])}

Files Generated:
- deit_cv_results.json: Detailed results
- deit_cv_results.csv: Results table
- best_deit_fold_*.pth: Best model weights for each fold
"""

with open('deit_training_summary.txt', 'w') as f:
    f.write(summary_text)

print("   - deit_training_summary.txt")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 11: DOWNLOAD FILES (CHO COLAB)
# ==========================================
print("📥 Chuẩn bị download files...")

# Kiểm tra nếu đang chạy trên Colab
try:
    from google.colab import files
    
    print("🔍 Phát hiện Google Colab - tự động download files...")
    
    # Download các file kết quả
    files_to_download = [
        'deit_cv_results.json',
        'deit_cv_results.csv', 
        'deit_training_summary.txt'
    ]
    
    for filename in files_to_download:
        if os.path.exists(filename):
            print(f"⬇️ Downloading {filename}...")
            files.download(filename)
    
    # Download best model từ fold có accuracy cao nhất
    best_fold_idx = val_accs.index(max(val_accs))
    best_model_file = f'best_deit_fold_{best_fold_idx + 1}.pth'
    
    if os.path.exists(best_model_file):
        print(f"⬇️ Downloading best model: {best_model_file}...")
        files.download(best_model_file)
    
    print("✅ Download hoàn tất!")
    
except ImportError:
    print("ℹ️ Không phải Colab environment - files đã được lưu locally")

print("\n" + "=" * 60)

# ==========================================
# BƯỚC 12: HƯỚNG DẪN SỬ DỤNG
# ==========================================
print("📋 HƯỚNG DẪN SỬ DỤNG MODEL")
print("=" * 40)

print("""
🎯 CÁCH SỬ DỤNG MODEL ĐÃ HUẤN LUYỆN:

1. Load model:
   ```python
   import torch
   import timm
   
   # Tạo model architecture
   model = timm.create_model('deit_small_patch16_224', pretrained=False, num_classes=4)
   
   # Load weights
   model.load_state_dict(torch.load('best_deit_fold_X.pth'))
   model.eval()
   ```

2. Predict trên ảnh mới:
   ```python
   from PIL import Image
   import torchvision.transforms as transforms
   
   # Transforms
   transform = transforms.Compose([
       transforms.Resize((224, 224)),
       transforms.ToTensor(),
       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
   ])
   
   # Load và predict
   image = Image.open('path/to/image.jpg').convert('RGB')
   input_tensor = transform(image).unsqueeze(0)
   
   with torch.no_grad():
       output = model(input_tensor)
       prediction = torch.argmax(output, dim=1)
   
   classes = ['angry', 'happy', 'relaxed', 'sad']
   predicted_emotion = classes[prediction.item()]
   ```

3. Class mapping:
   {class_to_idx}

🎉 TRAINING HOÀN TẤT!
""")

print("🏁 DeiT Cross-Validation Training Pipeline hoàn thành!")
print(f"⏱️ Thời gian chạy: {time.time() - time.time():.2f} giây")
print("🎯 Kết quả tốt nhất: {:.2f}%".format(max(val_accs)))

print("\n" + "=" * 60)
print("✨ CẢM ƠN BẠN ĐÃ SỬ DỤNG PIPELINE! ✨")
print("=" * 60)


In [None]:
#!/usr/bin/env python3
"""
🎓 DeiT Cross-Validation Training for Dog Emotion Recognition
============================================================

Complete pipeline for training DeiT on dog emotion dataset with:
- Automatic dataset download and preparation
- 5-fold stratified cross-validation
- 50 epochs training per fold
- Comprehensive visualization and evaluation
- Model saving and download

Author: Dog Emotion Recognition Team
Date: 2024
"""

import os
import sys
import warnings
import time
from datetime import datetime
import zipfile
import shutil
from pathlib import Path

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

print("🚀 Starting DeiT Cross-Validation Training Pipeline")
print("=" * 55)

# =====================================
# 1. PACKAGE INSTALLATION
# =====================================

print("\n📦 Installing required packages...")
packages = [
    'torch>=1.9.0',
    'torchvision>=0.10.0', 
    'scikit-learn>=1.0.0',
    'matplotlib>=3.3.0',
    'seaborn>=0.11.0',
    'gdown>=4.0.0',
    'Pillow>=8.0.0',
    'numpy>=1.21.0',
    'pandas>=1.3.0',
    'tqdm>=4.60.0'
]

for package in packages:
    try:
        os.system(f'pip install {package} --quiet')
        print(f"✅ {package}")
    except Exception as e:
        print(f"❌ Failed to install {package}: {e}")

print("📦 Package installation completed!")

# =====================================
# 2. IMPORTS
# =====================================

print("\n📚 Importing libraries...")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from PIL import Image
import gdown
from tqdm import tqdm
import json
import random

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# =====================================
# 3. DATASET DOWNLOAD
# =====================================

print("\n💾 Downloading dataset...")

# Google Drive dataset ID
dataset_id = "1ZAgz5u64i3LDbwMFpBXjzsKt6FrhNGdW"
dataset_zip = "dog_emotion_dataset.zip"
dataset_dir = "dog_emotion_dataset"

if not os.path.exists(dataset_dir):
    print("📥 Downloading dataset from Google Drive...")
    try:
        gdown.download(f'https://drive.google.com/uc?id={dataset_id}', dataset_zip, quiet=False)
        
        print("📂 Extracting dataset...")
        with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
            zip_ref.extractall('.')
        
        os.remove(dataset_zip)
        print("✅ Dataset downloaded and extracted successfully!")
        
    except Exception as e:
        print(f"❌ Error downloading dataset: {e}")
        print("Please check your internet connection and try again.")
        sys.exit(1)
else:
    print("✅ Dataset already exists!")

# =====================================
# 4. DATASET CLASS
# =====================================

class DogEmotionDataset(Dataset):
    """Dataset class for dog emotion recognition"""
    
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            # Load image
            image = Image.open(self.image_paths[idx]).convert('RGB')
            
            # Apply transforms
            if self.transform:
                image = self.transform(image)
            
            # Get label
            label = self.labels[idx]
            
            return image, label
            
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            # Return a dummy image and label
            dummy_image = torch.zeros(3, 224, 224)
            return dummy_image, 0

# =====================================
# 5. DATA PREPARATION
# =====================================

print("\n🔍 Preparing dataset...")

# Define emotion classes
emotion_classes = ['angry', 'happy', 'relaxed', 'sad']
class_to_idx = {cls: idx for idx, cls in enumerate(emotion_classes)}

# Collect all images and labels
all_images = []
all_labels = []

for class_name in emotion_classes:
    class_dir = Path(dataset_dir) / class_name
    if class_dir.exists():
        images = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.png'))
        all_images.extend(images)
        all_labels.extend([class_to_idx[class_name]] * len(images))
        print(f"📁 {class_name}: {len(images)} images")

print(f"\n📊 Total dataset: {len(all_images)} images")
print(f"📊 Classes: {emotion_classes}")

# Convert to numpy arrays
all_images = np.array(all_images)
all_labels = np.array(all_labels)

# =====================================
# 6. DEIT MODEL CREATION
# =====================================

def create_deit_model(num_classes=4, pretrained=True):
    """Create DeiT model for emotion classification"""
    
    # DeiT is not commonly available in standard torchvision
    # We'll use ViT as a base and modify it for DeiT-like behavior
    try:
        # Try to load ViT from torchvision as DeiT base
        model = models.vit_b_16(pretrained=pretrained)
        
        # Modify classifier for our classes
        num_features = model.heads.head.in_features
        model.heads.head = nn.Linear(num_features, num_classes)
        
        print("✅ Using ViT-B/16 as DeiT base model")
        
    except AttributeError:
        # Fallback: Create a simplified DeiT-like model using ResNet50
        print("⚠️ DeiT/ViT not available in torchvision, using ResNet50 as fallback...")
        model = models.resnet50(pretrained=pretrained)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, num_classes)
    
    return model

# =====================================
# 7. TRAINING FUNCTIONS
# =====================================

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train model for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def evaluate_model(model, dataloader, criterion, device):
    """Evaluate model on validation set"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predicted = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Evaluating")
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_predicted.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, all_predicted, all_labels

# =====================================
# 8. CROSS-VALIDATION TRAINING
# =====================================

print("\n🎯 Starting 5-Fold Cross-Validation Training...")

# Training parameters
n_folds = 5
epochs = 50
batch_size = 16
learning_rate = 1e-4
input_size = 224

# Data transforms for DeiT
train_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize cross-validation
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

# Storage for results
fold_results = []
all_train_losses = []
all_val_losses = []
all_train_accs = []
all_val_accs = []

# Training loop
for fold, (train_idx, val_idx) in enumerate(skf.split(all_images, all_labels)):
    print(f"\n{'='*20} FOLD {fold+1}/{n_folds} {'='*20}")
    
    # Split data
    train_images = all_images[train_idx]
    train_labels = all_labels[train_idx]
    val_images = all_images[val_idx]
    val_labels = all_labels[val_idx]
    
    print(f"Train samples: {len(train_images)}")
    print(f"Validation samples: {len(val_images)}")
    
    # Create datasets
    train_dataset = DogEmotionDataset(train_images, train_labels, train_transform)
    val_dataset = DogEmotionDataset(val_images, val_labels, val_transform)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Create model
    model = create_deit_model(num_classes=len(emotion_classes), pretrained=True)
    model = model.to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
    
    # Training tracking
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    best_val_acc = 0.0
    
    # Training epochs
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        print("-" * 30)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc, _, _ = evaluate_model(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step()
        
        # Save metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        # Print progress
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f'deit_fold_{fold+1}_best.pth')
            print(f"💾 New best model saved! Accuracy: {best_val_acc:.2f}%")
    
    # Store fold results
    fold_results.append({
        'fold': fold + 1,
        'best_val_acc': best_val_acc,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs
    })
    
    all_train_losses.append(train_losses)
    all_val_losses.append(val_losses)
    all_train_accs.append(train_accs)
    all_val_accs.append(val_accs)
    
    print(f"\n✅ Fold {fold+1} completed! Best validation accuracy: {best_val_acc:.2f}%")

# =====================================
# 9. RESULTS ANALYSIS
# =====================================

print("\n📊 Training Results Analysis")
print("=" * 50)

# Calculate statistics
fold_accuracies = [result['best_val_acc'] for result in fold_results]
mean_acc = np.mean(fold_accuracies)
std_acc = np.std(fold_accuracies)

print(f"Cross-Validation Results:")
print(f"Mean Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%")
print(f"Best Fold: {max(fold_accuracies):.2f}%")
print(f"Worst Fold: {min(fold_accuracies):.2f}%")

print("\nFold-by-fold results:")
for i, acc in enumerate(fold_accuracies):
    print(f"Fold {i+1}: {acc:.2f}%")

# =====================================
# 10. VISUALIZATION
# =====================================

print("\n📈 Creating visualizations...")

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('DeiT Cross-Validation Training Results', fontsize=16, fontweight='bold')

# 1. Training and Validation Loss
ax1 = axes[0, 0]
for fold in range(n_folds):
    epochs_range = range(1, epochs + 1)
    ax1.plot(epochs_range, all_train_losses[fold], alpha=0.7, label=f'Fold {fold+1} Train')
    ax1.plot(epochs_range, all_val_losses[fold], alpha=0.7, linestyle='--', label=f'Fold {fold+1} Val')

ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True, alpha=0.3)

# 2. Training and Validation Accuracy
ax2 = axes[0, 1]
for fold in range(n_folds):
    epochs_range = range(1, epochs + 1)
    ax2.plot(epochs_range, all_train_accs[fold], alpha=0.7, label=f'Fold {fold+1} Train')
    ax2.plot(epochs_range, all_val_accs[fold], alpha=0.7, linestyle='--', label=f'Fold {fold+1} Val')

ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.grid(True, alpha=0.3)

# 3. Cross-Validation Accuracy Distribution
ax3 = axes[1, 0]
ax3.bar(range(1, n_folds + 1), fold_accuracies, alpha=0.7, color='skyblue', edgecolor='navy')
ax3.axhline(y=mean_acc, color='red', linestyle='--', label=f'Mean: {mean_acc:.2f}%')
ax3.set_title('Cross-Validation Accuracy by Fold')
ax3.set_xlabel('Fold')
ax3.set_ylabel('Accuracy (%)')
ax3.set_xticks(range(1, n_folds + 1))
ax3.legend()
ax3.grid(True, alpha=0.3)

# Add accuracy values on bars
for i, acc in enumerate(fold_accuracies):
    ax3.text(i + 1, acc + 0.5, f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')

# 4. Model Performance Summary
ax4 = axes[1, 1]
ax4.axis('off')

# Create summary text
summary_text = f"""
DEIT TRAINING SUMMARY
{'='*21}

Dataset: Dog Emotion Recognition
Architecture: DeiT (ViT-based)
Input Size: 224×224
Classes: {len(emotion_classes)}

Training Configuration:
• Folds: {n_folds}
• Epochs per fold: {epochs}
• Batch size: {batch_size}
• Learning rate: {learning_rate}
• Optimizer: Adam

Results:
• Mean CV Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%
• Best Fold Accuracy: {max(fold_accuracies):.2f}%
• Total Training Time: {datetime.now().strftime('%H:%M:%S')}

Classes: {', '.join(emotion_classes)}
"""

ax4.text(0.05, 0.95, summary_text, transform=ax4.transAxes, fontsize=11,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

plt.tight_layout()
plt.show()

# =====================================
# 11. SAVE RESULTS
# =====================================

print("\n💾 Saving results...")

# Save training history
results_data = {
    'fold_results': fold_results,
    'mean_accuracy': mean_acc,
    'std_accuracy': std_acc,
    'emotion_classes': emotion_classes,
    'training_config': {
        'n_folds': n_folds,
        'epochs': epochs,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'input_size': input_size
    }
}

with open('deit_training_results.json', 'w') as f:
    json.dump(results_data, f, indent=2)

print("✅ Results saved to deit_training_results.json")

# =====================================
# 12. MODEL DOWNLOAD
# =====================================

print("\n📥 Preparing models for download...")

# Create a summary of all models
model_summary = []
for fold in range(n_folds):
    model_file = f'deit_fold_{fold+1}_best.pth'
    if os.path.exists(model_file):
        model_summary.append({
            'fold': fold + 1,
            'filename': model_file,
            'accuracy': fold_accuracies[fold],
            'size_mb': os.path.getsize(model_file) / (1024 * 1024)
        })

# Find best model
best_fold = fold_accuracies.index(max(fold_accuracies)) + 1
best_model_file = f'deit_fold_{best_fold}_best.pth'

print(f"\n🏆 Best model: {best_model_file} (Accuracy: {max(fold_accuracies):.2f}%)")
print(f"📊 Model summary:")
for model in model_summary:
    print(f"  Fold {model['fold']}: {model['filename']} - {model['accuracy']:.2f}% - {model['size_mb']:.1f}MB")

# Download best model (in Colab)
try:
    from google.colab import files
    print(f"\n📥 Downloading best model: {best_model_file}")
    files.download(best_model_file)
    print("✅ Model downloaded successfully!")
except ImportError:
    print("⚠️ Not running in Colab - model saved locally")

print("\n🎉 DeiT Cross-Validation Training Completed!")
print("=" * 55)
print(f"✅ Final Results:")
print(f"   Mean Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%")
print(f"   Best Model: {best_model_file} ({max(fold_accuracies):.2f}%)")
print(f"   Total Models: {len(model_summary)}")
print("=" * 55)
