In [None]:
PARAM_NAME = "context_length"
TEST_NAME = "temporal_loss"
SAMPLE_IDX = 1
WANDB_PROJECT = "temple/lung-registration"

## Imports

In [None]:
import sys

sys.path.append('../')

import os
import gc

import torch
import wandb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from torchvision.utils import save_image
from tqdm.notebook import tqdm
from PIL import Image, ImageDraw, ImageFont

from reg.transmorph import TransMorphModule
from reg.data import LungDataModule

# Series export

## Helper functions

In [None]:
def sorted_files_in_directory(directory: str):
    run_ids = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
    runs = [(d, os.path.join(directory, d)) for d in run_ids]
    files = [(run_id, sorted(os.listdir(path_suffix), reverse=True)) for run_id, path_suffix in runs]
    files_best = [(run_id, ff[0]) for run_id, ff in files]
    return files_best

In [None]:
def load_best_model(model_path: str):
    model = TransMorphModule.load_from_checkpoint(str(model_path), strict=True)
    print(f"{'=' * 5} Configuration summary {'=' * 92}")
    print(f"")
    print(model.hparams)
    print(f"")
    print("=" * 120)
    return model

In [None]:
def setup_data_module():
    n_available_cores = len(os.sched_getaffinity(0)) - 1
    n_available_cores = 1 if n_available_cores == 0 else n_available_cores
    data_module = LungDataModule(
        root_dir="/media/agjvc_rad3/_TESTKOLLEKTIV/Daten/Daten",
        split=(0.7, 0.1, 0.2),
        seed=42,
        pin_memory=True,
        num_workers=n_available_cores,
    )
    data_module.setup()
    return data_module

In [None]:
def compute_diff_series(warped_series: torch.Tensor):
    zero = torch.zeros(warped_series[0].shape).cuda()
    abs_diff_series_a = torch.stack([torch.abs(warped_series[i] - warped_series[i + 1]) for i in range(warped_series.shape[0] - 1)] + [zero], dim=0)
    abs_diff_series_b = torch.stack([zero] + [torch.abs(warped_series[i - 1] - warped_series[i]) for i in range(1, warped_series.shape[0])], dim=0)
    return (abs_diff_series_a + abs_diff_series_b) * 0.5

In [None]:
def compute_flow_series(flow_series: torch.Tensor):
    flow_series = torch.tanh(flow_series[:, :, :])
    flows_x = (flow_series[:, :, :, 0] + 1) / 2
    flows_y = (flow_series[:, :, :, 1] + 1) / 2
    flows_z = flows_x * 0
    flow_series = torch.stack([flows_x, flows_y, flows_z], dim=-1)
    return flow_series

In [None]:
def fetch_sample_from_dataloader(dataloader, sample_idx):
    for i, batch in enumerate(dataloader):
        if i == sample_idx or sample_idx is None:
            return batch

In [None]:
def add_text_to_image(image, text, position=(10, 10), font_size=20, color=(255, 255, 255)):
    pil_img = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_img)

    try:
        font = ImageFont.truetype("arial.ttf", font_size)
    except IOError:
        font = ImageFont.load_default()

    draw.text(position, text, font=font, fill=color)

    return np.array(pil_img)

In [None]:
def save_images_to_directory(directory, images, cmap):
    if not os.path.exists(directory):
        os.makedirs(directory)

    for idx, img in tqdm(enumerate(images), total=images.shape[0]):
        img_path = os.path.join(directory, f"{idx:0>3}.png")

        if cmap is not None:
            if img.ndim == 3 and img.shape[2] == 1:
                img = img.cpu().numpy()[:, :, 0]
            plt.imsave(img_path, img, cmap=cmap)
        else:
            # Save the image directly if no colormap is used
            img = img.permute(2, 0, 1)  # Change shape to (C, H, W)
            save_image(img, img_path)

        # Add the index as text to the image
        img = Image.open(img_path)
        img = add_text_to_image(np.array(img), f"{idx:0>3}")
        img = Image.fromarray(img)
        img.save(img_path)

