In the previous notebook we set up a HMM and forward computing process. This one will extend that idea, with a few differences. 

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression

import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✅ Using Apple M1/M2 GPU via MPS")
else:
    device = torch.device("cpu")
    print("⚠️ MPS not available, falling back to CPU")
# experimenting with torch
# unsqueeze function
x = torch.tensor([1, 2, 3, 4])
c = torch.unsqueeze(x, 0)  # adds a dimension at index 0
print(c)
b = torch.unsqueeze(x, 1)  # adds a dimension at index 1
print(b)

       # Hidden State Ha   Hb    Hc
transition_matrix = torch.tensor([[0.9,0.05,0.05],
                     [0.05,0.9,0.05],
                     [0.05,0.05,0.9]])

       # emissions  # A   B    C
emission_matrix = torch.tensor([[0.9,0.05,0.05],
                     [0.05,0.9,0.05],
                     [0.05,0.05,0.9]])


pi = torch.tensor([0.3,0.4,0.3])

eta = torch.tensor([0.9,0.05,0.05])

Now imagine the state and symbols are mixed into a single matrix. Design this matrix M to be able to update once every transition, that transform the current mixed state to a linear combination of it, which is the updated mixed state. Because M already inform us about the probability distribution, the update should include the symbol, and we choose one of the three msp matrix to update the probability. 

In [None]:
def compute_msp_matrices(A: torch.Tensor, B: torch.Tensor):
    """
    Given HMM transition matrix A [N, N] and emission matrix B [N, M],
    return a list of MSP matrices T_k, one for each observation symbol k.
    """
    N, M = B.shape
    T_list = []
    
    for k in range(M):  # loop over observation symbols
        emission_col = B[:, k]                      # shape (N,)
        T_k = A * emission_col.unsqueeze(0)         # broadcast multiply to shape (N, N)
        T_list.append(T_k)
    print("T_list:", T_list)
    return T_list  # list of [N x N] tensors, each is T_k

def generate_store_token_belief(T_list,pi:torch.tensor,eta,cycle,seed=None):
     token = []
     belief = []
     store = []
     eta = eta
     print("pi:", pi.shape[0])
     if seed is not None:
          torch.manual_seed(seed)
     for _ in range(cycle):
          dice = np.random.choice(pi.shape[0])
          #print("dice:", dice)
          t_x = T_list[dice]
          
          eta = eta @ t_x
          eta = eta / eta.sum()
          
          token.append(pi[dice])
          belief.append(eta)
          store.append(get_cartesian_from_barycentric(eta))
     token = torch.tensor(token)
     belief = torch.stack(belief)
     return token,belief,store

# helper function
def get_cartesian_from_barycentric(b):
    t = np.transpose(np.array([[0,0],[1,0],[0.5, np.sqrt(3)/2]])) # Triangle
    return t.dot(b)

print(get_cartesian_from_barycentric(eta))

def plot_beliefs_on_simplex(beliefs: torch.Tensor, title="Belief Trajectory"):
    assert beliefs.shape[1] == 3, "Only works for 3-state HMM"

    # Make triangle float32 to match beliefs
    v0 = torch.tensor([0.0, 0.0], dtype=beliefs.dtype)
    v1 = torch.tensor([1.0, 0.0], dtype=beliefs.dtype)
    v2 = torch.tensor([0.5, torch.sqrt(torch.tensor(3.0, dtype=beliefs.dtype)) / 2], dtype=beliefs.dtype)
    triangle = torch.stack([v0, v1, v2])  # [3, 2]
    print("triangle:", triangle)
    # Convert belief vectors to xy
    xy_coords = beliefs @ triangle  # [T, 2]
    print("xy_coords:", xy_coords.shape)
    print("examples:", xy_coords[0:50])
    # Plot
    plt.figure(figsize=(6, 6))
    plt.plot(*zip(*torch.cat([triangle, triangle[0].unsqueeze(0)], dim=0)), color='k', lw=1)
    plt.scatter(xy_coords[:, 0], xy_coords[:, 1], s=6, c=range(len(beliefs)), cmap='viridis')
    plt.title(title)
    plt.axis('equal')
    plt.axis('off')
    plt.show()


In [None]:
t_list = compute_msp_matrices(transition_matrix, emission_matrix)
tokens, beliefs,store = generate_store_token_belief(t_list, pi=pi,eta=eta, cycle=10000)
tokens = tokens.to(device)
beliefs = beliefs.to(device)

In [None]:
class ContinuousHMMDataset(torch.utils.data.Dataset):
    def __init__(self, tokens, beliefs, seq_len):
        self.tokens = tokens
        self.beliefs = beliefs
        self.seq_len = seq_len
        self.length = len(tokens) - seq_len - 1

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        x = self.tokens[idx : idx + self.seq_len]           # [seq_len]
        y = self.tokens[idx + 1 : idx + self.seq_len + 1]   # [seq_len]
        
        b = self.beliefs[idx + 1 : idx + self.seq_len + 1]  # [seq_len, 3]
        return x, y, b

In [None]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=2, n_layers=2, seq_len=32):
        super().__init__()
        self.seq_len = seq_len
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, seq_len, d_model))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4*d_model,
            activation="relu",
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, x, return_hidden=False):
        """
        x: [B, T] token indices
        returns:
            logits: [B, T, vocab_size]
            hidden: [B, T, d_model] if return_hidden=True
        """
        tok = self.token_emb(x)                   # [B, T, D]
        x = tok + self.pos_emb[:, :x.size(1), :]  # add positional encoding
        hidden = self.transformer(x)              # [B, T, D]
        logits = self.output(hidden)              # [B, T, vocab]

        if return_hidden:
            return logits, hidden
        return logits

In [None]:
def train_model(model, dataloader, epochs=10, lr=1e-3, device="cpu"):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for x, y, _ in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch}: Loss {total_loss:.4f}")
        
def extract_residuals(model, dataloader, device="cpu"):
    model.eval()
    all_hidden = []
    all_beliefs = []

    with torch.no_grad():
        for x, _, beliefs in dataloader:
            x = x.to(device)
            logits, hidden = model(x, return_hidden=True)
            all_hidden.append(hidden.cpu())
            all_beliefs.append(beliefs)

    reps = torch.cat(all_hidden, dim=0).reshape(-1, hidden.shape[-1])
    beliefs = torch.cat(all_beliefs, dim=0).reshape(-1, beliefs.shape[-1])
    return reps, beliefs

In [None]:
tokens = tokens.long()
tokens.to(device)
model = SimpleTransformer(vocab_size=3, d_model=64, n_heads=2, n_layers=2, seq_len=32)
model.to(device)
ds = ContinuousHMMDataset(tokens, beliefs, seq_len=32)
dataloader = torch.utils.data.DataLoader(ds, batch_size=16, shuffle=False, drop_last=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train_model(model, dataloader, epochs=4, device=device)

reps, belief_targets = extract_residuals(model, dataloader, device=device)

In [None]:
from sklearn.linear_model import LinearRegression

X = reps.detach().cpu().numpy()
Y = belief_targets.detach().cpu().numpy()

probe = LinearRegression()
probe.fit(X, Y)

Y_pred = probe.predict(X)
pred_tensor = torch.tensor(Y_pred, dtype=torch.float32)
plot_beliefs_on_simplex(pred_tensor, title="Transformer Residual → Belief Geometry")