# Introduction

The purpose of this notebook is to demonstrate how to structure training data and how to enact the interaction between the two networks in a Hyper Network model using PyTorch. A Hyper Network is composed of two Neural Networks; the HyperNet and the TargetNet (sometimes referred to as MainNet). The weights and biases of the TargetNet are the output of the HyperNet, simply rearranged to fit the dimensions of the TargetNet weights and biases tensors.

Consider a use case for a computer vision application for a camera with an optical zoom that causes enough image distortion that object detection performance in the corners of the image is drastically reduced. One solution could be to use different CNN models for different zoom levels if local processing and storage resources allow for this. One drawback is that between the specified zoom levels of each CNN model the user will need to select which model leads to more accurate detections. As you can imagine, this is not an ideal solution when distortion is this significant. An alternative solution is to use a HyperNetwork-CNN structure. The training data for this scenario will be generated in the same way, labeling images at specified zoom levels, however the training algorithm will change. The TargetNet will be set as a CNN and a HyperNet as a DNN. The zoom level for a specified image will be used as the HyperNet input, the HyperNet will output all parameters for the CNN, and the back propagation algorithm will tune the HyperNet to produce a more accurate CNN. In other words, we are not training the TargetNet (CNN) to detect objects in images, we are training the HyperNet (DNN) to produce a CNN that will accurately detect objects based on the current zoom level no matter the value. Instead of having, for example, 11 models for different zoom levels {0%, 10%, ..., 100%} and assuming which model to use for levels in between we can set an arbitrary zoom level on our lens as the input to our HyperNet and produce an accurate CNN at run time.

# Input Data Generation

Before we do anything with Hyper Networks, we first need to create our data arrays. For this demonstration we will create a Long Short Term Memory (LSTM) network that predicts sinusoidal wave forms and use this as our TargetNet. Regular LSTM networks will have poor performance when alterations are made to frequency, and to a lesser extent amplitude, so we will select these two values as the inputs to our HyperNet.

The test and validation data sets will be on true sine waves with no added noise. There are (2) test data sets and (2) validation data sets. Each will have (1) set with HyperNet inputs that are found in the training data set, and (1) set with HyperNet inputs that were not in the training data set. In this way, we see how well the trained Hyper Network predicts a true sine wave when only being given noisy sine waves for a parameter set, and also how it predicts a true sine wave with parameters it has never seen.

In [None]:
import numpy as np
import plotly.graph_objects as go

## Prep

As mentioned above, the inputs to the HyperNet will be Amplitude and Frequency. The following variables define the range and number of samples for the Amplitude and Frequency arrays.

In [None]:
NUM_AMP_PTS = 15
MIN_AMP = 0.1
MAX_AMP = 0.9

NUM_FREQ_PTS = 15
MIN_FREQ = 2
MAX_FREQ = 30

The following variables are used to create the sine waves for all data sets.

* `SEQ_LEN` and `PRED_LEN` are used to create input and prediction sequences for the LSTM TargetNet.
* `TRAIN_STEP` and `TEST_STEP` define how much to step forward between sequenced data.
    * In some demonstrations an `overlap` is defined instead.
* `ARRAY_LEN` defines the number of points in the sine waves.
* Finally, the last two variables define the number of cycles in the train and test/validation data sets.

In [None]:
SEQ_LEN = 200
PRED_LEN = 1
TRAIN_STEP = 5
TEST_STEP = 1
ARRAY_LEN = 2000
NUM_TEST_CYCLES = 5
ARRAYS_PER_PARAM_SET = 5
NOISE_FACTOR = 0.1

These define the number of arrays in:
* The ***N***umber of ***T***est data sets with HyperNet ***T***raining ***P***arameters
* The ***N***umber of ***T***est data sets with ***N***ew HyperNet ***P***arameters
* The ***N***umber of ***V***alidation data sets with HyperNet ***T***raining ***P***arameters
* The ***N***umber of ***V***alidation data sets with ***N***ew HyperNet ***P***arameters

In [None]:
NTTP = 5
NTNP = 5
NVTP = 5
NVNP = 5

