#### Self-organizing LLM wrapper

In [1]:
import torch
from model.model_sorl import SorlModelWrapper
from model.model_sorl import infer_level
from model.model_minimind import MiniMindConfig

full_vocab_list = [11, 50] # Base vocab + abstract vocabs
model = SorlModelWrapper.from_scratch(
    config=MiniMindConfig(vocab_size=sum(full_vocab_list)), # Config needs the total new vocab size
    full_vocab_size_list=full_vocab_list,
    memory_span=5,
    pad_token_id=0
)
# --- Generate text using the custom SORL logic ---

prompt = torch.tensor([[1, 2, 3]])
generated_sequence = model.generate(
    input_ids=prompt,
    max_new_tokens=50,
    temperature=0.0,
    top_k=50,
    force_abstraction_every_n=4  # Example: force an abstraction token every 10 steps
)

print("--- SORL Generation Results ---")
print("Base vocabulary size:", model.vocab_sizes[0].item())
print("Total vocabulary size:", model.model.config.vocab_size)
print("\nGenerated Sequence:", generated_sequence)


result = model.forward(prompt)
print("\n--- Forward propagation (sparse attention) ---")
print("result.logits.shape: ", result.logits.shape)


orig_tokens = torch.tensor([[1,2,3,61,2,4,1,61,3,4,2,61]])

levels = infer_level(orig_tokens, model.vocab_sizes, -1)
denoise_mask = torch.isin(orig_tokens, model.level_mask_tokens[1:])
denoise_levels = levels[denoise_mask]

new_tokens = model.denoise(orig_tokens, denoise_mask, denoise_levels, 0.0)
print("\n--- Denoising ---")
print(f"Generating 2 level-1 tokens in parallel: {orig_tokens[0].tolist()} --> {new_tokens[0].tolist()}")

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


--- SORL Generation Results ---
Base vocabulary size: 11
Total vocabulary size: 62

Generated Sequence: tensor([[20, 20, 56, 11, 11, 11, 37, 37, 37, 37, 37,  6,  6,  6, 37,  6,  6]])

--- Forward propagation (sparse attention) ---
result.logits.shape:  torch.Size([1, 3, 62])

--- Denoising ---
Generating 2 level-1 tokens in parallel: [1, 2, 3, 61, 2, 4, 1, 61, 3, 4, 2, 61] --> [1, 2, 3, 20, 2, 4, 1, 20, 3, 4, 2, 48]


In [None]:
# Change on memory fading gadget


#### Self-organizing Reinforcement Learning

In [1]:
import torch
from transformers import AutoTokenizer
from dataset.base import MemLoader
from model.model_sorl import SorlModelWrapper
from model.model_minimind import MiniMindConfig
from src.sorl import SORLConfig, sorl_search, compute_per_token_loss, compute_loss

# --- 1. Full Pipeline Initialization ---
print("--- Initializing training components ---")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer to get vocab size and pad token id
tokenizer = AutoTokenizer.from_pretrained('model/')
pad_token_id = tokenizer.pad_token_id

# Initialize the high-performance memory-mapped data loader
dataset = MemLoader('dataset/pretrain_hq.bin', device=device)
print("MemLoader initialized.")

# Initialize the SORL-wrapped model
base_vocab_size = tokenizer.vocab_size
abstract_vocab_sizes = [8]
full_vocab_list = [base_vocab_size] + abstract_vocab_sizes
minimind_config = MiniMindConfig(
    hidden_size=256, num_attention_heads=4, num_hidden_layers=4,
    intermediate_size=512, vocab_size=sum(full_vocab_list)
)
# The .to(device) call will now work correctly
sorl_model = SorlModelWrapper.from_scratch(
    config=minimind_config,
    full_vocab_size_list=full_vocab_list,
    memory_span=1024,
    pad_token_id=0
).to(device)
print("SORL Model initialized.")

--- Initializing training components ---


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


MemLoader initialized.
SORL Model initialized.


In [2]:
# Change can be made on the attention-masking mechanism
# -----------------------------------------------------

# 1. A visualization on where abstraction is added, and what is masked out 
#    I believe when abstraction is added, we only perform a 'distant memory masking'
#    I suspect a meomry distillation based training requires masking out the local chunk where abstraction is present

# .forward method accepts an 'attention_mask' argument

In [3]:
# Configure the SORL search algorithm
sorl_config = SORLConfig(
    n=4, temperature=1.0, K=8, l=1, steps=4, max_t_search=32,
    use_rhythmic_placeholders=True, use_spike_placeholders=False
)

# Set up the optimizer
optimizer = torch.optim.Adam(sorl_model.model.parameters(), lr=1e-4)
print("--- Initialization Complete ---\n")


# --- 2. Perform a Single SORL Training Step ---
print("--- Running one SORL training step ---")
# Get a batch of data instantly
data_batch, _ = dataset.get_batch(batch_size=4)
print(f"Fetched data batch of shape: {data_batch.shape}")

# a) SORL Search Step (run in no_grad context)
with torch.no_grad():
    search_data, switch_ratio = sorl_search(data_batch, sorl_model, sorl_config)
