$\textbf{SoRL (GAT)}$
1. Group advantage computation 
2. Surrogate loss computation

The key for learning from experience is learning from failure

In [1]:
from dataset.arithmetic import ArithmeticDataset


dataset = ArithmeticDataset(
    min_digit=1,
    max_digit=3,
    L=2,
    K=3,
    num_data=2000,
    filepath="dataset/multiplication/2K-123.bin"
).build()

100%|██████████| 2000/2000 [00:00<00:00, 1085763.40it/s]

Saved 2000 sequences to dataset/multiplication/2K-123.bin





In [1]:
# learn-to-explain (GAT) | one batch ver.

# (1). Initialize Data Buffer with trajectory-only data 

from model import GATConfig, GAT
from dataset.arithmetic import ArithmeticDataset

from dataclasses import asdict
from search import SORLConfig 
import wandb

gat_config = GATConfig(K=3, L=2, n_embd=128, n_head=4, n_layer=4, device="cpu", _compile=False,
                       vocab_size_list=[17, 8])
gat = GAT(gat_config)


config = SORLConfig(gat_config=gat_config, 
           n_generations=4, temperature=1.0, num_iterations=2, 
           joint_steps=10, context_length=1024, learning_rate=1e-3,
           dataset_name="100K-123", 
           dataset_path="dataset/multiplication/100K-123.bin",
           id_validate_dataset_path="dataset/multiplication/2k-123.bin",
           ood_validate_dataset_path="dataset/multiplication/2k-123.bin")

# load dataset
dataset = ArithmeticDataset.from_file(config.id_validate_dataset_path)
# id_val_dataset = ArithmeticDataset.from_file(config.id_validate_dataset_path)


# gat.load_checkpoint("experiment/nbody/SoRL-GRPO-per-token-alternate-nbody.pt")

In [1]:
from gat import GATConfig, reGAT
from sorl import generate_rollout_data
from search import repeat_hseq # Assuming repeat_hseq is in search.py
from utils import HierSeq
import torch 

print("\n--- Testing generate_rollout_data ---")

# 1. Setup a dummy model and data
config = GATConfig(L=2, K=3, vocab_size_list=[10, 5], device='cpu')
model = reGAT(config)
model.eval()

h_data = [
    ([1, 2, 3, 4, 5, 6, 7], []), # L0 tokens, no abstractions
    ([1, 2, 3, 4, 5, 6, 7], [])
]


--- Testing generate_rollout_data ---


#### Refactor: What if we put different sample in different batch dimension?

In [2]:
# util functions 
# ------------------------------------------------------------------------
def infer_level(indices: torch.Tensor, vocab_sizes: torch.Tensor, pad_token: int):
    indices_expanded = indices.unsqueeze(-1)  # [batch_size, seq_len, 1]
    levels = (indices_expanded < vocab_sizes.cumsum(dim=0)).int().argmax(dim=-1)

    padding_mask = (indices == pad_token)
    final_levels = torch.where(padding_mask, -1, levels.long())
    return final_levels

# this produces 'abstract timestamps'
def infer_timestamp(levels: torch.Tensor, K: int, l: int = 1) -> torch.Tensor:
    is_level = (levels == l-1).long()  
    cumulative_counts = torch.cumsum(is_level, dim=-1)
    timestamps = (cumulative_counts - 1) // K
    timestamps.clamp_(min=0) # this assings the correct timestamp 
    return timestamps

# Rhythmic insertion mask calculation
def infer_rythmic_insertion_mask(levels: torch.Tensor, timestamps: torch.Tensor, K: int, l: int): 

    within_level_mask = (levels <= l)
    timestamps[~within_level_mask] = False 

    B = timestamps.size(0)
    is_end_of_groups = torch.cat([
        (timestamps[:, :-1] != timestamps[:, 1:]),
        torch.full((B, 1), True, device=timestamps.device)
    ], dim=1)

    is_valid_elements = [] 
    for timestamp in timestamps: 
        group_counts = torch.bincount(timestamp) # count consecutive value group size
        is_valid_group = group_counts >= K 
        is_valid_element = is_valid_group[timestamp] # timestamp starts from 0 makes this valid
        is_valid_elements.append(is_valid_element)
    is_valid_elements = torch.stack(is_valid_elements, dim=0)

    insert_mask = is_end_of_groups & is_valid_elements
    return insert_mask # insert after 'True' position suffices

import torch.nn.functional as F

