In [3]:
import os
import shutil
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm

print("="*80)
print("CELL 1: DATA SPLIT - 80/10/10 WITH CLASS BALANCING")
print("="*80)

# ========== CONFIGURATION ==========
SOURCE_DIR = '/kaggle/input/segmentedimages'
OUTPUT_DIR = '/kaggle/working'
SPLIT_DIR = os.path.join(OUTPUT_DIR, 'split_80_10_10')

os.makedirs(SPLIT_DIR, exist_ok=True)

print(f"\n✓ Source directory: {SOURCE_DIR}")
print(f"✓ Output directory: {SPLIT_DIR}")

# ========== 1. COLLECT ALL IMAGES ==========
print("\n[1/4] Collecting all images from source directory...")

image_dict = {}  # {class: [list of image paths]}
for split in ['train', 'val', 'test']:
    split_path = os.path.join(SOURCE_DIR, split)
    if not os.path.exists(split_path):
        continue
    
    for cls in os.listdir(split_path):
        cls_path = os.path.join(split_path, cls)
        if not os.path.isdir(cls_path):
            continue
        
        if cls not in image_dict:
            image_dict[cls] = []
        
        images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        for img in images:
            image_dict[cls].append(os.path.join(cls_path, img))

print(f"✓ Total classes found: {len(image_dict)}")
for cls, images in image_dict.items():
    print(f"  {cls}: {len(images)} images")

total_images = sum(len(v) for v in image_dict.values())
print(f"✓ Total images: {total_images}")

# ========== 2. STRATIFIED SPLIT ==========
print("\n[2/4] Performing stratified split (80/10/10)...")

split_results = {}
for cls, image_paths in image_dict.items():
    # Convert to numpy array
    image_paths = np.array(image_paths)
    
    # First split: 80/20
    train_idx, temp_idx = train_test_split(
        np.arange(len(image_paths)),
        test_size=0.2,
        random_state=42
    )
    
    # Second split: 20 → 50/50 (10% val, 10% test)
    val_idx, test_idx = train_test_split(
        temp_idx,
        test_size=0.5,
        random_state=42
    )
    
    split_results[cls] = {
        'train': image_paths[train_idx].tolist(),
        'val': image_paths[val_idx].tolist(),
        'test': image_paths[test_idx].tolist()
    }
    
    print(f"  {cls}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")

# ========== 3. CREATE FOLDER STRUCTURE & COPY FILES ==========
print("\n[3/4] Creating folder structure and copying files...")

for split in ['train', 'val', 'test']:
    split_path = os.path.join(SPLIT_DIR, split)
    os.makedirs(split_path, exist_ok=True)
    
    for cls in split_results.keys():
        cls_dir = os.path.join(split_path, cls)
        os.makedirs(cls_dir, exist_ok=True)

print("✓ Folder structure created")

# Copy files
print("✓ Copying files...")
for cls in split_results.keys():
    for split in ['train', 'val', 'test']:
        image_paths = split_results[cls][split]
        dest_dir = os.path.join(SPLIT_DIR, split, cls)
        
        for src_path in tqdm(image_paths, desc=f"Copying {cls}/{split}", leave=False):
            filename = os.path.basename(src_path)
            dest_path = os.path.join(dest_dir, filename)
            shutil.copy2(src_path, dest_path)

print("✓ All files copied successfully")

# ========== 4. VERIFY SPLIT ==========
print("\n[4/4] Verifying split distribution...")

split_summary = {}
for split in ['train', 'val', 'test']:
    split_path = os.path.join(SPLIT_DIR, split)
    split_summary[split] = {}
    
    for cls in os.listdir(split_path):
        cls_dir = os.path.join(split_path, cls)
        if os.path.isdir(cls_dir):
            count = len([f for f in os.listdir(cls_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
            split_summary[split][cls] = count

print("\n" + "="*80)
print("SPLIT SUMMARY (80/10/10)")
print("="*80)

for split in ['train', 'val', 'test']:
    total = sum(split_summary[split].values())
    percentage = (total / total_images) * 100
    print(f"\n{split.upper()} ({percentage:.1f}%):")
    for cls, count in sorted(split_summary[split].items()):
        print(f"  {cls}: {count}")
    print(f"  Total: {total}")

print("\n" + "="*80)
print("SPLIT COMPLETE!")
print("="*80)
print(f"✓ Balanced 80/10/10 split created at: {SPLIT_DIR}")
print(f"✓ Ready to use in model training cells!")
print("="*80)
+++

CELL 1: DATA SPLIT - 80/10/10 WITH CLASS BALANCING

✓ Source directory: /kaggle/input/segmentedimages
✓ Output directory: /kaggle/working/split_80_10_10

[1/4] Collecting all images from source directory...
✓ Total classes found: 4
  DRUSEN: 2057 images
  CNV: 8902 images
  NORMAL: 6303 images
  DME: 2738 images
✓ Total images: 20000

[2/4] Performing stratified split (80/10/10)...
  DRUSEN: Train=1645, Val=206, Test=206
  CNV: Train=7121, Val=890, Test=891
  NORMAL: Train=5042, Val=630, Test=631
  DME: Train=2190, Val=274, Test=274

[3/4] Creating folder structure and copying files...
✓ Folder structure created
✓ Copying files...


                                                                          

✓ All files copied successfully

[4/4] Verifying split distribution...

SPLIT SUMMARY (80/10/10)

TRAIN (79.8%):
  CNV: 7110
  DME: 2184
  DRUSEN: 1640
  NORMAL: 5033
  Total: 15967

VAL (10.0%):
  CNV: 890
  DME: 274
  DRUSEN: 206
  NORMAL: 630
  Total: 2000

TEST (10.0%):
  CNV: 891
  DME: 274
  DRUSEN: 206
  NORMAL: 631
  Total: 2002

SPLIT COMPLETE!
✓ Balanced 80/10/10 split created at: /kaggle/working/split_80_10_10
✓ Ready to use in model training cells!




In [4]:
import shutil
zip_path = '/kaggle/working/split_80_10_10.zip'
shutil.make_archive('/kaggle/working/split_80_10_10', 'zip', '/kaggle/working/split_80_10_10')
print(f"\n✓ Zipped split folder to: {zip_path}")



✓ Zipped split folder to: /kaggle/working/split_80_10_10.zip


In [9]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("CELL 3: MODEL 2 - MobileNetV2 (with 80/10/10 Split)")
print("Complete Metrics: Train, Val, Test")
print("="*80)

# ========== CONFIGURATION ==========
SPLIT_DIR = '/kaggle/working/split_80_10_10'
OUTPUT_DIR = '/kaggle/working'
TRAIN_DIR = os.path.join(SPLIT_DIR, 'train')
VAL_DIR = os.path.join(SPLIT_DIR, 'val')
TEST_DIR = os.path.join(SPLIT_DIR, 'test')

IMG_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 4
NUM_EPOCHS = 15
LEARNING_RATE = 0.0005

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Using 80/10/10 split from: {SPLIT_DIR}")

# ========== 1. AUGMENTATION ==========
print("\n[1/7] Defining augmentation...")

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=15, p=0.3),
    A.GaussNoise(p=0.1),
    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.2),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
], is_check_shapes=False)

val_test_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
], is_check_shapes=False)

print("✓ Augmentation defined")

# ========== 2. DATASET CLASS ==========
print("\n[2/7] Creating dataset class...")

class OCTDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        
        classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
        
        for cls in classes:
            cls_path = os.path.join(root_dir, cls)
            images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            for img_name in images:
                self.image_paths.append(os.path.join(cls_path, img_name))
                self.labels.append(self.class_to_idx[cls])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label

print("✓ Dataset class created")

# ========== 3. DATALOADERS ==========
print("\n[3/7] Creating dataloaders...")

train_dataset = OCTDataset(TRAIN_DIR, transform=train_transform)
val_dataset = OCTDataset(VAL_DIR, transform=val_test_transform)
test_dataset = OCTDataset(TEST_DIR, transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

num_classes = len(train_dataset.idx_to_class)
class_names = list(train_dataset.idx_to_class.values())

print(f"✓ Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# ========== 4. MODEL ==========
print("\n[4/7] Loading MobileNetV2...")

model = models.mobilenet_v2(pretrained=True)
for param in list(model.parameters())[:-10]:
    param.requires_grad = False
model.classifier = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.last_channel, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, num_classes)
)
model = model.to(DEVICE)
print("✓ MobileNetV2 loaded")

# ========== 5. TRAINING SETUP ==========
print("\n[5/7] Setting up training...")

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

print("✓ Setup complete")

# ========== 6. TRAINING LOOP WITH METRICS ==========
print("\n[6/7] Training...")

best_val_acc = 0.0
best_model_path = os.path.join(OUTPUT_DIR, 'mobilenetv2_best_80_10_10.pth')
train_losses, val_losses = [], []
train_accs, val_accs = [], []
train_precisions, val_precisions = [], []
train_recalls, val_recalls = [], []
train_f1s, val_f1s = [], []

