In [1]:
import SimpleITK as sitk
from matplotlib import pyplot as plt
import os
from skimage import io, transform, util, img_as_float
from skimage.filters import gaussian
from skimage.color import separate_stains, hax_from_rgb, rgb2gray, gray2rgb
from skimage.exposure import rescale_intensity
from skimage.util import img_as_ubyte, crop
from skimage.draw import rectangle
from skimage.morphology import remove_small_objects, remove_small_holes

import math
import numpy as np
import tifffile
import cv2
import glob
import numpy as np
import time
import re

from joblib import Parallel, delayed
from itertools import groupby
from tqdm import tqdm
import pandas as pd

plt.rcParams['figure.figsize'] = 100, 100
plt.rcParams['image.cmap'] = 'gray'

In [2]:
# Set relevant information required by rest of the program

output_dir = '/home/rohit/Documents/Balloon Cell Tiles'
raw_img_dir = '/home/rohit/Documents/Slides'
stain_order = ["HE", "TB"]
pat_names = ['N19-783-01_TB.scn', 'N19-847-01_TB.scn', "N19-783-01_HE.scn", "N19-847-01_HE.scn"]

In [3]:
input_folder = output_dir + '/BaseAligned'
output_folder = output_dir + '/TiledOverlap'
reg_folder = 'Registered'

tile_x, tile_y = 2048, 2048              # Default (2048 x 2048), Training (686 x 666)
overlap_x, overlap_y = 512, 512          # Default (512 x 512), Training (0 x 0)


# Define Functions #


In [4]:
def myshow(image):
    img = sitk.GetArrayFromImage(image)
    plt.imshow(img)


In [5]:
def closest(lst, K):    
    return lst.index(lst[min(range(len(lst)), key = lambda i: abs(lst[i]-K))])

''' 
ORIGINAL GOALRES FUNCTION

def goalRes(file):
    l_res = []
    for page in tifffile.TiffFile(file).pages:
        l_res.append(page.shape[0])
    return closest(l_res, 32000)
'''

def goalRes(file):
    l_res = []
    for index, page in enumerate(tifffile.TiffFile(file).pages):
        try:
            l_res.append([index, page.tags['XResolution'].value[1]])
        except:
            l_res.append([index, 0])
    
    return closest([x[1] for x in l_res], 237000)


In [6]:
def myReg(fixed, moving):
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(fixed)
    elastixImageFilter.SetMovingImage(moving)
    parameterMap = sitk.GetDefaultParameterMap('translation')
    elastixImageFilter.SetParameterMap(parameterMap)
    elastixImageFilter.Execute()
    resultImage = elastixImageFilter.GetResultImage()
    transformParameterMapVector = elastixImageFilter.GetTransformParameterMap()
    return resultImage, transformParameterMapVector

In [7]:
# Since we are downscaling even more, change all factors of 2 to 4 (we downscale twice)

''' 
def upscaleMatrix(matrix):
    #matrix[0]['CenterOfRotationPoint'] = [str(float(matrix[0]['CenterOfRotationPoint'][0])*2), str(float(matrix[0]['CenterOfRotationPoint'][1])*2)]
    matrix[0]['Size'] = [str(float(matrix[0]['Size'][0])*2), str(float(matrix[0]['Size'][1])*2)]
    matrix[0]['TransformParameters'] = [str(float(matrix[0]['TransformParameters'][0])*2), str(float(matrix[0]['TransformParameters'][1])*2)]
    matrix[0]['FinalBSplineInterpolationOrder'] = ['1']
'''
    
def upscaleMatrix(matrix):
    #matrix[0]['CenterOfRotationPoint'] = [str(float(matrix[0]['CenterOfRotationPoint'][0])*2), str(float(matrix[0]['CenterOfRotationPoint'][1])*2)]
    matrix[0]['Size'] = [str(float(matrix[0]['Size'][0])*4), str(float(matrix[0]['Size'][1])*4)]
    matrix[0]['TransformParameters'] = [str(float(matrix[0]['TransformParameters'][0])*4), str(float(matrix[0]['TransformParameters'][1])*4)]
    matrix[0]['FinalBSplineInterpolationOrder'] = ['1']

    
def splitChans(path):
    split_chans = []
    for i in range(io.imread(path).shape[2]):
        split_chans.append(io.imread(path)[:,:,i])
    return split_chans
    

In [8]:
def deconv(image):
    try:
        img = io.imread(image)
        hax  = separate_stains(img, hax_from_rgb)
        hema = hax[:,:,0]
        hema[hema<0] = 0
        hema[hema>0.1] = 0.1
        h = rescale_intensity(hema, out_range = 'uint8')
    except:
        hax = separate_stains(image, hax_from_rgb)
        hema = hax[:,:,0]
        hema[hema<0] = 0
        hema[hema>0.1] = 0.1
        h = rescale_intensity(hema, out_range = 'uint8')
    return h

