In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Display
from IPython.display import Image, display
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns

In [2]:
PREFIX = '.'
SHAPE = (197, 233, 189)
MODEL_PATH = f'{PREFIX}/models/vit/mri_activation/vit_vit_(197, 233, 189)_[128]_0.0001_32_0.1_18_10_361_256_4_[512, 256]_0.15_checkpoint.h5'
TF_RECORD_PATH = f'{PREFIX}/data/tfrecords/tf_dataset_5.tfrecord'

CLASSES = { 'PD': 1, 'Eat': 0, 'Buy': 2, 'Sex': 3, 'Gamble': 4 }

In [3]:
def parse_example(example):
  features = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenFeature([], tf.int64)
  }
  example = tf.io.parse_single_example(example, features)
  image = tf.io.decode_raw(example['image'], tf.float32)
  image = tf.reshape(image, SHAPE)
  image = tf.reshape(image, (SHAPE[0], SHAPE[1], SHAPE[2]))
  #one_hot = tf.one_hot(example['label'], 2)
  return image, example['label']

def load_tfrecord(tfrecord_path):
  dataset = tf.data.TFRecordDataset(tfrecord_path)
  dataset = dataset.map(parse_example)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  #dataset = dataset.cache()
  return dataset

In [12]:
class MLP(layers.Layer):
  def __init__(self, hidden_units=[128, 64], dropout_rate=0.1, activation="gelu", kernel_regularizer=None, **kwargs):
    super(MLP, self).__init__(**kwargs)
    self.dropout_rate = dropout_rate
    self.hidden_units = hidden_units
    self.activation = activation
    self.kernel_regularizer = kernel_regularizer
    for i, units in enumerate(hidden_units):
      setattr(self, f'dense_{i}', layers.Dense(units, activation=self.activation, kernel_regularizer=kernel_regularizer))
      setattr(self, f'dropout_{i}', layers.Dropout(dropout_rate))

  def get_config(self, **kwargs):
    config = super().get_config()
    config_dict = {
        "dropout_rate": self.dropout_rate,
        "hidden_units": self.hidden_units,
        "activation": self.activation,
        "kernel_regularizer": self.kernel_regularizer
    }
    config.update(config_dict)
    return config

  def call(self, x, training=None, **kwargs):
    for i, _ in enumerate(self.hidden_units):
      x = getattr(self, f'dense_{i}')(x)
      x = getattr(self, f'dropout_{i}')(x, training=training)
    return x

class Patches(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def get_config(self, **kwargs):
      config = super().get_config()
      config_dict = {
          "patch_size": self.patch_size
      }
      config.update(config_dict)
      return config

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim=64, **kwargs):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection_dim = projection_dim
        self.projection = layers.Dense(units=self.projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=self.projection_dim
        )

    def get_config(self):
      config = super().get_config()
      config_dict = {
          "num_patches": self.num_patches,
          "projection_dim": self.projection_dim
      }

      config.update(config_dict)
      return config

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

def plot_3d_array_image(volume_img_array, slice_to_plot=64, cmap='gray'):
    # rotate the volume for better viewing
    volume_img_array = np.rot90(volume_img_array, 1, (0, 2))
    volume_img_array = np.rot90(volume_img_array, 5, (2, 1))
    # flip horizontally
    volume_img_array = np.flip(volume_img_array, axis=2)
    #plt.imshow(volume_img_array[slice_to_plot], cmap=cmap)
    #plt.show()
    return volume_img_array

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    # First, we create a model that maps the input image to the activations
    # of the last conv layer as well as the output predictions
    grad_model = tf.keras.models.Model(
    [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )
    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(np.array([img_path]))
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
            #print(preds, preds.shape)
            #print('Predicted class:', pred_index.numpy())
        class_channel = preds[:, pred_index]
        #print('class_channel', class_channel, class_channel.shape)
    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)
    #print('last_conv_layer_output', last_conv_layer_output.shape)
    #print('grads', grads.shape)

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    #pooled_grads = tf.reduce_mean(grads, axis=(0,))
    #pooled_grads = tf.squeeze(pooled_grads)
    #return tf.maximum(pooled_grads, 0) / tf.math.reduce_max(pooled_grads)
    # axis swap
    #pooled_grads = tf.transpose(pooled_grads, (0, 2, 1))
    #print('pooled_grads', pooled_grads.shape)

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    last_conv_layer_output = last_conv_layer_output[0]
    #print('last_conv_layer_output', last_conv_layer_output.shape)
    #print('pooled_grads new axis', pooled_grads[..., tf.newaxis].shape)
    #print('pooled_grads[..., tf.newaxis]', pooled_grads[..., tf.newaxis].shape)
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    #heatmap = last_conv_layer_output @ pooled_grads

    #print('heatmap', heatmap.shape)
    heatmap = tf.squeeze(heatmap)
    print('heatmap', heatmap.shape)

    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    #max_value = tf.math.reduce_max(heatmap)
    #min_value = tf.math.reduce_min(heatmap)
    #heatmap = (heatmap - min_value) / (max_value - min_value)
    return heatmap

