In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image, ImageFilter
import torch.nn.functional as F
from collections import Counter

# Configuration - Key improvements
DATA_DIR = "./"
TRAIN_DIR = os.path.join(DATA_DIR, "train/train/")
TEST_DIR = os.path.join(DATA_DIR, "test/test/")
CSV_PATH = os.path.join(DATA_DIR, "train.csv")

BATCH_SIZE = 32  # Increased for better GPU utilization
EPOCHS = 20  # More epochs with better early stopping
LR = 5e-5  # Slightly lower learning rate
N_SPLITS = 7  # More folds for better validation
IMG_SIZE = 384  # Larger image size for better feature extraction
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Set seeds
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Using device: {DEVICE}")

# Load data
df = pd.read_csv(CSV_PATH)
le = LabelEncoder()
df['label_idx'] = le.fit_transform(df['TARGET'])
df['image_id'] = df['ID']
num_classes = df['label_idx'].nunique()

print(f"Number of classes: {num_classes}")
print("Class distribution:")
print(df['TARGET'].value_counts())

# Enhanced dataset with multiple augmentation strategies
class AdvancedDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_training=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_training = is_training

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image_id'])
        
        try:
            image = Image.open(img_path).convert("RGB")
            
            # Apply subtle preprocessing
            if self.is_training and np.random.random() < 0.1:
                image = image.filter(ImageFilter.GaussianBlur(radius=0.5))
                
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            image = Image.new('RGB', (224, 224), color=(128, 128, 128))
            
        if self.transform:
            image = self.transform(image)
            
        label = row['label_idx']
        return image, label

# More aggressive and diverse augmentations
def get_train_transforms(img_size=384):
    return transforms.Compose([
        transforms.Resize((int(img_size * 1.1), int(img_size * 1.1))),
        transforms.RandomCrop((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
        transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15)),
        transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
    ])

