# Vanilla CNN

In [None]:
# Uncomment if needed:
# !pip install -q torch torchvision scikit-learn matplotlib seaborn

# Standard libraries
import os
import random
import numpy as np
import zipfile
from pathlib import Path

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Metrics
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    roc_curve,
    confusion_matrix
)

# Misc
from tqdm import tqdm
import time

print("All libraries imported successfully!")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


# Set random seeds for reproducibility
RANDOM_SEED = 42

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

print("‚úÖ Google Drive mounted successfully!")

# Define path to your zip file (no quotes around folder names!)
ZIP_PATH = '/content/drive/MyDrive/NEU - MS CS/3_SEM/CS - 7150 (DeepLearning)/midterm-proj-1/data/faceforensics_dataset.zip'

# Check if file exists
if os.path.exists(ZIP_PATH):
    file_size = os.path.getsize(ZIP_PATH) / (1024**3)  # GB
    print(f"‚úÖ Found dataset: {os.path.basename(ZIP_PATH)}")
    print(f"   Size: {file_size:.2f} GB")
else:
    print(f"‚ùå File not found: {ZIP_PATH}")
    print(f"   Please check the path!")

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"üñ•Ô∏è Device Configuration:")
print(f"   Device: {device}")

if device.type == 'cuda':
    print(f"   GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
    print(f"   ‚úÖ Using GPU for training!")
else:
    print(f"   ‚ö†Ô∏è GPU not available, using CPU")
    print(f"   Note: Training will be much slower on CPU")

# Training configuration
CONFIG = {
    # Data
    'batch_size': 64,              # Number of samples per batch
    'num_workers': 2,              # Parallel data loading (2 for Colab)

    # Training
    'num_epochs': 10,              # Total training epochs
    'learning_rate': 0.001,        # Learning rate for optimizer

    # Checkpoints
    'checkpoint_interval': 10000,  # Save checkpoint every 10k samples
    'save_dir': '/content/drive/MyDrive/NEU - MS CS/3_SEM/CS - 7150 (Deep Learning)/midterm-proj-1/models',

    # Data split
    'train_size': 90000,           # 90k for training
    'test_size': 10000,            # 10k for testing

    # Model
    'input_channels': 3,           # RGB images
    'image_size': 64,              # 64x64 images
    'num_classes': 2,              # Binary: Real vs Fake
}

# Create save directory if it doesn't exist
os.makedirs(CONFIG['save_dir'], exist_ok=True)

print("üìã Configuration:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

print(f"\n‚úÖ Configuration set!")

# Define extraction directory
EXTRACT_DIR = '/content/extracted_data'

# Create directory if it doesn't exist
os.makedirs(EXTRACT_DIR, exist_ok=True)

print(f"üìÇ Extracting to: {EXTRACT_DIR}")
print(f"‚è≥ This may take a moment...")

start_time = time.time()

# Extract zip file
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
    zip_ref.extractall(EXTRACT_DIR)

elapsed = time.time() - start_time

print(f"‚úÖ Extraction complete in {elapsed:.2f} seconds")

# List extracted files
print(f"\nüìÑ Extracted files:")
extracted_files = os.listdir(EXTRACT_DIR)
for file in extracted_files:
    file_path = os.path.join(EXTRACT_DIR, file)
    if os.path.isfile(file_path):
        size_mb = os.path.getsize(file_path) / (1024**2)
        print(f"   ‚Ä¢ {file} ({size_mb:.2f} MB)")


print("\nüì• LOADING TENSORS")
print("="*60)

# Define paths to extracted tensor files
X_path = os.path.join(EXTRACT_DIR, 'FaceForensics_X.pt')
y_path = os.path.join(EXTRACT_DIR, 'FaceForensics_y.pt')

# Check if files exist
if not os.path.exists(X_path):
    print(f"‚ùå X tensor not found at: {X_path}")
    raise FileNotFoundError("X tensor file missing!")

if not os.path.exists(y_path):
    print(f"‚ùå y tensor not found at: {y_path}")
    raise FileNotFoundError("y tensor file missing!")

print(f"üìÇ Loading X tensor from: {os.path.basename(X_path)}")
X = torch.load(X_path)

print(f"üìÇ Loading y tensor from: {os.path.basename(y_path)}")
y = torch.load(y_path)

print(f"\n‚úÖ Tensors loaded successfully!")

# Display tensor information
print(f"\nüìä Tensor Information:")
print(f"   X shape: {X.shape} (N, C, H, W)")
print(f"   y shape: {y.shape} (N,)")
print(f"   X dtype: {X.dtype}")
print(f"   y dtype: {y.dtype}")
print(f"   X range: [{X.min():.4f}, {X.max():.4f}]")
print(f"   y unique values: {torch.unique(y).tolist()}")

# Memory usage
X_memory = X.element_size() * X.nelement() / (1024**3)
y_memory = y.element_size() * y.nelement() / (1024**2)

print(f"\nüíæ Memory Usage:")
print(f"   X tensor: {X_memory:.2f} GB")
print(f"   y tensor: {y_memory:.2f} MB")

# Count samples per class
unique_labels, counts = torch.unique(y, return_counts=True)

print(f"üìä Label Distribution:")
for label, count in zip(unique_labels, counts):
    label_name = "Real" if label == 0 else "Fake"
    percentage = (count / len(y)) * 100
    print(f"   {label} ({label_name}): {count:,} samples ({percentage:.1f}%)")

# Check balance
balance_ratio = counts[0].item() / counts[1].item()
print(f"\n‚öñÔ∏è Balance Ratio: {balance_ratio:.3f}:1")

if 0.95 <= balance_ratio <= 1.05:
    print(f"   ‚úÖ Dataset is well balanced!")
else:
    print(f"   ‚ö†Ô∏è Dataset has imbalance")

# Check for any issues
print(f"\nüîç Data Quality Checks:")
print(f"   NaN in X: {torch.isnan(X).any().item()}")
print(f"   Inf in X: {torch.isinf(X).any().item()}")
print(f"   All y values valid (0 or 1): {torch.all((y == 0) | (y == 1)).item()}")

if torch.isnan(X).any() or torch.isinf(X).any():
    print(f"   ‚ùå Warning: Data contains NaN or Inf values!")
else:
    print(f"   ‚úÖ No NaN or Inf values detected")


print("\nüîÄ SHUFFLING DATA")
print("="*60)

# Get total number of samples
num_samples = X.shape[0]

# Create indices for all samples
indices = torch.randperm(num_samples)

# Shuffle X and y using the same indices
X_shuffled = X[indices]
y_shuffled = y[indices]

print(f"‚úÖ Data shuffled with random seed {RANDOM_SEED}")

# Verify shuffle worked
print(f"\nüîç Verification:")
print(f"   Original first 10 labels: {y[:10].tolist()}")
print(f"   Shuffled first 10 labels: {y_shuffled[:10].tolist()}")
print(f"   (Should be different)")

# Verify data integrity after shuffle
assert X_shuffled.shape == X.shape, "X shape changed after shuffle!"
assert y_shuffled.shape == y.shape, "y shape changed after shuffle!"
assert len(torch.unique(y_shuffled)) == 2, "Labels corrupted after shuffle!"

print(f"\n‚úÖ Shuffle verified - data integrity maintained")

# Replace original tensors with shuffled versions
X = X_shuffled
y = y_shuffled

# Clean up
del X_shuffled, y_shuffled

print("\n‚úÇÔ∏è TRAIN/TEST SPLIT")
print("="*60)

# Use 90/10 ratio instead of fixed numbers
train_ratio = 0.9
test_ratio = 0.1

# Calculate split sizes based on actual dataset size
num_samples = len(X)
train_size = int(num_samples * train_ratio)
test_size = num_samples - train_size  # Remaining samples go to test

print(f"üìä Split Configuration:")
print(f"   Total samples: {num_samples:,}")
print(f"   Train ratio: {train_ratio*100:.0f}%")
print(f"   Test ratio: {test_ratio*100:.0f}%")
print(f"   Calculated train size: {train_size:,}")
print(f"   Calculated test size: {test_size:,}")

# Verify
assert train_size + test_size == num_samples, "Split calculation error!"

# Split data
X_train = X[:train_size]
y_train = y[:train_size]

X_test = X[train_size:]  # All remaining samples
y_test = y[train_size:]

print(f"\n‚úÖ Split Complete:")
print(f"   Train samples: {len(X_train):,} ({len(X_train)/num_samples*100:.1f}%)")
print(f"   Test samples: {len(X_test):,} ({len(X_test)/num_samples*100:.1f}%)")

# Check label distribution in splits
train_real = (y_train == 0).sum().item()
train_fake = (y_train == 1).sum().item()
test_real = (y_test == 0).sum().item()
test_fake = (y_test == 1).sum().item()

print(f"\nüìä Train Set Distribution:")
print(f"   Real (0): {train_real:,} ({train_real/len(y_train)*100:.1f}%)")
print(f"   Fake (1): {train_fake:,} ({train_fake/len(y_train)*100:.1f}%)")

print(f"\nüìä Test Set Distribution:")
print(f"   Real (0): {test_real:,} ({test_real/len(y_test)*100:.1f}%)")
print(f"   Fake (1): {test_fake:,} ({test_fake/len(y_test)*100:.1f}%)")

# Verify balance
train_balance = train_real / train_fake if train_fake > 0 else 0
test_balance = test_real / test_fake if test_fake > 0 else 0

print(f"\n‚öñÔ∏è Balance Check:")
print(f"   Train balance: {train_balance:.3f}:1")
print(f"   Test balance: {test_balance:.3f}:1")

if 0.9 <= train_balance <= 1.1 and 0.9 <= test_balance <= 1.1:
    print(f"   ‚úÖ Both splits are well balanced!")
else:
    print(f"   ‚ö†Ô∏è Some imbalance detected (still acceptable)")

# Clean up original tensors to free memory
del X, y
import gc
gc.collect()

print(f"\n‚úÖ Split complete and original tensors cleared from memory")

# Create TensorDatasets
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

print(f"üì¶ Datasets created:")
print(f"   Train dataset: {len(train_dataset):,} samples")
print(f"   Test dataset: {len(test_dataset):,} samples")

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,              # Shuffle batches for training
    num_workers=CONFIG['num_workers'],
    pin_memory=True           # Faster data transfer to GPU
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,            # Don't shuffle test set
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

print(f"\nüîÑ DataLoaders created:")
print(f"   Batch size: {CONFIG['batch_size']}")
print(f"   Train batches: {len(train_loader):,}")
print(f"   Test batches: {len(test_loader):,}")

# Calculate samples per epoch
samples_per_epoch = len(train_loader) * CONFIG['batch_size']
print(f"\nüìä Training Info:")
print(f"   Samples per epoch: {samples_per_epoch:,}")
print(f"   Batches per epoch: {len(train_loader):,}")
print(f"   Epochs planned: {CONFIG['num_epochs']}")
print(f"   Total training steps: {len(train_loader) * CONFIG['num_epochs']:,}")

# Checkpoint calculation
batches_per_checkpoint = CONFIG['checkpoint_interval'] // CONFIG['batch_size']
print(f"\nüíæ Checkpoint Info:")
print(f"   Checkpoint every: {CONFIG['checkpoint_interval']:,} samples")
print(f"   That's every: {batches_per_checkpoint} batches")
print(f"   Expected checkpoints per epoch: {len(train_loader) // batches_per_checkpoint}")

print(f"\n‚úÖ DataLoaders ready for training!")

class SimpleCNN(nn.Module):
    """
    Simple CNN for binary classification (Real vs Fake)
    Architecture: Conv layers ‚Üí Flatten ‚Üí FC layers ‚Üí Output
    """

    def __init__(self, input_channels=3, num_classes=2):
        """
        Initialize the CNN model

        Args:
            input_channels (int): Number of input channels (3 for RGB)
            num_classes (int): Number of output classes (2 for binary)
        """
        super(SimpleCNN, self).__init__()

        # Convolutional Layer 1
        # Input: (batch, 3, 64, 64)
        # Output: (batch, 16, 32, 32)
        self.conv1 = nn.Conv2d(
            in_channels=input_channels,   # 3 (RGB)
            out_channels=16,               # 16 feature maps
            kernel_size=3,                 # 3x3 filter
            stride=1,                      # Move 1 pixel at a time
            padding=1                      # Pad to maintain size
        )
        self.relu1 = nn.ReLU()            # Activation function
        self.pool1 = nn.MaxPool2d(
            kernel_size=2,                 # 2x2 pooling window
            stride=2                       # Reduces size by half
        )

        # Convolutional Layer 2
        # Input: (batch, 16, 32, 32)
        # Output: (batch, 32, 16, 16)
        self.conv2 = nn.Conv2d(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(
            kernel_size=2,
            stride=2
        )

        # Fully Connected Layers (MLP)
        # After conv layers: (batch, 32, 16, 16)
        # Flatten: (batch, 32*16*16) = (batch, 8192)
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(
            in_features=32 * 16 * 16,     # 8192 input features
            out_features=128               # 128 hidden units
        )
        self.relu3 = nn.ReLU()

        self.fc2 = nn.Linear(
            in_features=128,
            out_features=num_classes       # 2 output classes
        )

    def forward(self, x):
        """
        Forward pass through the network

        Args:
            x (torch.Tensor): Input images (batch, 3, 64, 64)

        Returns:
            torch.Tensor: Class logits (batch, 2)
        """
        # Conv block 1
        x = self.conv1(x)      # (batch, 3, 64, 64) ‚Üí (batch, 16, 64, 64)
        x = self.relu1(x)      # Apply ReLU activation
        x = self.pool1(x)      # (batch, 16, 64, 64) ‚Üí (batch, 16, 32, 32)

        # Conv block 2
        x = self.conv2(x)      # (batch, 16, 32, 32) ‚Üí (batch, 32, 32, 32)
        x = self.relu2(x)
        x = self.pool2(x)      # (batch, 32, 32, 32) ‚Üí (batch, 32, 16, 16)

        # Flatten
        x = self.flatten(x)    # (batch, 32, 16, 16) ‚Üí (batch, 8192)

        # Fully connected layers
        x = self.fc1(x)        # (batch, 8192) ‚Üí (batch, 128)
        x = self.relu3(x)
        x = self.fc2(x)        # (batch, 128) ‚Üí (batch, 2)

        return x

# Create model instance
model = SimpleCNN(
    input_channels=CONFIG['input_channels'],
    num_classes=CONFIG['num_classes']
)

# Move model to GPU/CPU
model = model.to(device)

print(f"‚úÖ Model created and moved to {device}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüìä Model Statistics:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Model size: {total_params * 4 / (1024**2):.2f} MB (float32)")

print(model)

print("\n" + "="*60)

# Test forward pass with dummy input
print("\nüß™ Testing forward pass...")
dummy_input = torch.randn(1, 3, 64, 64).to(device)  # Batch of 1 image
dummy_output = model(dummy_input)

print(f"‚úÖ Forward pass successful!")
print(f"   Input shape: {dummy_input.shape}")
print(f"   Output shape: {dummy_output.shape}")
print(f"   Output (logits): {dummy_output}")
print(f"\nüí° Interpretation:")
print(f"   Output has 2 values: [score_for_class_0, score_for_class_1]")
print(f"   Higher score = model's prediction for that class")

# Loss function for classification
criterion = nn.CrossEntropyLoss()

print(f"üìâ Loss Function: CrossEntropyLoss")
print(f"   ‚Ä¢ Combines LogSoftmax + NLLLoss")
print(f"   ‚Ä¢ Suitable for multi-class classification")
print(f"   ‚Ä¢ Expects raw logits (not probabilities)")

# Optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=CONFIG['learning_rate']
)

print(f"\n‚öôÔ∏è Optimizer: Adam")
print(f"   ‚Ä¢ Learning rate: {CONFIG['learning_rate']}")
print(f"   ‚Ä¢ Adaptive learning rates per parameter")
print(f"   ‚Ä¢ Momentum and adaptive estimates")

print(f"\n‚úÖ Training components ready!")

print("\nüîß DEFINING TRAINING HELPER FUNCTIONS")
print("="*60)

def save_checkpoint(model, optimizer, epoch, batch_idx, loss, accuracy, filepath):
    """
    Save model checkpoint to disk

    Args:
        model: The neural network model
        optimizer: The optimizer
        epoch: Current epoch number
        batch_idx: Current batch index
        loss: Current loss value
        accuracy: Current accuracy
        filepath: Where to save the checkpoint
    """
    checkpoint = {
        'epoch': epoch,
        'batch_idx': batch_idx,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'accuracy': accuracy,
    }
    torch.save(checkpoint, filepath)
    print(f"   üíæ Checkpoint saved: {os.path.basename(filepath)}")


def load_checkpoint(model, optimizer, filepath):
    """
    Load model checkpoint from disk

    Args:
        model: The neural network model
        optimizer: The optimizer
        filepath: Path to checkpoint file

    Returns:
        epoch, batch_idx: Resume training from these values
    """
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    batch_idx = checkpoint['batch_idx']
    loss = checkpoint['loss']
    accuracy = checkpoint['accuracy']

    print(f"‚úÖ Checkpoint loaded: {os.path.basename(filepath)}")
    print(f"   Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss:.4f}, Acc: {accuracy:.2f}%")

    return epoch, batch_idx


def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch,
                   checkpoint_interval, save_dir):
    """
    Train model for one epoch with checkpointing

    Args:
        model: Neural network
        train_loader: DataLoader for training data
        criterion: Loss function
        optimizer: Optimizer
        device: GPU/CPU device
        epoch: Current epoch number
        checkpoint_interval: Save checkpoint every N samples
        save_dir: Directory to save checkpoints

    Returns:
        avg_loss, avg_accuracy: Average metrics for the epoch
    """
    model.train()  # Set model to training mode

    running_loss = 0.0
    correct = 0
    total = 0
    samples_processed = 0

    # Progress bar for batches
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch_idx, (images, labels) in enumerate(pbar):
        # Move data to device
        images = images.to(device)
        labels = labels.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Update running loss
        running_loss += loss.item()
        samples_processed += images.size(0)

        # Update progress bar
        current_accuracy = 100 * correct / total
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{current_accuracy:.2f}%'
        })

        # Save checkpoint every N samples
        if (samples_processed % checkpoint_interval) < images.size(0):
            avg_loss = running_loss / (batch_idx + 1)
            checkpoint_path = os.path.join(
                save_dir,
                f'checkpoint_epoch{epoch+1}_batch{batch_idx+1}.pt'
            )
            save_checkpoint(
                model, optimizer, epoch, batch_idx,
                avg_loss, current_accuracy, checkpoint_path
            )

    # Calculate epoch averages
    avg_loss = running_loss / len(train_loader)
    avg_accuracy = 100 * correct / total

    return avg_loss, avg_accuracy