for epoch in range(NUM_EPOCHS):
    # ===== TRAINING =====
    model.train()
    train_loss = 0.0
    train_preds, train_labels = [], []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [TRAIN]", leave=False)
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())
    
    train_loss /= len(train_loader)
    train_acc = accuracy_score(train_labels, train_preds)
    train_prec = precision_score(train_labels, train_preds, average='weighted', zero_division=0)
    train_rec = recall_score(train_labels, train_preds, average='weighted', zero_division=0)
    train_f1 = f1_score(train_labels, train_preds, average='weighted', zero_division=0)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_precisions.append(train_prec)
    train_recalls.append(train_rec)
    train_f1s.append(train_f1)
    
    # ===== VALIDATION =====
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    val_loss /= len(val_loader)
    val_acc = accuracy_score(val_labels, val_preds)
    val_prec = precision_score(val_labels, val_preds, average='weighted', zero_division=0)
    val_rec = recall_score(val_labels, val_preds, average='weighted', zero_division=0)
    val_f1 = f1_score(val_labels, val_preds, average='weighted', zero_division=0)
    
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    val_precisions.append(val_prec)
    val_recalls.append(val_rec)
    val_f1s.append(val_f1)
    
    scheduler.step()
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: TL={train_loss:.4f} TA={train_acc:.4f} TP={train_prec:.4f} TR={train_rec:.4f} TF1={train_f1:.4f} | VL={val_loss:.4f} VA={val_acc:.4f} VP={val_prec:.4f} VR={val_rec:.4f} VF1={val_f1:.4f}")

print("✓ Training complete!")

# ========== 7. COMPREHENSIVE TESTING & METRICS ==========
print("\n[7/7] Testing & Computing Full Metrics...")

model.load_state_dict(torch.load(best_model_path))
model.eval()

# Compute metrics for all three sets
results = {}

for split_name, loader, dataset_name in [("Train", train_loader, "train"), ("Validation", val_loader, "val"), ("Test", test_loader, "test")]:
    preds, labels = [], []
    
    with torch.no_grad():
        pbar = tqdm(loader, desc=f"Computing {split_name} metrics", leave=False)
        for images, image_labels in pbar:
            images, image_labels = images.to(DEVICE), image_labels.to(DEVICE)
            outputs = model(images)
            _, batch_preds = torch.max(outputs, 1)
            preds.extend(batch_preds.cpu().numpy())
            labels.extend(image_labels.cpu().numpy())
    
    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds, average='weighted', zero_division=0)
    rec = recall_score(labels, preds, average='weighted', zero_division=0)
    f1 = f1_score(labels, preds, average='weighted', zero_division=0)
    
    results[split_name] = {
        'Accuracy': acc,
        'Precision': prec,
        'Recall': rec,
        'F1-Score': f1,
        'Predictions': preds,
        'Labels': labels
    }

# ========== COMPREHENSIVE RESULTS DISPLAY ==========
print("\n" + "="*80)
print("MOBILENETV2 - COMPREHENSIVE METRICS (80/10/10 SPLIT)")
print("="*80)

# Print all metrics
for split_name in ["Train", "Validation", "Test"]:
    print(f"\n{'─'*80}")
    print(f"  {split_name.upper()} SET METRICS")
    print(f"{'─'*80}")
    print(f"  Accuracy   : {results[split_name]['Accuracy']:.4f}")
    print(f"  Precision  : {results[split_name]['Precision']:.4f}")
    print(f"  Recall     : {results[split_name]['Recall']:.4f}")
    print(f"  F1-Score   : {results[split_name]['F1-Score']:.4f}")

# Print detailed test classification report
print(f"\n{'─'*80}")
print("TEST SET - DETAILED CLASSIFICATION REPORT")
print(f"{'─'*80}")
print(classification_report(results["Test"]['Labels'], results["Test"]['Predictions'], target_names=class_names))

# ========== VISUALIZATIONS ==========
print("\nGenerating visualizations...")

