In [3]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import timm
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("NOTEBOOK 5: TRANSFORMER MODELS (ViT & Swin Transformer)")
print("="*80)

# ========== DATASET PATH DETECTION ==========
print("\n[0/6] Detecting dataset path...")

# Check all possible locations
possible_paths = [
    '/kaggle/working/split_80_10_10',           # From previous notebook
    '/kaggle/input/split_dataset/split_80_10_10',
    '/kaggle/input/split-dataset/split_80_10_10',
]

SPLIT_DIR = None
for path in possible_paths:
    if os.path.exists(path):
        train_path = os.path.join(path, 'train')
        if os.path.exists(train_path):
            SPLIT_DIR = path
            print(f"✓ Dataset found at: {SPLIT_DIR}")
            break

# If not found in standard paths, search recursively
if SPLIT_DIR is None:
    print("Searching recursively for split_80_10_10...")
    for search_root in ['/kaggle/working', '/kaggle/input']:
        if not os.path.exists(search_root):
            continue
        for root, dirs, files in os.walk(search_root):
            if 'split_80_10_10' in root:
                if os.path.exists(os.path.join(root, 'train')):
                    SPLIT_DIR = root
                    print(f"✓ Found at: {SPLIT_DIR}")
                    break
            if 'train' in dirs and 'val' in dirs and 'test' in dirs:
                # Check if this looks like our split directory
                train_classes = os.listdir(os.path.join(root, 'train'))
                if len(train_classes) >= 4:  # We have 4 classes
                    SPLIT_DIR = root
                    print(f"✓ Found at: {SPLIT_DIR}")
                    break
        if SPLIT_DIR:
            break

if SPLIT_DIR is None:
    print("\n❌ Dataset not found in expected locations!")
    print("\nPlease check:")
    print("1. Did you run the notebook that created split_80_10_10?")
    print("2. Is it saved in /kaggle/working/split_80_10_10?")
    print("\nAvailable paths:")
    print("  /kaggle/working contents:")
    try:
        for item in os.listdir('/kaggle/working'):
            print(f"    - {item}")
    except:
        pass
    raise FileNotFoundError("split_80_10_10 directory not found!")

OUTPUT_DIR = '/kaggle/working'
TRAIN_DIR = os.path.join(SPLIT_DIR, 'train')
VAL_DIR = os.path.join(SPLIT_DIR, 'val')
TEST_DIR = os.path.join(SPLIT_DIR, 'test')

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_WORKERS = 4
NUM_EPOCHS_VIT = 20
NUM_EPOCHS_SWIN = 20
LEARNING_RATE = 0.0001

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Using split from: {SPLIT_DIR}")

# Verify directories
try:
    train_classes = sorted(os.listdir(TRAIN_DIR))
    val_classes = sorted(os.listdir(VAL_DIR))
    test_classes = sorted(os.listdir(TEST_DIR))
    
    print(f"\nVerifying dataset structure:")
    print(f"  Train: {len(train_classes)} classes - {train_classes}")
    print(f"  Val:   {len(val_classes)} classes - {val_classes}")
    print(f"  Test:  {len(test_classes)} classes - {test_classes}")
    
    # Count images
    train_count = sum(len(os.listdir(os.path.join(TRAIN_DIR, c))) for c in train_classes)
    val_count = sum(len(os.listdir(os.path.join(VAL_DIR, c))) for c in val_classes)
    test_count = sum(len(os.listdir(os.path.join(TEST_DIR, c))) for c in test_classes)
    
    print(f"\n  Total Train images: {train_count}")
    print(f"  Total Val images:   {val_count}")
    print(f"  Total Test images:  {test_count}")
    print(f"  ✓ Dataset verified!")
except Exception as e:
    print(f"\n❌ Error verifying dataset: {e}")
    raise

# ========== 1. AUGMENTATION ==========
print("\n[1/6] Defining augmentation...")

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=20, p=0.5),
    A.GaussNoise(p=0.2),
    A.GaussianBlur(blur_limit=3, p=0.2),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
    A.ElasticTransform(p=0.2),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
], is_check_shapes=False)

