In [2]:
# Multiplication Dataset (cot included/not, reverse/not)
# - (no cot, no reverse) 11 x 12 = <answer> 132 
# - (cot, no reverse) 11 x 12 = 12 + 120 = <answer> 132
# - (no cot, reverse) 11 x 21 = <answer> 231 
# - (cot, reverse) 11 x 21 = 21 + 021 = <answer> 231 

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("tokenizer/digit_tokenizer")

In [8]:
# The HuggingFace tokenizer does not have a `.filepath` attribute.
# To get the path it was loaded from, use:
tokenizer.name_or_path

'tokenizer/digit_tokenizer'

In [1]:
from types import SimpleNamespace
from transformers import AutoTokenizer
import torch
from model.model_minimind import MiniMindConfig
from model.model_sorl import SorlModelWrapper
from dataset.base import MemLoader
from src.sorl import SORLConfig
import os 

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ==============================================================================
# 1. Configuration (Mimicking command-line args)
# ==============================================================================
args = SimpleNamespace(
    # --- Paths ---
    train_data_path="dataset/multiply/multiply_2x2_train.bin",
    val_data_path="dataset/multiply/multiply_2x2_val.bin",
    
    # --- Model Config ---
    hidden_size=256,
    num_hidden_layers=6,
    num_attention_heads=6,
    abstract_vocab_sizes="8",
    
    # --- Training Config ---
    device="cuda" if torch.cuda.is_available() else "cpu",
    batch_size=32,
    learning_rate=1e-4,
    
    # --- SORL Config ---
    n_rollout=5,
    temperature=1.0,
    K=4,
    denoise_steps=1,
    max_t_search=0,
    use_rhythmic_placeholders=True,
    use_spike_placeholders=False,
    use_special_placeholders=False,
    special_token_id=31,
    abstract_budget=5,
    temperature_flip=False,
    
    # --- Curriculum and Memory ---
    curriculum_ratio=0.6, # looks redundant as of now, it (vaguely) violates the "compositionality" principle
    train_iterations=200, # This will be used by the scheduler
    use_fade_memory=False,
    use_compression_mask=False, # <-- Set to True to test your new mask
    compression_curriculum_ratio=0.25,
    memory_span=128,
    
    # --- GAPT ---
    default_phase=None, # Set to 1 or 2 to override, None to enable GAPT
    delta=0.01,
    tau=0.1,
    p_m=10,
    p_c=10
)

# ==============================================================================
# 2. Initialization
# ==============================================================================
print("--- Initializing components ---")
# --- Tokenizer and Data ---
train_loader = MemLoader(args.train_data_path, device=args.device)
val_loader = MemLoader(args.val_data_path, device=args.device)
tokenizer = AutoTokenizer.from_pretrained(train_loader.tokenizer_path) # data is tokenized
pad_token_id = tokenizer.pad_token_id

# --- Model ---
base_vocab_size = len(tokenizer)
abstract_vocab_sizes = [int(v) for v in args.abstract_vocab_sizes.split(',')]
full_vocab_list = [base_vocab_size] + abstract_vocab_sizes

# 2 layer 4 head
minimind_config = MiniMindConfig(
    hidden_size=args.hidden_size,
    num_attention_heads=args.num_attention_heads,
    num_hidden_layers=args.num_hidden_layers,
    vocab_size=sum(full_vocab_list)
)

model = SorlModelWrapper.from_scratch(
    config=minimind_config,
    full_vocab_size_list=full_vocab_list,
    memory_span=args.memory_span,
    pad_token_id=pad_token_id
).to(args.device)

print(f"Model initialized on {args.device} with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters.")

# --- SORL Config and Schedulers ---
sorl_config = SORLConfig(
    n=args.n_rollout, 
    temperature=args.temperature, 
    K=args.K,
    l=1, 
    steps=args.denoise_steps, 
    max_t_search=args.max_t_search,
    use_rhythmic_placeholders=args.use_rhythmic_placeholders,
    use_spike_placeholders=args.use_spike_placeholders,
    use_special_placeholders=args.use_special_placeholders,
    special_token_id=args.special_token_id,
    abstract_budget=args.abstract_budget,
    temperature_flip=args.temperature_flip,
    curriculum_ratio=args.curriculum_ratio,
    use_fade_memory=args.use_fade_memory,
    use_compression_mask=args.use_compression_mask,
    min_keep=args.memory_span, 
    max_seq_len=train_loader.max_length,
    train_iterations=args.train_iterations, 
    train_batch_size=args.batch_size,
    val_batch_size=args.batch_size,
    max_length=train_loader.max_length,
    default_phase=args.default_phase, 
    delta=args.delta, tau=args.tau,
    p_m=args.p_m, p_c=args.p_c
)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


