In [2]:
# Multiplication Dataset (cot included/not, reverse/not)
# - (no cot, no reverse) 11 x 12 = <answer> 132 
# - (cot, no reverse) 11 x 12 = 12 + 120 = <answer> 132
# - (no cot, reverse) 11 x 21 = <answer> 231 
# - (cot, reverse) 11 x 21 = 21 + 021 = <answer> 231 

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("tokenizer/digit_tokenizer")

In [8]:
# The HuggingFace tokenizer does not have a `.filepath` attribute.
# To get the path it was loaded from, use:
tokenizer.name_or_path

'tokenizer/digit_tokenizer'

In [1]:
from types import SimpleNamespace
from transformers import AutoTokenizer
import torch
from model.model_minimind import MiniMindConfig
from model.model_sorl import SorlModelWrapper
from dataset.base import MemLoader
from src.sorl import SORLConfig

# ==============================================================================
# 1. Configuration (Mimicking command-line args)
# ==============================================================================
args = SimpleNamespace(
    # --- Paths ---
    train_data_path="dataset/multiply/multiply_1x1_train.bin",
    val_data_path="dataset/multiply/multiply_1x1_val.bin",
    
    # --- Model Config ---
    hidden_size=256,
    num_hidden_layers=4,
    num_attention_heads=4,
    abstract_vocab_sizes="8",
    
    # --- Training Config ---
    device="cuda" if torch.cuda.is_available() else "cpu",
    batch_size=32,
    learning_rate=3e-4,
    
    # --- SORL Config ---
    n_rollout=5,
    temperature=1.0,
    K=4,
    denoise_steps=1,
    max_t_search=0,
    use_rhythmic_placeholders=True,
    use_spike_placeholders=False,
    use_special_placeholders=False,
    special_token_id=31,
    abstract_budget=5,
    temperature_flip=False,
    
    # --- Curriculum and Memory ---
    curriculum_ratio=0.6, # looks redundant as of now, it (vaguely) violates the "compositionality" principle
    train_iterations=200, # This will be used by the scheduler
    use_fade_memory=False,
    use_compression_mask=False, # <-- Set to True to test your new mask
    compression_curriculum_ratio=0.25,
    memory_span=20,
    
    # --- GAPT ---
    default_phase=None, # Set to 1 or 2 to override, None to enable GAPT
    delta=0.01,
    tau=0.1,
    p_m=10,
    p_c=10
)

# ==============================================================================
# 2. Initialization
# ==============================================================================
print("--- Initializing components ---")
# --- Tokenizer and Data ---
train_loader = MemLoader(args.train_data_path, device=args.device)
val_loader = MemLoader(args.val_data_path, device=args.device)
tokenizer = AutoTokenizer.from_pretrained(train_loader.tokenizer_path) # data is tokenized
pad_token_id = tokenizer.pad_token_id

# --- Model ---
base_vocab_size = len(tokenizer)
abstract_vocab_sizes = [int(v) for v in args.abstract_vocab_sizes.split(',')]
full_vocab_list = [base_vocab_size] + abstract_vocab_sizes

# 2 layer 4 head
minimind_config = MiniMindConfig(
    hidden_size=args.hidden_size,
    num_attention_heads=args.num_attention_heads,
    num_hidden_layers=args.num_hidden_layers,
    vocab_size=sum(full_vocab_list)
)

model = SorlModelWrapper.from_scratch(
    config=minimind_config,
    full_vocab_size_list=full_vocab_list,
    memory_span=args.memory_span,
    pad_token_id=pad_token_id
).to(args.device)

print(f"Model initialized on {args.device} with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters.")

# --- SORL Config and Schedulers ---
sorl_config = SORLConfig(
    n=args.n_rollout, 
    temperature=args.temperature, 
    K=args.K,
    l=1, 
    steps=args.denoise_steps, 
    max_t_search=args.max_t_search,
    use_rhythmic_placeholders=args.use_rhythmic_placeholders,
    use_spike_placeholders=args.use_spike_placeholders,
    use_special_placeholders=args.use_special_placeholders,
    special_token_id=args.special_token_id,
    abstract_budget=args.abstract_budget,
    temperature_flip=args.temperature_flip,
    curriculum_ratio=args.curriculum_ratio,
    use_fade_memory=args.use_fade_memory,
    use_compression_mask=args.use_compression_mask,
    min_keep=args.memory_span, 
    max_seq_len=train_loader.max_length,
    train_iterations=args.train_iterations, 
    train_batch_size=args.batch_size,
    val_batch_size=args.batch_size,
    max_length=train_loader.max_length,
    default_phase=args.default_phase, 
    delta=args.delta, tau=args.tau,
    p_m=args.p_m, p_c=args.p_c
)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


--- Initializing components ---
Model initialized on cpu with 2.96M parameters.


$\textbf{Bug 1}$. Multiplication with 2 digits are not trained perfectly, suggesting issue with data / training pipeline (SoRL)
- loss is dropping (optimizer is fine, and loss function is consistent, potential data issue)
- trained model do not generate an answer, suggesting issue with the loss function
- $\textbf{Fix 1}$. Found error in 'compute_loss' -- it's not compatible with .long typed loss_mask (ends up slicing the index-1 element repetitively instead of picking the index with True mask etc.)

