In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install opencv-python-headless pillow pandas tqdm gdown albumentations matplotlib seaborn

# Check GPU and setup
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import numpy as np
import os
import time
from tqdm import tqdm
import matplotlib.pyplot as plt

print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎯 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    device = torch.device('cuda')
else:
    print("⚠️ Using CPU - training will be slower")
    device = torch.device('cpu')


In [None]:
# Download and extract dataset
import zipfile

DATASET_ID = "1ZAgz5u64i3LDbwMFpBXjzsKt6FrhNGdW"
DATASET_ZIP = "cropped_dataset_4k_face.zip"
EXTRACT_PATH = "data"

print("📥 Downloading dog emotion dataset...")
if not os.path.exists(DATASET_ZIP):
    !gdown {DATASET_ID} -O {DATASET_ZIP}
    print(f"✅ Dataset downloaded: {DATASET_ZIP}")
else:
    print(f"✅ Dataset already exists: {DATASET_ZIP}")

# Extract dataset
if not os.path.exists(EXTRACT_PATH):
    print("📂 Extracting dataset...")
    with zipfile.ZipFile(DATASET_ZIP, 'r') as zip_ref:
        zip_ref.extractall(EXTRACT_PATH)
    print("✅ Dataset extracted successfully")

# Dataset paths
data_root = os.path.join(EXTRACT_PATH, "cropped_dataset_4k_face", "Dog Emotion")
labels_csv = os.path.join(data_root, "labels.csv")

print(f"\n📂 Dataset structure:")
emotions = [d for d in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, d))]
print(f"   Emotion classes: {emotions}")

