# Generating Animations

### Import Modules

In [2]:
# extern modules
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import imageio.v3 as iio

# intern modules
from PNN_model_1d import ParameterizedNeuralNet

## 1) Load the Model

In [3]:
# path to file
PATH: str = "../trained_models/PNN_1d.pth"

# load model from 
model = ParameterizedNeuralNet()
PNN_state_dict = torch.load(PATH)
model.load_state_dict(PNN_state_dict)

<All keys matched successfully>

## 2) Animation

### Generate Multiple Frames and Save in Dedicated Folder

In [4]:
# background
bg = 0.400

# trained range
x_range = np.linspace(0.00, 1.00, 1000)

# range for scanning through NN
signals = np.linspace(0,1,100)

# iterating through the NN
for signal_idx, signal in enumerate(signals):

    # generating background and hypothesis
    hypothesis = np.ones_like(x_range) * signal
    background = np.ones_like(x_range) * bg

    # creation data vector for the PNN
    data = np.empty((1000, 3))
    data[:, 0] = x_range
    data[:, 1] = background
    data[:, 2] = hypothesis

    # pushing to tensor format
    data_tensor = torch.tensor(data).float()

    # model evaluation
    res = model(data_tensor).detach().numpy()

    # histograms
    counts_signals, bins_signal = np.histogram(np.random.normal(loc=signal, scale=0.03, size=1000), bins = 10)
    counts_bg, bins_bg = np.histogram(np.random.normal(loc=bg, scale=0.03, size=1000), bins = 10)
    
    # centeralizing the bins
    bins_center_bg = (bins_bg[1:] + bins_bg[:-1])/2
    bins_center_signal = (bins_signal[1:] + bins_signal[:-1])/2

    # normalizing the counts
    counts_signals = counts_signals/counts_signals.max()
    counts_bg = counts_bg/counts_bg.max()

    # histograms in the plot
    plt.fill_between(bins_center_signal, counts_signals, alpha=0.5)
    plt.fill_between(bins_center_bg, counts_bg, alpha=0.5)

    # plot the function with meta data
    plt.plot(x_range, res, label=f"H: Signal({np.round(signal,2)}) against background {bg}")
    plt.ylabel("Probability to discard background hypothesis.")

    # plt.legend() -> looks better in the GIF
    plt.xlabel("x")

    # save the plot
    plt.savefig((f'../frames/1d_signal_idx_{signal_idx}.png'))

    # flush the figure
    plt.close()

### Stack Multiple Frames in GIF

In [5]:
frames = np.stack([iio.imread(f'../frames/1d_signal_idx_{signal_idx}.png') for signal_idx, signal in enumerate(signals)], axis=0)
iio.imwrite('../animations/1d_signal.gif', frames)

### Display GIF

![](../animations/1d_signal.gif)