## Sequencing

This function will take a sine wave and convert it into LSTM sequences for later batching.

In [None]:
def sequence_array(sine: np.ndarray, step: int) -> tuple[np.ndarray]:
    """
    Convert sinusoidal arrays to LSTM sequences and predictions

    Parameters
    ----------
    sine : np.ndarray
        array of sinusoidal y-values
    step : int
        forward step for data

    Returns
    -------
    np.ndarray
        stack of input sequences
    np.ndarray
        stack of prediction arrays
    """
    seqs, preds, start = None, None, 0
    for _ in range(int((ARRAY_LEN - SEQ_LEN - PRED_LEN) / step) + 1):
        seq = sine[start : start + SEQ_LEN].reshape(1, SEQ_LEN, 1)
        pred = sine[start + SEQ_LEN : start + SEQ_LEN + PRED_LEN].reshape(1,)  # fmt:skip
        seqs = seq if seqs is None else np.vstack([seqs, seq])
        preds = pred if preds is None else np.vstack([preds, pred])
        start += step
    return seqs, preds

## Training Arrays

Here we create the noisy sine waves that are used for the training data set. We loop over the Amplitude and Frequency sampling arrays and record them separately from the sine wave data. Then we create and sequence the noisy sine waves and then save the arrays in a dictionary.

Once all noisy sine waves are created and sequenced, we save the list of dictionaries of arrays to a compressed NumPy file.

In [None]:
A = np.linspace(MIN_AMP, MAX_AMP, NUM_AMP_PTS)
F = np.linspace(MIN_FREQ, MAX_FREQ, NUM_FREQ_PTS)
t = np.linspace(0, 2 * np.pi, ARRAY_LEN)

TRAIN_DATA, count = [], 0
for amp in A:
    for f in F:
        hyper_params = np.array([amp, f])
        for i in range(ARRAYS_PER_PARAM_SET):
            sine = amp * np.sin(f * t) + (amp * NOISE_FACTOR * (np.random.rand(*t.shape) - 0.5))
            seqs, preds = sequence_array(sine, TRAIN_STEP)
            TRAIN_DATA.append({"hx": hyper_params, "tx": seqs, "tyhat": preds})
            count += 1
            print(f"Created and sequenced {count} arrays", end="\r")

print(f"Created and sequenced {count} arrays")

## Test and Validation (Training HyperNet Parameters)

Now we create the test and validation data set with HyperNet input values that were found in the training data set created above.

We randomly select different values of Amplitude and Frequency and create a true sine wave representation for those values.

For a true representation of testing and validating our model, there can be no values in one data set that are also found in the other.

Once all arrays are created and sequenced, we save them to the file names defined at the beginning of the notebook.

In [None]:
A_sels, F_sels, TTP_DATA, VTP_DATA = [], [], [], []

for _ in range(NTTP):

    # select random A and F from HyperNet training parameters
    sel_A = np.random.randint(0, len(A) - 1, (100,))
    sel_F = np.random.randint(0, len(F) - 1, (100,))
    sel_A = [x for x in sel_A if x not in A_sels][0]
    sel_F = [x for x in sel_F if x not in F_sels][0]
    A_sels.append(sel_A)
    F_sels.append(sel_F)

    # create tensors
    hyper_params = np.array([A[sel_A], F[sel_F]])
    sine = A[sel_A] * np.sin(F[sel_F] * t)

    # sequence sine wave array
    seqs, preds = sequence_array(sine, TRAIN_STEP)

    # append data set
    TTP_DATA.append({"hx": hyper_params, "tx": seqs, "tyhat": preds})

for _ in range(NVTP):

    # select random A and F from training HyperNet Parameters
    sel_A = np.random.randint(0, len(A) - 1, (100,))
    sel_F = np.random.randint(0, len(F) - 1, (100,))
    sel_A = [x for x in sel_A if x not in A_sels][0]
    sel_F = [x for x in sel_F if x not in F_sels][0]
    A_sels.append(sel_A)
    F_sels.append(sel_F)

    # create tensors
    hyper_params = np.array([A[sel_A], F[sel_F]])
    sine = A[sel_A] * np.sin(F[sel_F] * t)

    # sequence sine wave array
    seqs, preds = sequence_array(sine, TEST_STEP)

    # append data set
    VTP_DATA.append({"hx": hyper_params, "tx": seqs, "tyhat": preds})

