In [1]:
import sys
import os
import tqdm
import argparse
import json
from time import time
import torch
from torch import nn
import torch.nn.functional as F
import tiktoken
import yaml 
from safetensors import safe_open
import numpy as nps
import random
from collections import defaultdict
from tiktoken.load import load_tiktoken_bpe
from model import FlashSTU, FlashSTUConfig
from flash_stu.utils.stu_utils import get_spectral_filters
from flash_stu.utils.random_utils import get_logger, save_yaml_config
import math
from typing import Union
import gc


  from .autonotebook import tqdm as notebook_tqdm


Unable to import Triton-based MLP: No module named 'liger_kernel'. Falling back to vanilla SwiGLU MLP instead.
Unable to import Triton-based RMSNorm: No module named 'liger_kernel'. Falling back to PyTorch implementation.
Unable to import Triton-based MLP: No module named 'liger_kernel'. Falling back to vanilla SwiGLU MLP instead.
Unable to import Triton-based RMSNorm: No module named 'liger_kernel'. Falling back to PyTorch implementation.
Unable to import Triton-based RMSNorm: No module named 'liger_kernel'. Falling back to PyTorch implementation.


In [2]:
logger = get_logger(__name__)
bpe_path = "./o200k_base.tiktoken"

In [3]:
from inference import (
    set_initial_random_seed,
    apply_compile,
    load_stu_model,
    generate_text,
    generate_and_time,
    main
)

In [4]:
eval_path = "./configs/baseline_w_attn_bfloat16/eval.yaml"
checkpoint_path = "./model_step-114000.safetensors"
config_path = "./configs/lds_bfloat16_w_attn/config.json"

In [5]:
eval_config = yaml.load(open(eval_path, 'r'), Loader=yaml.FullLoader)
with open(config_path, "r") as f:
    model_config = json.load(f)

# Create output dir; save yaml configs in there
run_id = random.randint(0, 10 ** 6)

# Set random seed for reproducibility
set_initial_random_seed(eval_config.get('random_seed', -1))

# save yaml configs in there
save_path_for_this_exp = os.path.join(eval_config.get('save_dir'), str(run_id))
os.makedirs(save_path_for_this_exp, exist_ok=True)
logger.info(f"For this experiment, saving path is: {save_path_for_this_exp}")
logger.info(f"eval_config: {eval_config}")
logger.info(f"model_config: {model_config}")
save_yaml_config(eval_config, save_path_for_this_exp, "eval_config.yaml")
save_yaml_config(model_config, save_path_for_this_exp, "model_config.yaml")

# Need to calculate futurefill K here ... TO BE IMPROVED
futurefill_k = eval_config.get("futurefill_k", None)
if futurefill_k is not None:
    if isinstance(futurefill_k[-1], str) and futurefill_k[-1] == "None":
        generation_L = eval_config.get("max_length")[-1]
        futurefill_k = int(math.sqrt(generation_L * math.log2(generation_L)))
    elif isinstance(futurefill_k[-1], int):
        futurefill_k = futurefill_k[-1]

lds_state_dim = eval_config.get('lds_state_dim', None)
lds_path = eval_config.get('lds_path', None)

# Load model and config.
device = torch.device("cuda")

2025-04-09 16:19:33,147 - __main__ - INFO - For this experiment, saving path is: outputs/april7/543186
INFO:__main__:For this experiment, saving path is: outputs/april7/543186
2025-04-09 16:19:33,149 - __main__ - INFO - eval_config: {'random_seed': 74, 'run_name': 'baseline_attn_bfloat16', 'save_dir': 'outputs/april7/', 'input_length': [0], 'max_length': [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], 'temperature': 1, 'top_k': 1, 'repeat': 3, 'cache': True, 'debug': False}
INFO:__main__:eval_config: {'random_seed': 74, 'run_name': 'baseline_attn_bfloat16', 'save_dir': 'outputs/april7/', 'input_length': [0], 'max_length': [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072], 'temperature': 1, 'top_k': 1, 'repeat': 3, 'cache': True, 'debug': False}
2025-04-09 16:19:33,151 - __main__ - INFO - model_config: {'model_type': 'FlashSTU', 'dim': 896, 'num_heads': 8, 'num_layers': 12, 'seq_len': 8192, 'weight_tying': True, 'window_size': 102

In [6]:
model, config_data = load_stu_model(model_config, checkpoint_path, device, futurefill_k = futurefill_k, lds_state_dim = lds_state_dim, lds_path = lds_path)

Model Parameter Count: 550.31M



  return t.to(
2025-04-09 16:19:45,336 - inference - INFO - Loading checkpoint from ./model_step-114000.safetensors...
INFO:inference:Loading checkpoint from ./model_step-114000.safetensors...
2025-04-09 16:19:45,610 - inference - INFO - Checkpoint loaded in 0.27 seconds.
INFO:inference:Checkpoint loaded in 0.27 seconds.
2025-04-09 16:19:45,615 - inference - INFO - Model weights loaded successfully!
INFO:inference:Model weights loaded successfully!


In [7]:
# Create tokenizer (for della)
bpe_dict = load_tiktoken_bpe(bpe_path)
tokenizer = tiktoken.Encoding(
    name="o200k_base",  # Name of the encoding
    pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+""",
    mergeable_ranks=bpe_dict,
    special_tokens={
        "<|endoftext|>": 199999,  # Custom special token example (modify as needed)
        "<|endofprompt|>": 200018,
    }
)

In [8]:
generate_and_time(model, tokenizer, eval_config, save_path_for_this_exp, device)

2025-04-09 16:19:49,440 - inference - INFO - Generating text for prompt 1 of length 0. Max generation is 32
INFO:inference:Generating text for prompt 1 of length 0. Max generation is 32
100%|██████████| 28/28 [00:02<00:00, 11.16it/s]
2025-04-09 16:19:52,021 - inference - INFO - Current runtime: 2.516732931137085
INFO:inference:Current runtime: 2.516732931137085
2025-04-09 16:19:52,022 - inference - INFO - Output: ['Obama is famous for CMUN "cook "cook "cook "cook "cook "cook "cook "cook "cook "cook "cook "cook "cook']
INFO:inference:Output: ['Obama is famous for CMUN "cook "cook "cook "cook "cook "cook "cook "cook "cook "cook "cook "cook "cook']
2025-04-09 16:19:52,024 - inference - INFO - Generating text for prompt 1 of length 0. Max generation is 32
INFO:inference:Generating text for prompt 1 of length 0. Max generation is 32
100%|██████████| 28/28 [00:02<00:00, 11.91it/s]
2025-04-09 16:19:54,431 - inference - INFO - Current runtime: 2.3539979457855225
INFO:inference:Current runtime:

KeyboardInterrupt: 