def insert_tokens(
    tokens: torch.Tensor,
    insert_masks: torch.Tensor,
    placeholder_token: int,
    pad_token: int
) -> torch.Tensor:

    B, S_orig = tokens.shape
    device = tokens.device

    n_insertions_per_sample = insert_masks.sum(dim=1)
    max_n_insertions = n_insertions_per_sample.max().item()

    if max_n_insertions == 0:
        return tokens

    S_new = S_orig + max_n_insertions
    
    new_tokens = torch.full((B, S_new), pad_token, dtype=tokens.dtype, device=device)

    padded_masks = F.pad(insert_masks, (1, 0), value=0)[:, :-1].long()
    shifts = torch.cumsum(padded_masks, dim=1)

    original_indices_seq = torch.arange(S_orig, device=device).expand(B, -1)
    original_dest_indices = original_indices_seq + shifts

    new_tokens.scatter_(1, original_dest_indices, tokens)

    ph_rows, ph_cols = insert_masks.nonzero(as_tuple=True)
    ph_shifts = shifts[ph_rows, ph_cols]
    ph_dest_cols = ph_cols + 1 + ph_shifts
    new_tokens[ph_rows, ph_dest_cols] = placeholder_token

    return new_tokens

In [3]:
# We just need to record 'level' besides 'token idx'
# - in fact, given our per-level vocabulary size parameter, we can just record 'token_idx' as level can be inferred directly from it

# convert h_data into idx (batched ver, no flattening at all)

# hierarchical data (ordered) -- it mixes abstract & non-abstract tokens
data = torch.tensor([[1,2,3,12,4,5,6,7], [1,2,3,5,4,5,6,7]]) 

# infer level from 'idx' (that optionally contains 'abstract' token)
# levels = infer_level_from_idx(idx, model.vocab_sizes)

# forward propagation through GAT module
# (Option 1. remove the level-embedding part -- does not help in "search advantage")
# (Option 2. keep the level-embedding)
idx = data[:, :-1].contiguous()
target = data[:, 1:].contiguous() # .continuous() is important to avoid error for '.view()'

# forward propagation : compute perplexity per token (traj & abstract)
ppt = model(idx, target)

# generate return next_token (1 per sample)
next_idx, kv_cache = model.generate(idx, temperature=0.0)
next_idx, kv_cache = model.generate(next_idx.unsqueeze(1), temperature=0.0, kv_cache=kv_cache)

# denoise return updated idx, variable number of tokens updated per sample
levels = infer_level(idx, model.vocab_sizes, model.level_mask_tokens[0])
denoise_mask = levels.bool()
denoise_mask[1, 0] = True 
denoise_mask[1, 1] = True 

# denoise
updated_idx = model.denoise(idx, denoise_mask, temperature=0.0) # denoise return an updated idx

In [4]:
# ----------------------------------------------------------------------------------------------------------------




In [None]:
vocab_sizes = model.vocab_sizes
K = model.K 
tokens = idx
l = 1

# rythmic placeholder insertion

levels = infer_level(tokens, vocab_sizes, model.level_mask_tokens[0])
timestamps = infer_timestamp(levels, K, l)

insert_masks = infer_rythmic_insertion_mask(levels, timestamps, K, l)

# we'd use last token in level-0 as pad token
new_tokens = insert_tokens(tokens, insert_masks, model.level_mask_tokens[l], model.level_mask_tokens[0])

# Similarly, we can perform perplexity-based placeholder insertion

In [2]:
# Record & Save an annotated dataset

# from dataset.base import BaseHierDataset
from dataset.arithmetic import ArithmeticHierDataset
from nil import annotate_abstraction
from nil import supervise_gat 

record_dataset = ArithmeticHierDataset.from_dataset(dataset)

# Greedy Abstraction Annotation (Passing knowledge to the next generation)
# ------------------------------------------------------------------------
record_dataset = annotate_abstraction(record_dataset, gat)


# Reset GAT module
# -------------------
gat = GAT(gat_config)


# Weak Supervision (GAT)
# ------------------------------------------------------------------------
weak_iterations = 100 # require tuning
context_length = 1024

supervised_gat = supervise_gat(record_dataset, gat, weak_iterations, context_length)


Iteration 1/100, loss: 4.9127, abs_loss: 2.0794, ssl_loss: 2.8332
Iteration 2/100, loss: 4.7362, abs_loss: 1.9370, ssl_loss: 2.7991
Iteration 3/100, loss: 4.4892, abs_loss: 1.7341, ssl_loss: 2.7552
Iteration 4/100, loss: 4.2146, abs_loss: 1.5094, ssl_loss: 2.7052
Iteration 5/100, loss: 3.9717, abs_loss: 1.3026, ssl_loss: 2.6691
Iteration 6/100, loss: 3.7408, abs_loss: 1.1191, ssl_loss: 2.6218
Iteration 7/100, loss: 3.5362, abs_loss: 0.9555, ssl_loss: 2.5807
Iteration 8/100, loss: 3.3784, abs_loss: 0.8085, ssl_loss: 2.5699
Iteration 9/100, loss: 3.2184, abs_loss: 0.6780, ssl_loss: 2.5404
Iteration 10/100, loss: 3.0805, abs_loss: 0.5644, ssl_loss: 2.5161
Iteration 11/100, loss: 2.9727, abs_loss: 0.4672, ssl_loss: 2.5055
Iteration 12/100, loss: 2.8609, abs_loss: 0.3852, ssl_loss: 2.4758
Iteration 13/100, loss: 2.7815, abs_loss: 0.3173, ssl_loss: 2.4642
Iteration 14/100, loss: 2.7241, abs_loss: 0.2617, ssl_loss: 2.4625
Iteration 15/100, loss: 2.6535, abs_loss: 0.2164, ssl_loss: 2.4371
Iter