for emotion in emotions:
    emotion_path = os.path.join(data_root, emotion)
    if os.path.isdir(emotion_path):
        count = len([f for f in os.listdir(emotion_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"     {emotion}: {count} images")

print(f"   Labels CSV: {'✅' if os.path.exists(labels_csv) else '❌'} {labels_csv}")


In [None]:
# Create Dataset class
class DogEmotionDataset(Dataset):
    def __init__(self, root, labels_csv, transform=None):
        self.root = root
        df = pd.read_csv(labels_csv)
        self.items = df[['filename', 'label']].values
        unique_labels = sorted(df['label'].unique())
        self.label2index = {name: i for i, name in enumerate(unique_labels)}
        self.index2label = {i: name for name, i in self.label2index.items()}
        self.transform = transform
        print(f"📊 Dataset: {len(self.items)} samples")
        print(f"🏷️  Classes: {list(self.label2index.keys())}")

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        fn, label_str = self.items[idx]
        label_idx = self.label2index[label_str]
        img_path = os.path.join(self.root, label_str, fn)
        
        try:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, label_idx
        except Exception as e:
            # Fallback for corrupted images
            img = Image.new('RGB', (224, 224), (0, 0, 0))
            if self.transform:
                img = self.transform(img)
            return img, label_idx

# Create transforms for ResNet (224x224 ImageNet standard)
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create dataset
dataset = DogEmotionDataset(data_root, labels_csv, train_transform)
NUM_CLASSES = len(dataset.label2index)
EMOTION_CLASSES = list(dataset.label2index.keys())

print(f"\n✅ Dataset ready:")
print(f"   Total samples: {len(dataset)}")
print(f"   Number of classes: {NUM_CLASSES}")
print(f"   Emotion classes: {EMOTION_CLASSES}")


In [None]:
# Training function with custom checkpointing
def train_resnet_model(model_name, num_epochs=30, batch_size=16, learning_rate=1e-4):
    """
    Train ResNet model with specific checkpointing strategy:
    - Save best model từ epoch 10
    - Save mỗi 5 epochs từ epoch 10 (10, 15, 20, 25, 30)
    """
    print(f"\n🚀 Training {model_name} for {num_epochs} epochs")
    print("="*60)
    
    # Create model
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
        print("🏗️  Created ResNet50 with ImageNet pretrained weights")
    elif model_name == 'resnet101':
        model = models.resnet101(pretrained=True)
        print("🏗️  Created ResNet101 with ImageNet pretrained weights")
    else:
        raise ValueError(f"Unsupported model: {model_name}")
    
    # Modify final layer for our classes
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    model = model.to(device)
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"📊 Model stats:")
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    print(f"   Model size: {total_params * 4 / (1024**2):.1f} MB")
    
    # Create data loader
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=2, pin_memory=True)
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    # Training tracking
    train_losses = []
    train_accuracies = []
    best_acc = 0.0
    best_models = {}  # Store multiple best models
    
    # Create checkpoint directory
    checkpoint_dir = f"checkpoints_{model_name}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    print(f"🎯 Training configuration:")
    print(f"   Batch size: {batch_size}")
    print(f"   Learning rate: {learning_rate}")
    print(f"   Optimizer: Adam with weight decay 1e-4")
    print(f"   Scheduler: StepLR(step_size=10, gamma=0.1)")
    print(f"   Checkpoint dir: {checkpoint_dir}")
    print(f"   Device: {device}")
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Training loop with progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, (images, labels) in enumerate(pbar):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
        
        # Calculate epoch metrics
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        # Learning rate step
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        
        # Print epoch summary
        elapsed = time.time() - start_time
        eta = elapsed * (num_epochs - (epoch + 1)) / (epoch + 1)
        print(f"Epoch {epoch+1:2d}/{num_epochs} | "
              f"Loss: {epoch_loss:.4f} | "
              f"Acc: {epoch_acc:.4f} ({epoch_acc*100:.2f}%) | "
              f"LR: {current_lr:.2e} | "
              f"Time: {elapsed/60:.1f}m | ETA: {eta/60:.1f}m")
        
        # Checkpointing strategy
        if epoch + 1 >= 10:  # Start saving from epoch 10
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                # Save best model
                best_path = os.path.join(checkpoint_dir, f"best_model.pth")
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': epoch_loss,
                    'accuracy': epoch_acc,
                    'model_name': model_name
                }, best_path)
                print(f"✅ New best model saved: {best_path} (Acc: {epoch_acc:.4f})")
            
            # Save every 5 epochs from epoch 10
            if (epoch + 1) % 5 == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth")
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': epoch_loss,
                    'accuracy': epoch_acc,
                    'model_name': model_name
                }, checkpoint_path)
                best_models[f"epoch_{epoch+1}"] = {
                    'path': checkpoint_path,
                    'accuracy': epoch_acc,
                    'loss': epoch_loss
                }
                print(f"📦 Checkpoint saved: {checkpoint_path}")
    
    total_time = time.time() - start_time
    print(f"\n🎉 Training completed!")
    print(f"   Total time: {total_time/60:.1f} minutes")
    print(f"   Best accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)")
    print(f"   Final accuracy: {train_accuracies[-1]:.4f}")
    print(f"   Saved models: {len(best_models) + 1}")  # +1 for best model
    
    return {
        'model': model,
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'best_acc': best_acc,
        'best_models': best_models,
        'checkpoint_dir': checkpoint_dir,
        'total_time': total_time
    }


In [None]:
# Train ResNet50 with pretrained weights
print("🔥 Starting ResNet50 Training")
print("🎯 Strategy: 30 epochs, save best từ epoch 10, save mỗi 5 epochs từ epoch 10")

resnet50_results = train_resnet_model(
    model_name='resnet50',
    num_epochs=30,
    batch_size=16,  # Adjust based on GPU memory
    learning_rate=1e-4
)

print(f"\n📊 ResNet50 Training Summary:")
print(f"   Final accuracy: {resnet50_results['train_accuracies'][-1]:.4f}")
print(f"   Best accuracy: {resnet50_results['best_acc']:.4f}")
print(f"   Training time: {resnet50_results['total_time']/60:.1f} minutes")
print(f"   Checkpoint directory: {resnet50_results['checkpoint_dir']}")

# List saved models
print(f"\n💾 Saved ResNet50 models:")
print(f"   📄 best_model.pth - Best accuracy: {resnet50_results['best_acc']:.4f}")
for epoch_name, info in resnet50_results['best_models'].items():
    print(f"   📄 model_{epoch_name}.pth - Accuracy: {info['accuracy']:.4f}, Loss: {info['loss']:.4f}")


In [None]:
# Train ResNet101 with pretrained weights  
print("🔥 Starting ResNet101 Training")
print("🎯 Strategy: 30 epochs, save best từ epoch 10, save mỗi 5 epochs từ epoch 10")

resnet101_results = train_resnet_model(
    model_name='resnet101',
    num_epochs=30,
    batch_size=12,  # Smaller batch size for larger model
    learning_rate=1e-4
)