$\textbf{Idea 1}$. We inherit the language-modeling vocabulary, which splits '1' and ' 1' as separate token, this will make the arithmetic rule MUCH more complex (combinatorial space is much larger, therefore Radamacher complexity grows, which increases the generalization error etc.) --> It's worth trying a small, concise tokenizer. 

$\textbf{Idea 2}$. Topological similarity can potentially explain the curve / trend (generalization error v.s. BPE tokenization size, multiplication specific) (currently our hypothesis), and it can be a general metric adaptable for non-multiplication task easily as well. SoRL should be adopted from basic vocabulary, in order to maximize the topological similarity and show its efficiency here.

$\textbf{Progress 1}$. Build a collection of tokenizer, from digit tokenizer all the way to increasing sized BPE tokenizer

In [2]:
from src.sorl import evaluate 
from src.sorl import compute_per_token_loss, compute_loss, sorl_search, SearchScheduler, GatedPhaseTransition

# First, test out baseline performance, then test out SoRL performance etc.

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
search_scheduler = SearchScheduler(sorl_config)
gapt = GatedPhaseTransition(sorl_config.delta, sorl_config.tau, sorl_config.p_m, sorl_config.p_c)

# ==============================================================================
# 3. Interactive Training Loop
# ==============================================================================
print("\n--- Starting interactive training loop ---")
model.train()
for i in range(sorl_config.train_iterations): # Run for 10 steps
    # --- Scheduler Step ---
    t_search, drop_ratio = search_scheduler.step()
    sorl_config.max_t_search = 0
    model.drop_ratio = 0.0

    # --- Get data and perform SORL search ---
    # (1). Apply loss mask (and change its shape with abs padding) || (2). Customize abs padding
    data, loss_mask = train_loader.get_batch(sorl_config.train_batch_size)
    with torch.no_grad():
        search_data, switch_ratio = sorl_search(data, loss_mask, model, sorl_config)
        
    # --- Compute loss ---
    ppt = compute_per_token_loss(model, search_data)
    ssl_loss, abs_loss = compute_loss(search_data, model, ppt, loss_mask)
    
    total_loss = ssl_loss + abs_loss
    
    # --- Optimizer step ---
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # --- Logging ---
    greedy_advantage, best_advantage, greedy_info_gain, _, a_loss = evaluate(data, loss_mask, sorl_config, model, search_n=1)
    print(
        f"Step {i+1:02d} | "
        f"Loss: {total_loss.item():.2f} (SSL: {ssl_loss.item():.3f}, Abs: {abs_loss.item():.2f}) | "
        f"Advantage: {greedy_advantage:.1f}% | Info-Gain: {greedy_info_gain:.1f}% | Abs-Free-Loss: {a_loss:.3f} | "
        f"t_search: {t_search} | "
        f"drop_ratio: {model.drop_ratio:.2f}"
    )
    # break



--- Starting interactive training loop ---
Step 01 | Loss: 3.44 (SSL: 3.438, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -87.8% | Abs-Free-Loss: 1.197 | t_search: 0 | drop_ratio: 0.00
Step 02 | Loss: 2.48 (SSL: 2.484, Abs: 0.00) | Advantage: -0.0% | Info-Gain: -96.9% | Abs-Free-Loss: 1.056 | t_search: 0 | drop_ratio: 0.00
Step 03 | Loss: 2.11 (SSL: 2.112, Abs: 0.00) | Advantage: -0.0% | Info-Gain: -95.1% | Abs-Free-Loss: 0.976 | t_search: 0 | drop_ratio: 0.00
Step 04 | Loss: 1.87 (SSL: 1.866, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -100.5% | Abs-Free-Loss: 0.861 | t_search: 0 | drop_ratio: 0.00
Step 05 | Loss: 1.69 (SSL: 1.689, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -95.1% | Abs-Free-Loss: 0.805 | t_search: 0 | drop_ratio: 0.00
Step 06 | Loss: 1.38 (SSL: 1.377, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -106.0% | Abs-Free-Loss: 0.607 | t_search: 0 | drop_ratio: 0.00
Step 07 | Loss: 1.24 (SSL: 1.236, Abs: 0.00) | Advantage: -0.0% | Info-Gain: -100.5% | Abs-Free-Loss: 0.563 | t_search

In [9]:
from eval_multiply import compute_topological_similarity, evaluate_multiplication

# --- Topological Similarity Metric ---
print("--- Computing Topological Similarity of Number Embeddings (10-99) ---")
number_strings = [str(i) for i in range(10, 100)]
correlation = compute_topological_similarity(model, tokenizer, number_strings)
print(f"Topological Similarity: {correlation:.2f}")


# --- Evaluation (Generate & Check Answer) ---- 
input_ids, _ = val_loader.get_batch(2)
test_prompt = tokenizer.decode(input_ids[0])

# Clean up padding for the prompt
test_prompt = test_prompt.replace(tokenizer.pad_token, '').strip()

print("--- Running Evaluation ---")
evaluate_multiplication(model, tokenizer, test_prompt)

--- Computing Topological Similarity of Number Embeddings (10-99) ---
Topological Similarity: 0.42
--- Running Evaluation ---
Query: 8 * 4 = <answer>
Generated Response: 3 2 <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
Expected Answer:  3 2
Generated Answer: 3 2
