# Detection WSI Tissue with Semantic Segmentation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import sys
sys.path.append('../..')

import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv

from neurotk.torch.models import deeplabv3_model
from neurotk.torchvision.semantic_segmentation_transforms import (
    Normalize, ToTensor, Resize, Compose
)
from neurotk import imread

## Viewing Predictions from Trained Models

In [None]:
# Get a standard model and load pre-trained weights.
state_dict_fp = '/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detection/models/run1/best.pt'
model = deeplabv3_model()
model.load_state_dict(torch.load(state_dict_fp))
_ = model.eval()

In [None]:
# Read a sample image.
img_fp = '/jcDataStore/Data/nft-ai-project/wsi-inference/tissue-masks/images/NA5023-02_AT8.png'
# img = Image.open('/jcDataStore/Data/nft-ai-project/wsi-inference/tissue-masks/images/1023340.png')
mask = Image.open('/jcDataStore/Data/nft-ai-project/wsi-inference/tissue-masks/masks/NA5023-02_AT8.png')


def contours_to_points(contours):
    """Convert a list of opencv contours (i.e. contour shape is 
    (num_points, 1, 2) with x, y order) to a list of x,y point in format 
    ready to push as DSA annotations. This form is a list of lists with 
    [x, y, z] format where the z is always 0.
    
    Args:
        contours: List of numpy arrays in opencv contour format

    Returns:
        Points in DSA format.
    
    """
    points = []
    
    for contour in contours:
        points.append([
            [float(pt[0][0]), float(pt[0][1]), 0] for pt in contour
        ])
        
    return points


def predict_mask(model, img, size=256, norm=None, thresh=0.7):
    """Predict mask on the image given the model, and output the mask
    in the same aspect ratio as the input image.

    """
    model.eval()  # should not be modifying weights
    
    if isinstance(img, str):
        img = Image.open(img)
    elif isinstance(img, np.ndarray):
        img = Image.fromarray(img)
    elif not isinstance(img, Image.Image):
        raise TypeError(
            'img must be a filepath string, ndarray, or PIL image'
        )

    if norm is None:
        # Default normalization values for ImageNet.
        norm = {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225]
        }

    # Get the original shape of the image.
    orig_shape = img.size

    # Apply transforms to image.
    transforms = Compose([
        ToTensor(),
        Resize((size, size)),
        Normalize(mean=norm['mean'], std=norm['std'])
    ])

    img = transforms(img, Image.new('L', orig_shape))[0]

    # Predict the mask.
    with torch.set_grad_enabled(False):
        pred = model(img.unsqueeze(0))['out'][0][0]

        # Treshold the mask to keep pixels which represent positives.
        mask = (pred.cpu().numpy() > thresh).astype(np.uint8) * 255

        # Reformat the mask to its original size.
        mask = cv.resize(mask, orig_shape, None, None, cv.INTER_NEAREST)

        return mask


img = imread(img_fp)
pred = predict_mask(model, img)

sf = 40 / .25

# Extract contours.
contours = cv.findContours(pred, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)[0]

# Smoothe the contours
smoothed_contours = []

for contour in contours:
    smoothed_contours.append(cv.approxPolyDP(contour, 1, True))

# Convert the list of contours to points in DSA format.
tissue_points = contours_to_points(smoothed_contours)

# Convert each contour into a list dictionary to pass as an annotation 
# DSA element.
tissue_els = []

for pt in tissue_points:
    # Skip a point with too few points*
    # * DSA appears to prevent annotations of three points only.
    if len(pt) < 4:
        continue
        
    # Scale the points
    pt = np.array(pt) * sf
    
    tissue_els.append({
        'group': 'test',
        'type': 'polyline',
        'lineColor': 'rgb(0,179,60)',
        'lineWidth': 4.0,
        'closed': True,
        'points': pt.tolist(),
        'label': {'value': 'test'},
    })

# # Push as annotations.
# _ = gc.post(
#     f"/annotation?itemId={item['_id']}", 
#     json={
#         'name': doc_name, 
#         'description': 'Extracted from low res binary masks.', 
#         'elements': tissue_els})

## Better Gray Scale Datasets.
Create dataset at two sizes - 512 & 1280. Make the images be grayscale instead of RGB and instead of changing the aspect ratio to be square, use padding to avoid changing the original ratio of the images.

In [11]:
import sys
sys.path.append('../..')

from pandas import read_csv, DataFrame
import cv2 as cv
import numpy as np
from tqdm.notebook import tqdm
from os.path import join, basename, splitext

from neurotk import imread, imwrite
from neurotk.utils import create_dirs

In [14]:
# Filepaths and directories.
save_dir = '/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detection/datasets/' + \
           'tissue-dataset'
img_dir = join(save_dir, 'images')
mask_dir = join(save_dir, 'labels')
create_dirs([img_dir, mask_dir])

# Params
size = 512
grayscale = False  # False for RGB
pad = (255, 255, 255)  # value to pad image with
blur_kernel = (15, 15)  # size of kernel to blur image with
src_fp = '/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detection/' + \
         'tissue-dataset.csv'  # filepath to thumbnail / mask metadata
         
         
def reshape_with_pad(img, size, pad = (255, 255, 255)):
    """Reshape an image into a square aspect ratio without changing the original
    image aspect ratio - i.e. use padding.
    
    """
    h, w = img.shape[:2]

    if w > h:
        img = cv.copyMakeBorder(img, 0, w-h, 0, 0, cv.BORDER_CONSTANT, None, 
                                pad)
    else:
        img = cv.copyMakeBorder(img, 0, 0, 0, h-w, cv.BORDER_CONSTANT, None, 
                                pad)

    # Reshape the image.
    img = cv.resize(img, (size, size), None, None, cv.INTER_NEAREST)

    return img

In [16]:
# Track information about each image.
img_metadata = []

df = read_csv(src_fp)

for _, r in tqdm(df.iterrows(), total=len(df)):
    img = reshape_with_pad(imread(r.fp, grayscale=grayscale), size, pad=pad)
    mask = reshape_with_pad(imread(r.label, grayscale=grayscale), size, pad=0)

    # Blur mask to get edges smoother.
    mask = cv.GaussianBlur(mask, (15,15), 0, 0)
    mask = (mask > 0).astype(np.uint8) * 255

    fn = f'{splitext(basename(r.fp))[0]}.png'
    img_fp = join(img_dir, fn)
    mask_fp = join(mask_dir, fn)

    # Save image and mask
    imwrite(img_fp, img)
    imwrite(mask_fp, mask, grayscale=grayscale)

    img_metadata.append([img_fp, mask_fp, size])

# Save the metadata.
img_metadata = DataFrame(img_metadata, columns=['fp', 'label', 'size'])
img_metadata.to_csv(join(save_dir, 'dataset.csv'), index=False)
img_metadata.head()

  0%|          | 0/539 [00:00<?, ?it/s]

Unnamed: 0,fp,label,size
0,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,512
1,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,512
2,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,512
3,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,512
4,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,/jcDataStore/Data/NeuroTK-Dash/ml-tissue-detec...,512
