In [2]:
# Multiplication Dataset (cot included/not, reverse/not)
# - (no cot, no reverse) 11 x 12 = <answer> 132 
# - (cot, no reverse) 11 x 12 = 12 + 120 = <answer> 132
# - (no cot, reverse) 11 x 21 = <answer> 231 
# - (cot, reverse) 11 x 21 = 21 + 021 = <answer> 231 

In [6]:
from types import SimpleNamespace
from transformers import AutoTokenizer
import torch
from model.model_minimind import MiniMindConfig
from model.model_sorl import SorlModelWrapper
from dataset.base import MemLoader
from src.sorl import SORLConfig
import os 

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ==============================================================================
# 1. Configuration (Mimicking command-line args)
# ==============================================================================
args = SimpleNamespace(
    # --- Paths ---
    train_data_path="dataset/multiply/multiply_2x2_train.bin",
    val_data_path="dataset/multiply/multiply_2x2_val.bin",
    
    # --- Model Config ---
    hidden_size=768,
    num_hidden_layers=4,
    num_attention_heads=2,
    abstract_vocab_sizes="8",
    
    # --- Training Config ---
    device="cuda" if torch.cuda.is_available() else "cpu",
    batch_size=128,
    learning_rate=1e-4,
    epoch=3,
    
    # --- SORL Config ---
    n_rollout=5,
    temperature=1.0,
    K=4,
    denoise_steps=1,
    max_t_search=0,
    use_rhythmic_placeholders=True,
    use_spike_placeholders=False,
    use_special_placeholders=False,
    special_token_id=31,
    abstract_budget=5,
    temperature_flip=False,
    
    # --- Curriculum and Memory ---
    curriculum_ratio=0.6, # looks redundant as of now, it (vaguely) violates the "compositionality" principle
    use_fade_memory=False,
    use_compression_mask=False, # <-- Set to True to test your new mask
    compression_curriculum_ratio=0.25,
    memory_span=128,
    
    # --- GAPT ---
    default_phase=None, # Set to 1 or 2 to override, None to enable GAPT
    delta=0.01,
    tau=0.1,
    p_m=10,
    p_c=10
)

# ==============================================================================
# 2. Initialization
# ==============================================================================
print("--- Initializing components ---")
# --- Tokenizer and Data ---
train_loader = MemLoader(args.train_data_path, device=args.device)
val_loader = MemLoader(args.val_data_path, device=args.device)
tokenizer = AutoTokenizer.from_pretrained(train_loader.tokenizer_path) # data is tokenized
pad_token_id = tokenizer.pad_token_id

# --- Model ---
base_vocab_size = len(tokenizer)
abstract_vocab_sizes = [int(v) for v in args.abstract_vocab_sizes.split(',')]
full_vocab_list = [base_vocab_size] + abstract_vocab_sizes

# 2 layer 4 head
minimind_config = MiniMindConfig(
    hidden_size=args.hidden_size,
    num_attention_heads=args.num_attention_heads,
    num_hidden_layers=args.num_hidden_layers,
    vocab_size=sum(full_vocab_list)
)

model = SorlModelWrapper.from_scratch(
    config=minimind_config,
    full_vocab_size_list=full_vocab_list,
    memory_span=args.memory_span,
    pad_token_id=pad_token_id
).to(args.device)

print(f"Model initialized on {args.device} with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters.")

# --- SORL Config and Schedulers ---

sorl_config = SORLConfig(
    n=args.n_rollout, 
    temperature=args.temperature, 
    K=args.K,
    l=1, 
    steps=args.denoise_steps, 
    max_t_search=args.max_t_search,
    use_rhythmic_placeholders=args.use_rhythmic_placeholders,
    use_spike_placeholders=args.use_spike_placeholders,
    use_special_placeholders=args.use_special_placeholders,
    special_token_id=args.special_token_id,
    abstract_budget=args.abstract_budget,
    temperature_flip=args.temperature_flip,
    curriculum_ratio=args.curriculum_ratio,
    use_fade_memory=args.use_fade_memory,
    use_compression_mask=args.use_compression_mask,
    min_keep=args.memory_span, 
    max_seq_len=train_loader.max_length,
    train_iterations=int(args.epoch * len(train_loader) / args.batch_size), 
    train_batch_size=args.batch_size,
    val_batch_size=args.batch_size,
    max_length=train_loader.max_length,
    default_phase=args.default_phase, 
    delta=args.delta, tau=args.tau,
    p_m=args.p_m, p_c=args.p_c
)

--- Initializing components ---
Model initialized on cpu with 28.34M parameters.


In [None]:
from src.sorl import evaluate 
from src.sorl import compute_per_token_loss, compute_loss, sorl_search, SearchScheduler, GatedPhaseTransition

# First, test out baseline performance, then test out SoRL performance etc.

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
search_scheduler = SearchScheduler(sorl_config)
gapt = GatedPhaseTransition(sorl_config.delta, sorl_config.tau, sorl_config.p_m, sorl_config.p_c)

# ==============================================================================
# 3. Interactive Training Loop
# ==============================================================================
print("\n--- Starting interactive training loop ---")
model.train()