In [9]:
def tileSave(file):
    image = io.imread(file)
    tile_rows_overlap = list(range(int(image.shape[0]/tile_x)))
    tile_cols_overlap = list(range(int(image.shape[1]/tile_y)))
    combined = [(f,s) for f in tile_rows_overlap for s in tile_cols_overlap]    
    
    try: 
        os.makedirs('{}/{}'.format(output_folder,os.path.basename(file)[:-5]))
        os.chdir('{}/{}'.format(output_folder,os.path.basename(file)[:-5]))
    except:
        os.chdir('{}/{}'.format(output_folder,os.path.basename(file)[:-5]))

    filename = os.path.basename(file)[:-5]
    log = open('savelog.txt', 'w+')
    
    # Create low res slide image with tile rows and columns labeled
    
    imgScale = 0.25   # scale factor for low res slide image
    newX, newY = int(image.shape[1]*imgScale), int(image.shape[0]*imgScale)
    lowres = cv2.resize(image, (newX, newY))
    
    for i in tile_cols_overlap[:-1]:
        col = int((tile_x * imgScale) + i * (tile_x * imgScale))
        rr, cc = rectangle((0, col), (int(image.shape[0]*imgScale)-1, col))
        color = (255, 0, 0)
        lowres[rr, cc] = color
#         cv2.putText(lowres, text = "Col " + str(i), 
#                     org = (int(col - (0.8 * (tile_x * imgScale) - (overlap_x * imgScale) )), 50), 
#                     fontFace = 3, fontScale = 1, color = (0,0,255), thickness = 2)   # edit font size/scale here
    for j in tile_rows_overlap[:-1]:
        row = int((tile_y * imgScale) + j * (tile_y * imgScale))
        rr, cc = rectangle((row, 0), (row, int(image.shape[1]*imgScale)-1))
        color = (255, 0,0)
        lowres[rr, cc] = color
#         cv2.putText(lowres, text = "Row " + str(j), 
#                     org = (20, row - 90), fontFace = 3, 
#                     fontScale = 1, color = (0,0,255), thickness = 2)
    for spot in combined: 
        cv2.putText(lowres, text = '({}, {})'.format(spot[0], spot[1]), org = (int(spot[1]*tile_x*imgScale + (tile_x*imgScale)/2), int(spot[0]*tile_y*imgScale + (tile_y*imgScale)/2)), 
                    fontFace = 3, fontScale = 1, color = (0,0,255), thickness = 2)
    
    io.imsave('{}/{}_tilemap.tiff'.format(output_folder, filename[15:]), lowres)
    
    
#     # Save cropped tiles
    for i, j in combined:    
        row_top_offset = max(0, i * tile_x - overlap_x)
        row_bot_offset = max(0, image.shape[0] - (tile_x * (i+1) + overlap_x))
        col_l_offset = max(0, j * tile_y - overlap_y)
        col_r_offset = max(0, image.shape[1] - (tile_y * (j+1) + overlap_y))  

    #    print(str('{} {} {} {} {} {}'.format(row_top_offset, row_bot_offset, col_l_offset, col_r_offset, tile_rows_overlap[i], tile_cols_overlap[j])))

        croppedimg = crop(image,((row_top_offset,row_bot_offset),(col_l_offset, col_r_offset),(0,0)), copy = False)   
        #tissue detection block
        croppedimg[np.where((croppedimg==[0,0,0]).all(axis=2))] = [255,255,255]

        ''' 
        # NOTE* tweak this number (or get rid of it all together) to ensure that tiles match up in HE & TB
        if np.mean(croppedimg) < 250:
            io.imsave(os.path.basename(filename)[15:] + '_r{}_c{}.tiff'.format(tile_rows_overlap[i], tile_cols_overlap[j], croppedimg), croppedimg)
            log.write('saved ' + os.path.basename(os.path.dirname(filename)) +'_'+ os.path.basename(filename) + '_r{}_c{}.tiff\n'.format(tile_rows_overlap[i], tile_cols_overlap[j], croppedimg))
        '''   
        io.imsave(os.path.basename(filename)[15:] + '_r{}_c{}.tiff'.format(tile_rows_overlap[i], tile_cols_overlap[j], croppedimg), croppedimg)
        log.write('saved ' + os.path.basename(os.path.dirname(filename)) +'_'+ os.path.basename(filename) + '_r{}_c{}.tiff\n'.format(tile_rows_overlap[i], tile_cols_overlap[j], croppedimg))    
    log.close()