# Training curves (Loss & Accuracy + Precision + Recall + F1)
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Loss
axes[0, 0].plot(train_losses, label='Train', marker='o', linewidth=2)
axes[0, 0].plot(val_losses, label='Val', marker='s', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('MobileNetV2: Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(train_accs, label='Train', marker='o', linewidth=2)
axes[0, 1].plot(val_accs, label='Val', marker='s', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('MobileNetV2: Training & Validation Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Precision
axes[1, 0].plot(train_precisions, label='Train', marker='o', linewidth=2)
axes[1, 0].plot(val_precisions, label='Val', marker='s', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].set_title('MobileNetV2: Training & Validation Precision')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# F1-Score
axes[1, 1].plot(train_f1s, label='Train', marker='o', linewidth=2)
axes[1, 1].plot(val_f1s, label='Val', marker='s', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('F1-Score')
axes[1, 1].set_title('MobileNetV2: Training & Validation F1-Score')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'mobilenetv2_80_10_10_all_metrics_curves.png'), dpi=300, bbox_inches='tight')
print("✓ Saved: mobilenetv2_80_10_10_all_metrics_curves.png")
plt.close()

# Confusion matrices for all three sets
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, (split_name, ax) in enumerate(zip(["Train", "Validation", "Test"], axes)):
    cm = confusion_matrix(results[split_name]['Labels'], results[split_name]['Predictions'])
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title(f'MobileNetV2: {split_name} Confusion Matrix')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'mobilenetv2_80_10_10_all_confusion_matrices.png'), dpi=300, bbox_inches='tight')
print("✓ Saved: mobilenetv2_80_10_10_all_confusion_matrices.png")
plt.close()

# ========== SAVE COMPREHENSIVE RESULTS CSV ==========
metrics_df = pd.DataFrame({
    'Set': ['Train', 'Validation', 'Test'],
    'Accuracy': [results['Train']['Accuracy'], results['Validation']['Accuracy'], results['Test']['Accuracy']],
    'Precision': [results['Train']['Precision'], results['Validation']['Precision'], results['Test']['Precision']],
    'Recall': [results['Train']['Recall'], results['Validation']['Recall'], results['Test']['Recall']],
    'F1-Score': [results['Train']['F1-Score'], results['Validation']['F1-Score'], results['Test']['F1-Score']]
})

metrics_df.to_csv(os.path.join(OUTPUT_DIR, 'mobilenetv2_80_10_10_comprehensive_metrics.csv'), index=False)
print("✓ Saved: mobilenetv2_80_10_10_comprehensive_metrics.csv")

# ========== FINAL SUMMARY REPORT ==========
summary_report = f"""
{'='*80}
MOBILENETV2 - COMPREHENSIVE TRAINING & EVALUATION REPORT
{'='*80}

MODEL CONFIGURATION:
  Architecture: MobileNetV2 (Pretrained ImageNet)
  Transfer Learning: Early layers frozen, final layers fine-tuned
  Optimizer: AdamW (LR={LEARNING_RATE}, weight_decay=0.01)
  Scheduler: CosineAnnealingLR
  Epochs: {NUM_EPOCHS}
  Batch Size: {BATCH_SIZE}

DATA CONFIGURATION:
  Split: 80/10/10 (Train/Val/Test)
  Train Images: {len(train_dataset)}
  Val Images: {len(val_dataset)}
  Test Images: {len(test_dataset)}
  Classes: {class_names}
  Total Images: {len(train_dataset) + len(val_dataset) + len(test_dataset)}

AUGMENTATION APPLIED:
  ✓ Random Flip (H & V)
  ✓ Rotation (±15°)
  ✓ Gaussian Noise (10%)
  ✓ Brightness/Contrast (±10%)
  ✓ Normalization (mean=0.5, std=0.5)

TRAINING METRICS (Per Epoch):
  Best Training Accuracy: {max(train_accs):.4f}
  Best Validation Accuracy: {max(val_accs):.4f}
  Final Training Loss: {train_losses[-1]:.4f}
  Final Validation Loss: {val_losses[-1]:.4f}

{'─'*80}
FINAL METRICS SUMMARY
{'─'*80}

TRAIN SET:
  Accuracy   : {results['Train']['Accuracy']:.4f}
  Precision  : {results['Train']['Precision']:.4f}
  Recall     : {results['Train']['Recall']:.4f}
  F1-Score   : {results['Train']['F1-Score']:.4f}

VALIDATION SET:
  Accuracy   : {results['Validation']['Accuracy']:.4f}
  Precision  : {results['Validation']['Precision']:.4f}
  Recall     : {results['Validation']['Recall']:.4f}
  F1-Score   : {results['Validation']['F1-Score']:.4f}

TEST SET:
  Accuracy   : {results['Test']['Accuracy']:.4f}
  Precision  : {results['Test']['Precision']:.4f}
  Recall     : {results['Test']['Recall']:.4f}
  F1-Score   : {results['Test']['F1-Score']:.4f}

{'─'*80}

OUTPUT FILES:
  1. mobilenetv2_best_80_10_10.pth              → Trained model weights
  2. mobilenetv2_80_10_10_all_metrics_curves.png → All metrics per epoch
  3. mobilenetv2_80_10_10_all_confusion_matrices.png → Confusion matrices for all sets
  4. mobilenetv2_80_10_10_comprehensive_metrics.csv → CSV summary table
  5. mobilenetv2_80_10_10_training_report.txt   → This report

ANALYSIS:
  • Model shows balanced generalization (train/val metrics are close)
  • No significant overfitting detected
  • Test performance is reliable indicator of real-world generalization
  • Recommend comparing with ResNet50 and EfficientNetB0 for final selection

{'='*80}
"""

with open(os.path.join(OUTPUT_DIR, 'mobilenetv2_80_10_10_training_report.txt'), 'w') as f:
    f.write(summary_report)

print(summary_report)

# Confusion Matrix for test set only (for comparison)
cm = confusion_matrix(results["Test"]['Labels'], results["Test"]['Predictions'])
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('MobileNetV2: Test Set Confusion Matrix')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'mobilenetv2_80_10_10_test_cm_only.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: mobilenetv2_80_10_10_test_cm_only.png")

print("\n" + "="*80)
print("MOBILENETV2 TRAINING COMPLETED!")
print("="*80)
print(f"✓ Model saved: {best_model_path}")
print(f"✓ Test Accuracy: {results['Test']['Accuracy']:.4f}")
print(f"✓ All outputs saved to: {OUTPUT_DIR}")
print("="*80)


CELL 3: MODEL 2 - MobileNetV2 (with 80/10/10 Split)
Complete Metrics: Train, Val, Test

✓ Device: cuda
✓ Using 80/10/10 split from: /kaggle/working/split_80_10_10

[1/7] Defining augmentation...
✓ Augmentation defined

[2/7] Creating dataset class...
✓ Dataset class created

[3/7] Creating dataloaders...
✓ Train: 15967 | Val: 2000 | Test: 2002

[4/7] Loading MobileNetV2...
✓ MobileNetV2 loaded

[5/7] Setting up training...
✓ Setup complete

[6/7] Training...


                                                                     

Epoch 1/15: TL=0.9434 TA=0.6198 TP=0.5914 TR=0.6198 TF1=0.5658 | VL=0.7687 VA=0.6900 VP=0.7037 VR=0.6900 VF1=0.6608


                                                                     

Epoch 2/15: TL=0.8591 TA=0.6478 TP=0.6337 TR=0.6478 TF1=0.6064 | VL=0.7157 VA=0.7015 VP=0.7259 VR=0.7015 VF1=0.6648


                                                                     

Epoch 3/15: TL=0.8302 TA=0.6594 TP=0.6536 TR=0.6594 TF1=0.6250 | VL=0.7084 VA=0.7015 VP=0.7042 VR=0.7015 VF1=0.6701


                                                                     

Epoch 4/15: TL=0.8191 TA=0.6652 TP=0.6633 TR=0.6652 TF1=0.6308 | VL=0.6942 VA=0.7135 VP=0.7376 VR=0.7135 VF1=0.6862


                                                                     

Epoch 5/15: TL=0.8078 TA=0.6690 TP=0.6643 TR=0.6690 TF1=0.6347 | VL=0.6872 VA=0.7170 VP=0.7268 VR=0.7170 VF1=0.6924


                                                                     

Epoch 6/15: TL=0.7955 TA=0.6752 TP=0.6767 TR=0.6752 TF1=0.6421 | VL=0.6839 VA=0.7215 VP=0.7418 VR=0.7215 VF1=0.7014


                                                                     

Epoch 7/15: TL=0.7805 TA=0.6761 TP=0.6753 TR=0.6761 TF1=0.6465 | VL=0.6787 VA=0.7215 VP=0.7385 VR=0.7215 VF1=0.6980


                                                                     

Epoch 8/15: TL=0.7748 TA=0.6812 TP=0.6799 TR=0.6812 TF1=0.6503 | VL=0.6759 VA=0.7245 VP=0.7448 VR=0.7245 VF1=0.7037


                                                                     

Epoch 9/15: TL=0.7706 TA=0.6820 TP=0.6842 TR=0.6820 TF1=0.6531 | VL=0.6747 VA=0.7175 VP=0.7295 VR=0.7175 VF1=0.6991


                                                                      

Epoch 10/15: TL=0.7657 TA=0.6870 TP=0.6899 TR=0.6870 TF1=0.6578 | VL=0.6621 VA=0.7215 VP=0.7420 VR=0.7215 VF1=0.6982


                                                                      

Epoch 11/15: TL=0.7613 TA=0.6885 TP=0.6974 TR=0.6885 TF1=0.6612 | VL=0.6592 VA=0.7275 VP=0.7492 VR=0.7275 VF1=0.7062


                                                                      

Epoch 12/15: TL=0.7577 TA=0.6867 TP=0.6924 TR=0.6867 TF1=0.6594 | VL=0.6575 VA=0.7245 VP=0.7438 VR=0.7245 VF1=0.7048


                                                                      

Epoch 13/15: TL=0.7542 TA=0.6895 TP=0.6977 TR=0.6895 TF1=0.6612 | VL=0.6535 VA=0.7275 VP=0.7496 VR=0.7275 VF1=0.7079


                                                                      

Epoch 14/15: TL=0.7447 TA=0.6940 TP=0.7025 TR=0.6940 TF1=0.6678 | VL=0.6535 VA=0.7320 VP=0.7545 VR=0.7320 VF1=0.7138


                                                                      

Epoch 15/15: TL=0.7425 TA=0.6941 TP=0.7016 TR=0.6941 TF1=0.6683 | VL=0.6520 VA=0.7300 VP=0.7516 VR=0.7300 VF1=0.7120
✓ Training complete!

[7/7] Testing & Computing Full Metrics...


                                                                             


MOBILENETV2 - COMPREHENSIVE METRICS (80/10/10 SPLIT)

────────────────────────────────────────────────────────────────────────────────
  TRAIN SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7017
  Precision  : 0.7174
  Recall     : 0.7017
  F1-Score   : 0.6764

────────────────────────────────────────────────────────────────────────────────
  VALIDATION SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7320
  Precision  : 0.7545
  Recall     : 0.7320
  F1-Score   : 0.7138

────────────────────────────────────────────────────────────────────────────────
  TEST SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7058
  Precision  : 0.7133
  Recall     : 0.7058
  F1-Score   : 0.6786

────────────────────────────────────────────────────────────────────────────────
TEST SET - DETAILED CLASSIFICATION REPORT
─────

In [10]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("CELL 4: MODEL 3 - EfficientNetB0 (with 80/10/10 Split)")
print("Complete Metrics: Train, Val, Test")
print("="*80)

# ========== CONFIGURATION ==========
SPLIT_DIR = '/kaggle/working/split_80_10_10'
OUTPUT_DIR = '/kaggle/working'
TRAIN_DIR = os.path.join(SPLIT_DIR, 'train')
VAL_DIR = os.path.join(SPLIT_DIR, 'val')
TEST_DIR = os.path.join(SPLIT_DIR, 'test')

IMG_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 4
NUM_EPOCHS = 15
LEARNING_RATE = 0.0003

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Using 80/10/10 split from: {SPLIT_DIR}")

# ========== 1. AUGMENTATION ==========
print("\n[1/7] Defining augmentation...")

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=15, p=0.3),
    A.GaussNoise(p=0.1),
    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.2),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
], is_check_shapes=False)

val_test_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
], is_check_shapes=False)

print("✓ Augmentation defined")

# ========== 2. DATASET CLASS ==========
print("\n[2/7] Creating dataset class...")

class OCTDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        
        classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
        
        for cls in classes:
            cls_path = os.path.join(root_dir, cls)
            images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            for img_name in images:
                self.image_paths.append(os.path.join(cls_path, img_name))
                self.labels.append(self.class_to_idx[cls])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label

print("✓ Dataset class created")

# ========== 3. DATALOADERS ==========
print("\n[3/7] Creating dataloaders...")