val_test_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
], is_check_shapes=False)

print("✓ Augmentation defined")

# ========== 2. DATASET CLASS ==========
print("\n[2/6] Creating dataset class...")

class OCTDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        
        classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
        
        for cls in classes:
            cls_path = os.path.join(root_dir, cls)
            images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            for img_name in images:
                self.image_paths.append(os.path.join(cls_path, img_name))
                self.labels.append(self.class_to_idx[cls])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label

print("✓ Dataset class created")

# ========== 3. DATALOADERS ==========
print("\n[3/6] Creating dataloaders...")

train_dataset = OCTDataset(TRAIN_DIR, transform=train_transform)
val_dataset = OCTDataset(VAL_DIR, transform=val_test_transform)
test_dataset = OCTDataset(TEST_DIR, transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

num_classes = len(train_dataset.idx_to_class)
class_names = list(train_dataset.idx_to_class.values())

print(f"✓ Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# ========== 4. UTILITY FUNCTIONS ==========
def compute_metrics(y_true, y_pred):
    """Compute accuracy, precision, recall, F1"""
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    rec = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    return acc, prec, rec, f1

def train_transformer_model(model_name, model, train_loader, val_loader, test_loader, num_epochs, learning_rate):
    """Generic training function for transformer models"""
    
    print(f"\n{'='*80}")
    print(f"TRAINING {model_name.upper()}")
    print(f"{'='*80}")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_val_acc = 0.0
    best_model_path = os.path.join(OUTPUT_DIR, f'{model_name}_best_80_10_10.pth')
    
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    train_precisions, val_precisions = [], []
    train_recalls, val_recalls = [], []
    train_f1s, val_f1s = [], []
    
    # ===== TRAINING LOOP =====
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        train_preds, train_labels = [], []
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [TRAIN]", leave=False)
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_preds.extend(preds.cpu().numpy())
            train_labels.extend(labels.cpu().numpy())
        
        train_loss /= len(train_loader)
        train_acc, train_prec, train_rec, train_f1 = compute_metrics(train_labels, train_preds)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        train_precisions.append(train_prec)
        train_recalls.append(train_rec)
        train_f1s.append(train_f1)
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_preds, val_labels = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        
        val_loss /= len(val_loader)
        val_acc, val_prec, val_rec, val_f1 = compute_metrics(val_labels, val_preds)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_precisions.append(val_prec)
        val_recalls.append(val_rec)
        val_f1s.append(val_f1)
        
        scheduler.step()
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
        
        print(f"Epoch {epoch+1}/{num_epochs}: TL={train_loss:.4f} TA={train_acc:.4f} TF1={train_f1:.4f} | VL={val_loss:.4f} VA={val_acc:.4f} VF1={val_f1:.4f}")
    
    print("✓ Training complete!")
    
    # ===== TESTING =====
    print(f"\nTesting {model_name}...")
    model.load_state_dict(torch.load(best_model_path))
    model.eval()
    
    results = {}
    for split_name, loader in [("Train", train_loader), ("Validation", val_loader), ("Test", test_loader)]:
        preds, labels = [], []
        with torch.no_grad():
            pbar = tqdm(loader, desc=f"Computing {split_name} metrics", leave=False)
            for images, image_labels in pbar:
                images, image_labels = images.to(DEVICE), image_labels.to(DEVICE)
                outputs = model(images)
                _, batch_preds = torch.max(outputs, 1)
                preds.extend(batch_preds.cpu().numpy())
                labels.extend(image_labels.cpu().numpy())
        
        acc, prec, rec, f1 = compute_metrics(labels, preds)
        results[split_name] = {
            'Accuracy': acc,
            'Precision': prec,
            'Recall': rec,
            'F1-Score': f1,
            'Predictions': preds,
            'Labels': labels
        }
    
    # ===== DISPLAY RESULTS =====
    print("\n" + "="*80)
    print(f"{model_name.upper()} - COMPREHENSIVE METRICS (80/10/10 SPLIT)")
    print("="*80)
    
    for split_name in ["Train", "Validation", "Test"]:
        print(f"\n{'─'*80}")
        print(f"  {split_name.upper()} SET METRICS")
        print(f"{'─'*80}")
        print(f"  Accuracy   : {results[split_name]['Accuracy']:.4f}")
        print(f"  Precision  : {results[split_name]['Precision']:.4f}")
        print(f"  Recall     : {results[split_name]['Recall']:.4f}")
        print(f"  F1-Score   : {results[split_name]['F1-Score']:.4f}")
    
    print(f"\n{'─'*80}")
    print("TEST SET - DETAILED CLASSIFICATION REPORT")
    print(f"{'─'*80}")
    print(classification_report(results["Test"]['Labels'], results["Test"]['Predictions'], target_names=class_names))
    
    # ===== VISUALIZATIONS =====
    fig, axes = plt.subplots(2, 2, figsize=(16, 10))
    
    axes[0, 0].plot(train_losses, label='Train', marker='o', linewidth=2)
    axes[0, 0].plot(val_losses, label='Val', marker='s', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title(f'{model_name}: Training & Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].plot(train_accs, label='Train', marker='o', linewidth=2)
    axes[0, 1].plot(val_accs, label='Val', marker='s', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title(f'{model_name}: Training & Validation Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    axes[1, 0].plot(train_precisions, label='Train', marker='o', linewidth=2)
    axes[1, 0].plot(val_precisions, label='Val', marker='s', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].set_title(f'{model_name}: Training & Validation Precision')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].plot(train_f1s, label='Train', marker='o', linewidth=2)
    axes[1, 1].plot(val_f1s, label='Val', marker='s', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('F1-Score')
    axes[1, 1].set_title(f'{model_name}: Training & Validation F1-Score')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'{model_name}_80_10_10_all_metrics_curves.png'), dpi=300, bbox_inches='tight')
    print(f"✓ Saved: {model_name}_80_10_10_all_metrics_curves.png")
    plt.close()
    
    # Confusion matrices
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    for idx, (split_name, ax) in enumerate(zip(["Train", "Validation", "Test"], axes)):
        cm = confusion_matrix(results[split_name]['Labels'], results[split_name]['Predictions'])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=ax)
        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        ax.set_title(f'{model_name}: {split_name} Confusion Matrix')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'{model_name}_80_10_10_all_confusion_matrices.png'), dpi=300, bbox_inches='tight')
    print(f"✓ Saved: {model_name}_80_10_10_all_confusion_matrices.png")
    plt.close()
    
    # Save metrics to CSV
    metrics_df = pd.DataFrame({
        'Set': ['Train', 'Validation', 'Test'],
        'Accuracy': [results['Train']['Accuracy'], results['Validation']['Accuracy'], results['Test']['Accuracy']],
        'Precision': [results['Train']['Precision'], results['Validation']['Precision'], results['Test']['Precision']],
        'Recall': [results['Train']['Recall'], results['Validation']['Recall'], results['Test']['Recall']],
        'F1-Score': [results['Train']['F1-Score'], results['Validation']['F1-Score'], results['Test']['F1-Score']]
    })
    
    metrics_df.to_csv(os.path.join(OUTPUT_DIR, f'{model_name}_80_10_10_comprehensive_metrics.csv'), index=False)
    print(f"✓ Saved: {model_name}_80_10_10_comprehensive_metrics.csv")
    
    return results, metrics_df

# ========== 5A. VISION TRANSFORMER (ViT) ==========
print("\n[4/6] Training Vision Transformer (ViT)...")

vit_model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)
vit_model = vit_model.to(DEVICE)
print("✓ ViT loaded")