def get_val_transforms(img_size=384):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# Test Time Augmentation transforms
def get_tta_transforms(img_size=384):
    return [
        transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize((int(img_size * 1.05), int(img_size * 1.05))),
            transforms.CenterCrop((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
    ]

# Enhanced model with better architectures
def get_model(model_name='efficientnet', num_classes=20):
    if model_name == 'efficientnet':
        from torchvision.models import efficientnet_b3
        model = efficientnet_b3(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        print("✅ Using EfficientNet-B3")
    elif model_name == 'resnext101':
        model = models.resnext101_32x8d(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        print("✅ Using ResNeXt101")
    elif model_name == 'densenet':
        model = models.densenet161(pretrained=True)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        print("✅ Using DenseNet161")
    else:
        # Enhanced ResNeXt with dropout
        model = models.resnext50_32x4d(pretrained=True)
        model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(model.fc.in_features, num_classes)
        )
        print("✅ Using Enhanced ResNeXt50")
    
    return model

# Label smoothing loss
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1, weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.weight = weight

    def forward(self, x, target):
        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(x, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss
        
        if self.weight is not None:
            loss = loss * self.weight[target]
        
        return loss.mean()

# Calculate class weights with smoothing
def get_class_weights(df, smoothing=0.1):
    class_counts = df['label_idx'].value_counts().sort_index().values
    total_samples = len(df)
    # Apply smoothing to avoid extreme weights
    smoothed_counts = class_counts + smoothing * total_samples / len(class_counts)
    class_weights = total_samples / (len(smoothed_counts) * smoothed_counts)
    # Cap extreme weights
    class_weights = np.clip(class_weights, 0.5, 3.0)
    return torch.FloatTensor(class_weights)

# Training function with gradient accumulation
def train_one_epoch(model, loader, optimizer, criterion, device, accumulation_steps=1):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc="Training")
    optimizer.zero_grad()
    
    for batch_idx, (imgs, labels) in enumerate(pbar):
        imgs, labels = imgs.to(device), labels.to(device)
        
        outputs = model(imgs)
        loss = criterion(outputs, labels) / accumulation_steps
        loss.backward()
        
        if (batch_idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * imgs.size(0) * accumulation_steps
        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
        
        if batch_idx % 50 == 0:
            pbar.set_postfix({
                'Loss': f'{loss.item() * accumulation_steps:.4f}', 
                'Acc': f'{correct/total:.4f}',
                'Batch': f'{batch_idx}/{len(loader)}'
            })
    
    return total_loss / total, correct / total

# Enhanced validation with TTA
def validate_with_tta(model, dataset, device, batch_size=32):
    model.eval()
    tta_transforms = get_tta_transforms(IMG_SIZE)
    
    all_preds = []
    labels_all = []
    
    with torch.no_grad():
        for idx in tqdm(range(len(dataset.df)), desc="TTA Validation"):
            row = dataset.df.iloc[idx]
            img_path = os.path.join(dataset.img_dir, row['image_id'])
            
            try:
                image = Image.open(img_path).convert("RGB")
            except:
                image = Image.new('RGB', (224, 224), color=(128, 128, 128))
            
            tta_preds = []
            for transform in tta_transforms:
                img_tensor = transform(image).unsqueeze(0).to(device)
                output = model(img_tensor)
                pred = F.softmax(output, dim=1)
                tta_preds.append(pred.cpu().numpy())
            
            # Average TTA predictions
            avg_pred = np.mean(tta_preds, axis=0)
            all_preds.append(avg_pred)
            labels_all.append(row['label_idx'])
    
    all_preds = np.concatenate(all_preds, axis=0)
    pred_classes = np.argmax(all_preds, axis=1)
    
    f1 = f1_score(labels_all, pred_classes, average="micro")
    accuracy = (pred_classes == np.array(labels_all)).mean()
    
    return f1, accuracy

# Regular validation function
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    preds_all, labels_all = [], []
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Validating", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
            
            preds_all.extend(preds.cpu().numpy())
            labels_all.extend(labels.cpu().numpy())
    
    f1 = f1_score(labels_all, preds_all, average="micro")
    return total_loss / total, correct / total, f1

# Main training loop with multiple models
def train_models():
    # Use better performing models
    model_configs = [
        {'name': 'efficientnet', 'lr': 3e-5, 'epochs': 25},
        {'name': 'resnext101', 'lr': 2e-5, 'epochs': 20},
        {'name': 'densenet', 'lr': 4e-5, 'epochs': 22},
    ]
    
    all_results = {}
    
    for config in model_configs:
        model_name = config['name']
        lr = config['lr']
        epochs = config['epochs']
        
        print(f"\n{'='*60}")
        print(f"Training {model_name}")
        print(f"{'='*60}")
        
        skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
        fold_scores = []
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['label_idx'])):
            print(f"\n🔥 Fold {fold+1}/{N_SPLITS}")
            
            train_df = df.iloc[train_idx].reset_index(drop=True)
            val_df = df.iloc[val_idx].reset_index(drop=True)
            
            # Create datasets
            train_dataset = AdvancedDataset(train_df, TRAIN_DIR, get_train_transforms(IMG_SIZE), is_training=True)
            val_dataset = AdvancedDataset(val_df, TRAIN_DIR, get_val_transforms(IMG_SIZE))
            
            train_loader = DataLoader(
                train_dataset, 
                batch_size=BATCH_SIZE, 
                shuffle=True,
                num_workers=4,
                pin_memory=True,
                drop_last=True
            )
            val_loader = DataLoader(
                val_dataset, 
                batch_size=BATCH_SIZE, 
                shuffle=False,
                num_workers=4,
                pin_memory=True
            )
            
            # Model setup
            model = get_model(model_name, num_classes).to(DEVICE)
            
            # Enhanced optimizer with different parameters
            if model_name == 'efficientnet':
                optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
            else:
                optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-5)
            
            # Cosine annealing with warm restarts
            scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=5, T_mult=2, eta_min=1e-7
            )
            
            # Enhanced loss
            class_weights = get_class_weights(train_df).to(DEVICE)
            criterion = LabelSmoothingCrossEntropy(smoothing=0.1, weight=class_weights)
            
            # Training loop
            best_f1 = 0
            patience = 7
            patience_counter = 0
            
            for epoch in range(epochs):
                print(f"\nEpoch {epoch+1}/{epochs}")
                
                # Train with gradient accumulation if batch size is small
                accumulation_steps = max(1, 64 // BATCH_SIZE)
                train_loss, train_acc = train_one_epoch(
                    model, train_loader, optimizer, criterion, DEVICE, accumulation_steps
                )
                
                # Validate
                val_loss, val_acc, val_f1 = validate(model, val_loader, criterion, DEVICE)
                
                scheduler.step()
                
                print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
                print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
                
                # Enhanced TTA validation for best models
                if val_f1 > best_f1:
                    best_f1 = val_f1
                    patience_counter = 0
                    
                    # Use TTA for final validation on promising models
                    if val_f1 > 0.88:
                        tta_f1, tta_acc = validate_with_tta(model, val_dataset, DEVICE)
                        print(f"TTA - Acc: {tta_acc:.4f}, F1: {tta_f1:.4f}")
                        if tta_f1 > best_f1:
                            best_f1 = tta_f1
                    
                    torch.save(model.state_dict(), f"{model_name}_fold{fold}_best.pth")
                    print(f"✅ New best F1: {best_f1:.4f}")
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"⏰ Early stopping at epoch {epoch+1}")
                        break
            
            fold_scores.append(best_f1)
            print(f"\n📊 Fold {fold+1} completed - Best F1: {best_f1:.4f}")
        
        avg_score = np.mean(fold_scores)
        std_score = np.std(fold_scores)
        all_results[model_name] = {
            'scores': fold_scores,
            'mean': avg_score,
            'std': std_score
        }
        
        print(f"\n🎯 {model_name} Results:")
        print(f"Fold scores: {[f'{score:.4f}' for score in fold_scores]}")
        print(f"Average F1: {avg_score:.4f} ± {std_score:.4f}")
    
    return all_results

