In [1]:
import torch
import sys
import os
sys.path.append(os.path.abspath("../src"))
sys.path.append(os.path.abspath("../experiments/convex_hull"))

from model_550m import get_spectral_filters, STU
from full_fast_accel import FullFastSTU

class STUConfig():
    def __init__(
        self,
        d_in=896,
        d_out=896,
        num_eigh=24,
        seq_len=8192,
        use_hankel_L=False,
        use_approx=True,
        use_flash_fft = True,
        torch_dtype = torch.bfloat16
    ):
        super().__init__()
        self.n_embd = d_in  # Used by some parts of the code
        self.dim = d_in     # Used by other parts of the code
        self.d_out = d_out
        self.num_eigh = num_eigh
        self.seq_len = seq_len
        self.use_hankel_L = use_hankel_L
        self.use_approx = use_approx
        self.use_flash_fft = use_flash_fft
        self.torch_dtype = torch_dtype

def create_random_stu(d_in=1, d_out=1, num_eigh=24, use_hankel_L=False, use_approx=True):
    # Create a random config
 
    filters = get_spectral_filters(
        seq_len=8192,
        K=num_eigh,
        use_hankel_L=use_hankel_L,
        device=torch.device('cuda'),
        dtype=torch.bfloat16
    )
    
    # Create random STU
    config = STUConfig()
    stu = STU(config, filters).cuda()
    return stu

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
stu = create_random_stu()

seq_len = 100    
# Generate random input
batch_size = 1
d_in = stu.config.n_embd
x = torch.randn(batch_size, seq_len, d_in, dtype=torch.bfloat16).cuda()
input_pos = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1).cuda()

full_fast_stu = FullFastSTU(stu, "../experiments/convex_hull/fit_filters_205/250_phi_lds_float32.pt")

# Get outputs from both models
with torch.no_grad():
    stu_output = stu(x, input_pos)
    full_fast_output = full_fast_stu(x, input_pos)
# Calculate differences
abs_diff = torch.abs(stu_output - full_fast_output)
mean_diff = abs_diff.mean().item()
max_diff = abs_diff.max().item()
    

print(f"Mean absolute difference: {mean_diff}")
print(f"Max absolute difference: {max_diff}")


# Check if outputs are close
is_close = torch.allclose(stu_output, full_fast_output, rtol=1e-5, atol=1e-5)

#with best_phi_lds.pt
#Mean absolute difference: 0.00689697265625
#Max absolute difference: 0.0625

#with 250

  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Mean absolute difference: 0.09912109375
Max absolute difference: 4.0


In [None]:
full_fast_stu.profile(input_shape = (1, 128, 896))

Performing 3 warmup runs...


In [None]:
x = x.double()
x = x @ full_fast_stu.M_inputs

bsz = x.shape[0]
x_reshaped = x.permute(0, 2, 1).reshape(-1, x.shape[1], 1)
U_reshaped = full_fast_stu.lds(x_reshaped)
U_reshaped = torch.cat([U_reshaped[:, :, :24] @ full_fast_stu.M_filters, U_reshaped[:, :, 24:] @ full_fast_stu.M_filters], dim = -1)

U = U_reshaped.reshape(bsz, x.shape[2], x.shape[1], -1).permute(0, 2, 3, 1)


In [None]:
U.shape

In [None]:
x = x.double()
x = x @ full_fast_stu.M_inputs #[1, 8192, 4]

bsz = x.shape[0]
x_reshaped = x.permute(0, 2, 1).reshape(-1, x.shape[1], 1) #[4, 8192, 1]
U_reshaped = full_fast_stu.lds(x_reshaped) #[4, 8192, 48]
U_reshaped = torch.cat([U_reshaped[:, :, :24] @ full_fast_stu.M_filters, U_reshaped[:, :, 24:] @ full_fast_stu.M_filters], dim = -1)
#[4, 8192, 8]

U = U_reshaped.reshape(bsz, x.shape[2], x.shape[1], -1).permute(0, 2, 3, 1)
#[1, 8192, 8, 4]

In [11]:
print(U)

tensor([[[[ 24.8791,  19.1846,  18.5300,   2.2011],
          [ -9.3722,  -7.2270,  -6.9804,  -0.8292],
          [  5.2935,   4.0819,   3.9426,   0.4683],
          ...,
          [ -9.3714,  -7.2264,  -6.9798,  -0.8291],
          [  5.2935,   4.0819,   3.9426,   0.4683],
          [ -0.7550,  -0.5822,  -0.5623,  -0.0668]],

         [[  3.3225,   4.6895,   4.0535,  -3.0851],
          [ -1.0422,  -1.6051,  -1.3710,   1.1807],
          [  2.4939,   2.3757,   2.1934,  -0.4983],
          ...,
          [  2.2698,   0.9489,   1.0958,   1.4736],
          [ -3.1872,  -2.0050,  -2.0379,  -1.0009],
          [  3.3025,   2.4820,   2.4118,   0.3947]],

         [[-12.5063,  -9.0107,  -8.9068,  -2.1296],
          [  5.0971,   3.7098,   3.6559,   0.8079],
          [  6.8477,   5.5679,   5.3004,   0.1454],
          ...,
          [  4.8798,   3.8256,   3.7043,   0.3389],
          [  7.2195,   5.3688,   5.2168,   0.9499],
          [  1.7835,   1.6471,   1.5320,  -0.2734]],

         ...,

In [13]:
x = x.double()
x = x @ full_fast_stu.M_inputs #[1, 8192, 4]

bsz = x.shape[0]
x_reshaped = x.permute(0, 2, 1).reshape(-1, x.shape[1], 1) #[4, 8192, 1]
U_reshaped = full_fast_stu.lds(x_reshaped) #[4, 8192, 48]
U_reshaped = torch.cat([U_reshaped[:, :, :24] @ (full_fast_stu.M_filters @ full_fast_stu.M_inputs), U_reshaped[:, :, 24:] @ (full_fast_stu.M_filters @ full_fast_stu.M_inputs)], dim = -1)
#[4, 8192, 8]

U = U_reshaped.reshape(bsz, x.shape[2], x.shape[1], -1).permute(0, 2, 3, 1)
#[1, 8192, 8, 4]

In [14]:
print(U)

tensor([[[[ 142.2505,   89.6356,   87.9379,   43.5473],
          [ 130.8444,   82.4483,   80.8867,   40.0555],
          [ 143.4781,   90.4092,   88.6968,   43.9231],
          ...,
          [ 130.8443,   82.4483,   80.8867,   40.0555],
          [ 143.4791,   90.4098,   88.6974,   43.9234],
          [ -14.2901,   -9.0045,   -8.8340,   -4.3746]],

         [[  51.8887,   26.8500,   31.7309,   26.3099],
          [  29.9069,   13.4675,   18.1697,   18.7447],
          [  39.9690,   19.2887,   24.3592,   22.7509],
          ...,
          [ -25.4501,  -21.4143,  -16.0514,    1.7982],
          [ -35.0811,  -28.0023,  -22.0360,   -0.2242],
          [ -23.7802,  -14.3972,  -14.6659,   -8.3271]],

         [[  -2.1186,   -3.9061,   -1.2973,    3.9833],
          [ -37.8131,  -25.4595,  -23.3209,   -8.6213],
          [ -84.8005,  -55.5200,  -52.3803,  -22.1946],
          ...,
          [ -38.7571,  -23.7793,  -23.7698,  -12.9673],
          [ -86.0791,  -53.2412,  -52.9881,  -28.0862],