# Imports

In [None]:
from pathlib2 import Path
from tqdm.notebook import tqdm
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams["animation.embed_limit"] = 2048
plt.rcParams['figure.dpi'] = 150
%matplotlib inline

In [None]:
from dataset import LungDataModule
from model import TransMorphModule
from utils import load_losses, load_model_params

# Load models

In [None]:
ckpt_dirs = [
    "model_weights_v1/transmorph-gmi=1-gl2d=1-adam-0.0001-last-100-128-norm",
    "model_weights_v2/transmorph-gmi=1-gl2d=1-adam-0.0001-last-100-32-norm",
    "model_weights_v2/transmorph-gmi=1-gl2d=1-adam-0.0001-last-100-128-norm",
]

In [None]:
def load_best_model(ckpt_dir):
    ckpt_dir = Path(ckpt_dir)

    ckpt_file_names = sorted([c.name for c in ckpt_dir.glob("*.ckpt")], reverse=True)
    best_ckpt = ckpt_dir / ckpt_file_names[0]
    val_loss, epoch = ckpt_file_names[0].split("&")
    val_loss = val_loss.split("=")[1]
    epoch = epoch.split("=")[1].split(".")[0]

    model_name, image_loss, flow_loss, optimizer_name, lr, target_type, max_epoch, series_len, data_mod = str(
        ckpt_dir.name).split("-")

    target_type = str.lower(target_type)
    series_len = int(series_len)
    max_epoch = int(max_epoch)

    image_losses, flow_losses = load_losses(image_loss, flow_loss, delimiter="=")
    
    net, criteria_image, criteria_flow, criterion_disp, optimizer = load_model_params(
        model_name=model_name,
        image_losses=image_losses,
        flow_losses=flow_losses,
        optimizer_name=None,
        series_len=series_len)

    model = TransMorphModule.load_from_checkpoint(
        str(best_ckpt),
        strict=False,
        net=net,
        criteria_image=criteria_image,
        criteria_flow=criteria_flow,
        criterion_disp=criterion_disp,
        target_type=target_type,
    )

    print("=" * 80)
    print(f"Best performing model")
    print("")
    print(f"    val_loss    : {val_loss}")
    print(f"    epoch       : {epoch}")
    print("")
    print(f"    model_name  : {model_name}")
    print(f"    image_loss  : {criteria_image}")
    print(f"    flow_loss   : {criteria_flow}")
    print(f"    disp_loss   : {criterion_disp}")
    print(f"    target_type : {target_type}")
    print(f"    max_epoch   : {max_epoch}")
    print(f"    series_len  : {series_len}")
    print(f"    data_mod    : {data_mod}")
    print("")
    print("=" * 80)

    return model, data_mod

In [None]:
loaded = [load_best_model(ckpt_dir) for ckpt_dir in ckpt_dirs]
models = [model for model, _ in loaded]
data_mods = [data_mod for _, data_mod in loaded]
assert all(data_mods[0] == x for x in data_mods)
data_mod = data_mods[0]

# Run predictions

In [None]:
data_module = LungDataModule(batch_size=1, num_workers=12, pin_memory=True, mod=data_mod)
data_module.setup()
data_loader = data_module.test_dataloader()

In [None]:
trainer = pl.Trainer()
torch.set_float32_matmul_precision("high")

In [None]:
%%script echo "No action"
results = [trainer.test(model, data_loader) for model in models]

In [None]:
# may require resources big times
predictions = [trainer.predict(model, data_loader) for model in models]

# Create visualisation

In [None]:
def transform(warped_series, flow_series, disp_series, fixed_image):
    # re-order channels of outputs (t, w, h, c)
    warped_series = warped_series[0].detach().cpu().permute(dim)
    flow_series = flow_series[0].detach().cpu().permute(dim)

    # create diff map using comparison to neighbour
    abs_diff_series = torch.stack([torch.abs(w - fixed_image) for w in warped_series], dim=0)

    # create displacement vector field colour map
    flow_series = torch.tanh(flow_series[:, :, :])
    flows_x = (flow_series[:, :, :, 0] + 1) / 2
    flows_y = (flow_series[:, :, :, 1] + 1) / 2
    flows_z = flows_x * 0
    flow_series = torch.stack([flows_x, flows_y, flows_z], dim=-1)
    
    return warped_series, flow_series, disp_series, abs_diff_series

In [None]:
# nums = [0, 3, 6, 9, 12]
nums = [3]
subset_moving = []
subsets_warped = [] # [ [a,b,c], [a,b,c] ] 

for pred in predictions:
    subsets_warped.append([pred[i] for i in nums])

for i, batch in enumerate(data_loader):
    if i in nums:
        subset_moving.append(batch)

In [None]:
show_title = True
data = []
dim = (3, 1, 2, 0)

for i, moving in enumerate(subset_moving):
    moving = moving[0].detach().cpu().permute(dim)
    fixed = moving[-1]
    transformed = [transform(subset_warped[i][0], subset_warped[i][1], subset_warped[i][2], fixed) for subset_warped in
                   subsets_warped]
    data.append((moving, fixed, transformed))

In [None]:
num_cols = 5
num_rows = len(nums) * len(predictions)
fig, axs = plt.subplots(ncols=num_cols, nrows=num_rows, figsize=(2 * num_cols, 2 * num_rows))
axs = axs.flatten()
imgs = []

for i in range(len(nums)):
    row_offset = i * num_cols * len(predictions)
    
    if show_title and i == 0:       
        axs[row_offset + 0].set(title=r"$\mathit{m}$")
        axs[row_offset + 1].set(title=r"$\mathit{f}$")
        axs[row_offset + 2].set(title=r"$\mathit{m \circ \phi}$")
        axs[row_offset + 3].set(title=r"$\mathit{\left| \; (m \circ \phi) - f \; \right|}$")
        axs[row_offset + 4].set(title=r"$\mathit{\phi}$")

    m, f, t = data[i]

    for k in range(num_cols * len(predictions)):
        axs[row_offset + k].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], aspect="equal")

    for k in range(len(predictions)):
        ws, fs, _, ds = t[k]
        imgs.append(axs[row_offset + k * num_cols + 0].imshow(m[0], cmap="gray", animated=True))
        imgs.append(axs[row_offset + k * num_cols + 1].imshow(f, cmap="gray"))
        imgs.append(axs[row_offset + k * num_cols + 2].imshow(ws[0], cmap="gray", animated=True))
        imgs.append(axs[row_offset + k * num_cols + 3].imshow(ds[0], cmap="gray", animated=True))
        imgs.append(axs[row_offset + k * num_cols + 4].imshow(fs[0], cmap="gray", animated=True))
        
def animate(delta):
    for i in range(len(nums)):
        row_offset = i * num_cols * len(predictions)
        m, f, t = data[i]    

        for k in range(len(predictions)):
            ws, fs, _, ds = t[k]
            imgs[row_offset + k * num_cols + 0].set_data(m[delta])
            imgs[row_offset + k * num_cols + 2].set_data(ws[delta])
            imgs[row_offset + k * num_cols + 3].set_data(ds[delta])
            imgs[row_offset + k * num_cols + 4].set_data(fs[delta])

    return imgs

animation.FuncAnimation(fig, animate, frames=200, blit=True)