print("‚úÖ Helper functions defined:")
print("   ‚Ä¢ save_checkpoint() - Save model state")
print("   ‚Ä¢ load_checkpoint() - Load model state")
print("   ‚Ä¢ train_one_epoch() - Train for one epoch with checkpoints")


# Training history
history = {
    'train_loss': [],
    'train_accuracy': []
}

# Calculate total training info
total_batches = len(train_loader)
total_samples = len(train_loader.dataset)
batches_per_checkpoint = CONFIG['checkpoint_interval'] // CONFIG['batch_size']

print(f"\nüìä Training Configuration:")
print(f"   Total epochs: {CONFIG['num_epochs']}")
print(f"   Batches per epoch: {total_batches:,}")
print(f"   Samples per epoch: {total_samples:,}")
print(f"   Checkpoint every {CONFIG['checkpoint_interval']:,} samples (~{batches_per_checkpoint} batches)")
print(f"\n‚è±Ô∏è Estimated time per epoch: ~10-15 minutes (with GPU)")
print(f"‚è±Ô∏è Total estimated time: ~{CONFIG['num_epochs'] * 12} minutes")

print(f"\n{'='*60}")
print("üéØ TRAINING STARTED")
print(f"{'='*60}\n")

# Training loop
start_time = time.time()

for epoch in range(CONFIG['num_epochs']):
    print(f"\nüìÖ Epoch {epoch + 1}/{CONFIG['num_epochs']}")
    print("-" * 60)

    # Train for one epoch
    train_loss, train_acc = train_one_epoch(
        model=model,
        train_loader=train_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        epoch=epoch,
        checkpoint_interval=CONFIG['checkpoint_interval'],
        save_dir=CONFIG['save_dir']
    )

    # Store metrics
    history['train_loss'].append(train_loss)
    history['train_accuracy'].append(train_acc)

    # Print epoch summary
    print(f"\nüìä Epoch {epoch + 1} Summary:")
    print(f"   Average Loss: {train_loss:.4f}")
    print(f"   Average Accuracy: {train_acc:.2f}%")

    # Save epoch checkpoint
    epoch_checkpoint_path = os.path.join(
        CONFIG['save_dir'],
        f'model_epoch{epoch+1}.pt'
    )
    save_checkpoint(
        model, optimizer, epoch, len(train_loader)-1,
        train_loss, train_acc, epoch_checkpoint_path
    )

    print(f"   üíæ Epoch checkpoint saved")

