# Training and Evaluation
This notebook trains the model and evaluates it on the test dataset.

In [1]:
import sys
import os
sys.path.append(os.path.abspath('../src'))

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader
from inference import TumorPredictor
from train import train_model
from preprocessor import ImageDataset
import yaml

In [2]:
# Load configuration
with open("../configs/config.yaml", 'r') as f:
    config = yaml.safe_load(f)


In [3]:
# 1. Training
print("Starting training...")
train_model(config)


Starting training...


AttributeError: 'ImageDataset' object has no attribute 'image_dir'

In [None]:
# 2. Evaluation
def evaluate_model(predictor, data_loader):
    true_classes = []
    pred_classes = []
    seg_ious = []
    
    for images, masks, labels in tqdm(data_loader, desc="Evaluating"):
        batch_results = predictor.predict_batch(images)
        
        # Collect classification results
        true_classes.extend(labels.numpy())
        pred_classes.extend([r['class_idx'] for r in batch_results])
        
        # Calculate IoU for segmentation
        for mask, result in zip(masks, batch_results):
            intersection = np.logical_and(mask[0], result['segmentation_mask'])
            union = np.logical_or(mask[0], result['segmentation_mask'])
            iou = np.sum(intersection) / (np.sum(union) + 1e-10)
            seg_ious.append(iou)
    
    return true_classes, pred_classes, seg_ious


In [None]:
# Create test dataset and loader
test_dataset = ImageDataset(config["data"]["test_dir"], config)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [None]:
# Initialize predictor
predictor = TumorPredictor(config_path="configs/config.yaml")


In [None]:
# Evaluate
true_classes, pred_classes, seg_ious = evaluate_model(predictor, test_loader)


In [None]:
# 3. Visualizations

# Classification Results
plt.figure(figsize=(12, 5))

# Confusion Matrix
plt.subplot(121)
cm = confusion_matrix(true_classes, pred_classes)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=predictor.class_labels,
            yticklabels=predictor.class_labels)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')

# IoU Distribution
plt.subplot(122)
plt.hist(seg_ious, bins=20)
plt.title('Segmentation IoU Distribution')
plt.xlabel('IoU')
plt.ylabel('Count')
plt.tight_layout()
plt.show()


In [None]:
# Print Classification Report
print("\nClassification Report:")
print(classification_report(true_classes, pred_classes, 
                          target_names=predictor.class_labels))

# Print Average IoU
print(f"\nAverage Segmentation IoU: {np.mean(seg_ious):.4f}")


In [None]:
# 4. Example Predictions Visualization
def visualize_prediction(image, result):
    plt.figure(figsize=(15, 5))
    
    # Original Image
    plt.subplot(131)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    # Segmentation Mask
    plt.subplot(132)
    plt.imshow(result['segmentation_mask'], cmap='gray')
    plt.title('Segmentation Mask')
    plt.axis('off')
    
    # Class Probabilities
    plt.subplot(133)
    sns.barplot(x=predictor.class_labels, 
                y=result['class_probabilities'])
    plt.title('Class Probabilities')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()


In [None]:
# Visualize some example predictions
test_images = [test_dataset[i][0] for i in range(5)]  # Get 5 test images
results = predictor.predict_batch(test_images)

for image, result in zip(test_images, results):
    visualize_prediction(image, result)
