# Series animation

In [None]:
GROUP_NAME = "criteria_flow"
EVAL_PATH = f"../model_eval/test_{GROUP_NAME}"
HTML_PATH = "./analysis/html"
INTERACTIVE = False

In [None]:
import os
import sys

sys.path.append('../')

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

## Series animation functions

In [None]:
def load_images(base_path, targets, runs):
    data = []
    for run in runs:
        run_data = []
        for target in targets:
            target_path = os.path.join(base_path, run, target)
            images = []
            for img_name in sorted(os.listdir(target_path)):
                if img_name.endswith('.png'):
                    img_path = os.path.join(target_path, img_name)
                    img = Image.open(img_path)
                    img_array = np.array(img)
                    images.append(img_array)
            run_data.append(images)
        data.append(run_data)
    return data

In [None]:
def series_anim_main(base_path):
    targets = ["moving_series", "warped_series", "flow_series", "diff_series"]
    runs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    data = load_images(base_path, targets, runs)

    
    # Configuration for inline display
    plt.rcParams["animation.html"] = "jshtml"
    plt.rcParams["animation.embed_limit"] = 2048
    plt.rcParams['figure.dpi'] = 150
    %matplotlib inline
    
    # Define the number of columns and rows for the subplots
    num_cols = len(targets)
    num_rows = len(runs)
    fig, axs = plt.subplots(ncols=num_cols, nrows=num_rows, figsize=(2 * num_cols, 2 * num_rows))
    axs = axs.flatten()
    
    images = []
    for i in range(num_rows):
        row_offset = i * num_cols
        
        y_pos = 1 - ((i + 1) / float(num_rows + 1))  # Adjust the vertical position
        fig.text(0.01, y_pos, f"{runs[i]}", ha='right', va='center', fontsize=10, transform=fig.transFigure)
        
        if i == 0:       
            axs[row_offset + 0].set(title=r"$\mathit{m}$")
            axs[row_offset + 1].set(title=r"$\mathit{m \circ \phi}$")
            axs[row_offset + 2].set(title=r"$\mathit{\phi}$")
            axs[row_offset + 3].set(title=r"$\mathit{\left| \; (m \circ \phi) - f \; \right|}$")
    
        ms, ws, fs, ds = data[i]
    
        for k in range(num_cols):
            axs[row_offset + k].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], aspect="equal")
    
        images.append(axs[row_offset + 0].imshow(ms[0], animated=True))
        images.append(axs[row_offset + 1].imshow(ws[0], animated=True))
        images.append(axs[row_offset + 2].imshow(fs[0], animated=True))
        images.append(axs[row_offset + 3].imshow(ds[0], animated=True))
        
    def animate(delta):
        for local_i in range(len(runs)):
            local_row_offset = local_i * num_cols
            local_ms, local_ws, local_fs, local_ds = data[local_i]
        
            images[local_row_offset + 0].set_data(local_ms[delta])
            images[local_row_offset + 1].set_data(local_ws[delta])
            images[local_row_offset + 2].set_data(local_fs[delta])
            images[local_row_offset + 3].set_data(local_ds[delta])
    
        return images
    
    ani = animation.FuncAnimation(fig, animate, frames=len(data[0][0]), blit=True)
    return ani

## Main

In [None]:
ani = series_anim_main(EVAL_PATH)

In [None]:
%%script echo "skip"
if not INTERACTIVE:
    os.makedirs(HTML_PATH, exist_ok=True)
    ani.save(f"{HTML_PATH}/test_{GROUP_NAME}.html", writer='html')
    custom_css = (
        "\n<style>\n"
        "img {\n"
        "   margin: 0 auto;\n"
        "   display: block;\n"
        "   width: 50%;\n"
        "}\n"
        "</style>\n"
    )
    with open(f"{HTML_PATH}/test_{GROUP_NAME}.html", 'rw') as file:
        html_content = file.read()

        html_content = html_content + custom_css
        file.write(html_content)
ani if INTERACTIVE else None