--- Initializing components ---
Model initialized on cpu with 4.29M parameters.


$\textbf{Bug 1}$. Multiplication with 2 digits are not trained perfectly, suggesting issue with data / training pipeline (SoRL)
- loss is dropping (optimizer is fine, and loss function is consistent, potential data issue)
- trained model do not generate an answer, suggesting issue with the loss function
- $\textbf{Fix 1}$. Found error in 'compute_loss' -- it's not compatible with .long typed loss_mask (ends up slicing the index-1 element repetitively instead of picking the index with True mask etc.)

$\textbf{Idea 1}$. We inherit the language-modeling vocabulary, which splits '1' and ' 1' as separate token, this will make the arithmetic rule MUCH more complex (combinatorial space is much larger, therefore Radamacher complexity grows, which increases the generalization error etc.) --> It's worth trying a small, concise tokenizer. 

$\textbf{Idea 2}$. Topological similarity can potentially explain the curve / trend (generalization error v.s. BPE tokenization size, multiplication specific) (currently our hypothesis), and it can be a general metric adaptable for non-multiplication task easily as well. SoRL should be adopted from basic vocabulary, in order to maximize the topological similarity and show its efficiency here.

$\textbf{Progress 1}$. Build a collection of tokenizer, from digit tokenizer all the way to increasing sized BPE tokenizer

In [21]:
from src.sorl import evaluate 
from src.sorl import compute_per_token_loss, compute_loss, sorl_search, SearchScheduler, GatedPhaseTransition

# First, test out baseline performance, then test out SoRL performance etc.

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
search_scheduler = SearchScheduler(sorl_config)
gapt = GatedPhaseTransition(sorl_config.delta, sorl_config.tau, sorl_config.p_m, sorl_config.p_c)

# ==============================================================================
# 3. Interactive Training Loop
# ==============================================================================
print("\n--- Starting interactive training loop ---")
model.train()


for i in range(sorl_config.train_iterations): # Run for 10 steps
    # --- Scheduler Step ---
    t_search, drop_ratio = search_scheduler.step()
    sorl_config.max_t_search = 10
    model.drop_ratio = 0.5

    # --- Get data and perform SORL search ---
    # (1). Apply loss mask (and change its shape with abs padding) || (2). Customize abs padding
    data, loss_mask = train_loader.get_batch(sorl_config.train_batch_size)
    with torch.no_grad():
        search_data, switch_ratio = sorl_search(data, loss_mask, model, sorl_config)
        
    # --- Compute loss ---
    ppt = compute_per_token_loss(model, search_data)
    ssl_loss, abs_loss = compute_loss(search_data, model, ppt, loss_mask)
    
    total_loss = ssl_loss + abs_loss
    
    # --- Optimizer step ---
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # --- Logging ---
    greedy_advantage, best_advantage, greedy_info_gain, _, a_loss = evaluate(data, loss_mask, sorl_config, model, search_n=1)
    print(
        f"Step {i+1:02d} | "
        f"Loss: {total_loss.item():.2f} (SSL: {ssl_loss.item():.3f}, Abs: {abs_loss.item():.2f}) | "
        f"Advantage: {greedy_advantage:.1f}% | Info-Gain: {greedy_info_gain:.1f}% | Abs-Free-Loss: {a_loss:.3f} | "
        f"t_search: {t_search} | "
        f"drop_ratio: {model.drop_ratio:.2f}"
    )
    # break



