In [None]:


import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

from utils.base_model import Base3DCNN, Enhanced3DCNN
from utils.data_loader import get_medmnist_dataloaders
from utils.trainer import Trainer
from utils.metrics import evaluate_model, compute_metrics
from utils.visualization import plot_training_history, plot_confusion_matrix
from config import *

print(f"Device: {DEVICE}")
set_seed(42)



Collecting seaborn
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting pandas>=1.2 (from seaborn)
  Using cached pandas-2.3.3-cp311-cp311-win_amd64.whl.metadata (19 kB)
Collecting matplotlib!=3.6.1,>=3.4 (from seaborn)
  Using cached matplotlib-3.10.7-cp311-cp311-win_amd64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib!=3.6.1,>=3.4->seaborn)
  Using cached contourpy-1.3.3-cp311-cp311-win_amd64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib!=3.6.1,>=3.4->seaborn)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib!=3.6.1,>=3.4->seaborn)
  Using cached fonttools-4.60.1-cp311-cp311-win_amd64.whl.metadata (114 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib!=3.6.1,>=3.4->seaborn)
  Using cached kiwisolver-1.4.9-cp311-cp311-win_amd64.whl.metadata (6.4 kB)
Collecting packaging>=20.0 (from matplotlib!=3.6.1,>=3.4->seaborn)
  Using cached packaging-25.0-py3-none-any.whl.metadat

ModuleNotFoundError: No module named 'utils.trainer'

In [None]:
# Load OrganMNIST3D dataset
train_loader, val_loader, test_loader, num_classes = get_medmnist_dataloaders(
    dataset_name='organ',
    batch_size=DATA_CONFIG['batch_size'],
    num_workers=DATA_CONFIG['num_workers']
)

print(f"Number of classes: {num_classes}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Initialize model
if MODEL_CONFIG['architecture'] == 'enhanced':
    model = Enhanced3DCNN(
        in_channels=1,
        num_classes=num_classes,
        dropout_rate=MODEL_CONFIG['dropout_rate']
    )
else:
    model = Base3DCNN(
        in_channels=1,
        num_classes=num_classes,
        dropout_rate=MODEL_CONFIG['dropout_rate']
    )

model = model.to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Architecture: {MODEL_CONFIG['architecture']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"\nModel Summary:")
print(model)

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=TRAINING_CONFIG['learning_rate'],
    weight_decay=TRAINING_CONFIG['weight_decay']
)
scheduler = StepLR(
    optimizer,
    step_size=TRAINING_CONFIG['scheduler_step_size'],
    gamma=TRAINING_CONFIG['scheduler_gamma']
)

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=DEVICE,
    scheduler=scheduler
)

print("Training setup complete!")

In [None]:
# Train the model
num_epochs = 30
print(f"Training for {num_epochs} epochs...\n")

history = trainer.train(num_epochs=num_epochs)

print("\n" + "="*50)
print("Training completed!")
print(f"Best validation accuracy: {max(history['val_acc']):.4f}")
print("="*50)

In [None]:
plot_training_history(history, save_path='../figures/baseline_training_history.png')

In [None]:
# Evaluate on test set
print("Evaluating on test set...\n")
test_metrics, test_preds, test_labels = evaluate_model(model, test_loader, DEVICE)

print("Test Set Results:")
print(f"  Accuracy:  {test_metrics['accuracy']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  Recall:    {test_metrics['recall']:.4f}")
print(f"  F1-Score:  {test_metrics['f1_score']:.4f}")

print("\nPer-class metrics:")
for i, organ_name in ORGAN_CLASSES.items():
    if i < len(test_metrics['per_class']['f1_score']):
        f1 = test_metrics['per_class']['f1_score'][i]
        print(f"  {organ_name:15s}: F1={f1:.4f}")

In [None]:
# Plot confusion matrix
class_names = [ORGAN_CLASSES[i] for i in range(num_classes)]
plot_confusion_matrix(
    test_metrics['confusion_matrix'],
    class_names=class_names,
    save_path='../figures/baseline_confusion_matrix.png'
)

In [None]:
# Save trained model
import os
os.makedirs('../models', exist_ok=True)

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': history,
    'test_metrics': test_metrics,
    'num_classes': num_classes,
}, '../models/baseline_flat_model.pth')

print("Model saved to '../models/baseline_flat_model.pth'")