In [None]:
!pip install tensorflow==2.9.0

In [None]:
import os
import glob
from IPython.display import Image, display
import PIL
PIL.Image.MAX_IMAGE_PIXELS = None


import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from plotly import graph_objs as go
import plotly.express as px
import plotly.figure_factory as ff

import openslide

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras.utils as ku

In [None]:
BASE_DIR = '../input/prostate-cancer-grade-assessment'
IMAGES_DIR = os.path.join(BASE_DIR, 'train_images')
MASK_DIR = os.path.join(BASE_DIR, 'train_label_masks')

In [None]:
train_data = pd.read_csv(os.path.join(BASE_DIR, 'train.csv'))
train_data.head()

In [None]:
train_df_copy = train_data.copy()
train_df_copy['image_id'] = train_df_copy['image_id'] + '.tiff'
train_df_copy = train_df_copy.set_index('image_id')

In [None]:
class EDAUtilFunctions:
    
    def __init__(self, dataframe: pd.DataFrame):
        self.df = dataframe
    
    def count_plot(self, col_names: list):
        
        plt.figure(figsize=(8,6), tight_layout=True)
        colors = sns.color_palette('pastel')
        
        for col_name in col_names:
        
            if col_name not in self.df.columns:
                raise KeyError(f'{col_name} should be present in {self.df.columns}')

            

            unique_features, counts = np.unique(self.df[col_name], return_counts=True)

            plt.bar(unique_features, counts, color=colors[:len(unique_features)])
            plt.xticks(rotation=90)
            plt.xlabel(col_name)
            plt.ylabel('Count')
            plt.title('Count Plot')
            plt.show()
    
    def funnel_plot(self, text: str, values: str):
        if text not in self.df.columns:
            raise KeyError(f'{text} should be present in {self.df.columns}')
        
        if values not in self.df.columns:
            raise KeyError(f'{values} should be present in {self.df.columns}')
        
        fig = go.Figure(go.Funnelarea(
        text =self.df[text],
        values = self.df[values],
        title = {"position": "top center", "text": "Funnel-Chart of ISUP_grade Distribution"}))
        fig.show()
    
    


    

In [None]:
eda_df = EDAUtilFunctions(train_data)
eda_df.count_plot(['data_provider', 'isup_grade', 'gleason_score'])

In [None]:
# t = np.random.randint(200, 226, (5, 5, 3)).astype(np.uint8) - 255.0
# print(np.sum(t) / 75)
# plt.imshow((t + 255).astype(np.uint8))

