In [None]:
# take the model and test 5 images in in validation set

import torch
from utils.load_data import BrainDataset, get_data_loaders
from architecture.shared_model import load_trained_model, All_view, evaluate_dice_scores, class_dice, arrange_img, dice_coef, mean_iou, DLUNet
import random


device = torch.device("mps")

In [None]:
# get data
train_loader, val_loader, test_loader = get_data_loaders("../data")

In [13]:
# from architecture.shared_model import DLUNet, ReASPP3, load_trained_model
# 1. First create a fresh model instance

model = DLUNet(in_channels=4)
model = torch.load(
    # "../mlruns/820203924686178493/40d84b246dd1475a93e8fb5c244ae452/artifacts/checkpoints/dlu_net_model_epoch_15.pth"
    "../model/base_model/dlu_net_model_epoch_10.pth",
# weights_only=False,
    map_location=device
)

# model.load_state_dict(state_dict['state_dict'])


# 5. Set model to evaluation mode
model.eval()
print("Model loaded successfully and set to evaluation mode")

AttributeError: 'collections.OrderedDict' object has no attribute 'eval'

In [None]:
# Get a sample batch from the validation loader
# next() retrieves the first item from the iterator created by iter(val_loader)
# iter() creates an iterator from val_loader which is likely a DataLoader object
# This gets one batch of data containing both images and masks
import matplotlib.pyplot as plt
val_images, val_masks = next(iter(test_loader))
print(f"Batch shape - Images: {val_images.shape}, Masks: {val_masks.shape}")


# Select a random sample from the batch
random_idx = random.randint(0, val_images.shape[0]-1)


print(random_idx)


with torch.no_grad():
    # Get the first sample
    sample_image = val_images[random_idx:random_idx + 1].to(device)
    sample_mask = val_masks[random_idx:random_idx + 1].to(device)

    # Get prediction
    prediction = model(sample_image)
    thresholded_pred = (prediction > 0.2).float()

    # Use the prediction directly for dice calculation instead of calling evaluate_dice_scores
    tc_dice = class_dice(prediction, sample_mask, 2).item()
    ec_dice = class_dice(prediction, sample_mask, 3).item()
    wt_dice = class_dice(prediction, sample_mask, 4).item()

    print(
        f"Sample Dice Scores - Tumor Core: {tc_dice:.4f}, Enhancing Tumor: {ec_dice:.4f}, Whole Tumor: {wt_dice:.4f}")

# Visualize the prediction vs ground truth

# Convert tensors to the right format for visualization
sample_image_np = sample_image.cpu().numpy().transpose(
    0, 2, 3, 1)  # [B,C,H,W] -> [B,H,W,C]
sample_mask_np = sample_mask.cpu().numpy().transpose(0, 2, 3, 1)
pred_np = thresholded_pred.cpu().numpy().transpose(0, 2, 3, 1)

# Use the visualization functions
GT, Pre, TC, EC, WT = arrange_img(
    torch.from_numpy(sample_image_np),
    torch.from_numpy(sample_mask_np),
    torch.from_numpy(pred_np)
)

# Convert tensors to numpy for matplotlib if needed
if isinstance(GT, torch.Tensor):
    GT = GT.numpy()
if isinstance(Pre, torch.Tensor):
    Pre = Pre.numpy()
if isinstance(TC, torch.Tensor):
    TC = TC.numpy()
if isinstance(EC, torch.Tensor):
    EC = EC.numpy()
if isinstance(WT, torch.Tensor):
    WT = WT.numpy()

# Display the results
fig, ax = plt.subplots(2, 3, figsize=(18, 12))

ax[0, 0].imshow(GT)
ax[0, 0].set_title('Ground Truth', fontsize=15)
ax[0, 0].axis("off")

ax[0, 1].imshow(Pre)
ax[0, 1].set_title('Prediction', fontsize=15)
ax[0, 1].axis("off")

ax[0, 2].imshow(TC)
ax[0, 2].set_title(f'Tumor Core: {tc_dice:.4f}', fontsize=15)
ax[0, 2].axis("off")

ax[1, 0].imshow(EC)
ax[1, 0].set_title(f'Enhancing Tumor: {ec_dice:.4f}', fontsize=15)
ax[1, 0].axis("off")

ax[1, 1].imshow(WT)
ax[1, 1].set_title(f'Whole Tumor: {wt_dice:.4f}', fontsize=15)
ax[1, 1].axis("off")

# Display original MRI (first channel)
ax[1, 2].imshow(sample_image_np[0, :, :, 0], cmap='gray')
ax[1, 2].set_title('Original MRI (Channel 0)', fontsize=15)
ax[1, 2].axis("off")

plt.tight_layout()
plt.show()

In [None]:
def evaluate_model(model, val_loader):
    """
    Evaluate model on validation data.
    """
    model.eval()
    total_dice = 0
    total_iou = 0
    total_class_dice = {2: 0, 3: 0, 4: 0}
    num_batches = 0

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)

            # Calculate metrics
            dice = dice_coef(outputs, masks)
            iou = mean_iou(outputs, masks)

            class_dices = {
                2: class_dice(outputs, masks, 2),
                3: class_dice(outputs, masks, 3),
                4: class_dice(outputs, masks, 4)
            }

            total_dice += dice.item()
            total_iou += iou.item()
            for i in [2, 3, 4]:
                total_class_dice[i] += class_dices[i].item()

            num_batches += 1

    # Calculate mean metrics
    mean_metrics = {
        "dice_coef": float(total_dice / num_batches),
        "mean_iou": float(total_iou / num_batches),
        "class_dice": {
            "c2": float(total_class_dice[2] / num_batches),
            "c3": float(total_class_dice[3] / num_batches),
            "c4": float(total_class_dice[4] / num_batches)
        }
    }

    print(mean_metrics)
    
evaluate_model(model, test_loader)