# 4) Emory Confidence Heatmaps

Generate a confidence heatmap image, where each channel corresponds to a pathology feature (cored, diffuse, CAA). The main function runs this for a single input, which is the directory with norm tiles for a WSI. These tiles are used to generate three confidence heatmap for that WSI (one for each class of pathology: cored, diffuse, CAA).

NOTE - this notebook is the longest running notebook of the project. Each WSI can take 2-4 hours to create the heatmap. 

In [None]:
import warnings
warnings.filterwarnings("ignore")
import sys, os, glob
import torch
torch.manual_seed(123456789)
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from tqdm import tqdm_notebook
from os.path import join as oj
from pprint import pprint

#### <center>PyTorch Classes

In [None]:
# PyTorch classes for loading model
class HeatmapDataset(Dataset):
    def __init__(self, tile_dir, row, col, normalize=None, stride=1):
        """
        Args:
            tile_dir (string): path to the folder where tiles are
            row (int): row index of the tile being operated
            col (int): column index of the tile being operated
            normalize is a color norm transform
            stride: stride of sliding (low strides lead to long computational times)
        """
        # tile_size and img_size should not have been changed in the pipeline
        self.tile_size = 256
        self.img_size = 1536
        self.stride = stride
        padding = 128
        large_img = torch.ones(3, 3*self.img_size, 3*self.img_size)
        
        for i in [-1,0,1]:
            for j in [-1,0,1]:
                img_path = tile_dir+'/'+str(row+i)+'/'+str(col+j)+'.jpg'
                try:
                    img = Image.open(img_path)
                    img = transforms.ToTensor()(img) 
                except:
                    img = torch.ones(3,self.img_size, self.img_size)
                
                large_img[:, (i+1)*self.img_size:(i+2)*self.img_size,
                          (j+1)*self.img_size:(j+2)*self.img_size] = img
        
        if normalize is not None:
            large_img = normalize(large_img)
        
        self.padding_img = large_img[:,self.img_size-padding:2*self.img_size+padding,
                                     self.img_size-padding:2*self.img_size+padding]
        self.len = (self.img_size//self.stride)**2
        
    def __getitem__(self, index):

        row = (index*self.stride // self.img_size)*self.stride
        col = (index*self.stride % self.img_size)

        img = self.padding_img[:, row:row+self.tile_size, col:col+self.tile_size]        
    
        return img

    def __len__(self):
        return self.len


class Net(nn.Module):

    def __init__(self, fc_nodes=512, num_classes=3, dropout=0.5):
        super(Net, self).__init__()
        
    def forward(self, x):
 
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x


#### <center>Function

In [None]:
def sliding_conf_heatmap(model, tile_dir, save_path=None, **kwargs):
# img_size, stride, norm_transform, batch_size, num_workers, use_gpu):
    """Generate confidence heatmap using a trained model and a tiled image directory.
    
    :param model : PyTorch model
        trained CNN model
    :param tile_dir : str
        the directory containing the tiled images
    :param save_path : str (default: None)
        location to save confidence heatmap (with filename, no extension). If None then it will not be saved
        but still returned.
    :param kwargs : dict
        look below for kwargs params, must be passed
    
    :return final_output : ndarray
        confidence heatmap for the image, each channel is a heatmap for a class of pathologies (cored, diffuse,
        CAA)    
    """
    imgs = []
    
    # get metadata of tiled directory (col array, row array, num of col and rows) for looping in sliding
    # window approach
    for target in sorted(os.listdir(tile_dir)):
        d = os.path.join(tile_dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                # only grabs the paths to .jpg
                if fname.endswith('.jpg'):
                    path = os.path.join(root, fname)
                    imgs.append(path)
    # imgs contains the list of file/tile images

    # to match imgs get the row and column lists - this helps know where to put the output of each tile in 
    # the overall image
    rows = [int(image.split('/')[-2]) for image in imgs]
    row_nums = max(rows) + 1
    cols = [int(image.split('/')[-1].split('.')[0]) for image in imgs]
    col_nums = max(cols) +1  
    heatmap_res = kwargs['img_size'] // kwargs['stride']
    final_output = np.zeros((3, heatmap_res*row_nums, heatmap_res*col_nums))

    # loop through each row and col
    for row in tqdm_notebook(range(row_nums)):
        for col in range(col_nums):
            # load the data for a row, col pair
            image_datasets = HeatmapDataset(tile_dir, row, col, normalize=kwargs['normalize'], 
                                            stride=kwargs['stride'])
            dataloader = torch.utils.data.DataLoader(image_datasets, batch_size=kwargs['batch_size'],
                                                     shuffle=False, num_workers=kwargs['num_workers'])

            # predict on the image
            running_preds = torch.Tensor(0)
            for data in dataloader:
                # get the inputs
                inputs = data
                # wrap them in Variable
                if kwargs['use_gpu']:
                    inputs = Variable(inputs.cuda(), volatile=True)
                
                    # forward
                    outputs = model(inputs)
                    preds = F.sigmoid(outputs)  # posibility for each class
                    preds = preds.data.cpu()
                    running_preds = torch.cat([running_preds, preds])

            # add image to output
            cored = np.asarray(running_preds[:,0]).reshape(
                kwargs['img_size']//kwargs['stride'],kwargs['img_size']//kwargs['stride'])
            diffuse = np.asarray(running_preds[:,1]).reshape(
                kwargs['img_size']//kwargs['stride'],kwargs['img_size']//kwargs['stride'])
            caa = np.asarray(running_preds[:,2]).reshape(
                kwargs['img_size']//kwargs['stride'],kwargs['img_size']//kwargs['stride'])

            final_output[0, row*heatmap_res:(row+1)*heatmap_res, col*heatmap_res:(col+1)*heatmap_res] = cored
            final_output[1, row*heatmap_res:(row+1)*heatmap_res, col*heatmap_res:(col+1)*heatmap_res] = diffuse
            final_output[2, row*heatmap_res:(row+1)*heatmap_res, col*heatmap_res:(col+1)*heatmap_res] = caa

    if save_path is not None:      
        np.save(save_path, final_output)
    return final_output


#### <center>Global Parameters

In [None]:
norm_stats_path = '../modules/normalization.npy'  # used for color norm transform
img_size = 1536
stride = 16  # 16 is used in paper
batch_size = 64  # vary depending on computer running code
num_workers = 16  # vary depending on computer running code

norm = np.load(norm_stats_path, allow_pickle=True).item()
normalize = transforms.Normalize(norm["mean"], norm["std"])
to_tensor = transforms.ToTensor()

# GPU check
use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using GPU")
else:
    print("Using CPU")

kwargs = {
    'img_size': img_size, 'stride': stride, 'batch_size': batch_size, 'num_workers': num_workers, 
    'normalize': normalize, 'use_gpu': use_gpu
}


#### <center>Run</center>
This is long running process per image, can take 2-4 hours depending on the original file size, GPU used, and system set-up (i.r. RAM available). 

To run select the norm tile dir to work on, such as norm_tiles_dataset_emory. Select the appropriate location to save the heatmap to, note that it will NOT override any heatmap already present there of the same name. It will skip this cases. Lastly, choose the model to use for creating the heatmap.

In [None]:
# """***Parameters***"""
# choose a directory of norm tiles to generate confidence heatmaps for
DIR = '/mnt/Data/norm_tiles/norm_tiles_dataset_3/'
SAVE_DIR = '/mnt/Data/outputs/heatmaps_tang/'
# if batch is none it will run on all norm tile dirs
# otherwise the normtile dirs will be sorted alphabetically and will be run only
# on the range of indices provided
batch = (0, None)  # (start index, end index excluded)

model_path = '../models/CNN_model_parameters.pkl'
# model_path = '../models/CNN_fresh_model_parameters.pkl'

os.makedirs(SAVE_DIR, exist_ok=True)

# instatiate the model
model = torch.load(model_path, map_location=lambda storage, loc: storage)

# modify for gpu usage
if use_gpu:
    print('using GPU')
    model = model.module.cuda()
else:
    model = model.module
    
_ = model.train(False)  # Set model to evaluate mode


"""***main code***"""
pprint(kwargs)
print()

# list the norm dirs
filenames = sorted(glob.glob(DIR + '*'))

filenames = [filename.split('/')[-1] for filename in filenames]
    
if batch is None:
    print('Running on all norm dirs')
    batch = (0, len(filenames))
else:
    print('Running on indices {} to {}'.format(batch[0], batch[1] - 1))
    
    
indices = list(range(batch[0], batch[1]))
n = len(indices)
for i in indices:
    filename = filenames[i]
    print('running index {}'.format(i))
    # does not rerun already existing heatmaps
    save_path = oj(SAVE_DIR, filename+".npy")
    if not os.path.isfile(save_path):
        print("\tanalyzing {}".format(filename))
        tile_dir = oj(DIR, '{}/0/'.format(filename))
        _ = sliding_conf_heatmap(model, tile_dir, save_path=save_path, **kwargs)