# Training complete
total_time = time.time() - start_time
print(f"\n{'='*60}")
print("üéâ TRAINING COMPLETE!")
print(f"{'='*60}")
print(f"‚è±Ô∏è Total training time: {total_time/60:.2f} minutes")
print(f"üìä Final Training Accuracy: {history['train_accuracy'][-1]:.2f}%")
print(f"üìâ Final Training Loss: {history['train_loss'][-1]:.4f}")

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot loss
axes[0].plot(range(1, CONFIG['num_epochs'] + 1), history['train_loss'],
             marker='o', linewidth=2, markersize=8, color='#e74c3c')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss Over Epochs', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].set_xticks(range(1, CONFIG['num_epochs'] + 1))

# Plot accuracy
axes[1].plot(range(1, CONFIG['num_epochs'] + 1), history['train_accuracy'],
             marker='o', linewidth=2, markersize=8, color='#27ae60')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training Accuracy Over Epochs', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].set_xticks(range(1, CONFIG['num_epochs'] + 1))
axes[1].set_ylim([0, 100])

plt.tight_layout()
plt.show()

print("‚úÖ Training curves displayed")

# Print improvement
initial_acc = history['train_accuracy'][0]
final_acc = history['train_accuracy'][-1]
improvement = final_acc - initial_acc