# Ensemble prediction function
def create_ensemble_predictions():
    print("\n🔮 Creating ensemble predictions...")
    
    # Load test data info (you'll need to implement this based on test structure)
    test_files = os.listdir(TEST_DIR)
    
    model_configs = [
        {'name': 'efficientnet', 'weight': 0.4},
        {'name': 'resnext101', 'weight': 0.35},  
        {'name': 'densenet', 'weight': 0.25}
    ]
    
    # Implementation would continue here for ensemble predictions...
    print("Ensemble prediction framework ready!")

if __name__ == "__main__":
    print("🚀 Starting enhanced training...")
    results = train_models()
    
    print("\n🎉 Training completed!")
    print("\n📊 Final Results Summary:")
    for model_name, result in results.items():
        print(f"{model_name}: {result['mean']:.4f} ± {result['std']:.4f}")
    
    create_ensemble_predictions()

In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image, ImageFilter
import torch.nn.functional as F
from collections import Counter
import gc

# Configuration optimized for RTX 4060 (8GB VRAM)
DATA_DIR = "./"
TRAIN_DIR = os.path.join(DATA_DIR, "train/train/")
TEST_DIR = os.path.join(DATA_DIR, "test/test/")
CSV_PATH = os.path.join(DATA_DIR, "train.csv")

# RTX 4060 Optimized Settings
BATCH_SIZE = 16  # Conservative for 8GB VRAM
EPOCHS = 18
LR = 3e-5
N_SPLITS = 5  # Reduced for faster training
IMG_SIZE = 320  # Balance between quality and memory
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Memory optimization
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Create model save directories
MODEL_SAVE_DIR = "saved_models"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

for model_type in ['efficientnet', 'resnext101', 'densenet']:
    model_dir = os.path.join(MODEL_SAVE_DIR, model_type)
    os.makedirs(model_dir, exist_ok=True)

# Set seeds
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Using device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Load data
df = pd.read_csv(CSV_PATH)
le = LabelEncoder()
df['label_idx'] = le.fit_transform(df['TARGET'])
df['image_id'] = df['ID']
num_classes = df['label_idx'].nunique()

print(f"Number of classes: {num_classes}")
print("Class distribution:")
print(df['TARGET'].value_counts())

# Memory-efficient dataset
class MemoryEfficientDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_training=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_training = is_training

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image_id'])
        
        try:
            # Load and immediately process to save memory
            with Image.open(img_path) as image:
                image = image.convert("RGB")
                
                # Minimal preprocessing to save memory
                if self.is_training and np.random.random() < 0.05:
                    image = image.filter(ImageFilter.GaussianBlur(radius=0.3))
                    
                if self.transform:
                    image = self.transform(image)
                    
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Create minimal fallback
            blank = Image.new('RGB', (224, 224), color=(128, 128, 128))
            image = self.transform(blank) if self.transform else blank
            
        label = row['label_idx']
        return image, label

# Optimized augmentations for 4060
def get_train_transforms(img_size=320):
    return transforms.Compose([
        transforms.Resize((int(img_size * 1.05), int(img_size * 1.05))),
        transforms.RandomCrop((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.25),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2, scale=(0.02, 0.25)),
    ])

def get_val_transforms(img_size=320):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# TTA transforms (memory efficient)
def get_tta_transforms(img_size=320):
    return [
        transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
    ]

