In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import tiktoken
import logging
import json
from time import time
from safetensors import safe_open
from model import FlashSTU
from config import FlashSTUConfig

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def get_hankel(seq_len: int, use_hankel_L: bool = False) -> np.ndarray:
    entries = np.arange(1, seq_len + 1, dtype=np.float64)
    i_plus_j = entries[:, None] + entries[None, :]

    if use_hankel_L:
        sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
        denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
        Z = sgn * (8.0 / denom)
    elif not use_hankel_L:
        Z = 2.0 / (i_plus_j**3 - i_plus_j)
    else:
        raise ValueError("use_hankel_L must be a boolean")

    return Z

def get_spectral_filters(
    seq_len: int, 
    K: int, 
    use_hankel_L: bool = False, 
    device: torch.device = None,
    dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
    assert torch.cuda.is_available(), "CUDA is required."
    Z = get_hankel(seq_len, use_hankel_L)
    sigma, phi = np.linalg.eigh(Z)
    sigma_k, phi_k = sigma[-K:], phi[:, -K:]
    phi_k *= sigma_k ** 0.25
    filters = torch.from_numpy(phi_k)
    return filters.to(device=device, dtype=dtype)

# Load the checkpoint
print("Loading the checkpoint...")
start_time = time()
state_dict = {}
with safe_open(
    "model_19073.safetensors",
    framework="pt",
    device="cuda",
) as f:
    for k in f.keys():
        state_dict[k] = f.get_tensor(k)

print(f"Successfully loaded the checkpoint in {time() - start_time:.2f} seconds")


  from .autonotebook import tqdm as notebook_tqdm


Unable to import FlashFFTConv: No module named 'flashfftconv'. Falling back to PyTorch implementation.
Unable to import Triton-based MLP: No module named 'liger_kernel'. Falling back to vanilla SwiGLU MLP instead.
Unable to import Triton-based RMSNorm: No module named 'liger_kernel'. Falling back to PyTorch implementation.
Unable to import Triton-based MLP: No module named 'liger_kernel'. Falling back to vanilla SwiGLU MLP instead.
Unable to import Triton-based RMSNorm: No module named 'liger_kernel'. Falling back to PyTorch implementation.
Unable to import Triton-based RMSNorm: No module named 'liger_kernel'. Falling back to PyTorch implementation.
Using device: cuda
Loading the checkpoint...
Successfully loaded the checkpoint in 1.93 seconds


In [2]:
# Set precision for matrix multiplication
torch.set_float32_matmul_precision("high")

# Load model configurations from JSON file
with open("config.json", "r") as file:
    config = json.load(file)

# Extract model configurations
n_embd = config["n_embd"]
n_heads = config["n_heads"]
n_layers = config["n_layers"]
seq_len = config["seq_len"]
window_size = config["window_size"]
vocab_size = config["vocab_size"]
mlp_scale = config["mlp_scale"]
bias = config["bias"]
dropout = config["dropout"]
num_eigh = config["num_eigh"]
use_hankel_L = config["use_hankel_L"]
use_flash_fft = config["use_flash_fft"]
use_approx = config["use_approx"]
use_attn = config["use_attn"]
softcap = config["softcap"]

# Model setup
config = FlashSTUConfig(
    n_embd=n_embd,
    n_heads=n_heads,
    n_layers=n_layers,
    seq_len=seq_len,
    window_size=window_size,
    vocab_size=vocab_size,
    mlp_scale=mlp_scale,
    bias=bias,
    dropout=dropout,
    num_eigh=num_eigh,
    use_hankel_L=use_hankel_L,
    use_flash_fft=use_flash_fft,
    use_approx=use_approx,
    use_attn=use_attn,
    softcap=softcap,
    torch_dtype=getattr(torch, config["torch_dtype"]),
)
phi = get_spectral_filters(seq_len, num_eigh, use_hankel_L, device, torch.float32)
model = FlashSTU(config, phi)

# Load state dictionary into the model
model.load_state_dict(state_dict, strict = True)
model.to(device)
model.eval()

# Prepare tokenizer
tokenizer = tiktoken.get_encoding("o200k_base")

Model Parameter Count: 426.28M



In [3]:
import tqdm
def generate_text(
    model, tokenizer, prompt, num_return_sequences=4, max_length=1024, device="cuda", temperature=1.0, top_k=50
):
    model.eval()
    tokens = torch.tensor([tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})], device=device)
    tokens = tokens.repeat(num_return_sequences, 1)

    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(1337)

    eos_token_id = tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})[0]

    with torch.no_grad():
        for _ in tqdm.tqdm(range(max_length - tokens.size(1))):
            with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
                logits = model(tokens)
                logits = logits[:, -1, :]  # Get logits for the last token

                # Apply temperature scaling if temperature > 0
                if temperature > 0:
                    logits = logits / temperature

            probs = F.softmax(logits, dim=-1)  # Compute probabilities

            # Top-K sampling: set all probabilities outside the top K to 0
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
            ix = torch.multinomial(top_k_probs, 1, generator=sample_rng)
            next_token = torch.gather(top_k_indices, -1, ix)
            tokens = torch.cat((tokens, next_token), dim=1)

            # Break if EOS token is generated
            if (next_token == eos_token_id).any():
                break

    generated_sequences = []
    for i in range(num_return_sequences):
        decoded = tokenizer.decode(tokens[i].tolist())
        generated_sequences.append(decoded)

    return generated_sequences

