### Modifying TargetNet Parameters

The TargetNet parameters have to be modified in different ways depending on if the model is being trained or is being used for normal inference.

##### During training:

We cannot directly modify the TargetNet parameters during training. If we attempted this, PyTorch will raise an exception that parameters were modified by an in-place operation. To run the TargetNet with the parameters from the HyperNet (`hy`), we can use `torch.func.functional_call(targetNet, tn_params, tx)` where `targetNet` is the TargetNet model, `tn_params` is a dictionary with key-value pairs of the parameter name and the new tensor, and `tx` is the TargetNet input tensor.

***Note:*** If you are unfamiliar, `math.prod()` is similar to `sum()`, but multiplies all values of an iterable together.

```python
# collect new parameters in a dict with name of parameter as key
idx, tn_params = 0, {}
for n, p in targetNet.named_parameters():
    tn_params[n] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
    idx += math.prod(p.shape)

# use PyTorch's functional_call
ty = torch.func.functional_call(targetNet, tn_params, tx)
```

##### During normal inference operation:

We cannot use the above method to modify TargetNet parameters during normal inference because using `functional_call()` will throw raise an exception when not training. However, not that we are not training, we can modify the parameter tensors in-place without raising an exception. Then, we simply use the `forward()` method of the TargetNet.

```python
# modify parameters in place
idx = 0
for p in targetNet.parameters():
    p[:] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
    idx += math.prod(p.shape)

# use TargetNet's forward method
ty = targetNet(tx)
```

In [134]:
import math
from collections import OrderedDict

import numpy as np
import plotly
import plotly.graph_objects as go

import torch
from torch import optim
from torch.func import functional_call
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module, Sequential, Linear, LSTM, ReLU, Tanh, MSELoss

The following constants are used to define the TargetNet LSTM and the HyperNet DNN.

`HYPERNET_NODES` is a list of the number of nodes in each layer except for the output layer to the TargetNet. `HYPERNET_ACTS` is a list of activation functions for each hidden layer *and* the output layer to the TargetNet.

In [135]:
N_LSTM_LAYERS = 3
LSTM_IN_DIM = 1
LSTM_HIDDEN_DIM = 10
LSTM_OUT_DIM = 1

HYPERNET_NODES = [2, 8, 12, 12, 8, 2]
HYPERNET_ACTS = [ReLU(), ReLU(), ReLU(), ReLU(), ReLU(), Tanh()]

The following constants are values used to configure the training settings.

In [136]:
DEVICE = torch.device("cuda:0")
EPOCHS = 25
BATCH_SIZE = 25
LEARN_RATE = 1e-4
MODEL_NAME = "sinusoidal_hypernet.pt"
LOSS_FN = MSELoss()

Now we load the data arrays that were created with `gen_sin_data.ipynb`.

In [137]:
TRAIN_DATA = np.load("sinusoidal-train.npz", allow_pickle=True)["arr_0"]
TTP_DATA = np.load("sinusoidal-test-tp.npz", allow_pickle=True)["arr_0"]
TNP_DATA = np.load("sinusoidal-test-np.npz", allow_pickle=True)["arr_0"]
VTP_DATA = np.load("sinusoidal-val-tp.npz", allow_pickle=True)["arr_0"]
VNP_DATA = np.load("sinusoidal-val-np.npz", allow_pickle=True)["arr_0"]

The following custom Dataset class is used to batch the train and test data sets with a Dataloader. `x` is the TargetNet input tensor, and `yhat` is the TargetNet output target values tensor.

In [138]:
class RNNDataset(Dataset):

    def __init__(self, x: torch.tensor, yhat: torch.tensor) -> None:
        super().__init__()
        self.__x = x
        self.__yhat = yhat

    def __len__(self) -> int:
        return len(self.__x)

    def __getitem__(self, idx) -> torch.tensor:
        return self.__x[idx], self.__yhat[idx]