print(f"\n📊 ResNet101 Training Summary:")
print(f"   Final accuracy: {resnet101_results['train_accuracies'][-1]:.4f}")
print(f"   Best accuracy: {resnet101_results['best_acc']:.4f}")
print(f"   Training time: {resnet101_results['total_time']/60:.1f} minutes")
print(f"   Checkpoint directory: {resnet101_results['checkpoint_dir']}")

# List saved models
print(f"\n💾 Saved ResNet101 models:")
print(f"   📄 best_model.pth - Best accuracy: {resnet101_results['best_acc']:.4f}")
for epoch_name, info in resnet101_results['best_models'].items():
    print(f"   📄 model_{epoch_name}.pth - Accuracy: {info['accuracy']:.4f}, Loss: {info['loss']:.4f}")


In [None]:
# Visualize training results
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Training Loss Comparison
epochs = range(1, 31)
ax1.plot(epochs, resnet50_results['train_losses'], 'b-', label='ResNet50', linewidth=2)
ax1.plot(epochs, resnet101_results['train_losses'], 'r-', label='ResNet101', linewidth=2)
ax1.set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Training Accuracy Comparison
ax2.plot(epochs, [acc*100 for acc in resnet50_results['train_accuracies']], 'b-', label='ResNet50', linewidth=2)
ax2.plot(epochs, [acc*100 for acc in resnet101_results['train_accuracies']], 'r-', label='ResNet101', linewidth=2)
ax2.set_title('Training Accuracy Comparison', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 100)

# Model Performance Bar Chart
models = ['ResNet50', 'ResNet101']
best_accs = [resnet50_results['best_acc']*100, resnet101_results['best_acc']*100]
final_accs = [resnet50_results['train_accuracies'][-1]*100, resnet101_results['train_accuracies'][-1]*100]

x = np.arange(len(models))
width = 0.35

ax3.bar(x - width/2, best_accs, width, label='Best Accuracy', alpha=0.8, color='green')
ax3.bar(x + width/2, final_accs, width, label='Final Accuracy', alpha=0.8, color='orange')
ax3.set_title('Best vs Final Accuracy', fontsize=14, fontweight='bold')
ax3.set_ylabel('Accuracy (%)')
ax3.set_xticks(x)
ax3.set_xticklabels(models)
ax3.legend()
ax3.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for i, (best, final) in enumerate(zip(best_accs, final_accs)):
    ax3.text(i - width/2, best + 1, f'{best:.1f}%', ha='center', va='bottom', fontweight='bold')
    ax3.text(i + width/2, final + 1, f'{final:.1f}%', ha='center', va='bottom', fontweight='bold')

# Training Time Comparison
times = [resnet50_results['total_time']/60, resnet101_results['total_time']/60]
colors = ['skyblue', 'lightcoral']
bars = ax4.bar(models, times, color=colors, alpha=0.8)
ax4.set_title('Training Time Comparison', fontsize=14, fontweight='bold')
ax4.set_ylabel('Time (minutes)')
ax4.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, time_val in zip(bars, times):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height + 1, f'{time_val:.1f}m',
             ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

# Print comprehensive comparison
print("🏆 TRAINING RESULTS COMPARISON")
print("="*60)
print(f"{'Metric':<20} {'ResNet50':<15} {'ResNet101':<15} {'Winner'}")
print("-"*60)
print(f"{'Best Accuracy':<20} {resnet50_results['best_acc']*100:<14.2f}% {resnet101_results['best_acc']*100:<14.2f}% {'ResNet101' if resnet101_results['best_acc'] > resnet50_results['best_acc'] else 'ResNet50'}")
print(f"{'Final Accuracy':<20} {resnet50_results['train_accuracies'][-1]*100:<14.2f}% {resnet101_results['train_accuracies'][-1]*100:<14.2f}% {'ResNet101' if resnet101_results['train_accuracies'][-1] > resnet50_results['train_accuracies'][-1] else 'ResNet50'}")
print(f"{'Final Loss':<20} {resnet50_results['train_losses'][-1]:<14.4f} {resnet101_results['train_losses'][-1]:<14.4f} {'ResNet101' if resnet101_results['train_losses'][-1] < resnet50_results['train_losses'][-1] else 'ResNet50'}")
print(f"{'Training Time':<20} {resnet50_results['total_time']/60:<14.1f}m {resnet101_results['total_time']/60:<14.1f}m {'ResNet50' if resnet50_results['total_time'] < resnet101_results['total_time'] else 'ResNet101'}")

