## Prior Setup:
* Spark settings in `$SPARK_HOME/conf/spark-defaults.conf`:
    * Laptop:
    
    ```
    spark.driver.memory 12g
    spark.driver.maxResultSize 0
    spark.default.parallelism 100
    spark.serializer org.apache.spark.serializer.KryoSerializer
    ```
    
    * Server:
    
    ```
    spark.driver.memory 70g
    spark.executor.memory 100g
    spark.driver.maxResultSize 0
    spark.akka.frameSize 128
    spark.driver.extraJavaOptions -server -Xmn12G
    spark.executor.extraJavaOptions -server -Xmn12G
    spark.local.dirs /disk2/local,/disk3/local,/disk4/local,/disk5/local,/disk6/local,/disk7/local,/disk8/local,/disk9/local,/disk10/local,/disk11/local,/disk12/local
     spark.network.timeout 1000s
    ```
    
* Spark started with:
    * Laptop:

    ```
    PYSPARK_PYTHON=python3 PYSPARK_DRIVER_PYTHON=jupyter PYSPARK_DRIVER_PYTHON_OPTS="notebook" pyspark --master local[*] --driver-class-path $SYSTEMML_HOME/target/SystemML.jar --jars $SYSTEMML_HOME/target/SystemML.jar
    ```
    
    * Server:
    
    ```
    PYSPARK_PYTHON=python3 PYSPARK_DRIVER_PYTHON=jupyter PYSPARK_DRIVER_PYTHON_OPTS="notebook" pyspark --master spark://MASTER_URL:7077 --driver-class-path $SYSTEMML_HOME/target/SystemML.jar --jars $SYSTEMML_HOME/target/SystemML.jar
    ```
    
* A `data` folder containing a `training_image_data` folder with at least the first two training slides.
* A `training_ground_truth.csv` file in the `data` folder containing the tumor & molecular scores for each slide.
* Layout:

    ```
    - Preprocessing.ipynb
    - data/
        - training_ground_truth.csv
        - training_image_data
            - slide1....
            - slide2....
            - ...
    ```

# Setup

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import functools
import math
import multiprocessing as mp

import matplotlib.pyplot as plt
import numpy as np
import openslide
from openslide.deepzoom import DeepZoomGenerator
import pandas as pd
from pyspark import StorageLevel
from pyspark.mllib.linalg import Vectors
# import systemml  # pip3 install systemml

from scipy.ndimage.morphology import binary_fill_holes
from skimage.color import rgb2gray
from skimage.feature import canny
from skimage.morphology import binary_closing, binary_dilation, disk

plt.rcParams['figure.figsize'] = (10, 6)

# Open Whole-Slide Image

In [None]:
def open_slide(slide_num, dir="data/training_image_data"):
  """
  Open a whole-slide image, given an image number.
  
  Args:
    slide_num: Slide image number as an integer.
    dir: Directory in which the slides are stored, as a string.
  
  Returns:
    An OpenSlide object representing a whole-slide image.
  """
  filename = "{0}/TUPAC-TR-{1}.svs".format(dir, str(slide_num).zfill(3))
  slide = openslide.open_slide(filename)
  return slide

# Create Tile Generator

In [None]:
def create_tile_generator(slide, tile_size=1024, overlap=0):
  """
  Create a tile generator for the given slide.
  
  This generator is able to extract tiles from the overall
  whole-slide image.
  
  Args:
    slide: An OpenSlide object representing a whole-slide image.
    tile_size: The width and height of a square tile to be generated.
    overlap: Number of pixels by which to overlap the tiles.
  
  Returns:
    A DeepZoomGenerator object representing the tile generator. Each
    extracted tile is an Image with shape (tile_size, tile_size, channels).
    Note: This generator is not a true "Python generator function", but
    rather is an object that is capable of extracting individual tiles.
  """
  generator = DeepZoomGenerator(slide, tile_size=tile_size, overlap=overlap, limit_bounds=True)
  return generator

# Determine 20x Magnification Zoom Level