In [10]:
def base_align(down):
    im_down = io.imread(down[0])
    gray_im_down = rgb2gray(im_down)
    fixedDown = sitk.GetImageFromArray(gray_im_down)
    # fixedDown = sitk.GetImageFromArray(deconv(down[0]))
    for index, x in enumerate(down[1:]):
        im_x = io.imread(x)
        gray_im_x = rgb2gray(im_x)
        movingDown = sitk.GetImageFromArray(gray_im_x)
        # movingDown = sitk.GetImageFromArray(deconv(x))
        
        result1, transf1 = myReg(fixedDown, movingDown)

        upscaleMatrix(transf1)

        transformix = sitk.TransformixImageFilter()
        transformix.SetTransformParameterMap(transf1)
        
        _, _, _, pat, stain, _ = os.path.basename(x[:-5]).split('_')
        match = next(y[2] for y in l_base if y[0] == pat and y[1] == stain)
        split_chans = splitChans(match)

        out_chans = []
        for img in split_chans:
            transformix.SetMovingImage(sitk.GetImageFromArray(img))
            transformix.Execute()
            out_chans.append(sitk.GetArrayFromImage(sitk.Cast(transformix.GetResultImage(), sitk.sitkUInt8)))
        combined = np.dstack((out_chans[0], out_chans[1], out_chans[2]))
        tifffile.imwrite(output_dir + '/BaseAligned/Aligned_{}'.format(os.path.basename(match)), combined)
    _, _, _, pat0, stain0, _ = os.path.basename(down[0])[:-5].split('_')
    orig_pd1 = next(z[2] for z in l_base if z[0] == pat0 and z[1] == stain0)
    tifffile.imwrite(output_dir + '/BaseAligned/Aligned_{}'.format(os.path.basename(down[0])[5:]), io.imread(orig_pd1))      


In [11]:
def match_keypoints(moving, target, feature_detector):
    '''
    :param moving: image that is to be warped to align with target image
    :param target: image to which the moving image will be aligned
    :param feature_detector: a feature detector from opencv
    :return:
    '''

    kp1, desc1 = feature_detector.detectAndCompute(moving, None)
    kp2, desc2 = feature_detector.detectAndCompute(target, None)

    matcher = cv2.BFMatcher(normType=cv2.NORM_L2, crossCheck=True)
    matches = matcher.match(desc1, desc2)
    matches_sorted = sorted(matches, key = lambda x: x.distance)
    
    totalDistance = 0
    for g in matches_sorted:
        totalDistance += g.distance
    try:
        averageDistance = totalDistance/len(matches_sorted)
    except:
        averageDistance = 0

    src_match_idx = [m.queryIdx for m in matches_sorted[:200]]
    dst_match_idx = [m.trainIdx for m in matches_sorted[:200]]

    src_points = np.float32([kp1[i].pt for i in src_match_idx])
    dst_points = np.float32([kp2[i].pt for i in dst_match_idx])

    H, mask = cv2.findHomography(src_points, dst_points, cv2.RANSAC, ransacReprojThreshold=3)   # orig = 7 reprojthreshold could be tuned for best performance

    good = [matches_sorted[i] for i in np.arange(0, len(mask)) if mask[i] == [1]]
    
    num_good = len(good)

    filtered_src_match_idx = [m.queryIdx for m in good]
    filtered_dst_match_idx = [m.trainIdx for m in good]

    filtered_src_points = np.float32([kp1[i].pt for i in filtered_src_match_idx])
    filtered_dst_points = np.float32([kp2[i].pt for i in filtered_dst_match_idx])

    return filtered_src_points, filtered_dst_points, averageDistance, num_good

def apply_transform(moving, target, moving_pts, target_pts, transformer, output_shape_rc=None):
    '''
    :param transformer: transformer object from skimage. See https://scikit-image.org/docs/dev/api/skimage.transform.html for different transformations
    :param output_shape_rc: shape of warped image (row, col). If None, uses shape of traget image
    return
    '''
    if output_shape_rc is None:
        output_shape_rc = target.shape[:2]

    if str(transformer.__class__) == "<class 'skimage.transform.EuclideanTransform'>":
        transformer.estimate(target_pts, moving_pts)
        warped_img = transform.warp(moving, transformer, output_shape=output_shape_rc)

        ### Restimate to warp points
        transformer.estimate(moving_pts, target_pts)
        warped_pts = transformer(moving_pts)
    else:
        transformer.estimate(moving_pts, target_pts)
        warped_img = transform.warp(moving, transformer.inverse, output_shape=output_shape_rc)
        warped_pts = transformer(moving_pts)

    return warped_img, warped_pts

def keypoint_distance(moving_pts, target_pts, img_h, img_w):
    dst = np.sqrt(np.sum((moving_pts - target_pts)**2, axis=1)) / np.sqrt(img_h**2 + img_w**2)
    return np.mean(dst)


