In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from sklearn.model_selection import train_test_split

from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, Concatenate, Dropout, PReLU

from sklearn.metrics import mean_squared_error, mean_absolute_error

%matplotlib inline

In [2]:
# Motivation:
# https://github.com/MiguelMonteiro/VNet-Tensorflow/blob/master/VNet.py

# https://paperswithcode.com/method/prelu
    
# PReLU layer
# f(x) = alpha * x for x < 0
# f(x) = x for x >= 0

def conv3d_block(input_tensor, n_filters, kernel_size_ip = 3, batchnorm = True):
    # first layer
    x = layers.Conv3D(filters = n_filters, 
                      kernel_size = (kernel_size_ip, 
                                     kernel_size_ip, 
                                     kernel_size_ip), 
                      kernel_initializer = 'he_normal', 
                      padding = 'same')(input_tensor)
    if batchnorm:
        x = layers.BatchNormalization()(x)
        
    x = layers.PReLU()(x)
    
    # second layer
    x = layers.Conv3D(filters = n_filters, 
                      kernel_size = (kernel_size_ip, 
                                     kernel_size_ip, 
                                     kernel_size_ip), 
                      kernel_initializer = 'he_normal',
                      padding = 'same')(x)
    if batchnorm:
        x = layers.BatchNormalization()(x)
        
    x = layers.PReLU()(x)
    return x

def vnet(input_img, n_filters = 64, dropout = 0.5, batchnorm = True):
    # Contracting Path
    c1 = conv3d_block(input_img, n_filters * 1, kernel_size_ip = 3, batchnorm = batchnorm)
    p1 = layers.MaxPooling3D((2, 2, 2))(c1)
    p1 = layers.Dropout(dropout)(p1)
  
    c2 = conv3d_block(p1, n_filters * 2, kernel_size_ip = 3, batchnorm = batchnorm)
    p2 = layers.MaxPooling3D((2, 2, 2))(c2)
    p2 = layers.Dropout(dropout)(p2)
  
    c3 = conv3d_block(p2, n_filters * 4, kernel_size_ip = 3, batchnorm = batchnorm)

    # Expansive Path
    u4 = layers.Conv3DTranspose(n_filters * 2, (3, 3, 3), strides = (2, 2, 2), padding = 'same')(c3)
    u4 = layers.concatenate([u4, c2])
    u4 = layers.Dropout(dropout)(u4)
    c4 = conv3d_block(u4, n_filters * 2, kernel_size_ip = 3, batchnorm = batchnorm)

    u5 = layers.Conv3DTranspose(n_filters * 1, (3, 3, 3), strides = (2, 2, 2), padding = 'same')(c4)
    u5 = layers.concatenate([u5, c1])
    u5 = layers.Dropout(dropout)(u5)
    c5 = conv3d_block(u5, n_filters * 1, kernel_size_ip = 3, batchnorm = batchnorm)
  
    outputs = layers.Conv3D(1, (1, 1, 1), activation='linear')(c5)
    model = tf.keras.Model(inputs=[input_img], outputs=[outputs])
    return model

In [None]:
# the simulations can be generated by data_generation_nb notebook.
simulated_data = np.load('./generated_data/simulated_data_020124_150K.npy')

In [None]:
os.environ['PYTHONHASHSEED']=str(159)

np.random.seed(159)
tf.random.set_seed(159)

In [None]:
# Split the data
initial_state = simulated_data[:, 0]
final_state = simulated_data[:, 1]

# Compute the mean and std of the initial and final states
initial_state_mean, initial_state_std = np.mean(initial_state), np.std(initial_state)
final_state_mean, final_state_std = np.mean(final_state), np.std(final_state)

# Normalize the initial and final states
initial_state = (initial_state - initial_state_mean) / initial_state_std
final_state = (final_state - final_state_mean) / final_state_std

# Select one simulation's initial and final states
initial_conditions = initial_state[0]
final_density = final_state[0]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# initial state and final state
axes[0].imshow(initial_conditions[0], cmap='viridis')
axes[0].set_title("Initial State")

axes[1].imshow(final_density[0], cmap='viridis')  
axes[1].set_title("Final State")

plt.show()

