#### Self-organizing LLM wrapper

In [12]:
import torch
from model.model_sorl import SorlModelWrapper
from model.model_sorl import infer_level
from model.model_minimind import MiniMindConfig

full_vocab_list = [11, 50] # Base vocab + abstract vocabs
model = SorlModelWrapper(
    config=MiniMindConfig(vocab_size=sum(full_vocab_list)), # Config needs the total new vocab size
    full_vocab_size_list=full_vocab_list,
    memory_span=5
)
# --- Generate text using the custom SORL logic ---

prompt = torch.tensor([[1, 2, 3]])
generated_sequence = model.generate(
    input_ids=prompt,
    max_new_tokens=50,
    temperature=0.0,
    top_k=50,
    force_abstraction_every_n=4  # Example: force an abstraction token every 10 steps
)

print("--- SORL Generation Results ---")
print("Base vocabulary size:", model.vocab_sizes[0].item())
print("Total vocabulary size:", model.model.config.vocab_size)
print("\nGenerated Sequence:", generated_sequence)


result = model.forward(prompt)
print("\n--- Forward propagation (sparse attention) ---")
print("result.logits.shape: ", result.logits.shape)


orig_tokens = torch.tensor([[1,2,3,62,2,4,1,62,3,4,2,62]])

levels = infer_level(orig_tokens, model.vocab_sizes, -1)
denoise_mask = torch.isin(orig_tokens, model.level_mask_tokens[1:])
denoise_levels = levels[denoise_mask]

new_tokens = model.denoise(orig_tokens, denoise_mask, denoise_levels, 0.0)
print("\n--- Denoising ---")
print(f"Generating 2 level-1 tokens in parallel: {orig_tokens[0].tolist()} --> {new_tokens[0].tolist()}")

--- SORL Generation Results ---
Base vocabulary size: 12
Total vocabulary size: 63

Generated Sequence: tensor([[49, 49, 13, 13, 13, 13, 13, 13, 13, 13, 13,  1,  1,  1, 13,  8,  8]])

--- Forward propagation (sparse attention) ---
result.logits.shape:  torch.Size([1, 3, 63])

--- Denoising ---
Generating 2 level-1 tokens in parallel: [1, 2, 3, 62, 2, 4, 1, 62, 3, 4, 2, 62] --> [1, 2, 3, 49, 2, 4, 1, 49, 3, 4, 2, 49]


In [1]:
import torch
from model.model_sorl import SorlModelWrapper
from model.model_minimind import MiniMindConfig
# from model.model_sorl import 
from src.sorl import SORLConfig, sorl_search, compute_per_token_loss, compute_loss, evaluate

# --- 1. Setup the Model and Configuration ---
print("="*80)
print("--- Initializing Model and SORL Configuration ---")
print("="*80)

# Initialize a SORL-wrapped MiniMind model from scratch for the test
base_vocab_size = 512
abstract_vocab_sizes = [128]
full_vocab_list = [base_vocab_size] + abstract_vocab_sizes

minimind_config = MiniMindConfig(
    hidden_size=64, # Using smaller dimensions for faster testing
    num_attention_heads=2,
    num_hidden_layers=2,
    intermediate_size=128,
    vocab_size=sum(full_vocab_list)
)
sorl_model = SorlModelWrapper(
    config=minimind_config,
    full_vocab_size_list=full_vocab_list,
    memory_span=1024
)

# Create a configuration for the SORL search algorithm
# These parameters control how abstraction is performed
sorl_config = SORLConfig(
    n=4,                    # Number of candidates to roll out
    temperature=1.0,        # Temperature for sampling abstract tokens
    K=8,                    # Rhythmic stride for level-1 abstraction
    l=1,                    # The abstraction level to search for
    steps=4,                # Steps for chunk-wise denoising
    max_t_search=32,        # Max number of abstract timestamps to search within
    use_rhythmic_placeholders=True,
    use_spike_placeholders=False # Disable spike for simplicity in this test
)

