1. $0^{th}$ level policy model $\pi^{(0)}(s_{t}|s_{t^{(1)}}^{(1)} \circ a_{t}^{(1)})$
2. $1^{th}$ level policy model $\pi^{(1)}()$

In [None]:
from model import sandwich_embedding as se 
import torch 
from model import GPTConfig

config = GPTConfig(vocab_size=50304, n_layer=12, n_head=6, n_embd=768, K=3, L=2, device="cpu", _compile=False)
B, S, D = 3, 6, config.n_embd 
K = config.K
L = config.L

# Sandwich embedding ensemble (temporal predicted token embeddings ensemble)

token_embeddings = torch.randn(B, S, D)
high_level_embeddings = torch.randn(B, S//K, D)
low_level_embeddings = torch.randn(B, S*K, D)

# (I). Embedding Ensemble
# v1. pure additive ensemble across abstraction levels
se(low_level_embeddings, token_embeddings, high_level_embeddings, K)

# (II). Conditional GPT
from model import CondGPT, GPTConfig

condgpt = CondGPT(config)
idx = torch.randint(0, 50304, (B, S))
condgpt.forward(idx, high_level_embeddings, low_level_embeddings)
condgpt.generate(idx, high_level_embeddings, low_level_embeddings)[1]

tensor([0, 0, 0])

In [6]:
seq_len = token_embeddings.shape[1]
L2 = low_level_embeddings.shape[1]
assert seq_len*K <= L2, f"Planning without grounding is not allowed: {seq_len} > {L2}"
token_embeddings += low_level_embeddings[:, 0:seq_len*K:K]

if high_level_embeddings is not None: 
    L1 = high_level_embeddings.shape[1]
    assert L1 * K <= seq_len < (L1 + 1) * K, f"Execution without purpose or planning without grounding is not allowed: {L1 * K} < {seq_len} <= {(L1 + 1) * K}"
    cond_embeddings = high_level_embeddings.repeat_interleave(K, dim=1)
    token_embeddings[:, :L1 * K] += cond_embeddings

1. Better to put a $<begin>$ token at each abstract level, just to avoid forward propagation without token at abstract level. 

In [12]:
Lmax = 4 # maxiaml abstraction level
K = 2  # abstraction ratio
	
	
# Version 2. Focus on 'generating all tokens for current time' t, realisticly speaking, having the model go on forever is not a great idea
#            better control at least "how many steps it'll run", or at least having it wait for the real world ...

def generate_level(l: int, curr: str, t: int): 
    if l <= Lmax:
        curr += str(l)
        if t % K == 0: 
            return generate_level(l+1, curr, t // K)
    return curr
		
# bugs for t=2
total_str = ""
for t in range(1, 2**4+1):
	t_str = generate_level(0, "", t)
	total_str += t_str
	print("Step t string: ", total_str) 



Step t string:  0
Step t string:  001
Step t string:  0010
Step t string:  0010012
Step t string:  00100120
Step t string:  0010012001
Step t string:  00100120010
Step t string:  001001200100123
Step t string:  0010012001001230
Step t string:  001001200100123001
Step t string:  0010012001001230010
Step t string:  0010012001001230010012
Step t string:  00100120010012300100120
Step t string:  0010012001001230010012001
Step t string:  00100120010012300100120010
Step t string:  0010012001001230010012001001234


In [16]:
Lmax = 4
K = 2
BOS_TOKEN_ID = 0 

def decorate_sequences(idx: list, Lmax: int): 
	if not isinstance(idx[0], list): 
		idx = [idx] + [[BOS_TOKEN_ID] for l in range(1, Lmax)]
	else:
	  assert len(idx) == Lmax, f"Missing sequence for {Lmax} abstraction levels, currently only got {len(idx)}."
	  idx = [seq if (len(seq)>0 and isistance(seq[0], int)) else [BOS_TOKEN_ID] for seq in idx]
	return idx


idx = [3, 4, 5]
decorate_sequences(idx, Lmax)

[[3, 4, 5], [0], [0], [0]]

In [5]:
K

2

In [7]:
t = 1