In [83]:
import torch
import torch.nn as nn
import torch.nn.init as init
import sys
import os
import time

# Adjust the path to locate your modules.
sys.path.append(os.path.abspath("../../src"))

from lds import LDS as LDSBase
from inference_lds import LDS as LDSInference

def initialize_params(model):
    """
    Initialize model parameters to nonzero values.
    Here we simply initialize each parameter with a random normal distribution.
    """
    model.A.data.copy_(torch.randn_like(model.A))
    model.B.data.copy_(torch.randn_like(model.B))
    model.C.data.copy_(torch.randn_like(model.C))
    model.M.data.copy_(torch.randn_like(model.M))
    model.h0.data.copy_(torch.randn_like(model.h0))

def test_regressive_equivalence():
    torch.manual_seed(55)

    # Hyperparameters.
    state_dim = 4
    input_dim = 1    # Must equal output_dim for autoregressive generation (output fed as next input)
    output_dim = 1
    kx = 1         # AR order (number of taps)
    steps = 10      # Number of autoregressive generation steps.
    bsz = 1        # Batch size.

    # Instantiate models.
    # LDSInference is optimized for iterative generation.
    # LDSBase computes a full forward pass over the entire sequence.
    model_infer = LDSInference(state_dim, input_dim, output_dim, kx=kx)
    model_base = LDSBase(state_dim, input_dim, output_dim, kx=kx)

    # Initialize parameters on the inference model.
    initialize_params(model_infer)
    
    # Force both models to have identical parameters.
    model_base.A.data = model_infer.A.data.clone()
    model_base.B.data = model_infer.B.data.clone()
    model_base.C.data = model_infer.C.data.clone()
    model_base.M.data = model_infer.M.data.clone()
    model_base.h0.data = model_infer.h0.data.clone()

    # Reset the inference model's state.
    model_infer.reset_state(batch_size=bsz)

    # Create a growing input sequence for the base model.
    x_init = torch.randn(bsz, input_dim)
    inputs_base = x_init.unsqueeze(1)  # Shape: [bsz, 1, input_dim]

    inference_time_total = 0.0
    base_time_total = 0.0

    print("Step-by-step generation and comparison:")

    for t in range(steps):
        # Measure inference model time.
        start_infer = time.perf_counter()
        y_t_infer = model_infer.next_step(x_init)  # Shape: [bsz, output_dim]
        end_infer = time.perf_counter()
        inference_time_total += (end_infer - start_infer)

        # Measure base model time.
        start_base = time.perf_counter()
        y_base_all = model_base(inputs_base)
        y_t_base = y_base_all[:, -1, :]  # Extract the last output.
        end_base = time.perf_counter()
        base_time_total += (end_base - start_base)

        # Compare the outputs.
        diff = (y_t_infer - y_t_base).abs().max().item()
        print(f"Step {t} | y_infer: {y_t_infer} \n         y_base: {y_t_base} \n         max diff = {diff:.4e}")

        # Append the new output to inputs_base for the next iteration.
        inputs_base = torch.cat([inputs_base, y_t_infer.unsqueeze(1)], dim=1)
        # Use the new inference output as the next input.
        x_init = y_t_infer

    print("\nTiming Summary:")
    print(f"Total time for inference model over {steps} steps: {inference_time_total:.6f} seconds")
    print(f"Total time for base model over {steps} steps: {base_time_total:.6f} seconds")


test_regressive_equivalence()

Step-by-step generation and comparison:
Step 0 | y_infer: tensor([[-1.3543]], grad_fn=<AddBackward0>) 
         y_base: tensor([[-1.3543]], grad_fn=<SliceBackward0>) 
         max diff = 0.0000e+00
Step 1 | y_infer: tensor([[-1.1029]], grad_fn=<AddBackward0>) 
         y_base: tensor([[-1.1029]], grad_fn=<SliceBackward0>) 
         max diff = 2.3842e-07
Step 2 | y_infer: tensor([[1.6247]], grad_fn=<AddBackward0>) 
         y_base: tensor([[1.6247]], grad_fn=<SliceBackward0>) 
         max diff = 0.0000e+00
Step 3 | y_infer: tensor([[-0.9242]], grad_fn=<AddBackward0>) 
         y_base: tensor([[-0.9242]], grad_fn=<SliceBackward0>) 
         max diff = 0.0000e+00
Step 4 | y_infer: tensor([[-2.6577]], grad_fn=<AddBackward0>) 
         y_base: tensor([[-2.6577]], grad_fn=<SliceBackward0>) 
         max diff = 0.0000e+00
Step 5 | y_infer: tensor([[2.0841]], grad_fn=<AddBackward0>) 
         y_base: tensor([[2.0841]], grad_fn=<SliceBackward0>) 
         max diff = 0.0000e+00
Step 6 | y_infer

In [84]:
state_dim = 100
input_dim = 1    # Must equal output_dim for autoregressive generation (output fed as next input)
output_dim = 1
kx = 1         # AR order (number of taps)
seq_len = 100      # Number of autoregressive generation steps.
bsz = 1 

model_base = LDSBase(state_dim, input_dim, output_dim, kx=kx)
model_infer = LDSInference(state_dim, input_dim, output_dim, kx=kx)


# Initialize parameters on the inference model.
initialize_params(model_infer)

# M = torch.zeros_like(model_base.M.data)
# model_base.M.data = M
# Force both models to have identical parameters.
model_infer.A.data = model_base.A.data.clone()
model_infer.B.data = model_base.B.data.clone()
model_infer.C.data = model_base.C.data.clone()
model_infer.M.data = model_base.M.data.clone()
model_infer.h0.data = model_base.h0.data.clone()

# Reset the inference model's state.
model_infer.reset_state(batch_size=bsz)

# Create a growing input sequence for the base model.
x_init = torch.randn(bsz, seq_len ,input_dim)

In [85]:
inf_out = model_infer(x_init)
lds_out = model_base(x_init)

In [86]:
import torch.nn.functional as F
F.mse_loss(inf_out, lds_out)

tensor(9.6340e-16, grad_fn=<MseLossBackward0>)

In [87]:
model_infer.reset_state(1)
outputs = []
for t in range(seq_len):
    x_t = x_init[:, t, :]
    y_t = model_infer.next_step(x_t)
    outputs.append(y_t.unsqueeze(1))
y_inf_out = torch.cat(outputs, dim=1)

In [88]:
F.mse_loss(y_inf_out, lds_out)

tensor(9.6340e-16, grad_fn=<MseLossBackward0>)