In [1]:
import torch
import sys
import os
sys.path.append(os.path.abspath("../src"))
sys.path.append(os.path.abspath("../experiments/convex_hull"))

from model_550m import get_spectral_filters, STU
from full_fast_stu import FullFastSTU

class STUConfig():
    def __init__(
        self,
        d_in=4,
        d_out=4,
        num_eigh=24,
        seq_len=8192,
        use_hankel_L=False,
        use_approx=True,
        use_flash_fft = True,
        torch_dtype = torch.bfloat16
    ):
        super().__init__()
        self.n_embd = d_in  # Used by some parts of the code
        self.dim = d_in     # Used by other parts of the code
        self.d_out = d_out
        self.num_eigh = num_eigh
        self.seq_len = seq_len
        self.use_hankel_L = use_hankel_L
        self.use_approx = use_approx
        self.use_flash_fft = use_flash_fft
        self.torch_dtype = torch_dtype

def create_random_stu(d_in=1, d_out=1, num_eigh=24, use_hankel_L=False, use_approx=True):
    # Create a random config
 
    filters = get_spectral_filters(
        seq_len=8192,
        K=num_eigh,
        use_hankel_L=use_hankel_L,
        device=torch.device('cuda'),
        dtype=torch.bfloat16
    )
    
    # Create random STU
    config = STUConfig()
    stu = STU(config, filters).cuda()
    return stu

SyntaxError: unterminated string literal (detected at line 6) (20115604.py, line 6)

In [2]:
stu = create_random_stu()

seq_len = 8192    
# Generate random input
batch_size = 1
d_in = stu.config.n_embd
x = torch.randn(batch_size, seq_len, d_in, dtype=torch.bfloat16).cuda()
input_pos = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1).cuda()

full_fast_stu = FullFastSTU(stu, "./convex_hull/best_phi_lds.pt")
# Get outputs from both models
with torch.no_grad():
    stu_output = stu(x, input_pos)
    full_fast_output = full_fast_stu(x, input_pos)
# Calculate differences
abs_diff = torch.abs(stu_output - full_fast_output)
mean_diff = abs_diff.mean().item()
max_diff = abs_diff.max().item()
    

print(f"Mean absolute difference: {mean_diff}")
print(f"Max absolute difference: {max_diff}")


# Check if outputs are close
is_close = torch.allclose(stu_output, full_fast_output, rtol=1e-5, atol=1e-5)

#with best_phi_lds.pt
#Mean absolute difference: 0.00689697265625
#Max absolute difference: 0.0625

#with 250

  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Mean absolute difference: 0.0123291015625
Max absolute difference: 0.125


In [2]:
from lds import LDS as NLDS
from inference_lds import ILDS

In [3]:
def get_lds(checkpoint_path = './convex_hull/best_phi_lds.pt'):
    # Load the LDS model from the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    
    
    # Create the LDS model
    lds_phi = NLDS(
        state_dim=checkpoint['state_dim'],
        input_dim=checkpoint['input_dim'],
        output_dim=checkpoint['output_dim'],
        kx=checkpoint['kx'],
        dtype=torch.float32 if checkpoint['dtype'] == 'torch.float32' else torch.float64,
        # bsz_dim = 4
    )
    
    
    # Load the weights from checkpoint
    lds_phi.load_state_dict(checkpoint['model_state_dict'], strict = False)
    return lds_phi.cuda()

In [4]:
def get_ilds(checkpoint_path = './convex_hull/best_phi_lds.pt'):
    # Load the LDS model from the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    
    
    # Create the LDS model
    lds_phi = ILDS(
        state_dim=checkpoint['state_dim'],
        input_dim=checkpoint['input_dim'],
        output_dim=checkpoint['output_dim'],
        kx=checkpoint['kx'],
        dtype=torch.float32 if checkpoint['dtype'] == 'torch.float32' else torch.float64,
        bsz_dim = 4
    )
    
    
    # Load the weights from checkpoint
    lds_phi.load_state_dict(checkpoint['model_state_dict'], strict = False)
    return lds_phi.cuda()

In [5]:
lds = get_lds()

  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


In [6]:
ilds = get_ilds()

  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


In [21]:
x = torch.randn(100).reshape(1, -1, 1).cuda()

In [22]:
ilds.reset_state()

In [23]:
lds(x)

tensor([[[ 2.1530e-10,  4.5513e-10,  9.5488e-10,  ...,  3.3501e-03,
          -1.5496e-02, -1.1379e-01],
         [-1.9089e-08, -3.7422e-08, -7.2795e-08,  ...,  4.4006e-02,
          -1.5492e-01, -8.2400e-01],
         [ 3.0989e-07,  5.4630e-07,  9.4081e-07,  ...,  1.2978e-01,
          -2.0339e-01,  6.3200e-01],
         ...,
         [-1.1616e-04, -4.4637e-04, -8.8879e-04,  ...,  6.0397e-02,
          -4.2050e-01,  1.6584e+00],
         [ 2.5105e-04, -5.1268e-06, -3.9115e-04,  ..., -9.7667e-02,
           9.2074e-01, -2.1384e-01],
         [ 1.3872e-04, -1.6183e-04, -8.2756e-04,  ..., -2.6890e-01,
          -5.2398e-01,  3.4276e-01]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [24]:
ilds(x.double())

tensor([[[ 2.1530e-10,  4.5513e-10,  9.5489e-10,  ...,  3.3501e-03,
          -1.5496e-02, -1.1379e-01],
         [-1.9089e-08, -3.7422e-08, -7.2796e-08,  ...,  4.4006e-02,
          -1.5492e-01, -8.2400e-01],
         [ 3.0989e-07,  5.4630e-07,  9.4081e-07,  ...,  1.2978e-01,
          -2.0339e-01,  6.3200e-01],
         ...,
         [-1.1616e-04, -4.4637e-04, -8.8879e-04,  ...,  6.0397e-02,
          -4.2050e-01,  1.6584e+00],
         [ 2.5105e-04, -5.1268e-06, -3.9115e-04,  ..., -9.7667e-02,
           9.2074e-01, -2.1384e-01],
         [ 1.3872e-04, -1.6183e-04, -8.2756e-04,  ..., -2.6890e-01,
          -5.2398e-01,  3.4276e-01]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<CatBackward0>)

In [25]:
ilds.reset_state()
ilds.cache = True

In [28]:
out = []
for i in range(100):
    y = ilds(x[:, i:i+1, :].double())
    out.append(y)

In [29]:
torch.concat(out, dim = 1)

tensor([[[ 2.1530e-10,  4.5513e-10,  9.5489e-10,  ...,  3.3501e-03,
          -1.5496e-02, -1.1379e-01],
         [ 1.6157e-09,  3.4155e-09,  7.1660e-09,  ...,  2.5141e-02,
          -1.1629e-01, -8.5394e-01],
         [-7.9411e-10, -1.6787e-09, -3.5220e-09,  ..., -1.2356e-02,
           5.7157e-02,  4.1970e-01],
         ...,
         [-2.4426e-09, -5.1635e-09, -1.0833e-08,  ..., -3.8007e-02,
           1.7581e-01,  1.2910e+00],
         [-5.4714e-10, -1.1566e-09, -2.4267e-09,  ..., -8.5136e-03,
           3.9381e-02,  2.8918e-01],
         [-3.5817e-10, -7.5715e-10, -1.5885e-09,  ..., -5.5732e-03,
           2.5780e-02,  1.8930e-01]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<CatBackward0>)