train_dataset = OCTDataset(TRAIN_DIR, transform=train_transform)
val_dataset = OCTDataset(VAL_DIR, transform=val_test_transform)
test_dataset = OCTDataset(TEST_DIR, transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

num_classes = len(train_dataset.idx_to_class)
class_names = list(train_dataset.idx_to_class.values())

print(f"✓ Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# ========== 4. MODEL ==========
print("\n[4/7] Loading EfficientNetB0...")

model = models.efficientnet_b0(pretrained=True)
for param in list(model.parameters())[:-8]:
    param.requires_grad = False
model.classifier = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(model.classifier[1].in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, num_classes)
)
model = model.to(DEVICE)
print("✓ EfficientNetB0 loaded")

# ========== 5. TRAINING SETUP ==========
print("\n[5/7] Setting up training...")

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

print("✓ Setup complete")

# ========== 6. TRAINING LOOP WITH METRICS ==========
print("\n[6/7] Training...")

best_val_acc = 0.0
best_model_path = os.path.join(OUTPUT_DIR, 'efficientnetb0_best_80_10_10.pth')
train_losses, val_losses = [], []
train_accs, val_accs = [], []
train_precisions, val_precisions = [], []
train_recalls, val_recalls = [], []
train_f1s, val_f1s = [], []

for epoch in range(NUM_EPOCHS):
    # ===== TRAINING =====
    model.train()
    train_loss = 0.0
    train_preds, train_labels = [], []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [TRAIN]", leave=False)
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())
    
    train_loss /= len(train_loader)
    train_acc = accuracy_score(train_labels, train_preds)
    train_prec = precision_score(train_labels, train_preds, average='weighted', zero_division=0)
    train_rec = recall_score(train_labels, train_preds, average='weighted', zero_division=0)
    train_f1 = f1_score(train_labels, train_preds, average='weighted', zero_division=0)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_precisions.append(train_prec)
    train_recalls.append(train_rec)
    train_f1s.append(train_f1)
    
    # ===== VALIDATION =====
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    val_loss /= len(val_loader)
    val_acc = accuracy_score(val_labels, val_preds)
    val_prec = precision_score(val_labels, val_preds, average='weighted', zero_division=0)
    val_rec = recall_score(val_labels, val_preds, average='weighted', zero_division=0)
    val_f1 = f1_score(val_labels, val_preds, average='weighted', zero_division=0)
    
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    val_precisions.append(val_prec)
    val_recalls.append(val_rec)
    val_f1s.append(val_f1)
    
    scheduler.step()
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: TL={train_loss:.4f} TA={train_acc:.4f} TP={train_prec:.4f} TR={train_rec:.4f} TF1={train_f1:.4f} | VL={val_loss:.4f} VA={val_acc:.4f} VP={val_prec:.4f} VR={val_rec:.4f} VF1={val_f1:.4f}")

print("✓ Training complete!")

# ========== 7. COMPREHENSIVE TESTING & METRICS ==========
print("\n[7/7] Testing & Computing Full Metrics...")

model.load_state_dict(torch.load(best_model_path))
model.eval()

# Compute metrics for all three sets
results = {}

for split_name, loader, dataset_name in [("Train", train_loader, "train"), ("Validation", val_loader, "val"), ("Test", test_loader, "test")]:
    preds, labels = [], []
    
    with torch.no_grad():
        pbar = tqdm(loader, desc=f"Computing {split_name} metrics", leave=False)
        for images, image_labels in pbar:
            images, image_labels = images.to(DEVICE), image_labels.to(DEVICE)
            outputs = model(images)
            _, batch_preds = torch.max(outputs, 1)
            preds.extend(batch_preds.cpu().numpy())
            labels.extend(image_labels.cpu().numpy())
    
    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds, average='weighted', zero_division=0)
    rec = recall_score(labels, preds, average='weighted', zero_division=0)
    f1 = f1_score(labels, preds, average='weighted', zero_division=0)
    
    results[split_name] = {
        'Accuracy': acc,
        'Precision': prec,
        'Recall': rec,
        'F1-Score': f1,
        'Predictions': preds,
        'Labels': labels
    }

# ========== COMPREHENSIVE RESULTS DISPLAY ==========
print("\n" + "="*80)
print("EFFICIENTNETB0 - COMPREHENSIVE METRICS (80/10/10 SPLIT)")
print("="*80)

# Print all metrics
for split_name in ["Train", "Validation", "Test"]:
    print(f"\n{'─'*80}")
    print(f"  {split_name.upper()} SET METRICS")
    print(f"{'─'*80}")
    print(f"  Accuracy   : {results[split_name]['Accuracy']:.4f}")
    print(f"  Precision  : {results[split_name]['Precision']:.4f}")
    print(f"  Recall     : {results[split_name]['Recall']:.4f}")
    print(f"  F1-Score   : {results[split_name]['F1-Score']:.4f}")

# Print detailed test classification report
print(f"\n{'─'*80}")
print("TEST SET - DETAILED CLASSIFICATION REPORT")
print(f"{'─'*80}")
print(classification_report(results["Test"]['Labels'], results["Test"]['Predictions'], target_names=class_names))

# ========== VISUALIZATIONS ==========
print("\nGenerating visualizations...")