## Series export main

In [None]:
def main_series_export(test_name, sample_idx=None):
    model_path = f"../model_weights/test_{test_name}"
    eval_path = f"../model_eval/test_{test_name}"

    data_module = setup_data_module()
    dataloader = data_module.test_dataloader()

    if sample_idx is None:
        sample_idx = np.random.randint(0, 64)

    dim = (3, 1, 2, 0)
    moving_series = fetch_sample_from_dataloader(dataloader, sample_idx).cuda()
    moving_series_exp = moving_series[0].permute(dim)
    
    all_runs = sorted_files_in_directory(model_path)
    for run_id, path_suffix in all_runs:
        run_model_path = os.path.join(os.path.join(model_path, run_id), path_suffix)
    
        # 1. Load best model
        model = load_best_model(run_model_path)

        # 2. Extract predictions and inputs
        with torch.no_grad():
            model.eval()
            warped_series, flow_series = model(moving_series)
    
        warped_series = warped_series[0].permute(dim)
        flow_series = flow_series[0].permute(dim)
    
        # 3. Compute series difference
        diff_series = compute_diff_series(warped_series)
    
        # 4. Transform flow series
        transformed_flow = compute_flow_series(flow_series)
    
        # 5. Save images and video
        arr = [
            (moving_series_exp, "moving_series", None),
            (warped_series, "warped_series", None),
            (transformed_flow, "flow_series", None),
            (diff_series, "diff_series", "magma")
        ]
    
        for images, name, cmap in arr:
            save_images_to_directory(f"{eval_path}/{run_id}/{name}", images, cmap)

        del warped_series, flow_series, transformed_flow, diff_series
        
        gc.collect()
        torch.cuda.empty_cache()
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))

In [None]:
gc.collect()
torch.cuda.empty_cache()
main_series_export(TEST_NAME, SAMPLE_IDX)

# Series mean intensity histogram

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

def histogram_main(sample_idx):
    CUT_OFF = 30

    data_module = setup_data_module()
    dataloader = data_module.test_dataloader()

    if sample_idx is None:
        sample_idx = np.random.randint(0, 64)

    moving_series = fetch_sample_from_dataloader(dataloader, sample_idx)

    image_means = moving_series.mean(axis=(2, 3))[0, 0][CUT_OFF:]
    mean_of_means = torch.mean(image_means)
    std_of_means = torch.std(image_means)  # Calculate the standard deviation of the means
    diff = torch.abs(image_means - mean_of_means)
    _, max_diff_i = torch.topk(diff, 1, largest=True)
    _, mean_i = torch.topk(diff, 1, largest=False)
    _, max_i = torch.topk(image_means, 1)

    image_indices = np.array(list(range(0, len(image_means)))) + CUT_OFF

    figsize = (16, 5)

    fig, ax = plt.subplots(1, 1, figsize=figsize)
    fig.set_tight_layout(True)

    ax.set_title("Mean of Image Series")
    ax.set_xlabel("Image Index")
    ax.set_ylabel("Mean Value")

    ax.plot(image_indices, image_means, "-", color='b', lw=2, label="Image Means")
    ax.axvline(x=(max_diff_i + CUT_OFF).numpy()[0], color='r', linestyle='-', lw=2, label=f"Peak at idx = {(max_diff_i + CUT_OFF).numpy()[0]}")
    ax.axhline(y=image_means[-1], color='purple', linestyle="dashed", lw=2, label="Last Mean")
    ax.axhline(y=image_means[mean_i], color='green', linestyle="dashdot", lw=2, label="Mean of Means")
    ax.axhline(y=image_means[max_i], color='orange', linestyle="dotted", lw=2, label="Max Mean")

    # Plot sigma lines
    ax.axhline(y=mean_of_means + std_of_means, color='y', linestyle="dotted", lw=2, label="Mean + 1 Sigma")
    ax.axhline(y=mean_of_means - std_of_means, color='y', linestyle="dotted", lw=2, label="Mean - 1 Sigma")
    ax.axhline(y=mean_of_means + 2 * std_of_means, color='orange', linestyle="dotted", lw=2, label="Mean + 2 Sigma")
    ax.axhline(y=mean_of_means - 2 * std_of_means, color='orange', linestyle="dotted", lw=2, label="Mean - 2 Sigma")

    # Set x-axis ticks every 10 values
    ax.set_xticks(np.arange(image_indices[0], image_indices[-1] + 1, 10))

    # Draw vertical lines every 32 steps
    for x in range(image_indices[0], image_indices[-1] + 1, 32):
        ax.axvline(x=x, color='gray', linestyle='--', lw=1, label='Every 32 Steps' if x == image_indices[0] else "")

    ax.legend(loc='lower right', fontsize='small')
    ax.grid(True, which='both', linestyle='--', lw=0.5)

    plt.show()
    plt.close()


