In [1]:
import os
import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim

from recurrent_network import RecurrentNetwork

In [2]:
'''
Testing for any forward pass / backpropagation errors
'''

# ssm config
embedding_size = 64
state_space_size = 256
input_size = 2
output_size = input_size
batch_size = 1
seq_len = 20000
device = T.device("cuda" if T.cuda.is_available() else "cpu")

# instantiate modepl
model = RecurrentNetwork(
    embedding_size=embedding_size,
    state_space_size=state_space_size,
    input_size=input_size,
    output_size=output_size,
    batch_size=batch_size,
    chkpt_dir='temp/',
    device=device
)

# backprop setup
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# generate input and target data
inputs = T.randn(batch_size, seq_len, input_size, dtype=T.float32, device=device)
targets = T.randn(batch_size, seq_len, input_size, dtype=T.float32, device=device)

# training loop for 1 step to test backprop
model.train()
optimizer.zero_grad()

# forward pass
h_t, preds = model(inputs, None, embeddings_only=False)


# loss + backprop
loss = criterion(preds, targets)
loss.backward()
optimizer.step()

print("Loss:", loss.item())
print("Backpropagation and optimizer step completed successfully.")




Loss: 1.1018831729888916
Backpropagation and optimizer step completed successfully.


In [4]:
# checking for numerical stability
print(preds[-1], h_t[-1])

tensor([[ 0.4591, -0.1959],
        [ 0.4604, -0.1604],
        [ 0.3640, -0.1641],
        ...,
        [ 0.4368, -0.2011],
        [ 0.4634, -0.1987],
        [ 0.4463, -0.1905]], device='cuda:0', grad_fn=<SelectBackward0>) tensor([[[ 1.6615e-03+2.0768e-04j,  1.8464e-03+1.9990e-04j,
           1.3338e-03+2.2970e-04j,  ...,
           1.2840e-03+9.5930e-05j,  1.0180e-03+2.2727e-05j,
           1.2584e-03+2.7814e-04j],
         [ 1.6175e-03+9.0074e-05j,  1.4089e-03+1.6909e-04j,
           1.0490e-03+2.0764e-04j,  ...,
           1.2951e-03+1.6018e-04j,  1.3748e-03+2.7018e-04j,
           1.2638e-03+6.6188e-05j],
         [ 1.2854e-03+3.5280e-04j,  1.5346e-03+3.3003e-04j,
           1.9073e-03+1.4986e-04j,  ...,
           1.2785e-03+1.4658e-04j,  1.8336e-03+7.7702e-05j,
           1.8120e-03+3.0130e-04j],
         ...,
         [ 1.5559e-03+2.0730e-04j,  1.7579e-03+2.1331e-04j,
           1.5628e-03+3.4168e-04j,  ...,
           1.7586e-03+7.1525e-05j,  1.6483e-03+4.5804e-04j,
        