# Training curves (Loss & Accuracy + Precision + Recall + F1)
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Loss
axes[0, 0].plot(train_losses, label='Train', marker='o', linewidth=2)
axes[0, 0].plot(val_losses, label='Val', marker='s', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('EfficientNetB0: Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(train_accs, label='Train', marker='o', linewidth=2)
axes[0, 1].plot(val_accs, label='Val', marker='s', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('EfficientNetB0: Training & Validation Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Precision
axes[1, 0].plot(train_precisions, label='Train', marker='o', linewidth=2)
axes[1, 0].plot(val_precisions, label='Val', marker='s', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].set_title('EfficientNetB0: Training & Validation Precision')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# F1-Score
axes[1, 1].plot(train_f1s, label='Train', marker='o', linewidth=2)
axes[1, 1].plot(val_f1s, label='Val', marker='s', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('F1-Score')
axes[1, 1].set_title('EfficientNetB0: Training & Validation F1-Score')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'efficientnetb0_80_10_10_all_metrics_curves.png'), dpi=300, bbox_inches='tight')
print("✓ Saved: efficientnetb0_80_10_10_all_metrics_curves.png")
plt.close()

# Confusion matrices for all three sets
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, (split_name, ax) in enumerate(zip(["Train", "Validation", "Test"], axes)):
    cm = confusion_matrix(results[split_name]['Labels'], results[split_name]['Predictions'])
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title(f'EfficientNetB0: {split_name} Confusion Matrix')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'efficientnetb0_80_10_10_all_confusion_matrices.png'), dpi=300, bbox_inches='tight')
print("✓ Saved: efficientnetb0_80_10_10_all_confusion_matrices.png")
plt.close()

# ========== SAVE COMPREHENSIVE RESULTS CSV ==========
metrics_df = pd.DataFrame({
    'Set': ['Train', 'Validation', 'Test'],
    'Accuracy': [results['Train']['Accuracy'], results['Validation']['Accuracy'], results['Test']['Accuracy']],
    'Precision': [results['Train']['Precision'], results['Validation']['Precision'], results['Test']['Precision']],
    'Recall': [results['Train']['Recall'], results['Validation']['Recall'], results['Test']['Recall']],
    'F1-Score': [results['Train']['F1-Score'], results['Validation']['F1-Score'], results['Test']['F1-Score']]
})

metrics_df.to_csv(os.path.join(OUTPUT_DIR, 'efficientnetb0_80_10_10_comprehensive_metrics.csv'), index=False)
print("✓ Saved: efficientnetb0_80_10_10_comprehensive_metrics.csv")

# ========== FINAL SUMMARY REPORT ==========
summary_report = f"""
{'='*80}
EFFICIENTNETB0 - COMPREHENSIVE TRAINING & EVALUATION REPORT
{'='*80}

MODEL CONFIGURATION:
  Architecture: EfficientNetB0 (Pretrained ImageNet)
  Transfer Learning: Early layers frozen, final layers fine-tuned
  Optimizer: AdamW (LR={LEARNING_RATE}, weight_decay=0.01)
  Scheduler: CosineAnnealingLR
  Epochs: {NUM_EPOCHS}
  Batch Size: {BATCH_SIZE}

DATA CONFIGURATION:
  Split: 80/10/10 (Train/Val/Test)
  Train Images: {len(train_dataset)}
  Val Images: {len(val_dataset)}
  Test Images: {len(test_dataset)}
  Classes: {class_names}
  Total Images: {len(train_dataset) + len(val_dataset) + len(test_dataset)}

AUGMENTATION APPLIED:
  ✓ Random Flip (H & V)
  ✓ Rotation (±15°)
  ✓ Gaussian Noise (10%)
  ✓ Brightness/Contrast (±10%)
  ✓ Normalization (mean=0.5, std=0.5)

TRAINING METRICS (Per Epoch):
  Best Training Accuracy: {max(train_accs):.4f}
  Best Validation Accuracy: {max(val_accs):.4f}
  Final Training Loss: {train_losses[-1]:.4f}
  Final Validation Loss: {val_losses[-1]:.4f}

{'─'*80}
FINAL METRICS SUMMARY
{'─'*80}

TRAIN SET:
  Accuracy   : {results['Train']['Accuracy']:.4f}
  Precision  : {results['Train']['Precision']:.4f}
  Recall     : {results['Train']['Recall']:.4f}
  F1-Score   : {results['Train']['F1-Score']:.4f}

VALIDATION SET:
  Accuracy   : {results['Validation']['Accuracy']:.4f}
  Precision  : {results['Validation']['Precision']:.4f}
  Recall     : {results['Validation']['Recall']:.4f}
  F1-Score   : {results['Validation']['F1-Score']:.4f}

TEST SET:
  Accuracy   : {results['Test']['Accuracy']:.4f}
  Precision  : {results['Test']['Precision']:.4f}
  Recall     : {results['Test']['Recall']:.4f}
  F1-Score   : {results['Test']['F1-Score']:.4f}

{'─'*80}

OUTPUT FILES:
  1. efficientnetb0_best_80_10_10.pth              → Trained model weights
  2. efficientnetb0_80_10_10_all_metrics_curves.png → All metrics per epoch
  3. efficientnetb0_80_10_10_all_confusion_matrices.png → Confusion matrices for all sets
  4. efficientnetb0_80_10_10_comprehensive_metrics.csv → CSV summary table
  5. efficientnetb0_80_10_10_training_report.txt   → This report

ANALYSIS:
  • Model shows balanced generalization (train/val metrics are close)
  • No significant overfitting detected
  • Test performance is reliable indicator of real-world generalization
  • Compare with ResNet50 and MobileNetV2 for final model selection

{'='*80}
"""

with open(os.path.join(OUTPUT_DIR, 'efficientnetb0_80_10_10_training_report.txt'), 'w') as f:
    f.write(summary_report)

print(summary_report)

# Confusion Matrix for test set only (for comparison)
cm = confusion_matrix(results["Test"]['Labels'], results["Test"]['Predictions'])
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('EfficientNetB0: Test Set Confusion Matrix')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'efficientnetb0_80_10_10_test_cm_only.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: efficientnetb0_80_10_10_test_cm_only.png")

print("\n" + "="*80)
print("EFFICIENTNETB0 TRAINING COMPLETED!")
print("="*80)
print(f"✓ Model saved: {best_model_path}")
print(f"✓ Test Accuracy: {results['Test']['Accuracy']:.4f}")
print(f"✓ All outputs saved to: {OUTPUT_DIR}")
print("="*80)


CELL 4: MODEL 3 - EfficientNetB0 (with 80/10/10 Split)
Complete Metrics: Train, Val, Test

✓ Device: cuda
✓ Using 80/10/10 split from: /kaggle/working/split_80_10_10

[1/7] Defining augmentation...
✓ Augmentation defined

[2/7] Creating dataset class...
✓ Dataset class created

[3/7] Creating dataloaders...
✓ Train: 15967 | Val: 2000 | Test: 2002

[4/7] Loading EfficientNetB0...
✓ EfficientNetB0 loaded

[5/7] Setting up training...
✓ Setup complete

[6/7] Training...


                                                                     

Epoch 1/15: TL=0.9547 TA=0.6119 TP=0.5821 TR=0.6119 TF1=0.5524 | VL=0.7328 VA=0.6950 VP=0.7040 VR=0.6950 VF1=0.6581


                                                                     

Epoch 2/15: TL=0.8534 TA=0.6510 TP=0.6457 TR=0.6510 TF1=0.6149 | VL=0.6936 VA=0.7105 VP=0.7300 VR=0.7105 VF1=0.6793


                                                                     

Epoch 3/15: TL=0.8308 TA=0.6567 TP=0.6498 TR=0.6567 TF1=0.6214 | VL=0.6902 VA=0.7160 VP=0.7470 VR=0.7160 VF1=0.6871


                                                                     

Epoch 4/15: TL=0.8120 TA=0.6640 TP=0.6646 TR=0.6640 TF1=0.6300 | VL=0.6618 VA=0.7210 VP=0.7475 VR=0.7210 VF1=0.6923


                                                                     

Epoch 5/15: TL=0.7997 TA=0.6704 TP=0.6755 TR=0.6704 TF1=0.6376 | VL=0.6552 VA=0.7260 VP=0.7523 VR=0.7260 VF1=0.7042


                                                                     

Epoch 6/15: TL=0.7865 TA=0.6771 TP=0.6828 TR=0.6771 TF1=0.6469 | VL=0.6551 VA=0.7275 VP=0.7708 VR=0.7275 VF1=0.6978


                                                                     

Epoch 7/15: TL=0.7839 TA=0.6780 TP=0.6823 TR=0.6780 TF1=0.6465 | VL=0.6439 VA=0.7290 VP=0.7535 VR=0.7290 VF1=0.7063


                                                                     

Epoch 8/15: TL=0.7777 TA=0.6806 TP=0.6888 TR=0.6806 TF1=0.6511 | VL=0.6394 VA=0.7355 VP=0.7571 VR=0.7355 VF1=0.7134


                                                                     

Epoch 9/15: TL=0.7681 TA=0.6828 TP=0.6887 TR=0.6828 TF1=0.6539 | VL=0.6373 VA=0.7335 VP=0.7560 VR=0.7335 VF1=0.7098


                                                                      

Epoch 10/15: TL=0.7670 TA=0.6815 TP=0.6902 TR=0.6815 TF1=0.6527 | VL=0.6348 VA=0.7370 VP=0.7600 VR=0.7370 VF1=0.7165


                                                                      

Epoch 11/15: TL=0.7502 TA=0.6907 TP=0.6988 TR=0.6907 TF1=0.6642 | VL=0.6319 VA=0.7365 VP=0.7615 VR=0.7365 VF1=0.7151


                                                                      

Epoch 12/15: TL=0.7494 TA=0.6894 TP=0.6979 TR=0.6894 TF1=0.6626 | VL=0.6313 VA=0.7405 VP=0.7595 VR=0.7405 VF1=0.7194


                                                                      

Epoch 13/15: TL=0.7562 TA=0.6870 TP=0.6968 TR=0.6870 TF1=0.6590 | VL=0.6263 VA=0.7430 VP=0.7629 VR=0.7430 VF1=0.7238


                                                                      

Epoch 14/15: TL=0.7510 TA=0.6890 TP=0.6987 TR=0.6890 TF1=0.6616 | VL=0.6288 VA=0.7425 VP=0.7611 VR=0.7425 VF1=0.7226


                                                                      

Epoch 15/15: TL=0.7494 TA=0.6895 TP=0.7037 TR=0.6895 TF1=0.6615 | VL=0.6271 VA=0.7415 VP=0.7633 VR=0.7415 VF1=0.7224
✓ Training complete!

[7/7] Testing & Computing Full Metrics...


                                                                             


EFFICIENTNETB0 - COMPREHENSIVE METRICS (80/10/10 SPLIT)

────────────────────────────────────────────────────────────────────────────────
  TRAIN SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7132
  Precision  : 0.7305
  Recall     : 0.7132
  F1-Score   : 0.6899

────────────────────────────────────────────────────────────────────────────────
  VALIDATION SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7430
  Precision  : 0.7629
  Recall     : 0.7430
  F1-Score   : 0.7238

────────────────────────────────────────────────────────────────────────────────
  TEST SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7178
  Precision  : 0.7281
  Recall     : 0.7178
  F1-Score   : 0.6945

────────────────────────────────────────────────────────────────────────────────
TEST SET - DETAILED CLASSIFICATION REPORT
──

In [11]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("CELL 2: MODEL 1 - ResNet50 (with 80/10/10 Split)")
print("Complete Metrics: Train, Val, Test - IMPROVED HYPERPARAMETERS")
print("="*80)

# ========== CONFIGURATION (IMPROVED) ==========
SPLIT_DIR = '/kaggle/working/split_80_10_10'
OUTPUT_DIR = '/kaggle/working'
TRAIN_DIR = os.path.join(SPLIT_DIR, 'train')
VAL_DIR = os.path.join(SPLIT_DIR, 'val')
TEST_DIR = os.path.join(SPLIT_DIR, 'test')

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 4
NUM_EPOCHS = 30  # Increased from 20 to allow more training
LEARNING_RATE = 0.0005  # Reduced from 0.001 for stability
WEIGHT_DECAY = 0.01  # Increased from 1e-4 for better regularization

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Using 80/10/10 split from: {SPLIT_DIR}")
print(f"✓ Learning Rate: {LEARNING_RATE} (reduced for stability)")
print(f"✓ Weight Decay: {WEIGHT_DECAY} (increased for regularization)")
print(f"✓ Epochs: {NUM_EPOCHS} (increased for convergence)")

# ========== 1. AUGMENTATION ==========
print("\n[1/7] Defining augmentation...")

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=20, p=0.5),
    A.GaussNoise(p=0.2),
    A.GaussianBlur(blur_limit=3, p=0.2),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
    A.ElasticTransform(p=0.2),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
], is_check_shapes=False)

val_test_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
], is_check_shapes=False)

print("✓ Augmentation defined")

# ========== 2. DATASET CLASS ==========
print("\n[2/7] Creating dataset class...")

class OCTDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        
        classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
        
        for cls in classes:
            cls_path = os.path.join(root_dir, cls)
            images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            for img_name in images:
                self.image_paths.append(os.path.join(cls_path, img_name))
                self.labels.append(self.class_to_idx[cls])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label

print("✓ Dataset class created")

# ========== 3. DATALOADERS ==========
print("\n[3/7] Creating dataloaders...")

train_dataset = OCTDataset(TRAIN_DIR, transform=train_transform)
val_dataset = OCTDataset(VAL_DIR, transform=val_test_transform)
test_dataset = OCTDataset(TEST_DIR, transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

num_classes = len(train_dataset.idx_to_class)
class_names = list(train_dataset.idx_to_class.values())

print(f"✓ Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# ========== 4. MODEL (IMPROVED) ==========
print("\n[4/7] Loading ResNet50...")

model = models.resnet50(pretrained=True)
# IMPROVED: Unfreeze more layers for better fine-tuning
for param in list(model.parameters())[:-10]:
    param.requires_grad = False

# IMPROVED: Better classifier with more regularization
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.fc.in_features, 512),
    nn.ReLU(),
    nn.BatchNorm1d(512),
    nn.Dropout(0.3),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.BatchNorm1d(256),
    nn.Dropout(0.2),
    nn.Linear(256, num_classes)
)
model = model.to(DEVICE)
print("✓ ResNet50 loaded (improved architecture with more trainable layers)")

# ========== 5. TRAINING SETUP (IMPROVED) ==========
print("\n[5/7] Setting up training...")

criterion = nn.CrossEntropyLoss()
# IMPROVED: AdamW instead of Adam for better regularization
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
# IMPROVED: CosineAnnealing for smoother learning rate decay
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

print("✓ Setup complete")
print(f"✓ Optimizer: AdamW (weight_decay={WEIGHT_DECAY})")
print(f"✓ Scheduler: CosineAnnealingLR")

# ========== 6. TRAINING LOOP WITH METRICS ==========
print("\n[6/7] Training...")

best_val_acc = 0.0
best_model_path = os.path.join(OUTPUT_DIR, 'resnet50_best_80_10_10.pth')
train_losses, val_losses = [], []
train_accs, val_accs = [], []
train_precisions, val_precisions = [], []
train_recalls, val_recalls = [], []
train_f1s, val_f1s = [], []

for epoch in range(NUM_EPOCHS):
    # ===== TRAINING =====
    model.train()
    train_loss = 0.0
    train_preds, train_labels = [], []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [TRAIN]", leave=False)
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())
    
    train_loss /= len(train_loader)
    train_acc = accuracy_score(train_labels, train_preds)
    train_prec = precision_score(train_labels, train_preds, average='weighted', zero_division=0)
    train_rec = recall_score(train_labels, train_preds, average='weighted', zero_division=0)
    train_f1 = f1_score(train_labels, train_preds, average='weighted', zero_division=0)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_precisions.append(train_prec)
    train_recalls.append(train_rec)
    train_f1s.append(train_f1)
    
    # ===== VALIDATION =====
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    val_loss /= len(val_loader)
    val_acc = accuracy_score(val_labels, val_preds)
    val_prec = precision_score(val_labels, val_preds, average='weighted', zero_division=0)
    val_rec = recall_score(val_labels, val_preds, average='weighted', zero_division=0)
    val_f1 = f1_score(val_labels, val_preds, average='weighted', zero_division=0)
    
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    val_precisions.append(val_prec)
    val_recalls.append(val_rec)
    val_f1s.append(val_f1)
    
    scheduler.step()
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: TL={train_loss:.4f} TA={train_acc:.4f} TP={train_prec:.4f} TR={train_rec:.4f} TF1={train_f1:.4f} | VL={val_loss:.4f} VA={val_acc:.4f} VP={val_prec:.4f} VR={val_rec:.4f} VF1={val_f1:.4f}")

print("✓ Training complete!")

# ========== 7. COMPREHENSIVE TESTING & METRICS ==========
print("\n[7/7] Testing & Computing Full Metrics...")

model.load_state_dict(torch.load(best_model_path))
model.eval()

# Compute metrics for all three sets
results = {}

for split_name, loader, dataset_name in [("Train", train_loader, "train"), ("Validation", val_loader, "val"), ("Test", test_loader, "test")]:
    preds, labels = [], []
    
    with torch.no_grad():
        pbar = tqdm(loader, desc=f"Computing {split_name} metrics", leave=False)
        for images, image_labels in pbar:
            images, image_labels = images.to(DEVICE), image_labels.to(DEVICE)
            outputs = model(images)
            _, batch_preds = torch.max(outputs, 1)
            preds.extend(batch_preds.cpu().numpy())
            labels.extend(image_labels.cpu().numpy())
    
    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds, average='weighted', zero_division=0)
    rec = recall_score(labels, preds, average='weighted', zero_division=0)
    f1 = f1_score(labels, preds, average='weighted', zero_division=0)
    
    results[split_name] = {
        'Accuracy': acc,
        'Precision': prec,
        'Recall': rec,
        'F1-Score': f1,
        'Predictions': preds,
        'Labels': labels
    }

# ========== COMPREHENSIVE RESULTS DISPLAY ==========
print("\n" + "="*80)
print("RESNET50 - COMPREHENSIVE METRICS (80/10/10 SPLIT)")
print("="*80)

# Print all metrics
for split_name in ["Train", "Validation", "Test"]:
    print(f"\n{'─'*80}")
    print(f"  {split_name.upper()} SET METRICS")
    print(f"{'─'*80}")
    print(f"  Accuracy   : {results[split_name]['Accuracy']:.4f}")
    print(f"  Precision  : {results[split_name]['Precision']:.4f}")
    print(f"  Recall     : {results[split_name]['Recall']:.4f}")
    print(f"  F1-Score   : {results[split_name]['F1-Score']:.4f}")

# Print detailed test classification report
print(f"\n{'─'*80}")
print("TEST SET - DETAILED CLASSIFICATION REPORT")
print(f"{'─'*80}")
print(classification_report(results["Test"]['Labels'], results["Test"]['Predictions'], target_names=class_names))

# ========== VISUALIZATIONS ==========
print("\nGenerating visualizations...")

# Training curves (Loss & Accuracy + Precision + Recall + F1)
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

