In [None]:
#!/usr/bin/env python3
"""
🧠 NASNet Cross-Validation Training for Dog Emotion Recognition
==============================================================

Complete pipeline for training NASNet on dog emotion dataset with:
- Automatic dataset download and preparation
- 5-fold stratified cross-validation
- 30 epochs training per fold
- Comprehensive visualization and evaluation
- Model saving and download

Author: Dog Emotion Recognition Team
Date: 2024
"""

import os
import sys
import warnings
import time
from datetime import datetime
import zipfile
import shutil
from pathlib import Path

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

print("🚀 Starting NASNet Cross-Validation Training Pipeline")
print("=" * 60)

# =============================================================================
# 1. CLONE REPOSITORY
# =============================================================================
print("\n🔄 Cloning repository from GitHub...")
REPO_URL = "https://github.com/hoangh-e/dog-emotion-recognition-hybrid.git"
REPO_NAME = "dog-emotion-recognition-hybrid"

if not os.path.exists(REPO_NAME):
    os.system(f"git clone {REPO_URL}")
    print("✅ Repository cloned successfully")
else:
    print("✅ Repository already exists")

os.chdir(REPO_NAME)
sys.path.insert(0, os.getcwd())

# =============================================================================
# 2. PACKAGE INSTALLATION
# =============================================================================
print("\n📦 Installing required packages...")

# Install packages
packages = [
    'torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121',
    'gdown',
    'scikit-learn',
    'matplotlib',
    'seaborn',
    'Pillow',
    'numpy',
    'pandas',
    'tqdm',
    'opencv-python'
]

for package in packages:
    print(f"Installing {package.split()[0]}...")
    os.system(f"pip install {package} -q")

print("✅ All packages installed successfully!")

# =============================================================================
# 3. IMPORTS
# =============================================================================
print("\n📚 Importing libraries...")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import mobilenet_v3_large
import torch.nn.functional as F

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
import gdown
import json
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import random
from collections import Counter

# Import NASNet from our custom module
from dog_emotion_classification.nasnet import (
    load_nasnet_model,
    predict_emotion_nasnet,
    get_nasnet_transforms,
    create_simple_nasnet
)

# Set random seeds for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# =============================================================================
# 4. DATASET DOWNLOAD AND PREPARATION
# =============================================================================
print("\n💾 Downloading and preparing dataset...")

# Download dataset from Google Drive
dataset_url = "https://drive.google.com/uc?id=1ZAgz5u64i3LDbwMFpBXjzsKt6FrhNGdW"
dataset_zip = "dog_emotion_dataset.zip"

if not os.path.exists(dataset_zip):
    print("📥 Downloading dataset...")
    gdown.download(dataset_url, dataset_zip, quiet=False)
    print("✅ Dataset downloaded successfully!")
else:
    print("📁 Dataset already exists, skipping download")

# Extract dataset
dataset_dir = "dog_emotion_dataset"
if not os.path.exists(dataset_dir):
    print("📂 Extracting dataset...")
    with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
        zip_ref.extractall()
    print("✅ Dataset extracted successfully!")
else:
    print("📁 Dataset already extracted")

# =============================================================================
# 5. NASNET MODEL ARCHITECTURE
# =============================================================================
print("\n🏗️ Using NASNet from our custom module...")
        out1 = self.relu(self.bn1(self.conv1(x)))
        out2 = self.relu(self.bn2(self.conv2(x)))
        out3 = self.relu(self.bn3(self.conv3(x)))
        
        # Combine outputs
        if self.downsample is not None:
            residual = self.downsample(x)
        else:
            residual = x
        
        # Simple combination strategy
        out = out1 + out2 + out3
        if out.size() == residual.size():
            out = out + residual
        
        return out