In [None]:
def get_20x_zoom_level(slide, generator):
  """
  Return the zoom level that corresponds to a 20x magnification.
  
  The generator can extract tiles from multiple zoom levels, downsampling
  by a factor of 2 per level from highest to lowest resolution.
  
  Args:
    slide: An OpenSlide object representing a whole-slide image.
    generator: A DeepZoomGenerator object representing a tile generator.
      Note: This generator is not a true "Python generator function", but
      rather is an object that is capable of extracting individual tiles.
  
  Returns:
    Zoom level corresponding to a 20x magnification, or as close as possible.
  """
  highest_zoom_level = generator.level_count - 1  # 0-based indexing
  try:
    mag = int(slide.properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER])
    # `mag / 20` gives the downsampling factor between the slide's
    # magnification and the desired 20x magnification.
    # `(mag / 20) / 2` gives the zoom level offset from the highest
    # resolution level, based on a 2x downsampling factor in the
    # generator.
    offset = math.floor((mag / 20) / 2)
    level = highest_zoom_level - offset
  except ValueError:
    # In case the slide magnification level is unknown, just
    # use the highest resolution.
    level = highest_zoom_level
  return level

# Generate Tile Indices For Whole-Slide Image.

In [None]:
def process_slide(slide_num, tile_size=1024, overlap=0):
  """
  Generate all possible tile indices for a whole-slide image.
  
  Given a slide number, tile size, and overlap, generate
  all possible (slide_num, tile_size, overlap, zoom_level, col, row)
  indices.
  
  Args:
    slide_num: Slide image number as an integer.
    tile_size: The width and height of a square tile to be generated.
    overlap: Number of pixels by which to overlap the tiles.
  
  Returns:
    A list of (slide_num, tile_size, overlap, zoom_level, col, row)
    integer index tuples representing possible tiles to extract.
  """
  # Open slide.
  slide = open_slide(slide_num)
  # Create tile generator.
  generator = create_tile_generator(slide, tile_size, overlap)
  # Get 20x zoom level.
  zoom_level = get_20x_zoom_level(slide, generator)
  # Generate all possible (zoom_level, col, row) tile index tuples.
  cols, rows = generator.level_tiles[zoom_level]
  tile_indices = [(slide_num, tile_size, overlap, zoom_level, col, row)
                  for col in range(cols) for row in range(rows)]
  return tile_indices

# Generate 1024x1024x3 Tile From Tile Index

In [None]:
def process_tile_index(tile_index):
  """
  Generate a tile from a tile index.
  
  Given a (slide_num, tile_size, overlap, zoom_level, col, row) tile
  index, generate a (slide_num, tile) tuple.
  
  Args:
    tile_index: A (slide_num, tile_size, overlap, zoom_level, col, row)
      integer index tuple representing a tile to extract.
  
  Returns:
    A (slide_num, tile) tuple, where slide_num is an integer, and tile
    is a 3D NumPy array of shape (tile_size, tile_size, channels).
  """
  slide_num, tile_size, overlap, zoom_level, col, row = tile_index
  # Open slide.
  slide = open_slide(slide_num)
  # Create tile generator.
  generator = create_tile_generator(slide, tile_size, overlap)
  # Generate tile
  tile = np.array(generator.get_tile(zoom_level, (col, row)))
  return (slide_num, tile)

# Filter Tile For Dimensions & 90% Tissue Threshold

In [None]:
def calc_tissue_percentage(binary_tile):
  """
  Calculate the percentage of the tile that is filled by tissue.
  
  Args:
    binary_tile: A square region of a whole-slide image as a 2D
      NumPy array that has been binarized to the pixel values
      {0, 1}.  Pixels with a value of `1` represent tissue, while
      pixels with a value of `0` represent background.
  
  Returns:
    Percentage of tile that is filled by tissue.
  """
  (x, y) = binary_tile.shape
  percentage = binary_tile.sum() / (x * y)
  # TODO: Try `binary_tile.mean()`
  return percentage

