# Series export

In [None]:
GROUP_NAME = "<group_name>"
MODEL_PATH = f"../model_weights/test_{GROUP_NAME}"
EVAL_PATH = f"../model_eval/test_{GROUP_NAME}"
SAMPLE_IDX = 1

In [None]:
import gc
import os
import sys

sys.path.append('../')

from PIL import Image, ImageDraw, ImageFont
from torchvision.utils import save_image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch

from reg.data import LungDataModule
from reg.transmorph import TransMorphModule

## Series export functions

In [None]:
def sorted_files_in_directory(directory: str):
    run_ids = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
    runs = [(d, os.path.join(directory, d)) for d in run_ids]
    files = [(run_id, sorted(os.listdir(path_suffix), reverse=True)) for run_id, path_suffix in runs]
    files_best = [(run_id, ff[0]) for run_id, ff in files]
    return files_best

def load_best_model(model_path: str):
    model = TransMorphModule.load_from_checkpoint(str(model_path), strict=True)
    print(f"{'=' * 5} Configuration summary {'=' * 92}")
    print(f"")
    print(model.hparams)
    print(f"")
    print("=" * 120)
    return model

def setup_data_module():
    n_available_cores = len(os.sched_getaffinity(0)) - 1
    n_available_cores = 1 if n_available_cores == 0 else n_available_cores
    data_module = LungDataModule(
        root_dir="/media/agjvc_rad3/_TESTKOLLEKTIV/Daten/Daten",
        split=(0.7, 0.1, 0.2),
        seed=42,
        pin_memory=True,
        num_workers=n_available_cores,
    )
    data_module.setup()
    return data_module

def compute_diff_series(warped_series: torch.Tensor, fixed_image: torch.Tensor):
    abs_diff_series = torch.stack([torch.abs(w - fixed_image) for w in warped_series], dim=0)
    return abs_diff_series

def compute_flow_series(flow_series: torch.Tensor):
    flow_series = torch.tanh(flow_series[:, :, :])
    flows_x = (flow_series[:, :, :, 0] + 1) / 2
    flows_y = (flow_series[:, :, :, 1] + 1) / 2
    flows_z = flows_x * 0
    flow_series = torch.stack([flows_x, flows_y, flows_z], dim=-1)
    return flow_series

def fetch_sample_from_dataloader(dataloader, sample_idx):
    for i, batch in enumerate(dataloader):
        if i == sample_idx or sample_idx is None:
            return batch
        
def add_text_to_image(image, text, position=(10, 10), font_size=20, color=(255, 255, 255)):
    pil_img = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_img)
    try:
        font = ImageFont.truetype("arial.ttf", font_size)
    except IOError:
        font = ImageFont.load_default()
    draw.text(position, text, font=font, fill=color)
    return np.array(pil_img)

def save_images_to_directory(directory, images, cmap):
    if not os.path.exists(directory):
        os.makedirs(directory)
    for idx, img in tqdm(enumerate(images), total=images.shape[0]):
        img_path = os.path.join(directory, f"{idx:0>3}.png")
        if cmap is not None:
            if img.ndim == 3 and img.shape[2] == 1:
                img = img.cpu().numpy()[:, :, 0]
            plt.imsave(img_path, img, cmap=cmap)
        else:
            img = img.permute(2, 0, 1)
            save_image(img, img_path)
        img = Image.open(img_path)
        img = add_text_to_image(np.array(img), f"{idx:0>3}")
        img = Image.fromarray(img)
        img.save(img_path)

In [None]:
def main_series_export(model_path, eval_path, sample_idx=None):
    data_module = setup_data_module()
    dataloader = data_module.test_dataloader()

    if sample_idx is None:
        sample_idx = np.random.randint(0, 64)

    dim = (3, 1, 2, 0)
    moving_series = fetch_sample_from_dataloader(dataloader, sample_idx).cuda()
    moving_series_exp = moving_series[0].permute(dim)
    
    all_runs = sorted_files_in_directory(model_path)
    for run_id, path_suffix in all_runs:
        run_model_path = os.path.join(os.path.join(model_path, run_id), path_suffix)
    
        # 1. Load best model
        model = load_best_model(run_model_path)

        # 2. Extract predictions and inputs
        with torch.no_grad():
            model.eval()
            warped_series, flow_series, fixed_image = model(moving_series)
    
        warped_series = warped_series[0].permute(dim)
        flow_series = flow_series[0].permute(dim)
        fixed_image = fixed_image[0].permute(dim)[0]
    
        # 3. Compute series difference
        diff_series = compute_diff_series(warped_series, fixed_image)
    
        # 4. Transform flow series
        transformed_flow = compute_flow_series(flow_series)
    
        # 5. Save images and video
        arr = [
            (moving_series_exp, "moving_series", None),
            (warped_series, "warped_series", None),
            (transformed_flow, "flow_series", None),
            (diff_series, "diff_series", "magma")
        ]
    
        for images, name, cmap in arr:
            save_images_to_directory(f"{eval_path}/{run_id}/{name}", images, cmap)

        del warped_series, flow_series, fixed_image, transformed_flow, diff_series
        
        gc.collect()
        torch.cuda.empty_cache()
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))

## Main

In [None]:
gc.collect()
torch.cuda.empty_cache()
main_series_export(MODEL_PATH, EVAL_PATH, SAMPLE_IDX)