# Kidney Tumor Detection from MRI Scans with 3D CNNs ✍️

![CT Image](https://storage.googleapis.com/kaggle-datasets-images/674071/1185670/0d449dc88ac1318321ae8f7de974f0fa/dataset-cover.jpg?t=2020-05-25-13-21-33)

- **Author:** *Mariusz Wiśniewski*
- **Date created:** *April 17th, 2023*
- **Last modified:** *April 21st, 2023*

## Overview

In this notebook, we will train a (2+1)D convolutional neural network to predict the presence of brain tumor from volumetric MRI scans. After that, we will see how to generate a class activation heatmap for our 3D image classification model.

### Libraries Used

- [Tensorflow 🔥](https://www.tensorflow.org)
- [NiBabel 🩻](https://nipy.org/nibabel/)
- [SciPy 🔬](https://scipy.org)
- [OpenCV 🖼️](https://opencv.org)

### References

- [VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition 📃](https://www.ri.cmu.edu/pub_files/2015/9/voxnet_maturana_scherer_iros15.pdf)
- [FusionNet: 3D Object Classification Using MultipleData Representations 📃](http://3ddl.cs.princeton.edu/2016/papers/Hegde_Zadeh.pdf)
- [Uniformizing Techniques to Process CT scans with 3D CNNs for Tuberculosis Prediction 📃](https://arxiv.org/abs/2007.13224)
- [Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization 📃](https://arxiv.org/abs/1610.02391)
- [Learning Deep Features for Discriminative Localization 📃](https://arxiv.org/abs/1512.04150)
- [The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS) 📃](https://ieeexplore.ieee.org/document/6975210)
- [Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features 📃](https://www.nature.com/articles/sdata2017117)
- [Identifying the Best Machine Learning Algorithms for Brain Tumor Segmentation, Progression Assessment, and Overall Survival Prediction in the BRATS Challenge 📃](https://arxiv.org/abs/1811.02629)
- [3D image classification from CT scans 📝](https://keras.io/examples/vision/3D_image_classification/)
- [Grad-CAM class activation visualization 📝](https://keras.io/examples/vision/grad_cam/)
- [A Comprehensive Introduction to Different Types of Convolutions in Deep Learning](https://towardsdatascience.com/a-comprehensive-introduction-to-different-types-of-convolutions-in-deep-learning-669281e58215)
- [Spine🦴Fracture: EDA🔎 & loading DICOM & 3D browse 📓](https://www.kaggle.com/code/jirkaborovec/spine-fracture-eda-loading-dicom-3d-browse)
- [[RSNA_22] Dicom to NumPy 3D 📓](https://www.kaggle.com/code/vmuzhichenko/rsna-22-dicom-to-numpy-3d)

# Notebook Setup

## Import Statements

In [None]:
import os
import random
import tempfile
from glob import glob

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import tensorflow as tf
import wandb
from dotenv import load_dotenv
from ipywidgets import IntSlider, interact
from matplotlib import animation, rc
from matplotlib.patches import PathPatch, Rectangle
from matplotlib.path import Path
from volumentations import Flip, Compose
from scipy.ndimage import zoom
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import class_weight
from tensorflow.keras import Input, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.layers import (BatchNormalization, Conv3D, Dense,
                                     Dropout, GlobalAveragePooling3D,
                                     MaxPool3D, Normalization)
from tensorflow.keras.optimizers import Adam
from wandb.keras import WandbCallback

## Loading Environmental Variables

In [None]:
load_dotenv()

## Random Seed for Reproducibility

In [None]:
seed = 27

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

strategy = tf.distribute.MultiWorkerMirroredStrategy()

## Weights & Biases Setup

In [None]:
wandb.login(key=os.environ.get('WANDB_API_KEY'))
run = wandb.init(
    name='4-mw-3D-CNN-cw-st-resampled-pt-flips',
    project=os.environ.get('WANDB_KITS_PROJECT'),
    entity=os.environ.get('WANDB_ENTITY'),
    id='4-mw-3D-CNN-cw-st-resampled-pt-flips',
    # resume=True,
)

## Project Configuration

In [None]:
wandb.config = {
    'learning_rate': 1e-3,
    'min_learning_rate': 1e-7,
    'epochs': 500,
    'batch_size': 64,
    'test_batch_size': 1,
    'img_size': 160,
    'depth': 80,
    'n_classes': 2,
    # 'fold': 0,
}

config = wandb.config

# Dataset

BraTS2020 is a medical imaging challenge. The goal is to segment the tumor from the brain MRI scans. The dataset is available on [Kaggle](https://www.kaggle.com/datasets/awsaf49/brats2020-training-data).

All BraTS multimodal scans are available as NIfTI files (.nii.gz) and describe:

- native (**T1**),
- post-contrast T1-weighted (**T1Gd**),
- T2-weighted (**T2**),
- T2 Fluid Attenuated Inversion Recovery (**T2-FLAIR**)

volumes, and were acquired with different clinical protocols and various scanners from multiple (n=19) institutions.

All the imaging datasets have been segmented manually, by one to four raters, following the same annotation protocol, and their annotations were approved by experienced neuro-radiologists. Annotations comprise:

- the necrotic and non-enhancing tumor core (**NCR/NET — label 1**),
- the peritumoral edema (**ED — label 2**),
- the GD-enhancing tumor (**ET — label 4**),
 
as described both in the [BraTS 2012-2013 TMI paper](https://ieeexplore.ieee.org/document/6975210) and in the [latest BraTS summarizing paper](https://arxiv.org/abs/1811.02629). The provided data are distributed after their pre-processing, i.e., co-registered to the same anatomical template, interpolated to the same resolution ($1 {mm}^3$) and skull-stripped.

## Class Distribution

In [None]:
BASE_PATH = '.'
DATA_TYPES = ['vol', 'seg']
CLASS_NAMES = ['normal', 'tumor']

# MEAN = np.array(0.20801003)
# STD = np.array(0.16490023)

df = pd.read_csv(f'{BASE_PATH}/metadata.csv')
print(df['label'].value_counts())

In [None]:
df.head()

## Build Train and Validation Datasets

In [None]:
# df['path'] = df['path'].apply(lambda x: x.replace('data', 'data_128x64x64'))
df['path']

In [None]:
labels = np.array(df['label'])
paths = np.array([f'{path}volume-patch-{str(patch)}.npz' for path, patch in zip(df['path'], df['patch'])])
train_indices = df[df['test'] == 0].index.values.tolist()
test_indices = df[df['test'] == 1].index.values.tolist()
assert set(train_indices).isdisjoint(test_indices)

print(
    'Number of samples:\n'
    f'train: {len(train_indices)} ({round(float(len(train_indices)) / float(len(labels)) * 100.0)}% of the dataset)\n'
    f'test: {len(test_indices)} ({round(float(len(test_indices)) / float(len(labels)) * 100.0)}% of the dataset)'
)

# Data Augmentation

The CT scans are augmented by randomly flipping the volumes, so that the model does not directly know the correspondence of the brain half and the volume. The data is stored in rank-4 tensors of shape `(samples, height, width, depth, channels)`.

In [None]:
def get_augmentations():
    return Compose([
        Flip(1, p=0.5),
        Flip(0, p=0.5),
        Flip(2, p=0.5),
    ], p=1.0)

The training data is passed through an augmentation function that randomly rotates volume at different angles while constructing the train and test data loaders. It should be noted that both the training and testing data have previously been rescaled to have values between 0 and 1.

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(self, indices, paths, labels, batch_size=4, dim=(240, 120, 155),
                 n_classes=2, shuffle=True, transform=None):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.paths = paths
        self.labels = labels
        self.indices = indices
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.transform = transform
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.indices) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Find list of IDs
        indices_temp = [self.indices[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(indices_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.indices))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, indices_temp):
        'Generates data containing batch_size samples'
        # Initialization
        X = np.empty((self.batch_size, *self.dim, len(DATA_TYPES[:-1])))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(indices_temp):
            # Store sample
            volume = np.load(self.paths[ID])['data'].astype('float32')
            # volume = volume.astype('float32')

            if self.transform is not None:
                volume = self.transform(**{'image': volume})['image']
            X[i,] = np.expand_dims(volume, axis=3)

            # # Store class
            y[i] = self.labels[ID]

        return X, tf.keras.utils.to_categorical(y, num_classes=self.n_classes)

In [None]:
augmentations = get_augmentations()

training_generator = DataGenerator(
    train_indices,
    paths,
    labels,
    batch_size=config['batch_size'],
    dim=(config['img_size'], config['img_size'] // 2, config['depth']),
    shuffle=True,
    transform=augmentations,
)

test_generator = DataGenerator(
    test_indices,
    paths,
    labels, 
    batch_size=config['batch_size'],
    dim=(config['img_size'], config['img_size'] // 2, config['depth']),
    shuffle=False,
    # transform=augmentations,
)

## Augmented CT Scan Visualization

In [None]:
images, _ = training_generator.__getitem__(0)
image = images[0]

print('Dimensions of the CT scan are:', image.shape)
plt.imshow(np.squeeze(image[:, :, 32]), cmap='bone')

## CT Slice Montage

Since a CT scan comprises many slices, let us visualize a montage of the them.

In [None]:
def plot_slices(num_rows, num_columns, width, height, data):
    """Plot a montage of CT slices"""
    data = np.rot90(np.array(data))
    data = np.transpose(data)
    data = np.reshape(data, (num_rows, num_columns, width, height))
    rows_data, columns_data = data.shape[0], data.shape[1]
    heights = [slc[0].shape[0] for slc in data]
    widths = [slc.shape[1] for slc in data[0]]
    fig_width = 12.0
    fig_height = fig_width * sum(heights) / sum(widths)
    _, axarr = plt.subplots(
        rows_data,
        columns_data,
        figsize=(fig_width, fig_height),
        gridspec_kw={'height_ratios': heights},
    )
    for i in range(rows_data):
        for j in range(columns_data):
            axarr[i, j].imshow(data[i][j], cmap='bone')
            axarr[i, j].axis('off')
    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    plt.show()


# Visualize montage of slices.
# 4 rows and 10 columns for 40 slices of the CT scan.
plot_slices(4, 10, config['img_size'], config['img_size'] // 2, image[:, :, 12:52])

# 3D Convolutional Neural Network

## 3D Convolutions

RGB images that consist of 3 channels are typically processed using 2D CNNs. A 3D CNN is essentially the 3D equivalent: it takes as input a 3D volume or a sequence of 2D frames (e.g. slices in a CT scan). 3D CNNs are a powerful model for learning representations for volumetric data. In contrary to the 2D convolutions, here in 3D convolution, the filter depth is smaller than the input layer depth (*kernel size < channel size*). As a result, the 3D filter has the ability to move in all three directions (height, width, channel of the image). At each position, the element-wise multiplication and addition provide one number. Since the filter slides through a 3D space, the output numbers are arranged in a 3D space as well. The output is then a 3D data.

<center>
<figure>
    <img src='https://i.imgur.com/TYMETaw.gif' alt='3D Convolution' style='width: 680px;'/>
    <figcaption>Visualization of a 3D convolution of <i>5x5x5</i> volume with <i>3x3x3</i> kernel, no padding, no strides. It results in a <i>3x3x3</i> output volume.</figcaption>
</figure>
</center>
    
Similarly to 2D convolutions, which encode spatial relationships of objects in a 2D domain, 3D convolutions can describe the spatial relationships of objects in the 3D space.

## (2+1)D Convolutions

Instead of a single 3D convolution to process the time and space dimensions, we will use a (2+1)D convolution, which processes the space and time (or depth) dimensions separately. The figure below shows the factored spatial and temporal convolutions of a (2+1)D convolution.

<center>
    <figure>
        <img src='https://www.tensorflow.org/images/tutorials/video/2plus1CNN.png', alt='(2+1)D Convolution' style='width: 680px;'/>
    <figcaption>Visualization of a (2+1)D convolution of <i>6x6x8</i> volume with <i>3x3x3</i> kernel, no padding, no strides, that is separated into <i>1x3x3</i> spatial kernel and <i>3x1x1</i> temporal kernel. See the <a href='https://www.tensorflow.org/tutorials/video/video_classification'> source</a>.</figcaption>
</figure>
</center>

The main advantage of this approach is that it reduces the number of parameters. In the (2+1)D convolution the depthwise convolution takes in data of the shape `(1, width, height)`, while the pointwise convolution takes in data of the shape `(depth, 1, 1)`. For example, a (2+1)D convolution with kernel size `(3x3x3)` would need weight matrices of size `(9 * channels**2) + (3 * channels**2)`, which is fewer than half as many as the full 3D convolution.

## Residual Connections

A *ResNet* model is constructed from a sequence of residual blocks. There are two branches in a residual block. The main branch performs the computation, but it is farily difficult for the gradients to flow through. The residual branch skips the main computation and typically merely adds the input to the main branch's result. Gradients readily flow via this branch. As a result, there will be an easy path from the loss function to any of the residual block's main branches. This overcomes the problem of vanishing gradients.

<center>
    <figure>
        <img src='https://miro.medium.com/max/1140/1*D0F3UitQ2l5Q0Ak-tjEdJg.webp', alt='Residual Connection' style='width: 480px;'/>
    <figcaption>Visualization of a residual connection. See the <a href='https://arxiv.org/abs/1512.03385'> source</a>.</figcaption>
</figure>
</center>

## About the Model

To make the model easier to understand, we structure it into residual blocks (see the [paper](https://arxiv.org/abs/1512.03385)). The architecture of the (2+1)D CNN used in this example is based on [A Closer Look at Spatiotemporal Convolutions for Action Recognition](https://ieeexplore.ieee.org/document/8578773) paper.

In contrast to CAM (see the [paper](https://arxiv.org/abs/1512.04150)), which requires performing global average pooling over convolutional maps immediately prior to prediction (convolutional feature maps -> global average pooling -> softmax layer), Grad-CAM does not require *any* modifications in the network architecture.

In [None]:
from tensorflow.keras.models import load_model
from miscnn.neural_network.metrics import dice_soft, dice_crossentropy, tversky_loss


def get_model():
    inputs = Input((None, None, None, 1))

    # extract layers from pretrained model pt_model
    with strategy.scope():
        pt_model = load_model('evaluation/fold_0/model.hdf5', custom_objects={'dice_soft': dice_soft, 'dice_crossentropy': dice_crossentropy, 'tversky_loss': tversky_loss})

    x = pt_model.layers[1](inputs)
    
    # take all layers until conv3d_9 (included)
    for layer in pt_model.layers[2:pt_model.layers.index(pt_model.get_layer('conv3d_9')) + 1]:
        x = layer(x)

    x = BatchNormalization(name='b1')(x)
    x = GlobalAveragePooling3D(name='gap1')(x)
    x = Dense(256, activation='relu', name='d1')(x)
    # x = Dropout(0.65, name='do1')(x)

    x = Dense(128, activation='relu', name='d2')(x)
    # x = Dropout(0.65, name='do2')(x)

    x = Dense(2, activation='softmax', name='d3')(x)

    return Model(inputs=inputs, outputs=x, name='encoder_model')

In [None]:
# from tensorflow.keras.models import load_model
# from miscnn.neural_network.metrics import dice_soft, dice_crossentropy, tversky_loss


# with strategy.scope():
#     pt_model = load_model('evaluation/fold_0/model.hdf5', custom_objects={'dice_soft': dice_soft, 'dice_crossentropy': dice_crossentropy, 'tversky_loss': tversky_loss})

# def get_model():
#     inputs = Input((None, None, None, 1))

#     encoder_output_layer = pt_model.get_layer('conv3d_9').output
#     encoder_model = Model(inputs=pt_model.inputs, outputs=encoder_output_layer)

#     x = encoder_model(inputs)
#     x = BatchNormalization(name='b1')(x)
#     x = GlobalAveragePooling3D(name='gap1')(x)
#     x = Dense(256, activation='relu', name='d1')(x)
#     # x = Dropout(0.5, name='do1')(x)

#     x = Dense(64, activation='relu', name='d2')(x)
#     # x = Dropout(0.5, name='do2')(x)

#     x = Dense(2, activation='softmax', name='d3')(x)

#     return Model(inputs=inputs, outputs=x, name='encoder_model')

# model = get_model()
# model.summary()

In [None]:
# def get_model(height=128, width=128, depth=64, channels=4):
#     """Build a 3D convolutional neural network model."""
#     inputs = Input((None, None, None, channels))

#     # x = Normalization(mean=MEAN, variance=STD**2)(inputs)

#     x = Conv3D(
#         filters=16, kernel_size=3, strides=(1, 1, 1), padding='same', activation='relu'
#     )(inputs)
#     x = Conv3D(
#         filters=16, kernel_size=3, strides=(1, 1, 1), padding='same', activation='relu'
#     )(inputs)
#     x = MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
#     x = BatchNormalization()(x)
    
#     x = Conv3D(
#         filters=32, kernel_size=3, strides=(1, 1, 1), padding='same', activation='relu'
#     )(x)
#     x = Conv3D(
#         filters=32, kernel_size=3, strides=(1, 1, 1), padding='same', activation='relu'
#     )(x)
#     x = MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
#     x = BatchNormalization()(x)
    
#     x = Conv3D(
#         filters=64, kernel_size=3, strides=(1, 1, 1), padding='same', activation='relu'
#     )(x)
#     x = Conv3D(
#         filters=64, kernel_size=3, strides=(1, 1, 1), padding='same', activation='relu'
#     )(x)
#     x = MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
#     x = BatchNormalization()(x)
    
#     # x = Conv3D(
#     #     filters=128, kernel_size=3, strides=(1, 1, 1), padding='same', activation='relu'
#     # )(x)
#     x = Conv3D(
#         filters=128, kernel_size=3, strides=(1, 1, 1), padding='same', activation='relu'
#     )(x)
#     x = MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
#     x = BatchNormalization()(x)

#     x = GlobalAveragePooling3D()(x)
    
#     x = Dense(units=64, activation='relu')(x)
#     # x = Dropout(0.6)(x)
    
#     x = Dense(units=32, activation='relu')(x)
#     # x = Dropout(0.6)(x)

#     outputs = Dense(units=config['n_classes'], activation='softmax')(x)

#     return Model(inputs, outputs, name='3D-CNN')

In [None]:
# class Conv2Plus1D(tf.keras.layers.Layer):
#     def __init__(self, filters, kernel_size, padding, **kwargs):
#         """
#             A sequence of convolutional layers that first apply the convolution operation
#             over the spatial dimensions, and then the temporal dimension. 
#         """
#         super().__init__(**kwargs)
#         self.seq = tf.keras.Sequential([  
#             # Spatial decomposition
#             Conv3D(filters=filters,
#                    kernel_size=(1, kernel_size[1], kernel_size[2]),
#                    padding=padding),
#             # Temporal decomposition
#             Conv3D(filters=filters, 
#                    kernel_size=(kernel_size[0], 1, 1),
#                    padding=padding)
#             ])
#         self.filters = filters
#         self.kernel_size = kernel_size
#         self.padding = padding

#     def call(self, x):
#         return self.seq(x)
    
#     def get_config(self):
#         config = super().get_config().copy()
#         config.update({
#             'filters': self.filters,
#             'kernel_size': self.kernel_size,
#             'padding': self.padding,
#         })
#         return config
    

# class ResidualMain(tf.keras.layers.Layer):
#     """
#         Residual block of the model with convolution, layer normalization,
#         and the activation function, ReLU.
#     """
#     def __init__(self, filters, kernel_size, **kwargs):
#         super().__init__(**kwargs)
#         self.seq = tf.keras.Sequential([
#             Conv2Plus1D(filters=filters,
#                         kernel_size=kernel_size,
#                         padding='same'),
#             tf.keras.layers.LayerNormalization(),
#             tf.keras.layers.ReLU(),
#             Conv2Plus1D(filters=filters, 
#                         kernel_size=kernel_size,
#                         padding='same'),
#             tf.keras.layers.LayerNormalization()
#         ])
#         self.filters = filters
#         self.kernel_size = kernel_size

#     def call(self, x):
#         return self.seq(x)
    
#     def get_config(self):
#         config = super().get_config().copy()
#         config.update({
#             'filters': self.filters,
#             'kernel_size': self.kernel_size,
#         })
#         return config
    
    
# class Project(tf.keras.layers.Layer):
#     """
#         Project certain dimensions of the tensor as the data is passed
#         through different sized filters and downsampled.
#     """
#     def __init__(self, units, **kwargs):
#         super().__init__(**kwargs)
#         self.seq = tf.keras.Sequential([
#             tf.keras.layers.Dense(units),
#             tf.keras.layers.LayerNormalization()
#         ])
#         self.units = units

#     def call(self, x):
#         return self.seq(x)
    
#     def get_config(self):
#         config = super().get_config().copy()
#         config.update({
#             'units': self.units,
#         })
#         return config
    
    
# def add_residual_block(input, filters, kernel_size):
#     """
#         Add residual blocks to the model. If the last dimensions of the input data
#         and filter size does not match, project it such that last dimension matches.
#     """
#     out = ResidualMain(filters, kernel_size)(input)

#     res = input
#     # Using the Keras functional APIs, project the last dimension of the tensor to
#     # match the new filter size
#     if out.shape[-1] != input.shape[-1]:
#         res = Project(out.shape[-1])(res)

#     return tf.keras.layers.add([res, out])

In [None]:
# def get_model(height=128, width=128, depth=64, channels=3):
#     """Build a 3D convolutional neural network model."""
#     inputs = Input((None, None, None, channels))
#     # x = Normalization(mean=MEAN, variance=STD**2)(inputs)

#     x = add_residual_block(inputs, filters=16, kernel_size=(3, 3, 3))
#     x = MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
#     x = BatchNormalization()(x)

#     x = add_residual_block(x, filters=32, kernel_size=(3, 3, 3))
#     x = MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
#     x = BatchNormalization()(x)
    
#     x = add_residual_block(x, filters=64, kernel_size=(3, 3, 3))
#     x = MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
#     x = BatchNormalization()(x)
    
#     x = add_residual_block(x, filters=128, kernel_size=(3, 3, 3))
#     x = MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(x)
#     x = BatchNormalization()(x)

#     x = GlobalAveragePooling3D()(x)
    
#     x = Dense(units=64, activation='relu')(x)
#     # x = Dropout(0.35)(x)
    
#     x = Dense(units=32, activation='relu')(x)
#     # x = Dropout(0.35)(x)

#     outputs = Dense(units=config['n_classes'], activation='softmax')(x)

#     return Model(inputs, outputs, name='2-1D-CNN2')

In [None]:
# Build model.
# if wandb.run.resumed:
# restore the best model
# artifact = run.use_artifact(
#     f'{os.environ.get("WANDB_ENTITY")}/{os.environ.get("WANDB_MD_LUNG_PROJECT")}/model-1-mw-21D-ResNet-cw-f0:v21', type='model')
# artifact_dir = artifact.download()
# model = tf.keras.models.load_model(
#     artifact_dir,
#     custom_objects={
#         'Project': Project,
#         'ResidualMain': ResidualMain,
#         'Conv2Plus1D': Conv2Plus1D,
#     }
# )
# else:
with strategy.scope():
    model = get_model()
    # model = get_model(height=config['img_size'], width=config['img_size'] // 2,
      # depth=config['depth'], channels=len(DATA_TYPES[:-1]))
# best_model = wandb.restore(
#     'model-best.h5',
#      run_path=f"{os.environ.get('WANDB_ENTITY')}/{os.environ.get('WANDB_MD_LUNG_PROJECT')}/1-mw-21D-ResNet")
# model.load_weights(best_model.name)
# model = add_regularization(model, regularizer=tf.keras.regularizers.l2(1e-3))
model.summary()

## Class Weights

In [None]:
train_labels = [labels[idx] for idx in train_indices]
class_weights = class_weight.compute_class_weight(
    class_weight='balanced', # balanced for computing weights based on class count
    classes=np.unique(train_labels),
    y=train_labels,
)
class_weights = dict(zip(np.unique(labels), class_weights))
print(class_weights)

## Model Training

In [None]:
# Compile model
with strategy.scope():
    model.compile(
        loss='categorical_crossentropy',
        optimizer=Adam(learning_rate=config['learning_rate']),
        metrics=['acc'],
    )

# Train the model
model.fit(
    training_generator,
    validation_data=test_generator,
    epochs=config['epochs'],
    class_weight=class_weights,
    # initial_epoch=wandb.run.step,
    verbose=2,
    callbacks=[
        ModelCheckpoint(f'3d_pt_best.h5', save_best_only=True, monitor='val_loss', mode='min'),
        ReduceLROnPlateau(
            monitor='loss',
            mode='min',
            patience=15,
            factor=0.1,
            min_lr=config['min_learning_rate'],
        ),
        EarlyStopping(monitor='val_loss', mode='min',
                      patience=30, restore_best_weights=True),
        WandbCallback(monitor='val_loss', mode='min'),
    ],
    workers=6
)

## Visualizing Training History

The training sets' model accuracy and loss are displayed here.

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 3))
ax = ax.ravel()

for i, metric in enumerate(['acc', 'loss']):
    ax[i].plot(model.history.history[metric])
    ax[i].plot(model.history.history[f'val_{metric}'])
    ax[i].set_title(f'Model {metric}')
    ax[i].set_xlabel('epochs')
    ax[i].set_ylabel(metric)
    ax[i].legend(['train', 'val'])

## Cleanup

In [None]:
wandb.finish()

## Single CT Scan Predictions

In [None]:
# model.load_weights(f'3d_best.h5')

In [None]:
# os.makedirs(f'mid_results/fold{config["fold"]}_128x64x128', exist_ok=True)

In [None]:
df_original = pd.read_csv(f'{BASE_PATH}/original_metadata.csv')
print(df_original['label'].value_counts())

test_paths = np.array([f'{path}volume.npz' for path in df_original['path']])
test_labels = np.array(df_original['label'])
test_indices = df_original[df_original['test'] == 1].index.values.tolist()

y_test = test_labels

In [None]:
for i in range(56):
    artifact = run.use_artifact(
        f'{os.environ.get("WANDB_ENTITY")}/{os.environ.get("WANDB_KITS_PROJECT")}/model-4-mw-3D-CNN-cw-st-resampled-pt-flips:v{i}', type='model')
    artifact_dir = artifact.download()
    model = tf.keras.models.load_model(
        artifact_dir
    )

    y_pred = []
    for ID in test_indices:
        volume = np.load(test_paths[ID])['data'].astype('float32')

        volume = np.expand_dims(volume, axis=3)
        volume = np.expand_dims(volume, axis=0)
        y_pred.append(np.argmax(model.predict(volume)))

    # print(f'\nArtifact {i}')
    # print(classification_report(y_test, y_pred, digits=3))

    # Create a txt file and append it with the classification report
    with open(f'new_reports/4-mw-3D-CNN-cw-st-resampled-pt-flips.txt', 'a') as f:
        f.write(f'Artifact {i}')
        f.write(classification_report(y_test, y_pred, digits=3))
        f.write('\n')


# #     y_pred = [np.argmax(x) for x in model.predict(test_generator, batch_size=config['test_batch_size'])]
# #     y_test = [labels[idx] for idx in test_indices]
# #     print(f'\nArtifact {i}')
# #     print(classification_report(y_test, y_pred, digits=3))
# #     # save classification report to file
# #     with open(f'mid_results/fold{config["fold"]}_std_mean_no_edema/classification_report_{i}.txt', 'w') as f:
# #         f.write(classification_report(y_test, y_pred, digits=3))

# #     cfsn_matrix = confusion_matrix(y_test, y_pred)
# #     df_cm = pd.DataFrame(cfsn_matrix, index=range(config['n_classes']), columns=CLASS_NAMES)
# #     plt.figure(figsize=(15, 6))
# #     sn.heatmap(df_cm, annot=True, linewidths=0.5, fmt='d')
# #     plt.savefig(f'mid_results/fold{config["fold"]}_std_mean/confusion_matrix_{i}.png')
# #     plt.show()

# Model Evaluation

In [None]:
# y_pred = [np.argmax(x) for x in model.predict(test_generator, batch_size=config['test_batch_size'])]
y_test = [labels[idx] for idx in test_indices]

In [None]:
df_original = pd.read_csv(f'{BASE_PATH}/original_metadata.csv')
print(df_original['label'].value_counts())

test_paths = np.array([f'{path}volume.npz' for path in df_original['path']])
test_labels = np.array(df_original['label'])
test_indices = df_original[df_original['test'] == 1].index.values.tolist()

# load all test_indices and run the evaluation
y_pred = []
for ID in test_indices:
    volume = np.load(test_paths[ID])['data']
    volume = volume.astype('float32')

    volume = np.expand_dims(volume, axis=3)
    volume = np.expand_dims(volume, axis=0)
    print(volume.shape)
    y_pred.append(np.argmax(model.predict(volume)))

y_test = test_labels
print(classification_report(y_test, y_pred, digits=3))

In [None]:
y_test = test_labels

## Classification Report

In [None]:
print(classification_report(y_test, y_pred, digits=3))

## Confusion Matrix

In [None]:
cfsn_matrix = confusion_matrix(y_test, y_pred)
df_cm = pd.DataFrame(cfsn_matrix, index=range(config['n_classes']), columns=CLASS_NAMES)
plt.figure(figsize=(15, 6))
sn.heatmap(df_cm, annot=True, linewidths=0.5, fmt='d')

# Grad-CAM 3D Visualizations

Now let us obtain a class activation heatmap for our image classification mode A detailed description of the procedure can be found in [Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization](https://arxiv.org/abs/1610.02391) paper.

**Gradient-weighted Class Activation Mapping (Grad-CAM)** employs the gradients of any target concept (for example, "tiger" in a classification network or a sequence of words in a captioning network) flowing into the final convolutional layer to generate a coarse localization map highlighting the important regions in the image for predicting the concept.

## Configurable Parameters

Several prior studies claim that deeper representations in a CNN capture higher-level visual constructs. Furthermore, because convolutional layers naturally preserve spatial information that is lost in fully-connected layers, we may anticipate the **last** convolutional layers to provide the best compromise of high-level semantics and detailed spatial information.

Use `model.summary()` to see the names of all layers in the mode These are necessary to get the value for `last_conv_layer_name`.

In [None]:
input_volume = test_generator.__getitem__(1)[0]
prediction = model.predict(input_volume)[0]
print(prediction)

volume_size = input_volume.shape
last_conv_layer_name = 'conv3d_6'

## Grad-CAM Algorithm

*Grad-CAM* is **class-discriminative**, meaning it can produce a separate visualization for every class present in the image. This is the reason why we introduce the `pred_index` argument. Keep in mind that when we do not pass any value to our `pred_index`, the generated heatmap will correspond to the class with the highest probability. It uses the gradient information flowing into the last convolutional layer of the CNN to assign importance values to each neuron for a particular decision of interest.

In [None]:
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    """Generate class activation heatmap"""
    # First, we create a model that maps the input image to the activations
    # of the last conv layer as well as the output predictions
    grad_model = Model(
        [model.inputs], [model.get_layer(
            last_conv_layer_name).output, model.output]
    )

    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel (equivalent to global average pooling)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2, 3))

    # We multiply each channel in the feature map array
    # by 'how important this channel is' with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    # Notice that we clip the heatmap values, which is equivalent to applying ReLU
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

## Heatmap Generation

In [None]:
# Remove last layer's activation
model.layers[-1].activation = None

# Print what the top predicted class is
input_volume = input_volume[0:1, :, :, :, :]
img_array = input_volume

preds = model.predict(img_array)
print('Predicted:', preds[0])

input_volume = np.squeeze(input_volume)
# Generate class activation heatmap
heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)

In [None]:
plt.matshow(np.squeeze(heatmap[:, :, 1]))
plt.show()

## Expanding Heatmap Dimensions

Notice that similarly to resizing the input volume, expanding the heatmap dimensions is based on the *spline interpolated zoom*.

In [None]:
def get_resized_heatmap(heatmap, shape):
    """Resize heatmap to shape"""
    # Rescale heatmap to a range 0-255
    upscaled_heatmap = np.uint8(255 * heatmap)

    upscaled_heatmap = zoom(
        upscaled_heatmap,
        (
            shape[0] / upscaled_heatmap.shape[0],
            shape[1] / upscaled_heatmap.shape[1],
            shape[2] / upscaled_heatmap.shape[2],
        ),
    )

    return upscaled_heatmap


# skipping the channel axis 
resized_heatmap = get_resized_heatmap(heatmap, input_volume.shape)

# Visualizations

Now it is time for us to graphically visualize the results obtained by overlaying the heatmap with the image. We utilize the `jet` colormap for this.

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 20))

ax[0].imshow(np.squeeze(input_volume[:, :, 30]), cmap='bone')
img0 = ax[1].imshow(np.squeeze(input_volume[:, :, 30]), cmap='bone')
img1 = ax[1].imshow(np.squeeze(resized_heatmap[:, :, 30]),
                    cmap='jet', alpha=0.3, extent=img0.get_extent())
plt.show()

## Bounding Boxes

Here we prepare some functions that we will use to annotate images by drawing bounding boxes around the regions of interest. The bounding boxes are drawn on the images using the coordinates of the obtained from the heatmap in the following format: `(x_center, y_center, width, height)`.

The process of obtaining the bounding boxes is as follows:

1. Obtain the coordinates of the heatmap in places where its values are above a certain threshold (optionally, we can exploit Otsu's method to automatically determine the threshold).
2. Find the connected components in the binary image obtained from the heatmap. Each connected component corresponds to a region of interest.
3. For each connected component, obtain the bounding box coordinates using the minimal up-right rectangle technique.
4. Draw the bounding boxes on the image.

In [None]:
def nms(boxes, thresh=0.3):
    """
    Non-maximum suppression: Greedily select high-scoring detections and
    skip detections that are significantly covered by a previously
    selected detection.

    Source: https://pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
    """
    # if there are no boxes, return an empty list
    if len(boxes) == 0:
        return []
    # if the bounding boxes integers, convert them to floats --
    # this is important since we'll be doing a bunch of divisions
    if boxes.dtype.kind == "i":
        boxes = boxes.astype("float")
    # initialize the list of picked indexes	
    pick = []
    # grab the coordinates of the bounding boxes
    x1 = boxes[:,0]
    y1 = boxes[:,1]
    x2 = boxes[:,2]
    y2 = boxes[:,3]
    # compute the area of the bounding boxes and sort the bounding
    # boxes by the bottom-right y-coordinate of the bounding box
    area = (x2 - x1 + 1) * (y2 - y1 + 1)
    idxs = np.argsort(y2)
    # keep looping while some indexes still remain in the indexes
    # list
    while len(idxs) > 0:
        # grab the last index in the indexes list and add the
        # index value to the list of picked indexes
        last = len(idxs) - 1
        i = idxs[last]
        pick.append(i)
        # find the largest (x, y) coordinates for the start of
        # the bounding box and the smallest (x, y) coordinates
        # for the end of the bounding box
        xx1 = np.maximum(x1[i], x1[idxs[:last]])
        yy1 = np.maximum(y1[i], y1[idxs[:last]])
        xx2 = np.minimum(x2[i], x2[idxs[:last]])
        yy2 = np.minimum(y2[i], y2[idxs[:last]])
        # compute the width and height of the bounding box
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)
        # compute the ratio of overlap
        overlap = (w * h) / area[idxs[:last]]
        # delete all indexes from the index list that have
        idxs = np.delete(idxs, np.concatenate(([last],
            np.where(overlap > thresh)[0])))
    # return only the bounding boxes that were picked using the
    # integer data type
    return boxes[pick].astype("int")

def get_bounding_boxes(heatmap, threshold=0.15, otsu=False):
    """Get bounding boxes from heatmap"""
    p_heatmap = np.copy(heatmap)

    if otsu:
        # Otsu's thresholding method to find the bounding boxes
        threshold, p_heatmap = cv2.threshold(
            heatmap, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
        )
    else:
        # Using a fixed threshold
        p_heatmap[p_heatmap < threshold * 255] = 0
        p_heatmap[p_heatmap >= threshold * 255] = 1

    # find the contours in the thresholded heatmap
    contours = cv2.findContours(p_heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = contours[0] if len(contours) == 2 else contours[1]

    # get the bounding boxes from the contours
    bboxes = []
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        bboxes.append([x, y, x + w, y + h])

    return nms(np.array(bboxes))


def get_bbox_patches(bboxes, color='r', linewidth=2):
    """Get patches for bounding boxes"""
    patches = []
    for bbox in bboxes:
        x1, y1, x2, y2 = bbox
        patches.append(
            Rectangle(
                (x1, y1),
                x2 - x1,
                y2 - y1,
                edgecolor=color,
                facecolor='none',
                linewidth=linewidth,
            )
        )
    return patches

In [None]:
# show the bounding boxes on the original image
fig, ax = plt.subplots(1, 2, figsize=(10, 20))

ax[0].imshow(np.squeeze(input_volume[:, :, 30]), cmap='bone')
img0 = ax[1].imshow(np.squeeze(input_volume[:, :, 30]), cmap='bone')
img1 = ax[1].imshow(np.squeeze(resized_heatmap[:, :, 30]),
                    cmap='jet', alpha=0.3, extent=img0.get_extent())

bboxes = get_bounding_boxes(np.squeeze(resized_heatmap[:, :, 30]))
patches = get_bbox_patches(bboxes)

for patch in patches:
    ax[1].add_patch(patch)

plt.show()

## Interactive Slice Viewers

_**Note:** unfortunately, *kaggle* does not currently handle preserving the widget state and embedding it into the static notebook preview, thus **dragging the slider has no effect on the displayed figures**. See the following discussions for more information: [#33450](https://www.kaggle.com/questions-and-answers/33450), [#42782](https://www.kaggle.com/product-feedback/42782), [#2360](https://github.com/jupyter-widgets/ipywidgets/issues/2360), [#13637754](https://stackoverflow.com/a/63575304/13637754). To see the interactive visualizations, simply copy the notebook and run it yourself._

In [None]:
def _draw_line(ax, coords, clr='g'):
    line = Path(coords, [Path.MOVETO, Path.LINETO])
    pp = PathPatch(line, linewidth=3, edgecolor=clr, facecolor='none')
    ax.add_patch(pp)


def _set_axes_labels(ax, axes_x, axes_y):
    ax.set_xlabel(axes_x)
    ax.set_ylabel(axes_y)
    ax.set_aspect('equal', 'box')


def _draw_bboxes(ax, heatmap):
    bboxes = get_bounding_boxes(heatmap, otsu=True)
    patches = get_bbox_patches(bboxes)
    for patch in patches:
        ax.add_patch(patch)


_rec_prop = dict(linewidth=5, facecolor='none')


def show_volume(vol, z, y, x, heatmap=None, alpha=0.3, fig_size=(6, 6)):
    """Show a slice of a volume with optional heatmap"""
    fig, axarr = plt.subplots(nrows=2, ncols=2, figsize=fig_size)
    v_z, v_y, v_x = vol.shape

    img0 = axarr[0, 0].imshow(vol[z, :, :], cmap='bone')
    if heatmap is not None:
        axarr[0, 0].imshow(
            heatmap[z, :, :], cmap='jet', alpha=alpha, extent=img0.get_extent()
        )
        _draw_bboxes(axarr[0, 0], heatmap[z, :, :])

    axarr[0, 0].add_patch(Rectangle((-1, -1), v_x, v_y, edgecolor='r', **_rec_prop))
    _draw_line(axarr[0, 0], [(x, 0), (x, v_y)], 'g')
    _draw_line(axarr[0, 0], [(0, y), (v_x, y)], 'b')
    _set_axes_labels(axarr[0, 0], 'X', 'Y')

    img1 = axarr[0, 1].imshow(vol[:, :, x].T, cmap='bone')
    if heatmap is not None:
        axarr[0, 1].imshow(
            heatmap[:, :, x].T, cmap='jet', alpha=alpha, extent=img1.get_extent()
        )
        _draw_bboxes(axarr[0, 1], heatmap[:, :, x].T)

    axarr[0, 1].add_patch(Rectangle((-1, -1), v_z, v_y, edgecolor='g', **_rec_prop))
    _draw_line(axarr[0, 1], [(z, 0), (z, v_y)], 'r')
    _draw_line(axarr[0, 1], [(0, y), (v_x, y)], "b")
    _set_axes_labels(axarr[0, 1], 'Z', 'Y')

    img2 = axarr[1, 0].imshow(vol[:, y, :], cmap='bone')
    if heatmap is not None:
        axarr[1, 0].imshow(
            heatmap[:, y, :], cmap='jet', alpha=alpha, extent=img2.get_extent()
        )
        _draw_bboxes(axarr[1, 0], heatmap[:, y, :])

    axarr[1, 0].add_patch(Rectangle((-1, -1), v_x, v_z, edgecolor='b', **_rec_prop))
    _draw_line(axarr[1, 0], [(0, z), (v_x, z)], 'r')
    _draw_line(axarr[1, 0], [(x, 0), (x, v_y)], 'g')
    _set_axes_labels(axarr[1, 0], 'X', 'Z')
    axarr[1, 1].set_axis_off()
    fig.tight_layout()


def interactive_show(volume, heatmap=None):
    """Show a volume interactively"""
    # transpose volume from (x, y, z) to (z, y, x)
    volume = np.transpose(volume, (2, 0, 1))
    if heatmap is not None:
        heatmap = np.transpose(heatmap, (2, 0, 1))
    vol_shape = volume.shape

    interact(
        lambda x, y, z: plt.show(show_volume(volume, z, y, x, heatmap)),
        z=IntSlider(min=0, max=vol_shape[0] - 1, step=1, value=int(vol_shape[0] / 2)),
        y=IntSlider(min=0, max=vol_shape[1] - 1, step=1, value=int(vol_shape[1] / 2)),
        x=IntSlider(min=0, max=vol_shape[2] - 1, step=1, value=int(vol_shape[2] / 2)),
    )

In [None]:
interactive_show(input_volume)

In [None]:
interactive_show(input_volume, resized_heatmap)

## Animations

In [None]:
rc('animation', html='jshtml')


def create_animation(array, case, heatmap=None, alpha=0.3):
    """Create an animation of a volume"""
    array = np.transpose(array, (2, 0, 1))
    if heatmap is not None:
        heatmap = np.transpose(heatmap, (2, 0, 1))
    fig = plt.figure(figsize=(4, 4))
    images = []
    for idx, image in enumerate(array):
        # plot image without notifying animation
        image_plot = plt.imshow(image, animated=True, cmap='bone')
        aux = [image_plot]
        if heatmap is not None:
            image_plot2 = plt.imshow(
                heatmap[idx], animated=True, cmap='jet', alpha=alpha, extent=image_plot.get_extent())
            aux.append(image_plot2)

            # add bounding boxes to the heatmap image as animated patches
            bboxes = get_bounding_boxes(heatmap[idx])
            patches = get_bbox_patches(bboxes)
            aux.extend(image_plot2.axes.add_patch(patch) for patch in patches)
        images.append(aux)

    plt.axis('off')
    plt.tight_layout()
    plt.subplots_adjust(top=0.90)
    plt.title(f'Patient ID: {case}', fontsize=16)
    ani = animation.ArtistAnimation(
        fig, images, interval=5000//len(array), blit=False, repeat_delay=1000)
    plt.close()
    return ani

In [None]:
create_animation(input_volume, 'Test')

In [None]:
create_animation(input_volume, 'Test', heatmap=resized_heatmap)