## Test and Validation (New HyperNet Parameters)

Now we create the test and validation data sets for values that were not included in the training data set. We use NumPy's `rand()` funnction to generate 1000 different values and select the first one that was not in the training data set and not already selected previously.

Once all arrays are created and sequenced, we save them to the file names defined at the beginning of the notebook.

In [None]:
min_A, max_A = min(A), max(A)
min_F, max_F = min(F), max(F)

A_sels, F_sels, TNP_DATA, VNP_DATA = [], [], [], []

for _ in range(NTNP):

    # select random A and F not in HyperNet training parameters
    sel_A = np.random.rand(1000)
    sel_F = np.random.rand(1000) * max_F
    sel_A = [x for x in sel_A if x not in A and x not in A_sels and min_A < x < max_A][0]
    sel_F = [x for x in sel_F if x not in F and x not in F_sels and min_F < x < max_F][0]
    A_sels.append(sel_A)
    F_sels.append(sel_F)

    # create tensors
    hyper_params = np.array([sel_A, sel_F])
    sine = sel_A * np.sin(sel_F * t)

    # sequence sine wave array
    seqs, preds = sequence_array(sine, TRAIN_STEP)

    # append data set
    TNP_DATA.append({"hx": hyper_params, "tx": seqs, "tyhat": preds})


for _ in range(NVNP):

    # select random A and F not in HyperNet training parameters
    sel_A = np.random.rand(1000)
    sel_F = np.random.rand(1000) * max_F
    sel_A = [x for x in sel_A if x not in A and x not in A_sels and min_A < x < max_A][0]
    sel_F = [x for x in sel_F if x not in F and x not in F_sels and min_F < x < max_F][0]
    A_sels.append(sel_A)
    F_sels.append(sel_F)

    # create tensors
    hyper_params = np.array([sel_A, sel_F])
    sine = sel_A * np.sin(sel_F * t)

    # sequence sine wave array
    seqs, preds = sequence_array(sine, TEST_STEP)

    # append data set
    VNP_DATA.append({"hx": hyper_params, "tx": seqs, "tyhat": preds})

# Load Data to PyTorch

The following custom Dataset class is used to batch the train and test data sets with a Dataloader. `x` is the TargetNet input tensor, and `yhat` is the TargetNet output target values tensor.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

DEVICE = torch.device("cuda:0")
BATCH_SIZE = 25


class RNNDataset(Dataset):

    def __init__(self, x: torch.Tensor, yhat: torch.Tensor) -> None:
        super().__init__()
        self.x = x
        self.yhat = yhat

    def __len__(self) -> int:
        return len(self.x)

    def __getitem__(self, idx) -> tuple[torch.Tensor]:
        return self.x[idx], self.yhat[idx]

This function will parse the loaded datasets. It will convert all NumPy Arrays to PyTorch Tensors and ensure that the tensor is on the correct device and of a compatible data format. With this function defined we parse the training and test data sets and shuffle the training data set but not the test data sets.

In [None]:
def parse_input_dataset(dataset: list, sh: bool) -> list:
    """
    load RNN data in dataset with `RNNDataset` and `DataLoader`

    Parameters
    ----------
    dataset : list
        list of data arrays
    sh : bool
        shuffle Dataset if True, else False

    Returns
    -------
    list
        parsed dataset
    """
    for i, ds in enumerate(dataset):
        hx = torch.from_numpy(ds["hx"]).to(DEVICE).float()
        tx = torch.from_numpy(ds["tx"]).to(DEVICE).float()
        tyhat = torch.from_numpy(ds["tyhat"]).to(DEVICE).float()
        rnn_data = DataLoader(
            RNNDataset(tx, tyhat), batch_size=BATCH_SIZE, shuffle=sh, drop_last=True
        )
        dataset[i] = [hx, rnn_data]
    return dataset


