$\textbf{SoRL (GAT)}$
1. Group advantage computation 
2. Surrogate loss computation

The key for learning from experience is learning from failure

In [2]:
from dataset.arithmetic import ArithmeticDataset

dataset = ArithmeticDataset(
    min_digit=1,
    max_digit=3,
    num_data=2000,
    filepath="dataset/multiplication/2K-123.bin"
).build()

100%|██████████| 2000/2000 [00:00<00:00, 621700.73it/s]

Saved 2000 sequences to dataset/multiplication/2K-123.bin





In [1]:
from src import GATConfig, GAT
from src.sorl import SORLConfig
import torch 
from dataset.arithmetic import ArithmeticDataset

dataset = ArithmeticDataset.from_file("dataset/multiplication/2K-123.bin")

# 1. Setup a dummy model and data
config = GATConfig(L=2, K=3, vocab_size_list=[dataset.vocab_size_list[0], 5], device='cpu')
model = GAT(config)
model.eval()


print(f"Initialized GAT model")

Initialized GAT model


In [2]:
from src import infer_level
from copy import deepcopy 
from src.sorl import prep_denoise
from dataset.base import get_batch

data = get_batch(dataset, batch_size=10, max_length=1024, pad_token_id=model.level_mask_tokens[0])

# forward propagation : compute perplexity per token (traj & abstract)
idx = data[:, :-1].contiguous().clone()
target = data[:, 1:].contiguous().clone() 
ppt = model(idx, target)

# causal generate return next_token (1 per sample)
idx = data.contiguous()
next_idx, kv_cache, levels = model.generate(idx, temperature=0.0)
next_idx, kv_cache, levels = model.generate(next_idx.unsqueeze(1), temperature=0.0, kv_cache=kv_cache, levels=levels)

# parallel denoise (return updated token sequence) 
levels = infer_level(idx, model.vocab_sizes, model.level_mask_tokens[0])
denoise_mask = levels.bool() # toy denoise mask
denoise_mask[1, 0] = True 
denoise_mask[1, 1] = True 

from src import pad_abstract_tokens
# denoise | it conduct in-place update
denoise_idx = deepcopy(idx)
l = 1
denoise_idx = pad_abstract_tokens(denoise_idx, model, l, use_rhythmic_placeholders=True)
denoise_mask, denoise_levels = prep_denoise(denoise_idx, model)

denoise_idx = model.denoise(denoise_idx, denoise_mask, denoise_levels, temperature=0.0) # denoise return an updated idx

In [None]:



# sorl_search(data, model, sorl_config)
# from src.sorl import heuristic_rollout, chunk_denoise, infer_timestamp
# search_data, search_data_idx = heuristic_rollout(data, model, l=config.l, n=config.n-1, temperature=config.temperature, steps=config.steps, max_t_search=config.max_t_search, start_ts=config.start_ts, end_ts=config.end_ts, use_spike_placeholders=config.use_spike_placeholders, abstract_budget=config.abstract_budget, use_rhythmic_placeholders=config.use_rhythmic_placeholders)


In [2]:
# Record & Save an annotated dataset

# from dataset.base import BaseHierDataset
from dataset.arithmetic import ArithmeticHierDataset
from nil import annotate_abstraction
from nil import supervise_gat 

record_dataset = ArithmeticHierDataset.from_dataset(dataset)

# Greedy Abstraction Annotation (Passing knowledge to the next generation)
# ------------------------------------------------------------------------
record_dataset = annotate_abstraction(record_dataset, gat)


# Reset GAT module
# -------------------
gat = GAT(gat_config)


# Weak Supervision (GAT)
# ------------------------------------------------------------------------
weak_iterations = 100 # require tuning
context_length = 1024

supervised_gat = supervise_gat(record_dataset, gat, weak_iterations, context_length)