vit_results, vit_metrics = train_transformer_model(
    'vit', vit_model, train_loader, val_loader, test_loader, 
    NUM_EPOCHS_VIT, LEARNING_RATE
)

# ========== 5B. SWIN TRANSFORMER ==========
print("\n[5/6] Training Swin Transformer...")

swin_model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=num_classes)
swin_model = swin_model.to(DEVICE)
print("✓ Swin Transformer loaded")

swin_results, swin_metrics = train_transformer_model(
    'swin', swin_model, train_loader, val_loader, test_loader, 
    NUM_EPOCHS_SWIN, LEARNING_RATE
)

# ========== 6. COMPARISON TABLE ==========
print("\n[6/6] Creating comparison tables...")

# Individual model tables
for df, name in zip([vit_metrics, swin_metrics], ['vit', 'swin']):
    plt.figure(figsize=(8, 2))
    table = plt.table(cellText=df.values, colLabels=df.columns, loc='center', cellLoc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1.3, 1.7)
    plt.axis('off')
    plt.title(f"{name.upper()} - All Metrics", fontsize=14)
    plt.savefig(os.path.join(OUTPUT_DIR, f'{name}_all_metrics_table.png'), dpi=300, bbox_inches='tight')
    plt.close()

# Combined comparison (Test Set Only)
combined = pd.concat([
    vit_metrics[vit_metrics["Set"].str.lower() == "test"].assign(Model='ViT'),
    swin_metrics[swin_metrics["Set"].str.lower() == "test"].assign(Model='Swin')
], ignore_index=True)[["Model", "Accuracy", "Precision", "Recall", "F1-Score"]]

