 # UNETR: Transformers for 3D Medical Image Segmentation



# Methodlogy 
* In this notebook **2.5D** images are used for Training for **Segmentation** with `tf.data`, `tfrecord` using `Tensorflow`.  
* In a nutshell, **2.5D Image Training** is training of **3D** image like **2D** Image. 2.5D images can take leverage of the extra depth information like our typical RGB image. 2.5D Images are built from 3 channels with 2 strides 
* The UNETR model from **[UNETR: Transformers for 3D Medical Image Segmentation](https://arxiv.org/pdf/2103.10504.pdf)** is used here (from the transunet library).
* The model has 100M parameters that we need to train. To use TPU capabilities, the dataset has to be transformed into a TFRecord. I used the 2.5D image dataset created in this notebook by awsaf49: [UWMGI: 2.5D TFRecord Data](https://www.kaggle.com/code/awsaf49/uwmgi-2-5d-tfrecord-data).
* "TFRecord files are created using **StratifiedGroupFold** to avoid data leakage due to `case` and to stratify `empty` and `non-empty` mask cases".
* This notebook is compatible for both **GPU** and **TPU**. Device is automatically selected so you won't have to do anything to allocate device.
* As there are overlaps between **Stomach**, **Large Bowel** & **Small Bowel** classes, this is a **MultiLabel Segmentation** task, so final activaion should be `sigmoid` instead of `softmax`.
* You can play with different models and losses.


# Reference Notebooks and Datasets 

**UNETR Model**:
* [UNETR](https://www.kaggle.com/code/usharengaraju/tensorflow-unetr-w-b)

**2.5D-TransUNet**:
* Train: [UWMGI: TransUnet 2.5D [Train] [TF]](https://www.kaggle.com/awsaf49/uwmgi-transunet-2-5d-train-tf/)
<!-- * Infer:  UWMGI: TransUnet 2.5D [Infer] [TF]-->

**Data/Dataset**:
* Dataset: [UWMGI: 2.5D TFRecord Dataset](https://www.kaggle.com/datasets/awsaf49/uwmgi-25d-tfrecord-dataset)

In [None]:
!pip install  segmentation_models

In [None]:

import pandas as pd, numpy as np, random,os, shutil
import tensorflow as tf
import re
import math
import cv2
import matplotlib.pyplot as plt
from glob import glob
from kaggle_datasets import KaggleDatasets

from PIL import Image
from tqdm import tqdm
import glob

import math
import tensorflow as tf
import tensorflow.keras.backend as k

import os
from typing import List, Tuple
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go
from skimage import io
from skimage.color import gray2rgb
from skimage.transform import resize
#from rich.jupyter import print

# Set tf.keras as backend
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import segmentation_models as sm


# **<span style="color:#F7B2B0;">Data Preprocessing</span>**

In [None]:
# Code copied from https://www.kaggle.com/code/ayuraj/quick-data-eda-segmentation-viz-using-w-b

ROOT_DIR = '../input/uw-madison-gi-tract-image-segmentation/'
df = pd.read_csv(ROOT_DIR+'train.csv')
# Remove rows with NaN Segmentation masks
df = df[df.segmentation.notna()].reset_index(drop=False)
def get_case_str(row):
    case_num = row.id.split('_')[0]
    return case_num

def get_case_id(row):
    case_num = row.id.split('_')[0]
    return int(case_num[4:])

df['case_str'] = df.apply(lambda row: get_case_str(row), axis=1)
df['case_id'] = df.apply(lambda row: get_case_id(row), axis=1)

def get_day_str(row):
    return row.id.split('_')[1]

def get_day_id(row):
    return int(row.id.split('_')[1][3:])

df['day_str'] = df.apply(lambda row: get_day_str(row), axis=1)
df['day_id'] = df.apply(lambda row: get_day_id(row), axis=1)

def get_slice_str(row):
    slice_id = row.id.split('_')[-1]
    return f'slice_{slice_id}'

df['slice_str'] = df.apply(lambda row: get_slice_str(row), axis=1)
filepaths = glob.glob(ROOT_DIR+'train/*/*/*/*')


file_df = pd.DataFrame(columns=['case_str', 'day_str', 'slice_str', 'filename', 'filepath'])
for idx, filepath in tqdm(enumerate(filepaths)):
    case_day_str = filepath.split('/')[5]
    case_str, day_str = case_day_str.split('_')

    filename = filepath.split('/')[-1]
    slice_id = filename.split('_')[1]
    slice_str = f'slice_{slice_id}'
    
    file_df.loc[idx] = [case_str, day_str, slice_str, filename, filepath]

df = pd.merge(df, file_df, on=['case_str', 'day_str', 'slice_str'])

def get_image_height(row):
    return int(row.filename[:-4].split('_')[2])
    
def get_image_width(row):
    return int(row.filename[:-4].split('_')[3])

def get_pixel_height(row):
    return float(row.filename[:-4].split('_')[4])

def get_pixel_width(row):
    return float(row.filename[:-4].split('_')[5])

df['img_height'] = df.apply(lambda row: get_image_height(row), axis=1)
df['img_width'] = df.apply(lambda row: get_image_width(row), axis=1)
df['pixel_height (mm)'] = df.apply(lambda row: get_pixel_height(row), axis=1)
df['pixel_width (mm)'] = df.apply(lambda row: get_pixel_width(row), axis=1)

df.drop('index', axis=1, inplace=True)

by_case = df.groupby('case_str')
case_df = by_case.get_group('case123')

by_day = case_df.groupby('day_str')
day_df = by_day.get_group('day0')

by_slice = day_df.groupby('slice_str')
slice_df = by_slice.get_group('slice_0075')

# saving the dataframe
df.to_csv('df.csv')
# saving the dataframe
slice_df.to_csv('slice_df.csv')

In [None]:
class CFG:

    debug = False   
   
    verbose = 0
    display_plot = True

    # Device for training
    device = None  # device is automatically selected

    # Seeding for reproducibility
    seed = 101

    # Image Size
    img_size = [96, 96]

    # Batch Size & Epochs
    batch_size = 2
    drop_remainder = False
    epochs = 15
    steps_per_execution = None

    # Model & Backbone
    model_name = "UNETR"
    backbone = None
    
    # Loss & Optimizer & LR Scheduler
    loss = "dice_loss"
    optimizer = "Adam"
    lr = 5e-4
    lr_schedule = "CosineDecay"
    patience = 5
   
    # Clip values to [0, 1]
    clip = False
    
    # Number of folds
    folds = 5

    # Which Folds to train
    selected_folds = [0, 1, 2, 3, 4]
    
    # Augmentation
    augment = False
    transform = False

import re
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)


# **<span style="color:#F7B2B0;">Set up device</span>**

In [None]:
try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    strategy = tf.distribute.MirroredStrategy() # for CPU/GPU or multi-GPU machines
    

In [None]:
#strategy, CFG.device, tpu = configure_device()
AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
BASE_PATH = '/kaggle/input/uw-madison-gi-tract-image-segmentation'
GCS_PATH = KaggleDatasets().get_gcs_path('uwmgi-25d-tfrecord-dataset')
ALL_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/uwmgi/*.tfrec')
print('NUM TFRECORD FILES: {:,}'.format(len(ALL_FILENAMES)))
print('NUM TRAINING IMAGES: {:,}'.format(count_data_items(ALL_FILENAMES)))

# **<span style="color:#F7B2B0;">Input Data Pipeline</span>**

In [None]:
# Decode image from bytestring to tensor
def decode_image(data, height, width, target_size=CFG.img_size):
    img = tf.io.decode_raw(data, out_type=tf.uint16)
    img = tf.reshape(img, [height, width, 3])  # explicit size needed for TPU
    img = tf.cast(img, tf.float32)
    img = tf.math.divide_no_nan(img, tf.math.reduce_max(img))  # scale image to [0, 1]
    img = tf.image.resize_with_pad(
        img, target_size[0], target_size[1], method="nearest"
    )  # resize with pad to avoid distortion
    img = tf.reshape(img, [*target_size, 3])  # reshape after resize
    return img


# Decode mask from bytestring to tensor
def decode_mask(data, height, width, target_size=CFG.img_size):
    msk = tf.io.decode_raw(data, out_type=tf.uint8)
    msk = tf.reshape(msk, [height, width, 3])  # explicit size needed for TPU
    msk = tf.cast(msk, tf.float32)
    msk = msk / 255.0  # scale mask data to[0, 1]
    msk = tf.image.resize_with_pad(
        msk, target_size[0], target_size[1], method="nearest"
    )
    msk = tf.reshape(msk, [*target_size, 3])  # reshape after resize
    return msk


# Read tfrecord data & parse it & do augmentation
def read_tfrecord(example, augment=True, return_id=False, dim=CFG.img_size):
    tfrec_format = {
        "id": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),  # tf.string means bytestring
        "height": tf.io.FixedLenFeature([], tf.int64),
        "width": tf.io.FixedLenFeature([], tf.int64),
        "mask": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(
        example, tfrec_format
    )  # parses a single example proto.
    image_id = example["id"]
    height = example["height"]
    width = example["width"]
    img = decode_image(example["image"], height, width, dim)  # access image
    msk = decode_mask(example["mask"], height, width, dim)  # access mask
    img = tf.reshape(img, [*dim, 3])
    msk = tf.reshape(msk, [*dim, 3])
    img = tf.repeat(img[:, :, np.newaxis,:], 96, axis=2)
    msk = tf.repeat(msk[:, :, np.newaxis,:], 96, axis=2)
    return (img, msk) if not return_id else (img, image_id, msk)

def get_dataset(
    filenames,
    shuffle=False,
    repeat=False,
    augment=False,
    cache=False,
    return_id=False,
    batch_size=CFG.batch_size ,
    target_size=CFG.img_size,
    drop_remainder=False,
    seed=CFG.seed,
):
    dataset = tf.data.TFRecordDataset(filenames)  # read tfrecord files
    dataset = dataset.map(
        lambda x: read_tfrecord(
            x,
            augment=augment,  # unparse tfrecord data with masks
            return_id=return_id,
            dim=target_size,
        ))
    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)  # batch the data
    dataset = dataset.prefetch(AUTO)  # prefetch data for speedup
    return dataset

ds = get_dataset(ALL_FILENAMES, augment=False, cache=False, repeat=False)

# **<span style="color:#F7B2B0;">Model Architecture</span>**


![](https://i.imgur.com/eUyPwD0.png)

 
UNETR uses a contracting-expanding pattern consisting of a stack of transformers as the encoder which is connected to a decoder via skip connections. 1D sequence of a 3D input volume x ∈ R^(H×W×D×C) is created with resolution (H,W,D) and C input channels and divide it into flattened uniform non-overlapping patches xv ∈R^(N×(P^3 .C)) where (P, P, P) denotes the resolution of each patch and N = (H×W ×D)/P^3 is the length of the sequence.Then the patches are projected into a K dimensional embedding space using a linear layer and add 1D positional embedding to it. After embeddings a stack of transformer blocks consisting of multi-head self-attention (MSA) and multilayer perceptron (MLP) sublayers are used.
    


In [None]:
class SingleDeconv3DBlock(tf.keras.layers.Layer):

    def __init__(self,filters):
        super(SingleDeconv3DBlock, self).__init__()
        self.block = tf.keras.layers.Conv3DTranspose(filters= filters, 
                                                     kernel_size=2, strides=2, 
                                                     padding="valid", 
                                                     output_padding=None)
                                                     

    def call(self, inputs):        
        return self.block(inputs)



class SingleConv3DBlock(tf.keras.layers.Layer):

    def __init__(self, filters,kernel_size):
        super(SingleConv3DBlock, self).__init__()
        self.kernel=kernel_size
        self.res = tuple(map(lambda i: (i - 1)//2, self.kernel))
        self.block = tf.keras.layers.Conv3D(filters= filters, 
                                            kernel_size=kernel_size, 
                                            strides=1, 
                                            padding='same')

    def call(self, inputs):
        return self.block(inputs)
    
class Conv3DBlock(tf.keras.layers.Layer):

    def __init__(self, filters,kernel_size=(3,3,3)):
        super(Conv3DBlock, self).__init__()
        self.a= tf.keras.Sequential([
                                     SingleConv3DBlock(filters,kernel_size=kernel_size),
                                     tf.keras.layers.BatchNormalization(),
                                     tf.keras.layers.Activation('relu')
        ])
        

    def call(self, inputs):
        return self.a(inputs)
    
class Deconv3DBlock(tf.keras.layers.Layer):

    def __init__(self, filters,kernel_size=(3,3,3)):
        super(Deconv3DBlock, self).__init__()
        self.a= tf.keras.Sequential([
                                     SingleDeconv3DBlock(filters=filters),
                                     SingleConv3DBlock(filters=filters,kernel_size=kernel_size),
                                     tf.keras.layers.BatchNormalization(),
                                     tf.keras.layers.Activation('relu')
        ])
  
    def call(self, inputs):
        return self.a(inputs)
    




MLP comprises two linear layers with GELU activation functions, i is the intermediate block identifier, and L is the number of transformer layers.A MSA sublayer comprises parallel self-attention (SA) heads. The SA block is a parameterized function that learns the mapping between a query (q) and the corresponding key (k) and value (v) representations in a sequence. The attention weights (A) are computed by measuring the similarity between two elements in z and their key-value pairs using softmax function. 
    


In [None]:
class SelfAttention(tf.keras.layers.Layer):

    def __init__(self, num_heads,embed_dim,dropout):
        super(SelfAttention, self).__init__()

        self.num_attention_heads = num_heads
        self.attention_head_size = int(embed_dim / num_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query=tf.keras.layers.Dense(self.all_head_size)
        self.key = tf.keras.layers.Dense(self.all_head_size)
        self.value = tf.keras.layers.Dense(self.all_head_size)                

        self.out=tf.keras.layers.Dense(embed_dim)
        self.attn_dropout=tf.keras.layers.Dropout(dropout)
        self.proj_dropout=tf.keras.layers.Dropout(dropout)

        self.softmax=tf.keras.layers.Softmax()

        self.vis=False

    def transpose_for_scores(self,x):
        new_x_shape=list(x.shape[:-1] + (self.num_attention_heads, self.attention_head_size))
        new_x_shape[0] = -1
        y = tf.reshape(x, new_x_shape)
        return tf.transpose(y,perm=[0,2,1,3])

    def call(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)  
        attention_scores= query_layer @ tf.transpose(key_layer,perm=[0,1,3,2])
        attention_scores= attention_scores/math.sqrt(self.attention_head_size)
        attention_probs=self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer= attention_probs @ value_layer
        context_layer=tf.transpose( context_layer, perm=[0,2,1,3])
        new_context_layer_shape = list(context_layer.shape[:-2] + (self.all_head_size,))
        new_context_layer_shape[0]= -1
        context_layer = tf.reshape(context_layer,new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights        

In [None]:
class Mlp(tf.keras.layers.Layer):

    def __init__(self, output_features, drop=0.):
        super(Mlp, self).__init__()
        self.a=tf.keras.layers.Dense(units=output_features,activation=tf.nn.gelu)
        self.b=tf.keras.layers.Dropout(drop)

    def call(self, inputs):
        x=self.a(inputs)
        return self.b(x)

class PositionwiseFeedForward(tf.keras.layers.Layer):

    def __init__(self, d_model=768,d_ff=2048, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.a=tf.keras.layers.Dense(units=d_ff)
        self.b=tf.keras.layers.Dense(units=d_model)
        self.c=tf.keras.layers.Dropout(dropout)

    def call(self, inputs):
        return self.b(self.c(tf.nn.relu(self.a(inputs))))

##embeddings, projection_dim=embed_dim
class PatchEmbedding(tf.keras.layers.Layer): 
  def __init__(self ,  cube_size, patch_size , embed_dim):
        super(PatchEmbedding, self).__init__()
        self.num_of_patches=int((cube_size[0]*cube_size[1]*cube_size[2])/(patch_size*patch_size*patch_size))
        self.patch_size=patch_size
        self.size = patch_size
        self.embed_dim = embed_dim

        self.projection = tf.keras.layers.Dense(embed_dim)

        self.clsToken = tf.Variable(tf.keras.initializers.GlorotNormal()(shape=(1 , 512 , embed_dim)) , trainable=True)

        self.positionalEmbedding = tf.keras.layers.Embedding(self.num_of_patches , embed_dim)
        self.patches=None
        self.lyer = tf.keras.layers.Conv3D(filters= self.embed_dim,kernel_size=self.patch_size, strides=self.patch_size,padding='valid')
        #embedding - basically is adding numerical embedding to the layer along with an extra dim  
      
  def call(self , inputs):
        patches =self.lyer(inputs)
        patches = tf.reshape(patches , (tf.shape(inputs)[0] , -1 , self.size * self.size * 3))
        patches = self.projection(patches)
        positions = tf.range(0 , self.num_of_patches , 1)[tf.newaxis , ...]
        positionalEmbedding = self.positionalEmbedding(positions)
        patches = patches + positionalEmbedding

        return patches, positionalEmbedding


A sequence representation zi (i ∈ {3,6,9,12}) is extracted with size H×W×D /P^3 ×K, from the transformer and reshaped into a (H/P) × (W/P) × (D/P) ×K tensor. At the bottleneck of the encoder (i.e. output of the transformer's last layer), a deconvolutional layer is applied to the transformed feature map to increase its resolution by a factor of 2. The resized feature map is then concatenated with the feature map of the previous transformer output (e.g. z9), and fed into consecutive 3 × 3 × 3 convolutional layers and the output is upsampled using a deconvolutional layer. This process is repeated for all the other subsequent layers up to the original input resolution where the final output is fed into a 1×1×1 convolutional layer with a softmax activation function to generate voxel-wise semantic predictions.
    

In [None]:
##transformerblock
class TransformerLayer(tf.keras.layers.Layer):
    def __init__(self ,  embed_dim, num_heads ,dropout, cube_size, patch_size):
      super(TransformerLayer,self).__init__()

      self.attention_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)

      self.mlp_norm = tf.keras.layers.LayerNormalization(epsilon=1e-6)

#embed_dim/no-of_heads
      self.mlp_dim = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size))
      
      self.mlp = PositionwiseFeedForward(embed_dim,2048)
      self.attn = SelfAttention(num_heads, embed_dim, dropout)
    
    def call(self ,x  , training=True):
      h=x
      x=self.attention_norm(x)
      x,weights= self.attn(x)
      x=x+h
      h=x

      x = self.mlp_norm(x)
      x = self.mlp(x)

      x = x + h

      return x, weights


class TransformerEncoder(tf.keras.layers.Layer):
  def __init__(self ,embed_dim , num_heads,cube_size, patch_size , num_layers=12 , dropout=0.1,extract_layers=[3,6,9,12]):
    super(TransformerEncoder,self).__init__()
#  embed_dim, num_heads ,dropout, cube_size, patch_size
    self.embeddings = PatchEmbedding(cube_size,patch_size, embed_dim)
    self.extract_layers =extract_layers
    self.encoders = [TransformerLayer(embed_dim, num_heads,dropout, cube_size, patch_size) for _ in range(num_layers)]
  
  def call(self , inputs , training=True):
    extract_layers = []
    x = inputs
    x,_=self.embeddings(x)
    
    for depth,layer in enumerate(self.encoders):
      x,_= layer(x , training=training)
      if depth + 1 in self.extract_layers:
                extract_layers.append(x)
    
    return extract_layers

In [None]:
class UNETR(tf.keras.Model):
    def __init__(self, img_shape=(96,96, 96), input_dim=3, output_dim=3, embed_dim=768, patch_size=16, num_heads=12, dropout=0.1):
        super(UNETR,self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embed_dim = embed_dim
        self.img_shape = img_shape
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.num_layers = 12
        self.ext_layers = [3, 6, 9, 12]
        
        self.patch_dim = [int(x / patch_size) for x in img_shape]
        self.transformer = \
            TransformerEncoder(
                self.embed_dim,
                self.num_heads,
                self.img_shape,
                self.patch_size,
                self.num_layers,
                self.dropout,
                self.ext_layers
            )
        
        # U-Net Decoder
        self.decoder0 = \
            tf.keras.Sequential([
                Conv3DBlock(32, (3,3,3)),
                Conv3DBlock(64, (3,3,3))]
            )
      
        self.decoder3 = \
            tf.keras.Sequential([
                Deconv3DBlock(512),
                Deconv3DBlock(256),
                Deconv3DBlock(128)]
            )
   
        self.decoder6 = \
            tf.keras.Sequential([
                Deconv3DBlock(512),
                Deconv3DBlock(256)]
            )
    
        self.decoder9 = \
            Deconv3DBlock(512)

        self.decoder12_upsampler = \
            SingleDeconv3DBlock(512)

        self.decoder9_upsampler = \
            tf.keras.Sequential([
                Conv3DBlock(512),
                Conv3DBlock(512),
                Conv3DBlock(512),
                SingleDeconv3DBlock(256)]
            )

        self.decoder6_upsampler = \
            tf.keras.Sequential([
                Conv3DBlock(256),
                Conv3DBlock(256),
                SingleDeconv3DBlock(128)]
            )

        self.decoder3_upsampler = \
            tf.keras.Sequential(
                [Conv3DBlock(128),
                Conv3DBlock(128),
                SingleDeconv3DBlock(64)]
            )

        self.decoder0_header = \
            tf.keras.Sequential(
                [Conv3DBlock(64),
                Conv3DBlock(64),
                SingleConv3DBlock(output_dim, (1,1,1))]
            ) 

 
    def call(self, x):
        z = self.transformer(x)
        z0, z3, z6, z9, z12 = x, z[0],z[1],z[2],z[3]
        z3 = tf.reshape(tf.transpose(z3,perm=[0,2,1]),[-1,  *self.patch_dim,self.embed_dim])
        z6 = tf.reshape(tf.transpose(z6,perm=[0,2,1]),[-1,  *self.patch_dim,self.embed_dim])
        z9 = tf.reshape(tf.transpose(z9,perm=[0,2,1]),[-1,  *self.patch_dim,self.embed_dim])
        z12 = tf.reshape(tf.transpose(z12,perm=[0,2,1]),[-1,  *self.patch_dim,self.embed_dim])
        z12 = self.decoder12_upsampler(z12)
        z9 = self.decoder9(z9)
        z9 = self.decoder9_upsampler(tf.concat([z9, z12], 4))
        z6 = self.decoder6(z6)
        z6 = self.decoder6_upsampler(tf.concat([z6, z9], 4))
        z3 = self.decoder3(z3)
        z3 = self.decoder3_upsampler(tf.concat([z3, z6], 4))
        z0 = self.decoder0(z0)
        output = self.decoder0_header(tf.concat([z0, z3], 4))
        return output
        

# **<span style="color:#F7B2B0;">Loss Functions</span>**

In [None]:
from segmentation_models.base import functional as F
import tensorflow.keras.backend as K

kwargs = {}
kwargs["backend"] = K  # set tensorflow.keras as backend


def dice_coef(y_true, y_pred):
    """Dice coefficient"""
    dice = F.f_score(
        y_true,
        y_pred,
        beta=1,
        smooth=1e-5,
        per_image=False,
        threshold=0.5,
        **kwargs,
    )
    return dice

def tversky(y_true, y_pred, axis=(0, 1, 2), alpha=0.3, beta=0.7, smooth=0.0001):
    "Tversky metric"
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    tp = tf.math.reduce_sum(y_true * y_pred, axis=axis) # calculate True Positive
    fn = tf.math.reduce_sum(y_true * (1 - y_pred), axis=axis) # calculate False Negative
    fp = tf.math.reduce_sum((1 - y_true) * y_pred, axis=axis) # calculate False Positive
    tv = (tp + smooth) / (tp + alpha * fn + beta * fp + smooth) # calculate tversky
    tv = tf.math.reduce_mean(tv)
    return tv


def tversky_loss(y_true, y_pred):
    "Tversky Loss"
    return 1 - tversky(y_true, y_pred)


def focal_tversky_loss(y_true, y_pred, gamma=0.75):
    "Focal Tversky Loss: Focal Loss + Tversky Loss"
    tv = tversky(y_true, y_pred)
    return K.pow((1 - tv), gamma)


# Register custom objects
custom_objs = {
    "dice_loss": sm.losses.dice_loss,
    "dice_coef": dice_coef,
    "bce_dice_loss": sm.losses.bce_dice_loss,
    "bce_jaccard_loss": sm.losses.bce_jaccard_loss,
    "tversky_loss": tversky_loss,
    "focal_tversky_loss": focal_tversky_loss,
    "jaccard_loss": sm.losses.jaccard_loss,
    "precision": sm.metrics.precision,
    "recall": sm.metrics.recall,
}
tf.keras.utils.get_custom_objects().update(custom_objs)


# **<span style="color:#F7B2B0;">Callbacks - LR schedule</span>**

In [None]:
def get_lr_callback():
    if CFG.lr_schedule == "ReduceLROnPlateau":
        lr_schedule = tf.keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.1,
            patience=int(CFG.patience / 2),
            min_lr=CFG.lr / 1e2,
        )
    elif CFG.lr_schedule == "CosineDecay":
        lr_schedule = tf.keras.experimental.CosineDecay(
            initial_learning_rate=CFG.lr, decay_steps=CFG.epochs + 2, alpha=CFG.lr / 1e2
        )
        lr_schedule = tf.keras.callbacks.LearningRateScheduler(lr_schedule, verbose=0)
    elif CFG.lr_schedule == "ExponentialDecay":
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=CFG.lr,
            decay_steps=CFG.epochs + 2,
            decay_rate=0.05,
            staircase=False,
        )
        lr_schedule = tf.keras.callbacks.LearningRateScheduler(lr_schedule, verbose=0)
    return lr_schedule

# **<span style="color:#F7B2B0;">Model Summary</span>**

In [None]:
def get_model(name=CFG.model_name, loss=CFG.loss, backbone=CFG.backbone):
    #model = TransUNet(image_size=CFG.img_size[0], freeze_enc_cnn=False, pretrain=True)
    model = UNETR()
    lr = CFG.lr
    if CFG.optimizer == "Adam":
        opt = tf.keras.optimizers.Adam(learning_rate=lr)
    elif CFG.optimizer == "AdamW":
        opt = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=lr)
    elif CFG.optimizer == "RectifiedAdam":
        opt = tfa.optimizers.RectifiedAdam(learning_rate=lr)
    else:
        raise ValueError("Wrong Optimzer Name")

    model.compile(
        optimizer=opt,
        loss=loss,
        steps_per_execution=CFG.steps_per_execution, # to reduce idle time
        metrics=[
            dice_coef,
            "precision",
            "recall",
        ],
    )
    return model

In [None]:
#model = tf.keras.models.load_model('../input/unetr-model/unetr')
model = get_model()
#model.summary()

In [None]:
from sklearn.model_selection import KFold

M = {}
# Which Metrics to store
metrics = [
    "loss",
    "dice_coef",
    "precision",
    "recall",
]
# Intialize Metrics
for fm in metrics:
    M["val_" + fm] = []

ALL_FILENAMES = sorted(ALL_FILENAMES)

# Split tfrecord using KFold
kf = KFold(n_splits=CFG.folds, shuffle=True, random_state=CFG.seed) # kfold between trrecord files
for fold, (train_idx, valid_idx) in enumerate(kf.split(ALL_FILENAMES)):
    # If fold is not in selected folds then avoid that fold
    if fold not in CFG.selected_folds:
        continue
        
    # Train and validation files
    TRAIN_FILENAMES = [ALL_FILENAMES[i] for i in train_idx]
    VALID_FILENAMES = [ALL_FILENAMES[i] for i in valid_idx]
    
    # Take Only 10 Files if run in Debug Mode
    if CFG.debug:
        TRAIN_FILENAMES = TRAIN_FILENAMES[:10]
        VALID_FILENAMES = VALID_FILENAMES[:10]

    # Shuffle train files
    random.shuffle(TRAIN_FILENAMES)

    # Count train and valid samples
    NUM_TRAIN = count_data_items(TRAIN_FILENAMES)
    NUM_VALID = count_data_items(VALID_FILENAMES)

    # Compute batch size & steps_per_epoch
    BATCH_SIZE = CFG.batch_size * REPLICAS
    STEPS_PER_EPOCH = NUM_TRAIN // BATCH_SIZE

    print("#" * 65)
    print("#### FOLD:", fold)
    print(
        "#### IMAGE_SIZE: (%i, %i) | BATCH_SIZE: %i | EPOCHS: %i"
        % (CFG.img_size[0], CFG.img_size[1], BATCH_SIZE, CFG.epochs)
    )
    print(
        "#### MODEL: %s | BACKBONE: %s | LOSS: %s"
        % (CFG.model_name, CFG.backbone, CFG.loss)
    )
    print("#### NUM_TRAIN: {:,} | NUM_VALID: {:,}".format(NUM_TRAIN, NUM_VALID))
    print("#" * 65)

   
    # Build model in device
    K.clear_session()
    with strategy.scope():
        model = get_model(name=CFG.model_name, backbone=CFG.backbone, loss=CFG.loss)

    # Callbacks
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        "/kaggle/working/fold-%i.h5" % fold,
        verbose=CFG.verbose,
        monitor="val_dice_coef",
        mode="max",
        save_best_only=True,
        save_weights_only=True,
    )
    callbacks = [checkpoint, get_lr_callback()]

    # Create train & valid dataset
    train_ds = get_dataset(
        TRAIN_FILENAMES,
        augment=CFG.augment,
        batch_size=BATCH_SIZE,
        cache=False,
        drop_remainder=False,
    )
    valid_ds = get_dataset(
        VALID_FILENAMES,
        shuffle=False,
        augment=False,
        repeat=False,
        batch_size=BATCH_SIZE,
        cache=False,
        drop_remainder=False,
    )

    # Train model
    history = model.fit(
        train_ds,
        epochs=CFG.epochs if not CFG.debug else 2,
        steps_per_epoch=STEPS_PER_EPOCH,
        callbacks=callbacks,
        validation_data=valid_ds,
        #         validation_steps = NUM_VALID/BATCH_SIZE,
        verbose=CFG.verbose,
    )

    # Convert dict history to df history
    history = pd.DataFrame(history.history)

    # Load best weights
    model.load_weights("/kaggle/working/fold-%i.h5" % fold)

    # Compute & save best valid result
    print("\nValid Result:")
    m = model.evaluate(
        get_dataset(
            VALID_FILENAMES,
            batch_size=BATCH_SIZE,
            augment=False,
            shuffle=False,
            repeat=False,
            cache=False,
        ),
        return_dict=True,
#        steps=NUM_VALID/BATCH_SIZE,
        verbose=1,
    )
    print()
    
    # Store valid results
    for fm in metrics:
        M["val_" + fm].append(m[fm])
        
 
    # Plot Training History
    if CFG.display_plot:
        plt.figure(figsize=(15, 5))
        plt.plot(
            np.arange(len(history["dice_coef"])),
            history["dice_coef"],
            "-o",
            label="Train Dice",
            color="#ff7f0e",
        )
        plt.plot(
            np.arange(len(history["dice_coef"])),
            history["val_dice_coef"],
            "-o",
            label="Val Dice",
            color="#1f77b4",
        )
        x = np.argmax(history["val_dice_coef"])
        y = np.max(history["val_dice_coef"])
        xdist = plt.xlim()[1] - plt.xlim()[0]
        ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x, y, s=200, color="#1f77b4")
        plt.text(x - 0.03 * xdist, y - 0.13 * ydist, "max dice\n%.2f" % y, size=14)
        plt.ylabel("dice_coef", size=14)
        plt.xlabel("Epoch", size=14)
        plt.legend(loc=2)
        plt2 = plt.gca().twinx()
        plt2.plot(
            np.arange(len(history["dice_coef"])),
            history["loss"],
            "-o",
            label="Train Loss",
            color="#2ca02c",
        )
        plt2.plot(
            np.arange(len(history["dice_coef"])),
            history["val_loss"],
            "-o",
            label="Val Loss",
            color="#d62728",
        )
        x = np.argmin(history["val_loss"])
        y = np.min(history["val_loss"])
        ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x, y, s=200, color="#d62728")
        plt.text(x - 0.03 * xdist, y + 0.05 * ydist, "min loss", size=14)
        plt.ylabel("Loss", size=14)
        plt.title("FOLD %i" % (fold), size=18)
        plt.legend(loc=3)
        plt.savefig(f"fig-{fold}.png")
        plt.show()

# **<span style="color:#F7B2B0;">Making Predictions</span>**

In [None]:
#pred = model.predict(ds.skip(200).take(1))



## **<span style="color:#F7B2B0;">References</span>**

https://www.kaggle.com/code/usharengaraju/tensorflow-unetr-w-b

https://arxiv.org/pdf/2103.10504.pdf

https://github.com/tamasino52/UNETR (Pytorch)

https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb (Pytorch)

https://www.kaggle.com/code/awsaf49/uwmgi-transunet-2-5d-train-tf

https://www.kaggle.com/datasets/awsaf49/uwmgi-25d-tfrecord-dataset

https://www.kaggle.com/code/bsridatta/eda-for-a-healthy-gi-tract

https://www.kaggle.com/datasets/bsridatta/uwmadison-flattened-metadata
