In [None]:
"""
=============================================================================
SIAMESE CONVNEXT V2 NANO - GPU OPTIMIZED (UNFROZEN BACKBONE)
=============================================================================
This version trains ALL 15.5M parameters for better accuracy.
Optimized for GTX 1050 Ti (4GB VRAM) with ~30% memory usage.

Expected improvement: 82.8% → 85-88% accuracy
=============================================================================
"""

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import rasterio
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score, accuracy_score,
    precision_score, recall_score, roc_curve, auc, precision_recall_curve
)
from sklearn.preprocessing import label_binarize
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

In [None]:
# =============================================================================
# CONFIGURATION - GPU OPTIMIZED (UNFROZEN BACKBONE)
# =============================================================================
CONFIG = {
    'image_size': 224,
    'batch_size': 12,           # Increased for GPU efficiency
    'epochs': 30,               # More epochs since we're training more params
    'learning_rate': 5e-5,      # Lower LR for fine-tuning (was 1e-3)
    'backbone_lr': 1e-5,        # Even lower for backbone layers
    'num_workers': 2,           # Parallel data loading
    'freeze_backbone': False,   # *** UNFROZEN - Train all 15.5M params ***
    'use_amp': True,            # Mixed precision for memory efficiency
    'n_folds': 5,
    'patience': 10,             # More patience for fine-tuning
    'tta_enabled': True,
    'model_name': 'convnextv2_nano',
    'gradient_clip': 1.0,       # Gradient clipping for stability
    'warmup_epochs': 2,         # Learning rate warmup
}

In [None]:
PRE_DIR = "PRE-event"
POST_DIR = "POST-event"

In [None]:
# Output directories
PLOT_DIR = "metric_plots_gpu"
if __name__ == "__main__":
    os.makedirs(PLOT_DIR, exist_ok=True)

    print("\n" + "="*70)
    print("SIAMESE CONVNEXT V2 NANO - GPU OPTIMIZED (FULL FINE-TUNING)")
    print("Training ALL 15.5M parameters for maximum accuracy")
    print("="*70 + "\n")

    # =============================================================================
    # CHECK GPU
    # =============================================================================
    print("STEP 1: CHECKING HARDWARE")
    print("-" * 70)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"GPU: {gpu_name}")
        print(f"Total VRAM: {gpu_memory:.1f} GB")
        print(f"Backbone: UNFROZEN (Full fine-tuning)")
        print(f"Expected memory: ~1.2 GB (~30%)")
        torch.cuda.empty_cache()
    else:
        print("ERROR: CUDA not available!")
        print("This script requires GPU. Install CUDA-enabled PyTorch:")
        print("  pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118")
        exit(1)

    # Import timm
    try:
        import timm
        print(f"timm version: {timm.__version__}")
    except ImportError:
        import subprocess
        subprocess.check_call(['pip', 'install', 'timm'])
        import timm

    # ...existing code for all main logic, data loading, training, etc. goes here...

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
def load_tiff_image(path):
    """Load a geospatial TIFF file and convert to RGB PIL Image"""
    with rasterio.open(path) as src:
        if src.count >= 3:
            img = np.stack([src.read(i) for i in [1, 2, 3]], axis=-1)
        else:
            img = np.stack([src.read(1)] * 3, axis=-1)
        
        if img.max() > 255:
            img = (img / img.max() * 255).astype(np.uint8)
        else:
            img = img.astype(np.uint8)
        
        return Image.fromarray(img)

In [None]:
# =============================================================================
# DATASET CLASS
# =============================================================================
class FloodDataset(Dataset):
    def __init__(self, pre_paths, post_paths, labels, transform=None):
        self.pre_paths = list(pre_paths)
        self.post_paths = list(post_paths)
        self.labels = list(labels)
        self.transform = transform
    
    def __len__(self):
        return len(self.pre_paths)
    
    def __getitem__(self, idx):
        pre_img = load_tiff_image(self.pre_paths[idx])
        post_img = load_tiff_image(self.post_paths[idx])
        
        if self.transform:
            pre_img = self.transform(pre_img)
            post_img = self.transform(post_img)
        
        label = self.labels[idx]
        return pre_img, post_img, label