In [4]:
prompts = [
    "The future of artificial intelligence is",
    # "In the year 2050, the world will",
    # "The most important scientific discovery of the 21st century is",
    # "If I could change one thing about the education system, it would be",
    # "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.",
]

for prompt in prompts:
    print(f"\nGenerating text for prompt: '{prompt}'\n")
    generated_texts = generate_text(model, tokenizer, prompt, num_return_sequences=1, max_length=50)
    for i, text in enumerate(generated_texts):
        print(f"Sample {i + 1}: {text}\n")



Generating text for prompt: 'The future of artificial intelligence is'



100%|██████████| 44/44 [00:05<00:00,  8.77it/s]

Sample 1: The future of artificial intelligence is also a bit controversial in the political area. The idea of artificial intelligence is actually a part of people who might not have a clear idea of how can we make this technology come to life. People that already have computers can






In [5]:
import torch
from  torch import nn
class LDS(nn.Module):
    def __init__(self, state_dim, input_dim, output_dim):
        super(LDS, self).__init__()
        
        self.d_out = output_dim

        self.h0 = nn.Parameter(torch.randn(state_dim))
        # self.A = nn.Parameter(torch.clip(torch.randn(state_dim), -.7, 0.7))
        init_A = torch.randn(state_dim)
        self.A = nn.Parameter(init_A/torch.max(torch.abs(init_A)))
        self.B = nn.Parameter(torch.randn(input_dim, state_dim) / input_dim)
        self.C = nn.Parameter(torch.randn(state_dim,output_dim) / state_dim)
        self.D = nn.Parameter(torch.randn(input_dim,output_dim) / output_dim) #keep for more complex systems

        self.M = nn.Parameter(torch.randn(output_dim, output_dim) / output_dim) #autoregressive
    def forward(self, inputs):
        bsz, seq_len, _ = inputs.shape
        h_t = self.h0.expand(bsz, self.h0.shape[0]).to(device)  # Ensure h0 is on the correct device
        outputs = []
        A = self.A.flatten()

        # Store all intermediate h_t states
        all_h_t = []

        for t in range(seq_len):
            u_t = inputs[:, t, :]  # Get input for all batches at time t
            h_t = A * h_t + u_t @ self.B  # Update hidden states for all batches
            all_h_t.append(h_t.unsqueeze(1))  # Store the updated h_t for each time step

        # Concatenate all h_t states along the time dimension
        all_h_t = torch.cat(all_h_t, dim=1)

        # Apply C to all concatenated h_t states at once
        outputs = torch.matmul(all_h_t, self.C)

        return outputs


    def compute_loss(self, inputs, targets):
        mse_loss = nn.MSELoss()
        outputs = self(inputs)
        return mse_loss(outputs, targets)

In [6]:
# files  =  ["./lds_trained/1895_interim_lds_model_and_optimizer.pt",
#            "./lds_trained/8562_interim_lds_model_and_optimizer.pt",
#            "./lds_trained/7452_4_20000_0.00306_lds_model_and_optimizer.pt",
#            "./lds_trained/6353_6_20000_0.00283_lds_model_and_optimizer.pt",
#            "./lds_trained/8167_10_20000_0.00160_lds_model_and_optimizer.pt"
#            ]

# files  =  ["./lds_trained/4512_0_80000_interim_lds_model_and_optimizer.pt",
#            "./lds_trained/3353_2_80000_interim_lds_model_and_optimizer.pt",
#            "./lds_trained/7452_4_20000_0.00306_lds_model_and_optimizer.pt",
#            "./lds_trained/6353_6_20000_0.00283_lds_model_and_optimizer.pt",
#            "./lds_trained/8167_10_20000_0.00160_lds_model_and_optimizer.pt"
#            ]


files  =  ["./lds_trained/4512_0_80000_interim_lds_model_and_optimizer.pt"]

lds_layers = {}
for idx, layer_idx in enumerate([0]): #], 2, 4, 6, 10]):
    lds_layers[layer_idx] = LDS(80000 if idx < 2 else 20000, 768, 768).to(device)
    checkpoint = torch.load(files[idx])
    lds_layers[layer_idx].load_state_dict(checkpoint["lds_state_dict"])

  checkpoint = torch.load(files[idx])


In [13]:
from dataloader import DataLoader

