In [None]:
# install openslide, histomics_stream, pandas
!apt-get update
!apt-get install -y openslide-tools
!pip install openslide-python
!pip install histomics_stream 'large_image[openslide]' scikit_image --find-links https://girder.github.io/large_image_wheels
!pip install pandas

# install mil
user = '########' #git username
token = '################################' #personal access token - see https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token
!git clone https://{user}:{token}@github.com/PathologyDataScience/mil.git
!pip install -e mil

In [None]:
from mil.io.reader import read_record
from mil.io.writer import split_inference, write_record
from mil.io.utils import inference, study
from mil.models import convolutional_model
import numpy as np
import os
import pandas as pd
import tensorflow as tf

#parameters
t=224
overlap=0
chunk=1792
magnification=20
tol=0.02
tile_batch=128
tile_prefetch=2
labels = ['i', 't']
model_name='EfficientNetB3'
svspath = '/data/transplant/nwu/wsi/'
csvfile = '/data/transplant/nwu/CTOT08_clinical_BiopsyImageKeys.csv'
output = '/tf/notebooks/transplant/3d/'

In [None]:
# Create a wrapped model that consumes tiles and tile metadata.
# tf.keras.Model.predict takes a single input, and so we combine 
# (tiles, tile_metadata) for passing to the wrapped model. Inside
# the wrapper these are separated and inference is done on the tiles.
# To avoid predict discarding the tile_metadata, we add a dummy 'y' 
# variable 0. to be discarded by predict.

# define the wrapped model class
class WrappedModel(tf.keras.Model):
    def __init__(self, extractor, *args, **kwargs):
        super(WrappedModel, self).__init__(*args, **kwargs)
        self.model = model
        
    def call(self, inputs, *args, **kwargs):
        return self.model(inputs[0]), inputs[1]
    

# create the feature extractor model to be wrapped
model = tf.keras.applications.efficientnet.EfficientNetB3(
        include_top=False, weights='imagenet', input_shape=(t, t, 3),
        pooling='avg')

# create a distributed wrapped model
with tf.distribute.MirroredStrategy().scope():
    
    # wrap the model
    wrapped_model = WrappedModel(model, name='wrapped_model')

In [None]:
# generate .tfr test data

# extract labels from csv
table = pd.read_csv(csvfile)
table = table[['SVS_FileName', 'I', 'T']]
table = table.dropna()
table = table.rename(columns={'SVS_FileName': "name", "I": "i_score", "T": "t_score"})
table = table.rename(columns={c:c.split('_')[0] for c in table.columns if '_score' in c})

# match table entries to existing files
files = [slide for slide in os.listdir(svspath) if os.path.splitext(slide)[1] == '.svs']
table = table[table.name.isin(files)].reset_index()

# select two slides
slide_1 = table.loc[0]['name']
slide_2 = table.loc[1]['name']

# extract labels for these slides
labels_1 = {l:float(table.loc[0][l]) for l in table.loc[0].keys() if l != 'name'}

# create studies for two slides
study_1 = study([svspath+slide_1], (t, t), (overlap, overlap), (chunk, chunk), magnification)

# do inference for two slides
features_1, tile_info_1 = inference(wrapped_model, study_1, batch=tile_batch, prefetch=tile_prefetch)

# write one slide as both 2d, 3d in separate files
write_record(output + '2d.tfr', features_1, tile_info_1, labels_1, structured=False)
write_record(output + '3d.tfr', features_1, tile_info_1, labels_1, structured=True)

# write a single record containing two slides
study_c = study([svspath+slide_1, svspath+slide_2], 
                       (t, t), (overlap, overlap), (chunk, chunk), magnification)
features_c, tile_info_c = inference(wrapped_model, study_c, batch=tile_batch, prefetch=tile_prefetch)
write_record(output + 'combined_2d.tfr', features_c, tile_info_c, labels_1, structured=False)

# write a record for each slide independently
features_c, tile_info_c = split_inference(features_c, tile_info_c)
for i, (features, tile_info) in enumerate(zip(features_c, tile_info_c)):
    write_record(output + 'combined_' + str(i) + '_2d.tfr',
                 features, tile_info, labels_1, structured=False)

In [None]:
# testing

# function for comparing tile info
def compare_tile_info(t1, t2):
    [tf.debugging.assert_equal(t1[k], t2[k]) for k in t1.keys()]
    
    
# function for comparing 2D arrays that are not aligned
def compare_unaligned_2d(f1, f2, t1, t2):
    matches = []
    for x,y in zip(t1['tile_left'], t1['tile_top']):
        matches.append(
            tf.where(
                tf.logical_and(
                    x == t2['tile_left'],
                    y == t2['tile_top'])
            ).numpy()[0][0]
        )
    matches = tf.constant(matches)
    tf.debugging.assert_equal(f1, tf.gather(f2, matches, axis=0))
    

# read in files used for comparisons
def read(path, label, structured):
    ds = tf.data.TFRecordDataset(path)
    serialized = list(ds.take(1))[0]
    f, l, s, t = read_record(serialized, label, False)
    return f, l, s, t
f_2, l_2, s_2, t_2 = read(output + '2d.tfr', labels, False)
f_3, l_3, s_3, t_3 = read(output + '3d.tfr', labels, True)
f_c, l_c, s_c, t_c = read(output + 'combined_2d.tfr', labels, False)

# compare file stored/read as 2D, file stored as 3D read as 2D
f, l, s, ti = read(output + '3d.tfr', labels, False)
tf.debugging.assert_near(f_2, f)
assert l_2 == l
assert s_2 == s
compare_tile_info(t_2, ti)

# compare file stored/read as 3D, file stored as 2D read as 3D
f, l, s, ti = read(output + '2d.tfr', labels, True)
tf.debugging.assert_near(f_3, f)
assert l_3 == l
assert s_3 == s
compare_tile_info(t_3, ti)

# compare contents of jointly stored slides and independently stored slides
f, l, s, ti = read(output + 'combined_0_2d.tfr', labels, False)
f0 = f_c[t_c['slide_index'] == 0]
t0 = {k: t_c[k][t_c['slide_index'] == 0] for k in t_c.keys()}
compare_unaligned_2d(f, f0, ti, t0)

f, l, s, ti = read(output + 'combined_1_2d.tfr', labels, False)
f1 = f_c[t_c['slide_index'] == 1]
t1 = {k: t_c[k][t_c['slide_index'] == 1] for k in t_c.keys()}
compare_unaligned_2d(f, f1, ti, t1)

# try writing output as structured from multiple files (should error)
#write_record(output + 'combined_3d.tfr', features_c, tile_info_c, labels_1, structured=True)