In [None]:
from pathlib import Path
import sys

import matplotlib.animation as animation
import torch
import numpy as np

from leap3d.config import DATA_DIR
from leap3d.scanning import ScanParameters
from leap3d.plotting import plot_top_layer_temperature, plot_top_view_scan_boundaries
from leap3d.dataset import prepare_scan_results
from matplotlib import pyplot as plt


In [None]:
ANIMATION_FRAME_DURATION = 1
GIF_FPS = 60
PLOT_DIR = Path("../Plots/")

In [None]:
scan_result_filepath = DATA_DIR / "case_0006.npz"
params_file = DATA_DIR / DATA_DIR / "Params.npy"
rough_coordinates = DATA_DIR / "Rough_coord.npz"
case_id = 6
scan_parameters = ScanParameters(params_file, rough_coordinates, case_id)

In [None]:
x_input, target = prepare_scan_results(scan_result_filepath, scan_parameters, dims=2)
x_input = torch.tensor(x_input, dtype=torch.float32)
x_input = x_input.to("cuda")
target = torch.tensor(target, dtype=torch.float32)

In [None]:
from leap3d.models.lightning import LEAP3D_UNet2D


models = {
    "UNet2D": LEAP3D_UNet2D.load_from_checkpoint("./model_checkpoints/unet2d-v2.ckpt"),
    "UNet2D_norm_naive": LEAP3D_UNet2D.load_from_checkpoint("./model_checkpoints/unet2d_normalized_naive-v1.ckpt"),
    "UNet2D_norm": LEAP3D_UNet2D.load_from_checkpoint("./model_checkpoints/unet2d_normalized.ckpt")
}

In [None]:
for model in models.values():
    model.eval()

In [None]:
from leap3d.config import MELTING_POINT

def normalize_temperature_0(x):
        new_x =  x.clone()
        if len(new_x.shape) == 2:
            new_x /= MELTING_POINT
            return new_x
        new_x[:, :, -1] /= MELTING_POINT
        return new_x

def unnormalize_temperature_0(x):
        new_x = x.clone()
        if len(new_x.shape) == 2:
            new_x *= MELTING_POINT
            return new_x
        new_x[:, :, -1] *= MELTING_POINT
        return new_x

def normalize_temperature_1(x):
        new_x =  x.clone()
        if len(new_x.shape) == 2:
            new_x = (new_x - 300) / (MELTING_POINT - 300)
            return new_x
        new_x[:, :,-1] = (new_x[:, :, -1] - 300) / (MELTING_POINT - 300)
        return new_x

def unnormalize_temperature_1(x):
        new_x = x.clone()
        if len(new_x.shape) == 2:
            new_x = new_x * (MELTING_POINT - 300) + 300
            return new_x
        new_x[:, :, -1] = new_x[:, :, -1] * (MELTING_POINT - 300) + 300
        return new_x

In [None]:
x_min, x_max, y_min, y_max, *_ = scan_parameters.get_bounds()

fig, ax = plt.subplots(sharex=True, sharey=True, ncols=len(models) * 2 + 1)
fig.set_figwidth(5 * (len(ax)) + 5)

frames = len(target)

for axis in ax:
    axis.set_xlim(x_min, x_max)
    axis.set_ylim(y_min, y_max)
    axis.set_aspect('equal', adjustable='box')

transforms = [lambda x: x, normalize_temperature_0, normalize_temperature_1]
untransforms = [lambda x: x, unnormalize_temperature_0, unnormalize_temperature_1]

ims = []
for i in range(0, frames, 5):
    ims_at_timestep = []

    im = plot_top_layer_temperature(ax[0], target[i], scan_parameters, False)
    ims_at_timestep.append(im)

    for index, ((model_name, model), transform, untransform) in enumerate(zip(models.items(), transforms, untransforms)):
        model_output = untransform(model(transform(x_input[i]))).squeeze().cpu().detach().numpy()
        im = plot_top_layer_temperature(ax[index * 2 + 1], model_output, scan_parameters, False)
        ims_at_timestep.append(im)
        im = plot_top_layer_temperature(ax[index * 2 + 2], np.abs(target[i] - model_output), scan_parameters, False)
        ims_at_timestep.append(im)


    ims.append(ims_at_timestep)

for axis in ax:
    plot_top_view_scan_boundaries(axis, scan_parameters)

fig.colorbar(ims[0][0], ax=ax[-1])

In [None]:
fig.get_axes()[0].set_title("Top layer view of temperature map over time")
ani = animation.ArtistAnimation(fig, ims, interval=ANIMATION_FRAME_DURATION, blit=True,
                                repeat_delay=1000)
ani.save(PLOT_DIR / "aaa2.gif", fps=GIF_FPS)

In [None]:
x_min, x_max, y_min, y_max, *_ = scan_parameters.get_bounds()

fig, ax = plt.subplots(sharex=True, sharey=True, ncols=len(models) + 1)
fig.set_figwidth(5 * (len(ax)) + 10)

frames = 500

for axis in ax:
    axis.set_xlim(x_min, x_max)
    axis.set_ylim(y_min, y_max)
    axis.set_aspect('equal', adjustable='box')

ims = []

new_x_input = [x_input[0], normalize_temperature_0(x_input[0]), normalize_temperature_1(x_input[0])]

for i in range(0, frames, 1):
    print(i, end='\r')
    ims_at_timestep = []

    im = plot_top_layer_temperature(ax[0], target[i], scan_parameters, False)
    ims_at_timestep.append(im)

    for index, ((model_name, model), transform, untransform) in enumerate(zip(models.items(), transforms, untransforms)):
        model_output = model(new_x_input[index])
        # print(x_input[0].shape, model_output.shape)
        new_x_input[index][:,:, -1] = model_output.squeeze()
        model_output_plotting = untransform(model_output).squeeze().cpu().detach().numpy()

        im = plot_top_layer_temperature(ax[index + 1], model_output_plotting, scan_parameters, False)
        ims_at_timestep.append(im)

    ims.append(ims_at_timestep)

for axis in ax:
    plot_top_view_scan_boundaries(axis, scan_parameters)

fig.colorbar(ims[0][0], ax=ax[-1])

In [None]:
fig.get_axes()[0].set_title("Top layer view of temperature map over time")
ani = animation.ArtistAnimation(fig, ims, interval=ANIMATION_FRAME_DURATION * 5, blit=True,
                                repeat_delay=1000)
ani.save(PLOT_DIR / "bbb2.gif", fps=GIF_FPS)