plt.figure(figsize=(8, 2))
table = plt.table(cellText=combined.values, colLabels=combined.columns, loc='center', cellLoc='center')
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.3, 1.7)
plt.axis('off')
plt.title("Transformer Models - Test Set Comparison", fontsize=14)
plt.savefig(os.path.join(OUTPUT_DIR, 'transformers_comparison_table.png'), dpi=300, bbox_inches='tight')
plt.close()

print("✓ All comparison tables saved!")

# ========== FINAL SUMMARY ==========
print("\n" + "="*80)
print("SUMMARY: ALL TRANSFORMER MODELS TRAINED")
print("="*80)
print("\nViT Test Results:")
print(f"  Accuracy:  {vit_results['Test']['Accuracy']:.4f}")
print(f"  Precision: {vit_results['Test']['Precision']:.4f}")
print(f"  Recall:    {vit_results['Test']['Recall']:.4f}")
print(f"  F1-Score:  {vit_results['Test']['F1-Score']:.4f}")

print("\nSwin Transformer Test Results:")
print(f"  Accuracy:  {swin_results['Test']['Accuracy']:.4f}")
print(f"  Precision: {swin_results['Test']['Precision']:.4f}")
print(f"  Recall:    {swin_results['Test']['Recall']:.4f}")
print(f"  F1-Score:  {swin_results['Test']['F1-Score']:.4f}")

print(f"\n✓ All outputs saved to: {OUTPUT_DIR}")
print("="*80)

 

NOTEBOOK 5: TRANSFORMER MODELS (ViT & Swin Transformer)

[0/6] Detecting dataset path...
Searching recursively for split_80_10_10...
✓ Found at: /kaggle/input/split-dataset

✓ Device: cuda
✓ Using split from: /kaggle/input/split-dataset

Verifying dataset structure:
  Train: 4 classes - ['CNV', 'DME', 'DRUSEN', 'NORMAL']
  Val:   4 classes - ['CNV', 'DME', 'DRUSEN', 'NORMAL']
  Test:  4 classes - ['CNV', 'DME', 'DRUSEN', 'NORMAL']

  Total Train images: 15967
  Total Val images:   2000
  Total Test images:  2002
  ✓ Dataset verified!

[1/6] Defining augmentation...
✓ Augmentation defined

[2/6] Creating dataset class...
✓ Dataset class created

[3/6] Creating dataloaders...
✓ Train: 15967 | Val: 2000 | Test: 2002

[4/6] Training Vision Transformer (ViT)...


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

✓ ViT loaded

TRAINING VIT


                                                                     

Epoch 1/20: TL=0.8139 TA=0.6741 TF1=0.6487 | VL=0.6642 VA=0.7290 VF1=0.7169


                                                                     

Epoch 2/20: TL=0.7013 TA=0.7128 TF1=0.6953 | VL=0.6388 VA=0.7385 VF1=0.7276


                                                                     

