In [None]:
import sys
sys.path.append('../')

In [None]:
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from pathlib2 import Path
from dataset import reader

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams["animation.embed_limit"] = 2048
plt.rcParams['figure.dpi'] = 150
%matplotlib inline

## Load data

In [None]:
root = Path("/media/agjvc_rad3/_TESTKOLLEKTIV/Daten/Daten/Feldstärke_P1_1_5T/Series1/")
original_path = root / "dicoms.mat"
ant_path = root / "images_reg_av/dicoms.mat"
tm_path = root / "images_reg_av_transmorph/dicoms.mat"

original_data = reader(original_path)[0].transpose(2,0,1)
ant_data = reader(ant_path)[0].transpose(2,0,1)
tm_data = reader(tm_path)[0].transpose(2,0,1)

## Visualisation

In [None]:
num_cols = 3
num_rows = 1
fig, axs = plt.subplots(ncols=num_cols, nrows=num_rows, figsize=(2 * num_cols, 2 * num_rows))
axs = axs.flatten()
images = []

axs[0].set(title=r"Original")
axs[1].set(title=r"A.N.T.")
axs[2].set(title=r"TransMorph")

for k in range(num_cols):
    axs[k].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], aspect="equal")

images.append(axs[0].imshow(original_data[0], cmap="gray", animated=True))
images.append(axs[1].imshow(ant_data[0], cmap="gray", animated=True))
images.append(axs[2].imshow(tm_data[0], cmap="gray", animated=True))

def animate(delta):
    images[0].set_data(original_data[delta])
    images[1].set_data(ant_data[delta])
    images[2].set_data(tm_data[delta])
    return images

animation.FuncAnimation(fig, animate, frames=200, blit=True)

## Analysis

In [None]:
def calculate_stats(image):
    image_means = image.mean(axis=(1, 2))
    image_vars = image.var(axis=(1, 2))
    return image_means, image_vars

In [None]:
image_means, image_vars = calculate_stats(original_data)

In [None]:
figsize=(16,5)

fig, ax = plt.subplots(1, 2, figsize=figsize)
fig.set_tight_layout(True)
ax = ax.flatten()

ax[0].set_title("Mean")
ax[1].set_title("Variance")

ax[0].plot(image_means, "-", lw=1)
ax[1].plot(image_vars, "-", lw=1)

plt.show()
plt.close()