def make_gradcam_heatmap_3D(img_array, model, last_conv_layer_name, pred_index=None):
    # First, we create a model that maps the input image to the activations
    # of the last conv layer as well as the output predictions
    grad_model = tf.keras.models.Model(
    [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )
    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(np.array([img_path]))
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
            #print(preds, preds.shape)
            #print('Predicted class:', pred_index.numpy())
        class_channel = preds[:, pred_index]
        #print('class_channel', class_channel, class_channel.shape)
    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)
    #print('last_conv_layer_output', last_conv_layer_output.shape)
    #print('grads', grads.shape)

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    #pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    pooled_grads = tf.reduce_mean(grads, axis=(0,))
    pooled_grads = tf.squeeze(pooled_grads)
    return tf.maximum(pooled_grads, 0) / tf.math.reduce_max(pooled_grads)
    # axis swap
    #pooled_grads = tf.transpose(pooled_grads, (0, 2, 1))
    #print('pooled_grads', pooled_grads.shape)

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    last_conv_layer_output = last_conv_layer_output[0]
    #print('last_conv_layer_output', last_conv_layer_output.shape)
    #print('pooled_grads new axis', pooled_grads[..., tf.newaxis].shape)
    #print('pooled_grads[..., tf.newaxis]', pooled_grads[..., tf.newaxis].shape)
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    #heatmap = last_conv_layer_output @ pooled_grads

    #print('heatmap', heatmap.shape)
    heatmap = tf.squeeze(heatmap)
    print('heatmap', heatmap.shape)

    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    #max_value = tf.math.reduce_max(heatmap)
    #min_value = tf.math.reduce_min(heatmap)
    #heatmap = (heatmap - min_value) / (max_value - min_value)
    return heatmap

In [5]:
tf_dataset = load_tfrecord(TF_RECORD_PATH)

all_items = 3286
train_size = int(all_items * 0.8)
val_size = int(all_items - train_size)

train_dataset = tf_dataset.take(train_size)
val_dataset = tf_dataset.skip(train_size)

mri_classes = {}

MRI_PER_CLASS = 5
print('MRI_PER_CLASS', MRI_PER_CLASS, val_size)
for x, y in val_dataset:
  mri_y_array = y.numpy()

  list_mri_class = mri_classes.get(mri_y_array, [])
  
  #if len(list_mri_class) < MRI_PER_CLASS:
  list_mri_class.append(x.numpy())
  mri_classes[mri_y_array] = list_mri_class
  continue

  completed = True

  for key_class in mri_classes:
    if len(mri_classes[key_class]) < MRI_PER_CLASS:
      completed = False
      break

  if len(mri_classes) == 5 and completed:
    break


MRI_PER_CLASS 5 658


In [6]:
img_size = SHAPE
preprocess_input = keras.applications.xception.preprocess_input
decode_predictions = keras.applications.xception.decode_predictions

model = tf.keras.models.load_model(
  MODEL_PATH,
  custom_objects={ 'MLP': MLP, 'Patches': Patches, 'PatchEncoder': PatchEncoder },
  compile=False
)

In [11]:
mri_classes[0][0].shape

(197, 233, 189)

In [16]:
last_conv_layer_name = model.layers[1].name

model.layers[-1].activation = None

SAVE_DIR = './data/activations'