TRAIN_DATA = parse_input_dataset(TRAIN_DATA, True)
TTP_DATA = parse_input_dataset(TTP_DATA, False)
TNP_DATA = parse_input_dataset(TNP_DATA, False)

# Modifying TargetNet Parameters

The TargetNet parameters have to be modified in different ways depending on if the model is being trained or is being used for normal inference.

## During training

We cannot directly modify the TargetNet parameters during training. If we attempted this, PyTorch will raise an exception that parameters were modified by an in-place operation. To run the TargetNet with the parameters from the HyperNet (`hy`), we can use `torch.func.functional_call(targetNet, tn_params, tx)` where `targetNet` is the TargetNet model, `tn_params` is a dictionary with key-value pairs of the parameter name and the new tensor, and `tx` is the TargetNet input tensor.

***Note:*** If you are unfamiliar, `math.prod()` is similar to `sum()`, but multiplies all values of an iterable together.

```python
# collect new parameters in a dict with name of parameter as key
idx, tn_params = 0, {}
for n, p in targetNet.named_parameters():
    tn_params[n] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
    idx += math.prod(p.shape)

# use PyTorch's functional_call
ty = torch.func.functional_call(targetNet, tn_params, tx)
```

## During normal inference operation

We cannot use the above method to modify TargetNet parameters during normal inference because using `functional_call()` will throw raise an exception when not training. However, not that we are not training, we can modify the parameter tensors in-place without raising an exception. Then, we simply use the `forward()` method of the TargetNet.

```python
# modify parameters in place
idx = 0
for p in targetNet.parameters():
    p[:] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
    idx += math.prod(p.shape)

# use TargetNet's forward method
ty = targetNet(tx)
```

## Imports

In [None]:
import math

import plotly

from torch import optim
from torch.func import functional_call
from torch.nn import Module, Sequential, Linear, LSTM, ReLU, Tanh, MSELoss

# Model Parameters

The following constants are used to define the TargetNet LSTM and the HyperNet DNN.

`HYPERNET_NODES` is a list of the number of nodes in each layer, except for the output layer to the TargetNet.

`HYPERNET_ACTS` is a list of activation functions for each hidden layer *and* the output layer to the TargetNet.

In [None]:
N_LSTM_LAYERS = 3
LSTM_IN_DIM = 1
LSTM_HIDDEN_DIM = 10
LSTM_OUT_DIM = 1

HYPERNET_NODES = [2, 8, 12, 12, 8, 2]
HYPERNET_ACTS = [ReLU(), ReLU(), ReLU(), ReLU(), ReLU(), Tanh()]

# Training Setup

## Parameters

The following constants are values used to configure the training settings.

In [None]:
EPOCHS = 15
LEARN_RATE = 1e-4
LOSS_FN = MSELoss()
MODEL_NAME = "sinusoidal_hypernet.pt"

## Functions

This function will train the HyperNetwork model for a single epoch for a given HyperNet input array `hx`. Note that as described at the top of the notebook, we create dictionaries of the named parameter tensors with the output of the HyperNet. Once these are created, we run the TargNet on the batched noisy sine wave sequences using `functional_call()`.

The loss is calculated and summed across batches and returned to the main training loop below.

