# Segmentated Images
Felipe Giuste 08-11-2020

Modified by Danni Chen 02-26-2022

In [1]:
# from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:95% !important; }</style>"))

In [1]:
import numpy as np
import pandas as pd
import re

# ## PyTorch ##
import torch
from torch import nn #, optim
from torch.autograd import Variable

## TorchVision ##
import torchvision.transforms as transforms

## Utils ##
from getClassificationStats import getStats
from TrainTestSplit import ImageDataset

## Plot ##
import matplotlib.pyplot as plt


## Seed ##
random_state = 1234
np.random.seed(random_state)
torch.manual_seed(random_state)
torch.cuda.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state)

## CUDNN ##
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

## Use best Device (CUDA vs CPU) ##
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Print Device Properties ##
if device == torch.device('cuda'):
    print( torch.cuda.get_device_properties( device ) )

cpu


# User Variables

In [None]:
## Model Name ##
from datetime import datetime
timeStamp = str(datetime.now())
model_name = 'UNET_SEGMENTATION_COVID_LUNG ' + timeStamp
## Import CAE class ##
from utils.CAE_Segmentation_V18 import model

## Model Weights path ##
model_weights_path = 'model/%s_Final.pt'% model_name

## Normalize before PCA ##
normalize = False


## Images per batch ##
batch_size= 8

## Image Directory ##
image_directory = '/data/HeartTransplant/patches_2020-08-27_20X/' # 20X patches
# '/data/HeartTransplant/patches_2020-08-14_10X/' # 10X patches
# '/data/HeartTransplant/patches_2020-08-27_20X/' # 20X patches
# '/data/HeartTransplant/patches_2020-08-05_40X/' # 40X patches

# Metadata array:
metadata_path= 'data/08-27-2020/metadata_2020-08-27_20X.csv' # 20X metadata
# 'data/08-14-2020/metadata_2020-08-14_10X.csv' # 10X metadata
# 'data/08-27-2020/metadata_2020-08-27_20X.csv' # 20X metadata
# 'data/08-05-2020/metadata_2020-08-05_40X.csv' # 40X metadata

## QC Variables ##
focus_threshold = 3000 # 20X
# 2500 # 10X
# 3000 # 20X
# 3000 # 40X
grey_threshold  = 1000 # 20X
# 1000 # 10X
# 1000 # 20X
# 1000 # 40X

## Minimum Number of Tiles Passing QC ##
nTilesThresh = 50

## Metadata dataframe ##
metadata = pd.read_csv( metadata_path, index_col=0, dtype= {'Subject':str} )

# Dataset

In [None]:
## WSI to Test ##
WSI_ID = 'S14-2657_2016-09-1419.43.29_he' # 1 Annotated, Scratched #4
# 'CR1_018225_HE_2'
# 'CR1_164876_HE_1' # Small
# 'CR2_673196_HE_2' # Purple
# 'S13-1900_2016-09-1417.42.38_he' # Pink
# 'S14-2657_2016-09-1419.43.29_he' # 1 Annotated, Scratched #4

## Testing Dataset: Metadata ##
# tile_df = metadata[ metadata['ID'].str.startswith('C') ] # WSI name starts with 'C'
# tile_df = metadata[ metadata['ID'].str.startswith('S') ] # WSI name starts with 'S'
tile_df = metadata[ metadata['ID'] == WSI_ID ]

## Testing Dataset: Define ##
tile_dataset = ImageDataset( metadata=tile_df, root=image_directory, 
                             focus_threshold=focus_threshold, # QC
                             grey_threshold=grey_threshold,   # QC
                             nTilesThresh= nTilesThresh,      # QC
                             augment=False # No Augmentation
                           ) 
## Testing Dataset: Dataloader ##
tile_loader = torch.utils.data.DataLoader(
    tile_dataset, batch_size=batch_size, 
    pin_memory=True,
    shuffle= True, # Set:False to match Tensorboard output
    num_workers=40 )


## Augmented Dataset: Define ##
augmented_dataset = ImageDataset( metadata=tile_df, root=image_directory, 
                             focus_threshold=focus_threshold, # QC
                             grey_threshold=grey_threshold,   # QC
                             nTilesThresh= nTilesThresh,      # QC
                             augment=True # Add Augmentation
                           ) 
## Augmented Dataset: Dataloader ##
augmented_loader = torch.utils.data.DataLoader(
    tile_dataset, batch_size=batch_size, 
    pin_memory=True,
    shuffle= True, # Set:False to match Tensorboard output
    num_workers=40 )


nPatches = len( augmented_dataset.metadata )
print( 'Number of Patches (QC Passed): %s'% nPatches )

In [None]:
## Tiles: Total ##
print( 'Tiles (Total): %s'% len(metadata) )
## Tiles: Passed QC ##
print( 'Tiles (Passed QC): %s'% len(tile_dataset.metadata) )

## Non-Rejection Tiles (Passed QC) ##
n_nonRejection = sum( tile_dataset.metadata['Label'] == 0 )
print( 'Tiles Non-Rejection (Passed QC): %s'% n_nonRejection )

