# Imports

In [None]:
import sys
sys.path.append('../')

In [None]:
import os

from pathlib2 import Path
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch

from reg.data import LungDataModule
from reg.model import TransMorphModuleBuilder

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams["animation.embed_limit"] = 2048
plt.rcParams['figure.dpi'] = 150
%matplotlib inline

# Load models

In [None]:
def load_best_model(ckpt_dir):
    ckpt_dir = Path(ckpt_dir)
    ckpt_file_names = sorted([c.name for c in ckpt_dir.glob("*.ckpt")], reverse=True)
    
    best_ckpt = ckpt_dir / ckpt_file_names[0]
    val_loss, epoch = ckpt_file_names[0].split("&")
    val_loss = val_loss.split("=")[1]
    epoch = epoch.split("=")[1].split(".")[0]
    
    model, config = TransMorphModuleBuilder.from_ckpt(str(best_ckpt), True).build()

    print("=" * 80)
    print(f"Best performing model")
    print("")
    print(f"  val_loss                  = {val_loss}")
    print(f"  epoch                     = {epoch}")
    print("")
    for k,v in config.items():
        print(f"  {k:<25} = {v}")
    print("")
    print("=" * 80)
    
    return model

In [None]:
ckpt_dirs = [
    "../model_weights_v3/network_transmorph-tiny.criteria-warped_gmi-1.criteria-flow_gl2d-1.reg-strategy_soreg.reg-target_last.reg-depth_32.ident-loss_False.optimizer_adam.learning-rate_1E-04/"
]

In [None]:
loaded = [load_best_model(ckpt_dir) for ckpt_dir in ckpt_dirs]
models = [model for model in loaded]

# Run predictions

In [None]:
n_available_cores = len(os.sched_getaffinity(0)) - 1
n_available_cores = 1 if n_available_cores == 0 else n_available_cores
data_module = LungDataModule(
            root_dir="/media/agjvc_rad3/_TESTKOLLEKTIV/Daten/Daten",
            split=(0.7, 0.1, 0.2),
            seed=42,
            pin_memory=True,
            num_workers=n_available_cores,
        )
data_module.setup()
data_loader = data_module.test_dataloader()

In [None]:
trainer = pl.Trainer()
torch.set_float32_matmul_precision("high")

In [None]:
%%script echo "No action"
results = [trainer.test(model, data_loader) for model in models]
results

In [None]:
# may require resources big times
predictions = [trainer.predict(model, data_loader) for model in models]

# Create visualisation

In [None]:
def transform(warped_series, flow_series, fixed_image):
    # re-order channels of outputs (t, w, h, c)
    warped_series = warped_series[0].detach().cpu().permute(dim)
    flow_series = flow_series[0].detach().cpu().permute(dim)

    # create diff map using comparison to neighbour
    abs_diff_series = torch.stack([torch.abs(w - fixed_image) for w in warped_series], dim=0)

    # create displacement vector field colour map
    flow_series = torch.tanh(flow_series[:, :, :])
    flows_x = (flow_series[:, :, :, 0] + 1) / 2
    flows_y = (flow_series[:, :, :, 1] + 1) / 2
    flows_z = flows_x * 0
    flow_series = torch.stack([flows_x, flows_y, flows_z], dim=-1)
    
    return fixed_image, warped_series, flow_series, abs_diff_series

In [None]:
nums = [0]
subset_moving = []
subsets_warped = [] 

for pred in predictions:
    subsets_warped.append([pred[i] for i in nums])

for i, batch in enumerate(data_loader):
    if i in nums:
        subset_moving.append(batch)

In [None]:
show_title = True
data = []
dim = (3, 1, 2, 0)

for i, moving in enumerate(subset_moving):
    moving_p = moving[0].detach().cpu().permute(dim)
    transformed = [transform(subset_warped[i][0], subset_warped[i][1], models[k].extract_fixed_image(moving)[0].permute((1,2,0))) for k, subset_warped in
                   enumerate(subsets_warped)]
    data.append((moving_p, transformed))

In [None]:
num_cols = 5
num_rows = len(nums) * len(predictions)
fig, axs = plt.subplots(ncols=num_cols, nrows=num_rows, figsize=(2 * num_cols, 2 * num_rows))
axs = axs.flatten()
images = []

for i in range(len(nums)):
    row_offset = i * num_cols * len(predictions)
    
    if show_title and i == 0:       
        axs[row_offset + 0].set(title=r"$\mathit{m}$")
        axs[row_offset + 1].set(title=r"$\mathit{f}$")
        axs[row_offset + 2].set(title=r"$\mathit{m \circ \phi}$")
        axs[row_offset + 3].set(title=r"$\mathit{\left| \; (m \circ \phi) - f \; \right|}$")
        axs[row_offset + 4].set(title=r"$\mathit{\phi}$")

    m, t = data[i]

    for k in range(num_cols * len(predictions)):
        axs[row_offset + k].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], aspect="equal")

    for k in range(len(predictions)):
        f, ws, fs, ds = t[k]
        images.append(axs[row_offset + k * num_cols + 0].imshow(m[0], cmap="gray", animated=True))
        images.append(axs[row_offset + k * num_cols + 1].imshow(f, cmap="gray"))
        images.append(axs[row_offset + k * num_cols + 2].imshow(ws[0], cmap="gray", animated=True))
        images.append(axs[row_offset + k * num_cols + 3].imshow(ds[0], cmap="gray", animated=True))
        images.append(axs[row_offset + k * num_cols + 4].imshow(fs[0], cmap="gray", animated=True))

def animate(delta):
    for i in range(len(nums)):
        row_offset = i * num_cols * len(predictions)
        m, t = data[i]    

        for k in range(len(predictions)):
            _, ws, fs, ds = t[k]
            images[row_offset + k * num_cols + 0].set_data(m[delta])
            images[row_offset + k * num_cols + 2].set_data(ws[delta])
            images[row_offset + k * num_cols + 3].set_data(ds[delta])
            images[row_offset + k * num_cols + 4].set_data(fs[delta])

    return images

animation.FuncAnimation(fig, animate, frames=200, blit=True)