In [None]:
# Imports 
import numpy as np 
import matplotlib.pyplot as plt
import keras
from keras import layers
import io 
import imageio
from IPython.display import Image, display
from ipywidgets import widgets, Layout, HBox
from keras.callbacks import ModelCheckpoint
from keras.models import load_model
import tensorflow as tf



In [None]:
# Load train and validation datasets
train_dataset = np.load("train_dataset_128.npy")
val_dataset = np.load("val_dataset_128.npy")

# Normalize the data
# train_dataset = train_dataset / 255.0
# val_dataset = val_dataset / 255.0

# Print dataset shapes
print(f"Training Dataset Shapes: {train_dataset.shape}")
print(f"Validation Dataset Shapes: {val_dataset.shape}")

# Define a helper function to shift the frames
def create_shifted_frames(data):
    x = data[:, :-1, :, :]
    y = data[:, 1:, :, :]
    return x, y

# Apply the processing function to the datasets
x_train, y_train = create_shifted_frames(train_dataset)
x_val, y_val = create_shifted_frames(val_dataset)

# Print training and validation dataset information
print(f"Training Dataset Shapes: {x_train.shape}, {y_train.shape}")
print(f"Validation Dataset Shapes: {x_val.shape}, {y_val.shape}")


In [None]:
# Visualize a sample sequence from the dataset
%matplotlib inline

# Set up the plot
fig, axes = plt.subplots(4, 5, figsize=(10, 8))

# Select a random data example
data_choice = np.random.randint(len(train_dataset))

# Plot each frame of the selected example
for idx, ax in enumerate(axes.flat):
    ax.imshow(np.squeeze(train_dataset[data_choice][idx]), cmap="gray")
    ax.set_title(f"Frame {idx + 1}")
    ax.axis("off")

print(f"Displaying frames for example {data_choice}.")
plt.tight_layout()
plt.show()


In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint

# Define model information
model_info = '_128_SSIM'

# Set up the ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    filepath=f'conv_lstm_model_128_epoch_{{epoch:02d}}{model_info}.h5', 
    save_freq='epoch',   
    save_weights_only=False,  
    verbose=1  
)

In [None]:
from tensorflow.image import ssim

def ssim_loss(true, pred):
    """
    SSIM loss function.
    
    Args:
    true (tensor): Ground truth images.
    pred (tensor): Predicted images.
    
    Returns:
    float: SSIM loss value.
    """
    # SSIM is typically calculated on a per image basis, and values range between -1 and 1.
    # The original SSIM index is a measure of similarity, but for a loss function, we need a dissimilarity measure.
    # We can convert it by subtracting the SSIM from 1.
    return 1 - tf.reduce_mean(ssim(true, pred, max_val=1.0))


In [None]:
import tensorflow.keras.backend as K