In [None]:
class TIFFVisualization:
    
    @staticmethod
    def visualize_patch(patch_shape: tuple, n_images: int, pos: tuple, img_dir: str, dft: pd.DataFrame):
        

        filenames = glob.glob(f'{img_dir}/*.tiff')
        
        if len(filenames) is 0:
            raise RuntimeError(f'{img_dir} should contain tiff encoded images')
        
        
        
        indices = np.random.randint(0, len(filenames), n_images)
        
        fig, axes = plt.subplots(n_images // 3, 3, figsize=(20, 60))
        for i in range(n_images):
            key_name = filenames[indices[i]].split('/')[-1]
            slide = openslide.OpenSlide(filenames[indices[i]])
            patch = slide.read_region(pos, 2, patch_shape)
            axes[i // 3][i % 3].imshow(patch)
            slide.close()       
            axes[i//3, i%3].axis('off')
            
            axes[i//3, i%3].set_title(f'image id: {key_name}\n ISUP Grade: {dft.loc[key_name].isup_grade}')
            
        
        plt.show()
    
    @staticmethod
    def visualize_masks(n_images: int, mask_dir: str, img_dir: str, dft: pd.DataFrame, 
                        max_size=(400, 400), alpha=0.8):
        
        filenames = glob.glob(f'{mask_dir}/*.tiff')
        
        if len(filenames) is 0:
            raise RuntimeError(f'{img_dir} should contain tiff encoded images')
            
        indices = np.random.randint(0, len(filenames), n_images)
        
        fig, axes = plt.subplots(n_images, 2, figsize=(10, 60))
        
        for i in range(n_images):
            key_name = filenames[indices[i]].split('/')[-1]
            key_name = key_name.replace('_mask', "")
            train_img = f'{img_dir}/{key_name}'
            
            img = openslide.OpenSlide(train_img)
            mask = openslide.OpenSlide(filenames[indices[i]])
            
            img_data = img.read_region((0,0), img.level_count - 1, img.level_dimensions[-1])
            mask_data = mask.read_region((0,0), mask.level_count - 1, mask.level_dimensions[-1])
            mask_data = mask_data.split()[0]

            
            center = dft.loc[key_name]['data_provider']
            alpha_int = int(round(255*alpha))
            if center == 'radboud':
                alpha_content = np.less(mask_data.split()[0], 2).astype('uint8') * alpha_int + (255 - alpha_int)
            elif center == 'karolinska':
                alpha_content = np.less(mask_data.split()[0], 1).astype('uint8') * alpha_int + (255 - alpha_int)

            

            alpha_content = PIL.Image.fromarray(alpha_content)
            preview_palette = np.zeros(shape=768, dtype=int)

            if center == 'radboud':
                # Mapping: {0: background, 1: stroma, 2: benign epithelium, 3: Gleason 3, 4: Gleason 4, 5: Gleason 5}
                preview_palette[0:18] = (np.array([0, 0, 0, 0.5, 0.5, 0.5, 0, 1, 0, 1, 1, 0.7, 1, 0.5, 0, 1, 0, 0]) * 255).astype(int)
            elif center == 'karolinska':
                # Mapping: {0: background, 1: benign, 2: cancer}
                preview_palette[0:9] = (np.array([0, 0, 0, 0, 1, 0, 1, 0, 0]) * 255).astype(int)

            mask_data.putpalette(data=preview_palette.tolist())
            mask_rgb = mask_data.convert(mode='RGB')
            
            

            overlayed_image = PIL.Image.composite(image1=img_data, image2=mask_rgb, mask=alpha_content)

            overlayed_image.thumbnail(size=max_size, resample=0)
            
            axes[i][1].imshow(overlayed_image)
            axes[i][0].imshow(img.get_thumbnail(size=max_size))
            
            axes[i][0].axis('off')
            axes[i][1].axis('off')

            
            axes[i][0].set_title(f'image id: {key_name}\ncenter: {dft.loc[key_name].data_provider}')
            axes[i][1].set_title(f' \n ISUP Grade: {dft.loc[key_name].isup_grade}')
            
            img.close()
            mask.close()
        
        plt.show()

In [None]:
TIFFVisualization.visualize_patch((512, 512), 15, np.random.randint(1759, 1800, (2,)), IMAGES_DIR, train_df_copy)

In [None]:
TIFFVisualization.visualize_masks(6, MASK_DIR, IMAGES_DIR, train_df_copy,)

In [None]:
class DataGenerator(ku.Sequence):
    pass
class DataGenerator(ku.Sequence):
    
    whitelist_tasks = ['classification', 'segmentation',]
    
    
    def __init__(self, 
                 img_dir: str,
                 df_images: pd.DataFrame, 
                 batch_size: int, 
                 target_size: tuple, 
                 task: str,
                 mask_dir=None, 
                is_training=True,
                zoom_range=None,
                brightness_range=None):
        
        self.img_dir = img_dir
        self.df = df_images.iloc[:5000]
        self.batch_size = batch_size
        self.target_size = target_size
        self.mask_dir = mask_dir
        self.training = is_training
        self.zoom_param = zoom_range
        self.brightness_param = brightness_range
    
    
        if task not in DataGenerator.whitelist_tasks:
            raise ValueError(f"task should be one of the {DataGenerator.whitelist_tasks}")
        else:
            self.task = task
        
        if mask_dir is not None:
            train_df = self.df.copy()
            masks = os.listdir(self.mask_dir)
            masks_df = pd.Series(masks).to_frame()
            masks_df.columns = ['mask_file_name']
            masks_df['image_id'] = masks_df.mask_file_name.apply(lambda x: x.split('_')[0])
            train_df = pd.merge(train_df, masks_df, on='image_id', how='outer')
            del masks_df
            train_df = train_df[~train_df.mask_file_name.isna()]
            mask_safe_df = train_df.copy()
            del train_df
            mask_safe_df = mask_safe_df[mask_safe_df['data_provider'] == 'karolinska']
            mask_safe_df.reset_index(drop=True, inplace=True)
            self.mask_safe_df = mask_safe_df
        
        if self.task == 'classification':
            self.indices = range(len(self.df))
        
        elif self.task == 'segmentation':
            self.indices = self.mask_safe_df.index.values
        
    
    def on_epoch_start(self):
        if self.training:
            np.random.shuffle(self.indices)
    
    def __get_transormed_images(self, image_id: str) -> tf.Tensor:
        
        rel_image_path = f'{self.img_dir}/{image_id}.tiff'
        image = openslide.OpenSlide(rel_image_path)
        thumbnail = image.get_thumbnail(self.target_size)
        image.close()
        
        img_data = np.array(thumbnail)
        
        if self.brightness_param is not None:
            img_data = tf.keras.preprocessing.image.random_brightness(self.brightness_param)
        
        if self.zoom_param is not None:
            img_data = tf.keras.preprocessing.image.random_zoom(self.zoom_param)
            
        
        img_data = tf.cast(img_data, tf.float32) / 255.0
        img_data = tf.image.resize(img_data, self.target_size)
        img_data = tf.expand_dims(img_data, 0)
        return img_data
    
    
    def __get_overlayed_imgs(self, image_id: str) -> tf.Tensor:
        base_path = f'{self.img_dir}/{image_id}'
        mask_path = f'{self.mask_dir}/{image_id}'
        img = openslide.OpenSlide(f'{base_path}.tiff')
        mask = openslide.OpenSlide(f'{mask_path}_mask.tiff')
            
        img_data = img.read_region((0,0), img.level_count - 1, img.level_dimensions[-1])
        mask_data = mask.read_region((0,0), mask.level_count - 1, mask.level_dimensions[-1])
        mask_data = mask_data.split()[0]
        alpha = 0.8
        alpha_int = int(round(255*alpha))
        alpha_content = np.less(mask_data.split()[0], 1).astype('uint8') * alpha_int + (255 - alpha_int)
        alpha_content = PIL.Image.fromarray(alpha_content)
        preview_palette = np.zeros(shape=768, dtype=int)
        
        preview_palette[0:9] = (np.array([0, 0, 0, 0, 1, 0, 1, 0, 0]) * 255).astype(int)

        mask_data.putpalette(data=preview_palette.tolist())
        mask_rgb = mask_data.convert(mode='RGB')
        overlayed_image = PIL.Image.composite(image1=img_data, image2=mask_rgb, mask=alpha_content)
        overlayed_image.thumbnail(size=self.target_size, resample=0)
        img_thumbnail = img.get_thumbnail(size=self.target_size)
        img.close()
        mask.close()
        
        img_thumbnail = tf.cast(np.array(img_thumbnail), tf.float64) / 255.0
        overlayed_image = tf.cast(np.array(overlayed_image), tf.float64) / 255.0
        
        img_thumbnail = tf.image.resize(img_thumbnail, (self.target_size))
        overlayed_image = tf.image.resize(overlayed_image, (self.target_size))
        
        
        return(img_thumbnail, overlayed_image)
        
        
        
        
    def __len__(self):
        if self.task == 'classification':
            return len(self.df) // self.batch_size
        
        elif self.task == 'segmentation':
            return (len(self.mask_safe_df) // self.batch_size)
    
    def __getitem__(self, idx):
        batch = self.indices[idx * self.batch_size: (idx + 1) * self.batch_size]
        
        if self.task == 'classification':
            batch_img_ids = self.df['image_id'].iloc[batch].values
            batch_img_labels = self.df['isup_grade'].iloc[batch].values
            batch_image_data = [self.__get_transormed_images(file_id) for file_id in batch_img_ids]
            
            onehot_batch_labels = ku.to_categorical(batch_img_labels, 
                                                    DataGenerator.__deduce_num_classes(self))
            
            
            return tf.squeeze(tf.stack(batch_image_data), 1), tf.stack(onehot_batch_labels)
        
        elif self.task == 'segmentation':
            try:
                batch_img_ids = self.mask_safe_df['image_id'].iloc[batch].values
                batch_overlayed_data = [self.__get_overlayed_imgs(file_id) 
                                                 for file_id in batch_img_ids]
                img_mask = tf.stack(batch_overlayed_data)

                return img_mask[:, 0, :, :, :], img_mask[:, 1, :, :, :]
            
            except Exception as e:
                pass            
            
            
    @staticmethod
    def __deduce_num_classes(obj: DataGenerator) -> int:
        return len(obj.df['isup_grade'].unique())

In [None]:
train_datagen = DataGenerator(IMAGES_DIR,
                              train_data,
                              128,
                              (224, 224), 
                              'classification',
                              MASK_DIR,
                             )

In [None]:
# next(iter(train_datagen))

# Classification

In [None]:
temp = train_data.copy()
temp['image_id'] = temp['image_id'] + '.tiff'
temp['isup_grade'] = temp['isup_grade'].astype(str)
datagen = ImageDataGenerator(rescale=1./255.0, validation_split=0.2)
datagenerator = datagen.flow_from_dataframe(temp,IMAGES_DIR, 'image_id', 'isup_grade',batch_size=2,)

In [None]:
model = tf.keras.applications.mobilenet_v2.MobileNetV2(weights='imagenet', include_top=False, 
                                                       input_shape=(256, 256, 3))
def build_model(pre_trained_model: tf.keras.Model) -> tf.keras.Model:
    for layer in pre_trained_model.layers:
        layer.trainable=False
    
    
    
    x = pre_trained_model.outputs[0]
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(512, activation='relu')(x)
    outputs = layers.Dense(6, activation='softmax')(x)
    
    model = tf.keras.Model(inputs=[pre_trained_model.inputs], outputs=[outputs], name='test_model')
    return model

model = build_model(model)

METRICS = [
      tf.keras.metrics.TruePositives(name='tp'),
      tf.keras.metrics.FalsePositives(name='fp'),
      tf.keras.metrics.TrueNegatives(name='tn'),
      tf.keras.metrics.FalseNegatives(name='fn'), 
      tf.keras.metrics.BinaryAccuracy(name='accuracy'),
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.Recall(name='recall'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='recall', patience=2, mode='max', restore_best_weights=True, verbose=1
)

model.compile(optimizer=tf.keras.optimizers.Adam(),
             loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
             metrics=METRICS)

In [None]:
history = model.fit(datagenerator, epochs=100, verbose=1, callbacks=[early_stopping],)

# Segmentation

In [None]:
def build_model():
  inputs = layers.Input((224, 224, 3))

  conv1 = layers.Conv2D(2, 32, 2, padding='same', activation='relu')(inputs)
  conv2 = layers.Conv2D(4, 32, 2, padding='same', activation='relu')(conv1)
  conv3 = layers.Conv2D(8, 32, 2, padding='same', activation='relu')(conv2)
  conv4 = layers.Conv2D(16, 32, 2, padding='same', activation='relu')(conv3)
  conv5 = layers.Conv2D(32, 32, 2, padding='same', activation='relu')(conv4)

  deconv1 = layers.Conv2DTranspose(32, 32, 1, padding='same')(conv5)
  concat = layers.Concatenate()([conv5, deconv1])
  deconv2 = layers.Conv2DTranspose(16, 32, 2, padding='same')(concat)
  concat = layers.Concatenate()([conv4, deconv2])
  deconv3 = layers.Conv2DTranspose(8, 32, 2, padding='same')(concat)
  concat = layers.Concatenate()([conv3, deconv3])
  deconv4 = layers.Conv2DTranspose(4, 32, 2, padding='same')(concat)
  concat = layers.Concatenate()([conv2, deconv4])
  deconv5 = layers.Conv2DTranspose(2, 32, 2, padding='same')(concat)
  concat = layers.Concatenate()([conv1, deconv5])
  deconv6 = layers.Conv2DTranspose(1, 32, 2, padding='same')(concat)
  concat = layers.Concatenate()([inputs, deconv6])
  deconv6 = layers.Conv2DTranspose(3, 32, 1, padding='same', activation='linear')(concat)



  model = tf.keras.Model(inputs=[inputs], outputs=[deconv6], name='auto_encoders_for_noise_removal')

  return model

model = build_model()

print(model.summary())

model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.Huber(),
              metrics='mae')


In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
      def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
model_history = model.fit(train_datagen, epochs=50,
                          callbacks=[DisplayCallback()])