# Install and import dependencies

In [0]:
!pip install h5py
!pip install tensorflow-gpu
!pip install wandb

In [0]:
import numpy as np
import os
import time
import matplotlib.pyplot as plt
import h5py
import wandb
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPool2D, UpSampling2D, Concatenate, Input, BatchNormalization
from tensorflow.keras.activations import relu, softmax
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import load_model
from wandb.keras import WandbCallback

tf.test.gpu_device_name()

# Fetching data from google drive

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
data = h5py.File('/content/drive/My Drive/Colab Notebooks/heart_chambers_data.h5', 'r')
print(list(data.keys()))

In [0]:
x_train = data['x_train'].value
y_train = data['y_train'].value
x_val = data['x_val'].value
y_val = data['y_val'].value
x_test = data['x_test'].value
y_test = data['y_test'].value

# Image data generator - augmentation

In [0]:
train_datagen = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

train_image_generator = train_datagen.flow(x_train, y_train, batch_size=32)

# Convolution blocks for UNet

In [0]:
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)

    p = MaxPool2D((2, 2), (2, 2))(c)
    return c, p


def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
    us = UpSampling2D((2, 2))(x)
    concat = Concatenate()([us, skip])

    c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
    c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)

    return c


def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
    c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)

    return c

# UNet model function

In [0]:
def create_UNet():
    f = [16, 32, 64, 128, 256, 512]
    inputs = Input((256, 256, 1))
    
    p0 = inputs
    c1, p1 = down_block(p0, f[0]) #256 -> 128
    c2, p2 = down_block(p1, f[1]) #128 -> 64
    c3, p3 = down_block(p2, f[2]) #64 -> 32
    c4, p4 = down_block(p3, f[3]) #32 -> 16
    c5, p5 = down_block(p4, f[4]) #16 -> 8
    
    bn = bottleneck(p5, f[5])
    
    u1 = up_block(bn, c5, f[4]) #8 -> 16
    u2 = up_block(u1, c4, f[3]) #16 -> 32
    u3 = up_block(u2, c3, f[2]) #32 -> 64
    u4 = up_block(u3, c2, f[1]) #64 -> 128
    u5 = up_block(u4, c1, f[0]) #128 -> 256
    
    outputs = Conv2D(4, (1, 1), padding="same", activation=softmax)(u5)
    model = Model(inputs, outputs)
    return model

# Decoding one_hot_encoding function:

In [0]:
threshold = 0.8

def one_hot_decode(mask):
  decoded_mask = np.zeros((256,256))
  for i in range(4):
    for row in range(256):
      for col in range(256):
        if mask[row, col, i] > threshold:
          decoded_mask[row, col] = i * 85

  return decoded_mask

# Train, evaluate, log and save model

In [0]:
epochs = 500

run = wandb.init(project="Unet-heart-chamber", entity="damirj")

model = create_UNet()
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=[MeanIoU(4, name="mean_i_o_u")])
model.fit(train_image_generator, epochs=epochs, validation_data=(x_val, y_val), callbacks=[WandbCallback()])

test_evaluation = model.evaluate(x_test, y_test, batch_size=50, verbose=0)
wandb.log({"Test accuracy": test_evaluation[1]})

predicted = model.predict(x_test)
stacked_images = []
for i in range(0, 80, 8):
  original_image = x_test[i].reshape((256,256)) * 255
  ground_truth_mask = one_hot_decode(y_test[i])
  prediction = one_hot_decode(predicted[i])
  prediction_eval = model.evaluate(x_test[i].reshape((1,256,256,1)), y_test[i].reshape((1,256,256,4)), verbose=0)
  stacked_img = np.hstack((original_image, ground_truth_mask, prediction))
  stacked_images.append(wandb.Image(stacked_img, caption="accuracy: {}".format(prediction_eval[1])))

wandb.log({"Predictions (Original, Ground truth, Prediction)": stacked_images})

model.save(os.path.join(wandb.run.dir, "model.h5"))