In [1]:
# In your notebook or script:
from model.model_minimind import MiniMindForCausalLM, MiniMindConfig
from model.extend_minimind import SorlModelWrapper
import torch 

# 1. Create the base model
minimind_config = MiniMindConfig()
base_model = MiniMindForCausalLM(minimind_config)
# ... load your trained weights into base_model ...

# 2. Define the SORL configuration
sorl_config = {
    "vocab_size_list": [128, 64, 32], # Must match your SORL setup
    "memory_span": 512,
    # ... any other SORL parameters you need ...
}

# 3. Create the SORL wrapper
sorl_model = SorlModelWrapper(base_model, sorl_config)

# 4. Generate text using the custom SORL logic
prompt = torch.tensor([[1, 2, 3]])
generated_sequence = sorl_model.generate(
    input_ids=prompt,
    max_new_tokens=50,
    force_abstraction_every_n=10 # Example: force an abstraction token every 10 steps
)

print("Generated Sequence:", generated_sequence)


AssertionError: State mismatch: KV cache length 2 != levels_cache length 3

#### Towards Alignment
- And honestly, GPT->GAT change is really on .generate / .denoise method, as well as the tokenization

In [3]:
import torch
import torch.nn.functional as F
from typing import Optional
from transformers import LogitsProcessor

from model.extend_minimind import MiniMindForGeneration, infer_level
from model.model_minimind import MiniMindConfig

# 1. Create the base model using our compatible subclass
minimind_config = MiniMindConfig()
model = MiniMindForGeneration(minimind_config)
model.eval()
device = model.device

# 2. Define the SORL configuration
sorl_config = {
    "vocab_size_list": [128, 64, 32], # Must match your SORL setup
    "memory_span": 5, # Using a small number for easy debugging
}

# 3. SORL vocabulary setup
vocab_sizes = torch.tensor(sorl_config["vocab_size_list"]).to(device) + 1
level_mask_tokens = vocab_sizes.cumsum(dim=0) - 1

print("Setup complete. Model is on device:", device)

Setup complete. Model is on device: cpu


In [4]:
# --- Initial State ---
# The prompt we want the model to complete
prompt_ids = torch.tensor([[10, 20, 130, 40, 150]], device=device) # Using some high/low level tokens

# The sequence of generated tokens, starting with the prompt
generated_ids = prompt_ids.clone()

# The KV cache from the previous step (starts as None)
past_key_values = None

# The SORL levels corresponding to every token in the KV cache
levels_cache = infer_level(generated_ids, vocab_sizes, level_mask_tokens[0])

print("--- Initial State ---")
print("Generated IDs:", generated_ids)
print("Levels Cache:", levels_cache)
print("KV Cache:", past_key_values)

--- Initial State ---
Generated IDs: tensor([[ 10,  20, 130,  40, 150]])
Levels Cache: tensor([[0, 0, 1, 0, 1]])
KV Cache: None


In [5]:
# --- Perform one step of generation ---

# 1. Prepare inputs for the model
# The model's internal helper function correctly handles the KV cache
model_inputs = model.prepare_inputs_for_generation(
    input_ids=generated_ids, 
    past_key_values=past_key_values
)

# 2. Forward pass to get logits and the new, full KV cache
outputs = model(**model_inputs)
next_token_logits = outputs.logits[:, -1, :]
current_pkv = outputs.past_key_values

# 3. Sample the next token (using greedy search for simplicity)
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

# 4. Update our generated sequence and levels cache with the new token's info
generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
new_level = infer_level(next_token_id, vocab_sizes, level_mask_tokens[0])
levels_cache = torch.cat([levels_cache, new_level], dim=1)

print(f"--- Step Output (Token #{generated_ids.shape[1] - prompt_ids.shape[1]}) ---")
print("New Token ID:", next_token_id.item())
print("New Token Level:", new_level.item())
print("\n--- State BEFORE Pruning ---")
print("KV Cache Length:", current_pkv[0][0].shape[2])
print("Levels Cache:   ", levels_cache)


# 5. Prune the state for the *next* iteration
seq_len = current_pkv[0][0].shape[2]
assert seq_len == levels_cache.shape[1], "CRITICAL: State mismatch before pruning!"

is_recent = torch.arange(seq_len, device=device) >= (seq_len - sorl_config["memory_span"])
is_high_level = (levels_cache > 0).squeeze(0)

keep_mask = is_recent | is_high_level
keep_indices = torch.where(keep_mask)[0]

pruned_pkv_list = []
for k, v in current_pkv:
    pruned_k = k[:, :, keep_indices, :]
    pruned_v = v[:, :, keep_indices, :]
    pruned_pkv_list.append((pruned_k, pruned_v))

# This is the final state we carry over to the next loop iteration
past_key_values = tuple(pruned_pkv_list) 
levels_cache = levels_cache[:, keep_indices]

print("\n--- State AFTER Pruning (for next step) ---")
print("Kept Indices:   ", keep_indices.tolist())
print("Pruned KV Cache Length:", past_key_values[0][0].shape[2])
print("Pruned Levels Cache:   ", levels_cache)
print("-------------------------------------------\n")


--- Step Output (Token #1) ---
New Token ID: 3601
New Token Level: 0

--- State BEFORE Pruning ---
KV Cache Length: 2
Levels Cache:    tensor([[0, 0, 1, 0, 1, 0]])


AssertionError: CRITICAL: State mismatch before pruning!