In [None]:
import torch
import sys
import os
sys.path.append(os.path.abspath("../src"))
sys.path.append(os.path.abspath("../experiments/convex_hull"))

from model_550m import get_spectral_filters, STU
from full_fast_stu import FullFastSTU

class STUConfig():
    def __init__(
        self,
        d_in=4,
        d_out=4,
        num_eigh=24,
        seq_len=8192,
        use_hankel_L=False,
        use_approx=True,
        use_flash_fft = True,
        torch_dtype = torch.bfloat16
    ):
        super().__init__()
        self.n_embd = d_in  # Used by some parts of the code
        self.dim = d_in     # Used by other parts of the code
        self.d_out = d_out
        self.num_eigh = num_eigh
        self.seq_len = seq_len
        self.use_hankel_L = use_hankel_L
        self.use_approx = use_approx
        self.use_flash_fft = use_flash_fft
        self.torch_dtype = torch_dtype

def create_random_stu(d_in=1, d_out=1, num_eigh=24, use_hankel_L=False, use_approx=True):
    # Create a random config
 
    filters = get_spectral_filters(
        seq_len=8192,
        K=num_eigh,
        use_hankel_L=use_hankel_L,
        device=torch.device('cuda'),
        dtype=torch.bfloat16
    )
    
    # Create random STU
    config = STUConfig()
    stu = STU(config, filters).cuda()
    return stu

In [24]:
stu = create_random_stu()

seq_len = 8192    
# Generate random input
batch_size = 1
d_in = stu.config.n_embd
x = torch.randn(batch_size, seq_len, d_in, dtype=torch.bfloat16).cuda()
input_pos = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1).cuda()

full_fast_stu = FullFastSTU(stu, "./convex_hull/best_phi_lds.pt")
full_fast_stu.lds.M.data = torch.zeros_like(full_fast_stu.lds.M)
# Get outputs from both models
with torch.no_grad():
    stu_output = stu(x, input_pos)
    full_fast_output = full_fast_stu(x, input_pos)
# Calculate differences
abs_diff = torch.abs(stu_output - full_fast_output)
mean_diff = abs_diff.mean().item()
max_diff = abs_diff.max().item()
    

print(f"Mean absolute difference: {mean_diff}")
print(f"Max absolute difference: {max_diff}")


# Check if outputs are close
is_close = torch.allclose(stu_output, full_fast_output, rtol=1e-5, atol=1e-5)

#with best_phi_lds.pt
#Mean absolute difference: 0.00689697265625
#Max absolute difference: 0.0625

#with 250

  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Mean absolute difference: 0.0133056640625
Max absolute difference: 0.25


In [76]:
full_fast_stu.lds.C.shape

torch.Size([1149, 8])

In [77]:
x = x.double()

In [78]:
x = x @ full_fast_stu.M_inputs

In [81]:
bsz = x.shape[0]
x_reshaped = x.permute(0, 2, 1).reshape(-1, x.shape[1], 1)

In [82]:
U_reshaped = lds_phi(x_reshaped)

In [83]:
U_reshaped = torch.cat([U_reshaped[:, :, :24] @ full_fast_stu.M_filters, U_reshaped[:, :, 24:] @ full_fast_stu.M_filters], dim = -1)

In [84]:
U_reshaped.shape

torch.Size([4, 8192, 8])

In [85]:
U_reshaped2 = full_fast_stu.lds(x_reshaped)
U_reshaped2.shape

torch.Size([4, 8192, 8])

tensor(8.5754e-14, device='cuda:0', dtype=torch.float64,
       grad_fn=<MeanBackward0>)

In [5]:
from lds import LDS

In [34]:
checkpoint = torch.load( "./convex_hull/best_phi_lds.pt", map_location=torch.device('cpu'))
        
# Create the LDS model
lds_phi = LDS(
    state_dim=checkpoint['state_dim'],
    input_dim=checkpoint['input_dim'],
    output_dim=checkpoint['output_dim'],
    kx=checkpoint['kx'],
    dtype=torch.float32 if checkpoint['dtype'] == 'torch.float32' else torch.float64,
    
)

# Load the weights from checkpoint
lds_phi.load_state_dict(checkpoint['model_state_dict'], strict = False)
lds_phi = lds_phi.cuda()
lds_phi.M.data = torch.zeros_like(lds_phi.M.data)

  checkpoint = torch.load( "./convex_hull/best_phi_lds.pt", map_location=torch.device('cpu'))