Iteration 1/100, loss: 4.9127, abs_loss: 2.0794, ssl_loss: 2.8332
Iteration 2/100, loss: 4.7362, abs_loss: 1.9370, ssl_loss: 2.7991
Iteration 3/100, loss: 4.4892, abs_loss: 1.7341, ssl_loss: 2.7552
Iteration 4/100, loss: 4.2146, abs_loss: 1.5094, ssl_loss: 2.7052
Iteration 5/100, loss: 3.9717, abs_loss: 1.3026, ssl_loss: 2.6691
Iteration 6/100, loss: 3.7408, abs_loss: 1.1191, ssl_loss: 2.6218
Iteration 7/100, loss: 3.5362, abs_loss: 0.9555, ssl_loss: 2.5807
Iteration 8/100, loss: 3.3784, abs_loss: 0.8085, ssl_loss: 2.5699
Iteration 9/100, loss: 3.2184, abs_loss: 0.6780, ssl_loss: 2.5404
Iteration 10/100, loss: 3.0805, abs_loss: 0.5644, ssl_loss: 2.5161
Iteration 11/100, loss: 2.9727, abs_loss: 0.4672, ssl_loss: 2.5055
Iteration 12/100, loss: 2.8609, abs_loss: 0.3852, ssl_loss: 2.4758
Iteration 13/100, loss: 2.7815, abs_loss: 0.3173, ssl_loss: 2.4642
Iteration 14/100, loss: 2.7241, abs_loss: 0.2617, ssl_loss: 2.4625
Iteration 15/100, loss: 2.6535, abs_loss: 0.2164, ssl_loss: 2.4371
Iter

Keep it beautifully simple

In [5]:
from src import sorl_search, SORLConfig
from dataset.arithmetic import ArithmeticDataset
from src import GATConfig, GAT

sorl_config = SORLConfig(
    n = 2,
    temperature = 0.75,   
    # rollout specific 
    causal_rollout=False, 
    
    l=1,
    steps=1,
    use_rhythmic_placeholders=True,
    use_spike_placeholders=True,
    abstract_budget=5,
    max_t_search=5,

    # dataset specific
    train_dataset_path="dataset/multiplication/100K-123.bin",
    val_dataset_path="dataset/multiplication/2K-123.bin",
    train_batch_size=24,
    val_batch_size=1,
    train_iterations=400,
    val_iterations=1,
    # optimization
    learning_rate=1e-3, 
    log_interval=10
)

train_dataset = ArithmeticDataset.from_file(sorl_config.train_dataset_path)
val_dataset = ArithmeticDataset.from_file(sorl_config.val_dataset_path)
traj_vocab_size = train_dataset.vocab_size_list[0]

gat_config = GATConfig(K=3, L=2, n_embd=128, n_head=4, n_layer=4, device="cpu", _compile=False,
                       vocab_size_list=[traj_vocab_size, 8], t_keep=sorl_config.max_length)

gat = GAT(gat_config)

In [6]:
from src.sorl import SearchScheduler, sorl_search, compute_loss, evaluate
import torch 
from dataset.base import get_batch

start_step = 0 
config = sorl_config

optimizer = torch.optim.Adam(gat.parameters(), lr=config.learning_rate)
scheduler = SearchScheduler(config, gat.K, curriculum_ratio=0.5)

for i in range(config.train_iterations):
    # config.temperature = 0.0 if i % 2 == 0 else 1.0
    global_step = start_step + i
    gat.train() 

    t_search = scheduler.step()
    config.max_t_search = t_search

    data = get_batch(train_dataset, config.train_batch_size, config.max_length, gat.level_mask_tokens[0], device=gat.device)

    with torch.no_grad(): 

        search_data, switch_ratio = sorl_search(data, gat, config)
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    ppt = gat(search_data[:, :-1].contiguous(), target=search_data[:, 1:].contiguous())

    ssl_loss, abs_loss = compute_loss(search_data, gat, ppt)
    loss = abs_loss + ssl_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Validation needs to be more rigorous : more samples
    gat.eval()

    with torch.no_grad(): 
        _, improve_ppl_train, _, vocab_utilization_rate = evaluate(data, gat, 5, config)

    print(f"Iteration {i+1}/{config.train_iterations} "
                    f"- loss: {loss.item():.4f}, abs_loss: {abs_loss.item():.4f}, ssl_loss: {ssl_loss.item():.4f}, search_ppl: {improve_ppl_train.item():.4f}, switch_ratio: {switch_ratio:.4f}, vocab_utilization_rate: {vocab_utilization_rate:.4f}, t_search: {t_search}")