# Calculate improvement
acc_improvement = (resnet101_results['best_acc'] - resnet50_results['best_acc']) * 100
time_overhead = (resnet101_results['total_time'] - resnet50_results['total_time']) / 60

print(f"\n📈 Performance Analysis:")
print(f"   ResNet101 vs ResNet50 accuracy improvement: {acc_improvement:+.2f}%")
print(f"   ResNet101 training time overhead: {time_overhead:+.1f} minutes")
print(f"   Accuracy per minute (ResNet50): {resnet50_results['best_acc']*100/(resnet50_results['total_time']/60):.2f}%/min")
print(f"   Accuracy per minute (ResNet101): {resnet101_results['best_acc']*100/(resnet101_results['total_time']/60):.2f}%/min")


In [None]:
# Test models on sample images
def test_model_predictions(model_path, model_name, num_classes=4):
    """Load model and test on sample images"""
    print(f"\n🧪 Testing {model_name} predictions...")
    
    # Load model
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=False)
    else:
        model = models.resnet101(pretrained=False)
    
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    # Test on sample images
    sample_images = {}
    for emotion in EMOTION_CLASSES:
        emotion_path = os.path.join(data_root, emotion)
        if os.path.isdir(emotion_path):
            image_files = [f for f in os.listdir(emotion_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            if image_files:
                sample_images[emotion] = os.path.join(emotion_path, image_files[0])
    
    print(f"📷 Testing on {len(sample_images)} sample images:")
    correct_predictions = 0
    total_predictions = len(sample_images)
    
    for true_emotion, image_path in sample_images.items():
        try:
            # Load and preprocess image
            img = Image.open(image_path).convert('RGB')
            img_tensor = val_transform(img).unsqueeze(0).to(device)
            
            # Predict
            with torch.no_grad():
                outputs = model(img_tensor)
                probabilities = torch.softmax(outputs, dim=1)
                _, predicted_idx = torch.max(outputs, 1)
                predicted_emotion = EMOTION_CLASSES[predicted_idx.item()]
                confidence = probabilities[0][predicted_idx].item()
            
            correct = predicted_emotion == true_emotion
            if correct:
                correct_predictions += 1
            
            print(f"   🖼️  {os.path.basename(image_path)}: {true_emotion} → {predicted_emotion} ({confidence:.3f}) {'✅' if correct else '❌'}")
            
        except Exception as e:
            print(f"   ❌ Error processing {image_path}: {e}")
    
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    print(f"   🎯 Sample accuracy: {correct_predictions}/{total_predictions} ({accuracy*100:.1f}%)")
    
    return accuracy

# Test ResNet50 best model
resnet50_best_path = os.path.join(resnet50_results['checkpoint_dir'], 'best_model.pth')
if os.path.exists(resnet50_best_path):
    resnet50_sample_acc = test_model_predictions(resnet50_best_path, 'resnet50', NUM_CLASSES)

# Test ResNet101 best model  
resnet101_best_path = os.path.join(resnet101_results['checkpoint_dir'], 'best_model.pth')
if os.path.exists(resnet101_best_path):
    resnet101_sample_acc = test_model_predictions(resnet101_best_path, 'resnet101', NUM_CLASSES)


In [None]:
# Final Summary
print("🎉 TRAINING COMPLETE - FINAL SUMMARY")
print("="*70)

print(f"\n📊 Dataset Information:")
print(f"   Total samples: {len(dataset):,}")
print(f"   Emotion classes: {EMOTION_CLASSES}")
print(f"   Input resolution: 224x224 (ImageNet standard)")
print(f"   Data augmentation: ✅ (rotation, flip, color jitter)")

print(f"\n🏗️  Model Architectures:")
resnet50_params = sum(p.numel() for p in models.resnet50().parameters())
resnet101_params = sum(p.numel() for p in models.resnet101().parameters())
print(f"   ResNet50:  {resnet50_params:,} parameters (~{resnet50_params*4/(1024**2):.0f} MB)")
print(f"   ResNet101: {resnet101_params:,} parameters (~{resnet101_params*4/(1024**2):.0f} MB)")
print(f"   Pretrained: ✅ ImageNet weights")
print(f"   Fine-tuning: Full model training")

print(f"\n🏆 Training Results:")
print(f"   {'Model':<10} {'Best Acc':<10} {'Final Acc':<11} {'Final Loss':<11} {'Time':<8} {'Models Saved'}")
print(f"   {'-'*10} {'-'*10} {'-'*11} {'-'*11} {'-'*8} {'-'*12}")
print(f"   {'ResNet50':<10} {resnet50_results['best_acc']*100:<9.2f}% {resnet50_results['train_accuracies'][-1]*100:<10.2f}% {resnet50_results['train_losses'][-1]:<10.4f} {resnet50_results['total_time']/60:<7.1f}m {len(resnet50_results['best_models'])+1}")
print(f"   {'ResNet101':<10} {resnet101_results['best_acc']*100:<9.2f}% {resnet101_results['train_accuracies'][-1]*100:<10.2f}% {resnet101_results['train_losses'][-1]:<10.4f} {resnet101_results['total_time']/60:<7.1f}m {len(resnet101_results['best_models'])+1}")

# List all saved models
print(f"\n💾 Saved Models Summary:")
print(f"   📁 ResNet50 checkpoints ({resnet50_results['checkpoint_dir']}):")
print(f"      📄 best_model.pth (Accuracy: {resnet50_results['best_acc']*100:.2f}%)")
for epoch_name, info in resnet50_results['best_models'].items():
    print(f"      📄 model_{epoch_name}.pth (Accuracy: {info['accuracy']*100:.2f}%)")

print(f"   📁 ResNet101 checkpoints ({resnet101_results['checkpoint_dir']}):")
print(f"      📄 best_model.pth (Accuracy: {resnet101_results['best_acc']*100:.2f}%)")
for epoch_name, info in resnet101_results['best_models'].items():
    print(f"      📄 model_{epoch_name}.pth (Accuracy: {info['accuracy']*100:.2f}%)")

# Total models count
total_models = len(resnet50_results['best_models']) + len(resnet101_results['best_models']) + 2  # +2 for best models
print(f"\n📊 Total Models Generated: {total_models}")
print(f"   ResNet50: {len(resnet50_results['best_models']) + 1} models")
print(f"   ResNet101: {len(resnet101_results['best_models']) + 1} models")

# Download models (for Colab)
try:
    from google.colab import files
    print(f"\n📥 Downloading trained models...")
    
    # Download ResNet50 best model
    if os.path.exists(resnet50_best_path):
        print(f"📦 Downloading ResNet50 best model...")
        files.download(resnet50_best_path)
    
    # Download ResNet101 best model
    if os.path.exists(resnet101_best_path):
        print(f"📦 Downloading ResNet101 best model...")
        files.download(resnet101_best_path)
    
    print("✅ Model downloads completed!")
    
except ImportError:
    print(f"\n💾 Models saved locally:")
    print(f"   ResNet50: {resnet50_results['checkpoint_dir']}/")
    print(f"   ResNet101: {resnet101_results['checkpoint_dir']}/")

print(f"\n🔄 Usage Example:")
print("```python")
print("import torch")
print("import torchvision.models as models")
print("import torchvision.transforms as transforms")
print("")
print("# Load ResNet50 model")
print("model = models.resnet50(pretrained=False)")
print("model.fc = torch.nn.Linear(model.fc.in_features, 4)")
print("checkpoint = torch.load('best_model.pth')")
print("model.load_state_dict(checkpoint['model_state_dict'])")
print("model.eval()")
print("")
print("# Transform for inference")
print("transform = transforms.Compose([")
print("    transforms.Resize((224, 224)),")
print("    transforms.ToTensor(),")
print("    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])")
print("])")
print("")
print("# Predict emotion")
print("emotion_classes = ['angry', 'happy', 'relaxed', 'sad']")
print("img_tensor = transform(image).unsqueeze(0)")
print("with torch.no_grad():")
print("    outputs = model(img_tensor)")
print("    predicted_idx = outputs.argmax(dim=1)")
print("    emotion = emotion_classes[predicted_idx.item()]")
print("```")

print(f"\n" + "="*70)
print("🎉 ResNet Pretrained Training Experiment Completed Successfully! 🎯")