def regAll(path):
    try:
        os.makedirs('{}/{}/{}'.format(
            output_folder,
            reg_folder,
            (os.path.basename(path))))
        os.chdir('{}/{}/{}'.format(
            output_folder,
            reg_folder,
            (os.path.basename(path))))
    except:
        os.chdir('{}/{}/{}'.format(
            output_folder,
            reg_folder,
            (os.path.basename(path))))
    f = open('{}/{}/{}/distance.txt'.format(
                output_folder,
                reg_folder,
                (os.path.basename(path))), 'w+')
    
    tiles = all_paths[path]
    for i, k in enumerate(tiles):
        # block for registering image to image 1 (PD-1)  - rationale: hematoxylin staining often looks qualitatively very different than all other Hema stains combined with AEC
        original_target = io.imread(tiles[0])
        original_moving = io.imread(tiles[i])

        target = img_as_ubyte(gaussian(rgb2gray(original_target), 3))    #alpha could be tuned for best performace 
        moving = img_as_ubyte(gaussian(rgb2gray(original_moving), 3))

        fd = cv2.KAZE_create(extended=True)
        try:
            moving_pts, target_pts, averageDistance, num_good = match_keypoints(moving, target, feature_detector=fd)
        except:
            f.write('fail on' + str(i) + 'match_keypoints')
            f.close()
            break
        f.write('{}_Average Distance {} = {}________Number of Good Matches = {}\n'.format(str(i),str(averageDistance), all_paths[path][i],str(num_good)))

        transformer = transform.EuclideanTransform()
        try:
            warped_img, warped_pts = apply_transform(original_moving, original_target, moving_pts, target_pts, transformer=transformer)
        except:
            f.write('fail on' + str(i) + 'apply_transform')
            f.close()
            break

        warped_img = img_as_ubyte(warped_img)
        io.imsave(str(i) + '_' + (os.path.basename(tiles[i][:-5]+ '_reg.tiff')), warped_img)
    
    f.close()
    
    '''
    for i, k in enumerate(all_paths[path]):
        # block for registering image to image 1 (PD-1)  - rationale: hematoxylin staining often looks qualitatively very different than all other Hema stains combined with AEC
        original_target = io.imread(all_paths[path][0])
    
        original_moving = io.imread(all_paths[path][i])

        target_file = deconv(all_paths[path][0])
        moving_file = deconv(all_paths[path][i])

        target = img_as_ubyte(gaussian(target_file, 3))    #alpha could be tuned for best performace 
        moving = img_as_ubyte(gaussian(moving_file, 3))


        fd = cv2.KAZE_create(extended=True)
        try:
            moving_pts, target_pts, averageDistance, num_good = match_keypoints(moving, target, feature_detector=fd)
        except:
            f.write('fail on' + str(i) + 'match_keypoints')
            f.close()
            break
        f.write('{}_Average Distance {} = {}________Number of Good Matches = {}\n'.format(str(i),str(averageDistance), all_paths[path][i],str(num_good)))

        transformer = transform.EuclideanTransform()
        try:
            warped_img, warped_pts = apply_transform(original_moving, original_target, moving_pts, target_pts, transformer=transformer)
        except:
            f.write('fail on' + str(i) + 'apply_transform')
            f.close()
            break

        warped_img = img_as_ubyte(warped_img)

        io.imsave(str(i) + '_' + (os.path.basename(all_paths[path][i][:-5]+ '_reg.tiff')), warped_img)

        if num_good < 10 or averageDistance > 0.5:                     # both number of good matches needed and average distance can be tuned for best performance
            try:
                original_target = img_as_ubyte(io.imread(str(i-1)+'_i-1_'+ (os.path.basename(all_paths[path][i-1][:-5]+ '_reg.tiff'))))
            except:
                original_target = img_as_ubyte(io.imread(str(i-1)+'_' + (os.path.basename(all_paths[path][i-1][:-5]+ '_reg.tiff'))))
            original_moving = img_as_ubyte(io.imread(all_paths[path][i]))

            try:
                target_file = deconv(str(i-1)+'_i-1_' + (os.path.basename(all_paths[path][i-1][:-5]+ '_reg.tiff')))
            except:
                target_file = deconv(str(i-1)+'_' + (os.path.basename(all_paths[path][i-1][:-5]+ '_reg.tiff')))
            moving_file = deconv(all_paths[path][i])

            target = img_as_ubyte(gaussian(target_file, 6))    #alpha could be tuned for best performace 
            moving = img_as_ubyte(gaussian(moving_file, 6))


            fd = cv2.KAZE_create(extended=True)
            try:
                moving_pts, target_pts, averageDistance, num_good = match_keypoints(moving, target, feature_detector=fd)
            except:
                f.write('fail on' + str(i) + '-1 match_keypoints')
                f.close()
                break
            f.write('{}_i-1_Average Distance {} = {}________Number of Good Matches = {}\n'.format(str(i),str(averageDistance), all_paths[path][i],str(num_good)))

            transformer = transform.EuclideanTransform()
            try:
                warped_img, warped_pts = apply_transform(original_moving, original_target, moving_pts, target_pts, transformer=transformer)
            except:
                f.write('fail on' + str(i) + '-1 apply_transform')
                f.close()
                break
            warped_img = img_as_ubyte(warped_img)
            io.imsave(str(i)+'_i-1_' + (os.path.basename(all_paths[path][i][:-5]+ '_reg.tiff')), warped_img)
            if num_good<10 or averageDistance > 0.5:
                break
            else:
                continue
            break
    '''    