for label_class, id_class in CLASSES.items():
  print('Class:', label_class)
  os.makedirs(os.path.join(SAVE_DIR, label_class), exist_ok=True)
  
  list_mri_classes = mri_classes.get(id_class, [])
  if len(list_mri_classes) == 0:
    continue

  for idx_subject, img_path in enumerate(list_mri_classes):
    slice_to_plot = 124
    print('idx_subject', idx_subject)

    subject_folder = os.path.join(SAVE_DIR, label_class, f'subject_{idx_subject}')
    os.makedirs(subject_folder, exist_ok=True)

    mri_path = os.path.join(subject_folder, 'mri.npy')
    activation_path = os.path.join(subject_folder, 'activation.npy')
    

    #original_volumen = plot_3d_array_image(img_path, slice_to_plot, cmap='gray')
    gradcam_heatmap = make_gradcam_heatmap(img_path, model, last_conv_layer_name, pred_index=None)
    
    # plot heatmap
    plt.imshow(gradcam_heatmap, cmap='crest')
    # save heatmap
    plt.savefig(activation_path.replace('.npy', '.png'), dpi=300)
    plt.close()
    #gradcam_heatmap = plot_3d_array_image(gradcam_heatmap, slice_to_plot, cmap='gray')

    #np.save(mri_path, original_volumen)
    #np.save(activation_path, gradcam_heatmap)

    #fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    #ax1.imshow(original_volumen[slice_to_plot], cmap='gray')
    #plt.colorbar(ax1.get_images()[0], ax=ax1, orientation='vertical', fraction=0.046, pad=0.04)

    #ax2.imshow(original_volumen[slice_to_plot], cmap='gray')
    #ax2.imshow(gradcam_heatmap[slice_to_plot], cmap='crest', alpha=1)
    #plt.colorbar(ax2.get_images()[0], ax=ax2, orientation='vertical', fraction=0.046, pad=0.04)

    #fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    #ax1.imshow(original_volumen[slice_to_plot], cmap='gray')
    #plt.colorbar(ax1.get_images()[0], ax=ax1, orientation='vertical', fraction=0.046, pad=0.04)

    #ax2.imshow(original_volumen[slice_to_plot], cmap='gray')
    #ax2.imshow(gradcam_heatmap, cmap='crest', alpha=1)
    #plt.colorbar(ax2.get_images()[0], ax=ax2, orientation='vertical', fraction=0.046, pad=0.04)
    
    #plt.show()

    #plt.imshow(original_volumen[slice_to_plot], cmap='gray')
    #plt.imshow(gradcam_heatmap, cmap='crest', alpha=0.75)
    #plt.show()
    #sns.heatmap(gradcam_heatmap, cmap='crest')
    #plt.show()


Class: PD
idx_subject 0
heatmap (197, 233)
idx_subject 1
heatmap (197, 233)
idx_subject 2
heatmap (197, 233)
idx_subject 3
heatmap (197, 233)
idx_subject 4
heatmap (197, 233)
idx_subject 5
heatmap (197, 233)
idx_subject 6
heatmap (197, 233)
idx_subject 7
heatmap (197, 233)
idx_subject 8
heatmap (197, 233)
idx_subject 9
heatmap (197, 233)
idx_subject 10
heatmap (197, 233)
idx_subject 11
heatmap (197, 233)
idx_subject 12
heatmap (197, 233)
idx_subject 13
heatmap (197, 233)
idx_subject 14
heatmap (197, 233)
idx_subject 15
heatmap (197, 233)
idx_subject 16
heatmap (197, 233)
idx_subject 17
heatmap (197, 233)
idx_subject 18
heatmap (197, 233)
idx_subject 19
heatmap (197, 233)
idx_subject 20
heatmap (197, 233)
idx_subject 21
heatmap (197, 233)
idx_subject 22
heatmap (197, 233)
idx_subject 23
heatmap (197, 233)
idx_subject 24
heatmap (197, 233)
idx_subject 25
heatmap (197, 233)
idx_subject 26
heatmap (197, 233)
idx_subject 27
heatmap (197, 233)
idx_subject 28
heatmap (197, 233)
idx_subject 29

In [None]:
gradcam_heatmap

In [None]:
last_conv_layer_name = model.layers[1].name

img_path = mri_classes[1]
slice_to_plot = 64

plot_3d_array_image(img_path, slice_to_plot, cmap='gray')

gradcam_heatmap = make_gradcam_heatmap(img_path, model, last_conv_layer_name, pred_index=None)

sns.heatmap(gradcam_heatmap, cmap='crest')
plt.show()

In [None]:
gradcam_heatmap

In [None]:
last_conv_layer_name = model.layers[1].name

grad_model = tf.keras.models.Model(
[model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
)
# Then, we compute the gradient of the top predicted class for our input image
# with respect to the activations of the last conv layer
with tf.GradientTape() as tape:
    last_conv_layer_output, preds = grad_model(np.array([img_path]))
    if pred_index is None:
        pred_index = tf.argmax(preds[0])
        print(preds, preds.shape)
        print('Predicted class:', pred_index.numpy())
    class_channel = preds[:, pred_index]
    print('class_channel', class_channel, class_channel.shape)
# This is the gradient of the output neuron (top predicted or chosen)
# with regard to the output feature map of the last conv layer
grads = tape.gradient(class_channel, last_conv_layer_output)
print('last_conv_layer_output', last_conv_layer_output.shape)
print('grads', grads.shape)

# This is a vector where each entry is the mean intensity of the gradient
# over a specific feature map channel
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
print('pooled_grads', pooled_grads.shape)

# We multiply each channel in the feature map array
# by "how important this channel is" with regard to the top predicted class
# then sum all the channels to obtain the heatmap class activation
last_conv_layer_output = last_conv_layer_output[0]
print('last_conv_layer_output', last_conv_layer_output.shape)
print('pooled_grads[..., tf.newaxis]', pooled_grads[..., tf.newaxis].shape)
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
print('heatmap', heatmap.shape)
heatmap = tf.squeeze(heatmap)
print('heatmap', heatmap.shape)

plt.matshow(heatmap)
plt.show()