In [2]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

In [3]:
try:
    from flashfftconv import FlashFFTConv

    flash_fft_available = True
except ImportError as e:
    print(
        f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation."
    )
    flash_fft_available = False


In [4]:
import argparse
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from model_550m import STU, flash_convolve
import time
import random
from torch.nn import functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from lds import LDS

In [6]:
layer_i = 2
state_dim = 1000
seq_len = 512
kx = 5
lr = 0.0001
epochs = 5000

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the layer i weights
stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)
stu_layer_full.eval()

# Initialize LDS model
lds = LDS(state_dim, 896, 896, kx).to(device)
optimizer = torch.optim.Adam(lds.parameters(), lr=lr)

# Training
lds_loss_values = []

best_loss = float('inf')

  stu_layer_full = torch.load(f"../stu_layers/stu_layer_{layer_i}_550m_param_full.pt", map_location=device)


In [7]:
phi = stu_layer_full.stu_filters

In [8]:
def gen_stu_impulse_approx(stu, seq_len=1000):
    """
    Generate the impulse response of a STU model with approximation.
    
    Args:
        stu: The STU model
        seq_len: Length of the impulse response
        
    Returns:
        impulse_response: The impulse response of the STU model with shape (seq_len, d_out, d_in)
    """
    # Create an impulse input
    batch_size = 1
    d_in = stu.d_in
    d_out = stu.d_out
    impulse = torch.zeros((batch_size, seq_len, d_in), device=stu.M_inputs.device if hasattr(stu, 'M_inputs') else 'cpu')
    
    # Initialize the output tensor with the correct shape (seq_len, d_out, d_in)
    impulse_response = torch.zeros((seq_len, d_out, d_in), device=impulse.device)
    
    # For each input dimension, create an impulse and get the response
    for i in range(d_in):
        # Reset the impulse tensor
        impulse.zero_()
        # Set the impulse for the current input dimension
        impulse[:, 0, i] = 1.0
        
        # Pass the impulse through the STU model
        with torch.no_grad():
            if stu.use_approx:
                # Project the impulse using M_inputs
                impulse_proj = impulse @ stu.M_inputs.float()
                
                # Project the filters using M_filters
                phi_proj = stu.stu_filters.float() @ stu.M_filters.float()
                
                # Compute the convolution
                if stu.flash_fft:
                    spectral_plus, spectral_minus = flash_convolve(
                        impulse_proj, phi_proj, stu.flash_fft, stu.use_approx
                    )
                else:
                    spectral_plus, spectral_minus = convolve(
                        impulse_proj, phi_proj, stu.n, stu.use_approx
                    )
                
                # The impulse response for this input dimension
                response = spectral_plus if stu.use_hankel_L else spectral_plus + spectral_minus
            else:
                # For non-approximation case, use the original forward pass
                response = stu(impulse)
            
            # Store the response for this input dimension
            impulse_response[:, :, i] = response.squeeze(0).float()
    
    return impulse_response.cpu().numpy()


In [9]:
stu_impulse = gen_stu_impulse_approx(stu_layer_full, seq_len = seq_len)
stu_impulse = torch.Tensor(stu_impulse).cuda()

In [10]:
torch.save(stu_impulse.cpu(), "filter_2_impulse.pth")

In [11]:
torch.save(stu_layer_full.stu_filters.cpu(), "phi.pth")

In [12]:
stu_impulse.shape

torch.Size([512, 896, 896])

In [None]:
import torch
import torch.nn.functional as F

for epoch in range(epochs):
    optimizer.zero_grad()
    
    total_loss = 0.0
    
    # Get model parameters
    A = lds.A
    B = lds.B
    C = lds.C
    M = lds.M
    
    # Compute loss by summing (C.T @ A^i @ B.T + M[:,:,i] - stu_impulse[i])**2 directly
    running_loss = 0.0
    
    for i in range(seq_len):
        # Compute C @ A^i @ B directly for the impulse response at time i
        # This is equivalent to computing the impulse response at time i
        x = B.T
        x = (A**i).reshape(-1,1) * x
        y_pred = C.T @ x
        
        # Add M[:,:,i] for the first kx steps
        if i < kx:
            y_pred = y_pred + M[:,:,i]
        
        # Compute squared error with stu_impulse[i]
        squared_error = torch.sum((y_pred - stu_impulse[i])**2)
        running_loss += squared_error
    
    # Compute mean squared error
    total_loss = running_loss / seq_len
    total_loss.backward()
    
    torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
    lds_loss_values.append(total_loss.item())
    optimizer.step()

    with torch.no_grad():
        lds.A.data.clamp_(max=1, min=-1)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss.item()}")

