# WSI AutoTiler

>Converting large `.tiff` WSI `into 1024x1024.pngs` images

Gagan Daroach

In [1]:
import os
import re
import statistics
import cv2
import tensorflow as tf
import numpy as np
from sklearn.cluster import MiniBatchKMeans
%matplotlib inline
from matplotlib import pyplot as plt
from datetime import datetime

In [2]:
def load_all_img_paths(dir_to_search):
    filepaths = []
    for subdir, dirs, files in os.walk(dir_to_search):
        for file in files:
            if is_wsi(file):
                filepath = os.path.join(subdir, file)
                filepaths.append(filepath)
    return filepaths

In [3]:
def is_wsi(img_path):
    '''
    Checks if filename is like 145_12.tiff.
    '''
    filename = os.path.basename(img_path)
    wsi_tiff_regex = '\d+_\d+.tif'
    x = re.search(wsi_tiff_regex, filename)
    if x is None:
        return False
    if filename.startswith('._'):
        return False
    return True

In [4]:
def create_image_mask(img, print_time=False):
    starttime = datetime.now()
    img_pixel_list = img.reshape((img.shape[0] * img.shape[1], 3))
    cluster = MiniBatchKMeans(2, tol=0.2)
    cluster.fit(img_pixel_list.astype('uint8'))
    pixel_labels = cluster.labels_
    img_mask = pixel_labels.reshape((img.shape[0],img.shape[1]))
    if print_time:
        print((datetime.now()-starttime).seconds)
    return img_mask

In [5]:
def downsample_img(img, num_downsamples=1):
    ds_img = img
    for i in range(0,num_downsamples):
        ds_img = cv2.pyrDown(ds_img)
    return ds_img

In [6]:
def most_common_pixel_in_img(img):
    mode = statistics.mode(img.flatten())
    return mode

In [7]:
def mask_crop_to_wsi_crop(mask_crop,wsi,mask_x,mask_y,x_ratio,y_ratio,target_png_shape):
    wsi_x = mask_x*x_ratio
    wsi_y = mask_y*y_ratio
    return wsi[wsi_y:wsi_y+target_png_shape,wsi_x:wsi_x+target_png_shape]

In [8]:
def crawl(mask, wsi, target_png_shape=1024):
    w_shape = wsi.shape
    m_shape = mask.shape
    y_ratio = int(w_shape[0]/m_shape[0])
    x_ratio = int(w_shape[1]/m_shape[1])
    mask_crops = []
    wsi_crops = []
    wsi_crops_passed = []
    mask_shape = int(target_png_shape/16) #if scale is 4x
    tissue_classification_color = most_common_pixel_in_img(mask)
    for y in range(0,len(mask),50):
        for x in range(0,len(mask[0]),50):
            mask_crop = mask[y:y+mask_shape,x:x+mask_shape]
            wsi_crop = mask_crop_to_wsi_crop(mask_crop, wsi, x, y, x_ratio, y_ratio, target_png_shape)
            mask_crops.append((mask_crop, wsi_crop))
            if most_common_pixel_in_img(mask_crop)==tissue_classification_color:
                wsi_crops.append(wsi_crop)
            else:
                wsi_crops_passed.append(wsi_crop)
    return wsi_crops, wsi_crops_passed, mask_crops

In [9]:
def save_images(crops, img_path, output_dir):
    wsi_name = os.path.basename(img_path)
    # matches
    match_directory = os.path.join(output_dir, 'match')
    pass_directory = os.path.join(output_dir, 'pass')
    mask_directory = os.path.join(output_dir, 'mask')
    
    wsi_crops = crops[0]
    for i, img in enumerate(wsi_crops):
        filename = f'{wsi_name}_{i}.png'
        full_path = os.path.join(match_directory, filename)
        cv2.imwrite(full_path, img)
        print(f'saved: {filename}')
        
    wsi_crops_passed = crops[1]
    for i, img in enumerate(wsi_crops_passed):
        filename = f'{wsi_name}_pass_{i}.png'
        full_path = os.path.join(pass_directory, filename)
        cv2.imwrite(full_path, img)
        print(f'saved: {filename}')
        
#     masks = crops[2]
#     for i, img in enumerate(masks):
#         filename = f'{wsi_name}_mask_{i}.png'
#         full_path = os.path.join(mask_directory, filename)
#         cv2.imwrite(full_path, img)
#         print(f'saved: {filename}')

In [10]:
input_tiff_dir = '/srv/tank/mcw/Prostates'
output_dir ='/srv/tank/mcw/autotiler'

In [11]:
def main():
    wsi_paths = load_all_img_paths(input_tiff_dir)
    print('Loaded These tiffs')
    for x in wsi_paths:
        print(f'{os.path.basename(x)}')
    print(wsi_paths[0])
    for img_path in wsi_paths:
        print(f'starting on {img_path}')
        wsi_img = cv2.imread(img_path)
        downscaled_4x = downsample_img(wsi_img, 4)
        mask_4x = create_image_mask(downscaled_4x)
        crops = crawl(mask_4x, wsi_img)
        save_images(crops, img_path, output_dir)

In [None]:
main()

Loaded These tiffs
146_7.tiff
216_11.tiff
149_8.tiff
209_14.tiff
150_6.tiff
152_8.tiff
145_11.tiff
163_8.tiff
179_8.tiff
181_9.tiff
151_8.tiff
164_8.tiff
148_9.tiff
215_10.tiff
211_13.tiff
212_14.tiff
147_10.tiff
210_2.tiff
180_11.tiff
214_14.tiff
/srv/tank/mcw/Prostates/146_7/146_7.tiff
starting on /srv/tank/mcw/Prostates/146_7/146_7.tiff
saved: 146_7.tiff_0.png
saved: 146_7.tiff_1.png
saved: 146_7.tiff_2.png