print(f"\nüìä Training Progress:")
print(f"   Initial accuracy: {initial_acc:.2f}%")
print(f"   Final accuracy: {final_acc:.2f}%")
print(f"   Improvement: +{improvement:.2f}%")

# Save final model
final_model_path = os.path.join(CONFIG['save_dir'], 'final_model.pt')

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_loss': history['train_loss'],
    'train_accuracy': history['train_accuracy'],
    'config': CONFIG,
    'num_epochs': CONFIG['num_epochs']
}, final_model_path)

print(f"‚úÖ Final model saved to:")
print(f"   {final_model_path}")

# List all saved checkpoints
print(f"\nüìÇ Saved checkpoints:")
checkpoint_files = sorted([f for f in os.listdir(CONFIG['save_dir']) if f.endswith('.pt')])
for i, filename in enumerate(checkpoint_files, 1):
    filepath = os.path.join(CONFIG['save_dir'], filename)
    size_mb = os.path.getsize(filepath) / (1024**2)
    print(f"   {i}. {filename} ({size_mb:.2f} MB)")

print(f"\n‚úÖ All checkpoints saved to Google Drive")
print(f"   These will persist even after session ends!")

# Load the final trained model
final_model_path = os.path.join(CONFIG['save_dir'], 'final_model.pt')