def dice_coef(y_true, y_pred, smooth=1):
    """
    Compute the Dice coefficient, a measure of overlap between the true and predicted binary masks.
    
    Args:
    y_true (tensor): Ground truth binary mask.
    y_pred (tensor): Predicted binary mask.
    smooth (float): Smoothing factor to avoid division by zero.
    
    Returns:
    float: Dice coefficient.
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return dice

def dice_coef_loss(y_true, y_pred):
    """
    Compute the Dice loss, which is 1 minus the Dice coefficient.
    
    Args:
    y_true (tensor): Ground truth binary mask.
    y_pred (tensor): Predicted binary mask.
    
    Returns:
    float: Dice loss.
    """
    return 1 - dice_coef(y_true, y_pred)


In [8]:
from tensorflow.keras.callbacks import ModelCheckpoint

# Define model information
model_info = '_128_SSIM'

# Set up the ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    filepath=f'conv_lstm_model_128_epoch_{{epoch:02d}}{model_info}.h5', 
    save_freq='epoch',   
    save_weights_only=False,  
    verbose=1  
)

In [None]:
from keras import layers, models, optimizers
from keras import metrics

def build_conv_lstm_model(input_shape):
    """
    Build a ConvLSTM model.
    
    Args:
    input_shape (tuple): Shape of the input data (frames_per_sequence, width, height, channels).
    
    Returns:
    model: A compiled Keras model.
    """
    inputs = layers.Input(shape=(None, *input_shape))

    # First level of ConvLSTM
    x = layers.ConvLSTM2D(filters=128, kernel_size=(5, 5), padding="same", return_sequences=True, activation="relu")(inputs)
    x = layers.BatchNormalization()(x)
    
    # Second level of ConvLSTM
    x = layers.ConvLSTM2D(filters=64, kernel_size=(3, 3), padding="same", return_sequences=True, activation="relu")(x)
    x = layers.BatchNormalization()(x)
    
    # Third level of ConvLSTM
    x = layers.ConvLSTM2D(filters=64, kernel_size=(1, 1), padding="same", return_sequences=True, activation="relu")(x)
    x = layers.Dropout(0.5)(x)
    
    # Conv3D to get the final output frame
    outputs = layers.Conv3D(filters=1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same")(x)
    
    model = models.Model(inputs, outputs)
    return model

# Build and compile the model
input_shape = x_train.shape[2:]  
conv_lstm_model = build_conv_lstm_model(input_shape)

# Define the optimizer
optimizer = optimizers.Adam(learning_rate=0.001)
loss = keras.losses.BinaryCrossentropy()

# Compile the model
conv_lstm_model.compile(loss=dice_coef_loss, optimizer=optimizer)

# Print model summary
conv_lstm_model.summary()


In [None]:
# Set the number of epochs and batch size
epochs = 10
batch_size = 2

# Train the model using GPU
with tf.device('/GPU:0'):
    history = conv_lstm_model.fit(
        x_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        callbacks=[checkpoint_callback],
        validation_data=(x_val, y_val)
    )


In [None]:
# Load the trained model
loaded_model = tf.keras.models.load_model("conv_lstm_model_128_epoch_10_128_dice.h5", compile=False)


# Compile the loaded model
loaded_model.compile(loss=dice_coef_loss, optimizer=optimizer)

# loaded_model.summary()

In [None]:
import numpy as np

# Load and normalize the synthetic train and validation datasets
syn_train_dataset = np.load("synthetic_train_data_128.npy") / 255.0
syn_val_dataset = np.load("synthetic_val_data_128.npy") / 255.0

# Ensure the datasets are of type float32
syn_train_dataset = syn_train_dataset.astype(np.float32)
syn_val_dataset = syn_val_dataset.astype(np.float32)

# Inspect the dataset shapes
print(f"Training Dataset Shape: {syn_train_dataset.shape}")
print(f"Validation Dataset Shape: {syn_val_dataset.shape}")

# Helper function to shift the frames
def create_shifted_frames(data):
    x = data[:, :-1, :, :, :]
    y = data[:, 1:, :, :, :]
    return x, y

# Apply the processing function to the datasets
x_train, y_train = create_shifted_frames(syn_train_dataset)
x_val, y_val = create_shifted_frames(syn_val_dataset)

# Inspect the shapes of the processed datasets
print(f"Training Dataset Shapes: x_train: {x_train.shape}, y_train: {y_train.shape}")
print(f"Validation Dataset Shapes: x_val: {x_val.shape}, y_val: {y_val.shape}")


In [None]:
model_path = "../gamma_index/ConvLSTM_Model_128_dice.h5"
# model_path = '../gamma_index/finetuned_ConvLSTM_Model_128_dice_10.h5'
# loaded_model = load_model(model_path)
optimizer = keras.optimizers.Adam()
loaded_model = load_model(model_path, compile=False)
loaded_model.compile(loss=dice_coef_loss, optimizer=optimizer)
loaded_model.summary()


In [None]:
checkpoint_callback = ModelCheckpoint(
#     filepath='../gamma_index/retuned_finetuned_ConvLSTM_Model_128_dice_{epoch:02d}.h5',  
    filepath='../gamma_index/finetuned_ConvLSTM_Model_128_dice_{epoch:02d}.h5',  
    save_freq='epoch',                              
    save_weights_only=False,                          
    verbose=1                                         
)

# Train the model 
epochs = 10
batch_size = 2

loaded_model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, 
              validation_data = (x_val, y_val), callbacks=[checkpoint_callback])

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Loop through the specified range of examples in the validation dataset
for idx in range(106):
    index = idx  # Use the current index directly
    example = val_dataset[index]

    # Pick the first 19 frames from the example for prediction
    frames = example[:19, ...]

    # Ground truth for the 20th frame
    ground_truth_frame = example[19, ...]

    # Predict the 20th frame using the model
    new_prediction = loaded_model.predict(np.expand_dims(frames, axis=0))
    predicted_20th_frame = np.squeeze(new_prediction[:, -1, ...])

    # Plot the ground truth and predicted 20th frames
    plt.figure(figsize=(10, 5))

    # Ground truth frame
    plt.subplot(1, 2, 1)
    plt.imshow(np.squeeze(ground_truth_frame), cmap="gray")
    plt.title(f"Ground Truth (Frame 20) of index: {index}")
    plt.axis("off")

    # Predicted frame
    plt.subplot(1, 2, 2)
    plt.imshow(np.squeeze(predicted_20th_frame), cmap="gray")
    plt.title(f"Predicted (Frame 20) of index: {index}")
    plt.axis("off")

    plt.show()