In [None]:
# train and test
X_train, X_test, y_train, y_test  = train_test_split(initial_state, final_state, test_size=0.15, random_state=123)

# further splitting it in training and testing
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=123)

In [None]:
# Recasting dimensions to 4D tensor (suitable for 3D convolutions) with the following dimensions: (batch_size, height, width, depth, num_channels)
X_train = np.reshape(X_train, (*X_train.shape, 1))
X_val = np.reshape(X_val, (*X_val.shape, 1))
X_test = np.reshape(X_test, (*X_test.shape, 1))

y_train = np.reshape(y_train, (*y_train.shape, 1))
y_val = np.reshape(y_val, (*y_val.shape, 1))
y_test = np.reshape(y_test, (*y_test.shape, 1))

In [None]:
X_train.shape, X_val.shape, y_train.shape, y_val.shape, X_test.shape, y_test.shape

In [None]:
# saving in case the instance failed
# np.save('X_test.npy', X_test)
# np.save('y_test.npy', y_test)

## Instantiate VNet model

In [None]:
# Parameters
Ngrid = 32  # grid size

In [None]:
input_shape = (32, 32, 32, 1) 
input_img = tf.keras.Input(shape=input_shape)

In [None]:
vnet_model = vnet(input_img, n_filters=128, dropout=0.5, batchnorm=True)

In [None]:
# Compile the model
vnet_model.compile(optimizer=Adam(), 
                   loss='mse',
                   metrics=['mae'],
                   run_eagerly=True)

In [None]:
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
checkpoint = ModelCheckpoint('vnet.h5', monitor='loss', verbose=1, save_best_only=True)

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print(tf.reduce_sum(tf.random.normal([1000, 1000])))
print(tf.config.list_physical_devices('GPU'))

In [None]:
history = vnet_model.fit(x=y_train, 
                         y=X_train, 
                         validation_data=(y_val, X_val), 
                         batch_size=32, 
                         epochs=50, 
                         callbacks=[early_stopping, reduce_lr, checkpoint])

In [None]:
# Load the arrays from .npy files in case the instance needs to be started from here.
# y_test = np.load('y_test.npy')
# X_test = np.load('X_test.npy')

In [None]:
# Load the model
# as the kernel was failing, the notebook was started from here again.
# vnet = load_model('old_vnet_model_5K.h5')
vnet = load_model('vnet_model.h5')

In [None]:
vnet.summary()

In [None]:
eval_results = vnet.evaluate(x=X_test, y=y_test)
print(f"Evaluation Results:{eval_results}")

In [None]:
predictions = vnet.predict(X_test)
predictions.shape

In [None]:
np.min(predictions), np.min(X_test), np.max(predictions), np.max(X_test)

In [None]:
vmin = min(np.min(predictions), np.min(X_test))
vmax = max(np.max(predictions), np.max(X_test))

vmin, vmax

In [None]:
for test_index in range(5):  # Plot first 5 test cases
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # Plot predicted initial state
    im = axs[0].imshow(predictions[test_index].reshape(Ngrid, Ngrid, Ngrid)[Ngrid//2], vmin=vmin, vmax=vmax)
    axs[0].set_title('Predicted Initial State')
    fig.colorbar(im, ax=axs[0])

    # Plot true initial state
    im = axs[1].imshow(y_test[test_index].reshape(Ngrid, Ngrid, Ngrid)[Ngrid//2], vmin=vmin, vmax=vmax)
    axs[1].set_title('True Initial State')
    fig.colorbar(im, ax=axs[1])

    # Plot given final state 
    im = axs[2].imshow(X_test[test_index].reshape(Ngrid, Ngrid, Ngrid)[Ngrid//2], vmin=vmin, vmax=vmax)
    axs[2].set_title('Given Final State')
    fig.colorbar(im, ax=axs[2])

    plt.show()

In [None]:
# calculating the error metrics
mse = mean_squared_error(y_test.flatten(), predictions.flatten())
mae = mean_absolute_error(y_test.flatten(), predictions.flatten())

print(f"Mean Squared Error (MSE): {mse}")
print(f"Mean Absolute Error (MAE): {mae}")

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(y_test.flatten(), predictions.flatten(), alpha=0.3)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=4)
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Scatter plot of True vs Predicted values')
plt.show()

------
# Script Complete