Iteration 1/400 - loss: 3.1780, abs_loss: 0.0000, ssl_loss: 3.1780, search_ppl: 0.0002, switch_ratio: 0.0000, vocab_utilization_rate: 0.0000, t_search: 0
Iteration 2/400 - loss: 3.1360, abs_loss: 0.0000, ssl_loss: 3.1360, search_ppl: -0.0000, switch_ratio: 0.0000, vocab_utilization_rate: 0.0000, t_search: 0
Iteration 3/400 - loss: 3.0655, abs_loss: 0.0000, ssl_loss: 3.0655, search_ppl: -0.0002, switch_ratio: 0.0000, vocab_utilization_rate: 0.0000, t_search: 0
Iteration 4/400 - loss: 2.9848, abs_loss: 0.0000, ssl_loss: 2.9848, search_ppl: 0.0000, switch_ratio: 0.0000, vocab_utilization_rate: 0.0000, t_search: 0
Iteration 5/400 - loss: 2.8915, abs_loss: 0.0000, ssl_loss: 2.8915, search_ppl: -0.0001, switch_ratio: 0.0000, vocab_utilization_rate: 0.0000, t_search: 0
Iteration 6/400 - loss: 2.8178, abs_loss: 0.0000, ssl_loss: 2.8178, search_ppl: -0.0002, switch_ratio: 0.0000, vocab_utilization_rate: 0.0000, t_search: 0
Iteration 7/400 - loss: 2.7558, abs_loss: 0.0000, ssl_loss: 2.7558, sear

KeyboardInterrupt: 

In [None]:
# (A). no curriculum | temperature flip |  ppl_improve: 15. | abs loss ~ 0.0
# (B). curriculum | temperature flip | ppl_improve: 26. | abs loss ~ 0.0
# Conclusion: curriculum helps improve the search ability of the model. 

# (C). curriculum | no temperature flip, temperature=1.0 | ppl_improve: 0.15 | abs loss: 1.33
# Conclusion: temperature flip stabilizes abstraction (we just need more push towards convergence)

# (D). curriculum | disallow spike placeholders | ppl_improve: 9.6 | abs_loss: 0.55 | vocab_utilization_rate: 0.22
# (E). curriculum | allow spike placeholders | ppl_improve: 13.8 | abs_loss: 0.6 | vocab_utilization_rate: 0.22
# Conclusion: it's hard to tell the effect of spike-placeholders here, but vocab_utilization is an issue

# (F). curriculum | temperature=0.5 | ppl_improve: 30.6 | abs_loss: ~0.0 | vocab utilization 0.2 (flucturate a bit)
# (G). curriculum | temperature=0.75 | ppl_improve: 15.9 | abs_loss: 0.46 | vocab utilization 0.33 ~ 0.44 

# Q. what about the perplexity-placeholder? Does it help? 
# Q. how about memory fading? Can it work? 
# Q. can we measure 'vocabulary utilization rate'?

In [8]:

# compute abstract utilization rate
# search_data <=
si, ei = gat.vocab_sizes.cumsum(dim=0) # begin_idx, end_idx

vocab_utilization_rate = search_data[(search_data >= si) & (search_data < ei)].unique().size(0) / (ei - si).item()


vocab_utilization_rate = data[(data >= si) & (data < ei)].unique().size(0) / (ei - si).item()

vocab_utilization_rate

0.0