for i in range(sorl_config.train_iterations): # Run for 10 steps
    # --- Scheduler Step ---
    t_search, drop_ratio = search_scheduler.step()
    sorl_config.max_t_search = 0
    model.drop_ratio = 0.0

    # --- Get data and perform SORL search ---
    # (1). Apply loss mask (and change its shape with abs padding) || (2). Customize abs padding
    data, loss_mask = train_loader.get_batch(sorl_config.train_batch_size)
    with torch.no_grad():
        search_data, switch_ratio = sorl_search(data, loss_mask, model, sorl_config)
        
    # --- Compute loss ---
    ppt = compute_per_token_loss(model, search_data)
    ssl_loss, abs_loss = compute_loss(search_data, model, ppt, loss_mask)
    
    total_loss = ssl_loss + abs_loss
    
    # --- Optimizer step ---
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # --- Logging ---
    greedy_advantage, best_advantage, greedy_info_gain, _, a_loss = evaluate(data, loss_mask, sorl_config, model, search_n=1)
    print(
        f"Step {i+1:02d} | "
        f"Loss: {total_loss.item():.2f} (SSL: {ssl_loss.item():.3f}, Abs: {abs_loss.item():.2f}) | "
        f"Advantage: {greedy_advantage:.1f}% | Info-Gain: {greedy_info_gain:.1f}% | Abs-Free-Loss: {a_loss:.3f} | "
        f"t_search: {t_search} | "
        f"drop_ratio: {model.drop_ratio:.2f}"
    )


- No abstraction allowed

