# Neural Image Compression # 



## Imports ##

In [12]:
import numpy as np
import pandas as pd
from os.path import join, dirname, basename, splitext, exists
import os
from glob import glob
import sys
import shutil
import random
from tqdm import tqdm
from PIL import Image

import multiresolutionimageinterface as mri
import keras

## Data ##

To demonstrate the functionality of NIC, we will need a set of whole-slide images (WSIs) with their respective slide-level labels. In this case, we will use the WSIs that can be found using the following pattern:

`//chansey.umcn.nl/pathology/projects/pathology-liver-survival/data/images/Batch*/*.mrxs`

And the slide-level labels can be found in column `HGP_SL` from:

`//chansey.umcn.nl/pathology/projects/pathology-liver-survival/data/clinical/slide_list_hgpbin.csv`

Additionally, column `partition` assigns each slide to a data fold (partitions 1 to 4 are used for cross-validation, and `encoder` to extract patches for encoder training -see next section-).

In [13]:
root_dir = r'W:\projects\pathology-liver-survival'
slide_dir = join(root_dir, 'data', 'images')
csv_path = join(root_dir, 'data', 'clinical', 'slide_list_hgpbin.csv')
cache_dir = None  # used to store local copies of files during I/O operations (useful in cluster)

## 1. Encoder network ##

To perform NIC, we will need an encoder network to transform small image patches into embedding vectors. According to the paper, BiGAN produces the best unsupervised encoder and it is the one we will train here.

Alternatively, a collection of pretrained encoders (the one used in the NIC paper) can be found in 

`./models/encoders_patches_pathology/*.h5`

Remember that these pretrained encoders accept 128x128x3 patches taken at 0.5 um/px resolution (often level 1), except for the BiGAN model that takes 64x64x3 at 1 um/px (often level 2).



In order to train the BiGAN model, we will first extract patches from the slides in the `encoder` partition. We will sample 10K patches per slide, producing ~260K patches in total. We select 96x96 patches to perform crop augmentation during training later.

In [None]:
from source.extract_patches import create_patch_dataset

patches_npy_path = join(root_dir, 'results', 'patches', 'training.npy')

# Extracts patches from whole-slide images and store them in a numpy array file
create_patch_dataset(
    input_dir=slide_dir,
    csv_path=csv_path,
    partition_tag='encoder',
    output_path=patches_npy_path,
    image_level=2,
    patch_size=96,
    n_patches_per_image=10000,
    cache_dir=join(cache_dir, 'patches')
)

Once we have extracted the patches, we can proceed to train the BiGAN model. We will use the hyper-parameters described in the NIC paper. 

In [None]:
from source.train_bigan_model import BiganModel

model_bigan_dir = join(root_dir, 'results', 'encoders', 'bigan', 'rotterdam1_96_noaug', '0.0001')

# Trains BiGAN
bigan = BiganModel(
    latent_dim=128,
    n_filters=128,
    lr=0.0001,
    patch_size=64,
)
bigan.train(
    x_path=patches_npy_path,
    output_dir=model_bigan_dir,
    epochs=400000,
    batch_size=64,
    sample_interval=1000,
    save_models_on_epoch=True
)

Beware that training this model is highly unstable, thus it can fail or collapse with ease. If this happens, restart the training. Selecting a checkpoint model is a manual procedure: check the generated images and loss values and avoid abnormal results. 

## 2. Compress images ##

Once we have a trained encoder, we can proceed with the WSI compression. I recommend running several `IDLE` instances of the following code in the cluster to speed up the lenghty process.

Before the actual compression, we need to vectorize the WSIs. This process extracts all non-background patches from the slide and store them in numpy array format for quick access. In this case, we will read 64x64 patches at 1 um/px resolution (level 2).

In [17]:
from source.vectorize_wsi import vectorize_wsi