## Rejection Tiles (Passed QC) ##
n_Rejection = sum( tile_dataset.metadata['Label'] != 0 )
print( 'Tiles Rejection (Passed QC): %s'% n_Rejection )

print()
## Non-Rejection WSI (Passed QC) ##
no_rejection = tile_dataset.metadata['Label'] == 0
n_nonRejection = len( tile_dataset.metadata[no_rejection]['ID'].unique() )
print( 'WSI Non-Rejection (Passed QC): %s'% n_nonRejection )

## Rejection WSI (Passed QC) ##
rejection = tile_dataset.metadata['Label'] != 0
n_Rejection = len( tile_dataset.metadata[rejection]['ID'].unique() )
print( 'WSI Rejection (Passed QC): %s'% n_Rejection )

# Model: Parameters

In [None]:
## Model: Match Saved ##
print("Total GPUs: %s" %torch.cuda.device_count() )
## Model saved as parallel ##
if( list( torch.load(model_weights_path).keys() )[0].find('module.', 0, 7) != -1 ):
    ## Parallel: adds 'model.' in front of model layers ##
    model = nn.DataParallel(model)

## Load Trained Weights (Note: needs to be after DataParallel) ##
model.load_state_dict( torch.load(model_weights_path), strict=True)
# print( model.state_dict()['module.encoder.0.weight'][0] )

## Freeze Model ##
for p in model.parameters():
    p.requires_grad = False
    
# Model: Send to device
model.to(device);

## Do not Train ##
model.eval();

# Normalize Encoded Channels

In [None]:
## Normalize Channels (range 0-1) ##
if(normalize):

    # List of Channel means (list of lists) ##
    batch_means = []
    batch_maxs = []
    batch_mins = []

    ## Iterate across image batches ##
    total_batches = len(tile_loader)
    for batch_indx, item in enumerate(tile_loader):
        print('Train Batch: %s / %s (%5.2f)     '% (batch_indx+1, total_batches, 
                                                    ((batch_indx+1)/total_batches)), end='\r')
        ## Assaign item (dict) values ##
        img, label, _ = item.values()
        ## Image to GPU ##
        img = Variable(img).to(device)
        ## Forward: encode->decode ##
        encoded, _, _ = model(img)

        ## Channel Means for batch ##
        channel_means = [ torch.mean(encoded[:,channel]).item() for channel in range(encoded.shape[1]) ]
        ## Append to batch_means list ##
        batch_means.append( channel_means )

        ## Channel Maxes for batch ##
        channel_maxs = [ torch.max(encoded[:,channel]).item() for channel in range(encoded.shape[1]) ]
        ## Append to batch_maxs list ##
        batch_maxs.append( channel_maxs )

        ## Channel Means for batch ##
        channel_mins = [ torch.min(encoded[:,channel]).item() for channel in range(encoded.shape[1]) ]
        ## Append to batch_mins list ##
        batch_mins.append( channel_mins )

    ## batch_means to Array (n_batches, n_channels) ##
    batch_means = np.array(batch_means)
    batch_maxs = np.array(batch_maxs)
    batch_mins = np.array(batch_mins)

    ## Channel Means ##
    channel_means = np.mean( batch_means, axis=0 )
    channel_maxs = np.max( batch_maxs, axis=0 )
    channel_mins = np.min( batch_mins, axis=0 )

    ## Subtract Mins from Maxes ##
    channel_maxs = channel_maxs - channel_mins

# Model Interpretation

In [None]:
## Batches (iterator) ##
batches = iter( tile_loader )

## Divide encoded by tile maximum ##
brighten = False

In [None]:
### Batch y, encoded, reconstruction ###
## Batch ##
batch = next(batches)
## Assaign item (dict) values ##
img, y, ID = batch.values()
## Image to GPU ##
img = Variable(img).to(device)
## Forward: encode->decode ##
# encoded, _, reconstruction = model(img)
model_output = model(img)
encoded = model_output['encoded']
reconstruction = model_output['reconstruction']

## Normalize Channels (range 0-1) ##
if(normalize):
    ## Iterate across channels ##
    for channel in range(encoded.shape[1]):
        encoded[:, channel] = ( encoded[:, channel] - channel_mins[channel] ) / channel_maxs[channel]


## Quilt rows and columns ##
nCols= 2+encoded.shape[1]
nRows=batch_size
## Spacers ##
wspace= 0.
hspace= 0.25
## Quilt width and height ##
figsize_x = 20+(wspace*nCols)
figsize_y = 20*(nRows/nCols)+(hspace*nRows)

## Plot Tiles ##
fig, ax = plt.subplots(nrows=nRows, ncols=nCols, figsize=(figsize_x, figsize_y), 
                       gridspec_kw = {'wspace':wspace, 'hspace':hspace}) 

## Don't change model ##
model.eval()