Epoch 3/20: TL=0.6774 TA=0.7234 TF1=0.7061 | VL=0.6053 VA=0.7315 VF1=0.7167


                                                                     

Epoch 4/20: TL=0.6623 TA=0.7265 TF1=0.7111 | VL=0.5933 VA=0.7510 VF1=0.7361


                                                                     

Epoch 5/20: TL=0.6382 TA=0.7350 TF1=0.7204 | VL=0.5717 VA=0.7625 VF1=0.7500


                                                                     

Epoch 6/20: TL=0.6162 TA=0.7467 TF1=0.7340 | VL=0.5485 VA=0.7630 VF1=0.7481


                                                                     

Epoch 7/20: TL=0.6079 TA=0.7465 TF1=0.7338 | VL=0.5727 VA=0.7615 VF1=0.7477


                                                                     

Epoch 8/20: TL=0.5995 TA=0.7512 TF1=0.7388 | VL=0.5526 VA=0.7685 VF1=0.7552


                                                                     

Epoch 9/20: TL=0.5878 TA=0.7543 TF1=0.7428 | VL=0.5366 VA=0.7740 VF1=0.7651


                                                                      

Epoch 10/20: TL=0.5734 TA=0.7601 TF1=0.7491 | VL=0.5432 VA=0.7720 VF1=0.7592


                                                                      

Epoch 11/20: TL=0.5603 TA=0.7643 TF1=0.7539 | VL=0.5354 VA=0.7705 VF1=0.7615


                                                                      

Epoch 12/20: TL=0.5415 TA=0.7718 TF1=0.7624 | VL=0.5336 VA=0.7815 VF1=0.7723


                                                                      

Epoch 13/20: TL=0.5308 TA=0.7752 TF1=0.7662 | VL=0.5478 VA=0.7745 VF1=0.7654


                                                                      

Epoch 14/20: TL=0.5240 TA=0.7817 TF1=0.7730 | VL=0.5350 VA=0.7705 VF1=0.7657


                                                                      

Epoch 15/20: TL=0.5132 TA=0.7832 TF1=0.7752 | VL=0.5164 VA=0.7860 VF1=0.7788


                                                                      

Epoch 16/20: TL=0.4988 TA=0.7888 TF1=0.7815 | VL=0.5100 VA=0.7885 VF1=0.7803


                                                                      

Epoch 17/20: TL=0.4930 TA=0.7913 TF1=0.7843 | VL=0.5206 VA=0.7830 VF1=0.7752


                                                                      

Epoch 18/20: TL=0.4819 TA=0.7947 TF1=0.7881 | VL=0.5191 VA=0.7835 VF1=0.7757


                                                                      

Epoch 19/20: TL=0.4810 TA=0.7957 TF1=0.7891 | VL=0.5227 VA=0.7845 VF1=0.7765


                                                                      

Epoch 20/20: TL=0.4774 TA=0.7964 TF1=0.7899 | VL=0.5199 VA=0.7840 VF1=0.7761
✓ Training complete!

Testing vit...


                                                                             


VIT - COMPREHENSIVE METRICS (80/10/10 SPLIT)

────────────────────────────────────────────────────────────────────────────────
  TRAIN SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7928
  Precision  : 0.8174
  Recall     : 0.7928
  F1-Score   : 0.7857

────────────────────────────────────────────────────────────────────────────────
  VALIDATION SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7885
  Precision  : 0.8122
  Recall     : 0.7885
  F1-Score   : 0.7803

────────────────────────────────────────────────────────────────────────────────
  TEST SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7727
  Precision  : 0.7990
  Recall     : 0.7727
  F1-Score   : 0.7628

────────────────────────────────────────────────────────────────────────────────
TEST SET - DETAILED CLASSIFICATION REPORT
─────────────

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

✓ Swin Transformer loaded

TRAINING SWIN


                                                                     

Epoch 1/20: TL=0.7639 TA=0.6885 TF1=0.6617 | VL=0.5981 VA=0.7580 VF1=0.7487


                                                                     