In [12]:
def combineMask(folder):
    log = open('/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/mask_errors.txt', 'w+')
    try:
        l_regfiles = []
        for root, dirs, files in os.walk(folder):
            for file in files:
                if os.path.basename(root) == 'mask':
                    continue
                if file.endswith('.tiff'):
                    l_regfiles.append(os.path.join(root, file))
        l_regfiles = sorted(l_regfiles)
        l_reg_sort = [list(g) for k, g in groupby(l_regfiles, key = lambda x: os.path.basename(x)[0])]  # sort files by first character (original and i-1 image will both have same number preceding)
        l_reg_filter = []
        
        for h in l_reg_sort:         # if i-1 image exists, take it if not take the regular registered image
            if len(h) == 1:
                l_reg_filter.append(h[0])
            if len(h) == 2:
                l_reg_filter.append(h[1])
        l_reg_filter = sorted(l_reg_filter)
        
        
        combined_mask = np.zeros((io.imread(l_reg_filter[0]).shape[0],io.imread(l_reg_filter[0]).shape[1]), dtype=bool)    # create empty mask file with same shape as images
        
        for k in l_reg_filter:
            if len(l_reg_filter) == 2:
                img = io.imread(k)
                img[np.all(img == (0,0,0), axis = -1)] = (255, 255,255)  # set black registration gaps to white
                gaus = gaussian(img, 6)  # apply gaussian blur to image, returns float

                thresh = 0.97    # Try range of values to determine ideal threshold
                mask_gaus = (gaus[:,:,1] ==0) | ((gaus[:,:,0] > thresh) | (gaus[:,:,1] > thresh) | (gaus[:,:,2] > thresh)) # masks any image area where pixel value is either 0 or above threshold on any of the 3 channels
                combined_mask = combined_mask + mask_gaus
            else:
                break
        else:     
            try:
                os.makedirs(folder + '/mask')
            except:
                os.chdir(folder + '/mask')
            clean_mask = remove_small_objects(combined_mask, min_size=2000)
            clean_mask = remove_small_holes(clean_mask, area_threshold=2000)
            io.imsave(folder + '/mask/{}_CombinedMask.tiff'.format(os.path.basename(folder[:-5])), img_as_ubyte(clean_mask))   #save combined mask file
            arrays = []
            for y in l_reg_filter:
                img = io.imread(y)
                img[np.all(img == (0,0,0), axis = -1)] = (250, 250,250)
                img[clean_mask] = 250
                arrays.append(img)
            stack = np.stack(arrays, axis = 0)
            io.imsave(folder + '/mask/{}_stack.tiff'.format(os.path.basename(folder[:-5])), stack)
            demo = io.imread(l_reg_filter[0])
            gaps = demo[:,:,1] == 0
            demo[gaps] = 250
            demo[clean_mask] = 250
            io.imsave(folder + '/mask/{}_demo.tiff'.format(os.path.basename(folder[:-5])), demo)   # save an example of the mask on image        
    except:
        log.write('fail on ' + folder)
        log.close()


#         if len(l_reg_filter) == 8:
#             os.rename(folder, os.path.dirname(folder) + '/Masked_' + os.path.basename(folder))
                

# Base Alignment of Full Slide Images #

In [13]:
# create subfolders in output folder
os.makedirs(output_dir + '/Unaligned/BaseImages', exist_ok = True)
os.makedirs(output_dir + '/Unaligned/DownImages', exist_ok = True)
os.makedirs(output_dir + '/BaseAligned', exist_ok = True)


In [14]:
# create list of lists of paths to .scn images for each patient stains in imaged order
all_slides = []
for y in pat_names:
    slide_names = []
    for root, dirs, files in os.walk(raw_img_dir):
        for i in files:
            if i == y:
                stain = os.path.basename(os.path.dirname(os.path.join(root, i)))
                stain_ord = stain_order.index(stain)+1
                slide_names.append(tuple([stain_ord, os.path.join(root, i)]))
    
    slide_names = sorted(slide_names)
    slide_names = [x[1] for x in slide_names]
    all_slides.append(slide_names)

print(all_slides)

[['/home/rohit/Documents/Slides/TB/N19-783-01_TB.scn'], ['/home/rohit/Documents/Slides/TB/N19-847-01_TB.scn'], ['/home/rohit/Documents/Slides/HE/N19-783-01_HE.scn'], ['/home/rohit/Documents/Slides/HE/N19-847-01_HE.scn']]


In [15]:
# Open and resave raw images cropped to the nearest multiple of 2048 pixels in x, y dimensions
for y in all_slides:
    for index, slide in enumerate(y):
        img = tifffile.imread(slide, key = goalRes(slide))    # watch out for if within an image set, output shapes change. consider setting the crop_row, crop_col variables based on image1 and applying to all remaining images
        crop_row = (img.shape[0]%2048) # Default (2048); 1024 since crop is 1024
        crop_col = (img.shape[1]%2048) # Default (2048); 1024 since crop is 1024
        out_image = crop(img, ((int(math.floor(crop_row/2)), int(math.ceil(crop_row/2))), (int(math.floor(crop_col/2)), int(math.ceil(crop_col/2))), (0,0)), copy = False)
        tifffile.imwrite(output_dir + '/Unaligned/BaseImages/{}_Base_{}_{}.tiff'.format(str(index), os.path.basename(slide[:-4]), os.path.basename(os.path.dirname(slide))), out_image, photometric='minisblack')
        print(out_image.shape)

(61440, 36864, 3)
(49152, 36864, 3)
(61440, 36864, 3)
(49152, 36864, 3)


In [16]:
# List base images written above
l_base  = []
for files in os.listdir(output_dir + '/Unaligned/BaseImages'):
    _, _, pat, stain, _ = os.path.basename(files)[:-5].split('_')
    l_base.append((pat, stain, os.path.join(output_dir + '/Unaligned/BaseImages/' ,files)))