print(f"SORL search complete. New sequence shape: {search_data.shape}")

# b) Forward Pass: Compute per-token loss on the "improved" data
ppt = compute_per_token_loss(sorl_model, search_data)

# c) Compute final SORL loss (combining trajectory and abstraction losses)
ssl_loss, abs_loss = compute_loss(search_data, sorl_model, ppt)
total_loss = ssl_loss + abs_loss
print(f"Computed Loss -> Total: {total_loss.item():.4f} (SSL: {ssl_loss.item():.4f}, Abs: {abs_loss.item():.4f})")

# d) Backward Pass and Optimizer Step
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
print("Optimizer step complete (weights have been updated).")

print("\n--- ✅ Single training step finished! ---")

--- Initialization Complete ---

--- Running one SORL training step ---
Fetched data batch of shape: torch.Size([4, 255])
SORL search complete. New sequence shape: torch.Size([4, 286])
Computed Loss -> Total: 17.7112 (SSL: 8.9598, Abs: 8.7514)
Optimizer step complete (weights have been updated).

--- ✅ Single training step finished! ---


#### Subway

#### Hidden Information compressed into Abstraction

$\textbf{Question 1}$. Do we need the loss mask here? 
- $\textbf{Answer 1}$. Yes, because we want to compare whether abstraction can replace memories, the effect should be on the prediction perplexity only to remove other factors, such as perplexity of provided number etc. 

$\textbf{Question 2}$. When we perform memory degradation, is the the loss of omitted trajectory token prefix still intact? Are we doing un-fair comparison therein?

$\textbf{Issue 1}$. Scheduler is off, when a new abstract token is added, drop ratio is already 1.0 making it impossible for the model to learn via curriculum. 

$\textbf{Question 3}$. How do we reliably justify whether the model has 'learned' to compress the memory into the abstraction?
- I think we ought to have a control group etc. 

