In [None]:
import os
import argparse
import glob
import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F

from modeling_hyena import StripedHyenaModelForCausalLM

from datasets import load_from_disk, concatenate_datasets
from accelerate.utils import set_seed
from transformers import (
    set_seed as transformers_set_seed,
    PreTrainedTokenizerFast,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer
)

SEED = 42
set_seed(SEED)
transformers_set_seed(SEED)

checkpoint = "../weights"
model = StripedHyenaModelForCausalLM.from_pretrained(checkpoint)

```
get_block
    raise NotImplementedError
NotImplementedError
```

In [None]:
checkpoint = "../weights/model.safetensors"
model = StripedHyenaModelForCausalLM.from_pretrained(checkpoint)

```
OSError: Incorrect path_or_model_id: '../weights/model.safetensors'. Please provide either the path to a local folder or the repo_id of a model on the Hub.
```

In [None]:
from configuration_hyena import StripedHyenaConfig

config_dict = {
    "vocab_size": 32,                   
    "hidden_size": 128,                 
    "num_filters": 128,                 
    "inner_mlp_size": 352,             
    "attn_layer_idxs": [4, 8, 12],      
    "hyena_layer_idxs": [0, 1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15],
    "num_layers": 16,                   
    "tie_embeddings": True,            
    "short_filter_length": 3,          
    "num_attention_heads": 16,         
    "proj_groups": 1,                  
    "hyena_filter_groups": 1,          
    "split_k0": True,                  
    "column_split_hyena": True,        
    "column_split": False,             
    "model_parallel_size": 1,          
    "pipe_parallel_size": 1,           
    "short_filter_bias": True,         
    "mha_out_proj_bias": True,         
    "qkv_proj_bias": True,             
    "final_norm": True,                
    "use_cache": False,                
    "use_flash_attention_2": True,     
    "use_flash_rmsnorm": True,         
    "use_flash_depthwise": False,      
    "use_flashfft": False,             
    "inference_mode": True,            
    "prefill_style": "fft",            
    "max_seqlen": 65536,               
    "eps": 1e-5,                       
    "state_size": 8,                   
    "rotary_emb_base": 500000,         
    "smeared_gqa": False,              
    "make_vocab_size_divisible_by": 8,  
    "log_intermediate_values": False,   
    "bidirectional": False              
}

config = StripedHyenaConfig(**config_dict)
checkpoint = "../weights"
model = StripedHyenaModelForCausalLM.from_pretrained(checkpoint, config=config)

In [None]:
from configuration_hyena import StripedHyenaConfig
from safetensors.torch import load_file

config_dict = {
    "vocab_size": 32,                   
    "hidden_size": 128,                 
    "num_filters": 128,                 
    "inner_mlp_size": 352,             
    "attn_layer_idxs": [4, 8, 12],      
    "hyena_layer_idxs": [0, 1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15],
    "num_layers": 16,                   
    "tie_embeddings": True,            
    "short_filter_length": 3,          
    "num_attention_heads": 16,         
    "proj_groups": 1,                  
    "hyena_filter_groups": 1,          
    "split_k0": True,                  
    "column_split_hyena": True,        
    "column_split": False,             
    "model_parallel_size": 1,          
    "pipe_parallel_size": 1,           
    "short_filter_bias": True,         
    "mha_out_proj_bias": True,         
    "qkv_proj_bias": True,             
    "final_norm": True,                
    "use_cache": False,                
    "use_flash_attention_2": True,     
    "use_flash_rmsnorm": True,         
    "use_flash_depthwise": False,      
    "use_flashfft": False,             
    "inference_mode": True,            
    "prefill_style": "fft",            
    "max_seqlen": 65536,               
    "eps": 1e-5,                       
    "state_size": 8,                   
    "rotary_emb_base": 500000,         
    "smeared_gqa": False,              
    "make_vocab_size_divisible_by": 8,  
    "log_intermediate_values": False,   
    "bidirectional": False              
}

config = StripedHyenaConfig(**config_dict)
model = StripedHyenaModelForCausalLM(config)

checkpoint = "../weights/model.safetensors"
state_dict = load_file(checkpoint)
model.load_state_dict(state_dict)

```
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for StripedHyenaModelForCausalLM:
Missing key(s) in state_dict: "backbone.unembed.weight".
```

In [None]:
print("model state_dict:", [k for k in model.state_dict().keys() if "blocks" not in k])
print("model.safetensors:", [k for k in state_dict.keys() if "blocks" not in k])

```
model state_dict: ['backbone.embedding_layer.weight', 'backbone.norm.scale', 'backbone.unembed.weight']
model.safetensors: ['backbone.embedding_layer.weight', 'backbone.norm.scale']
```

In [None]:
model.load_state_dict(state_dict, strict=False)

In [None]:
from modeling_hyena import StripedHyenaModelForCausalLM, StripedHyenaModelForExtractingEmbeddings
from configuration_hyena import StripedHyenaConfig
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="../processing-seqs/lornash_tokenizer.json",
    padding_side='right',
    truncation_side='right',
    cls_token='[CLS]',
    bos_token='[CLS]',
    sep_token='[SEP]',
    eos_token='[SEP]',
    unk_token='[UNK]',
    mask_token='[MASK]',
    pad_token='[PAD]',
    model_max_length=2**16
)

config_dict = {
    "vocab_size": 32,                   
    "hidden_size": 128,                 
    "num_filters": 128,                 
    "inner_mlp_size": 352,             
    "attn_layer_idxs": [4, 8, 12],      
    "hyena_layer_idxs": [0, 1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15],
    "num_layers": 16,                   
    "tie_embeddings": True,            
    "short_filter_length": 3,          
    "num_attention_heads": 16,         
    "proj_groups": 1,                  
    "hyena_filter_groups": 1,          
    "split_k0": True,                  
    "column_split_hyena": True,        
    "column_split": False,             
    "model_parallel_size": 1,          
    "pipe_parallel_size": 1,           
    "short_filter_bias": True,         
    "mha_out_proj_bias": True,         
    "qkv_proj_bias": True,             
    "final_norm": True,                
    "use_cache": False,                
    "use_flash_attention_2": True,     
    "use_flash_rmsnorm": True,         
    "use_flash_depthwise": False,      
    "use_flashfft": False,             
    "inference_mode": True,            
    "prefill_style": "fft",            
    "max_seqlen": 65536,               
    "eps": 1e-5,                       
    "state_size": 8,                   
    "rotary_emb_base": 500000,         
    "smeared_gqa": False,              
    "make_vocab_size_divisible_by": 8,  
    "log_intermediate_values": False,   
    "bidirectional": False}

config = StripedHyenaConfig(**config_dict)

checkpoint = "../weights"
model = StripedHyenaModelForCausalLM.from_pretrained(checkpoint, config=config)
model = StripedHyenaModelForExtractingEmbeddings.from_pretrained(checkpoint, tokenizer=tokenizer, config=config)