l_base = sorted(l_base)

print(l_base)


[('N19-783-01', 'HE', '/home/rohit/Documents/Balloon Cell Tiles/Unaligned/BaseImages/0_Base_N19-783-01_HE_HE.tiff'), ('N19-783-01', 'TB', '/home/rohit/Documents/Balloon Cell Tiles/Unaligned/BaseImages/0_Base_N19-783-01_TB_TB.tiff'), ('N19-847-01', 'HE', '/home/rohit/Documents/Balloon Cell Tiles/Unaligned/BaseImages/0_Base_N19-847-01_HE_HE.tiff'), ('N19-847-01', 'TB', '/home/rohit/Documents/Balloon Cell Tiles/Unaligned/BaseImages/0_Base_N19-847-01_TB_TB.tiff')]


In [17]:
# create downscaled images for easier alignment
for i in l_base:
    img = cv2.pyrDown(io.imread(i[2]))
    down_img = cv2.pyrDown(img)
    tifffile.imwrite(output_dir + '/Unaligned/DownImages/Down_{}.tiff'.format(os.path.basename(i[2][:-5])), down_img, photometric='minisblack')
    print(down_img.shape)

(15360, 9216, 3)
(15360, 9216, 3)
(12288, 9216, 3)
(12288, 9216, 3)


In [18]:
l_down = []

for y in pat_names:
    slide_names = []
    for root, dirs, files in os.walk(output_dir + '/Unaligned/DownImages'):
        for i in files:
            _, _, _, pat, stain, _ = i.split('_')
            if pat == y[:-7]:
                # stain = stain[:-5]
                stain_ord = stain_order.index(stain) + 1
                slide_names.append(tuple([stain_ord, os.path.join(root, i)]))
    
    slide_names = sorted(slide_names)
    slide_names = [x[1] for x in slide_names]
    l_down.append(slide_names)

l_down = l_down[:-2]
print(l_down)

[['/home/rohit/Documents/Balloon Cell Tiles/Unaligned/DownImages/Down_0_Base_N19-783-01_HE_HE.tiff', '/home/rohit/Documents/Balloon Cell Tiles/Unaligned/DownImages/Down_0_Base_N19-783-01_TB_TB.tiff'], ['/home/rohit/Documents/Balloon Cell Tiles/Unaligned/DownImages/Down_0_Base_N19-847-01_HE_HE.tiff', '/home/rohit/Documents/Balloon Cell Tiles/Unaligned/DownImages/Down_0_Base_N19-847-01_TB_TB.tiff']]


In [19]:
# Ideas:
# 1. Convert RGB into GrayScale
# 2. Try running without deconv function (TESTED - produces three channel RGB image, not same dimensions as deconv)
# 3. Try printing out deconvoluted image (TESTED - produced black screen, causing alignment not to work)
# 4. Change SITK default parameters (look into documentation) - Tune to find the best registration
# 5. Try different larger overlap size to help facilitate registration
# 6. Try downscaling images even more (by a factor of 4) (TESTED - looks like this works!)

start = time.time()
Parallel(n_jobs = 3)(delayed(base_align)(file) for file in tqdm(l_down))
end = time.time()

print('Compute Time = ' + str(round(((end-start)/60), 3)) + ' Minutes' ) 

100%|██████████| 2/2 [00:00<00:00, 17.12it/s]


Compute Time = 16.645 Minutes


# Tile Image and Register #

In [20]:
l_files = []
for root, dirs, files in os.walk(input_folder):
    for i in files:
        l_files.append(os.path.join(root, i))

print(l_files)

['/home/rohit/Documents/Balloon Cell Tiles/BaseAligned/Aligned_0_Base_N19-847-01_TB_TB.tiff', '/home/rohit/Documents/Balloon Cell Tiles/BaseAligned/Aligned_0_Base_N19-847-01_HE_HE.tiff', '/home/rohit/Documents/Balloon Cell Tiles/BaseAligned/Aligned_0_Base_N19-783-01_TB_TB.tiff', '/home/rohit/Documents/Balloon Cell Tiles/BaseAligned/Aligned_0_Base_N19-783-01_HE_HE.tiff']


In [21]:
start = time.time()
Parallel(n_jobs = 25)(delayed(tileSave)(file)for file in tqdm(l_files))
end = time.time()

print('Compute Time = ' + str(round(((end-start)/60), 3)) + ' Minutes' ) 

100%|██████████| 4/4 [00:00<00:00, 37.72it/s]


Compute Time = 3.833 Minutes


In [22]:
# list all cropped files for a slide
l_allfiles = []
for root, dirs, files in os.walk(output_folder):
    for file in files:
        if file.endswith('_reg.tiff') | file.endswith('_tilemap.tiff'):
            continue
        if os.path.basename(root) == 'mask':
            continue
        if file.endswith('.tiff'):
            l_allfiles.append(os.path.join(root,file))

In [23]:
#Extract metadata for each cropped image from file names
metadata = []
for path in l_allfiles:
    file = os.path.basename(path)[:-5]
    slide, stain, _, row, col = file.split('_')
    coord = row+col
    stain_ord = stain_order.index(stain)+1
    file_meta = [slide, stain, stain_ord, coord, path]
    metadata.append(file_meta) 