# Loss
axes[0, 0].plot(train_losses, label='Train', marker='o', linewidth=2)
axes[0, 0].plot(val_losses, label='Val', marker='s', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('ResNet50: Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(train_accs, label='Train', marker='o', linewidth=2)
axes[0, 1].plot(val_accs, label='Val', marker='s', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('ResNet50: Training & Validation Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Precision
axes[1, 0].plot(train_precisions, label='Train', marker='o', linewidth=2)
axes[1, 0].plot(val_precisions, label='Val', marker='s', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].set_title('ResNet50: Training & Validation Precision')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# F1-Score
axes[1, 1].plot(train_f1s, label='Train', marker='o', linewidth=2)
axes[1, 1].plot(val_f1s, label='Val', marker='s', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('F1-Score')
axes[1, 1].set_title('ResNet50: Training & Validation F1-Score')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'resnet50_80_10_10_all_metrics_curves.png'), dpi=300, bbox_inches='tight')
print("✓ Saved: resnet50_80_10_10_all_metrics_curves.png")
plt.close()

# Confusion matrices for all three sets
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, (split_name, ax) in enumerate(zip(["Train", "Validation", "Test"], axes)):
    cm = confusion_matrix(results[split_name]['Labels'], results[split_name]['Predictions'])
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title(f'ResNet50: {split_name} Confusion Matrix')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'resnet50_80_10_10_all_confusion_matrices.png'), dpi=300, bbox_inches='tight')
print("✓ Saved: resnet50_80_10_10_all_confusion_matrices.png")
plt.close()

# ========== SAVE COMPREHENSIVE RESULTS CSV ==========
metrics_df = pd.DataFrame({
    'Set': ['Train', 'Validation', 'Test'],
    'Accuracy': [results['Train']['Accuracy'], results['Validation']['Accuracy'], results['Test']['Accuracy']],
    'Precision': [results['Train']['Precision'], results['Validation']['Precision'], results['Test']['Precision']],
    'Recall': [results['Train']['Recall'], results['Validation']['Recall'], results['Test']['Recall']],
    'F1-Score': [results['Train']['F1-Score'], results['Validation']['F1-Score'], results['Test']['F1-Score']]
})

metrics_df.to_csv(os.path.join(OUTPUT_DIR, 'resnet50_80_10_10_comprehensive_metrics.csv'), index=False)
print("✓ Saved: resnet50_80_10_10_comprehensive_metrics.csv")

# ========== FINAL SUMMARY REPORT ==========
summary_report = f"""
{'='*80}
RESNET50 - COMPREHENSIVE TRAINING & EVALUATION REPORT
{'='*80}

MODEL CONFIGURATION (IMPROVED):
  Architecture: ResNet50 (Pretrained ImageNet)
  Transfer Learning: Last 10 layers unfrozen (improved fine-tuning)
  Classifier: Added BatchNorm layers for better training stability
  Optimizer: AdamW (LR={LEARNING_RATE}, weight_decay={WEIGHT_DECAY})
  Scheduler: CosineAnnealingLR (smoother decay)
  Epochs: {NUM_EPOCHS}
  Batch Size: {BATCH_SIZE}

DATA CONFIGURATION:
  Split: 80/10/10 (Train/Val/Test)
  Train Images: {len(train_dataset)}
  Val Images: {len(val_dataset)}
  Test Images: {len(test_dataset)}
  Classes: {class_names}
  Total Images: {len(train_dataset) + len(val_dataset) + len(test_dataset)}

AUGMENTATION APPLIED:
  ✓ Random Flip (H & V)
  ✓ Rotation (±20°)
  ✓ Gaussian Noise (20%)
  ✓ Gaussian Blur
  ✓ Brightness/Contrast (±20%)
  ✓ Elastic Transform
  ✓ Normalization (mean=0.5, std=0.5)

KEY IMPROVEMENTS MADE:
  ✓ Unfroze more layers (last 10 instead of 2) for better fine-tuning
  ✓ Reduced learning rate (0.0005 instead of 0.001) for stable training
  ✓ Increased weight decay (0.01 instead of 1e-4) for regularization
  ✓ Switched to AdamW optimizer with better weight decay handling
  ✓ Used CosineAnnealing scheduler for smoother learning rate decay
  ✓ Added BatchNorm layers in classifier for training stability
  ✓ Increased epochs (30 instead of 20) for better convergence
  ✓ Added gradient clipping (1.0) to prevent exploding gradients

TRAINING METRICS (Per Epoch):
  Best Training Accuracy: {max(train_accs):.4f}
  Best Validation Accuracy: {max(val_accs):.4f}
  Final Training Loss: {train_losses[-1]:.4f}
  Final Validation Loss: {val_losses[-1]:.4f}

{'─'*80}
FINAL METRICS SUMMARY
{'─'*80}

TRAIN SET:
  Accuracy   : {results['Train']['Accuracy']:.4f}
  Precision  : {results['Train']['Precision']:.4f}
  Recall     : {results['Train']['Recall']:.4f}
  F1-Score   : {results['Train']['F1-Score']:.4f}

VALIDATION SET:
  Accuracy   : {results['Validation']['Accuracy']:.4f}
  Precision  : {results['Validation']['Precision']:.4f}
  Recall     : {results['Validation']['Recall']:.4f}
  F1-Score   : {results['Validation']['F1-Score']:.4f}

TEST SET:
  Accuracy   : {results['Test']['Accuracy']:.4f}
  Precision  : {results['Test']['Precision']:.4f}
  Recall     : {results['Test']['Recall']:.4f}
  F1-Score   : {results['Test']['F1-Score']:.4f}

{'─'*80}

OUTPUT FILES:
  1. resnet50_best_80_10_10.pth              → Trained model weights
  2. resnet50_80_10_10_all_metrics_curves.png → All metrics per epoch
  3. resnet50_80_10_10_all_confusion_matrices.png → Confusion matrices for all sets
  4. resnet50_80_10_10_comprehensive_metrics.csv → CSV summary table
  5. resnet50_80_10_10_training_report.txt   → This report

ANALYSIS:
  • Model shows balanced generalization (train/val metrics tracking closely)
  • No significant overfitting detected
  • Test performance is reliable indicator of real-world generalization
  • Hyperparameter improvements should yield better results than previous version
  • Compare with MobileNetV2 and EfficientNetB0 for final model selection

{'='*80}
"""

with open(os.path.join(OUTPUT_DIR, 'resnet50_80_10_10_training_report.txt'), 'w') as f:
    f.write(summary_report)

print(summary_report)

# Confusion Matrix for test set only (for comparison)
cm = confusion_matrix(results["Test"]['Labels'], results["Test"]['Predictions'])
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('ResNet50: Test Set Confusion Matrix')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'resnet50_80_10_10_test_cm_only.png'), dpi=300, bbox_inches='tight')
plt.close()
print("✓ Saved: resnet50_80_10_10_test_cm_only.png")

print("\n" + "="*80)
print("RESNET50 TRAINING COMPLETED!")
print("="*80)
print(f"✓ Model saved: {best_model_path}")
print(f"✓ Test Accuracy: {results['Test']['Accuracy']:.4f}")
print(f"✓ All outputs saved to: {OUTPUT_DIR}")
print("="*80)


CELL 2: MODEL 1 - ResNet50 (with 80/10/10 Split)
Complete Metrics: Train, Val, Test - IMPROVED HYPERPARAMETERS

✓ Device: cuda
✓ Using 80/10/10 split from: /kaggle/working/split_80_10_10
✓ Learning Rate: 0.0005 (reduced for stability)
✓ Weight Decay: 0.01 (increased for regularization)
✓ Epochs: 30 (increased for convergence)

[1/7] Defining augmentation...
✓ Augmentation defined

[2/7] Creating dataset class...
✓ Dataset class created

[3/7] Creating dataloaders...
✓ Train: 15967 | Val: 2000 | Test: 2002

[4/7] Loading ResNet50...
✓ ResNet50 loaded (improved architecture with more trainable layers)

[5/7] Setting up training...
✓ Setup complete
✓ Optimizer: AdamW (weight_decay=0.01)
✓ Scheduler: CosineAnnealingLR

[6/7] Training...


                                                                     

Epoch 1/30: TL=1.0197 TA=0.5886 TP=0.5400 TR=0.5886 TF1=0.5489 | VL=0.7899 VA=0.6740 VP=0.6808 VR=0.6740 VF1=0.6235


                                                                     

Epoch 2/30: TL=0.9168 TA=0.6283 TP=0.6071 TR=0.6283 TF1=0.5861 | VL=0.7410 VA=0.6940 VP=0.7142 VR=0.6940 VF1=0.6573


                                                                     

Epoch 3/30: TL=0.8865 TA=0.6414 TP=0.6297 TR=0.6414 TF1=0.6025 | VL=0.7260 VA=0.6970 VP=0.7094 VR=0.6970 VF1=0.6671


                                                                     

Epoch 4/30: TL=0.8712 TA=0.6464 TP=0.6333 TR=0.6464 TF1=0.6083 | VL=0.7192 VA=0.6915 VP=0.6998 VR=0.6915 VF1=0.6593


                                                                     

Epoch 5/30: TL=0.8600 TA=0.6448 TP=0.6305 TR=0.6448 TF1=0.6075 | VL=0.7161 VA=0.7050 VP=0.7258 VR=0.7050 VF1=0.6718


                                                                     

Epoch 6/30: TL=0.8484 TA=0.6555 TP=0.6514 TR=0.6555 TF1=0.6206 | VL=0.7079 VA=0.7090 VP=0.7273 VR=0.7090 VF1=0.6890


                                                                     

Epoch 7/30: TL=0.8371 TA=0.6572 TP=0.6520 TR=0.6572 TF1=0.6246 | VL=0.7211 VA=0.6970 VP=0.6923 VR=0.6970 VF1=0.6697


                                                                     

Epoch 8/30: TL=0.8235 TA=0.6641 TP=0.6569 TR=0.6641 TF1=0.6311 | VL=0.6813 VA=0.7135 VP=0.7403 VR=0.7135 VF1=0.6849


                                                                     

Epoch 9/30: TL=0.8140 TA=0.6678 TP=0.6662 TR=0.6678 TF1=0.6372 | VL=0.6716 VA=0.7215 VP=0.7371 VR=0.7215 VF1=0.7016


                                                                      

Epoch 10/30: TL=0.8119 TA=0.6666 TP=0.6654 TR=0.6666 TF1=0.6371 | VL=0.6600 VA=0.7300 VP=0.7649 VR=0.7300 VF1=0.7080


                                                                      

Epoch 11/30: TL=0.8005 TA=0.6755 TP=0.6765 TR=0.6755 TF1=0.6472 | VL=0.6664 VA=0.7260 VP=0.7457 VR=0.7260 VF1=0.7108


                                                                      

Epoch 12/30: TL=0.7891 TA=0.6769 TP=0.6770 TR=0.6769 TF1=0.6500 | VL=0.6751 VA=0.7220 VP=0.7390 VR=0.7220 VF1=0.6935


                                                                      

Epoch 13/30: TL=0.7917 TA=0.6753 TP=0.6737 TR=0.6753 TF1=0.6467 | VL=0.6583 VA=0.7290 VP=0.7596 VR=0.7290 VF1=0.7091


                                                                      

Epoch 14/30: TL=0.7846 TA=0.6816 TP=0.6829 TR=0.6816 TF1=0.6548 | VL=0.6474 VA=0.7290 VP=0.7413 VR=0.7290 VF1=0.7109


                                                                      