In [None]:
def train_epoch(
    hyperNet: Sequential,
    targetLSTM: LSTM,
    targetSeq: Sequential,
    hx: torch.Tensor,
    t_data: DataLoader,
    batch_size: int,
    loss_fn: Module,
    device: torch.device,
) -> torch.Tensor:
    """
    Train HyperNetwork model for a given epoch

    Parameters
    ----------
    hyperNet : Sequential
        HyperNetwork described as a PyTorch Sequential model
    targetLSTM : LSTM
        TargetNetwork LSTM portion
    targetSeq : Sequential
        TargetNetwork Sequential portion
    hx : torch.Tensor
        HyperNetwork input tensor
    t_data : DataLoader
        Batched training data
    batch_size : int
        Training data batch size
    loss_fn : Module
        PyTorch defined loss function
    device : torch.device
        PyTorch defined computation device

    Returns
    -------
    torch.Tensor
        calculated loss tensor
    """

    # hidden/cell state dimensions
    hc_state_dim = (targetLSTM.num_layers, batch_size, targetLSTM.hidden_size)

    # run HyperNet
    hy = hyperNet(hx)

    # collect TargetNet weights and biases in dicts
    idx, tn_params = 0, {}
    for n, p in targetLSTM.named_parameters():
        tn_params[n] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
        idx += math.prod(p.shape)
    tns_params = {}
    for n, p in targetSeq.named_parameters():
        tns_params[n] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
        idx += math.prod(p.shape)

    # run LSTM data in batches
    loss = 0
    for _, batch in enumerate(t_data):
        tx, tyhat = batch[0].to(device), batch[1].to(device)
        h0 = torch.zeros(*hc_state_dim).to(device)
        c0 = torch.zeros(*hc_state_dim).to(device)
        ty, _ = functional_call(targetLSTM, tn_params, (tx, (h0, c0)))
        ty = functional_call(targetSeq, tns_params, ty[:, -1, :])
        loss += loss_fn(ty, tyhat)

    return loss

Similarly to the `train_epoch()` function defined above, this function will perform the same operation but with the test data set passed. Note that `functional_call()` is still used to run the TargetNet.

In [None]:
def test_epoch(
    hyperNet: Sequential,
    targetLSTM: LSTM,
    targetSeq: Sequential,
    vhx: torch.Tensor,
    vt_data: DataLoader,
    batch_size: int,
    loss_fn: Module,
    device: torch.device,
) -> float:
    """
    Test HyperNetwork model for a given epoch

    Parameters
    ----------
    hyperNet : Sequential
        HyperNetwork described as a PyTorch Sequential model
    targetLSTM : LSTM
        TargetNetwork LSTM portion
    targetSeq : Sequential
        TargetNetwork Sequential portion
    vhx : torch.Tensor
        HyperNetwork validation input tensor
    vt_data : DataLoader
        Batched validation data
    batch_size : int
        Validation data batch size
    loss_fn : Module
        PyTorch defined loss function
    device : torch.device
        PyTorch defined computation device

    Returns
    -------
    float
        calculated loss for this epoch
    """

    hc_state_dim = (targetLSTM.num_layers, batch_size, targetLSTM.hidden_size)

    loss = 0
    with torch.no_grad():

        # run HyperNet
        vhy = hyperNet(vhx)

        # collect TargetNet weights and biases in dicts
        idx, vtn_params = 0, {}
        for n, p in targetLSTM.named_parameters():
            vtn_params[n] = vhy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
            idx += math.prod(p.shape)
        vtns_params = {}
        for n, p in targetSeq.named_parameters():
            vtns_params[n] = vhy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
            idx += math.prod(p.shape)

        # run LSTM data in batches
        for _, batch in enumerate(vt_data):
            vtx, vtyhat = batch[0].to(device), batch[1].to(device)
            vh0 = torch.zeros(*hc_state_dim).to(device)
            vc0 = torch.zeros(*hc_state_dim).to(device)
            vty, _ = functional_call(targetLSTM, vtn_params, (vtx, (vh0, vc0)))
            vty = functional_call(targetSeq, vtns_params, vty[:, -1, :])
            loss += loss_fn(vty, vtyhat)

    return loss.item()

# Create Model Architecture

## TargetNet

Now that our data is loaded and the train and test functions are defined, we have to create the Hyper Network models. We must create the TargetNet first so that we know the output dimension of the HyperNet. Here we create an LSTM with the values defined at the beginning of this notebook and a Linear layer and activation function also defined above. These two are tracked in separate variables to properly use in PyTorch but should be considered together as the TargetNet.

Once these two models are created, we need to calculate total number of parameters in the TargetNet and make sure that all are set to not require gradient so that they are not accidentally picked up by the backpropagation operations.

Finally we move both models to the inference device selected at the beginning of the notebook.