def vectorize_images(input_dir, csv_path, output_dir, cache_dir, image_level, patch_size, overwrite):
    """
    Converts a set of whole-slide images into numpy arrays with valid tissue patches for fast processing.

    :param input_dir: folder containing the whole-slide images.
    :param csv_path: list of slides.
    :param output_dir: destination folder to store the vectorized images.
    :param cache_dir: folder to store whole-slide images temporarily for fast access.
    :param image_level: image resolution to read the patches.
    :param patch_size: size of the read patches.
    :param overwrite: True to overwrite existing images.
    :return: nothing
    """

    # Output dir
    if not exists(output_dir):
        os.makedirs(output_dir)

    # Read image file names
    df = pd.read_csv(csv_path, header=0, index_col=0)

    # Shuffle names
    df = df.sample(len(df), replace=False)

    # Process files
    for index, row in tqdm(df.iterrows()):

        try:
            wsi_path = join(input_dir, row['batch'], row['slide_id'] + '.mrxs')
            output_pattern = join(output_dir, row['slide_id'] + '_{item}.npy')
            if overwrite or not exists(output_pattern.format(item='im_shape')):
                print('Processing image {image}'.format(image=row['slide_id']), flush=True)
                vectorize_wsi(
                    image_path=cache_file(wsi_path, cache_dir, overwrite=False),
                    mask_path=None,
                    output_pattern=output_pattern,
                    image_level=image_level,
                    mask_level=None,
                    patch_size=patch_size,
                    stride=patch_size,
                    downsample=1,
                    select_bounding_box=True
                )
                print('Successful vectorized {image}'.format(image=row['slide_id']), flush=True)
            else:
                print('Already existing file {image}'.format(image=row['slide_id']), flush=True)

        except Exception as e:
            print('Failed to process image {row}. Exception: {e}'.format(row=row, e=e), flush=True)

In [None]:
# Vectorize WSIs

vectorized_dir = join(root_dir, 'results', 'vectorized', 'rotterdam1')

vectorize_images(
    input_dir=slide_dir,
    csv_path=csv_path,
    output_dir=vectorized_dir,
    cache_dir=join(cache_dir, 'vectorized'),
    image_level=2,
    patch_size=64,
    overwrite=False
)

Now we can compress the WSIs. Each WSI (vectorized file) will be processed 8 times due to WSI-level augmentation (rotation and flip). We will use an existing pretrained encoder from the NIC paper.

In [19]:
from source.featurize_wsi import encode_augment_wsi

def featurize_images(input_dir, csv_path, model_path, output_dir, batch_size, overwrite):
    """
    Compresses a set of whole-slide images using a trained encoder network.

    :param input_dir: directory containing the compressed (featurized) images.
    :param csv_path: path to list of slides.
    :param model_path: path to trained encoder network.
    :param output_dir: destination folder to store the compressed images.
    :param batch_size: number of images to process in the GPU in one-go.
    :param overwrite: True to overwrite existing files.
    :return: nothing.
    """

    # Output dir
    if not exists(output_dir):
        os.makedirs(output_dir)

    # Read image file names
    df = pd.read_csv(csv_path, header=0, index_col=0)

    # Shuffle names
    df = df.sample(len(df), replace=False)

    # Load encoder model
    encoder = keras.models.load_model(
        filepath=model_path
    )

    # Process files
    for index, row in tqdm(df.iterrows()):

        try:
            wsi_pattern = join(input_dir, row['slide_id'] + '_{item}.npy')
            if exists(wsi_pattern.format(item='im_shape')):
                encode_augment_wsi(
                    wsi_pattern=wsi_pattern,
                    encoder=encoder,
                    output_dir=output_dir,
                    batch_size=batch_size,
                    aug_modes=[('none', 0), ('none', 90), ('none', 180), ('none', 270), ('horizontal', 0), ('vertical', 0), ('vertical', 90), ('vertical', 270)],
                    overwrite=overwrite
                )
            else:
                print('Vectorized file not found: {f}'.format(f=wsi_pattern.format(item='im_shape')), flush=True)

        except Exception as e:
            print('Failed to process image {row}. Exception: {e}'.format(row=row, e=e), flush=True)


In [None]:
# Featurize images

featurized_dir = join(root_dir, 'results', 'featurized', 'rotterdam1', 'bigan', 'nic')
model_path = join('models', 'encoders_patches_pathology', 'encoder_bigan.h5')

featurize_images(
    input_dir=vectorized_dir,
    csv_path=csv_path,
    output_dir=featurized_dir,
    model_path=model_path,
    batch_size=128,
    overwrite=False
)

## 3. Train CNN on compressed images ##