--- Starting interactive training loop ---
Step 01 | Loss: 3.46 (SSL: 3.461, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -92.8% | Abs-Free-Loss: 1.351 | t_search: 0 | drop_ratio: 0.00
Step 02 | Loss: 2.66 (SSL: 2.664, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -90.8% | Abs-Free-Loss: 1.307 | t_search: 0 | drop_ratio: 0.00
Step 03 | Loss: 2.56 (SSL: 2.557, Abs: 0.00) | Advantage: -0.0% | Info-Gain: -90.2% | Abs-Free-Loss: 1.212 | t_search: 0 | drop_ratio: 0.00
Step 04 | Loss: 2.34 (SSL: 2.340, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -90.2% | Abs-Free-Loss: 1.102 | t_search: 0 | drop_ratio: 0.00
Step 05 | Loss: 2.13 (SSL: 2.131, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -90.2% | Abs-Free-Loss: 1.077 | t_search: 0 | drop_ratio: 0.00
Step 06 | Loss: 2.09 (SSL: 2.093, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -90.2% | Abs-Free-Loss: 1.007 | t_search: 0 | drop_ratio: 0.00
Step 07 | Loss: 1.91 (SSL: 1.908, Abs: 0.00) | Advantage: 0.0% | Info-Gain: -91.1% | Abs-Free-Lo

KeyboardInterrupt: 

In [5]:
583 * 128 / 80800

0.9235643564356436

$\textbf{Question} 1$. What'd happen if the topological similarity approaches 1, and how can we make it so?
$\textbf{Question} 2$. 

$\textbf{Observation 1}$. 200 epochs is far from enough for minimid model to learn 2x2 multiplication. Think 2k epochs at least. We could also use a bigger batch size. 

In [21]:
from eval_multiply import evaluate_on_loader

# --- Evaluation (Generate & Check Answer) ---- 
print("--- Running Evaluation ---")
evaluate_on_loader(model, tokenizer, val_loader, batch_size=10, K=None)

--- Running Evaluation ---


Evaluating Batches: 100%|██████████| 20/20 [00:12<00:00,  1.61it/s, Accuracy=15.50%]


--- Evaluation Summary ---
Samples Evaluated: 200
Correct Predictions: 31
Accuracy: 15.50%
Topological Similarity: 0.48





{'accuracy': 15.5,
 'correct': 31,
 'total': 200,
 'top_sim_score': 0.48200833816144695}

$\textbf{Record 1.}$.  1.5 epochs (that's 80k * 1.5 = 120k data getting trained, with batch size of 128, roughly 1k iterations required) on 2x2 multiplication produces 66.5% accuracy on test set. 

In [8]:
from transformers import GPT2LMHeadModel

# Load the pretrained "gpt2" model (which is the 124M parameter version)
model = GPT2LMHeadModel.from_pretrained('gpt2')

# You can print the model architecture to see its layers
# print(model)

# To get the total number of parameters, you can do the following:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

# And to see how many of those are trainable:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Total parameters: 124,439,808
Trainable parameters: 124,439,808


In [9]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [20]:
from eval_multiply import evaluate_multiplication

input_ids, loss_mask = val_loader.get_batch(1)
prompt = tokenizer.decode(input_ids[0])
evaluate_multiplication(model, tokenizer, prompt, K=None)

Query: 1 2 * 4 8 =
Generated Response: <answer> 5 5 6 <eos> <eos> <eos> <eos> <eos> <eos>
Expected Answer:  5 7 6
Generated Answer: 5 5 6


('5 5 6', '5 7 6')

In [5]:
# Validate train_loader data is correct 
from eval_multiply import _get_query_and_gt_ids

input_ids, loss_mask = train_loader.get_batch(1)
prompt = tokenizer.decode(input_ids[0])

query_str = prompt.split('=')[0].strip() # execute this


$\textbf{Idea 3}$. How about using RL instead? 

$\textbf{Idea 4}$. How about using SoRL instead? 

$\textbf{Idea 5}$. How about using SoRL + RL instead? 

Issue 1. Missing CoT generations. 
- $\textit{Fix 1}$. Dynamically identify case requiring CoT and case that doesn't. 

Issue 2. I can't believe transformer can't learn multiplication (2 digits !?) -- Can we initialize a Qwen architecture and try on this again? Can we scale up the experiment script and run on GPU instead?

$\textbf{Issue 1}$. When including abstraction, the generated response don't 'stop' anymore. 

$\textbf{Issue 2}$. Abstraction generation in 'train-time' mismatch with that of 'inference-time' (former is parallel search over query, latter is causal generation over answer).

$\textbf{Idea 1}$. We have mismatch between 'train-time' abstraction addition (which is on query), and 'test-time' abstraction addition (which is on answer). It's probably better to add abstraction on query, or prefix token sequence only.

In [None]:
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM

# ==============================================================================
# 2. Initialization
# ==============================================================================
print("--- Initializing components ---")
# --- Tokenizer and Data ---
train_loader = MemLoader(args.train_data_path, device=args.device)
val_loader = MemLoader(args.val_data_path, device=args.device)
tokenizer = AutoTokenizer.from_pretrained(train_loader.tokenizer_path) # data is tokenized
pad_token_id = tokenizer.pad_token_id

# --- Model ---
base_vocab_size = len(tokenizer)
abstract_vocab_sizes = [int(v) for v in args.abstract_vocab_sizes.split(',')]
full_vocab_list = [base_vocab_size] + abstract_vocab_sizes

# 2 layer 4 head
minimind_config = MiniMindConfig(
    hidden_size=args.hidden_size,
    num_attention_heads=args.num_attention_heads,
    num_hidden_layers=args.num_hidden_layers,
    vocab_size=sum(full_vocab_list)
)

model = MiniMindForCausalLM(minimind_config)

--- Initializing components ---


In [8]:
from src.sorl import evaluate 
from src.sorl import compute_per_token_loss, compute_loss, sorl_search, SearchScheduler, GatedPhaseTransition

# First, test out baseline performance, then test out SoRL performance etc.

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

# ==============================================================================
# 3. Interactive Training Loop
# ==============================================================================
print("\n--- Starting interactive training loop ---")
model.train()

for i in range(sorl_config.train_iterations): # Run for 10 steps

    # --- Get data and perform SORL search ---
    # (1). Apply loss mask (and change its shape with abs padding) || (2). Customize abs padding
    data, loss_mask = train_loader.get_batch(sorl_config.train_batch_size)
    # break
    # --- Compute loss ---
    ppt = compute_per_token_loss(model, data, tokenizer.pad_token_id)
    ssl_loss = (ppt * loss_mask[:, 1:]).sum() / loss_mask[:, 1:].sum()
    total_loss = ssl_loss
    
    # --- Optimizer step ---
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    # --- Logging ---
    print(
        f"Step {i+1:02d} | "
        f"Loss: {total_loss.item():.2f}"
    )


--- Starting interactive training loop ---
Step 01 | Loss: 1.24
Step 02 | Loss: 1.39
Step 03 | Loss: 1.21
Step 04 | Loss: 1.34
Step 05 | Loss: 1.36
Step 06 | Loss: 1.30
Step 07 | Loss: 1.32
Step 08 | Loss: 1.23
Step 09 | Loss: 1.22
Step 10 | Loss: 1.26
Step 11 | Loss: 1.28
Step 12 | Loss: 1.20
Step 13 | Loss: 1.22
Step 14 | Loss: 1.17
Step 15 | Loss: 1.18
Step 16 | Loss: 1.23
Step 17 | Loss: 1.21
Step 18 | Loss: 1.16
Step 19 | Loss: 1.26
Step 20 | Loss: 1.13
Step 21 | Loss: 1.13
Step 22 | Loss: 1.24
Step 23 | Loss: 1.12
Step 24 | Loss: 1.23
Step 25 | Loss: 1.11
Step 26 | Loss: 1.11
Step 27 | Loss: 1.17
Step 28 | Loss: 1.21
Step 29 | Loss: 1.19
Step 30 | Loss: 1.20
Step 31 | Loss: 1.26
Step 32 | Loss: 1.18
Step 33 | Loss: 1.15
Step 34 | Loss: 1.09
Step 35 | Loss: 1.14
Step 36 | Loss: 1.16
Step 37 | Loss: 1.05
Step 38 | Loss: 1.12
Step 39 | Loss: 1.20
Step 40 | Loss: 1.06
Step 41 | Loss: 1.18
Step 42 | Loss: 1.14
Step 43 | Loss: 1.11
Step 44 | Loss: 1.17
Step 45 | Loss: 1.19
Step 46 | L