## While loops for control ##
patch_index = 0
i= 0
while i < nRows:
    ## Don't overflow ##
    if( patch_index > batch['image'].shape[0] -1 ):
        ## Next Batch ##
        try:
            batch = next(batches)
        except StopIteration:
            break
        ## Assaign item (dict) values ##
        img, y, ID = batch.values()
        ## Image to GPU ##
        img = Variable(img).to(device)
        ## Forward: encode->decode ##
#         encoded, _, reconstruction = model(img)
        model_output = model(img)
        encoded = model_output['encoded']
        reconstruction = model_output['reconstruction']
        
        ## Normalize Channels (range 0-1) ##
        if(normalize):
            ## Iterate across channels ##
            for channel in range(encoded.shape[1]):
                encoded[:, channel] = ( encoded[:, channel] - channel_mins[channel] ) / channel_maxs[channel]
        
        ## Reset Index ##
        patch_index = 0

        
        
    ## Original Image ##
    image_original = img[patch_index].cpu().numpy()
    image_original = np.moveaxis(image_original, 0, -1)

    ## Encoded Image ##
    image_encoded = encoded[patch_index].cpu().numpy()
    image_encoded = np.moveaxis(image_encoded, 0, -1)
    
    ## Reconstructed Image ##
    image_reconstruction = reconstruction[patch_index].cpu().numpy()
    image_reconstruction = np.moveaxis(image_reconstruction, 0, -1)
    
    

    ## Show Original ##
    ax[i,0].imshow(image_original, vmin=0, vmax=1.)
    ax[i,0].set_xticks([])
    ax[i,0].set_yticks([])
    ax[i,0].set_aspect('equal')
    
    ## Show Reconstruction ##
    ax[i,1].imshow(image_reconstruction, vmin=0, vmax=1., cmap='Greys_r')
    ax[i,1].set_xticks([])
    ax[i,1].set_yticks([])
    ax[i,1].set_aspect('equal')
    
    ## Iterate Across Encoded Channels ##
    for channel in range(image_encoded.shape[2]):
        ## Show Encoded ##
        encoded_channel = image_encoded[:,:,channel]
        if(brighten):
            ax[i,channel+2].imshow(encoded_channel /np.max(encoded_channel), vmin=0, vmax=1., cmap='Greys_r') #
        else:
            ax[i,channel+2].imshow(encoded_channel, vmin=0, vmax=1., cmap='Greys_r') #
        ax[i,channel+2].set_xticks([])
        ax[i,channel+2].set_yticks([])
        ax[i,channel+2].set_aspect('equal')
        
        ## Column Labels ##
        if( patch_index == 0 ):
            ax[i,channel+2].set_title('%s'% ('Encoded: %s'% channel), fontsize=12, color='blue')
    
    
    ## Column Labels ##
    if( patch_index == 0 ):
        ax[i,0].set_title('%s'% ('Original'), fontsize=12, color='blue')
        ax[i,1].set_title('%s'% ('Reconstruction'), fontsize=12, color='blue')
    
    
    ## Image (Row) Label ##
    y_label = str(ID[patch_index])
    ax[i,0].set_ylabel('%s'% (y_label), 
                       fontsize=8,
                       color='blue')

    ## Next Tile ##
    patch_index += 1
    i=i+1

# Examine Tiles

In [None]:
%%script echo 'Skip'

batch_index = 5

## Encoded Image ##
image_encoded = encoded[batch_index].cpu().numpy()
image_encoded = np.moveaxis(image_encoded, 0, -1)

## Original Image ##
image_original = img[batch_index].cpu().numpy()
image_original = np.moveaxis(image_original, 0, -1)

np.max(image_encoded[:,:,2] )

In [None]:
%%script echo 'Skip'
## Show Original ##
plt.imshow(image_original, vmin=0, vmax=1.);
plt.xticks([]);
plt.yticks([]);

In [None]:
%%script echo 'Skip'
## Show Encoded ##
plt.imshow(image_encoded[:,:,:3], vmin=0, vmax=1.);
plt.xticks([]);
plt.yticks([]);

# Histogram

In [None]:
%%script echo 'Skip'
plt.imshow( image_encoded[:,:,1], vmin=0, vmax=1., cmap='Greys_r' );

In [None]:
%%script echo 'Skip'
plt.hist( image_encoded[:,:,3].flatten(), bins=500 );

In [None]:
%%script echo 'Skip'
## C2 ##
plt.imshow( image_encoded[:,:,2], vmin=0, vmax=1., cmap='Greys_r' );

In [None]:
%%script echo 'Skip'
import cv2

## Average Mask (remove noise) ##
image_blur = cv2.blur(image_encoded[:,:,2],(2,2))
plt.imshow( image_blur, vmin=0, vmax=1., cmap='Greys_r' );

In [None]:
%%script echo 'Skip'
## ??? ##
kernel = np.ones((2,2),np.uint8)

image_morph = cv2.morphologyEx(image_encoded[:,:,2], cv2.MORPH_OPEN, kernel )
image_morph = cv2.morphologyEx(image_morph, cv2.MORPH_CLOSE, kernel )
plt.imshow( image_morph, vmin=0, vmax=1., cmap='Greys_r' );