In [1]:
import argparse
from config import update_settings, get_settings
import numpy as np
import os
from datetime import datetime

args = {
    "nt": 1,
    "nb_epoch": 250,
    "batch_size": 1,
    "output_channels": [3, 6, 12, 24],
    "num_P_CNN": 1,
    "num_R_CLSTM": 1,
    "num_passes": 1,
    "pan_hierarchical": False,
    "downscale_factor": 4,
    "resize_images": False,
    "train_proportion": 0.7,
    "results_subdir": f"{str(datetime.now())}",
    "dataset_weights": "various",
    "data_subset_weights": "CircleV_CrossH",
    "dataset": "circle_vertical",
    "data_subset": "circle_vertical",
    "model_choice": "baseline",
    "system": "laptop",
    "reserialize_dataset": False,
    "output_mode": "Error"
}


update_settings(args["system"], args["dataset_weights"], args["data_subset_weights"], args["results_subdir"])
DATA_DIR, WEIGHTS_DIR, _, _ = get_settings()["dirs"]

import os
import warnings
import hickle as hkl

# Suppress warnings
warnings.filterwarnings("ignore")
# or '2' to filter out INFO messages too
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
import shutil
import keras
from keras import backend as K
from keras import layers
from data_utils import SequenceGenerator, IntermediateEvaluations, create_dataset_from_serialized_generator, config_gpus 
%matplotlib tk
import matplotlib.pyplot as plt
# import addcopyfighandler
from matplotlib.colors import LinearSegmentedColormap

# PICK MODEL
if args["model_choice"] == "baseline":
    # Predict next frame along RGB channels only
    if not args['pan_hierarchical']:
        from PPN_models.PPN_Baseline import ParaPredNet
    else:
        from PPN_models.PPN_Baseline import ParaPredNet
        print("Using Pan-Hierarchical Representation")
elif args["model_choice"] == "cl_delta":
    # Predict next frame and change from current frame
    from PPN_models.PPN_CompLearning_Delta_Predictions import ParaPredNet
elif args["model_choice"] == "cl_recon":
    # Predict current and next frame
    from PPN_models.PPN_CompLearning_Recon_Predictions import ParaPredNet
elif args["model_choice"] == "multi_channel":
    # Predict next frame along Disparity, Material Index, Object Index, 
    # Optical Flow, Motion Boundaries, and RGB channels all stacked together
    assert args["dataset"] == "monkaa" or args["dataset"] == "driving", "Multi-channel model only works with Monkaa or Driving dataset"
    from PPN_models.PPN_Multi_Channel import ParaPredNet
    bottom_layer_output_channels = 7 # 1 Disparity, 3 Optical Flow, 3 RGB
    args["output_channels"][0] = bottom_layer_output_channels
else:
    raise ValueError("Invalid model choice")

# where weights are loaded prior to eval
if (args["dataset_weights"], args["data_subset_weights"]) in [
    ("rolling_square", "single_rolling_square"),
    ("rolling_circle", "single_rolling_circle"),
]:
    # where weights will be loaded/saved
    weights_file = os.path.join(WEIGHTS_DIR, f"para_prednet_"+args["data_subset"]+"_weights.hdf5")
elif (args["dataset_weights"], args["data_subset_weights"]) in [
    ("all_rolling", "single"),
    ("all_rolling", "multi")
]:
    # where weights will be loaded/saved
    weights_file = os.path.join(WEIGHTS_DIR, f"para_prednet_"+args["dataset_weights"]+"_"+args["data_subset_weights"]+"_weights.hdf5")
elif args["dataset_weights"] in ["all_rolling", "ball_collisions", "various"]:
    # where weights will be loaded/saved
    weights_file = os.path.join(WEIGHTS_DIR, f"para_prednet_"+args["dataset_weights"]+"_"+args["data_subset_weights"]+"_weights.hdf5")
else:
    # where weights will be loaded/saved
    weights_file = os.path.join(WEIGHTS_DIR, f"para_prednet_"+args["dataset"]+"_weights.hdf5")