def keep_tile(tile_tuple, tile_size=1024, tissue_threshold=0.9):
  """
  Determine if a tile should be kept.
  
  This filters out tiles based on size and a tissue percentage
  threshold, using a custom algorithm. If a tile has height &
  width equal to (tile_size, tile_size), and contains greater
  than or equal to the given percentage, then it will be kept;
  otherwise it will be filtered out.
  
  Args:
    tile_tuple: A (slide_num, tile) tuple, where slide_num is an
      integer, and tile is a 3D NumPy array of shape 
      (tile_size, tile_size, channels).
    tissue_threshold: Tissue percentage threshold.
  
  Returns:
    A Boolean indicating whether or not a tile should be kept
    for future usage.
  """
  slide_num, tile = tile_tuple
  if tile.shape[0:2] == (tile_size, tile_size):
    # Convert 3D RGB image to 2D grayscale image, from
    # 0 (dense tissue) to 1 (plain background).
    tile = rgb2gray(tile)
    # 8-bit depth complement, from 1 (dense tissue)
    # to 0 (plain background).
    tile = 1 - tile
    # Canny edge detection with hysteresis thresholding.
    # This returns a binary map of edges, with 1 equal to
    # an edge. The idea is that tissue would be full of
    # edges, while background would not.
    tile = canny(tile)
    # Binary closing, which is a dilation followed by
    # an erosion. This removes small dark spots, which
    # helps remove noise in the background.
    tile = binary_closing(tile, disk(10))
    # Binary dilation, which enlarges bright areas,
    # and shrinks dark areas. This helps fill in holes
    # within regions of tissue.
    tile = binary_dilation(tile, disk(10))
    # Fill remaining holes within regions of tissue.
    tile = binary_fill_holes(tile)
    # Calculate percentage of tissue coverage.
    percentage = calc_tissue_percentage(tile)
    return percentage >= tissue_threshold
  else:
    return False

# Generate Flattened 3x256x256 Samples From Tile