In [24]:
# Filter all cropped images for Hematoxylin as baseline to register all other images to
pd1_list = [item for item in metadata if item[1] == 'HE']

In [25]:
# For each cropped Hematoxylin image, find all files that are from the matching row - col location of the slide.  
#If a matching crop does not exist in all files, toss.


all_paths = {}
l_pd1_filter = []

for coord in pd1_list:
    files = [item for item in metadata if (item[0] == coord[0] and item[3] == coord[3])]
    for i in files:
        if (i[1] == 'HE'):
            l_pd1_filter.append(i[4])
        file_paths = [item[4] for item in files]
        key = [item[4] for item in files if item[1] == 'HE']
        all_paths[key[0]] = file_paths

'''
    if len(files) == 8:
        for i in files:
            files = sorted(files, key = lambda x: x[2])
            if i[1] == 'HE':
                l_pd1_filter.append(i[4])
            file_paths = [item[4] for item in files]
            key = [item[4] for item in files if item[1] == 'PD-1']
            all_paths[key[0]] = file_paths
'''

"\n    if len(files) == 8:\n        for i in files:\n            files = sorted(files, key = lambda x: x[2])\n            if i[1] == 'HE':\n                l_pd1_filter.append(i[4])\n            file_paths = [item[4] for item in files]\n            key = [item[4] for item in files if item[1] == 'PD-1']\n            all_paths[key[0]] = file_paths\n"

In [26]:
start = time.time()
Parallel(n_jobs = 25)(delayed(regAll)(file)for file in tqdm(l_pd1_filter))
# regAll(l_pd1_filter[82])
end = time.time()

print('Compute Time = ' + str(round(((end-start)/60), 3)) + ' Minutes' ) 

100%|██████████| 972/972 [59:25<00:00,  3.67s/it]


Compute Time = 63.507 Minutes


In [27]:
l_regfolders = []                 # list of all folders containing registered images from 1 cropped field
for root, dirs, files in os.walk('{}/{}'.format(output_folder, reg_folder)):
    for x in dirs:
        if x.endswith('.tiff'):
            l_regfolders.append(os.path.join(root, x))

l_regfolders = sorted(l_regfolders)
print(l_regfolders)