if os.path.exists(final_model_path):
    checkpoint = torch.load(final_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"‚úÖ Loaded final model from: {os.path.basename(final_model_path)}")

    # Display training history from checkpoint
    if 'train_accuracy' in checkpoint:
        print(f"\nüìä Training History:")
        print(f"   Final training accuracy: {checkpoint['train_accuracy'][-1]:.2f}%")
        print(f"   Final training loss: {checkpoint['train_loss'][-1]:.4f}")
else:
    print(f"‚ö†Ô∏è Final model not found, using current model state")

# Set model to evaluation mode
model.eval()
print(f"\n‚úÖ Model set to evaluation mode")
print(f"   (Dropout disabled, BatchNorm in eval mode)")

# Storage for predictions and labels
all_predictions = []
all_labels = []
all_probs = []

print(f"üìä Test set: {len(test_loader.dataset):,} samples")
print(f"   Processing {len(test_loader)} batches...")

# Disable gradient computation for evaluation
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating"):
        # Move to device
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)

        # Get probabilities (apply softmax to logits)
        probs = F.softmax(outputs, dim=1)

        # Get predicted class (0 or 1)
        _, predicted = torch.max(outputs, 1)

        # Store results
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)

print(f"\n‚úÖ Evaluation complete!")
print(f"   Predictions collected: {len(all_predictions):,}")