Epoch 2/20: TL=0.6480 TA=0.7318 TF1=0.7160 | VL=0.5821 VA=0.7470 VF1=0.7397


                                                                     

Epoch 3/20: TL=0.6303 TA=0.7411 TF1=0.7275 | VL=0.5259 VA=0.7790 VF1=0.7676


                                                                     

Epoch 4/20: TL=0.6047 TA=0.7507 TF1=0.7381 | VL=0.5113 VA=0.7840 VF1=0.7773


                                                                     

Epoch 5/20: TL=0.5879 TA=0.7562 TF1=0.7446 | VL=0.5128 VA=0.7845 VF1=0.7758


                                                                     

Epoch 6/20: TL=0.5780 TA=0.7593 TF1=0.7478 | VL=0.5135 VA=0.7855 VF1=0.7768


                                                                     

Epoch 7/20: TL=0.5582 TA=0.7640 TF1=0.7546 | VL=0.5574 VA=0.7745 VF1=0.7612


                                                                     

Epoch 8/20: TL=0.5548 TA=0.7687 TF1=0.7593 | VL=0.5268 VA=0.7830 VF1=0.7732


                                                                     

Epoch 9/20: TL=0.5434 TA=0.7725 TF1=0.7631 | VL=0.5044 VA=0.7850 VF1=0.7771


                                                                      

Epoch 10/20: TL=0.5374 TA=0.7752 TF1=0.7667 | VL=0.4961 VA=0.7890 VF1=0.7813


                                                                      

Epoch 11/20: TL=0.5268 TA=0.7799 TF1=0.7713 | VL=0.5025 VA=0.7920 VF1=0.7842


                                                                      

Epoch 12/20: TL=0.5166 TA=0.7807 TF1=0.7728 | VL=0.4952 VA=0.7900 VF1=0.7816


                                                                      

Epoch 13/20: TL=0.5051 TA=0.7811 TF1=0.7736 | VL=0.5014 VA=0.7895 VF1=0.7822


                                                                      

Epoch 14/20: TL=0.4990 TA=0.7876 TF1=0.7805 | VL=0.4997 VA=0.7865 VF1=0.7798


                                                                      

Epoch 15/20: TL=0.4889 TA=0.7908 TF1=0.7842 | VL=0.4963 VA=0.7960 VF1=0.7891


                                                                      

Epoch 16/20: TL=0.4847 TA=0.7937 TF1=0.7876 | VL=0.4972 VA=0.7920 VF1=0.7852


                                                                      

Epoch 17/20: TL=0.4816 TA=0.7938 TF1=0.7873 | VL=0.5005 VA=0.7965 VF1=0.7898


                                                                      

Epoch 18/20: TL=0.4731 TA=0.7955 TF1=0.7895 | VL=0.4975 VA=0.7920 VF1=0.7854


                                                                      

Epoch 19/20: TL=0.4727 TA=0.7982 TF1=0.7922 | VL=0.4962 VA=0.7960 VF1=0.7892


                                                                      

Epoch 20/20: TL=0.4703 TA=0.7975 TF1=0.7913 | VL=0.4962 VA=0.7970 VF1=0.7902
✓ Training complete!

Testing swin...


                                                                             


SWIN - COMPREHENSIVE METRICS (80/10/10 SPLIT)

────────────────────────────────────────────────────────────────────────────────
  TRAIN SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.8022
  Precision  : 0.8340
  Recall     : 0.8022
  F1-Score   : 0.7965

────────────────────────────────────────────────────────────────────────────────
  VALIDATION SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7970
  Precision  : 0.8315
  Recall     : 0.7970
  F1-Score   : 0.7902

────────────────────────────────────────────────────────────────────────────────
  TEST SET METRICS
────────────────────────────────────────────────────────────────────────────────
  Accuracy   : 0.7782
  Precision  : 0.8102
  Recall     : 0.7782
  F1-Score   : 0.7698

────────────────────────────────────────────────────────────────────────────────
TEST SET - DETAILED CLASSIFICATION REPORT
────────────