# Setup and Imports

In [1]:
# Install if monai not installed
!pip install monai

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
# Imports
import torch
from torch.utils.data import DataLoader

import time
import json
import numpy as np

import monai
from monai.data import ArrayDataset

from skimage.transform import resize

import matplotlib.pyplot as plt
from matplotlib import cm

# Model Prediction Method Setup

In [3]:
# Helper methods to generate and resize images for model input.
def generate_heatmap(heatmap, mask=None, alpha=0.2, cmap='jet'):
    if mask.ndim < 4:
      mask = mask[..., None]

    # Get colormap indices
    indices = np.round(255.*heatmap).astype(np.int32)

    # Get colourmap values
    levels = list(range(256))
    cm_func = cm.get_cmap(cmap)
    cmap_vals = cm_func(levels)[:, :3]

    # Gather colourmap values at indices
    return np.take(cmap_vals, indices, axis=0)
  
def resize_reshape_image(cdis_img):
  reshaped_img = [None] * min(cdis_img.shape[0], 25)

  # Reduce to 25 slices.
  if cdis_img.shape[0] > 25:
      slice_index = abs(cdis_img.shape[0] - 25)
      cdis_img = cdis_img[:-slice_index] # standardizing to 25 slices

  # Resize to 224 x 224.
  for slc in range(cdis_img.shape[0]):
    slc_res = resize(cdis_img[slc], (224, 224))
    reshaped_img[slc] = slc_res
  
  return reshaped_img


In [4]:
# Method to preprocess CDIs image for model input.
cdis_linear_window = [0, 5000]
alpha = 0.3  # lower --> increased transparency, higher --> reduced transparency

def preprocess_cdis_imgs(img):

    # Masks for non-zeros values
    cdis_mask = img > 0

    # Linearize and normalize the CDI^s value.
    cdis = np.clip(img, *cdis_linear_window)
    norm_cdis = (cdis - cdis_linear_window[0])/(cdis_linear_window[1] - cdis_linear_window[0])

    # Compute the pure CDIs image
    cdis_image = generate_heatmap(norm_cdis, mask=cdis_mask, alpha=alpha)

    # Standardize the image.
    std_img = np.array(resize_reshape_image(cdis_image))

    # Change the dimensions of the image to expected dimensions.
    final_img = np.transpose(std_img, (3, 0, 1, 2))

    return final_img

In [5]:
# Method to get model prediction given images and a path to the model weights.
def get_model_prediction(imgs, path_to_model_weights):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  # Load trained model.
  model = monai.networks.nets.resnet34(spatial_dims=3, n_input_channels=1, 
                                    num_classes=2, pretrained=False)

  model.conv1= torch.nn.Conv3d(3, 64, kernel_size=(7,7,7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
  model = model.to(device)

  model.load_state_dict(torch.load(path_to_model_weights))
  model.eval()

  # Load in the patient data.
  pred_ds = ArrayDataset(img=imgs, labels=[0])
  test_loader = DataLoader(pred_ds, batch_size=1, shuffle=True, pin_memory=torch.cuda.is_available())
  y_pred = []
    
  # Predict pCR for the given image.
  with torch.no_grad():
      for test_data in test_loader:
          test_images, test_labels = (
              test_data[0].float().to(device),
              test_data[1].to(device),
          )
          pred = model(test_images).argmax(dim=1)
          for i in range(len(pred)):
              y_pred.append(pred[i].item())
  
  return y_pred

# Demo Image Output with CancerNet-BCa-A Model

In [6]:
# Step 0: Get model location and load demo image.
path_to_model_weights = "models/CancerNet-BCa-A.pth"
demo_cdis_img = np.load("demo_cdis.npy")

# Step 1: Preprocess demo image.
demo_cdis_img_preprocessed = preprocess_cdis_imgs(demo_cdis_img)

# Step 2: Get model prediction for preprocessed demo image.
pred = get_model_prediction([demo_cdis_img_preprocessed], path_to_model_weights)

# Step 3: Print model prediction to the user.
print("Actual patient pCR: True")
print("Predicted patient pCR: ", pred[0] == 0)


Actual patient pCR: True
Predicted patient pCR:  True