In [None]:
def process_tile(tile_tuple, new_size=256, grayscale=False):
  """
  Cut up a tile into smaller blocks of new_size x new_size pixels,
  change the shape of each sample from (H, W, channels) to 
  (channels, H, W), then flatten each into a vector of length
  channels*H*W.
  
  Args:
    tile_tuple: A (slide_num, tile) tuple, where slide_num is an
      integer, and tile is a 3D NumPy array of shape 
      (tile_size, tile_size, channels).
    new_size: The new width and height of the square samples to be
      generated.
  
  Returns:
    A list of (slide_num, sample) tuples representing cut up tiles,
    where each sample has been transposed from
    (new_size_x, new_size_y, channels) to (channels, new_size_x, new_size_y),
    and flattened to a vector of length (channels*new_size_x*new_size_y).
  """
  slide_num, tile = tile_tuple
  if grayscale:
    tile = rgb2gray(tile)[:, :, np.newaxis]  # Grayscale
  x, y, ch = tile.shape
  # 1. Reshape into a 5D array of (num_x, new_size_x, num_y, new_size_y, ch), where
  # num_x and num_y are the number of chopped tiles on the x and y axes, respectively.
  # 2. Swap new_size_x and num_y axes to create (num_x, num_y, new_size_x, new_size_y, ch).
  # 3. Combine num_x and num_y into single axis, returning
  # (num_samples, new_size_x, new_size_y, ch).
  # 4. Swap axes from (num_samples, new_size_x, new_size_y, ch) to
  # (num_samples, ch, new_size_x, new_size_y).
  # 5. Flatten samples into (num_samples, ch*new_size_x*new_size_y).
  samples = (tile.reshape((x // new_size, new_size, y // new_size, new_size, ch))
                 .swapaxes(1,2)
                 .reshape((-1, new_size, new_size, ch))
                 .transpose(0,3,1,2))
  samples = samples.reshape(samples.shape[0], -1)
  samples = [(slide_num, sample) for sample in list(samples)]
  return samples

# Visualize Tile

In [None]:
def visualize_tile(tile):
  """
  Plot a tissue tile.
  
  Args:
    tile: A 3D NumPy array of shape (tile_size, tile_size, channels).
  
  Returns:
    None
  """
  plt.imshow(tile)
  plt.show()

# Visualize Sample

In [None]:
def visualize_sample(sample, size=256):
  """
  Plot a tissue sample.
  
  Args:
    sample: A square sample flattened to a vector of size
      (channels*size_x*size_y).
    size: The width and height of the square samples.
  
  Returns:
    None
  """
  # Change type, reshape, transpose to (size_x, size_y, channels).
  length = sample.shape[0]
  channels = int(length / (size * size))
  if channels > 1:
    sample = sample.astype('uint8').reshape((channels, size, size)).transpose(1,2,0)
    plt.imshow(sample)
  else:
    vmax = 255 if sample.max() > 1 else 1
    sample = sample.reshape((size, size))
    plt.imshow(sample, cmap="gray", vmin=0, vmax=vmax)
  plt.show()

# Get Ground Truth Labels

In [None]:
labels = pd.read_csv("data/training_ground_truth.csv", names=["tumor_score","molecular_score"], header=None)
labels["slide_num"] = range(1, 501)

# Create slide_num -> tumor_score and slide_num -> molecular_score dictionaries
tumor_score_dict = {int(s): int(l) for s,l in zip(labels.slide_num, labels.tumor_score)}
molecular_score_dict = {int(s): float(l) for s,l in zip(labels.slide_num, labels.molecular_score)}

---

# Process All Images

## Singlenode

## Singlenode w/ multiprocessing

## Spark

In [None]:
grayscale = True
# Get list of image numbers, minus the broken ones
# num_partitions = 20000  #1600
# broken = {2, 45, 91, 112, 242, 256, 280, 313, 329, 467}
# slide_nums = sorted(set(range(1,501)) - broken)
num_partitions = 200
slide_nums = [1, 2]
slides = sc.parallelize(slide_nums)
# Force even partitioning by collecting and parallelizing -- for memory issues
# TODO: Try mapping tile_indices to a (random_num, tuple) pair, then sort by key
tile_indices = slides.flatMap(process_slide).repartition(num_partitions).collect()
tile_indices = sc.parallelize(tile_indices, num_partitions)
tiles = tile_indices.map(process_tile_index)
filtered_tiles = tiles.filter(keep_tile)
samples = filtered_tiles.flatMap(lambda tile: process_tile(tile, grayscale=grayscale))
# df = samples.map(lambda tup: (tup[0], Vectors.dense(tup[1]))).toDF(["slide_num", "sample"])
df = (samples.map(lambda tup: 
                 (tup[0], tumor_score_dict[tup[0]], molecular_score_dict[tup[0]], 
                  Vectors.dense(tup[1])))
              .toDF(["slide_num", "tumor_score", "molecular_score", "sample"]))
df = df.select(df.slide_num.astype("int"), df.tumor_score.astype("int"),
               df.molecular_score, df["sample"])
df = df.repartition(num_partitions)

In [None]:
# df.write.mode('overwrite').save("data/samples_labels_df.parquet", format="parquet")
#df.write.save("data/samples_labels_df.parquet", format="parquet")
# df.write.save("data/samples_labels_df_grayscale.parquet", format="parquet")

In [None]:
df

---

# Test Reading in Preprocessed DataFrame

In [None]:
df2 = sqlContext.read.load("data/samples_labels_df.parquet")
channels = 3
# df2 = sqlContext.read.load("data/samples_labels_df_grayscale.parquet")
# channels = 1

In [None]:
df2

In [None]:
df2.count()

In [None]:
sample = df2.first().sample
print(sample.array.dtype)
visualize_sample(sample)

In [None]:
df2.schema

In [None]:
df2.repartition(100).rdd.getNumPartitions()

In [None]:
df2.groupBy("slide_num").count().show()

---

# Explore Adding Ground Truth Labels