In [None]:
# create TargetNet modules
targetLSTM = LSTM(LSTM_IN_DIM, LSTM_HIDDEN_DIM, N_LSTM_LAYERS, batch_first=True, device=DEVICE)
targetSeq = Sequential(Linear(LSTM_HIDDEN_DIM, LSTM_OUT_DIM), Tanh()).to(DEVICE)

# get number of parameters in TargetNet
target_params = 0
for p in targetLSTM.parameters():
    target_params += math.prod(p.shape)
    p.requires_grad = False
for p in targetSeq.parameters():
    target_params += math.prod(p.shape)
    p.requires_grad = False

## HyperNet

Now that we know the total number of parameters in the TargetNet, we can create the HyperNet DNN. First we append the number of parameters in the TargetNet to the HyperNet nodes list, then create a Sequential model of alternating Linear layers and Activation functions defined above.

Finally, we also move the model to the selected inference device.

In [None]:
HYPERNET_NODES.append(target_params)
layers = []
for i, act in enumerate(HYPERNET_ACTS):
    layers.append(Linear(HYPERNET_NODES[i], HYPERNET_NODES[i + 1]))
    layers.append(act)
hyperNet = Sequential(*layers).to(DEVICE)

# Train Hyper Network

And finally we can start training the HyperNet! First we pass the HyperNet parameters to the optimizer. Note that the optimizer does ***not*** see the TargetNet parameters because we aren't attempting to create an LSTM that can generate sine waves, we are creating a DNN that can *tune* an LSTM to create sine waves of varying Amplitude and Frequency.

We also create a tuple of the dimensions that will be used for the LSTM Hidden and Cell state arrays.

For each epoch, we iterate over the training data sets and back propagate with the loss calculated in the `train_epoch()` function defined above. We then calculate test loss for the test data set with training HyperNet inputs and the test data set with unseen HyperNet inputs.

We output the model that has the best training loss and print out each loss value tracked during training at the end of each epoch.

In [None]:
# create optimizer
optimizer = optim.Adam(hyperNet.parameters(), lr=LEARN_RATE)

# track loss over epochs
best_loss = torch.inf
loss, ttp_loss, tnp_loss = [], [], []

# dimension of LSTM hidden/cell states
hc_state_dim = (targetLSTM.num_layers, BATCH_SIZE, targetLSTM.hidden_size)

# iterate over epochs
for i in range(EPOCHS):

    # train HyperNet for this epoch
    loss_i = 0
    for hx, t_data in TRAIN_DATA:

        # run epoch
        optimizer.zero_grad()
        loss_b = train_epoch(
            hyperNet, targetLSTM, targetSeq, hx, t_data, BATCH_SIZE, LOSS_FN, DEVICE
        )

        # back propagate HyperNet
        loss_b.backward()
        optimizer.step()

        # track loss for HyperNet input set
        loss_i += loss_b.item()

    # test model with training parameters
    ttp_loss_i = 0
    for vhx, vt_data in TTP_DATA:
        ttp_loss_i += test_epoch(
            hyperNet, targetLSTM, targetSeq, vhx, vt_data, BATCH_SIZE, LOSS_FN, DEVICE
        )

    # test model with new parameters
    tnp_loss_i = 0
    for vhx, vt_data in TNP_DATA:
        tnp_loss_i += test_epoch(
            hyperNet, targetLSTM, targetSeq, vhx, vt_data, BATCH_SIZE, LOSS_FN, DEVICE
        )

    # track loss for epoch
    loss.append(loss_i)
    ttp_loss.append(ttp_loss_i)
    tnp_loss.append(tnp_loss_i)

    # save model if it has lowest loss
    if loss_i < best_loss:
        best_loss = loss_i
        hyperNet = hyperNet.cpu()
        torch.save(hyperNet.state_dict(), MODEL_NAME)
        hyperNet = hyperNet.to(DEVICE)

    # print training progress
    print(
        f"Epoch: {i+1:0>2} | Loss: {loss_i:.3e} | "
        f"Best Loss: {best_loss:.3e} | "
        f"TTP Loss: {ttp_loss_i:.3e} | "
        f"TNP Loss: {tnp_loss_i:.3e}"
    )

