In [1]:
from types import SimpleNamespace
from transformers import AutoTokenizer
import torch
from model.model_minimind import MiniMindConfig
from model.model_sorl import SorlModelWrapper
from dataset.base import MemLoader
from src.sorl import SORLConfig

# ==============================================================================
# 1. Configuration (Mimicking command-line args)
# ==============================================================================
args = SimpleNamespace(
    # --- Paths ---
    train_data_path="dataset/multiply_train.bin",
    val_data_path="dataset/multiply_val.bin",
    tokenizer_path="model/",
    
    # --- Model Config ---
    hidden_size=256,
    num_hidden_layers=4,
    num_attention_heads=4,
    abstract_vocab_sizes="8",
    
    # --- Training Config ---
    device="cuda" if torch.cuda.is_available() else "cpu",
    batch_size=4,
    learning_rate=3e-4,
    
    # --- SORL Config ---
    n_rollout=5,
    temperature=1.0,
    K=4,
    denoise_steps=1,
    max_t_search=1,
    use_rhythmic_placeholders=False,
    use_spike_placeholders=False,
    use_special_placeholders=True,
    special_token_id=31,
    abstract_budget=5,
    temperature_flip=False,
    
    # --- Curriculum and Memory ---
    curriculum_ratio=0.6, # looks redundant as of now, it (vaguely) violates the "compositionality" principle
    train_iterations=400, # This will be used by the scheduler
    use_fade_memory=False,
    use_compression_mask=True, # <-- Set to True to test your new mask
    compression_curriculum_ratio=0.25,
    memory_span=20,
    
    # --- GAPT ---
    default_phase=None, # Set to 1 or 2 to override, None to enable GAPT
    delta=0.01,
    tau=0.1,
    p_m=10,
    p_c=10
)

# ==============================================================================
# 2. Initialization
# ==============================================================================
print("--- Initializing components ---")
# --- Tokenizer and Data ---
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
pad_token_id = tokenizer.pad_token_id
train_loader = MemLoader(args.train_data_path, device=args.device)
val_loader = MemLoader(args.val_data_path, device=args.device)

# --- Model ---
base_vocab_size = tokenizer.vocab_size
abstract_vocab_sizes = [int(v) for v in args.abstract_vocab_sizes.split(',')]
full_vocab_list = [base_vocab_size] + abstract_vocab_sizes

minimind_config = MiniMindConfig(
    hidden_size=args.hidden_size,
    num_attention_heads=args.num_attention_heads,
    num_hidden_layers=args.num_hidden_layers,
    vocab_size=sum(full_vocab_list)
)

model = SorlModelWrapper.from_scratch(
    config=minimind_config,
    full_vocab_size_list=full_vocab_list,
    memory_span=args.memory_span,
    pad_token_id=pad_token_id
).to(args.device)

print(f"Model initialized on {args.device} with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters.")

# --- SORL Config and Schedulers ---
sorl_config = SORLConfig(
    n=args.n_rollout, 
    temperature=args.temperature, 
    K=args.K,
    l=1, 
    steps=args.denoise_steps, 
    max_t_search=args.max_t_search,
    use_rhythmic_placeholders=args.use_rhythmic_placeholders,
    use_spike_placeholders=args.use_spike_placeholders,
    use_special_placeholders=args.use_special_placeholders,
    special_token_id=args.special_token_id,
    abstract_budget=args.abstract_budget,
    temperature_flip=args.temperature_flip,
    curriculum_ratio=args.curriculum_ratio,
    use_fade_memory=args.use_fade_memory,
    use_compression_mask=args.use_compression_mask,
    min_keep=args.memory_span, 
    max_seq_len=train_loader.max_length,
    train_iterations=args.train_iterations, 
    train_batch_size=args.batch_size,
    val_batch_size=args.batch_size,
    max_length=train_loader.max_length,
    default_phase=args.default_phase, 
    delta=args.delta, tau=args.tau,
    p_m=args.p_m, p_c=args.p_c
)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


--- Initializing components ---
Model initialized on cpu with 4.59M parameters.
