# Get autoregressive predictions from pretrained checkpoint

In [None]:
%cd ..

In [None]:
from maskpredformer.scheduled_sampling_trainer import MaskSimVPScheduledSamplingModule
from maskpredformer.trainer import MaskSimVPModule
from maskpredformer.simvp_dataset import DLDataset
from maskpredformer.vis_utils import show_gif, show_video_line

import random
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
ckpt_path = "checkpoints/method=SS_simvp=simvp_epoch=16-val_loss=0.014.ckpt_inc_every_n_epoch=10_max_sample_steps=5_schedule_k=2_unlabeled=False/simvp_ss_epoch=10-valid_last_frame_iou=0.461.ckpt"
module = MaskSimVPScheduledSamplingModule.load_from_checkpoint(ckpt_path, use_gt_data=True, unlabeled=False)

In [None]:
val_set = DLDataset("data/DL", "val", unlabeled=False, pre_seq_len=11, aft_seq_len=11)

In [None]:
ckpt_path = "./checkpoints/simvp_epoch=16-val_loss=0.014.ckpt"
old_module = MaskSimVPModule.load_from_checkpoint(ckpt_path, unlabeled=False)

In [None]:
@torch.no_grad()
def get_result(i):
    x, y= val_set[i]
    x = x.unsqueeze(0).to(module.device)
    y = y.unsqueeze(0).to(module.device)
    cur_seq = module.sample_autoregressive(x, 11)
    x = x.squeeze(0).cpu().numpy()
    y = y.squeeze(0).cpu().numpy()
    y_hat = cur_seq.squeeze(0).cpu().numpy()
    return x, y, y_hat

In [None]:
@torch.no_grad()
def get_result_old(i):
    x, y= module.val_set[i]
    x = x.unsqueeze(0).to(old_module.device)
    y = y.unsqueeze(0).to(old_module.device)
    cur_seq = x.clone()
    for _ in range(11):
        y_hat_logits = old_module.step(cur_seq, None)
        y_hat = torch.argmax(y_hat_logits, dim=2)
        cur_seq = torch.cat([cur_seq[:, 1:], y_hat], dim=1)
    # convert to numpy
    x = x.squeeze(0).cpu().numpy()
    y = y.squeeze(0).cpu().numpy()
    y_hat = cur_seq.squeeze(0).cpu().numpy()
    return x, y, y_hat

In [None]:
def get_all_results(old=False):
    all_yhat = []
    for i in tqdm(range(len(module.val_set))):
        if old:
            x, y, y_hat = get_result_old(i)
        else:
            x, y, y_hat = get_result(i)
        all_yhat.append(y_hat)
    return np.stack(all_yhat)

In [None]:
all_yhat = get_all_results(old=False)


In [None]:
root_val_dir = "data/val_gt/"
gt_masks = np.stack([np.load(os.path.join(root_val_dir, f"video_{i}", "mask.npy")) for i in range(1000, 2000)])

# Visualization

In [None]:
def get_gif(i):
    x, y, y_hat = get_result(i)
    show_gif(x, y, y_hat, out_path='./result.gif')
get_gif(50)

In [None]:
show_video_line(gt_masks[50], 22)

In [None]:
show_video_line(np.concatenate((gt_masks[50, :11], all_yhat[50]), axis=0), 22)

In [None]:
# show_video_line(np.concatenate([all_x[vis_idx], all_yhat[vis_idx]], axis=0), 22)

# IoU calculation

In [None]:
from torchmetrics import JaccardIndex

In [None]:
jaccard = JaccardIndex(task='multiclass', num_classes=49)

In [None]:
jaccard(torch.from_numpy(all_yhat[:,-1]), torch.from_numpy(gt_masks[:, -1]))