## Plot Loss History

Now that training has concluded, let's plot the loss history over each epoch.

In [None]:
loss_fig = go.Figure()
loss_fig.add_trace(
    go.Scatter(x=torch.arange(len(loss)) + 1, y=loss, name="Loss", mode="lines", showlegend=True)
)
loss_fig.add_trace(
    go.Scatter(
        x=torch.arange(len(ttp_loss)) + 1,
        y=ttp_loss,
        name="TTP Loss",
        mode="lines",
        showlegend=True,
    )
)
loss_fig.add_trace(
    go.Scatter(
        x=torch.arange(len(tnp_loss)) + 1,
        y=tnp_loss,
        name="TNP Loss",
        mode="lines",
        showlegend=True,
    )
)
loss_fig.update_layout(
    title="Sinusoidal HyperNet Loss",
    yaxis={"title": "Loss"},
    xaxis={"title": "Epoch"},
    hovermode="x unified",
)

loss_fig.show()

## Validation

Now we begin the validation process. First we load the model that was saved during the training loop.

In [None]:
hyperNet.load_state_dict(torch.load(MODEL_NAME, weights_only=True))

This function will act as the `forward()` function for our overall Hyper Network. The name was specifically chosen to match the name used by PyTorch for their models.

Note that unlike the training loop, we are not collecting dictionaries of tensors to use as a substiture for the TargetNet parameters. We are instead modifying the values of these tensors in place after we run the HyperNet.

We also use the `forward()` method of the TargetNet models, rather than using `functional_call()` like when we were training the HyperNet.

In [None]:
def forward(
    hyperNet: Sequential,
    targetLSMT: LSTM,
    targetSeq: Sequential,
    hx: torch.Tensor,
    tx: torch.Tensor,
) -> torch.Tensor:
    """
    Forward calculation for HyperNetwork and TargetNetwork in normal
    operation

    Parameters
    ----------
    hyperNet : Sequential
        HyperNetwork defined as a PyTorch Sequential
    targetLSMT : LSTM
        TargetNetwork LSTM portion
    targetSeq : Sequential
        TargetNetwork Sequential portion
    hx : torch.Tensor
        HyperNetwork input tensor
    tx : torch.Tensor
        TargetNetwork input tensor

    Returns
    -------
    torch.Tensor
        Output array from TargetNetwork
    """

    hy = hyperNet(hx)

    idx = 0
    for p in targetLSMT.parameters():
        p[:] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
        idx += math.prod(p.shape)
    for p in targetSeq.parameters():
        p[:] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
        idx += math.prod(p.shape)

    hc_dims = (targetLSMT.num_layers, tx.shape[0], targetLSMT.hidden_size)
    h0 = torch.zeros(*hc_dims).to(DEVICE)
    c0 = torch.zeros(*hc_dims).to(DEVICE)

    y, _ = targetLSMT(tx, (h0, c0))
    y = targetSeq(y[:, -1, :])

    return y.detach()

This function is similar to the `parse_input_dataset()` function defined above. We still ensure that each NumPy array is converted to a PyTorch tensor on the selected inference device and of a compatible data format. In addition, we call the `forward()` function defined in the above cell to process the results and include the Hyper Network output array in the data set.

In [None]:
def parse_validation_dataset(
    dataset: list, hyperNet: Sequential, targetLSMT: LSTM, targetSeq: Sequential
) -> list:
    """
    Load RNN data in dataset with `RNNDataset` and `DataLoader`

    Parameters
    ----------
    dataset : list
        list of data arrays
    hyperNet : Sequential
        HyperNetwork defined as a PyTorch Sequential
    targetLSMT : LSTM
        TargetNetwork LSTM portion
    targetSeq : Sequential
        TargetNetwork Sequential portion

    Returns
    -------
    list
        List of validation results after inference on input data
    """
    for i, ds in enumerate(dataset):
        hx = torch.from_numpy(ds["hx"]).to(DEVICE).float()
        tx = torch.from_numpy(ds["tx"]).to(DEVICE).float()
        tyhat = torch.from_numpy(ds["tyhat"]).to(DEVICE).float()
        ty = forward(hyperNet, targetLSMT, targetSeq, hx, tx)
        dataset[i] = [hx, tx, ty, tyhat]
    return dataset