Keep it beautifully simple

In [4]:
# Benchmark RL & SSL combination strategies 
# (I). Pick the best & learn it 
# -------------------------------------------------------------
import copy 
import wandb
import torch
from search import compute_ssl_loss, get_batch, eval_search_improvement
from search import compute_abs_ssl_loss
from search import sorl_search_v2
from search import compute_curriculum_t_increment, eval_ppl_with_search, curriculum_iter

n = 3 
temperature = 1.0
num_iterations = 200
context_length = 1024
global_step = 0 

optimizer = torch.optim.Adam(gat.parameters(), lr=1e-3)
gat.train() 

# curriculum
t_search = 0
t_delta, t_max = compute_curriculum_t_increment(num_iterations=num_iterations, context_length=context_length, K=gat.K, max_ts=max(dataset.lengths))

while global_step < num_iterations: 

    batch_data = get_batch(dataset.sequences, dataset.lengths, context_length // n, gat.L, gat.K)

    with torch.no_grad(): 
        repeat_batch, switch_ratio, rollout_advantages = sorl_search_v2(gat, batch_data, n, temperature, t_search) # pinned greedy sample ver.

    ppt = gat(repeat_batch)

    ssl_loss = compute_ssl_loss(repeat_batch, ppt)
    abs_loss = compute_abs_ssl_loss(repeat_batch, ppt, level=1)

    loss = abs_loss + ssl_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f"Iteration {global_step+1}/{num_iterations}, loss: {loss.item():.4f}, abs_loss: {abs_loss.item():.4f}, ssl_loss: {ssl_loss.item():.4f}")

    global_step += 1
    t_search = min(t_search + t_delta, t_max)
    del loss, abs_loss, ssl_loss
    
    # train data ppl improvement
    improve_ppl_train = eval_search_improvement(gat, batch_data, t_search=t_search)
    print(f"\nImprove ppl percentage (train): {improve_ppl_train:.4f}")
    print(f"per-sample abstraction switch ratio: {switch_ratio:.4f} | t_search: {t_search} | (How often greedy sample is rejected by other abstraction)")
    # s = observe_abstraction(batch_data, gat, t_search=t_search, temperature=0.0)
    # print(s)

    # if global_step % 10 == 0: 
    if False: 
        val_data = get_batch(id_val_dataset.sequences, id_val_dataset.lengths, context_length, gat.L, gat.K)
        
        with torch.no_grad(): 
            improve_ppl_val = eval_search_improvement(gat, val_data, t_search=t_search)
            print(f"Improve ppl percentage (val): {improve_ppl_val:.4f}\n")
        
            if t_search == t_max:
                traj_ppl_val = eval_ppl_with_search(val_data, gat, dataset.answer_token_id, n=6, temperature=1.0)
                print(f"Traj ppl (val): {traj_ppl_val.mean().item():.4f}\n")

            if not config.t_curriculum: 
                traj_ppl_val = eval_generate_ppl(gat, val_data, n=1, temperature=0.0, t_search=t_search).mean()
                print(f"Traj ppl (val): {traj_ppl_val.item():.4f}\n")

Iteration 1/200, loss: 3.2189, abs_loss: 0.0000, ssl_loss: 3.2189

Improve ppl percentage (train): 0.0000
per-sample abstraction switch ratio: 0.0000 | t_search: 1 | (How often greedy sample is rejected by other abstraction)
Iteration 2/200, loss: 3.1475, abs_loss: 0.0000, ssl_loss: 3.1475

Improve ppl percentage (train): 0.0376
per-sample abstraction switch ratio: 0.0000 | t_search: 2 | (How often greedy sample is rejected by other abstraction)
Iteration 3/200, loss: 6.4414, abs_loss: 3.3758, ssl_loss: 3.0656

Improve ppl percentage (train): 0.0855
per-sample abstraction switch ratio: 0.4000 | t_search: 3 | (How often greedy sample is rejected by other abstraction)
Iteration 4/200, loss: 6.2989, abs_loss: 3.2920, ssl_loss: 3.0069

Improve ppl percentage (train): 0.1368
per-sample abstraction switch ratio: 0.2857 | t_search: 4 | (How often greedy sample is rejected by other abstraction)
Iteration 5/200, loss: 6.1441, abs_loss: 3.1884, ssl_loss: 2.9557

Improve ppl percentage (train): 0

KeyboardInterrupt: 