# Training and Evaluation
This notebook trains the model and evaluates it on the test dataset.

In [1]:
%pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
import os
sys.path.append(os.path.abspath('../src'))

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader
from inference import TumorPredictor
from train import train_model
from preprocessor import ImageDataset
import yaml
from utils import load_config

In [3]:
config = load_config()
config

{'train': {'class_weight': 0.7,
  'device': 'cuda',
  'log_frequency': 100,
  'save_dir': '/home/xi8t/WORK/BreastTumorPredictor/checkpoints/trained',
  'checkpoint_frequency': 2},
 'data': {'img_size': [512, 512],
  'dir': '/home/xi8t/WORK/BreastTumorPredictor/data/breast_ultrasonic_dataset',
  'augmentation': {'flip': True, 'rotation': 15},
  'num_classes': 5},
 'features_model': {'architecture': 'resnet50',
  'train_last_n_layers': 2,
  'train_layer4': True,
  'train_bn': True},
 'seg_model': {'path': '/home/xi8t/WORK/BreastTumorPredictor/checkpoints/sam2/sam2.1_hiera_small.pt',
  'download_file_path': '/home/xi8t/WORK/BreastTumorPredictor/checkpoints/sam2/download.sh',
  'config_path': '/home/xi8t/WORK/BreastTumorPredictor/configs/sam2/sam2_hiera_s.yaml'}}



KeyError: 'model'

In [4]:
%run ../src/train_tuner.py


  from .autonotebook import tqdm as notebook_tqdm
[I 2024-11-17 06:35:47,478] A new study created in memory with name: no-name-1cd74734-037e-4eb0-b9d8-b4636907054e


Checkpoint file not found. Running /home/xi8t/WORK/BreastTumorPredictor/checkpoints/sam2/download.sh to download it...
Downloading sam2.1_hiera_small.pt checkpoint...


--2024-11-17 06:35:47--  https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt
Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.164.21.88, 18.164.21.117, 18.164.21.21, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.164.21.88|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 184416285 (176M) [application/vnd.snesdev-page-table]
Saving to: ‘sam2.1_hiera_small.pt.4’

     0K .......... .......... .......... .......... ..........  0% 4.27M 41s
    50K .......... .......... .......... .......... ..........  0% 8.51M 31s
   100K .......... .......... .......... .......... ..........  0% 3.80M 36s
   150K .......... .......... .......... .......... ..........  0% 2.49M 45s
   200K .......... .......... .......... .......... ..........  0% 20.8M 37s
   250K .......... .......... .......... .......... ..........  0% 8.98M 34s
   300K .......... .......

All checkpoints are downloaded successfully.
Checkpoint file downloaded successfully.
sam2 model cfg - ../configs/sam2/sam2_hiera_s.yaml


MissingConfigException: Cannot find primary config '../configs/sam2/sam2_hiera_s.yaml'. Check that it's in your config search path.

Config search path:
	provider=hydra, path=pkg://hydra.conf
	provider=main, path=pkg://sam2
	provider=schema, path=structured://

In [5]:
print(os.getcwd())

/home/xi8t/WORK/BreastTumorPredictor/notebooks


In [None]:
# 2. Evaluation
def evaluate_model(predictor, data_loader):
    true_classes = []
    pred_classes = []
    seg_ious = []
    
    for images, masks, labels in tqdm(data_loader, desc="Evaluating"):
        batch_results = predictor.predict_batch(images)
        
        # Collect classification results
        true_classes.extend(labels.numpy())
        pred_classes.extend([r['class_idx'] for r in batch_results])
        
        # Calculate IoU for segmentation
        for mask, result in zip(masks, batch_results):
            intersection = np.logical_and(mask[0], result['segmentation_mask'])
            union = np.logical_or(mask[0], result['segmentation_mask'])
            iou = np.sum(intersection) / (np.sum(union) + 1e-10)
            seg_ious.append(iou)
    
    return true_classes, pred_classes, seg_ious


In [None]:
# Create test dataset and loader
test_dataset = ImageDataset(config["data"]["test_dir"], config)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [None]:
# Initialize predictor
predictor = TumorPredictor(config_path="configs/config.yaml")


In [None]:
# Evaluate
true_classes, pred_classes, seg_ious = evaluate_model(predictor, test_loader)


In [None]:
# 3. Visualizations

# Classification Results
plt.figure(figsize=(12, 5))

# Confusion Matrix
plt.subplot(121)
cm = confusion_matrix(true_classes, pred_classes)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=predictor.class_labels,
            yticklabels=predictor.class_labels)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')

# IoU Distribution
plt.subplot(122)
plt.hist(seg_ious, bins=20)
plt.title('Segmentation IoU Distribution')
plt.xlabel('IoU')
plt.ylabel('Count')
plt.tight_layout()
plt.show()


In [None]:
# Print Classification Report
print("\nClassification Report:")
print(classification_report(true_classes, pred_classes, 
                          target_names=predictor.class_labels))

# Print Average IoU
print(f"\nAverage Segmentation IoU: {np.mean(seg_ious):.4f}")


In [None]:
# 4. Example Predictions Visualization
def visualize_prediction(image, result):
    plt.figure(figsize=(15, 5))
    
    # Original Image
    plt.subplot(131)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    # Segmentation Mask
    plt.subplot(132)
    plt.imshow(result['segmentation_mask'], cmap='gray')
    plt.title('Segmentation Mask')
    plt.axis('off')
    
    # Class Probabilities
    plt.subplot(133)
    sns.barplot(x=predictor.class_labels, 
                y=result['class_probabilities'])
    plt.title('Class Probabilities')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()


In [None]:
# Visualize some example predictions
test_images = [test_dataset[i][0] for i in range(5)]  # Get 5 test images
results = predictor.predict_batch(test_images)

for image, result in zip(test_images, results):
    visualize_prediction(image, result)
