# Baseline Flat Classification

This notebook trains a **flat (non-hierarchical) classifier** on the merged 3D MedMNIST datasets for comparison with the hierarchical approach.

The flat classifier predicts the fine-grained label directly without the intermediate region classification step.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from config import (
    DEVICE, DATA_CONFIG, MODEL_CONFIG, TRAINING_CONFIG,
    PATHS, set_seed, DEFAULT_MERGED_DATASETS
)
from utils.data_loader import create_hierarchical_dataset
from utils.cnn_3d_models import get_3d_model
from utils.trainer import Trainer

set_seed(42)
print(f"Device: {DEVICE}")

## 1. Load Merged Dataset

In [None]:
# Load all 5 merged datasets
print(f"Loading datasets: {DEFAULT_MERGED_DATASETS}")

train_loader, val_loader, test_loader, dataset_info = create_hierarchical_dataset(
    datasets_to_include=DEFAULT_MERGED_DATASETS,
    batch_size=DATA_CONFIG['batch_size'],
    num_workers=DATA_CONFIG['num_workers']
)

print("\n" + "="*60)
print("MERGED DATASET INFO")
print("="*60)
print(f"Datasets: {dataset_info['datasets_included']}")
print(f"Train samples: {dataset_info['train_samples']:,}")
print(f"Val samples: {dataset_info['val_samples']:,}")
print(f"Test samples: {dataset_info['test_samples']:,}")
print(f"Fine classes: {dataset_info['num_fine_classes']}")

## 2. Create Wrapper DataLoaders for Flat Training

The hierarchical dataset returns `(img, coarse_label, fine_label)`. For flat training, we only need `(img, fine_label)`.

In [None]:
class FlatDataLoaderWrapper:
    """Wrapper that extracts only (img, fine_label) from hierarchical dataloader."""
    def __init__(self, hierarchical_loader):
        self.loader = hierarchical_loader
    
    def __iter__(self):
        for imgs, coarse_labels, fine_labels in self.loader:
            yield imgs, fine_labels
    
    def __len__(self):
        return len(self.loader)

# Create flat loaders
flat_train_loader = FlatDataLoaderWrapper(train_loader)
flat_val_loader = FlatDataLoaderWrapper(val_loader)
flat_test_loader = FlatDataLoaderWrapper(test_loader)

print("\n✓ Flat data loaders created")

## 3. Build Flat Classifier

In [None]:
# Create flat classifier
num_classes = dataset_info['num_fine_classes']
architecture = MODEL_CONFIG['architecture']

model = get_3d_model(
    model_name=architecture,
    num_classes=num_classes
).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nFlat classifier created:")
print(f"  Architecture: {architecture}")
print(f"  Number of classes: {num_classes}")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## 4. Train Flat Classifier

In [None]:
# Create trainer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=TRAINING_CONFIG['learning_rate'],
    weight_decay=TRAINING_CONFIG['weight_decay']
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=TRAINING_CONFIG['scheduler_step_size'],
    gamma=TRAINING_CONFIG['scheduler_gamma']
)

trainer = Trainer(
    model=model,
    train_loader=flat_train_loader,
    val_loader=flat_val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=DEVICE,
    scheduler=scheduler
)

# Total epochs = coarse + fine epochs for fair comparison
total_epochs = TRAINING_CONFIG['coarse_epochs'] + TRAINING_CONFIG['fine_epochs']
print(f"Training for {total_epochs} epochs...")

In [None]:
# Train
history = trainer.train(num_epochs=total_epochs)

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)

## 5. Training Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train Loss', color='blue')
axes[0].plot(history['val_loss'], label='Val Loss', color='orange')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Flat Classifier Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train Acc', color='blue')
axes[1].plot(history['val_acc'], label='Val Acc', color='orange')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Flat Classifier Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{PATHS['figures']}/baseline_training_{architecture}.png", dpi=150)
plt.show()

## 6. Evaluation on Test Set

In [None]:
def evaluate_flat_model(model, test_loader, device):
    """Evaluate flat model on test set."""
    model.eval()
    
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="Evaluating"):
            imgs = imgs.to(device, dtype=torch.float32)
            if imgs.max() > 1:
                imgs = imgs / 255.0
            
            labels = labels.squeeze(-1).long().to(device)
            
            outputs = model(imgs)
            preds = outputs.argmax(1)
            
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = correct / total if total > 0 else 0
    
    return {
        'accuracy': accuracy,
        'predictions': np.array(all_preds),
        'labels': np.array(all_labels),
    }

In [None]:
# Evaluate
results = evaluate_flat_model(model, flat_test_loader, DEVICE)

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)
print(f"\nTest Accuracy: {results['accuracy']:.4f}")

## 7. Save Model

In [None]:
import os
os.makedirs(PATHS['models'], exist_ok=True)

model_path = f"{PATHS['models']}/baseline_{architecture}.pth"





fffffffffflogsdjmjfjdfjdf,,kkkkddd
torch.save({
    'model_state_dict': model.state_dict(),
    'dataset_info': dataset_info,
    'history': history,
    'test_results': results,
    'config': {
        'architecture': architecture,
        'num_classes': num_classes,
    }
}, model_path)

print(f"\n✓ Model saved to: {model_path}")

## 8. Summary

In [None]:
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"\nDatasets used: {', '.join(dataset_info['datasets_included'])}")
print(f"Total training samples: {dataset_info['train_samples']:,}")
print(f"\nModel architecture: {architecture}")
print(f"Number of classes: {num_classes}")
print(f"Total parameters: {total_params:,}")
print(f"\nFinal Test Accuracy: {results['accuracy']:.4f}")
print(f"\nModel saved to: {model_path}")