Once we have compressed the WSIs, we can proceed with the CNN classifier. In this example, we will train a classifier targeting the binary label `HGP_SL` found in the CSV file. We will be training 4 models using cross-validation: in each fold, we will use 2 data partitions for training, 1 for validation and 1 for testing. At the end of model training, we perform inference on the test set, compute metrics, and run GradCAM on the images.



In [None]:
from source.gradcam_wsi import gradcam_on_dataset
from source.train_compressed_wsi import train_wsi_classifier, eval_model, compute_metrics

def train_model(featurized_dir, csv_path, fold_n, output_dir, cache_dir, batch_size=16,
                images_dir=None, vectorized_dir=None, lr=1e-2, patience=4,
                occlusion_augmentation=False, elastic_augmentation=False, shuffle_augmentation=None):
    """
    Trains a CNN using compressed whole-slide images.

    :param featurized_dir: folder containing the compressed (featurized) images.
    :param csv_path: list of slides with labels.
    :param fold_n: fold determining which data partitions to use for training, validation and testing.
    :param output_dir: destination folder to store results.
    :param cache_dir: folder to store compressed images temporarily for fast access.
    :param batch_size: number of samples to train with in one-go.
    :return: nothing.
    """

    # Params
    folds = [
        {'training': ['partition_0', 'partition_1'], 'validation': ['partition_2'], 'test': ['partition_3']},
        {'training': ['partition_1', 'partition_2'], 'validation': ['partition_3'], 'test': ['partition_0']},
        {'training': ['partition_2', 'partition_3'], 'validation': ['partition_0'], 'test': ['partition_1']},
        {'training': ['partition_3', 'partition_0'], 'validation': ['partition_1'], 'test': ['partition_2']},
    ]
    result_dir = join(output_dir, 'fold_{n}'.format(n=fold_n))

    # Train CNN
    train_wsi_classifier(
        data_dir=featurized_dir,
        csv_path=csv_path,
        partitions=folds[fold_n],
        crop_size=400,
        output_dir=result_dir,
        output_units=2,
        cache_dir=cache_dir,
        n_epochs=200,
        batch_size=batch_size,
        lr=lr,
        code_size=128,
        workers=1,
        train_step_multiplier=1,
        val_step_multiplier=0.5,
        keep_data_training=1,
        keep_data_validation=1,
        patience=patience,
        occlusion_augmentation=occlusion_augmentation,
        elastic_augmentation=elastic_augmentation,
        shuffle_augmentation=shuffle_augmentation
    )

    # Evaluate CNN
    eval_model(
        model_path=join(result_dir, 'checkpoint.h5'),
        data_dir=featurized_dir,
        csv_path=csv_path,
        partitions=folds[fold_n],
        crop_size=400,
        output_path=join(result_dir, 'eval', 'preds.csv'),
        cache_dir=cache_dir,
        batch_size=batch_size,
        keep_data=1
    )

    # Metrics
    try:
        compute_metrics(
            input_path=join(result_dir, 'eval', 'preds.csv'),
            output_dir=join(result_dir, 'eval')
        )
    except Exception as e:
        print('Failed to compute metrics. Exception: {e}'.format(e=e), flush=True)

    # Apply GradCAM analysis to CNN
    gradcam_on_dataset(
        featurized_dir=featurized_dir,
        csv_path=csv_path,
        model_path=join(result_dir, 'checkpoint.h5'),
        partitions=folds[fold_n]['test'],
        layer_name='separable_conv2d_1',
        output_unit=1,
        custom_objects=None,
        cache_dir=cache_dir,
        images_dir=images_dir,
        vectorized_dir=vectorized_dir
    )

In [None]:
# Train CNN

selected_fold = 0
model_dir = join(root_dir, 'results', 'models', 'rotterdam1', 'bigan', 'nic', 'hgp_bin')

train_model(
    featurized_dir=featurized_dir,
    csv_path=csv_path,
    fold_n=selected_fold, 
    output_dir=model_dir,
    cache_dir=join(cache_dir, 'cnn'),
    occlusion_augmentation=False,
    lr=1e-2,
    patience=4,
    elastic_augmentation=False,
    images_dir=slide_dir,  # required for GradCAM
    vectorized_dir=vectorized_dir,  # required for GradCAM
    shuffle_augmentation=None
)