class NASNetModel(nn.Module):
    """Custom NASNet for Dog Emotion Recognition"""
    def __init__(self, num_classes=4, num_cells=6, channels=32):
        super(NASNetModel, self).__init__()
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, channels, 3, 2, 1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True)
        )
        
        # NASNet cells
        self.cells = nn.ModuleList()
        in_channels = channels
        
        for i in range(num_cells):
            is_reduction = (i % 2 == 1)  # Reduction every other cell
            out_channels = channels * (2 ** (i // 2))
            
            self.cells.append(NASNetCell(in_channels, out_channels, is_reduction))
            in_channels = out_channels
        
        # Classifier
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(in_channels, num_classes)
        
    def forward(self, x):
        x = self.stem(x)
        
        for cell in self.cells:
            x = cell(x)
        
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        
        return x

def create_nasnet_model(num_classes=4):
    """Create NASNet model"""
    model = NASNetModel(num_classes=num_classes)
    return model

print("✅ NASNet model architecture defined!")

# =============================================================================
# 5. DATASET CLASS
# =============================================================================
print("\n📊 Setting up dataset class...")

class DogEmotionDataset(Dataset):
    """Custom dataset for dog emotion recognition"""
    
    def __init__(self, data_dir, transform=None, target_size=(224, 224)):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.target_size = target_size
        
        # Emotion classes
        self.classes = ['angry', 'happy', 'relaxed', 'sad']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        # Load all images
        self.samples = []
        self._load_samples()
        
        print(f"📁 Dataset loaded: {len(self.samples)} images")
        print(f"📊 Classes: {self.classes}")
        
        # Print class distribution
        class_counts = Counter([sample[1] for sample in self.samples])
        for cls, count in class_counts.items():
            print(f"   {self.classes[cls]}: {count} images")
    
    def _load_samples(self):
        """Load all image samples"""
        for class_name in self.classes:
            class_dir = self.data_dir / class_name
            if class_dir.exists():
                for img_path in class_dir.glob('*.jpg'):
                    if img_path.is_file():
                        self.samples.append((str(img_path), self.class_to_idx[class_name]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            # Load image
            image = Image.open(img_path).convert('RGB')
            
            # Resize image
            image = image.resize(self.target_size, Image.LANCZOS)
            
            # Apply transforms
            if self.transform:
                image = self.transform(image)
            
            return image, label
            
        except Exception as e:
            print(f"⚠️ Error loading image {img_path}: {e}")
            # Return a dummy image
            dummy_image = Image.new('RGB', self.target_size, color='black')
            if self.transform:
                dummy_image = self.transform(dummy_image)
            return dummy_image, label

# =============================================================================
# 6. DATA TRANSFORMS
# =============================================================================
print("\n🔄 Setting up data transforms...")

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation transforms
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("✅ Data transforms configured!")

# =============================================================================
# 7. TRAINING FUNCTIONS
# =============================================================================
print("\n🏋️ Setting up training functions...")

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training", leave=False)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation", leave=False)
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, all_preds, all_targets

def train_fold(fold_num, train_dataset, val_dataset, num_epochs=30, batch_size=16, learning_rate=1e-4):
    """Train one fold"""
    print(f"\n🔄 Training Fold {fold_num + 1}")
    print("-" * 50)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Create NASNet model using our custom module
    model = create_simple_nasnet(num_classes=4, device=device).to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    # Training history
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    
    best_val_acc = 0.0
    best_model_state = None
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc, val_preds, val_targets = validate_epoch(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step()
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
        
        # Record history
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    # Load best model
    model.load_state_dict(best_model_state)
    
    # Final validation
    final_val_loss, final_val_acc, final_preds, final_targets = validate_epoch(model, val_loader, criterion, device)
    
    print(f"\n✅ Fold {fold_num + 1} completed!")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    
    return {
        'model': model,
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'best_val_acc': best_val_acc,
        'final_preds': final_preds,
        'final_targets': final_targets
    }

print("✅ Training functions ready!")

# =============================================================================
# 8. CROSS-VALIDATION TRAINING
# =============================================================================
print("\n🎯 Starting 5-Fold Cross-Validation Training...")

# Load full dataset
full_dataset = DogEmotionDataset(
    data_dir=dataset_dir,
    transform=None,  # Will be set per fold
    target_size=(224, 224)
)

# Prepare data for stratified k-fold
X = list(range(len(full_dataset)))
y = [full_dataset.samples[i][1] for i in X]

# 5-fold stratified cross-validation
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Training parameters
NUM_EPOCHS = 30
BATCH_SIZE = 16
LEARNING_RATE = 1e-4

# Store results
fold_results = []
all_val_accs = []

print(f"📊 Training Configuration:")
print(f"   - Epochs per fold: {NUM_EPOCHS}")
print(f"   - Batch size: {BATCH_SIZE}")
print(f"   - Learning rate: {LEARNING_RATE}")
print(f"   - Device: {device}")

# Start training
start_time = time.time()

for fold, (train_idx, val_idx) in enumerate(kfold.split(X, y)):
    print(f"\n{'='*60}")
    print(f"🔄 FOLD {fold + 1}/5")
    print(f"{'='*60}")
    
    # Create fold datasets
    train_samples = [full_dataset.samples[i] for i in train_idx]
    val_samples = [full_dataset.samples[i] for i in val_idx]
    
    # Create datasets with transforms
    train_dataset = DogEmotionDataset(dataset_dir, transform=train_transform)
    train_dataset.samples = train_samples
    
    val_dataset = DogEmotionDataset(dataset_dir, transform=val_transform)
    val_dataset.samples = val_samples
    
    print(f"📊 Fold {fold + 1} data:")
    print(f"   - Training samples: {len(train_samples)}")
    print(f"   - Validation samples: {len(val_samples)}")
    
    # Train fold
    fold_result = train_fold(fold, train_dataset, val_dataset, NUM_EPOCHS, BATCH_SIZE, LEARNING_RATE)
    fold_results.append(fold_result)
    all_val_accs.append(fold_result['best_val_acc'])
    
    # Memory cleanup
    torch.cuda.empty_cache()

# Calculate overall statistics
mean_acc = np.mean(all_val_accs)
std_acc = np.std(all_val_accs)
total_time = time.time() - start_time

print(f"\n🎉 CROSS-VALIDATION COMPLETED!")
print(f"{'='*60}")
print(f"📊 Final Results:")
print(f"   - Mean Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%")
print(f"   - Individual Folds: {[f'{acc:.2f}%' for acc in all_val_accs]}")
print(f"   - Total Training Time: {total_time/3600:.2f} hours")
print(f"   - Best Fold: {np.argmax(all_val_accs) + 1} ({max(all_val_accs):.2f}%)")

# =============================================================================
# 9. VISUALIZATION
# =============================================================================
print("\n📊 Creating visualizations...")

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")

# Create comprehensive plots
fig = plt.figure(figsize=(20, 15))

# 1. Training curves for all folds
ax1 = plt.subplot(2, 3, 1)
for fold, result in enumerate(fold_results):
    epochs = range(1, len(result['train_losses']) + 1)
    plt.plot(epochs, result['train_losses'], label=f'Fold {fold+1}', alpha=0.7)
plt.title('Training Loss Across Folds', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# 2. Validation curves for all folds
ax2 = plt.subplot(2, 3, 2)
for fold, result in enumerate(fold_results):
    epochs = range(1, len(result['val_accs']) + 1)
    plt.plot(epochs, result['val_accs'], label=f'Fold {fold+1}', alpha=0.7)
plt.title('Validation Accuracy Across Folds', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

# 3. Fold comparison
ax3 = plt.subplot(2, 3, 3)
fold_names = [f'Fold {i+1}' for i in range(5)]
bars = plt.bar(fold_names, all_val_accs, color=sns.color_palette("husl", 5))
plt.title('Best Validation Accuracy by Fold', fontsize=14, fontweight='bold')
plt.ylabel('Accuracy (%)')
plt.ylim(0, 100)

# Add value labels on bars
for bar, acc in zip(bars, all_val_accs):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
             f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')

# Add mean line
plt.axhline(y=mean_acc, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_acc:.1f}%')
plt.legend()

# 4. Confusion Matrix for best fold
best_fold_idx = np.argmax(all_val_accs)
best_result = fold_results[best_fold_idx]

ax4 = plt.subplot(2, 3, 4)
cm = confusion_matrix(best_result['final_targets'], best_result['final_preds'])
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['angry', 'happy', 'relaxed', 'sad'],
            yticklabels=['angry', 'happy', 'relaxed', 'sad'])
plt.title(f'Confusion Matrix - Best Fold ({best_fold_idx+1})', fontsize=14, fontweight='bold')
plt.xlabel('Predicted')
plt.ylabel('Actual')

# 5. Training vs Validation curves for best fold
ax5 = plt.subplot(2, 3, 5)
epochs = range(1, len(best_result['train_accs']) + 1)
plt.plot(epochs, best_result['train_accs'], label='Training', linewidth=2)
plt.plot(epochs, best_result['val_accs'], label='Validation', linewidth=2)
plt.title(f'Training vs Validation - Best Fold ({best_fold_idx+1})', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True, alpha=0.3)

# 6. Statistics summary
ax6 = plt.subplot(2, 3, 6)
ax6.axis('off')
stats_text = f"""
📊 NASNet Cross-Validation Results

🎯 Model Performance:
   • Mean Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%
   • Best Fold: {best_fold_idx+1} ({max(all_val_accs):.2f}%)
   • Worst Fold: {np.argmin(all_val_accs)+1} ({min(all_val_accs):.2f}%)

⚙️ Training Configuration:
   • Architecture: NASNet Custom
   • Epochs per fold: {NUM_EPOCHS}
   • Batch size: {BATCH_SIZE}
   • Learning rate: {LEARNING_RATE}
   • Device: {device}

⏱️ Training Time:
   • Total: {total_time/3600:.2f} hours
   • Per fold: {total_time/3600/5:.2f} hours
   • Per epoch: {total_time/3600/5/NUM_EPOCHS:.2f} hours

📈 Data Information:
   • Total samples: {len(full_dataset)}
   • Classes: {len(full_dataset.classes)}
   • Folds: 5 (stratified)
"""

ax6.text(0.1, 0.9, stats_text, transform=ax6.transAxes, fontsize=12,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8))

plt.tight_layout()
plt.savefig('nasnet_cross_validation_results.png', dpi=300, bbox_inches='tight')
plt.show()

print("✅ Visualizations created and saved!")

# =============================================================================
# 10. MODEL SAVING AND DOWNLOAD
# =============================================================================
print("\n💾 Saving best model...")

# Save best model
best_model = fold_results[best_fold_idx]['model']
model_filename = f'nasnet_best_fold_{best_fold_idx+1}_acc_{max(all_val_accs):.2f}.pth'

# Save model state dict
torch.save({
    'model_state_dict': best_model.state_dict(),
    'model_config': {
        'num_classes': 4,
        'architecture': 'NASNet Custom'
    },
    'training_info': {
        'best_fold': best_fold_idx + 1,
        'best_accuracy': max(all_val_accs),
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'epochs': NUM_EPOCHS,
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE
    },
    'class_names': ['angry', 'happy', 'relaxed', 'sad']
}, model_filename)

print(f"✅ Model saved as: {model_filename}")

# Save training results
results_filename = 'nasnet_training_results.json'
training_results = {
    'cross_validation_results': {
        'mean_accuracy': float(mean_acc),
        'std_accuracy': float(std_acc),
        'fold_accuracies': [float(acc) for acc in all_val_accs],
        'best_fold': int(best_fold_idx + 1),
        'best_accuracy': float(max(all_val_accs))
    },
    'training_config': {
        'epochs': NUM_EPOCHS,
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'device': str(device),
        'total_time_hours': total_time/3600
    },
    'dataset_info': {
        'total_samples': len(full_dataset),
        'num_classes': len(full_dataset.classes),
        'class_names': full_dataset.classes
    }
}

with open(results_filename, 'w') as f:
    json.dump(training_results, f, indent=2)

print(f"✅ Training results saved as: {results_filename}")

# Download files (for Colab)
try:
    from google.colab import files
    print("\n📥 Downloading files...")
    files.download(model_filename)
    files.download(results_filename)
    files.download('nasnet_cross_validation_results.png')
    print("✅ Files downloaded successfully!")
except ImportError:
    print("📁 Files saved locally (not in Colab environment)")

# =============================================================================
# 11. FINAL SUMMARY
# =============================================================================
print(f"\n🎉 NASNet CROSS-VALIDATION TRAINING COMPLETED!")
print(f"{'='*80}")
print(f"📊 FINAL RESULTS SUMMARY:")
print(f"   🎯 Mean Accuracy: {mean_acc:.2f}% ± {std_acc:.2f}%")
print(f"   🏆 Best Fold: {best_fold_idx+1} with {max(all_val_accs):.2f}% accuracy")
print(f"   📈 All Fold Accuracies: {[f'{acc:.2f}%' for acc in all_val_accs]}")
print(f"   ⏱️ Total Training Time: {total_time/3600:.2f} hours")
print(f"   💾 Model saved as: {model_filename}")
print(f"   📊 Results saved as: {results_filename}")
print(f"   🖼️ Visualization saved as: nasnet_cross_validation_results.png")
print(f"{'='*80}")

# Usage instructions
print(f"\n📋 HOW TO USE THE TRAINED MODEL:")
print(f"1. Load the model:")
print(f"   model = create_nasnet_model(num_classes=4)")
print(f"   checkpoint = torch.load('{model_filename}')")
print(f"   model.load_state_dict(checkpoint['model_state_dict'])")
print(f"2. Use for inference on new dog images")
print(f"3. Classes: {full_dataset.classes}")

print(f"\n🎯 Training completed successfully! NASNet is ready for dog emotion recognition.")