--- Starting interactive training loop ---
Step 01 | Loss: 1.35 (SSL: 1.345, Abs: 0.00) | Advantage: 58.1% | Info-Gain: 47.9% | Abs-Free-Loss: 3.973 | t_search: 0 | drop_ratio: 0.50
Step 02 | Loss: 1.41 (SSL: 1.412, Abs: 0.00) | Advantage: 63.8% | Info-Gain: 45.4% | Abs-Free-Loss: 3.947 | t_search: 0 | drop_ratio: 0.50
Step 03 | Loss: 1.49 (SSL: 1.488, Abs: 0.00) | Advantage: 63.2% | Info-Gain: 42.9% | Abs-Free-Loss: 3.876 | t_search: 0 | drop_ratio: 0.50
Step 04 | Loss: 1.38 (SSL: 1.373, Abs: 0.00) | Advantage: 62.6% | Info-Gain: 43.9% | Abs-Free-Loss: 3.804 | t_search: 0 | drop_ratio: 0.50
Step 05 | Loss: 1.35 (SSL: 1.343, Abs: 0.00) | Advantage: 63.5% | Info-Gain: 44.5% | Abs-Free-Loss: 3.825 | t_search: 0 | drop_ratio: 0.50
Step 06 | Loss: 1.39 (SSL: 1.387, Abs: 0.00) | Advantage: 65.3% | Info-Gain: 44.6% | Abs-Free-Loss: 3.910 | t_search: 0 | drop_ratio: 0.50
Step 07 | Loss: 1.37 (SSL: 1.365, Abs: 0.00) | Advantage: 61.5% | Info-Gain: 47.2% | Abs-Free-Loss: 4.022 | t_search: 0 | 

$\textbf{Question} 1$. What'd happen if the topological similarity approaches 1, and how can we make it so?
$\textbf{Question} 2$. 

In [3]:
from eval_multiply import compute_topological_similarity, evaluate_on_loader

# --- Topological Similarity Metric ---
print("--- Computing Topological Similarity of Number Embeddings (10-99) ---")
number_strings = [str(i) for i in range(10, 100)]
correlation = compute_topological_similarity(model, tokenizer, number_strings)
print(f"Topological Similarity: {correlation:.2f}")


# --- Evaluation (Generate & Check Answer) ---- 
input_ids, _ = val_loader.get_batch(2)
test_prompt = tokenizer.decode(input_ids[0])

# Clean up padding for the prompt
test_prompt = test_prompt.replace(tokenizer.pad_token, '').strip()

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print("--- Running Evaluation ---")
evaluate_on_loader(model, tokenizer, train_loader, batch_size=10, K=sorl_config.K)

--- Computing Topological Similarity of Number Embeddings (10-99) ---
Topological Similarity: 0.49
--- Running Evaluation ---


Evaluating Batches: 100%|██████████| 20/20 [00:03<00:00,  5.43it/s, Accuracy=0.00%]


--- Evaluation Summary ---
Samples Evaluated: 200
Correct Predictions: 0
Accuracy: 0.00%





{'accuracy': 0.0, 'correct': 0, 'total': 200}

$\textbf{Idea 3}$. How about using RL instead? 

$\textbf{Idea 4}$. How about using SoRL instead? 

$\textbf{Idea 5}$. How about using SoRL + RL instead? 

In [29]:
from eval_multiply import evaluate_multiplication

for d in data: 
    test_prompt = tokenizer.decode(d)
    evaluate_multiplication(model, tokenizer, test_prompt, K=None)

Query: 6 0 * 8 9 =
Generated Response: 5 5 9 9 9 9 9 9 9 9
Expected Answer:  5 3 4 0
Generated Answer: 5 5 9 9 9 9 9 9 9 9
Query: 2 5 * 5 1 =
Generated Response: 2 2 2 2 2 2 2 2 2 2
Expected Answer:  1 2 7 5
Generated Answer: 2 2 2 2 2 2 2 2 2 2
Query: 6 9 * 3 8 =
Generated Response: 5 9 9 9 9 9 9 9 9 9
Expected Answer:  2 6 2 2
Generated Answer: 5 9 9 9 9 9 9 9 9 9
Query: 6 6 * 1 6 =
Generated Response: 5 2 2 2 2 2 2 2 2 2
Expected Answer:  1 0 5 6
Generated Answer: 5 2 2 2 2 2 2 2 2 2
Query: 8 1 * 7 0 =
Generated Response: 7 7 7 7 7 7 7 7 9 9
Expected Answer:  5 6 7 0
Generated Answer: 7 7 7 7 7 7 7 7 9 9
Query: 1 1 * 1 0 =
Generated Response: 2 2 2 2 2 2 2 2 2 <eos>
Expected Answer:  1 1 0
Generated Answer: 2 2 2 2 2 2 2 2 2
Query: 5 8 * 1 2 =
Generated Response: 3 3 3 2 2 2 2 2 2 2
Expected Answer:  6 9 6
Generated Answer: 3 3 3 2 2 2 2 2 2 2
Query: 7 5 * 8 2 =
Generated Response: 7 7 7 7 9 9 9 9 9 9
Expected Answer:  6 1 5 0
Generated Answer: 7 7 7 7 9 9 9 9 9 9
Query: 9 6 * 7 1 =