# Memory-optimized model loading
def get_model(model_name='efficientnet', num_classes=20):
    # Clear GPU cache before loading new model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    
    if model_name == 'efficientnet':
        # Use smaller EfficientNet for 4060
        try:
            model = models.efficientnet_b2(pretrained=True)  # B2 instead of B3 for memory
            model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
            print("✅ Using EfficientNet-B2 (Memory Optimized)")
        except:
            model = models.efficientnet_b0(pretrained=True)
            model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
            print("✅ Using EfficientNet-B0 (Fallback)")
    elif model_name == 'resnext101':
        # Use ResNeXt50 instead of 101 for memory
        model = models.resnext50_32x4d(pretrained=True)
        model.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(model.fc.in_features, num_classes)
        )
        print("✅ Using ResNeXt50 (Memory Optimized)")
    elif model_name == 'densenet':
        # Use smaller DenseNet
        model = models.densenet121(pretrained=True)  # 121 instead of 161
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        print("✅ Using DenseNet121 (Memory Optimized)")
    else:
        model = models.resnext50_32x4d(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        print("✅ Using ResNeXt50 (Default)")
    
    return model

# Label smoothing loss (memory efficient)
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1, weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.weight = weight

    def forward(self, x, target):
        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(x, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss
        
        if self.weight is not None:
            loss = loss * self.weight[target]
        
        return loss.mean()

# Memory-optimized class weights
def get_class_weights(df, smoothing=0.1):
    class_counts = df['label_idx'].value_counts().sort_index().values
    total_samples = len(df)
    smoothed_counts = class_counts + smoothing * total_samples / len(class_counts)
    class_weights = total_samples / (len(smoothed_counts) * smoothed_counts)
    class_weights = np.clip(class_weights, 0.7, 2.0)  # Less extreme weights
    return torch.FloatTensor(class_weights)

# Memory-efficient training function
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc="Training")
    
    for batch_idx, (imgs, labels) in enumerate(pbar):
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
        
        # Memory cleanup every 20 batches
        if batch_idx % 20 == 0:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}', 
                'Acc': f'{correct/total:.4f}',
                'VRAM': f'{torch.cuda.memory_allocated()/1e9:.1f}GB' if torch.cuda.is_available() else 'N/A'
            })
    
    return total_loss / total, correct / total

# Memory-efficient validation
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    preds_all, labels_all = [], []
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Validating", leave=False):
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
            
            preds_all.extend(preds.cpu().numpy())
            labels_all.extend(labels.cpu().numpy())
            
            # Clean up
            del imgs, labels, outputs
    
    f1 = f1_score(labels_all, preds_all, average="micro")
    return total_loss / total, correct / total, f1

# Memory-efficient TTA validation (limited)
def validate_with_tta(model, dataset, device, batch_size=8):  # Smaller batch for TTA
    model.eval()
    tta_transforms = get_tta_transforms(IMG_SIZE)
    
    all_preds = []
    labels_all = []
    
    with torch.no_grad():
        # Process in smaller chunks for memory
        for start_idx in tqdm(range(0, len(dataset.df), batch_size), desc="TTA Validation"):
            end_idx = min(start_idx + batch_size, len(dataset.df))
            batch_preds = []
            batch_labels = []
            
            for idx in range(start_idx, end_idx):
                row = dataset.df.iloc[idx]
                img_path = os.path.join(dataset.img_dir, row['image_id'])
                
                try:
                    with Image.open(img_path) as image:
                        image = image.convert("RGB")
                        
                        tta_preds = []
                        for transform in tta_transforms:
                            img_tensor = transform(image).unsqueeze(0).to(device)
                            output = model(img_tensor)
                            pred = F.softmax(output, dim=1)
                            tta_preds.append(pred.cpu().numpy())
                            del img_tensor, output  # Immediate cleanup
                        
                        avg_pred = np.mean(tta_preds, axis=0)
                        batch_preds.append(avg_pred)
                        batch_labels.append(row['label_idx'])
                except:
                    # Fallback for failed images
                    dummy_pred = np.ones((1, 20)) / 20
                    batch_preds.append(dummy_pred)
                    batch_labels.append(row['label_idx'])
            
            if batch_preds:
                all_preds.extend(batch_preds)
                labels_all.extend(batch_labels)
            
            # Memory cleanup after each batch
            torch.cuda.empty_cache()
    
    all_preds = np.concatenate(all_preds, axis=0)
    pred_classes = np.argmax(all_preds, axis=1)
    
    f1 = f1_score(labels_all, pred_classes, average="micro")
    accuracy = (pred_classes == np.array(labels_all)).mean()
    
    return f1, accuracy

