In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, Concatenate, Dropout

from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

from sklearn.metrics import mean_squared_error, mean_absolute_error

%matplotlib inline

In [None]:
# the simulations can be generated by data_generation_nb notebook.
simulated_data = np.load('./generated_data/simulated_data_010124_5K.npy')

In [None]:
os.environ['PYTHONHASHSEED']=str(159)

np.random.seed(159)
tf.random.set_seed(159)

In [None]:
# Split the data
initial_state = simulated_data[:, 0]
final_state = simulated_data[:, 1]

# Compute the mean and std of the initial and final states
initial_state_mean, initial_state_std = np.mean(initial_state), np.std(initial_state)
final_state_mean, final_state_std = np.mean(final_state), np.std(final_state)

# Normalize the initial and final states
initial_state = (initial_state - initial_state_mean) / initial_state_std
final_state = (final_state - final_state_mean) / final_state_std

# Select one simulation's initial and final states
initial_conditions = initial_state[0]
final_density = final_state[0]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# initial state and final state
axes[0].imshow(initial_conditions[0], cmap='viridis')
axes[0].set_title("Initial State")

axes[1].imshow(final_density[0], cmap='viridis')  
axes[1].set_title("Final State")

plt.show()

In [None]:
# train and test
X_train, X_test, y_train, y_test  = train_test_split(initial_state, final_state, test_size=0.15, random_state=123)

# further splitting it in training and testing
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=123)

In [None]:
# Recasting dimensions to 4D tensor (suitable for 3D convolutions) with the following dimensions: (batch_size, height, width, depth, num_channels)
X_train = np.reshape(X_train, (*X_train.shape, 1))
X_val = np.reshape(X_val, (*X_val.shape, 1))
X_test = np.reshape(X_test, (*X_test.shape, 1))

y_train = np.reshape(y_train, (*y_train.shape, 1))
y_val = np.reshape(y_val, (*y_val.shape, 1))
y_test = np.reshape(y_test, (*y_test.shape, 1))

In [None]:
X_train.shape, X_val.shape, y_train.shape, y_val.shape, X_test.shape, y_test.shape

In [None]:
# saving in case the instance failed
# np.save('X_unet_test.npy', X_test)
# np.save('y_unet_test.npy', y_test)

## Instantiate UNet model

In [None]:
# Parameters
Ngrid = 32

In [None]:
def create_volumetric_unet(input_shape, num_classes):
    # input layer
    inputs = Input(input_shape)

    # Contracting Path
    conv1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv1)
    drop1 = Dropout(0.5)(conv1)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2))(drop1)
    
    conv2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv2)
    drop2 = Dropout(0.5)(conv2)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2))(drop2)

    conv3 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv3)

    # Expansive path
    up1 = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv3), conv2])
    conv4 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(up1)
    conv4 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv4)

    up2 = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv4), conv1])
    conv5 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(up2)
    conv5 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv5)

    # Final conv layer to reduce channel dimensionality to number of classes
    conv_final = Conv3D(num_classes, (1, 1, 1))(conv5)

    return Model(inputs=[inputs], outputs=[conv_final])

In [None]:
unet_model = create_volumetric_unet((Ngrid, Ngrid, Ngrid, 1), num_classes=1)

In [None]:
# Compile the model
unet_model.compile(optimizer=Adam(),
                   loss='mse',
                   run_eagerly=True)

In [None]:
# Define callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=5)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
model_checkpoint = ModelCheckpoint('unet.hdf5', monitor='loss',verbose=1, save_best_only=True)

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print(tf.reduce_sum(tf.random.normal([1000, 1000])))
print(tf.config.list_physical_devices('GPU'))

In [None]:
# Train the model
history = unet_model.fit(x=X_train,
                    y=y_train,
                    validation_data=(X_val, y_val),
                    batch_size=32, 
                    epochs=50,
                    callbacks=[early_stopping, reduce_lr, model_checkpoint])

In [None]:
# Load the model in case the instance crashed
# unet_model = load_model('unet.hdf5')

In [None]:
# Evaluate the model
eval_results = unet_model.evaluate(x=X_test, y=y_test)
print(f"Evaluation Results:{eval_results}")

In [None]:
# Predict with the model
predictions = unet_model.predict(X_test)
predictions.shape

In [None]:
vmin = min(np.min(predictions), np.min(X_test))
vmax = max(np.max(predictions), np.max(X_test))

vmin, vmax

In [None]:
for test_index in range(5):  # Plot first 5 test cases
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # Plot predicted initial state
    im = axs[0].imshow(predictions[test_index].reshape(Ngrid, Ngrid, Ngrid)[Ngrid//2], vmin=vmin, vmax=vmax)
    axs[0].set_title('Predicted Initial State')
    fig.colorbar(im, ax=axs[0])

    # Plot true initial state
    im = axs[1].imshow(y_test[test_index].reshape(Ngrid, Ngrid, Ngrid)[Ngrid//2], vmin=vmin, vmax=vmax)
    axs[1].set_title('True Initial State')
    fig.colorbar(im, ax=axs[1])

    # Plot given final state 
    im = axs[2].imshow(X_test[test_index].reshape(Ngrid, Ngrid, Ngrid)[Ngrid//2], vmin=vmin, vmax=vmax)
    axs[2].set_title('Given Final State')
    fig.colorbar(im, ax=axs[2])

    plt.show()

In [None]:
# calculating the error metrics
mse = mean_squared_error(y_test.flatten(), predictions.flatten())
mae = mean_absolute_error(y_test.flatten(), predictions.flatten())

print(f"Mean Squared Error (MSE): {mse}")
print(f"Mean Absolute Error (MAE): {mae}")

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(y_test.flatten(), predictions.flatten(), alpha=0.3)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=4)
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title('Scatter plot of True vs Predicted values')
plt.show()