In [4]:
import numpy as np
import torch as t
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import pearsonr
from sklearn.decomposition import PCA

class Mess3Process:
    def __init__(self):
        self.T_A = t.tensor([[0.765, 0.1175, 0.1175],
                             [0.915, 0.0675, 0.0175],
                             [0.91, 0.0025, 0.0875]])
        self.T_B = t.tensor([[0.45, 0.45, 0.1],
                             [0.1, 0.8, 0.1],
                             [0.1, 0.1, 0.8]])
        self.T_C = t.tensor([[0.45, 0.1, 0.45],
                             [0.1, 0.45, 0.45],
                             [0.1, 0.1, 0.8]])
        self.tokens = ['A', 'B', 'C']
        self.num_states = 3
        self.T_A = self.T_A / self.T_A.sum(dim=1, keepdim=True)
        self.T_B = self.T_B / self.T_B.sum(dim=1, keepdim=True)
        self.T_C = self.T_C / self.T_C.sum(dim=1, keepdim=True)

    def generate_sequence(self, length):
        states = t.zeros(length, dtype=t.long)
        observations = []
        current_state = t.randint(0, self.num_states, (1,)).item()
        for t_idx in range(length):
            states[t_idx] = current_state
            T_choice = t.randint(0, 3, (1,)).item()
            T = self.T_A if T_choice == 0 else self.T_B if T_choice == 1 else self.T_C
            probs = T[current_state]
            token_idx = t.multinomial(probs, 1).item()
            token = self.tokens[token_idx]
            observations.append(token)
            current_state = t.multinomial(probs, 1).item()
        return states, observations

In [7]:
import wandb
proc = Mess3Process()

In [None]:
# build a trasnformer
from transformer_lens import HookedTransformer, HookedTransformerConfig
from dataclasses import dataclass

# initialize model config

cfg = HookedTransformerConfig(
    n_layers=4,
    d_model=64,
    n_ctx=10,
    d_vocab=3,
    n_heads=1,
    d_head=8,
    d_mlp=256,
    act_fn="relu",
)

# dataset generation and model training config
@dataclass
class TrainConfig:
    learning_rate: float = 1e-2
    batch_size: int = 64
    n_epoch = 1_000_000
    weight_decay: float | None = None
    checkpoint_every: int = 100


tensor([0, 0, 0, 2, 0, 0, 1, 2, 0, 0, 0, 0, 1, 0, 1])
['A', 'A', 'A', 'A', 'C', 'A', 'B', 'C', 'A', 'A', 'A', 'B', 'A', 'B', 'B']


In [None]:
from tqdm import tqdm
def train_model(
    cfg: HookedTransformerConfig,
    tcfg: TrainConfig,
    device: str = "cuda" if t.cuda.is_available() else "cpu",
) -> HookedTransformer:
    # initialize model
    model = HookedTransformer(cfg)
    model.to(device)
    
    # Optimizer
    optimizer = t.optim.Adam(
        model.parameters(),
        lr=tcfg.learning_rate,
    )
    run = wandb.init(
        project="toy-transformer-markov-chain",
        name=f"transformer-epochs-{tcfg.n_epoch}",
    )
    
    # training loop begins
    for epoch in tqdm(range(tcfg.n_epoch)):
        pass
    
    
    
    
