In [None]:
from toy_model import *
from metrics import *
import wandb
import torch
import numpy as np

In [None]:
wandb.login()

In [None]:
#train data process
T0 = np.array([
    [0, 1, 0],
    [0, 0, 1],
    [0, 0, 0.5]
])

T1 = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0.5, 0, 0]
])

In [None]:
# to find kl loss between model and these processes
T0_proc1 = np.array([
    [0, 1, 0],
    [0, 0, 1], 
    [0, 0, 0.5]
])
T1_proc1 = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0.5, 0, 0]
])

# Different process
'''T0_proc2 = np.array([
    [0.3, 0.7, 0],
    [0, 0.2, 0.8],
    [0.1, 0.1, 0.8]
])
T1_proc2 = np.array([
    [0.2, 0.3, 0.5],
    [0.6, 0.4, 0],
    [0, 0.8, 0.2]
])'''
process1 = MarkovData(n_gen=50, gen_len=32, n_states=3, d_vocab=2, T_list=[T0_proc1, T1_proc1], seed=42)
#process2 = MarkovData(n_gen=50, gen_len=30, n_states=3, d_vocab=2, T_list=[T0_proc2, T1_proc2], seed=43)


In [None]:
metrics_config = MetricsConfig(
    track_markov_kl=True,
    markov_processes=[process1],  # Will create markov_kl_proc0, markov_kl_proc1
    
    
    track_ngrams=True,
    ngram_orders=[1, 2, 3],
    track_previous_token=True,
    track_in_context=True, 
    #icl_data=icl_data,
    icl_k1=5,
    icl_k2=32,

    track_prefix_matching=False)

In [None]:
dataset = MarkovData(10000, 32, 3, 2, [T0, T1])

In [None]:

model = train_model(
    dataset=dataset,
    n_layers=2,
    d_model=16,
    n_heads=2, 
    attn_only=True,
    act_fn='silu',

    # Training
    n_epochs=300,
    batch_size=64,
    lr=0.1,

    # Logging
    wandb=True,
    wandb_project_name="ICL",
    save_dir="proc1/debug/",
    save_every=20,
    print_every=10,

    # ALL ADVANCED METRICS ENABLED
    metrics_config=metrics_config,
    metrics_log_interval=20
    )
wandb.finish()

In [None]:
model=load_model("proc1/seq_len_30/model300.pt","proc1/seq_len_30/model_cfg.pt")
logits = model(torch.tensor([[0,1,1,0,1,0,0,1,1,0], 
                                [1,0,1,1,0,1,0,0,1,1], 
                                [1,0,0,1,0,0,1,0,0,1]], dtype=torch.int64))
print(logits[:, -1])
print(logits[:, -1].argmax(dim=-1))

# Sample and compare
sample, states = dataset.model.sample_sequence(max_new_tokens=32)
preds = model(torch.tensor([sample], dtype=torch.int64)).argmax(dim=-1).flatten().tolist()

for s, pred in zip(sample[1:], preds[:-1]):
    print(f"Actual: {s}, Predicted: {pred}")
