#### Dynamic Causal Propagation

In [1]:
# The issue with this counting sequence is that it has batch size of 1 only 
# it's fine for now, but clearly a fourier series decomposition is more beautiful

# (I). Counting Sequence
# --------------------------------------------------------------------------------
from collections import defaultdict
import torch

# (I.1) Generate Counting Sequence
# ----------------------------------------------------
def generate_level(l: int, seq: list, t: int, L: int, K: int): 
    if l < L:
        seq[l] += str(l)
        if t % K == 0: 
            return generate_level(l+1, seq, t // K, L, K)
    return seq

def generate_count_seq(L: int, K: int, T: int): 
    seq = defaultdict(str)
    for t in range(1, T+1): 
        seq = generate_level(0, seq, t, L, K)
    return seq
# ----------------------------------------------------

# (I.2) Tokenizer (basic integer tokenizer)
# ----------------------------------------------------

class TinyTokenizer: 
    def __init__(self, vocab: dict):
        self.vocab = {str(k): v for k, v in vocab.items()}
        self.vocab_size = len(vocab)

    def __call__(self, seq: str):
        return [self.vocab[c] for c in seq]

# ----------------------------------------------------


# (I.3) Tensor idx sequence preparation 
# ----------------------------------------------------

L = 3
K = 3
T = 1024

data = generate_count_seq(L, K, T)
tokenizer = TinyTokenizer({str(k): k for k in range(10)})

idx = [tokenizer(seq) for seq in data.values()]
samples = [(idx, None)]

from model import GATConfig, GAT, BatchedHierSeq
from torch.optim import Adam 

config = GATConfig(K=K, L=L, n_embd=128, n_head=4, device="cpu", _compile=False)
gat = GAT(config)

batch_data = BatchedHierSeq.from_hierarchical_data(samples, K=gat.K, L=gat.L)

# Batched Forward Propagation
epochs = 20 
gat.train()


# Training Loop : learning just fine -- loss reduces quickly
# ----------------------------------------------------
optimizer = Adam(gat.parameters(), lr=1e-3)

for epoch in range(epochs):
    loss = gat(batch_data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")
# ----------------------------------------------------

Epoch 1/20, Loss: 4.158883571624756
Epoch 2/20, Loss: 3.9783146381378174
Epoch 3/20, Loss: 3.731306791305542
Epoch 4/20, Loss: 3.5046794414520264
Epoch 5/20, Loss: 3.2924158573150635
Epoch 6/20, Loss: 3.079880475997925
Epoch 7/20, Loss: 2.8665931224823
Epoch 8/20, Loss: 2.656156063079834
Epoch 9/20, Loss: 2.4512057304382324
Epoch 10/20, Loss: 2.2491233348846436
Epoch 11/20, Loss: 2.052178382873535
Epoch 12/20, Loss: 1.8624504804611206
Epoch 13/20, Loss: 1.6795622110366821
Epoch 14/20, Loss: 1.5043525695800781
Epoch 15/20, Loss: 1.3342372179031372
Epoch 16/20, Loss: 1.1713873147964478
Epoch 17/20, Loss: 1.0061345100402832
Epoch 18/20, Loss: 0.858919084072113
Epoch 19/20, Loss: 0.7540488243103027
Epoch 20/20, Loss: 0.6810130476951599


In [4]:
# Batched Generation: 
test_data = generate_count_seq(L, K, 10)
idx = [tokenizer(seq) for seq in test_data.values()]
test_samples = [(idx, None)]
test_batch_data = BatchedHierSeq.from_hierarchical_data(test_samples, K=gat.K, L=gat.L)


# Bug: new-token level is not correct -- it's always 1 here
n_new_toks = 15
for _ in range(n_new_toks): 
    test_batch_data = gat.generate(test_batch_data)
    print(f"Generated token idx {test_batch_data.tokens[-1].item()} at level {test_batch_data.levels[-1].item()}")


Generated token idx 0 at level 0
Generated token idx 0 at level 0
Generated token idx 1 at level 1
Generated token idx 0 at level 0
Generated token idx 0 at level 0
Generated token idx 0 at level 0
Generated token idx 1 at level 1
Generated token idx 0 at level 0
Generated token idx 0 at level 0
Generated token idx 0 at level 0
Generated token idx 1 at level 1
Generated token idx 2 at level 2
Generated token idx 0 at level 0
Generated token idx 0 at level 0
Generated token idx 0 at level 0


In [6]:
test_batch_data

BatchedHierSeq(tokens=tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
        1, 2, 0, 0, 0]), levels=tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
        1, 2, 0, 0, 0]), timestamps=tensor([ 1,  2,  3,  3,  4,  5,  6,  6,  7,  8,  9,  9,  9, 10, 11, 12, 12, 13,
        14, 15, 15, 16, 17, 18, 18, 18, 19, 20, 21]), sample_idx=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0]), batch_size=1, K=3, L=3)

In [1]:
from model import GATConfig, GAT, BatchedHierSeq


K, L = 3, 3
config = GATConfig(K=K, L=L, n_embd=128, n_head=4, device="cpu", _compile=False)
gat = GAT(config)

samples = [
    (
        [[10, 11, 12], [20, 21], [30]],  # token_sequences for sample 1
        None  # will generate default timestamps
    ),
    (
        [[40, 41, 42, 43], [50, 51], [31]],  # token_sequences for sample 2  
        None
    )
]

batch_data = BatchedHierSeq.from_hierarchical_data(samples, K=gat.K, L=gat.L)

# Batched forward propagation
gat(batch_data)

# Batched generation
gat.generate(batch_data)

BatchedHierSeq(tokens=tensor([10, 11, 12, 20, 21, 30,  0, 40, 41, 42, 50, 43, 51, 31,  0]), levels=tensor([0, 0, 0, 1, 1, 2, 0, 0, 0, 0, 1, 0, 1, 2, 0]), timestamps=tensor([ 1,  2,  3,  3,  6,  9, 10,  1,  2,  3,  3,  4,  6,  9, 10]), sample_idx=tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), batch_size=2, K=3, L=3)

In [4]:
# Comparative studies required between GAT and other hiearchical models, such as 'adaptive-chunk', 'byte-latent', 'hiearchical-reasoning', 'JEPA' to name a few. 