In [139]:
def parse_input_dataset(dataset: list, sh: bool) -> list:
    """
    load RNN data in dataset with `RNNDataset` and `DataLoader`

    :param dataset: list of data arrays
    :type dataset: list
    :return: parse dataset
    :rtype: list
    """
    for i, ds in enumerate(dataset):
        hx = torch.from_numpy(ds["hx"]).to(DEVICE).float()
        tx = torch.from_numpy(ds["tx"]).to(DEVICE).float()
        tyhat = torch.from_numpy(ds["tyhat"]).to(DEVICE).float()
        rnn_data = DataLoader(
            RNNDataset(tx, tyhat), batch_size=BATCH_SIZE, shuffle=sh, drop_last=True
        )
        dataset[i] = [hx, rnn_data]
    return dataset

In [140]:
TRAIN_DATA = parse_input_dataset(TRAIN_DATA, True)
TTP_DATA = parse_input_dataset(TTP_DATA, False)
TNP_DATA = parse_input_dataset(TNP_DATA, False)

In [141]:
def train_epoch(
    hyperNet_: Sequential,
    targetLSTM_: LSTM,
    targetSeq_: Sequential,
    hx: torch.tensor,
    t_data: DataLoader,
    batch_size: int,
    loss_fn: Module,
    device: torch.device,
) -> torch.tensor:

    # hidden/cell state dimensions
    hc_state_dim = (targetLSTM_.num_layers, batch_size, targetLSTM_.hidden_size)

    # run HyperNet
    hy = hyperNet_(hx)

    # collect TargetNet weights and biases in dicts
    idx, tn_params = 0, {}
    for n, p in targetLSTM_.named_parameters():
        tn_params[n] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
        idx += math.prod(p.shape)
    tns_params = {}
    for n, p in targetSeq_.named_parameters():
        tns_params[n] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
        idx += math.prod(p.shape)

    # run LSTM data in batches
    loss = 0
    for _, batch in enumerate(t_data):
        tx, tyhat = batch[0].to(device), batch[1].to(device)
        h0 = torch.zeros(*hc_state_dim).to(device)
        c0 = torch.zeros(*hc_state_dim).to(device)
        ty, _ = functional_call(targetLSTM_, tn_params, (tx, (h0, c0)))
        ty = functional_call(targetSeq_, tns_params, ty[:, -1, :])
        loss += loss_fn(ty, tyhat)

    return loss

In [142]:
def test_epoch(
    hyperNet_: Sequential,
    targetLSTM_: LSTM,
    targetSeq_: Sequential,
    vhx: torch.tensor,
    vt_data: DataLoader,
    batch_size: int,
    loss_fn: Module,
    device: torch.device,
) -> float:

    hc_state_dim = (targetLSTM_.num_layers, batch_size, targetLSTM_.hidden_size)

    loss = 0
    with torch.no_grad():

        # run HyperNet
        vhy = hyperNet_(vhx)

        # collect TargetNet weights and biases in dicts
        idx, vtn_params = 0, {}
        for n, p in targetLSTM_.named_parameters():
            vtn_params[n] = vhy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
            idx += math.prod(p.shape)
        vtns_params = {}
        for n, p in targetSeq_.named_parameters():
            vtns_params[n] = vhy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
            idx += math.prod(p.shape)

        # run LSTM data in batches
        for _, batch in enumerate(vt_data):
            vtx, vtyhat = batch[0].to(device), batch[1].to(device)
            vh0 = torch.zeros(*hc_state_dim).to(device)
            vc0 = torch.zeros(*hc_state_dim).to(device)
            vty, _ = functional_call(targetLSTM_, vtn_params, (vtx, (vh0, vc0)))
            vty = functional_call(targetSeq_, vtns_params, vty[:, -1, :])
            loss += loss_fn(vty, vtyhat)

    return loss.item()