print(f"Model Initialized. Total vocabulary size: {sorl_model.model.config.vocab_size}")
print(f"SORL Config: {sorl_config}\n")


# --- 2. Create Dummy Data ---
batch_size = 2
seq_len = 128
# Create a batch of random token sequences
dummy_data = torch.randint(0, base_vocab_size, (batch_size, seq_len), device=sorl_model.model.device)
print(f"Created dummy data with shape: {dummy_data.shape}\n")


# --- 3. Test the `sorl_search` Function ---
print("="*80)
print("--- Testing `sorl_search` ---")
print("="*80)

# This is the core of the SORL algorithm. It takes the original data and
# finds a better representation by inserting abstract tokens.
with torch.no_grad():
    best_sequence, switch_ratio = sorl_search(dummy_data, sorl_model, sorl_config)

print(f"Original sequence length: {dummy_data.shape[1]}")
print(f"Sequence length after search (with abstractions): {best_sequence.shape[1]}")
print(f"Abstraction switch ratio: {switch_ratio:.2f}")
# The switch ratio indicates how often the algorithm preferred a sampled abstraction over the greedy one.
assert best_sequence.shape[0] == batch_size
assert best_sequence.shape[1] > seq_len # Should be longer due to added placeholders
print("✅ `sorl_search` test passed.\n")


# --- 4. Test Loss Computation ---
print("="*80)
print("--- Testing Loss Computation ---")
print("="*80)

# Compute the loss on the improved sequence found by the search.
# This is what would be used for the backward pass during training.
ppt = compute_per_token_loss(sorl_model, best_sequence)
ssl_loss, abs_loss = compute_loss(best_sequence, sorl_model, ppt)
total_loss = ssl_loss + abs_loss

print(f"Per-token loss shape: {ppt.shape}")
print(f"Trajectory Loss (ssl_loss): {ssl_loss.item():.4f}")
print(f"Abstraction Loss (abs_loss): {abs_loss.item():.4f}")
print(f"Total Loss: {total_loss.item():.4f}")
assert ssl_loss.ndim == 0 and abs_loss.ndim == 0
print("✅ Loss computation test passed.\n")


# --- 5. Test the `evaluate` Function ---
print("="*80)
print("--- Testing `evaluate` ---")
print("="*80)

# This function simulates a validation step, comparing a greedy search
# against a random search to measure the potential for improvement.
with torch.no_grad():
    greedy_ppl, improve_ppl_percent, _, _ = evaluate(dummy_data, sorl_model, n=4, config=sorl_config)

print(f"Greedy trajectory perplexity: {greedy_ppl.item():.4f}")
print(f"Search improvement over greedy: {improve_ppl_percent.item():.2f}%")
print("✅ `evaluate` test passed.")

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


--- Initializing Model and SORL Configuration ---
Model Initialized. Total vocabulary size: 642
SORL Config: SORLConfig(n=4, temperature=1.0, K=8, causal_rollout=False, budget=None, l=1, steps=4, max_t_search=32, start_ts=None, end_ts=None, abstract_budget=5, use_rhythmic_placeholders=True, use_spike_placeholders=False, curriculum_ratio=0.6, max_seq_len=None, use_fade_memory=False, min_keep=1024, train_dataset_path=None, val_dataset_path=None, train_batch_size=128, val_batch_size=128, train_iterations=1000, val_iterations=10, max_length=1024, learning_rate=0.001, log_interval=100)

Created dummy data with shape: torch.Size([2, 128])

--- Testing `sorl_search` ---
Original sequence length: 128
Sequence length after search (with abstractions): 143
Abstraction switch ratio: 1.00
✅ `sorl_search` test passed.

--- Testing Loss Computation ---
Per-token loss shape: torch.Size([2, 142])
Trajectory Loss (ssl_loss): 6.6709
Abstraction Loss (abs_loss): 6.3879
Total Loss: 13.0588
✅ Loss computati

In [2]:
# Load language modeling data inside & train on them ~ 