# weights_file = os.path.join(f"/home/evalexii/Documents/Thesis/code/parallel_prednet/model_weights/{args['dataset_weights']}/{args['data_subset_weights']}", f"para_prednet_{args['data_subset_weights']}_weights.hdf5")
assert os.path.exists(weights_file), "Weights file not found"
if args['dataset'] != args['dataset_weights']: 
    print(f"WARNING: dataset ({args['dataset']}) and dataset_weights ({args['dataset_weights']}/{args['data_subset_weights']}) do not match - generalizing...") 
else:
    print(f"OK: dataset ({args['dataset']}) and dataset_weights ({args['dataset_weights']}/{args['dataset_weights']}) match") 

# Training parameters
nt = args["nt"]  # number of time steps
batch_size = args["batch_size"]  # 4
output_channels = args["output_channels"]

# Define image shape
if args["dataset"] == "kitti":
    original_im_shape = (128, 160, 3)
    im_shape = original_im_shape
elif args["dataset"] == "monkaa" or args["dataset"] == "driving":
    original_im_shape = (540, 960, 3)
    downscale_factor = args["downscale_factor"]
    im_shape = (original_im_shape[0] // downscale_factor, original_im_shape[1] // downscale_factor, 3)
elif args["dataset"] in ["rolling_square", "rolling_circle"]:
    original_im_shape = (50, 100, 3)
    downscale_factor = args["downscale_factor"]
    im_shape = (original_im_shape[0] // downscale_factor, original_im_shape[1] // downscale_factor, 3) if args["resize_images"] else original_im_shape
else:
    original_im_shape = (50, 50, 3)
    downscale_factor = args["downscale_factor"]
    im_shape = (original_im_shape[0] // downscale_factor, original_im_shape[1] // downscale_factor, 3) if args["resize_images"] else original_im_shape

print(f"Working on dataset: {args['dataset']}")

# Create ParaPredNet
if args["dataset"] == "kitti":
    # These are Kitti specific input shapes
    inputs = (keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)))
    PPN = ParaPredNet(args, im_height=im_shape[0], im_width=im_shape[1])  # [3, 48, 96, 192]
    outputs = PPN(inputs)
    PPN = keras.Model(inputs=inputs, outputs=outputs)

elif args["dataset"] == "monkaa":
    # These are Monkaa specific input shapes
    inputs = (keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)),
    )
    PPN = ParaPredNet(args, im_height=im_shape[0], im_width=im_shape[1])  # [3, 48, 96, 192]
    outputs = PPN(inputs)
    PPN = keras.Model(inputs=inputs, outputs=outputs)

elif args["dataset"] == "driving":
    # These are driving specific input shapes
    inputs = (keras.Input(shape=(nt, im_shape[0], im_shape[1], 1)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)),
        keras.Input(shape=(nt, im_shape[0], im_shape[1], 3)),
    )
    PPN = ParaPredNet(args, im_height=im_shape[0], im_width=im_shape[1])  # [3, 48, 96, 192]
    outputs = PPN(inputs)
    PPN = keras.Model(inputs=inputs, outputs=outputs)

else:
    # These are rolling_square specific input shapes
    inputs = keras.Input(shape=(nt, im_shape[0], im_shape[1], 3))
    PPN_layer = ParaPredNet(args, im_height=im_shape[0], im_width=im_shape[1])
    PPN_layer.output_mode = "Prediction"
    PPN_layer.continuous_eval = True
    outputs = PPN_layer(inputs)
    PPN = keras.Model(inputs=inputs, outputs=outputs)

resos = PPN.layers[-1].resolutions
PPN.compile(optimizer="adam", loss="mean_squared_error")
print("ParaPredNet compiled...")
PPN.build(input_shape=(None, nt) + im_shape)
print(PPN.summary())
num_layers = len(output_channels)  # number of layers in the architecture
print(f"{num_layers} PredNet layers with resolutions:")
for i in reversed(range(num_layers)):
    print(f"Layer {i+1}:  {resos[i][0]} x {resos[i][1]} x {output_channels[i]}")

# load previously saved weights
try: 
    PPN.load_weights(weights_file)
    print("Weights loaded successfully...")