VTP_DATA = parse_validation_dataset(VTP_DATA, hyperNet, targetLSTM, targetSeq)
VNP_DATA = parse_validation_dataset(VNP_DATA, hyperNet, targetLSTM, targetSeq)

Now let's plot our results to see how well the HyperNet can predict true sine waves for HyperNet inputs that were included in the training data set and HyperNet inputs that were *not* included in the training data set.

In [None]:
colors = plotly.colors.qualitative.Plotly
color_count = 0

First we plot the results with the validation data set of true sine waves for HyperNet inputs that were included in the training data set and see how well the model predicts the sine waves.

In [None]:
reshape = (-1,)

vtp_fig = go.Figure()
vtp_fig.update_layout(
    title="Training Parameter Validation", yaxis={"title": "Amplitude"}, xaxis={"title": "time"}
)
for i, (hx, tx, ty, tyhat) in enumerate(VTP_DATA):

    tyhat = tyhat.cpu().reshape(*reshape)
    ty = ty.cpu().reshape(*reshape)
    p_set = f"HN Set {i+1}"
    p_hover = "<br>    ".join([f"{x:.3e}" for x in hx.cpu().numpy()])

    vtp_fig.add_trace(
        go.Scatter(
            x=torch.arange(len(tx)),
            y=tyhat.cpu(),
            line={"color": colors[color_count]},
            name="True",
            showlegend=True,
            legendgroup=p_set,
            legendgrouptitle_text=p_set,
            hovertemplate=f"HyperNet Inputs:<br>    {p_hover}<br>" "<br>t: %{x}" "<br>a: %{y}",
        )
    )

    vtp_fig.add_trace(
        go.Scatter(
            x=torch.arange(len(tx)),
            y=ty,
            line={"color": colors[color_count], "dash": "dash"},
            name="Prediction",
            showlegend=True,
            legendgroup=p_set,
            legendgrouptitle_text=p_set,
            hovertemplate=f"HyperNet Inputs:<br>    {p_hover}<br>" "<br>t: %{x}" "<br>a: %{y}",
        )
    )

    color_count += 1
    if color_count == len(colors):
        color_count = 0

vtp_fig.show()

Overall, our prediction is not too bad! There is certainly room for improvement.

Now let's see the results on the validation data set with HyperNet inputs that our model has never seen before.

In [None]:
vnp_fig = go.Figure()
vnp_fig.update_layout(
    title="New Parameter Validation",
    yaxis={"title": "Amplitude"},
    xaxis={"title": "time"},
)
for i, (hx, tx, ty, tyhat) in enumerate(VNP_DATA):

    tyhat = tyhat.cpu().reshape(*reshape)
    ty = ty.cpu().reshape(*reshape)
    p_set = f"HN Set {i+1}"
    p_hover = "<br>    ".join([f"{x:.3e}" for x in hx.cpu().numpy()])

    vnp_fig.add_trace(
        go.Scatter(
            x=torch.arange(len(tx)),
            y=tyhat,
            line={"color": colors[color_count]},
            name="True",
            showlegend=True,
            legendgroup=p_set,
            legendgrouptitle_text=p_set,
            hovertemplate=f"HyperNet Inputs:<br>    {p_hover}<br>" "<br>t: %{x}" "<br>a: %{y}",
        )
    )

    vnp_fig.add_trace(
        go.Scatter(
            x=torch.arange(len(tx)),
            y=ty,
            line={"color": colors[color_count], "dash": "dash"},
            name="Prediction",
            showlegend=True,
            legendgroup=p_set,
            legendgrouptitle_text=p_set,
            hovertemplate=f"HyperNet Inputs:<br>    {p_hover}<br>" "<br>t: %{x}" "<br>a: %{y}",
        )
    )

    color_count += 1
    if color_count == len(colors):
        color_count = 0

vnp_fig.show()