['/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c0.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c1.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c10.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c11.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c12.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c13.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c14.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c15.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c16.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c17.tiff', '/home/rohit/Documents/Balloon 

In [40]:
# This cell is responsible for collecting a random number of 
# images (x20) for validation analysis by the Bret

import random
import shutil

output = "/home/rohit/Documents/Ilastik Balloon Cells (FINAL)/Validation Images"
os.mkdir(output)

for index in range(20):
    folder = random.choice(l_regfolders)
    _, _, _, _, _, _, _, name = folder.split("/")
    shutil.copytree(folder, output + "/" + name)

In [28]:
start = time.time()
Parallel(n_jobs = 25)(delayed(combineMask)(file)for file in tqdm(l_regfolders))
end = time.time()

print('Compute Time = ' + str(round(((end-start)/60), 3)) + ' Minutes' ) 

100%|██████████| 972/972 [05:12<00:00,  3.11it/s]


Compute Time = 5.605 Minutes


In [16]:
# At this point now that the HE & TB files have been combined into one stack, it is important to run the ilastik/
# cellprofiler script to generate the probability maps & masks. That way the probability maps 
# can be combined into that stack file as well to gain a comprehensive overview of where the balloon cells 
# are found in the tissue (run code below).

stack_images = []
prob_images = []
cellprofiler_images = []

for folder in os.listdir(output_folder + "/Registered"):
    pat, _, _, row, col = folder[:-5].split("_")
    interest_fields = pat + "_" + row + "_" + col
    
    for root, dirs, files in os.walk(output_folder 
                                     + "/Registered/" + folder):
        for file in files:
            if file.endswith("stack.tiff"):
                stack_image = os.path.join(root, file)
                stack_images.append((stack_image, interest_fields))
            if "Probabilities_0" in file and file.endswith(".tif"):
                prob_image = os.path.join(root, file)
                prob_images.append((prob_image, interest_fields))
            if "Probabilities_0" in file and file.endswith(".tiff"):
                cellprofiler_image = os.path.join(root, file)
                cellprofiler_images.append((cellprofiler_image, interest_fields))


for stack_image_name, stack_meta in stack_images:
    for prob_image_name, prob_meta in prob_images:
        for cp_image_name, cp_meta in cellprofiler_images:
            if (stack_meta == prob_meta and prob_meta == cp_meta):
                stack_img = io.imread(stack_image_name)
                prob_img = gray2rgb(img_as_ubyte(io.imread(prob_image_name)))
                cp_img = io.imread(cp_image_name)
                destack = np.split(stack_img, 2, axis = 0)
                new_stack = np.stack([destack[0][0], destack[1][0], prob_img, cp_img], axis = 0)
                io.imsave((stack_image_name[0 : stack_image_name.find("mask")] + 
                      stack_meta + "_ilastik_stack.tiff"), new_stack)


In [18]:
# #script to delete all files created by combineMask
for root, dirs, files in os.walk(output_folder + "/Registered"):
    for file in files:
        if file.endswith("Mask.tiff") | file.endswith("demo.tiff") | file.endswith("stack.tiff"):
            os.remove(os.path.join(root, file))
    if (os.path.basename(root) == "mask"):
        os.rmdir(root)

# Concatenate stack files #

In [17]:
# create list of all 'stack' files
os.makedirs(output_dir + '/StitchedStacks', exist_ok = True)


l_stacks = []
for root, dirs, files in os.walk(output_folder + '/Registered'):
    for file in files:
        if file.endswith('ilastik_stack.tiff'):
            l_stacks.append(os.path.join(root, file))
print(l_stacks)

['/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r0_c16.tiff/N19-783-01_r0_c16_ilastik_stack.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-847-01_HE_HE_r13_c12.tiff/N19-847-01_r13_c12_ilastik_stack.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r28_c3.tiff/N19-783-01_r28_c3_ilastik_stack.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r9_c15.tiff/N19-783-01_r9_c15_ilastik_stack.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-847-01_HE_HE_r3_c16.tiff/N19-847-01_r3_c16_ilastik_stack.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-783-01_HE_HE_r1_c0.tiff/N19-783-01_r1_c0_ilastik_stack.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Registered/N19-847-01_HE_HE_r18_c5.tiff/N19-847-01_r18_c5_ilastik_stack.tiff', '/home/rohit/Documents/Balloon Cell Tiles/TiledOverlap/Regist

In [20]:
sort_l_files = [file for file in l_files if os.path.basename(file).startswith('Aligned_0')]   #list of 1 image from each patient
stack_meta = {}
for path in l_stacks:
    file = os.path.basename(path)[:-5]
    pat, row, col, _, _ = file.split('_')
    
    
    int_row = row[1:]
    int_col = col[1:]
    coord = (int(int_row), int(int_col))
    file_meta = [coord, path]
    stack_meta[coord] = path

In [21]:
for p in sort_l_files:
    _, _, _, pat_ID, _, _ = os.path.basename(p).split('_')
    
    stack_meta = {}
    for path in l_stacks:
        file = os.path.basename(path)[:-5]
        pat, row, col, _, _ = file.split('_')
        int_row = row[1:]
        int_col = col[1:]
        coord = (int(int_row), int(int_col))
        file_meta = [coord, path]
        if pat == pat_ID:
            stack_meta[coord] = path
    
    
    img = io.imread(p)
    num_rows = list(range(int(img.shape[0]/tile_x)))
    num_cols = list(range(int(img.shape[1]/tile_y)))
    combined = [(f,s) for f in num_rows for s in num_cols] 
    
    l_CropPath = []
    for item in combined:
        try: 
            l_CropPath.append((item, stack_meta[item]))
        except:
            l_CropPath.append((item, 'placeholder'))
            
    l_crops = []
    for item in l_CropPath:
        if item[1] == 'placeholder':
            l_crops.append(np.zeros((4,2048,2048,3), dtype = np.uint8))
        else:
            row = item[0][0]
            col = item[0][1]
            if row == 0:
                row_top_offset = 0
            else:
                row_top_offset = overlap_x
            if row == int(img.shape[0]/tile_x)-1:
                row_bot_offset = 0
            else:
                row_bot_offset = overlap_x
            if col == 0:
                col_l_offset = 0
            else:
                col_l_offset = overlap_y
            if col == int(img.shape[1]/tile_y)-1:
                col_r_offset = 0
            else:
                col_r_offset = overlap_y
            l_crops.append(crop(io.imread(item[1]), ((0,0), (row_top_offset, row_bot_offset), (col_l_offset, col_r_offset), (0,0)), copy = False))

    conc_rows = []
    for x in range(len(num_rows)):
        conc_rows.append((np.concatenate(l_crops[((max(num_cols)+1)*(x)):((max(num_cols)+1)*(x+1))], axis = 2)))
    conc_all = np.concatenate(conc_rows, axis = 1)
    
    tifffile.imwrite(output_dir + '/StitchedStacks/{}_stitched.tiff'.format(pat_ID), conc_all)

In [38]:
# Create downscaled images of annotated scans
slides = [(output_dir + "/StitchedStacks/N19-783-01_stitched.tiff", "N19-783-01"), 
          (output_dir + "/StitchedStacks/N19-847-01_stitched.tiff", "N19-847-01")]
for slide, patient in slides:
    raw = io.imread(slide)
    destack = np.split(raw, 4, axis = 0)
    recombined_stack = []
    
    for channel in destack:
        down_channel = cv2.pyrDown(cv2.pyrDown(channel[0]))
        recombined_stack.append(down_channel)
    
    recombined_img = np.stack(recombined_stack, axis = 0)
    tifffile.imwrite(output_dir + "/StitchedStacks/Down_{}_stitched.tiff".format(patient),
                     recombined_img, photometric = "minisblack")
    print(recombined_img.shape)

(4, 15360, 9216, 3)
(4, 12288, 9216, 3)
