In [25]:
import sys
from pathlib import Path

# Add project root to path
notebook_dir = Path().resolve()
src_dir = notebook_dir.parent
project_root = src_dir.parent

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

In [26]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, classification_report, f1_score, accuracy_score
from torch.utils.data import DataLoader

from src.models.cnn_model import UNet
from src.data.dataset import PixelClassificationDataset
from src.utils.helpers import get_image_and_mask_paths, compare_two_images, clean_mask

device = torch.device("mps" if torch.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


# Setup

In [27]:
CHECKPOINT_PATH = project_root / "checkpoints"
DATA_DIR = str(project_root / "data/cnn_training/resized_images")
MASK_DIR = str(project_root / "data/cnn_training/resized_masks")

In [28]:
model = UNet().to(device)
model.load_state_dict(torch.load(CHECKPOINT_PATH / "final_model.pth", weights_only=True))

<All keys matched successfully>

In [29]:
train_image_paths, train_mask_paths, test_image_paths, test_mask_paths = get_image_and_mask_paths(DATA_DIR, MASK_DIR)
test_dataset = PixelClassificationDataset(test_image_paths, test_mask_paths, transform=None)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Quantitative Evaluation

In [30]:
model.eval()

preds = []
gts = []
with torch.no_grad():
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.cpu().numpy()

        outputs = model(images)

        pred_classes = torch.argmax(outputs, dim=1)  
        pred_classes = pred_classes.cpu().numpy()

        preds.append(pred_classes)
        gts.append(masks)

preds = np.concatenate(preds, axis=0)
gts = np.concatenate(gts, axis=0)

print(f"Predictions shape: {preds.shape}")
print(f"Ground truth shape: {gts.shape}")

Predictions shape: (4, 512, 512)
Ground truth shape: (4, 512, 512)


In [31]:
# Flatten the predictions and ground truth
preds_flat = preds.flatten()
gt_flat = gts.flatten()

print(f"Flattened predictions shape: {preds_flat.shape}")
print(f"Flattened ground truth shape: {gt_flat.shape}")

Flattened predictions shape: (1048576,)
Flattened ground truth shape: (1048576,)


In [32]:
accuracy = accuracy_score(gt_flat, preds_flat)
f1_macro = f1_score(gt_flat, preds_flat, average='macro')
f1_weighted = f1_score(gt_flat, preds_flat, average='weighted')
f1_per_class = f1_score(gt_flat, preds_flat, average=None)

# Confusion matrix
cm = confusion_matrix(gt_flat, preds_flat)

# Detailed classification report
report = classification_report(gt_flat, preds_flat, 
                               target_names=['Background', 'Nucleus', 'Other'])  # Adjust class names

print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score (Macro): {f1_macro:.4f}")
print(f"F1 Score (Weighted): {f1_weighted:.4f}")
print(f"F1 Score per class: {f1_per_class}")
print("\nConfusion Matrix:")
print(cm)
print("\nClassification Report:")
print(report)

Accuracy: 0.9636
F1 Score (Macro): 0.7850
F1 Score (Weighted): 0.9673
F1 Score per class: [0.98239304 0.7993227  0.5732047 ]

Confusion Matrix:
[[941971  24059   7671]
 [  1593  59952   4035]
 [   442    416   8437]]

Classification Report:
              precision    recall  f1-score   support

  Background       1.00      0.97      0.98    973701
     Nucleus       0.71      0.91      0.80     65580
       Other       0.42      0.91      0.57      9295

    accuracy                           0.96   1048576
   macro avg       0.71      0.93      0.78   1048576
weighted avg       0.97      0.96      0.97   1048576



In [33]:
ORG_SIZE = (2560, 1920)

In [34]:
def clean_mask_tune(mask, num_classes=2, min_size=800, area_threshold=100, radius=2):
    from skimage.morphology import remove_small_objects, remove_small_holes, binary_closing, disk
    from skimage.transform import resize
    """
    Clean a multi-class segmentation mask.
    
    Args:
        mask: numpy array with integer class labels (0, 1, 2, ...)
        num_classes: number of classes in the mask
        min_size: minimum size of objects to keep
        area_threshold: maximum size of holes to fill
        radius: radius of morphological closing disk
        
    Returns:
        Cleaned mask with same shape and dtype as input
    """
    cleaned_mask = np.zeros_like(mask, dtype=mask.dtype)
    
    for cls in range(num_classes):
        binary_mask = (mask == cls)
        
        if not binary_mask.any():
            continue
    
        cleaned_class = remove_small_objects(binary_mask, min_size=min_size)
        cleaned_class = remove_small_holes(cleaned_class, area_threshold=area_threshold)
        
        rad = disk(radius=radius)
        cleaned_class = binary_closing(cleaned_class, rad)
        
        cleaned_mask[cleaned_class] = cls
    
    return cleaned_mask

In [36]:
# test_images = next(iter(test_loader))[0]

# for i, (pred, gt) in enumerate(zip(preds, gts)):
#     image = test_images[i].permute(1, 2, 0).cpu().numpy()
#     plt.imshow(image)
#     plt.axis('off')
#     plt.title("Original Image")
#     plt.show()
#     compare_two_images(pred, gt, "Prediction", "Ground Truth")
#     scaled_gt = cv2.resize(gt, ORG_SIZE, interpolation=cv2.INTER_NEAREST)
#     scaled_pred = cv2.resize(pred, ORG_SIZE, interpolation=cv2.INTER_NEAREST)
#     cleaned_gt = clean_mask_tune(scaled_pred, num_classes=3)
#     compare_two_images(cleaned_gt, scaled_gt, "Cleaned Ground Truth", "Ground Truth")