In [1]:
import os
import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim

from recurrent_network import RecurrentNetwork

In [2]:
'''
Testing for any forward pass / backpropagation errors
'''

# ssm config
embedding_size = 64
state_space_size = 256
input_size = 2
output_size = input_size
batch_size = 1
seq_len = 30000
device = T.device("cuda" if T.cuda.is_available() else "cpu")

# instantiate modepl
model = RecurrentNetwork(
    embedding_size=embedding_size,
    state_space_size=state_space_size,
    input_size=input_size,
    output_size=output_size,
    batch_size=batch_size,
    chkpt_dir='temp/',
    device=device
)

# backprop setup
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# generate input and target data
inputs = T.randn(batch_size, seq_len, input_size, dtype=T.float32, device=device)
targets = T.randn(batch_size, seq_len, input_size, dtype=T.float32, device=device)

# training loop for 1 step to test backprop
model.train()
optimizer.zero_grad()

# forward pass
h_t, y_t = model(inputs, None)

# predict next state
preds = model.predict(y_t)

# loss + backprop
loss = criterion(preds, targets)
loss.backward()
optimizer.step()

print("Loss:", loss.item())
print("Backpropagation and optimizer step completed successfully.")




Loss: 1.1710835695266724
Backpropagation and optimizer step completed successfully.


In [3]:
# checking for numerical stability
print(y_t[-1], h_t[-1])

tensor([[0.0655, 0.1148, 0.2860,  ..., 0.2121, 0.1622, 0.1826],
        [0.0309, 0.1710, 0.3015,  ..., 0.2140, 0.1875, 0.1928],
        [0.0869, 0.1069, 0.3221,  ..., 0.2443, 0.1791, 0.1978],
        ...,
        [0.0775, 0.0791, 0.2950,  ..., 0.2322, 0.1089, 0.1692],
        [0.0610, 0.0731, 0.3166,  ..., 0.2220, 0.1313, 0.1782],
        [0.0732, 0.1319, 0.2933,  ..., 0.2156, 0.1626, 0.1756]],
       device='cuda:0') tensor([[[ 1.6417e-03+2.7430e-04j,  1.0630e-03+1.1223e-04j,
           1.4100e-03+2.9725e-04j,  ...,
           1.2634e-03+2.0918e-04j,  1.1443e-03+3.1881e-04j,
           1.5575e-03+2.7099e-04j],
         [ 1.0275e-03+3.2850e-04j,  1.6435e-03+4.2046e-04j,
           1.1387e-03+3.1379e-05j,  ...,
           1.3920e-03+3.1992e-04j,  1.0469e-03+3.0959e-05j,
           1.3619e-03+7.9969e-05j],
         [ 1.7237e-03+8.4195e-05j,  1.1512e-03+2.1940e-04j,
           1.6340e-03+9.5819e-05j,  ...,
           1.6490e-03+4.5517e-04j,  1.6809e-03+1.6036e-04j,
           1.5096e-03+3