In [None]:
# =============================================================================
# SIAMESE CONVNEXT V2 MODEL
# =============================================================================
class SiameseConvNeXtV2(nn.Module):
    def __init__(self, num_classes=3, freeze_backbone=False, model_name='convnextv2_nano'):
        super(SiameseConvNeXtV2, self).__init__()
        
        self.backbone = timm.create_model(
            model_name + '.fcmae_ft_in22k_in1k',
            pretrained=True,
            num_classes=0,
            global_pool='avg'
        )
        
        self.feature_dim = self.backbone.num_features
        
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        # Slightly larger classifier for more capacity
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),           # Reduced dropout since we have more regularization
            nn.Linear(self.feature_dim * 3, 512),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )
        
        print(f"Model: {model_name}")
        print(f"Feature dimension: {self.feature_dim}")
        print(f"Backbone frozen: {freeze_backbone}")
    
    def forward_one(self, x):
        return self.backbone(x)
    
    def forward(self, pre_img, post_img):
        feat_pre = self.forward_one(pre_img)
        feat_post = self.forward_one(post_img)
        feat_diff = torch.abs(feat_post - feat_pre)
        combined = torch.cat([feat_pre, feat_post, feat_diff], dim=1)
        return self.classifier(combined)

In [None]:
# =============================================================================
# DATA TRANSFORMS (Enhanced for fine-tuning)
# =============================================================================
train_transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(20),  # Increased rotation
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15)),
    transforms.RandomGrayscale(p=0.05),  # Occasional grayscale
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)),  # Cutout augmentation
])