In [143]:
# create TargetNet modules
targetLSTM = LSTM(LSTM_IN_DIM, LSTM_HIDDEN_DIM, N_LSTM_LAYERS, batch_first=True)
targetSeq = Sequential(
    OrderedDict([
        ("TO_Layer", Linear(LSTM_HIDDEN_DIM, LSTM_OUT_DIM)),
        ("TO_Tanh", Tanh()),
    ])
)

# get number of parameters in TargetNet
target_params = 0
for p in targetLSTM.parameters():
    target_params += math.prod(p.shape)
    p.requires_grad = False
for p in targetSeq.parameters():
    target_params += math.prod(p.shape)
    p.requires_grad = False

# create HyperNet
HYPERNET_NODES.append(target_params)
layers = []
for i, act in enumerate(HYPERNET_ACTS):
    layers.append(Linear(HYPERNET_NODES[i], HYPERNET_NODES[i + 1]))
    layers.append(act)
hyperNet = Sequential(*layers)

# move modules to DEVICE
targetLSTM.to(DEVICE)
targetSeq.to(DEVICE)
hyperNet.to(DEVICE)

Sequential(
  (0): Linear(in_features=2, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=12, bias=True)
  (3): ReLU()
  (4): Linear(in_features=12, out_features=12, bias=True)
  (5): ReLU()
  (6): Linear(in_features=12, out_features=8, bias=True)
  (7): ReLU()
  (8): Linear(in_features=8, out_features=2, bias=True)
  (9): ReLU()
  (10): Linear(in_features=2, out_features=2291, bias=True)
  (11): Tanh()
)

In [144]:
# create optimizer
optimizer = optim.Adam(hyperNet.parameters(), lr=LEARN_RATE)

# track loss over epochs
best_loss = torch.inf
loss, ttp_loss, tnp_loss = [], [], []

# dimension of LSTM hidden/cell states
hc_state_dim = (targetLSTM.num_layers, BATCH_SIZE, targetLSTM.hidden_size)

# iterate over epochs
for i in range(EPOCHS):

    # train HyperNet for this epoch
    loss_i = 0
    for ds in TRAIN_DATA:

        # HyperNet and TargetNet data
        hx, t_data = ds

        # run epoch
        optimizer.zero_grad()
        loss_b = train_epoch(
            hyperNet, targetLSTM, targetSeq, hx, t_data, BATCH_SIZE, LOSS_FN, DEVICE
        )

        # back propagate HyperNet
        loss_b.backward()
        optimizer.step()

        # track loss for HyperNet input set
        loss_i += loss_b.item()

    # test model with training parameters
    ttp_loss_i = 0
    for ds in TTP_DATA:
        vhx, vt_data = ds
        ttp_loss_i += test_epoch(
            hyperNet, targetLSTM, targetSeq, vhx, vt_data, BATCH_SIZE, LOSS_FN, DEVICE
        )

    # test model with new parameters
    tnp_loss_i = 0
    for ds in TNP_DATA:
        vhx, vt_data = ds
        tnp_loss_i += test_epoch(
            hyperNet, targetLSTM, targetSeq, vhx, vt_data, BATCH_SIZE, LOSS_FN, DEVICE
        )

    # track loss for epoch
    loss.append(loss_i)
    ttp_loss.append(ttp_loss_i)
    tnp_loss.append(tnp_loss_i)

    # save model if it has lowest loss
    if loss_i < best_loss:
        best_loss = loss_i
        hyperNet = hyperNet.cpu()
        torch.save(hyperNet.state_dict(), MODEL_NAME)
        hyperNet = hyperNet.to(DEVICE)

    # print training progress
    print(
        f"Epoch: {i+1:0>2} | Loss: {loss_i:.3e} | "
        f"Best Loss: {best_loss:.3e} | "
        f"TTP Loss: {ttp_loss_i:.3e} | "
        f"TNP Loss: {tnp_loss_i:.3e}"
    )