# Main training loop optimized for RTX 4060
def train_models():
    # Prioritize best performing models for limited compute
    model_configs = [
        {'name': 'efficientnet', 'lr': 3e-5, 'epochs': 16},
        {'name': 'resnext101', 'lr': 2e-5, 'epochs': 14}, 
        {'name': 'densenet', 'lr': 4e-5, 'epochs': 15},
    ]
    
    all_results = {}
    
    for config in model_configs:
        model_name = config['name']
        lr = config['lr']
        epochs = config['epochs']
        
        print(f"\n{'='*60}")
        print(f"Training {model_name}")
        print(f"{'='*60}")
        
        # Clear memory before each model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
        fold_scores = []
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['label_idx'])):
            print(f"\n🔥 Fold {fold+1}/{N_SPLITS}")
            
            train_df = df.iloc[train_idx].reset_index(drop=True)
            val_df = df.iloc[val_idx].reset_index(drop=True)
            
            print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}")
            
            # Create datasets
            train_dataset = MemoryEfficientDataset(
                train_df, TRAIN_DIR, get_train_transforms(IMG_SIZE), is_training=True
            )
            val_dataset = MemoryEfficientDataset(
                val_df, TRAIN_DIR, get_val_transforms(IMG_SIZE)
            )
            
            # DataLoaders optimized for 4060
            train_loader = DataLoader(
                train_dataset, 
                batch_size=BATCH_SIZE, 
                shuffle=True,
                num_workers=2,  # Reduced for stability
                pin_memory=True,
                drop_last=True,
                persistent_workers=True
            )
            val_loader = DataLoader(
                val_dataset, 
                batch_size=BATCH_SIZE, 
                shuffle=False,
                num_workers=2,
                pin_memory=True,
                persistent_workers=True
            )
            
            # Model setup
            model = get_model(model_name, num_classes).to(DEVICE)
            
            # Memory-efficient mixed precision training
            scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
            
            # Optimizer
            optimizer = optim.AdamW(
                model.parameters(), 
                lr=lr, 
                weight_decay=1e-4 if model_name == 'efficientnet' else 5e-5,
                eps=1e-8
            )
            
            # Learning rate scheduler
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=epochs, eta_min=1e-7
            )
            
            # Loss function
            class_weights = get_class_weights(train_df).to(DEVICE)
            criterion = LabelSmoothingCrossEntropy(smoothing=0.1, weight=class_weights)
            
            # Training loop
            best_f1 = 0
            patience = 5
            patience_counter = 0
            
            print("Starting training...")
            for epoch in range(epochs):
                print(f"\nEpoch {epoch+1}/{epochs}")
                
                # Train
                train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
                
                # Validate
                val_loss, val_acc, val_f1 = validate(model, val_loader, criterion, DEVICE)
                
                scheduler.step()
                
                print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
                print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
                
                # Save best model in organized folders
                if val_f1 > best_f1:
                    best_f1 = val_f1
                    patience_counter = 0
                    
                    # Use TTA for final validation on very good models
                    if val_f1 > 0.88:  # Only for promising models to save time
                        try:
                            tta_f1, tta_acc = validate_with_tta(model, val_dataset, DEVICE, batch_size=4)
                            print(f"TTA - Acc: {tta_acc:.4f}, F1: {tta_f1:.4f}")
                            if tta_f1 > best_f1:
                                best_f1 = tta_f1
                        except Exception as e:
                            print(f"TTA failed: {e}, using regular validation")
                    
                    # Save in organized folder structure
                    model_folder = os.path.join(MODEL_SAVE_DIR, model_name)
                    model_path = os.path.join(model_folder, f"fold{fold}_best.pth")
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'best_f1': best_f1,
                        'epoch': epoch,
                        'model_name': model_name,
                        'fold': fold
                    }, model_path)
                    
                    print(f"✅ New best F1: {best_f1:.4f}")
                    print(f"📁 Model saved: {model_path}")
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"⏰ Early stopping at epoch {epoch+1}")
                        break
                
                # Memory cleanup
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            
            fold_scores.append(best_f1)
            print(f"\n📊 Fold {fold+1} completed - Best F1: {best_f1:.4f}")
            
            # Clean up model for next fold
            del model
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
        
        avg_score = np.mean(fold_scores)
        std_score = np.std(fold_scores)
        all_results[model_name] = {
            'scores': fold_scores,
            'mean': avg_score,
            'std': std_score
        }
        
        print(f"\n🎯 {model_name} Results:")
        print(f"Fold scores: {[f'{score:.4f}' for score in fold_scores]}")
        print(f"Average F1: {avg_score:.4f} ± {std_score:.4f}")
    
    return all_results