In [None]:
val_transform = transforms.Compose([
    transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# TTA transforms
tta_transforms = [
    val_transform,
    transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.RandomVerticalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.RandomRotation((90, 90)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
]

In [None]:
# =============================================================================
# LEARNING RATE SCHEDULER WITH WARMUP
# =============================================================================
def get_lr_scheduler(optimizer, num_warmup_steps, num_training_steps):
    """Linear warmup then cosine decay"""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
# =============================================================================
# TRAINING FUNCTION FOR ONE FOLD
# =============================================================================
def train_fold(fold_idx, train_loader, val_loader, class_weights, num_training_steps):
    print(f"\n{'='*50}")
    print(f"FOLD {fold_idx + 1}/{CONFIG['n_folds']}")
    print(f"{'='*50}")
    
    # Training history
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'val_f1': [], 'learning_rate': []
    }
    
    # Create model
    model = SiameseConvNeXtV2(
        num_classes=3, 
        freeze_backbone=CONFIG['freeze_backbone'],
        model_name=CONFIG['model_name']
    )
    model = model.to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Different learning rates for backbone vs classifier
    backbone_params = list(model.backbone.parameters())
    classifier_params = list(model.classifier.parameters())
    
    optimizer = optim.AdamW([
        {'params': backbone_params, 'lr': CONFIG['backbone_lr']},
        {'params': classifier_params, 'lr': CONFIG['learning_rate']}
    ], weight_decay=0.05)
    
    # Warmup + Cosine scheduler
    num_warmup_steps = CONFIG['warmup_epochs'] * len(train_loader)
    scheduler = get_lr_scheduler(optimizer, num_warmup_steps, num_training_steps)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    scaler = torch.amp.GradScaler('cuda')
    
    best_val_f1 = 0
    patience_counter = 0
    best_model_state = None
    global_step = 0
    
    for epoch in range(CONFIG['epochs']):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for pre_img, post_img, label in train_loader:
            pre_img = pre_img.to(device)
            post_img = post_img.to(device)
            label = label.to(device)
            
            optimizer.zero_grad()
            
            with torch.amp.autocast('cuda'):
                outputs = model(pre_img, post_img)
                loss = criterion(outputs, label)
            
            scaler.scale(loss).backward()
            
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
            
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            global_step += 1
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += label.size(0)
            train_correct += predicted.eq(label).sum().item()
        
        train_loss /= len(train_loader)
        train_acc = 100. * train_correct / train_total
        
        # Validation
        model.eval()
        val_preds = []
        val_labels = []
        val_loss = 0
        
        with torch.no_grad():
            for pre_img, post_img, label in val_loader:
                pre_img = pre_img.to(device)
                post_img = post_img.to(device)
                label = label.to(device)
                
                with torch.amp.autocast('cuda'):
                    outputs = model(pre_img, post_img)
                    loss = criterion(outputs, label)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_preds.extend(predicted.cpu().numpy())
                val_labels.extend(label.cpu().numpy())
        
        val_loss /= len(val_loader)
        val_acc = accuracy_score(val_labels, val_preds) * 100
        val_f1 = f1_score(val_labels, val_preds, average='macro')
        
        # Record history
        current_lr = optimizer.param_groups[1]['lr']  # Classifier LR
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['learning_rate'].append(current_lr)
        
        # GPU memory on first epoch
        if epoch == 0:
            mem_used = torch.cuda.memory_allocated() / 1024**3
            mem_reserved = torch.cuda.memory_reserved() / 1024**3
            print(f"GPU Memory: {mem_used:.2f} GB used, {mem_reserved:.2f} GB reserved ({100*mem_used/gpu_memory:.0f}%)")
        
        # Early stopping
        improved = ""
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
            improved = " ★"
        else:
            patience_counter += 1
        
        print(f"Epoch {epoch+1:2d}/{CONFIG['epochs']} | "
              f"Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.1f}% | "
              f"Val F1: {val_f1:.3f} | LR: {current_lr:.2e}{improved}")
        
        if patience_counter >= CONFIG['patience']:
            print(f"Early stopping at epoch {epoch + 1}")
            break
    
    # Load best model
    model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
    
    return model, best_val_f1, history

In [None]:
# =============================================================================
# TTA PREDICTION
# =============================================================================
def predict_with_tta(models, pre_path, post_path):
    all_probs = []
    
    for model in models:
        model.eval()
        
        if CONFIG['tta_enabled']:
            for tta_transform in tta_transforms:
                pre_img = load_tiff_image(pre_path)
                post_img = load_tiff_image(post_path)
                
                pre_tensor = tta_transform(pre_img).unsqueeze(0).to(device)
                post_tensor = tta_transform(post_img).unsqueeze(0).to(device)
                
                with torch.no_grad():
                    with torch.amp.autocast('cuda'):
                        output = model(pre_tensor, post_tensor)
                    probs = torch.softmax(output, dim=1)
                    all_probs.append(probs.cpu().numpy())
        else:
            pre_img = load_tiff_image(pre_path)
            post_img = load_tiff_image(post_path)
            pre_tensor = val_transform(pre_img).unsqueeze(0).to(device)
            post_tensor = val_transform(post_img).unsqueeze(0).to(device)
            
            with torch.no_grad():
                output = model(pre_tensor, post_tensor)
                probs = torch.softmax(output, dim=1)
                all_probs.append(probs.cpu().numpy())
    
    avg_probs = np.mean(all_probs, axis=0)
    prediction = np.argmax(avg_probs, axis=1)[0]
    confidence = np.max(avg_probs)
    
    return prediction, confidence

=============================================================================
PREPARE DATA
=============================================================================

In [None]:
if __name__ == "__main__":
    # ...existing code for GPU check and timm import...

    print("\n\nSTEP 2: PREPARING DATASET")
    print("-" * 70)

    if os.path.exists('unsupervised_results.csv'):
        df = pd.read_csv('unsupervised_results.csv')
        print(f"Loaded {len(df)} samples from unsupervised_results.csv")
    else:
        print("ERROR: Run run_analysis.py first!")
        exit()

    pre_paths, post_paths, labels = [], [], []
    for _, row in df.iterrows():
        pre_path = os.path.join(PRE_DIR, row['pre_filename'])
        post_path = os.path.join(POST_DIR, row['post_filename'])
        if os.path.exists(pre_path) and os.path.exists(post_path):
            pre_paths.append(pre_path)
            post_paths.append(post_path)
            labels.append(int(row['cluster']))

    pre_paths = np.array(pre_paths)
    post_paths = np.array(post_paths)
    labels = np.array(labels)

    print(f"Found {len(pre_paths)} valid image pairs")
    print(f"Label distribution: {pd.Series(labels).value_counts().sort_index().to_dict()}")

    # Class weights
    class_counts = pd.Series(labels).value_counts().sort_index()
    class_weights = torch.tensor([1.0 / c for c in class_counts.values], dtype=torch.float32)
    class_weights = class_weights / class_weights.sum() * len(class_counts)
    class_weights = class_weights.to(device)
    print(f"Class weights: {class_weights.cpu().numpy()}")

    # =============================================================================
    # K-FOLD TRAINING
    # =============================================================================
    print("\n\nSTEP 3: K-FOLD CROSS-VALIDATION (GPU OPTIMIZED)")
    print("-" * 70)
    print(f"Model: ConvNeXt V2 Nano (UNFROZEN - Full fine-tuning)")
    print(f"Batch size: {CONFIG['batch_size']}")
    print(f"Learning rate: Backbone={CONFIG['backbone_lr']}, Classifier={CONFIG['learning_rate']}")

    skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=42)

    fold_models = []
    fold_f1_scores = []
    fold_accuracies = []
    all_val_preds = []
    all_val_labels = []
    all_val_probs = []
    fold_histories = []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(pre_paths, labels)):
        train_dataset = FloodDataset(pre_paths[train_idx], post_paths[train_idx], labels[train_idx], train_transform)
        val_dataset = FloodDataset(pre_paths[val_idx], post_paths[val_idx], labels[val_idx], val_transform)
        
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True,
                                  num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
                                num_workers=CONFIG['num_workers'], pin_memory=True)
        
        num_training_steps = CONFIG['epochs'] * len(train_loader)
        
        model, best_f1, history = train_fold(fold_idx, train_loader, val_loader, class_weights, num_training_steps)
        fold_models.append(model)
        fold_f1_scores.append(best_f1)
        fold_histories.append(history)
        
        # Evaluate
        model.eval()
        fold_preds, fold_labels_list = [], []
        
        with torch.no_grad():
            for pre_img, post_img, label in val_loader:
                pre_img, post_img = pre_img.to(device), post_img.to(device)
                outputs = model(pre_img, post_img)
                probs = torch.softmax(outputs, dim=1)
                _, predicted = outputs.max(1)
                fold_preds.extend(predicted.cpu().numpy())
                fold_labels_list.extend(label.numpy())
                all_val_probs.extend(probs.cpu().numpy())
        
        fold_acc = accuracy_score(fold_labels_list, fold_preds) * 100
        fold_accuracies.append(fold_acc)
        all_val_preds.extend(fold_preds)
        all_val_labels.extend(fold_labels_list)
        
        print(f"Fold {fold_idx + 1} Complete - Accuracy: {fold_acc:.1f}%, F1: {best_f1:.3f}")
        
        # Save model
        torch.save(model.state_dict(), f'convnextv2_gpu_fold_{fold_idx + 1}.pth')
        torch.cuda.empty_cache()

    # =============================================================================
    # RESULTS
    # =============================================================================
    print("\n\nSTEP 4: CROSS-VALIDATION RESULTS")
    print("-" * 70)

    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)
    mean_f1 = np.mean(fold_f1_scores)
    std_f1 = np.std(fold_f1_scores)

    print(f"\n{'='*50}")
    print(f"CONVNEXT V2 NANO (GPU UNFROZEN) RESULTS")
    print(f"{'='*50}")
    print(f"Accuracy: {mean_acc:.1f}% ± {std_acc:.1f}%")
    print(f"F1 Score: {mean_f1:.3f} ± {std_f1:.3f}")
    print(f"{'='*50}")

    print("\nClassification Report:")
    print(classification_report(all_val_labels, all_val_preds, target_names=['Cluster 0', 'Cluster 1', 'Cluster 2']))

    # =============================================================================
    # SAVE METRIC PLOTS
    # =============================================================================
    print("\n\nSTEP 4.5: SAVING METRIC PLOTS")
    print("-" * 70)

    # Plot 1-13: Training curves and metrics (same as before but with new data)
    # ... [Similar plotting code as original but saves to PLOT_DIR]

    # Training loss curves
    plt.figure(figsize=(10, 6))
    for i, history in enumerate(fold_histories):
        plt.plot(history['train_loss'], label=f'Fold {i+1} Train', linestyle='-', alpha=0.8)
        plt.plot(history['val_loss'], label=f'Fold {i+1} Val', linestyle='--', alpha=0.8)
    plt.xlabel('Epoch'); plt.ylabel('Loss')
    plt.title('Training & Validation Loss (GPU Unfrozen)', fontweight='bold')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(f'{PLOT_DIR}/01_training_loss_curves.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/01_training_loss_curves.png")

    # Training accuracy curves
    plt.figure(figsize=(10, 6))
    for i, history in enumerate(fold_histories):
        plt.plot(history['train_acc'], label=f'Fold {i+1} Train', linestyle='-', alpha=0.8)
        plt.plot(history['val_acc'], label=f'Fold {i+1} Val', linestyle='--', alpha=0.8)
    plt.xlabel('Epoch'); plt.ylabel('Accuracy (%)')
    plt.title('Training & Validation Accuracy (GPU Unfrozen)', fontweight='bold')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(f'{PLOT_DIR}/02_training_accuracy_curves.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/02_training_accuracy_curves.png")

    # Per-fold accuracy
    plt.figure(figsize=(10, 6))
    bars = plt.bar(range(1, CONFIG['n_folds'] + 1), fold_accuracies, color='steelblue', alpha=0.7)
    plt.axhline(y=mean_acc, color='red', linestyle='--', label=f'Mean: {mean_acc:.1f}%')
    for bar, acc in zip(bars, fold_accuracies):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, f'{acc:.1f}%', ha='center')
    plt.xlabel('Fold'); plt.ylabel('Accuracy (%)')
    plt.title('Per-Fold Accuracy (GPU Unfrozen)', fontweight='bold')
    plt.legend(); plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/03_per_fold_accuracy.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/03_per_fold_accuracy.png")

    # Confusion matrix
    plt.figure(figsize=(8, 6))
    cm = confusion_matrix(all_val_labels, all_val_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Cluster 0', 'Cluster 1', 'Cluster 2'],
                yticklabels=['Cluster 0', 'Cluster 1', 'Cluster 2'],
                annot_kws={'size': 14, 'weight': 'bold'})
    plt.xlabel('Predicted'); plt.ylabel('Actual')
    plt.title(f'CV Confusion Matrix (GPU Unfrozen)\nAccuracy: {mean_acc:.1f}%', fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{PLOT_DIR}/04_confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/04_confusion_matrix.png")

    # Per-class metrics
    precision_per_class = precision_score(all_val_labels, all_val_preds, average=None)
    recall_per_class = recall_score(all_val_labels, all_val_preds, average=None)
    f1_per_class = f1_score(all_val_labels, all_val_preds, average=None)

    x = np.arange(3)
    width = 0.25
    plt.figure(figsize=(10, 6))
    plt.bar(x - width, precision_per_class, width, label='Precision', color='#3498db')
    plt.bar(x, recall_per_class, width, label='Recall', color='#2ecc71')
    plt.bar(x + width, f1_per_class, width, label='F1', color='#e74c3c')
    plt.xlabel('Class'); plt.ylabel('Score')
    plt.title('Per-Class Metrics (GPU Unfrozen)', fontweight='bold')
    plt.xticks(x, ['Cluster 0', 'Cluster 1', 'Cluster 2'])
    plt.legend(); plt.ylim(0, 1.1); plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/05_per_class_metrics.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/05_per_class_metrics.png")

    # ROC Curves
    all_val_probs_arr = np.array(all_val_probs)
    all_val_labels_bin = label_binarize(all_val_labels, classes=[0, 1, 2])
    plt.figure(figsize=(10, 8))
    colors = ['#e74c3c', '#3498db', '#2ecc71']
    for i in range(3):
        fpr, tpr, _ = roc_curve(all_val_labels_bin[:, i], all_val_probs_arr[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, color=colors[i], linewidth=2, label=f'Cluster {i} (AUC={roc_auc:.3f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
    plt.title('ROC Curves (GPU Unfrozen)', fontweight='bold')
    plt.legend(loc='lower right'); plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/06_roc_curves.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/06_roc_curves.png")

    # Additional metric plots
    # 1. Validation F1 curves
    plt.figure(figsize=(10, 6))
    for i, history in enumerate(fold_histories):
        plt.plot(history['val_f1'], label=f'Fold {i+1}')
    plt.xlabel('Epoch'); plt.ylabel('Validation F1')
    plt.title('Validation F1 Curves', fontweight='bold')
    plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(f'{PLOT_DIR}/07_validation_f1_curves.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/07_validation_f1_curves.png")

    # 2. Learning rate schedule
    plt.figure(figsize=(10, 6))
    for i, history in enumerate(fold_histories):
        plt.plot(history['learning_rate'], label=f'Fold {i+1}')
    plt.xlabel('Epoch'); plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule', fontweight='bold')
    plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(f'{PLOT_DIR}/08_learning_rate_schedule.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/08_learning_rate_schedule.png")

    # 3. Per-fold F1 score
    plt.figure(figsize=(10, 6))
    bars = plt.bar(range(1, CONFIG['n_folds'] + 1), fold_f1_scores, color='purple', alpha=0.7)
    plt.axhline(y=mean_f1, color='red', linestyle='--', label=f'Mean: {mean_f1:.3f}')
    for bar, f1 in zip(bars, fold_f1_scores):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{f1:.3f}', ha='center')
    plt.xlabel('Fold'); plt.ylabel('F1 Score')
    plt.title('Per-Fold F1 Score', fontweight='bold')
    plt.legend(); plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/09_per_fold_f1_score.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/09_per_fold_f1_score.png")

    # 4. Normalized CV confusion matrix
    plt.figure(figsize=(8, 6))
    cm_norm = confusion_matrix(all_val_labels, all_val_preds, normalize='true')
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', 
                xticklabels=['Cluster 0', 'Cluster 1', 'Cluster 2'],
                yticklabels=['Cluster 0', 'Cluster 1', 'Cluster 2'],
                annot_kws={'size': 14, 'weight': 'bold'})
    plt.xlabel('Predicted'); plt.ylabel('Actual')
    plt.title('CV Confusion Matrix (Normalized)', fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{PLOT_DIR}/10_cv_confusion_matrix_normalized.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/10_cv_confusion_matrix_normalized.png")

    # 5. Class distribution
    plt.figure(figsize=(8, 6))
    class_counts = pd.Series(labels).value_counts().sort_index()
    plt.bar(class_counts.index, class_counts.values, color=colors)
    plt.xlabel('Class'); plt.ylabel('Count')
    plt.title('Class Distribution', fontweight='bold')
    plt.xticks([0, 1, 2], ['Cluster 0', 'Cluster 1', 'Cluster 2'])
    plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/11_class_distribution.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/11_class_distribution.png")

    # 6. Precision-Recall curves
    plt.figure(figsize=(10, 8))
    for i in range(3):
        pr, rc, _ = precision_recall_curve(all_val_labels_bin[:, i], all_val_probs_arr[:, i])
        pr_auc = auc(rc, pr)
        plt.plot(rc, pr, color=colors[i], linewidth=2, label=f'Cluster {i} (AUC={pr_auc:.3f})')
    plt.xlabel('Recall'); plt.ylabel('Precision')
    plt.title('Precision-Recall Curves', fontweight='bold')
    plt.legend(loc='lower left'); plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/12_precision_recall_curves.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/12_precision_recall_curves.png")

    # 7. Performance summary (bar plot)

    # The following plots require ensemble_preds and ensemble_confidences, so move them after ensemble evaluation.

    # =============================================================================
    # ENSEMBLE + TTA
    # =============================================================================
    print("\n\nSTEP 5: ENSEMBLE + TTA EVALUATION")
    print("-" * 70)

    # FINAL RESULTS (ConvNeXt V2 Nano - GPU UNFROZEN):
    # {'─'*50}
    ensemble_preds = []
    ensemble_confidences = []

    print("Running ensemble prediction with TTA...")
    for i in range(len(pre_paths)):
        pred, conf = predict_with_tta(fold_models, pre_paths[i], post_paths[i])
        ensemble_preds.append(pred)
        ensemble_confidences.append(conf)
        if (i + 1) % 100 == 0:
            print(f"  Processed {i + 1}/{len(pre_paths)} samples...")

    ensemble_acc = accuracy_score(labels, ensemble_preds) * 100
    ensemble_f1 = f1_score(labels, ensemble_preds, average='macro')

    print(f"\nEnsemble + TTA Results:")
    print(f"  Accuracy: {ensemble_acc:.1f}%")
    print(f"  Macro F1: {ensemble_f1:.3f}")

    # 7. Performance summary (bar plot)
    plt.figure(figsize=(8, 6))
    plt.bar(['CV Accuracy', 'CV F1', 'Ensemble Acc', 'Ensemble F1'],
            [mean_acc/100, mean_f1, ensemble_acc/100, ensemble_f1],
            color=['#3498db', '#e74c3c', '#2ecc71', '#9b59b6'])
    plt.ylim(0, 1.1)
    plt.ylabel('Score')
    plt.title('Performance Summary', fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/13_performance_summary.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/13_performance_summary.png")

    # 8. Ensemble confusion matrix
    plt.figure(figsize=(8, 6))
    cm_ens = confusion_matrix(labels, ensemble_preds)
    sns.heatmap(cm_ens, annot=True, fmt='d', cmap='Greens', 
                xticklabels=['Cluster 0', 'Cluster 1', 'Cluster 2'],
                yticklabels=['Cluster 0', 'Cluster 1', 'Cluster 2'],
                annot_kws={'size': 14, 'weight': 'bold'})
    plt.xlabel('Predicted'); plt.ylabel('Actual')
    plt.title('Ensemble Confusion Matrix', fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'{PLOT_DIR}/14_ensemble_confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/14_ensemble_confusion_matrix.png")

    # 9. Ensemble confidence distribution
    plt.figure(figsize=(8, 6))
    plt.hist(ensemble_confidences, bins=30, color='#f39c12', alpha=0.8)
    plt.xlabel('Confidence'); plt.ylabel('Count')
    plt.title('Ensemble Confidence Distribution', fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/15_ensemble_confidence_distribution.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/15_ensemble_confidence_distribution.png")

    # 10. Confidence by correctness
    correct = np.array(ensemble_preds) == np.array(labels)
    plt.figure(figsize=(8, 6))
    plt.hist(np.array(ensemble_confidences)[correct], bins=30, alpha=0.7, label='Correct', color='#2ecc71')
    plt.hist(np.array(ensemble_confidences)[~correct], bins=30, alpha=0.7, label='Incorrect', color='#e74c3c')
    plt.xlabel('Confidence'); plt.ylabel('Count')
    plt.title('Confidence by Correctness', fontweight='bold')
    plt.legend(); plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/16_confidence_by_correctness.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/16_confidence_by_correctness.png")

    # 11. Confidence by class
    plt.figure(figsize=(10, 6))
    for i in range(3):
        plt.hist(np.array(ensemble_confidences)[np.array(ensemble_preds)==i], bins=30, alpha=0.7, label=f'Class {i}', color=colors[i])
    plt.xlabel('Confidence'); plt.ylabel('Count')
    plt.title('Confidence by Predicted Class', fontweight='bold')
    plt.legend(); plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/17_confidence_by_class.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/17_confidence_by_class.png")

    # 12. CV vs Ensemble comparison
    plt.figure(figsize=(8, 6))
    plt.bar(['CV Accuracy', 'Ensemble Accuracy'], [mean_acc, ensemble_acc], color=['#3498db', '#2ecc71'], alpha=0.8)
    plt.bar(['CV F1', 'Ensemble F1'], [mean_f1*100, ensemble_f1*100], color=['#e74c3c', '#9b59b6'], alpha=0.8)
    plt.ylabel('Score (%)')
    plt.title('CV vs Ensemble Comparison', fontweight='bold')
    plt.ylim(0, 110)
    plt.grid(True, alpha=0.3)
    plt.savefig(f'{PLOT_DIR}/18_cv_vs_ensemble_comparison.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved {PLOT_DIR}/18_cv_vs_ensemble_comparison.png")

    # =============================================================================
    # SAVE RESULTS
    # =============================================================================
    print("\n\nSTEP 6: SAVING RESULTS")
    print("-" * 70)

    results_df = df.copy()
    results_df['ensemble_prediction'] = ensemble_preds
    results_df['confidence'] = ensemble_confidences
    results_df.to_csv('convnextv2_gpu_ensemble_results.csv', index=False)
    print("✓ Saved convnextv2_gpu_ensemble_results.csv")

    # =============================================================================
    # FINAL SUMMARY
    # =============================================================================
    print("\n" + "="*70)
    print("GPU OPTIMIZED TRAINING COMPLETE!")
    print("="*70)
    print(f"""
FINAL RESULTS (ConvNeXt V2 Nano - GPU UNFROZEN):
{'─'*50}
  Trainable Parameters:    15,475,347 (ALL)
  Cross-Validation Acc:    {mean_acc:.1f}% ± {std_acc:.1f}%
  Cross-Validation F1:     {mean_f1:.3f} ± {std_f1:.3f}
  Ensemble + TTA Acc:      {ensemble_acc:.1f}%
  Ensemble + TTA F1:       {ensemble_f1:.3f}
{'─'*50}

Saved Files:
  - convnextv2_gpu_fold_1.pth to convnextv2_gpu_fold_5.pth
  - convnextv2_gpu_ensemble_results.csv
  - {PLOT_DIR}/*.png (metric plots)
""")

In [None]:
torch.cuda.empty_cache()