Epoch: 01 | Loss: 3.724e+01 | Best Loss: 3.724e+01 | TTP Loss: 2.825e-01 | TNP Loss: 2.578e-01
Epoch: 02 | Loss: 2.695e+00 | Best Loss: 2.695e+00 | TTP Loss: 8.030e-02 | TNP Loss: 6.680e-02
Epoch: 03 | Loss: 1.231e+00 | Best Loss: 1.231e+00 | TTP Loss: 3.117e-02 | TNP Loss: 2.863e-02
Epoch: 04 | Loss: 8.735e-01 | Best Loss: 8.735e-01 | TTP Loss: 4.726e-02 | TNP Loss: 4.257e-02
Epoch: 05 | Loss: 6.433e-01 | Best Loss: 6.433e-01 | TTP Loss: 5.119e-02 | TNP Loss: 4.210e-02
Epoch: 06 | Loss: 5.042e-01 | Best Loss: 5.042e-01 | TTP Loss: 8.269e-02 | TNP Loss: 6.684e-02
Epoch: 07 | Loss: 4.993e-01 | Best Loss: 4.993e-01 | TTP Loss: 1.112e-01 | TNP Loss: 9.081e-02
Epoch: 08 | Loss: 5.077e-01 | Best Loss: 4.993e-01 | TTP Loss: 1.350e-01 | TNP Loss: 1.106e-01
Epoch: 09 | Loss: 5.083e-01 | Best Loss: 4.993e-01 | TTP Loss: 1.362e-01 | TNP Loss: 1.112e-01
Epoch: 10 | Loss: 4.952e-01 | Best Loss: 4.952e-01 | TTP Loss: 1.272e-01 | TNP Loss: 1.041e-01
Epoch: 11 | Loss: 4.766e-01 | Best Loss: 4.766e-01

In [145]:
loss_fig = go.Figure()
loss_fig.add_trace(
    go.Scatter(
        x=torch.arange(len(loss)) + 1,
        y=loss,
        name="Loss",
        mode="lines",
        showlegend=True,
    )
)
loss_fig.add_trace(
    go.Scatter(
        x=torch.arange(len(ttp_loss)) + 1,
        y=ttp_loss,
        name="TTP Loss",
        mode="lines",
        showlegend=True,
    )
)
loss_fig.add_trace(
    go.Scatter(
        x=torch.arange(len(tnp_loss)) + 1,
        y=tnp_loss,
        name="TNP Loss",
        mode="lines",
        showlegend=True,
    )
)
loss_fig.update_layout(
    title="Sinusoidal HyperNet Loss",
    yaxis={"title": "Loss"},
    xaxis={"title": "Epoch"},
    hovermode="x unified",
)

loss_fig.show()

In [146]:
hyperNet.load_state_dict(torch.load(MODEL_NAME, weights_only=True))

<All keys matched successfully>

In [147]:
def forward(
    hyperNet_: Sequential,
    targetLSMT_: LSTM,
    targetSeq_: Sequential,
    hx: torch.Tensor,
    tx: torch.Tensor,
) -> torch.Tensor:
    hy = hyperNet_(hx)

    idx = 0
    for p in targetLSMT_.parameters():
        p[:] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
        idx += math.prod(p.shape)
    for p in targetSeq_.parameters():
        p[:] = hy[idx : idx + math.prod(p.shape)].reshape(*p.shape)
        idx += math.prod(p.shape)

    hc_dims = (targetLSMT_.num_layers, tx.shape[0], targetLSMT_.hidden_size)
    h0 = torch.zeros(*hc_dims).to(DEVICE)
    c0 = torch.zeros(*hc_dims).to(DEVICE)

    y, _ = targetLSMT_(tx, (h0, c0))
    y = targetSeq_(y[:, -1, :])

    return y.detach()