val_loader = DataLoader(
    bsz=1,
    seq_len=seq_len, 
    dataset='./fineweb-edu', 
    split="val", 
    main_process=True,
)

2024-12-07 16:34:14,939 - INFO - Found 1 shards for split val


In [9]:
from torch.amp import autocast
from torch.nn import CrossEntropyLoss
import tqdm

def evaluate(model):
    loss_fn = CrossEntropyLoss()
    val_loss = 0.0
    torch_dtype = getattr(torch, 'bfloat16')
    val_steps = 20 # Arbitrarily set to reduce long evaluations
    model.eval()
    val_loader.reset()
    with torch.no_grad():
        for i, batch in tqdm.tqdm(zip(range(val_steps), val_loader, strict=False)):
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            if torch_dtype != torch.float32:
                with autocast(device_type=device.type, dtype=torch_dtype, cache_enabled=True):
                    preds = model(inputs)
            else:
                preds = model(inputs)

            loss = loss_fn(preds.flatten(0, 1), targets.flatten(0, 1))
            loss = loss / val_steps
            val_loss += loss.detach().float()
    return(val_loss)

In [15]:
lds_layers

{0: LDS(), 2: LDS(), 4: LDS(), 6: LDS(), 10: LDS()}

In [10]:
import torch
from torch import nn

class Identity(nn.Module):
    def forward(self, x):
        return x

In [None]:
import torch
import  copy
import torch.nn.functional as F

# Generate random input
random_input = torch.randn(10, 50, n_embd).to(device)
stu = copy.deepcopy(model.layers[0].stu)
stu.phi = stu.phi.to(torch.bfloat16)
# Get the output from lds_layers[0]
lds_layers[0].eval()
with torch.no_grad():
    lds_output = lds_layers[0](random_input)

# Get the output from model.layers[0].stu
model.layers[0].stu.eval()
with torch.no_grad():
    stu_output = stu(random_input.to(torch.bfloat16))

# Compute the error
error = F.mse_loss(lds_output, stu_output)
print(f"Error between lds_layers[0] and model.layers[0].stu: {error.item()}")

Error between lds_layers[0] and model.layers[0].stu: 7.877873576944694e-05


In [14]:
import copy
model2 = copy.deepcopy(model)

for idx in [0]:
    lds_layers[idx].eval()
# Substitute LDS models into model2
model2.layers[0].stu = lds_layers[0]
# model2.layers[2].stu = lds_layers[2]
# model2.layers[4].stu = lds_layers[4]
# model2.layers[6].stu = lds_layers[6]
# model2.layers[8].stu = Identity()
# model2.layers[10].stu = lds_layers[10]


# Run evaluation
val_loss = evaluate(model2) #baseline of 3.5
print(f"Validation Loss: {val_loss}")

20it [08:38, 25.92s/it]


Validation Loss: 4.08203125


In [39]:
prompts = [
    "The future of artificial intelligence is",
    # "In the year 2050, the world will",
    # "The most important scientific discovery of the 21st century is",
    # "If I could change one thing about the education system, it would be",
    # "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.",
]

for prompt in prompts:
    print(f"\nGenerating text for prompt: '{prompt}'\n")
    generated_texts = generate_text(model2, tokenizer, prompt, num_return_sequences=1, max_length=50)
    for i, text in enumerate(generated_texts):
        print(f"Sample {i + 1}: {text}\n")



Generating text for prompt: 'The future of artificial intelligence is'



100%|██████████| 44/44 [00:01<00:00, 24.59it/s]

Sample 1: The future of artificial intelligence is already in my thoughts and I cannot write much more than that. By contrast, the artificial intelligence today would be able to do things like the Internet at the speed in a matter of seconds. Therein 1 day.






In [34]:
mean_list = []
var_list = []

model.eval()
val_loader.reset()
with torch.no_grad():
    for batch in tqdm.tqdm(val_loader):
        inputs, targets = batch
        inputs, targets = inputs.to(device).to(torch.bfloat16), targets.to(device).to(torch.bfloat16)
        with autocast(device_type=device.type, dtype=getattr(torch, 'bfloat16'), cache_enabled=True):
        # Forward pass through the layers up to layer 2
            x = model.tok_emb(inputs)
            x = model.dropout(x)
            x = model.layers[0](x)
            x = model.layers[1](x)
        
        # Collect inputs to STU layer 2
        stu_inputs = x.clone().detach().cpu().numpy()
        mean_list.append(np.mean(stu_inputs))
        var_list.append(np.var(stu_inputs))

mean_stu_layer_2 = np.mean(mean_list)
var_stu_layer_2 = np.mean(var_list)

print(f"Mean of inputs to STU layer 2: {mean_stu_layer_2}")
print(f"Variance of inputs to STU layer 2: {var_stu_layer_2}")

0it [00:00, ?it/s]


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CUDABFloat16Type instead (while checking arguments for embedding)