$\textbf{Idea 1}$. Test on 'information gain' target (proposed in https://github.com/NVlabs/RLP?tab=readme-ov-file#paper). We'd need a EMA reference checkpoint $\pi_{\hat{\theta}}$ and information gain of an abstraction $a \in \mathcal{A}$ is computed by 
$$
r(a) = \pi_{\theta}(x_{i} | x_{<i}, a) - \pi_{\hat{\theta}}(x_{i} | x_{<i})
$$
In RLP, they've found online RL with information gain reward signal suffices to improve model performance by searching for high quality CoT that assists with next-token prediction target. 


In [None]:
# Prepare 'hidden information' dataset 
# xxxA: xxx is its basic form 
# we apply memory compression & SoRL to ask the model to predict 
# A: xxx
# Here, xxx is hidden information, can be a number etc. 
# ----------------------------------------------------------------
# This dataset will be useful to inspect whether abstraction is effective or not. 
# ----------------------------------------------------------------
# python -m dataset.prep_hidden_info_dataset



Generating 100 samples...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Generating samples: 100%|██████████| 100/100 [00:00<00:00, 209610.39it/s]


In [14]:
dataset.equals_token_id

31

In [None]:
from types import SimpleNamespace
from transformers import AutoTokenizer
import sys, torch

# --- Add project root to path ---
# You might need to adjust this depending on your notebook's location
if '..' not in sys.path:
    sys.path.append('..')

from model.model_minimind import MiniMindConfig
from model.model_sorl import SorlModelWrapper
from dataset.base import MemLoader
from src.sorl import SORLConfig, sorl_search, compute_loss, compute_per_token_loss, GatedPhaseTransition, SearchScheduler

# ==============================================================================
# 1. Configuration (Mimicking command-line args)
# ==============================================================================
args = SimpleNamespace(
    # --- Paths ---
    train_data_path="dataset/hidden_info.bin",
    tokenizer_path="model/",
    
    # --- Model Config ---
    hidden_size=256,
    num_hidden_layers=4,
    num_attention_heads=4,
    abstract_vocab_sizes="8",
    
    # --- Training Config ---
    device="cuda" if torch.cuda.is_available() else "cpu",
    batch_size=4,
    learning_rate=3e-4,
    
    # --- SORL Config ---
    n_rollout=3,
    temperature=1.0,
    K=4,
    denoise_steps=1,
    max_t_search=1,
    use_rhythmic_placeholders=False,
    use_spike_placeholders=False,
    use_special_placeholders=True,
    special_token_id=31,
    abstract_budget=5,
    temperature_flip=False,
    
    # --- Curriculum and Memory ---
    curriculum_ratio=0.6,
    train_iterations=1000, # This will be used by the scheduler
    use_fade_memory=False,
    use_compression_mask=True, # <-- Set to True to test your new mask
    compression_curriculum_ratio=0.25,
    memory_span=20,
    
    # --- GAPT ---
    default_phase=None, # Set to 1 or 2 to override, None to enable GAPT
    delta=0.01,
    tau=0.1,
    p_m=10,
    p_c=10
)

# ==============================================================================
# 2. Initialization
# ==============================================================================
print("--- Initializing components ---")
# --- Tokenizer and Data ---
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
pad_token_id = tokenizer.pad_token_id
train_loader = MemLoader(args.train_data_path, device=args.device)

# --- Model ---
base_vocab_size = tokenizer.vocab_size
abstract_vocab_sizes = [int(v) for v in args.abstract_vocab_sizes.split(',')]
full_vocab_list = [base_vocab_size] + abstract_vocab_sizes

minimind_config = MiniMindConfig(
    hidden_size=args.hidden_size,
    num_attention_heads=args.num_attention_heads,
    num_hidden_layers=args.num_hidden_layers,
    vocab_size=sum(full_vocab_list)
)

model = SorlModelWrapper.from_scratch(
    config=minimind_config,
    full_vocab_size_list=full_vocab_list,
    memory_span=args.memory_span,
    pad_token_id=pad_token_id
).to(args.device)

print(f"Model initialized on {args.device} with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters.")

# --- SORL Config and Schedulers ---
sorl_config = SORLConfig(
    n=args.n_rollout, 
    temperature=args.temperature, 
    K=args.K,
    l=1, 
    steps=args.denoise_steps, 
    max_t_search=args.max_t_search,
    use_rhythmic_placeholders=args.use_rhythmic_placeholders,
    use_spike_placeholders=args.use_spike_placeholders,
    use_special_placeholders=args.use_special_placeholders,
    special_token_id=args.special_token_id,
    abstract_budget=args.abstract_budget,
    temperature_flip=args.temperature_flip,
    curriculum_ratio=args.curriculum_ratio,
    use_fade_memory=args.use_fade_memory,
    use_compression_mask=args.use_compression_mask,
    min_keep=args.memory_span, 
    max_seq_len=train_loader.max_length,
    train_iterations=args.train_iterations, 
    max_length=train_loader.max_length,
    default_phase=args.default_phase, 
    delta=args.delta, tau=args.tau,
    p_m=args.p_m, p_c=args.p_c
)

In [None]:

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
search_scheduler = SearchScheduler(sorl_config)
gapt = GatedPhaseTransition(sorl_config.delta, sorl_config.tau, sorl_config.p_m, sorl_config.p_c)

# ==============================================================================
# 3. Interactive Training Loop
# ==============================================================================
print("\n--- Starting interactive training loop ---")
model.train()
for i in range(sorl_config.train_iterations): # Run for 10 steps
    # --- Scheduler Step ---
    t_search, drop_ratio = search_scheduler.step()
    sorl_config.max_t_search = t_search
    model.drop_ratio = drop_ratio
    
    # --- Get data and perform SORL search ---
    # (1). Apply loss mask (and change its shape with abs padding) || (2). Customize abs padding
    data, loss_mask = train_loader.get_batch(args.batch_size)
    with torch.no_grad():
        search_data, switch_ratio = sorl_search(data, model, sorl_config)
        
    # --- Compute loss ---
    ppt = compute_per_token_loss(model, search_data)
    ssl_loss, abs_loss = compute_loss(search_data, model, ppt, loss_mask)
    
    # --- GAPT adaptation ---
    current_phase = gapt.step(ssl_loss.item(), abs_loss.item())
    if sorl_config.default_phase is not None:
        current_phase = sorl_config.default_phase
    
    total_loss = ssl_loss + abs_loss if current_phase == 2 else ssl_loss
    
    # --- Optimizer step ---
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # --- Logging ---
    print(
        f"Step {i+1:02d} | "
        f"Loss: {total_loss.item():.4f} (SSL: {ssl_loss.item():.4f}, Abs: {abs_loss.item():.4f}) | "
        f"Phase: {current_phase} | "
        f"t_search: {t_search} | "
        f"drop_ratio: {drop_ratio:.2f}"
    )


--- Starting interactive training loop ---
Step 01 | Loss: 8.7683 (SSL: 8.7683, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.00
Step 02 | Loss: 8.9134 (SSL: 8.9134, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.01
Step 03 | Loss: 8.6558 (SSL: 8.6558, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.01
Step 04 | Loss: 9.0467 (SSL: 9.0467, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.02
Step 05 | Loss: 9.4948 (SSL: 9.4948, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.02
Step 06 | Loss: 9.4216 (SSL: 9.4216, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.02
Step 07 | Loss: 7.4151 (SSL: 7.4151, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.03
Step 08 | Loss: 8.9870 (SSL: 8.9870, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.03
Step 09 | Loss: 8.6267 (SSL: 8.6267, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.04
Step 10 | Loss: 8.0019 (SSL: 8.0019, Abs: 0.0000) | Phase: 1 | t_search: 0 | drop_ratio: 0.04
Step 11 | Loss: 

In [None]:
# evaluation - compare perplexity of target tokens against 
# (1). non-abstraction sequence
# (2). random-abstraction sequence

from src.sorl import * 

data, loss_mask = train_loader.get_batch(args.batch_size)
config = sorl_config

greedy_advantage, best_advantage = evaluate(data, loss_mask, config, model, search_n=1)