except: 
    raise ValueError("Weights don't fit - exiting...")

# Load dataset - only working for animations
test_data = hkl.load(DATA_DIR + f"/{args['data_subset']}_train.hkl")[0]
td_len = test_data.shape[0]

Working on dataset: circle_vertical
ParaPredNet compiled...
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 1, 50, 50, 3)]    0         
                                                                 
 para_pred_net (ParaPredNet  (1, 1, 50, 50, 3)         110184    
 )                                                               
                                                                 
Total params: 110184 (430.41 KB)
Trainable params: 110184 (430.41 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None
4 PredNet layers with resolutions:
Layer 4:  6 x 6 x 24
Layer 3:  12 x 12 x 12
Layer 2:  25 x 25 x 6
Layer 1:  50 x 50 x 3
Weights loaded successfully...


In [3]:
# manually initialize PPN layer states
PPN.layers[-1].init_layer_states()

# dataset_iter = iter(test_dataset)
fig, axs = plt.subplots(1, 3, figsize=(20, 4))
plt.show(block=False)
rg_colormap = LinearSegmentedColormap.from_list('custom_cmap', [(0, 'red'), (0.5, 'black'), (1, 'green')])

# test_data = np.reshape(test_data, (batch_size, td_len, im_shape[0], im_shape[1], 3))
for i in range(td_len):
    # ground_truth_image = next(dataset_iter)[0]
    print(f"Iteration {i+1}/{td_len}")
    ground_truth_image = np.reshape(test_data[i], (1, 1, *test_data.shape[1:]))
    predicted_image = PPN.layers[-1](ground_truth_image)
    error_image = ground_truth_image - predicted_image
    error_image_grey = np.mean(error_image, axis=-1, keepdims=True)
    mse = np.mean(error_image**2)

    # clear the axes
    axs[0].cla()
    axs[1].cla()
    axs[2].cla()

    # print the two images side-by-side
    axs[0].imshow(ground_truth_image[0,0,...])
    axs[1].imshow(predicted_image[0,0,...])
    axs[2].imshow(error_image_grey[0,0,...], cmap=rg_colormap)

    # add titles
    axs[0].set_title("Ground Truth")
    axs[1].set_title("Predicted")
    axs[2].set_title(f"Error, MSE: {mse:.3f}")
    fig.suptitle(f"Frame {i+1}/{td_len}")

    fig.canvas.draw()
    fig.canvas.flush_events()

    # enable click-through plotting
    # plt.show(block=True)

    # Wait for user input to continue or close the current plot
    # user_input = input("Press enter to continue or type 'close' to close the plot and stop: ")
    # if user_input.lower() == 'close':
    #     plt.close()
    #     break

    # delay n seconds
    # plt.pause(10)

Iteration 1/500
Iteration 2/500
Iteration 3/500
Iteration 4/500
Iteration 5/500
Iteration 6/500
Iteration 7/500
Iteration 8/500
Iteration 9/500
Iteration 10/500
Iteration 11/500
Iteration 12/500
Iteration 13/500
Iteration 14/500
Iteration 15/500
Iteration 16/500
Iteration 17/500
Iteration 18/500
Iteration 19/500
Iteration 20/500
Iteration 21/500
Iteration 22/500
Iteration 23/500
Iteration 24/500
Iteration 25/500
Iteration 26/500
Iteration 27/500
Iteration 28/500
Iteration 29/500
Iteration 30/500
Iteration 31/500
Iteration 32/500
Iteration 33/500
Iteration 34/500
Iteration 35/500
Iteration 36/500
Iteration 37/500
Iteration 38/500
Iteration 39/500
Iteration 40/500
Iteration 41/500
Iteration 42/500
Iteration 43/500
Iteration 44/500
Iteration 45/500
Iteration 46/500
Iteration 47/500
Iteration 48/500
Iteration 49/500
Iteration 50/500
Iteration 51/500
Iteration 52/500
Iteration 53/500
Iteration 54/500
Iteration 55/500
Iteration 56/500
Iteration 57/500
Iteration 58/500
Iteration 59/500
Iterat