# Model loading utility
def load_best_models():
    """Load the best models from each architecture"""
    print("📁 Available models:")
    
    for model_type in ['efficientnet', 'resnext101', 'densenet']:
        model_folder = os.path.join(MODEL_SAVE_DIR, model_type)
        if os.path.exists(model_folder):
            model_files = [f for f in os.listdir(model_folder) if f.endswith('.pth')]
            print(f"\n{model_type}:")
            for file in model_files:
                full_path = os.path.join(model_folder, file)
                try:
                    checkpoint = torch.load(full_path, map_location='cpu')
                    print(f"  {file}: F1 = {checkpoint.get('best_f1', 'Unknown'):.4f}")
                except:
                    print(f"  {file}: Unable to load info")

if __name__ == "__main__":
    print("🚀 Starting RTX 4060 optimized training...")
    print(f"💾 Models will be saved in: {MODEL_SAVE_DIR}/")
    
    results = train_models()
    
    print("\n🎉 Training completed!")
    print("\n📊 Final Results Summary:")
    for model_name, result in results.items():
        print(f"{model_name}: {result['mean']:.4f} ± {result['std']:.4f}")
    
    print(f"\n📁 All models saved in organized folders:")
    load_best_models()
    
    print(f"\n💡 Next steps:")
    print("1. Use the best models for ensemble predictions")
    print("2. Apply Test Time Augmentation on test set")
    print("3. Average predictions from multiple folds/models")

In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import torch.nn.functional as F
import gc
import time

# Configuration optimized for RTX 4060 (8GB VRAM)
DATA_DIR = "./"
TRAIN_DIR = os.path.join(DATA_DIR, "train/train/")
TEST_DIR = os.path.join(DATA_DIR, "test/test/")
CSV_PATH = os.path.join(DATA_DIR, "train.csv")

# Conservative settings to avoid hanging
BATCH_SIZE = 12  # Reduced from 16
EPOCHS = 15
LR = 3e-5
N_SPLITS = 5
IMG_SIZE = 320
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

# Memory optimization
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Create model save directories
MODEL_SAVE_DIR = "saved_models"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

for model_type in ['efficientnet', 'resnext', 'densenet']:
    model_dir = os.path.join(MODEL_SAVE_DIR, model_type)
    os.makedirs(model_dir, exist_ok=True)

# Set seeds
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Using device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Load and validate data
print("Loading and validating data...")
df = pd.read_csv(CSV_PATH)
le = LabelEncoder()
df['label_idx'] = le.fit_transform(df['TARGET'])
df['image_id'] = df['ID']
num_classes = df['label_idx'].nunique()

# Validate that all image files exist
print("Validating image files...")
valid_indices = []
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Checking files"):
    img_path = os.path.join(TRAIN_DIR, row['image_id'])
    if os.path.exists(img_path):
        try:
            # Quick validation - try to open the image
            with Image.open(img_path) as img:
                img.verify()  # Check if image is corrupted
            valid_indices.append(idx)
        except Exception as e:
            print(f"Skipping corrupted image: {img_path} - {e}")
    else:
        print(f"Missing file: {img_path}")

# Keep only valid images
df = df.iloc[valid_indices].reset_index(drop=True)
print(f"Valid images: {len(df)} / {len(valid_indices)}")

print(f"Number of classes: {num_classes}")
print("Class distribution:")
print(df['TARGET'].value_counts().head())

# Robust dataset with better error handling
class RobustDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_training=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_training = is_training
        
        # Cache image paths for faster access
        self.image_paths = [os.path.join(img_dir, row['image_id']) for _, row in df.iterrows()]
        
        print(f"Dataset initialized with {len(self.df)} samples")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        max_retries = 3
        for attempt in range(max_retries):
            try:
                row = self.df.iloc[idx]
                img_path = self.image_paths[idx]
                
                # Load image with timeout protection
                with Image.open(img_path) as image:
                    image = image.convert("RGB")
                    
                    # Apply transform immediately while image is loaded
                    if self.transform:
                        image = self.transform(image)
                    else:
                        # Fallback transform
                        image = transforms.ToTensor()(image)
                
                label = row['label_idx']
                return image, label
                
            except Exception as e:
                print(f"Error loading image {idx} (attempt {attempt+1}): {e}")
                if attempt == max_retries - 1:
                    # Return a black image as fallback
                    print(f"Using fallback for image {idx}")
                    fallback_img = torch.zeros((3, IMG_SIZE, IMG_SIZE))
                    return fallback_img, row['label_idx']
                time.sleep(0.1)  # Brief pause before retry