In [148]:
def parse_validation_dataset(
    dataset: list, hyperNet_: Sequential, targetLSMT_: LSTM, targetSeq_: Sequential
) -> list:
    """
    load RNN data in dataset with `RNNDataset` and `DataLoader`

    :param dataset: list of data arrays
    :type dataset: list
    :return: parse dataset
    :rtype: list
    """
    for i, ds in enumerate(dataset):
        hx = torch.from_numpy(ds["hx"]).to(DEVICE).float()
        tx = torch.from_numpy(ds["tx"]).to(DEVICE).float()
        tyhat = torch.from_numpy(ds["tyhat"]).to(DEVICE).float()
        ty = forward(hyperNet_, targetLSMT_, targetSeq_, hx, tx)
        dataset[i] = [hx, tx, ty, tyhat]
    return dataset

In [149]:
VTP_DATA = parse_validation_dataset(VTP_DATA, hyperNet, targetLSTM, targetSeq)
VNP_DATA = parse_validation_dataset(VNP_DATA, hyperNet, targetLSTM, targetSeq)

In [150]:
colors = plotly.colors.qualitative.Plotly
color_count = 0

In [151]:
vtp_fig = go.Figure()
vtp_fig.update_layout(
    title="Training Parameter Validation",
    yaxis={"title": "Amplitude"},
    xaxis={"title": "time"},
)
for i, ds in enumerate(VTP_DATA):

    hx, tx, ty, tyhat = ds
    tyhat = tyhat.cpu().reshape(
        -1,
    )
    ty = ty.cpu().reshape(
        -1,
    )
    p_set = f"HN Set {i+1}"
    p_hover = "<br>    ".join([f"{x:.3e}" for x in hx.cpu().numpy()])

    vtp_fig.add_trace(
        go.Scatter(
            x=torch.arange(len(tx)),
            y=tyhat.cpu(),
            line={"color": colors[color_count]},
            name="True",
            showlegend=True,
            legendgroup=p_set,
            legendgrouptitle_text=p_set,
            hovertemplate=f"HyperNet Inputs:<br>    {p_hover}<br>" "<br>t: %{x}" "<br>a: %{y}",
        )
    )

    vtp_fig.add_trace(
        go.Scatter(
            x=torch.arange(len(tx)),
            y=ty,
            line={"color": colors[color_count], "dash": "dash"},
            name="Prediction",
            showlegend=True,
            legendgroup=p_set,
            legendgrouptitle_text=p_set,
            hovertemplate=f"HyperNet Inputs:<br>    {p_hover}<br>" "<br>t: %{x}" "<br>a: %{y}",
        )
    )

    color_count += 1
    if color_count == len(colors):
        color_count = 0

vtp_fig.show()

In [152]:
vnp_fig = go.Figure()
vnp_fig.update_layout(
    title="New Parameter Validation",
    yaxis={"title": "Amplitude"},
    xaxis={"title": "time"},
)
for i, ds in enumerate(VNP_DATA):

    hx, tx, ty, tyhat = ds
    tyhat = tyhat.cpu().reshape(
        -1,
    )
    ty = ty.cpu().reshape(
        -1,
    )
    p_set = f"HN Set {i+1}"
    p_hover = "<br>    ".join([f"{x:.3e}" for x in hx.cpu().numpy()])

    vnp_fig.add_trace(
        go.Scatter(
            x=torch.arange(len(tx)),
            y=tyhat,
            line={"color": colors[color_count]},
            name="True",
            showlegend=True,
            legendgroup=p_set,
            legendgrouptitle_text=p_set,
            hovertemplate=f"HyperNet Inputs:<br>    {p_hover}<br>" "<br>t: %{x}" "<br>a: %{y}",
        )
    )

    vnp_fig.add_trace(
        go.Scatter(
            x=torch.arange(len(tx)),
            y=ty,
            line={"color": colors[color_count], "dash": "dash"},
            name="Prediction",
            showlegend=True,
            legendgroup=p_set,
            legendgrouptitle_text=p_set,
            hovertemplate=f"HyperNet Inputs:<br>    {p_hover}<br>" "<br>t: %{x}" "<br>a: %{y}",
        )
    )

    color_count += 1
    if color_count == len(colors):
        color_count = 0

vnp_fig.show()