# Detecting Cancer in Gigapixel Medical Images
## Applied Deep Learning (Spring 2018) 
### Akarsh Zingade, Kiran Ramesh, Arjun D'Cunha

### YouTube [demo](https://www.youtube.com/watch?v=royB3p2m9pM). GitHub [repo](https://github.com/kira-95/adl_cancer_detection).

Note: The 22 slides and tumor masks prepared by Prof. Joshua Gordon can be found [here](https://drive.google.com/drive/folders/1rwWL8zU9v0M27BtQKI52bF6bVLW82RL5?usp=sharing). The super set of this dataset can be found at [CAMELYON16](https://camelyon17.grand-challenge.org/Data/)


### Summary

We base our approach on the the work by Google AI's [Lui et al. (2017)](https://arxiv.org/abs/1703.02442) in "Detecting Cancer Metastases on Gigapixel Pathology Images". We use ImageNet pretrained architecture and then use transfer learning to solve the problem of detecting cancer cells in the images. We train it using a sliding-window based approach, where we train the model using the patches extracted using the sliding windows. Once the model is trained, we create a heatmap of the prediction on medical slides that were not used during training.

In this Notebook, we train the model using Focal Loss


### Flow of the Notebook.

1. Load the train and test slides.
2. Extract patches for train and test slides.
3. Split the train patches into train and validation set
4. Save the train, validation and test slides.

#### Train Slides: 031, 064, 075, 084, 091, 094, 096, 101
#### Test Slides:  016, 078, 110


In [None]:
# Define the levels to be used for training the model.
lvl1 = 4
lvl2 = 5

# Define the window size for the sliding window.
window_size = 299

# Define the center size to label the patch as tumorous or as healthy.
patch_centre = 128

# Number of Tumorous patches per slide per level
tumor_sampled_limit = 100

# Number of Healthy patches per slide per level
healthy_sampled_limit = 100

#Datafile Name prefix
prefix = 'multilevel'

## Import the relevant modules

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from openslide import open_slide, __library_version__ as openslide_version
import os
from PIL import Image
from skimage.color import rgb2gray
import random

In [None]:
SLIDES_DIR = 'slides'

In [None]:
def read_slide(slide, x, y, level, width, height, as_float=False):
    """ Read a region from the slide
    Return a numpy RBG array
    """
    im = slide.read_region((x,y), level, (width, height))
    im = im.convert('RGB') # drop the alpha channel
    if as_float:
        im = np.asarray(im, dtype=np.float32)
    else:
        im = np.asarray(im)
    assert im.shape == (height, width, 3)
    return im
  
def find_tissue_pixels(image, intensity=0.8):
    """ Return tissue pixels for an image
    """
    im_gray = rgb2gray(image)
    assert im_gray.shape == (image.shape[0], image.shape[1])
    indices = np.where(im_gray <= intensity)
    return zip(indices[0], indices[1])
  
def apply_mask(im, mask, color=(255,0,0)):
    """ Return the mask as an image
    """
    masked = np.copy(im)
    for x,y in mask: masked[x][y] = color
    return masked
  
def check_patch_centre(patch_mask, patch_centre):
    """ Check if there is any tumor pixel in the 128x128 centre
    inputs:
    - patch_mask: array, tumor mask
    - patch_centre: int, usually 128
    outputs: Boolean
    """
    patch_size = patch_mask.shape[0]
    offset = int((patch_size-patch_centre)/2)
    sum_cancers = np.sum(patc

In [None]:
def generate_images(slide_path, tumor_mask_path, lvl1, lvl2, window_size, tumor_sampled_limit, healthy_sampled_limit):
    """
    Return the patchs for two levels and their labels

    slide_path: Path to the slide.
    tumor_mask_path: Path to the tumor mask slide
    lvl1 and lvl2: Levels of the slide
    window_size: Sliding Window size
    tumor_sampled_limit: Number of Tumorous patches to return per slide per level
    healthy_sampled_limit: Number of Healthy patches to return per slide per level
    """
        
  patch_images_1 = []
  patch_images_2 = []
  
  patch_labels = []
  
  num_cancer = 0
  num_health = 0
 
  reference_lvl = 4

  slide = open_slide(slide_path)
  print ("Read WSI from %s with width: %d, height: %d" % (slide_path, slide.level_dimensions[0][0], slide.level_dimensions[0][1]))

  tumor_mask = open_slide(tumor_mask_path)
  print ("Read tumor mask from %s" % (tumor_mask_path))
  
  slide_image = read_slide(slide, 
                         x=0, 
                         y=0, 
                         level=reference_lvl, 
                         width=slide.level_dimensions[reference_lvl][0], 
                         height=slide.level_dimensions[reference_lvl][1])
  
  tumor_mask_image = read_slide(tumor_mask, 
                         x=0, 
                         y=0, 
                         level=reference_lvl, 
                         width=slide.level_dimensions[reference_lvl][0], 
                         height=slide.level_dimensions[reference_lvl][1])
  
  tumor_mask_image = tumor_mask_image[:,:,0]
  
  #Get a list of tumor pixels at reference level
  list_tumor_mask_pixels = np.nonzero(tumor_mask_image)
  
  #Construct a healthy tissue mask by subtracting tumor mask from tissue mask
  tissue_pixels = find_tissue_pixels(slide_image)
  tissue_regions = apply_mask(slide_image, tissue_pixels)

  healthy_mask_image = tissue_regions[:,:,0] - tumor_mask_image
  healthy_mask_image = healthy_mask_image > 0
  healthy_mask_image = healthy_mask_image.astype('int')

  #Get a list of healthy tissue pixels at reference level
  list_healthy_mask_pixels = np.nonzero(healthy_mask_image)
  
  #Collect tumor patches
  tumor_pixels = random.sample(list(zip(list_tumor_mask_pixels[1], list_tumor_mask_pixels[0])), tumor_sampled_limit * 10)
  
  count = 0
  for pixel in tumor_pixels:
    if count >= tumor_sampled_limit:
      break
      
    (x_ref, y_ref) = pixel

    #Convert reference_lvl coordinates to level 0 coordinates
    x0 = x_ref*(2**reference_lvl)
    y0 = y_ref*(2**reference_lvl)
    
    downsample_factor = 2**lvl1
    
    patch = read_slide(slide,
                       x = x0-(window_size//2)*downsample_factor,
                       y = y0-(window_size//2)*downsample_factor, 
                       level = lvl1,
                       width = window_size,
                       height = window_size)
    
    tumor_mask_patch = read_slide(tumor_mask,
                       x = x0-(window_size//2)*downsample_factor,
                       y = y0-(window_size//2)*downsample_factor, 
                       level = lvl1,
                       width = window_size,
                       height = window_size)
    
    tumor_mask_patch = tumor_mask_patch[:,:,0]
    
    tissue_pixels = find_tissue_pixels(patch)
    tissue_pixels = list(tissue_pixels)
    percent_tissue = len(tissue_pixels) / float(patch.shape[0] * patch.shape[0]) * 100

    if percent_tissue > 50 and check_patch_centre(tumor_mask_patch, 128):
        patch_images_1.append(patch)
        patch_images_2.append(read_slide(slide, x = x0-(window_size//2)*downsample_factor, y = y0-(window_size//2)*downsample_factor, level = lvl2, width = window_size, height = window_size))

        patch_labels.append(1)
        count += 1
        

        
  #Collect healthy patches
  healthy_pixels = random.sample(list(zip(list_healthy_mask_pixels[1], list_healthy_mask_pixels[0])), healthy_sampled_limit * 20)
  
  count = 0
  for pixel in healthy_pixels:
    if count >= healthy_sampled_limit:
      break
      
    (x_ref, y_ref) = pixel

    #Convert reference_lvl coordinates to level 0 coordinates
    x0 = x_ref*(2**reference_lvl)
    y0 = y_ref*(2**reference_lvl)
    
    downsample_factor = 2**lvl1
    
    patch = read_slide(slide,
                       x = x0-(window_size//2)*downsample_factor,
                       y = y0-(window_size//2)*downsample_factor, 
                       level = lvl1,
                       width = window_size,
                       height = window_size)
    
    tumor_mask_patch = read_slide(tumor_mask,
                       x = x0-(window_size//2)*downsample_factor,
                       y = y0-(window_size//2)*downsample_factor, 
                       level = lvl1,
                       width = window_size,
                       height = window_size)
    
    tumor_mask_patch = tumor_mask_patch[:,:,0]
    
    tissue_pixels = find_tissue_pixels(patch)
    tissue_pixels = list(tissue_pixels)
    percent_tissue = len(tissue_pixels) / float(patch.shape[0] * patch.shape[0]) * 100

    if percent_tissue > 50 and (not check_patch_centre(tumor_mask_patch, 128)):
        patch_images_1.append(patch)
        patch_images_2.append(read_slide(slide, x = x0-(window_size//2)*downsample_factor, y = y0-(window_size//2)*downsample_factor, level = lvl2, width = window_size, height = window_size))
        patch_labels.append(0)
        count += 1

  return patch_images_1, patch_images_2, patch_labels

## Extract patches for train slides. 

In [None]:
trainval_patch_images_lev1 = []
trainval_patch_images_lev2 = []
trainval_patch_labels = []

TRAIN_SLIDE_NUMS = ['016', '031', '064', '075', '078', '084', '094', '096', '101']

for num in TRAIN_SLIDE_NUMS:
  slide_path = os.path.join(SLIDES_DIR, 'tumor_' + num + '.tif')
  tumor_mask_path = os.path.join(SLIDES_DIR, 'tumor_' + num + '_mask.tif')  
  patch_images_1, patch_images_2, patch_labels = generate_images(slide_path, tumor_mask_path, lvl1, lvl2, window_size, tumor_sampled_limit, healthy_sampled_limit)
  trainval_patch_images_lev1.extend(patch_images_1)
  trainval_patch_images_lev2.extend(patch_images_2)
  trainval_patch_labels.extend(patch_labels)

## Extract patches for test slides.

In [None]:
test_patch_images_lev1 = []
test_patch_images_lev2 = []
test_patch_labels = []

TEST_SLIDE_NUMS = ['016','078','110']

for num in TEST_SLIDE_NUMS:
  slide_path = os.path.join(SLIDES_DIR, 'tumor_' + num + '.tif')
  tumor_mask_path = os.path.join(SLIDES_DIR, 'tumor_' + num + '_mask.tif')  
  patch_images_1, patch_images_2, patch_labels = generate_images(slide_path, tumor_mask_path, lvl1, lvl2, window_size, tumor_sampled_limit, healthy_sampled_limit)
  test_patch_images_lev1.extend(patch_images_1)
  test_patch_images_lev2.extend(patch_images_2)
  test_patch_labels.extend(patch_labels)

In [None]:
X1_trainval = np.asarray(trainval_patch_images_lev1)
X2_trainval = np.asarray(trainval_patch_images_lev2)
y_trainval = np.asarray(trainval_patch_labels)

X1_test = np.asarray(test_patch_images_lev1)
X2_test = np.asarray(test_patch_images_lev2)
y_test = np.asarray(test_patch_labels)

print(X1_trainval.shape, X2_trainval.shape, len(y_trainval), X1_test.shape, X2_test.shape, len(y_test))

In [None]:
del trainval_patch_images_lev1, trainval_patch_images_lev2, trainval_patch_labels,
del test_patch_images_lev1, test_patch_images_lev2, test_patch_labels

In [None]:
idxs = list(range(len(X1_trainval)))
np.random.shuffle(idxs)
train_idxs = idxs[:int(0.8*len(idxs))]
val_idxs = idxs[int(0.8*len(idxs)):]

X1_train = X1_trainval[train_idxs]
X2_train = X2_trainval[train_idxs]
y_train = y_trainval[train_idxs]

X1_val = X1_trainval[val_idxs]
X2_val = X2_trainval[val_idxs]
y_val = y_trainval[val_idxs]

dataset = {
    'X1_train' : X1_train,
    'X2_train' : X2_train,
    'y_train' : y_train,
    'X1_val' : X1_val,
    'X2_val' : X2_val,
    'y_val' : y_val,
    'X1_test' : X1_test,
    'X2_test' : X2_test,
    'y_test' : y_test,
}

np.save('./custom_multilevel_levels_' + str(lev1) + '_' + str(lev2), dataset)