In [None]:
histogram_main(SAMPLE_IDX)

# Series animation

## Helper functions

In [None]:
def load_images(base_path, targets, runs):
    data = []
    for run in runs:
        run_data = []
        for target in targets:
            target_path = os.path.join(base_path, run, target)
            images = []
            for img_name in sorted(os.listdir(target_path)):
                if img_name.endswith('.png'):
                    img_path = os.path.join(target_path, img_name)
                    img = Image.open(img_path)
                    img_array = np.array(img)
                    images.append(img_array)
            run_data.append(images)
        data.append(run_data)
    return data

In [None]:
def series_anim_main(test_name):
    base_path = f"../model_eval/test_{test_name}"
    targets = ["moving_series", "warped_series", "flow_series", "diff_series"]
    runs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    data = load_images(base_path, targets, runs)

    
    # Configuration for inline display
    plt.rcParams["animation.html"] = "jshtml"
    plt.rcParams["animation.embed_limit"] = 2048
    plt.rcParams['figure.dpi'] = 150
    %matplotlib inline
    
    # Define the number of columns and rows for the subplots
    num_cols = len(targets)
    num_rows = len(runs)
    fig, axs = plt.subplots(ncols=num_cols, nrows=num_rows, figsize=(2 * num_cols, 2 * num_rows))
    axs = axs.flatten()
    
    images = []
    for i in range(num_rows):
        row_offset = i * num_cols
        
        y_pos = 1 - ((i + 1) / float(num_rows + 1))  # Adjust the vertical position
        fig.text(0.01, y_pos, f"{runs[i]}", ha='right', va='center', fontsize=10, transform=fig.transFigure)
        
        if i == 0:       
            axs[row_offset + 0].set(title=r"$\mathit{m}$")
            axs[row_offset + 1].set(title=r"$\mathit{m \circ \phi}$")
            axs[row_offset + 2].set(title=r"$\mathit{\phi}$")
            axs[row_offset + 3].set(title=r"$\mathit{\left| \; (m \circ \phi) - f \; \right|}$")
    
        ms, ws, fs, ds = data[i]
    
        for k in range(num_cols):
            axs[row_offset + k].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], aspect="equal")
    
        images.append(axs[row_offset + 0].imshow(ms[0], animated=True))
        images.append(axs[row_offset + 1].imshow(ws[0], animated=True))
        images.append(axs[row_offset + 2].imshow(fs[0], animated=True))
        images.append(axs[row_offset + 3].imshow(ds[0], animated=True))
        
    def animate(delta):
        for local_i in range(len(runs)):
            local_row_offset = local_i * num_cols
            local_ms, local_ws, local_fs, local_ds = data[local_i]
        
            images[local_row_offset + 0].set_data(local_ms[delta])
            images[local_row_offset + 1].set_data(local_ws[delta])
            images[local_row_offset + 2].set_data(local_fs[delta])
            images[local_row_offset + 3].set_data(local_ds[delta])
    
        return images
    
    ani = animation.FuncAnimation(fig, animate, frames=len(data[0][0]), blit=True)
    return ani

In [None]:
series_anim_main(TEST_NAME)