# Imports

In [None]:
import os
if os.getcwd().split("/")[-1] != "reg":
    os.chdir("../reg")

In [None]:
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch

from dataset import LungDataModule
from utils import load_best_model

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams["animation.embed_limit"] = 2048
plt.rcParams['figure.dpi'] = 150
%matplotlib inline

# Load models

In [None]:
ckpt_dirs = [
    "model_weights_v2/transmorph-gmi=1-gl2d=1-adam-0.0001-mean-100-32-norm",
    "model_weights_v2/transmorph-gmi=1-gl2d=1-adam-0.0001-last-100-32-norm"
]

In [None]:
loaded = [load_best_model(ckpt_dir) for ckpt_dir in ckpt_dirs]
models = [model for model, _ in loaded]
data_mods = [data_mod for _, data_mod in loaded]
assert all(data_mods[0] == x for x in data_mods)
data_mod = data_mods[0]

# Run predictions

In [None]:
data_module = LungDataModule(batch_size=1, num_workers=12, pin_memory=True, mod=data_mod)
data_module.setup()
data_loader = data_module.test_dataloader()

In [None]:
trainer = pl.Trainer()
torch.set_float32_matmul_precision("high")

In [None]:
%%script echo "No action"
results = [trainer.test(model, data_loader) for model in models]

In [None]:
# may require resources big times
predictions = [trainer.predict(model, data_loader) for model in models]

# Create visualisation

In [None]:
def transform(warped_series, flow_series, disp_series, fixed_image):
    # re-order channels of outputs (t, w, h, c)
    warped_series = warped_series[0].detach().cpu().permute(dim)
    flow_series = flow_series[0].detach().cpu().permute(dim)

    # create diff map using comparison to neighbour
    abs_diff_series = torch.stack([torch.abs(w - fixed_image) for w in warped_series], dim=0)

    # create displacement vector field colour map
    flow_series = torch.tanh(flow_series[:, :, :])
    flows_x = (flow_series[:, :, :, 0] + 1) / 2
    flows_y = (flow_series[:, :, :, 1] + 1) / 2
    flows_z = flows_x * 0
    flow_series = torch.stack([flows_x, flows_y, flows_z], dim=-1)
    
    return fixed_image, flow_series, disp_series, abs_diff_series

In [None]:
nums = [6]
subset_moving = []
subsets_warped = [] 

for pred in predictions:
    subsets_warped.append([pred[i] for i in nums])

for i, batch in enumerate(data_loader):
    if i in nums:
        subset_moving.append(batch)

In [None]:
show_title = True
data = []
dim = (3, 1, 2, 0)

for i, moving in enumerate(subset_moving):
    moving_p = moving[0].detach().cpu().permute(dim)
    transformed = [transform(subset_warped[i][0], subset_warped[i][1], subset_warped[i][2],  models[k].fixed_image(moving)[0].permute((1,2,0))) for k, subset_warped in
                   enumerate(subsets_warped)]
    data.append((moving_p, transformed))

In [None]:
num_cols = 5
num_rows = len(nums) * len(predictions)
fig, axs = plt.subplots(ncols=num_cols, nrows=num_rows, figsize=(2 * num_cols, 2 * num_rows))
axs = axs.flatten()
images = []

for i in range(len(nums)):
    row_offset = i * num_cols * len(predictions)
    
    if show_title and i == 0:       
        axs[row_offset + 0].set(title=r"$\mathit{m}$")
        axs[row_offset + 1].set(title=r"$\mathit{f}$")
        axs[row_offset + 2].set(title=r"$\mathit{m \circ \phi}$")
        axs[row_offset + 3].set(title=r"$\mathit{\left| \; (m \circ \phi) - f \; \right|}$")
        axs[row_offset + 4].set(title=r"$\mathit{\phi}$")

    m, t = data[i]

    for k in range(num_cols * len(predictions)):
        axs[row_offset + k].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], aspect="equal")

    for k in range(len(predictions)):
        f, ws, fs, _, ds = t[k]
        images.append(axs[row_offset + k * num_cols + 0].imshow(m[0], cmap="gray", animated=True))
        images.append(axs[row_offset + k * num_cols + 1].imshow(f, cmap="gray"))
        images.append(axs[row_offset + k * num_cols + 2].imshow(ws[0], cmap="gray", animated=True))
        images.append(axs[row_offset + k * num_cols + 3].imshow(ds[0], cmap="gray", animated=True))
        images.append(axs[row_offset + k * num_cols + 4].imshow(fs[0], cmap="gray", animated=True))

def animate(delta):
    for i in range(len(nums)):
        row_offset = i * num_cols * len(predictions)
        m, t = data[i]    

        for k in range(len(predictions)):
            _, ws, fs, _, ds = t[k]
            images[row_offset + k * num_cols + 0].set_data(m[delta])
            images[row_offset + k * num_cols + 2].set_data(ws[delta])
            images[row_offset + k * num_cols + 3].set_data(ds[delta])
            images[row_offset + k * num_cols + 4].set_data(fs[delta])

    return images

animation.FuncAnimation(fig, animate, frames=200, blit=True)