In [12]:
import sys
import os
sys.path.append(os.path.abspath("../../src"))

In [13]:
import argparse
import torch
from torch import nn
import matplotlib.pyplot as plt
from stu import STU
import time
import random
import numpy as np

In [14]:
from lds import LDS

In [15]:
def get_hankel(seq_len: int, use_hankel_L: bool = False) -> np.ndarray:
    entries = np.arange(1, seq_len + 1, dtype=np.float64)
    i_plus_j = entries[:, None] + entries[None, :]

    if use_hankel_L:
        sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
        denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
        Z = sgn * (8.0 / denom)
    elif not use_hankel_L:
        Z = 2.0 / (i_plus_j**3 - i_plus_j)
    else:
        raise ValueError("use_hankel_L must be a boolean")

    return Z

def get_spectral_filters(
    seq_len: int,
    K: int,
    use_hankel_L: bool = False,
    device: torch.device = None,
    dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
    # assert torch.cuda.is_available(), "CUDA is required."
    Z = get_hankel(seq_len, use_hankel_L)
    sigma, phi = np.linalg.eigh(Z)
    sigma_k, phi_k = sigma[-K:], phi[:, -K:]
    phi_k *= sigma_k ** 0.25
    filters = torch.from_numpy(phi_k)
    return filters.to(device=device, dtype=dtype)


In [16]:
layer_i = 0
state_dim = 5000
batch_size = 2
epochs = 4000
seq_len = 512
kx = 5
lr = 0.0001

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seq_len = 4096
num_eigh = 30
use_hankel_L  = True
phi= get_spectral_filters(seq_len = seq_len, K = num_eigh,  use_hankel_L= use_hankel_L,
                                device  = device,  dtype = torch.float32)

stu_config = {
    "num_eigh": num_eigh,
    "use_hankel_L": True,
    "torch_dtype": torch.float32,
    "d_in": 1,
    "d_out": 1,
    "seq_len": seq_len,
    "k_u": 0
}


In [17]:
# Initialize LDS model
lds = LDS(state_dim, 768, 768, kx).to(device)
optimizer = torch.optim.Adam(lds.parameters(), lr=lr)

# Training
lds_loss_values = []

best_loss = float('inf')

In [18]:
from basic_stu import STU

In [19]:
phi_0 = STU(stu_config, phi).to(device)

In [20]:
phi_0.M_phi_minus.data = torch.zeros_like(phi_0.M_phi_minus)
phi_0.M_phi_plus.data = torch.zeros_like(phi_0.M_phi_plus)

In [21]:
phi_0.M_phi_plus[0][0][0].data = torch.tensor(1.0, requires_grad= True)

In [22]:
for epoch in range(epochs):
    inputs = torch.randn(batch_size, seq_len, 768).to(device).to(torch.bfloat16)
    stu_outputs = phi_0(inputs).to(device)

    optimizer.zero_grad()
    loss = lds.compute_loss(inputs.to(stu_outputs.dtype), stu_outputs)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(lds.parameters(), max_norm=1)
    lds_loss_values.append(loss.item())
    optimizer.step()

    with torch.no_grad():
        lds.A.data.clamp_(max=1, min=-1)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.81 GiB. GPU 0 has a total capacity of 9.50 GiB of which 2.30 GiB is free. Process 586991 has 7.16 GiB memory in use. Process 1491827 has 5.83 GiB memory in use. Process 3103330 has 700.00 MiB memory in use. Process 3102972 has 910.00 MiB memory in use. Process 3103824 has 932.00 MiB memory in use. Process 3102973 has 910.00 MiB memory in use. Including non-PyTorch memory, this process has 7.15 GiB memory in use. Of the allocated memory 5.69 GiB is allocated by PyTorch, and 1.37 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)