Epoch 0, Loss: 1.3863502740859985
Epoch 10, Loss: 1.297054648399353
Epoch 20, Loss: 1.2133961915969849
Epoch 30, Loss: 1.1305699348449707
Epoch 40, Loss: 1.0477790832519531
Epoch 50, Loss: 0.9660797119140625
Epoch 60, Loss: 0.8869556188583374
Epoch 70, Loss: 0.8116658926010132
Epoch 80, Loss: 0.7410813570022583
Epoch 90, Loss: 0.675707995891571
Epoch 100, Loss: 0.6157510280609131
Epoch 110, Loss: 0.5611893534660339
Epoch 120, Loss: 0.5118350386619568
Epoch 130, Loss: 0.4673987030982971
Epoch 140, Loss: 0.4275326430797577
Epoch 150, Loss: 0.3918716013431549
Epoch 160, Loss: 0.36004307866096497
Epoch 170, Loss: 0.33168020844459534
Epoch 180, Loss: 0.30642545223236084
Epoch 190, Loss: 0.2839439809322357
Epoch 200, Loss: 0.2639216184616089
Epoch 210, Loss: 0.24607297778129578
Epoch 220, Loss: 0.23013825714588165
Epoch 230, Loss: 0.2158854454755783
Epoch 240, Loss: 0.20310872793197632
Epoch 250, Loss: 0.19162611663341522
Epoch 260, Loss: 0.18127740919589996
Epoch 270, Loss: 0.17192344367504

In [None]:
torch.save(lds.state_dict(), "lds_10k_5_impulse.pth")

In [None]:
# Compute the impulse response of the trained LDS model
with torch.no_grad():
    lds_impulse = lds.impulse(seq_len=stu_impulse.shape[0])

# Print shapes for verification
print(f"LDS impulse shape: {lds_impulse.shape}")
print(f"STU impulse shape: {stu_impulse.shape}")

# Compute the mean squared error between the two impulse responses
mse = torch.mean((lds_impulse - stu_impulse) ** 2)
print(f"Mean Squared Error between LDS and STU impulse: {mse.item()}")

# Visualize a few impulse responses for comparison
import matplotlib.pyplot as plt

# Select a few input-output pairs to visualize
input_idx = 10  # First input dimension
output_indices = [0, 1]  # First two output dimensions

plt.figure(figsize=(12, 8))
for i, output_idx in enumerate(output_indices):
    plt.subplot(len(output_indices), 1, i+1)
    
    # Plot LDS impulse response
    plt.plot(lds_impulse[:, output_idx, input_idx].cpu().numpy(), 
             label=f'LDS Impulse (out={output_idx}, in={input_idx})')
    
    # Plot student impulse response
    plt.plot(stu_impulse[:, output_idx, input_idx].cpu().numpy(), 
             label=f'STU Impulse (out={output_idx}, in={input_idx})')
    
    plt.title(f'Impulse Response: Output {output_idx}, Input {input_idx}')
    plt.xlabel('Time step')
    plt.ylabel('Response')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Test how the models respond to Gaussian input
import torch
import numpy as np
import matplotlib.pyplot as plt

# Generate Gaussian input sequence
seq_len = 1000
input_dim = stu_impulse.shape[2]  # Get input dimension from the impulse shape
batch_size = 1

# Create random Gaussian input
np.random.seed(42)  # For reproducibility
gaussian_input = torch.tensor(np.random.normal(0, 1, (batch_size, seq_len, input_dim)), 
                             dtype=torch.float32).to(device)

# Run both models on the same input
with torch.no_grad():
    # Get LDS response to Gaussian input
    lds_response = lds(gaussian_input)
    
    # For STU, we need to use the impulse response to compute the output
    # This is essentially a convolution of the input with the impulse response
    stu_response = torch.zeros((batch_size, seq_len, stu_impulse.shape[1]), 
                              dtype=torch.float32).to(device)
    
    # Convolve input with impulse response
    for b in range(batch_size):
        for t in range(seq_len):
            for tau in range(min(t+1, stu_impulse.shape[0])):
                stu_response[b, t] += torch.matmul(
                    stu_impulse[tau], gaussian_input[b, t-tau]
                )

# Compute MSE between responses
response_mse = torch.mean((lds_response - stu_response) ** 2)
print(f"MSE between LDS and STU responses to Gaussian input: {response_mse.item()}")

# Visualize a few output dimensions
output_indices = [0, 1]  # First two output dimensions

plt.figure(figsize=(12, 8))
for i, output_idx in enumerate(output_indices):
    plt.subplot(len(output_indices), 1, i+1)
    
    # Plot LDS response
    plt.plot(lds_response[0, :, output_idx].cpu().numpy(), 
             label=f'LDS Response (out={output_idx})')
    
    # Plot STU response
    plt.plot(stu_response[0, :, output_idx].cpu().numpy(), 
             label=f'STU Response (out={output_idx})')
    
    plt.title(f'Response to Gaussian Input: Output {output_idx}')
    plt.xlabel('Time step')
    plt.ylabel('Response')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.show()