In [None]:
from typing import List, Optional, Tuple
import numpy as np
from mealymarkov import MarkovMealyModel
import os
from dotenv import load_dotenv
load_dotenv()

ozz_FILE_PATH = os.getenv('100_SAVE_PATH')
zir_FILE_PATH = os.getenv('ZIR_SAVE_PATH')

# Example small model (n=4 states, V=2 tokens) that satisfies the constraints.
n = 3
V = 2

# We construct T^0 and T^1 so that T^0 + T^1 is row-stochastic (rows sum to 1).
T0 = np.array([
    [0, 1, 0],
    [0, 0, 1],
    [0, 0, 0.5]
])

T1 = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0.5, 0, 0]
])

model = MarkovMealyModel(n=n, V=V, T_list=[T0, T1])

# By specification the default eta^0 is uniform
print("Initial eta^0 =", model.eta0)

tokens, states = model.sample_sequence(max_new_tokens=3, seed=42)

print("Generated tokens:", tokens)
print("States (eta^t) traversed:")
for i, s in enumerate(states):
    print(f"t={i} ->", np.round(s, 4))

Generating training sequences

In [None]:
import numpy as np
import json
#generating the process as discussed in the previous meet
#for the process that generates 100*
n = 3
V = 2
num_training_samples = 100
sequences = {}
# We construct T^0 and T^1 so that T^0 + T^1 is row-stochastic (rows sum to 1).
T0 = np.array([
    [0, 1, 0],
    [0, 0, 1],
    [0, 0, 0.5]
])

T1 = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0.5, 0, 0]
])

model = MarkovMealyModel(n=n, V=V, T_list=[T0, T1])
for i in range(num_training_samples):
    tokens, _ = model.sample_sequence(max_new_tokens=50)
    sequences[i] = tokens
with open(ozz_FILE_PATH, 'w') as fp:
    json.dump(sequences, fp, indent=4)

In [None]:

#generating the process as discussed in the previous meet
#for the process that generates ZIR
n = 3
V = 2
num_training_samples = 100
sequences = {}
# We construct T^0 and T^1 so that T^0 + T^1 is row-stochastic (rows sum to 1).
T0 = np.array([
    [0, 1, 0],
    [0, 0, 0],
    [0.5, 0, 0]
])

T1 = np.array([
    [0, 0, 0],
    [0, 0, 1],
    [0.5, 0, 0]
])

model = MarkovMealyModel(n=n, V=V, T_list=[T0, T1])
for i in range(num_training_samples):
    tokens, _ = model.sample_sequence(max_new_tokens=50)
    sequences[i] = tokens
with open(zir_FILE_PATH, 'w') as fp:
    json.dump(sequences, fp, indent=4)

In [None]:
import numpy as np
import torch
from toy_model import train_model, finetune_model, MarkovData

T0 = np.array([
    [0, 1, 0],
    [0, 0, 0],
    [0.5, 0, 0]
])
T1 = np.array([
    [0, 0, 0],
    [0, 0, 1],
    [0.5, 0, 0]
])

dataset = MarkovData(n_gen=1000, gen_len=50, n_states=3, d_vocab=2, T_list=[T0, T1])
model = train_model(
    dataset=dataset,
    n_layers=4,
    d_model=4,
    d_head=2,
    d_mlp=16,
    attn_only=True,
    n_epochs=500,
    lr=5e-3,
    batch_size=200,
    save_every=1000,
    print_every=10000,
    save_dir=None # To not to save the model
)

In [None]:
sample, states = dataset.model.sample_sequence(max_new_tokens=40)
preds = model(torch.tensor(sample, dtype=torch.int64)).argmax(dim=-1).flatten().tolist()
for s, pred in zip(sample[1:], preds[:-1]):
    print(f'Actual: {s}, Predicted: {pred}')

In [None]:
model = finetune_model(model, dataset, n_epochs=5, save_dir=None) # Add additional arguments as needed

In [None]:
with torch.no_grad():
    model.eval()
    logits = model(torch.tensor([[0,1,1,0,1,0,0,1,1,0],
                                 [1,0,1,1,0,1,0,0,1,1],
                                 [1,0,0,1,0,0,1,0,0,1]], dtype=torch.int64))
print(logits[:, -1, :])
print(logits[:, -1, :].argmax(dim=-1))
# Ground truth values: [1, 0, R]

Verifying whether a given probability distribution is what the Markov model would have provided

In [None]:
# Test the verify_sequence function
print('\n' + '='*50)
print('Testing verify_sequence function:')
print('='*50)

# Create a test sequence and probability distribution
test_sequence = ['0', '1', '0']  # String representations of token indices
test_probs = [
    [0.833, 0.166],
    [0.7, 0.3],
    [1, 1]
]

print(f"Test sequence: {test_sequence}")
print(f"Test probabilities: {test_probs}")

# Verify the sequence
is_converged, conv_pos = model.verify_sequence(test_sequence, test_probs, tolerance=0.1)

print(f"\nVerification result:")
print(f"  Is converged: {is_converged}")
print(f"  Convergence position: {conv_pos}")