In [26]:
model.eval()
prompt = test_prompt 

from eval_multiply import _extract_answer_from_ids
from model.model_sorl import infer_level

answer_token_id = tokenizer.encode('<answer>', add_special_tokens=False)[0]

input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)

answer_indices = torch.where(input_ids == answer_token_id)

answer_idx = answer_indices[1][0]
query_ids = input_ids[:, :answer_idx + 1]
ground_truth_ids = input_ids[:, answer_idx + 1:]

with torch.no_grad():
    # Issue 1. keep generating abstract tokens
    output = model.generate(
                query_ids,
                max_new_tokens=10,
                temperature=0.0,
                force_abstraction_every_n=None
    )

# Use utility function for parsing
level = infer_level(output, model.vocab_sizes, tokenizer.pad_token_id)
traj = output[level == 0].reshape(output.shape[0], -1)
generated_ids = traj[:, query_ids.shape[1]:]
generated_response = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
generated_answer = _extract_answer_from_ids(generated_ids[0], tokenizer)
ground_truth_answer = _extract_answer_from_ids(ground_truth_ids[0], tokenizer)

print(f"Query: {tokenizer.decode(query_ids[0], skip_special_tokens=True)}")
print(f"Generated Response: {generated_response.strip()}")
print(f"Expected Answer:  {ground_truth_answer}")
print(f"Generated Answer: {generated_answer}")

Query: 4 1 * 5 8 =
Generated Response: 3 3 3 3 3 3 3 3 3 3
Expected Answer:  2 3 7 8
Generated Answer: 3 3 3 3 3 3 3 3 3 3


In [31]:
output

tensor([[ 7,  4, 13,  8, 11, 14,  2,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6]])

In [30]:
search_data

tensor([[ 9,  3, 13, 11, 21, 12, 14,  2,  8, 21,  6,  7,  3, 19],
        [ 5,  8, 13,  8, 21,  4, 14,  2,  4, 21,  5, 10,  8, 19],
        [ 9, 12, 13,  6, 21, 11, 14,  2,  5, 21,  9,  5,  5, 19],
        [ 9,  9, 13,  4, 21,  9, 14,  2,  4, 21,  3,  8,  9, 19],
        [11,  4, 13, 10, 21,  3, 14,  2,  8, 21,  9, 10,  3, 19],
        [ 4,  4, 13,  4, 21,  3, 14,  2,  4, 21,  4,  3, 19,  1],
        [ 8, 11, 13,  4, 21,  5, 14,  2,  9, 21, 12,  9, 19,  1],
        [10,  8, 13, 11, 21,  5, 14,  2,  9, 21,  4,  8,  3, 19],
        [12,  9, 13, 10, 21,  4, 14,  2,  9, 21, 11,  4,  9, 19],
        [ 9,  3, 13,  6, 21,  5, 14,  2,  4, 21, 12,  5,  3, 19],
        [ 7,  4, 13,  5, 21,  6, 14,  2, 12, 21,  7,  6, 19,  1],
        [11,  7, 13, 11, 21,  6, 14,  2,  9, 21, 12, 10,  5, 19],
        [ 5,  6, 13, 10, 21,  4, 14,  2,  4, 21,  9,  6,  6, 19],
        [ 7,  3, 13,  6, 21,  7, 14,  2,  4, 21,  6,  9,  3, 19],
        [ 5, 11, 13, 11, 21, 12, 14,  2,  5, 21,  7, 12,  5, 19],
        [ 

$\textbf{Issue 1}$. When including abstraction, the generated response don't 'stop' anymore. 

$\textbf{Issue 2}$. Abstraction generation in 'train-time' mismatch with that of 'inference-time' (former is parallel search over query, latter is causal generation over answer).

$\textbf{Idea 1}$. We have mismatch between 'train-time' abstraction addition (which is on query), and 'test-time' abstraction addition (which is on answer). It's probably better to add abstraction on query, or prefix token sequence only.