Epoch 15/30: TL=0.7775 TA=0.6830 TP=0.6851 TR=0.6830 TF1=0.6578 | VL=0.6638 VA=0.7265 VP=0.7589 VR=0.7265 VF1=0.7022


                                                                      

Epoch 16/30: TL=0.7682 TA=0.6862 TP=0.6865 TR=0.6862 TF1=0.6595 | VL=0.6516 VA=0.7320 VP=0.7359 VR=0.7320 VF1=0.7170


                                                                      

Epoch 17/30: TL=0.7625 TA=0.6873 TP=0.6859 TR=0.6873 TF1=0.6609 | VL=0.6421 VA=0.7390 VP=0.7659 VR=0.7390 VF1=0.7234


                                                                      

Epoch 18/30: TL=0.7622 TA=0.6887 TP=0.6913 TR=0.6887 TF1=0.6648 | VL=0.6499 VA=0.7330 VP=0.7600 VR=0.7330 VF1=0.7133


                                                                      

Epoch 19/30: TL=0.7501 TA=0.6941 TP=0.6982 TR=0.6941 TF1=0.6700 | VL=0.6508 VA=0.7330 VP=0.7447 VR=0.7330 VF1=0.7156


                                                                      

Epoch 20/30: TL=0.7552 TA=0.6898 TP=0.6941 TR=0.6898 TF1=0.6665 | VL=0.6345 VA=0.7315 VP=0.7449 VR=0.7315 VF1=0.7126


                                                                      

Epoch 21/30: TL=0.7510 TA=0.6900 TP=0.6943 TR=0.6900 TF1=0.6664 | VL=0.6406 VA=0.7360 VP=0.7576 VR=0.7360 VF1=0.7192


                                                                      

Epoch 22/30: TL=0.7428 TA=0.6946 TP=0.6996 TR=0.6946 TF1=0.6719 | VL=0.6360 VA=0.7345 VP=0.7550 VR=0.7345 VF1=0.7167


                                                                      

Epoch 23/30: TL=0.7377 TA=0.6956 TP=0.6984 TR=0.6956 TF1=0.6725 | VL=0.6383 VA=0.7290 VP=0.7409 VR=0.7290 VF1=0.7113


                                                                      

Epoch 24/30: TL=0.7356 TA=0.6962 TP=0.7014 TR=0.6962 TF1=0.6739 | VL=0.6359 VA=0.7375 VP=0.7609 VR=0.7375 VF1=0.7218


                                                                      

Epoch 25/30: TL=0.7298 TA=0.6991 TP=0.7050 TR=0.6991 TF1=0.6778 | VL=0.6316 VA=0.7385 VP=0.7658 VR=0.7385 VF1=0.7227


                                                                      

Epoch 26/30: TL=0.7341 TA=0.6985 TP=0.7065 TR=0.6985 TF1=0.6759 | VL=0.6307 VA=0.7380 VP=0.7658 VR=0.7380 VF1=0.7211


                                                                      

Epoch 27/30: TL=0.7270 TA=0.7006 TP=0.7046 TR=0.7006 TF1=0.6792 | VL=0.6291 VA=0.7370 VP=0.7661 VR=0.7370 VF1=0.7212


                                                                      

Epoch 28/30: TL=0.7306 TA=0.7019 TP=0.7092 TR=0.7019 TF1=0.6808 | VL=0.6285 VA=0.7370 VP=0.7526 VR=0.7370 VF1=0.7236


                                                                      

Epoch 29/30: TL=0.7357 TA=0.6910 TP=0.6934 TR=0.6910 TF1=0.6693 | VL=0.6285 VA=0.7390 VP=0.7593 VR=0.7390 VF1=0.7241


                                                                      

Epoch 30/30: TL=0.7270 TA=0.7004 TP=0.7042 TR=0.7004 TF1=0.6794 | VL=0.6282 VA=0.7375 VP=0.7667 VR=0.7375 VF1=0.7208
✓ Training complete!

[7/7] Testing & Computing Full Metrics...


                                                                             


RESNET50 - COMPREHENSIVE METRICS (80/10/10 SPLIT)

────────────────────────────────────────────────────────────────────────────────
  TRAIN SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7046
  Precision  : 0.7187
  Recall     : 0.7046
  F1-Score   : 0.6827

────────────────────────────────────────────────────────────────────────────────
  VALIDATION SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7390
  Precision  : 0.7659
  Recall     : 0.7390
  F1-Score   : 0.7234

────────────────────────────────────────────────────────────────────────────────
  TEST SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7203
  Precision  : 0.7327
  Recall     : 0.7203
  F1-Score   : 0.7011

────────────────────────────────────────────────────────────────────────────────
TEST SET - DETAILED CLASSIFICATION REPORT
────────

In [14]:
import pandas as pd
import matplotlib.pyplot as plt

# Load results (adjust filenames if needed)
resnet_csv = "/kaggle/working/resnet50_80_10_10_comprehensive_metrics.csv"
mobilenet_csv = "/kaggle/working/mobilenetv2_80_10_10_comprehensive_metrics.csv"
efficientnet_csv = "/kaggle/working/efficientnetb0_80_10_10_comprehensive_metrics.csv"

resnet = pd.read_csv(resnet_csv)
mobilenet = pd.read_csv(mobilenet_csv)
efficientnet = pd.read_csv(efficientnet_csv)

# Add model name for easier merging
resnet['Model'] = 'ResNet50'
mobilenet['Model'] = 'MobileNetV2'
efficientnet['Model'] = 'EfficientNetB0'

# Helper for formatting only float columns
def styled_table(df, caption):
    numeric_cols = df.select_dtypes('number').columns
    return df.style.set_caption(caption).format({col: '{:.4f}'.format for col in numeric_cols})

# --------
# 1. Display Individual Model Tables (Markdown, and PNG to output section)
# --------
for df, model in zip([resnet, mobilenet, efficientnet], ['ResNet50', 'MobileNetV2', 'EfficientNetB0']):
    print(f"\n## {model} - All Metrics")
    display(df)
    # PNG export
    fig, ax = plt.subplots(figsize=(8, 2))
    ax.axis('off')
    table = ax.table(
        cellText=df.values, colLabels=df.columns, loc='center',
        cellLoc='center', colLoc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1.3, 1.7)
    ax.set_title(f"{model} - All Metrics", fontsize=16)
    plt.savefig(f"/kaggle/working/{model.lower()}_all_metrics_table.png", dpi=300, bbox_inches="tight")
    plt.close(fig)

# --------
# 2. Combined Table for All 3 Models (Test Set Only)
# --------
combined = pd.concat([
    resnet[resnet["Set"].str.lower() == "test"],
    mobilenet[mobilenet["Set"].str.lower() == "test"],
    efficientnet[efficientnet["Set"].str.lower() == "test"]
], ignore_index=True)[["Model", "Accuracy", "Precision", "Recall", "F1-Score"]]

print("\n## All Models - Test Set Comparison")
display(combined)

fig, ax = plt.subplots(figsize=(8,2))
ax.axis('off')
table = ax.table(
    cellText=combined.values,
    colLabels=combined.columns,
    loc='center',
    cellLoc='center', colLoc='center'
)
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.3, 1.7)
ax.set_title("All Models - Test Set Comparison", fontsize=16)
plt.savefig("/kaggle/working/all_models_comparison_table.png", dpi=300, bbox_inches="tight")
plt.close(fig)

print("✓ All tables created: shown below and PNG versions saved to /kaggle/working")



## ResNet50 - All Metrics


Unnamed: 0,Set,Accuracy,Precision,Recall,F1-Score,Model
0,Train,0.7046,0.7187,0.7046,0.6827,ResNet50
1,Validation,0.739,0.7659,0.739,0.7234,ResNet50
2,Test,0.7203,0.7327,0.7203,0.7011,ResNet50



## MobileNetV2 - All Metrics


Unnamed: 0,Set,Accuracy,Precision,Recall,F1-Score,Model
0,Train,0.7017,0.7174,0.7017,0.6764,MobileNetV2
1,Validation,0.732,0.7545,0.732,0.7138,MobileNetV2
2,Test,0.7058,0.7133,0.7058,0.6786,MobileNetV2



## EfficientNetB0 - All Metrics


Unnamed: 0,Set,Accuracy,Precision,Recall,F1-Score,Model
0,Train,0.7132,0.7305,0.7132,0.6899,EfficientNetB0
1,Validation,0.743,0.7629,0.743,0.7238,EfficientNetB0
2,Test,0.7178,0.7281,0.7178,0.6945,EfficientNetB0



## All Models - Test Set Comparison


Unnamed: 0,Model,Accuracy,Precision,Recall,F1-Score
0,ResNet50,0.7203,0.7327,0.7203,0.7011
1,MobileNetV2,0.7058,0.7133,0.7058,0.6786
2,EfficientNetB0,0.7178,0.7281,0.7178,0.6945


✓ All tables created: shown below and PNG versions saved to /kaggle/working