# Simplified transforms to reduce processing time
def get_train_transforms(img_size=320):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

def get_val_transforms(img_size=320):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# Simplified model loading
def get_model(model_name='efficientnet', num_classes=20):
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    
    if model_name == 'efficientnet':
        model = models.efficientnet_b1(pretrained=True)  # B1 instead of B2 for more stability
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        print("✅ Using EfficientNet-B1")
    elif model_name == 'resnext':
        model = models.resnext50_32x4d(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        print("✅ Using ResNeXt50")
    elif model_name == 'densenet':
        model = models.densenet121(pretrained=True)
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
        print("✅ Using DenseNet121")
    else:
        model = models.resnet50(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        print("✅ Using ResNet50 (fallback)")
    
    return model

# Simplified loss function
def get_class_weights(df):
    class_counts = df['label_idx'].value_counts().sort_index().values
    total_samples = len(df)
    class_weights = total_samples / (len(class_counts) * class_counts)
    class_weights = np.clip(class_weights, 0.5, 2.0)
    return torch.FloatTensor(class_weights)

# Robust training function
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    print(f"Starting epoch with {len(loader)} batches...")
    
    # Force CUDA initialization
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    pbar = tqdm(loader, desc="Training", total=len(loader))
    
    batch_count = 0
    for batch_idx, (imgs, labels) in enumerate(pbar):
        try:
            batch_count += 1
            
            # Move to device
            imgs = imgs.to(device, non_blocking=False)  # Disable non_blocking for stability
            labels = labels.to(device, non_blocking=False)
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            # Statistics
            total_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
            
            # Update progress
            if batch_idx % 5 == 0:
                current_acc = correct / total if total > 0 else 0
                pbar.set_postfix({
                    'Loss': f'{loss.item():.4f}', 
                    'Acc': f'{current_acc:.3f}',
                    'Batch': f'{batch_idx+1}/{len(loader)}'
                })
            
            # Memory cleanup
            if batch_idx % 10 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            continue
    
    pbar.close()
    return total_loss / total if total > 0 else 0, correct / total if total > 0 else 0

# Simplified validation function
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    preds_all, labels_all = [], []
    
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validating", leave=False)
        for imgs, labels in pbar:
            try:
                imgs = imgs.to(device, non_blocking=False)
                labels = labels.to(device, non_blocking=False)
                
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                
                total_loss += loss.item() * imgs.size(0)
                preds = outputs.argmax(1)
                correct += (preds == labels).sum().item()
                total += imgs.size(0)
                
                preds_all.extend(preds.cpu().numpy())
                labels_all.extend(labels.cpu().numpy())
                
            except Exception as e:
                print(f"Error in validation batch: {e}")
                continue
        pbar.close()
    
    f1 = f1_score(labels_all, preds_all, average="micro") if len(preds_all) > 0 else 0
    return total_loss / total if total > 0 else 0, correct / total if total > 0 else 0, f1

# Main training function with better error handling
def train_models():
    model_configs = [
        {'name': 'efficientnet', 'lr': 3e-5, 'epochs': 15},
        {'name': 'resnext', 'lr': 2e-5, 'epochs': 12}, 
        {'name': 'densenet', 'lr': 4e-5, 'epochs': 14},
    ]
    
    all_results = {}
    
    for config in model_configs:
        model_name = config['name']
        lr = config['lr']
        epochs = config['epochs']
        
        print(f"\n{'='*60}")
        print(f"Training {model_name}")
        print(f"{'='*60}")
        
        # Clear memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
        fold_scores = []
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['label_idx'])):
            print(f"\n🔥 Fold {fold+1}/{N_SPLITS}")
            
            train_df = df.iloc[train_idx].reset_index(drop=True)
            val_df = df.iloc[val_idx].reset_index(drop=True)
            
            print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}")
            
            # Create datasets
            train_dataset = RobustDataset(
                train_df, TRAIN_DIR, get_train_transforms(IMG_SIZE), is_training=True
            )
            val_dataset = RobustDataset(
                val_df, TRAIN_DIR, get_val_transforms(IMG_SIZE)
            )
            
            # Create robust data loaders
            print("Creating data loaders...")
            try:
                train_loader = DataLoader(
                    train_dataset, 
                    batch_size=BATCH_SIZE, 
                    shuffle=True,
                    num_workers=0,  # Disable multiprocessing
                    pin_memory=False,  # Disable memory pinning
                    drop_last=True,
                    timeout=0,  # No timeout
                    persistent_workers=False
                )
                val_loader = DataLoader(
                    val_dataset, 
                    batch_size=BATCH_SIZE, 
                    shuffle=False,
                    num_workers=0,
                    pin_memory=False,
                    timeout=0,
                    persistent_workers=False
                )
                
                print(f"✅ DataLoaders created - Train: {len(train_loader)} batches")
                
                # Test first batch
                print("Testing first batch...")
                test_batch = next(iter(train_loader))
                print(f"✅ First batch loaded: {test_batch[0].shape}")
                del test_batch
                
            except Exception as e:
                print(f"❌ DataLoader creation failed: {e}")
                continue
            
            # Model setup
            print(f"Setting up {model_name} model...")
            try:
                model = get_model(model_name, num_classes).to(DEVICE)
                
                # Test model
                dummy_input = torch.randn(2, 3, IMG_SIZE, IMG_SIZE).to(DEVICE)
                with torch.no_grad():
                    dummy_output = model(dummy_input)
                print(f"✅ Model test passed: {dummy_output.shape}")
                del dummy_input, dummy_output
                
            except Exception as e:
                print(f"❌ Model setup failed: {e}")
                continue
            
            # Optimizer and scheduler
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-7)
            
            # Loss function
            class_weights = get_class_weights(train_df).to(DEVICE)
            criterion = nn.CrossEntropyLoss(weight=class_weights)
            
            # Training loop
            best_f1 = 0
            patience = 4
            patience_counter = 0
            
            print("Starting training loop...")
            for epoch in range(epochs):
                print(f"\n--- Epoch {epoch+1}/{epochs} ---")
                
                try:
                    # Train
                    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
                    
                    # Validate
                    val_loss, val_acc, val_f1 = validate(model, val_loader, criterion, DEVICE)
                    
                    scheduler.step()
                    
                    print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
                    print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
                    
                    # Save best model
                    if val_f1 > best_f1:
                        best_f1 = val_f1
                        patience_counter = 0
                        
                        model_folder = os.path.join(MODEL_SAVE_DIR, model_name)
                        model_path = os.path.join(model_folder, f"fold{fold}_best.pth")
                        torch.save({
                            'model_state_dict': model.state_dict(),
                            'best_f1': best_f1,
                            'epoch': epoch,
                            'model_name': model_name,
                            'fold': fold
                        }, model_path)
                        
                        print(f"✅ New best F1: {best_f1:.4f} - Model saved!")
                    else:
                        patience_counter += 1
                        if patience_counter >= patience:
                            print(f"⏰ Early stopping at epoch {epoch+1}")
                            break
                    
                    # Memory cleanup
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        
                except Exception as e:
                    print(f"❌ Error in epoch {epoch+1}: {e}")
                    break
            
            fold_scores.append(best_f1)
            print(f"\n📊 Fold {fold+1} completed - Best F1: {best_f1:.4f}")
            
            # Cleanup
            del model, optimizer, scheduler, criterion
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
        
        if fold_scores:
            avg_score = np.mean(fold_scores)
            std_score = np.std(fold_scores)
            all_results[model_name] = {
                'scores': fold_scores,
                'mean': avg_score,
                'std': std_score
            }
            
            print(f"\n🎯 {model_name} Results:")
            print(f"Fold scores: {[f'{score:.4f}' for score in fold_scores]}")
            print(f"Average F1: {avg_score:.4f} ± {std_score:.4f}")
    
    return all_results

if __name__ == "__main__":
    print("🚀 Starting robust training for RTX 4060...")
    print(f"💾 Models will be saved in: {MODEL_SAVE_DIR}/")
    
    try:
        results = train_models()
        
        print("\n🎉 Training completed!")
        print("\n📊 Final Results Summary:")
        for model_name, result in results.items():
            print(f"{model_name}: {result['mean']:.4f} ± {result['std']:.4f}")
            
    except Exception as e:
        print(f"❌ Training failed: {e}")
        import traceback
        traceback.print_exc()
    
    print(f"\n📁 Check {MODEL_SAVE_DIR}/ for saved models")