In [1]:
import os
import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim

from recurrent_network import RecurrentNetwork

In [None]:
'''
Testing for any forward pass / backpropagation errors
'''

# ssm config
embedding_size = 64
state_space_size = 256
input_size = 2
batch_size = 1
seq_len = 30000
device = T.device("cuda" if T.cuda.is_available() else "cpu")

# instantiate modepl
model = RecurrentNetwork(
    embedding_size=embedding_size,
    state_space_size=state_space_size,
    input_size=input_size,
    batch_size=batch_size,
    chkpt_dir='temp/',
    device=device
)

# backprop setup
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# generate input and target data
inputs = T.randn(batch_size, seq_len, input_size, dtype=T.float32, device=device)
targets = T.randn(batch_size, input_size, dtype=T.float32, device=device)

# training loop for 1 step to test backprop
model.train()
optimizer.zero_grad()

# forward pass
h_t, y_t = model(inputs)

# predict next state
preds = model.predict(y_t)

# loss + backprop
loss = criterion(preds, targets)
loss.backward(retain_graph=True)
optimizer.step()

print("Loss:", loss.item())
print("Backpropagation and optimizer step completed successfully.")




Loss: 0.015437815338373184
Backpropagation and optimizer step completed successfully.


  return F.mse_loss(input, target, reduction=self.reduction)


In [4]:
print(y_t[-1], h_t[-1])

tensor([[ 0.5729, -0.0250,  0.0021,  ..., -0.0652,  0.1154,  0.2558],
        [ 0.5858, -0.0770, -0.0698,  ..., -0.0927,  0.0745,  0.1193],
        [ 0.6181, -0.0748, -0.0278,  ..., -0.0395,  0.0930,  0.1442],
        ...,
        [ 0.5851, -0.0516, -0.0414,  ..., -0.0581,  0.1105,  0.1665],
        [ 0.6032, -0.0548, -0.0312,  ..., -0.0323,  0.1034,  0.1742],
        [ 0.5814, -0.0251, -0.0162,  ..., -0.0493,  0.1217,  0.1998]],
       device='cuda:0') tensor([[[ 9.9706e-04+3.7133e-06j,  1.2037e-03+1.5207e-04j,
           1.9054e-03+1.7912e-04j,  ...,
           1.7216e-03+4.5275e-04j,  1.4240e-03+2.3090e-04j,
           1.1097e-03+3.4777e-06j],
         [ 1.8471e-03+1.3348e-04j,  9.9923e-04+6.0769e-05j,
           1.3140e-03+3.6246e-04j,  ...,
           1.6930e-03+2.2980e-04j,  1.3823e-03+1.8284e-04j,
           1.0248e-03+2.5133e-04j],
         [ 1.7605e-03+3.8212e-05j,  1.3460e-03+1.2854e-04j,
           1.2929e-03+2.9299e-04j,  ...,
           1.1784e-03+3.4452e-04j,  1.3303e-03+