In [1]:
import pytorch_lightning as pl
from models import *
from data.data_classes import *

In [2]:
save_filepath = "./plots/live_plot2.pdf"

In [3]:
batch_size = 8

data_path = "data/harness/combined.npy"
num_ctx_frames = 5
num_tgt_frames = 5
split_ratio=[0, 1.0, 0.0]

liveness_datamodule = LivenessDataModule(batch_size, 
                                         num_ctx_frames, num_tgt_frames,
                                         data_path,
                                         split_ratio=split_ratio)

liveness_datamodule.setup()

val_dl = liveness_datamodule.val_dataloader()
val_ctx_frames, val_tgt_frames = next(iter(val_dl))                                            

### SimVP

In [4]:
# # Experiment 1 (filtered and augmented)
# hid_s=64
# hid_t=256
# N_s=4
# N_t=8
# kernel_sizes=[3,5,7,11]
# groups=4

# channels = 3
# height = 280
# width = 160
# input_shape = (channels, num_ctx_frames, height, width)

# model = SimVP(input_shape=input_shape, 
#               hid_s=hid_s, hid_t=hid_t, 
#               N_s=N_s, N_t=N_t,
#               kernel_sizes=kernel_sizes, 
#               groups=groups)
# model = model.load_from_checkpoint("./logs/SimVP/experiment1/checkpoints/epoch=99-step=2000.ckpt")              
            
# # Experiment 2 (unfiltered)
# model = SimVP(input_shape=input_shape, 
#               hid_s=hid_s, hid_t=hid_t, 
#               N_s=N_s, N_t=N_t,
#               kernel_sizes=kernel_sizes, 
#               groups=groups)
# model = model.load_from_checkpoint("./logs/SimVP/experiment2/checkpoints/epoch=99-step=2000.ckpt")  

# # Experiment 3 (filtered)
# model = SimVP(input_shape=input_shape, 
#               hid_s=hid_s, hid_t=hid_t, 
#               N_s=N_s, N_t=N_t,
#               kernel_sizes=kernel_sizes, 
#               groups=groups)
# model = model.load_from_checkpoint("./logs/SimVP/experiment3/checkpoints/epoch=99-step=500.ckpt")

### PredRNN

### Predicted Frames Set 1

In [5]:
# # Experiment 2 (unfiltered)
# hid_s=64
# hid_t=256
# N_s=4
# N_t=8
# kernel_sizes=[3,5,7,11]
# groups=4

# channels = 3
# height = 280
# width = 160
# input_shape = (channels, num_ctx_frames, height, width)

# model = SimVP(input_shape=input_shape, 
#               hid_s=hid_s, hid_t=hid_t, 
#               N_s=N_s, N_t=N_t,
#               kernel_sizes=kernel_sizes, 
#               groups=groups)
# model = model.load_from_checkpoint("./logs/SimVP/experiment2/checkpoints/epoch=99-step=2000.ckpt")  

In [6]:
# model.eval()
# model1_pred_frames = model(val_ctx_frames)

# torch.save(model1_pred_frames, "./plots/5-5-simvp-uf")
# model1_pred_frames.shape

In [7]:
model1_pred_frames = torch.load("./plots/5-5-simvp-uf.pt")

### Predicted Frames Set 2

In [8]:
# # Experiment 3 (filtered)
# hid_s=64
# hid_t=256
# N_s=4
# N_t=8
# kernel_sizes=[3,5,7,11]
# groups=4

# channels = 3
# height = 280
# width = 160
# input_shape = (channels, num_ctx_frames, height, width)

# model = SimVP(input_shape=input_shape, 
#               hid_s=hid_s, hid_t=hid_t, 
#               N_s=N_s, N_t=N_t,
#               kernel_sizes=kernel_sizes, 
#               groups=groups)
# model = model.load_from_checkpoint("./logs/SimVP/experiment3/checkpoints/epoch=99-step=500.ckpt")

In [9]:
# model.eval()
# model2_pred_frames = model(val_ctx_frames)
# torch.save(model2_pred_frames, "./plots/5-5-simvp-ft.pt")
# model2_pred_frames.shape

In [10]:
model2_pred_frames = torch.load("./plots/5-5-simvp-ft.pt")
model2_pred_frames.shape

torch.Size([8, 3, 5, 280, 160])

### Predicted Frames Set 3

In [11]:
# # Experiment 1 (filtered and augmented)
# model = SimVP(input_shape=input_shape, 
#               hid_s=hid_s, hid_t=hid_t, 
#               N_s=N_s, N_t=N_t,
#               kernel_sizes=kernel_sizes, 
#               groups=groups)
# model = model.load_from_checkpoint("./logs/SimVP/experiment1/checkpoints/epoch=99-step=2000.ckpt")   

In [12]:
# model.eval()
# model3_pred_frames = model(val_ctx_frames)

# torch.save(model3_pred_frames, "./plots/5-5-simvp-fa")
# model3_pred_frames.shape

In [13]:
model3_pred_frames = torch.load("./plots/5-5-simvp-fa.pt")

### Plot figure

In [14]:
def make_thesis_plot(frame_sets, names, plot_width, plot_height):

    def show_frames(frames, ax, row_label=None):
        start_id = 1 if row_label == "Context" else 6
        for i, frame in enumerate(frames):
            ax[i].imshow(frame)
            ax[i].set_xticks([])
            ax[i].set_yticks([])
            ax[i].set_xlabel(f"t={start_id+i}")
            ax[i].xaxis.set_label_coords(.52, 1.15)

        if row_label is not None:
            ax[0].set_ylabel(row_label, wrap=True)

    fig, ax = plt.subplots(len(frame_sets), 5,
                               figsize = (plot_width, plot_height))

    for i, frames in enumerate(frame_sets):
        frames = frames.squeeze().permute(1, 2, 3, 0).cpu().detach().numpy()
        show_frames(frames, ax[i], names[i])
    fig.set_facecolor("white")
    plt.savefig(save_filepath)
    return fig

In [None]:
index = 1
plot_width = 6
plot_height = 6
frame_sets = [model1_pred_frames[index],
              model2_pred_frames[index],
              model3_pred_frames[index]]

names = ["SimVP-\nUnfiltered", 
         "SimVP-\nFiltered", "SimVP-\nFiltered and Augmented"]               
make_thesis_plot(frame_sets, names, plot_width, plot_height)