### Data

In [1]:
import torch

class ShakespeareDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        tokenizer,
        seq_len=256,
        max_samples=None,
        file_path="../data/shakespeare/main.txt",
    ):
        self.tokenizer = tokenizer
        self.seq_len = seq_len

        # Read Shakespeare text
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()

        tokens = self.tokenizer.encode(text)
        n_batches = len(tokens) // seq_len
        self.sequences = torch.tensor(tokens[:n_batches * seq_len], dtype=torch.long).reshape(n_batches, seq_len)
        self.sequences = self.sequences[:max_samples]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return {"input_ids": seq, "length": len(seq)}


### HMM

In [6]:
import torch

def log_domain_matmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Normally, a matrix multiplication
	computes out_{i,j} = sum_k A_{i,k} x B_{k,j}

	A log domain matrix multiplication
	computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	# log_A_expanded = torch.stack([log_A] * p, dim=2)
	# log_B_expanded = torch.stack([log_B] * m, dim=0)
    # fix for PyTorch > 1.5 by egaznep on Github:
	log_A_expanded = torch.reshape(log_A, (m,n,1))
	log_B_expanded = torch.reshape(log_B, (1,n,p))

	elementwise_sum = log_A_expanded + log_B_expanded
	out = torch.logsumexp(elementwise_sum, dim=1)

	return out

class TransitionModel(torch.nn.Module):
  def __init__(self, N):
    super(TransitionModel, self).__init__()
    self.N = N
    self.unnormalized_transition_matrix = torch.nn.Parameter(torch.randn(N,N))

  def forward(self, log_alpha):
    """
    log_alpha : Tensor of shape (batch size, N)
    Multiply previous timestep's alphas by transition matrix (in log domain)
    """
    log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

    # Matrix multiplication in the log domain
    out = log_domain_matmul(log_transition_matrix, log_alpha.transpose(0,1)).transpose(0,1)
    return out

class EmissionModel(torch.nn.Module):
  def __init__(self, N, M):
    super(EmissionModel, self).__init__()
    self.N = N
    self.M = M
    self.unnormalized_emission_matrix = torch.nn.Parameter(torch.randn(N,M))

  def forward(self, x_t):
    log_emission_matrix = torch.nn.functional.log_softmax(self.unnormalized_emission_matrix, dim=1)
    out = log_emission_matrix[:, x_t].transpose(0,1)
    return out

class HMM(torch.nn.Module):
  """
  Hidden Markov Model with discrete observations.
  """
  def __init__(self, M, N):
    super(HMM, self).__init__()
    self.M = M # number of possible observations
    self.N = N # number of states

    # A
    self.transition_model = TransitionModel(self.N)

    # b(x_t)
    self.emission_model = EmissionModel(self.N,self.M)

    # pi
    self.unnormalized_state_priors = torch.nn.Parameter(torch.randn(self.N))

    # use the GPU
    self.is_cuda = torch.cuda.is_available()
    if self.is_cuda: self.cuda()

  def sample(self, T=32):
    state_priors = torch.nn.functional.softmax(self.unnormalized_state_priors, dim=0)
    transition_matrix = torch.nn.functional.softmax(self.transition_model.unnormalized_transition_matrix, dim=0)
    emission_matrix = torch.nn.functional.softmax(self.emission_model.unnormalized_emission_matrix, dim=1)

    # sample initial state
    z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
    z = []; x = []
    z.append(z_t)
    for t in range(0,T):
      # sample emission
      x_t = torch.distributions.categorical.Categorical(emission_matrix[z_t]).sample().item()
      x.append(x_t)

      # sample transition
      z_t = torch.distributions.categorical.Categorical(transition_matrix[:,z_t]).sample().item()
      if t < T-1: z.append(z_t)

    return x, z

  def forward(self, x, T):
    """
    x : IntTensor of shape (batch size, T_max)
    T : IntTensor of shape (batch size)

    Compute log p(x) for each example in the batch.
    T = length of each example
    """
    if self.is_cuda:
      x = x.cuda()
      T = T.cuda()

    batch_size = x.shape[0]; T_max = x.shape[1]
    log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
    log_alpha = torch.zeros(batch_size, T_max, self.N)
    if self.is_cuda: log_alpha = log_alpha.cuda()

    log_alpha[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
    for t in range(1, T_max):
      log_alpha[:, t, :] = self.emission_model(x[:,t]) + self.transition_model(log_alpha[:, t-1, :])

    # Select the sum for the final timestep (each x may have different length).
    log_sums = log_alpha.logsumexp(dim=2)
    log_probs = torch.gather(log_sums, 1, T.view(-1,1) - 1)
    return log_probs

In [7]:
# Test loss computation 
model = HMM(M=32, N=32)
x, T = torch.randint(0, 32, (10, 10)), torch.randint(1, 11, (10,))
loss = -model(x,T).sum()
print(loss)

tensor(218.0242, grad_fn=<NegBackward0>)


In [None]:
# Train loop
import torch
from transformers import AutoTokenizer

# Hyperparameters
lr = 1e-2
batch_size = 32
seq_len_train = 32
seq_len_test = 32
n_hidden = 32

# Data
tokenizer = AutoTokenizer.from_pretrained("gpt2")
dataset = ShakespeareDataset(tokenizer, seq_len=seq_len_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model
model = HMM(M=len(tokenizer), N=n_hidden)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.00001)

# Train loop
for epoch in range(1):
    for idx, batch in enumerate(dataloader):
        optimizer.zero_grad()
        x,T = batch['input_ids'], batch['length']
        logp = model(x,T)
        loss = -logp.mean()
        loss.backward()
        optimizer.step()

        if idx % 10 == 0:
            sample = tokenizer.decode(model.sample(seq_len_test)[0])
            print(f"[Epoch {epoch}][{idx}/{len(dataloader)}] Loss: {loss.item():.2f} | {repr(sample)}")


[Epoch 0][0/331] Loss: 202.23 | "Palest himself thouear entAnd shall\nMFDlow upper;,TH stoppCall himself soon allAh behalf ' your,morrow of, enter findFor"
[Epoch 0][10/331] Loss: 199.26 | " scept this'llLord d her marry Lord endless FTP supposed words good mayuling:itions challenged is fromay\n thee flesh,y are fatherI's,\n"
[Epoch 0][20/331] Loss: 201.00 | ' me onK\n ourBut rememberED\nEOardon\n other rather\n, let thee we if\nESIyour thatAR this buthim vile golden thy'
[Epoch 0][30/331] Loss: 200.21 | ",They\n a give credit Ipired did: mad out intis nor waterForAn\n thunder youmen paTo served to wife,\n\n,'s"
[Epoch 0][40/331] Loss: 203.69 | ', you onUSess trees me:: me your sight of which\n windows this\n manners, her into I patience\n, them\n use about\nTrue'
[Epoch 0][50/331] Loss: 206.59 | " Cl\n exhib his the is alled:And IIVERS fleet\n bullU\nESS ofare forbID InterMy\n and you on ':ck,"
[Epoch 0][60/331] Loss: 209.41 | "\n,, blood fal them a V your and my, bothBR IHAM lord hat

### PC

In [None]:
!pip install git+https://github.com/SPFlow/SPFlow.git
# !pip install --upgrade pip

In [152]:
from transformers import AutoTokenizer


# # 1. Shakespeare dataset
file_path = "../data/shakespeare/main.txt"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
seq_len = 8
batch_size = 512
max_samples = 10

dataset = ShakespeareDataset(tokenizer, seq_len=seq_len, max_samples=max_samples)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# visualize
for i in range(min(2, max_samples)):
    print(repr(tokenizer.decode(dataset[i]['input_ids'])))


# 2. Simple data
# data = torch.zeros(batch_size, num_features, dtype=torch.long)
# data[:, 1] = 1

Token indices sequence length is longer than the specified maximum sequence length for this model (338024 > 1024). Running this sequence through the model will result in indexing errors


'First Citizen:\nBefore we proceed any'
' further, hear me speak.\n\n'


In [162]:
import torch
from tqdm import tqdm
from spflow.modules.rat import RatSPN
from spflow.modules.leaf import Categorical
from spflow.meta import Scope
from spflow import log_likelihood

torch.manual_seed(0)

num_features = seq_len
K = len(tokenizer) 
batch_size = 256

scope = Scope(list(range(num_features)))

leaf_layer = Categorical(
    scope=scope,
    out_channels=4,
    num_repetitions=2,
    K=K,
)

model = RatSPN(
    leaf_modules=[leaf_layer],
    n_root_nodes=1,
    n_region_nodes=8,
    num_repetitions=2,
    depth=3,
    outer_product=False,
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
n_epochs = 1000
log_every = 200
for epoch in range(n_epochs):
    for step, batch in tqdm(enumerate(dataloader), leave=False, total=len(dataloader)):
        optimizer.zero_grad()
        data = batch['input_ids']
        ll = log_likelihood(model, data)          # (B,)
        loss = -ll.mean()                         # NLL
        loss.backward()
        optimizer.step()
    if epoch % log_every == 0:
        print(f"[{epoch:4}] [{step:4}] Loss {loss.item():.2f}")


                                             

[   0] [   0] Loss 16.83


                                     

[ 200] [   0] Loss 16.71


                                     

[ 400] [   0] Loss 16.65


                                     

[ 600] [   0] Loss 16.60


                                     

[ 800] [   0] Loss 16.57


                                     

In [161]:
samples = sample(model, 10).to(torch.long)
tokenizer.batch_decode(samples)

[': resolved hear isAll, to Citizen',
 '\n.First to speak, to.',
 ' people resolved\n to you: to\n',
 ' people\n\nSpe: thanFirst.',
 ' all\n hear to chief\n\n are',
 'First. resolved Talk speak thanFirst any',
 ' people. hear\nBefore\n to the',
 'border Citizen\n\n: thanYouRes',
 '\n\n rather toBefore Sv CCT',
 'olved resolvedFirst toAll\n to\n']

In [10]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LogitsProcessor,
    LogitsProcessorList,
)
import torch

# Load GPT-2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()

# ---------------------------------------------------------
# 1. Normal generation
# ---------------------------------------------------------

prompt = "In a future city, artificial intelligence"

inputs = tokenizer(prompt, return_tensors="pt")

print("\n=== Normal generation ===\n")
normal_output = model.generate(
    **inputs,
    max_length=60,
    do_sample=True,
    top_p=0.9,
    temperature=1.0,
)
print(tokenizer.decode(normal_output[0], skip_special_tokens=True))


# ---------------------------------------------------------
# 2. Logits processor: Ban a list of words
# ---------------------------------------------------------

class BanListProcessor(LogitsProcessor):
    def __init__(self, tokenizer, words, penalty=100.0):
        super().__init__()
        self.tokenizer = tokenizer
        self.penalty = penalty

        vocab = tokenizer.get_vocab()          # token_str -> id
        self.word2ids = {}

        # For each word, find all tokens that decode to that word (ignoring whitespace)
        for w in words:
            ids_for_word = []
            for tok_str, tid in vocab.items():
                decoded = tokenizer.decode([tid])
                if decoded.strip() == w:
                    ids_for_word.append(tid)
            if ids_for_word:
                self.word2ids[w] = ids_for_word

        # Flatten list of all ids for fast application
        self.all_ids = [tid for ids in self.word2ids.values() for tid in ids]

        # Print mapping: word -> ids -> decoded tokens
        print("\nBanned words and their token forms:")
        for w, ids in self.word2ids.items():
            decoded_tokens = [repr(tokenizer.decode([tid])) for tid in ids]
            print(f"  {w:>5}: ids={ids}, tokens={decoded_tokens}")
        print("All banned token IDs:", self.all_ids)

    def __call__(self, input_ids, scores):
        # Subtract large penalty → exp(logit - penalty) ≈ zero
        for tid in self.all_ids:
            scores[:, tid] -= self.penalty
        return scores


# 10 most common English words
common_words = ["the", "of", "and", "to", "a", "in", "is", "you", "that", "it", "than", "that's"]

processors = LogitsProcessorList([
    BanListProcessor(tokenizer, common_words, penalty=100.0)
])


# ---------------------------------------------------------
# 3. Generation with banned common words
# ---------------------------------------------------------

print("\n=== With LogitsProcessor (banning 10 common words) ===\n")

banned_output = model.generate(
    **inputs,
    max_length=60,
    do_sample=True,
    top_p=0.9,
    temperature=1.0,
    logits_processor=processors,
)

print(tokenizer.decode(banned_output[0], skip_special_tokens=True))


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



=== Normal generation ===

In a future city, artificial intelligence can create useful information about where it wants to go, and it may be able to help police in a given way.

For example, the police could use AI to predict where people are headed at a particular time, or where they are going. The police


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Banned words and their token forms:
    the: ids=[1169, 262], tokens=["'the'", "' the'"]
     of: ids=[286, 1659], tokens=["' of'", "'of'"]
    and: ids=[392, 290], tokens=["'and'", "' and'"]
     to: ids=[284, 1462], tokens=["' to'", "'to'"]
      a: ids=[64, 257], tokens=["'a'", "' a'"]
     in: ids=[287, 259], tokens=["' in'", "'in'"]
     is: ids=[271, 318], tokens=["'is'", "' is'"]
    you: ids=[345, 5832], tokens=["' you'", "'you'"]
   that: ids=[326, 5562], tokens=["' that'", "'that'"]
     it: ids=[270, 340], tokens=["'it'", "' it'"]
   than: ids=[14813, 621], tokens=["'than'", "' than'"]
All banned token IDs: [1169, 262, 286, 1659, 392, 290, 284, 1462, 64, 257, 287, 259, 271, 318, 345, 5832, 326, 5562, 270, 340, 14813, 621]

=== With LogitsProcessor (banning 10 common words) ===

In a future city, artificial intelligence will probably only be used for simple tasks, such as writing or drawing. For example, imagine your car sits on an empty lot, only there are vehicles parked r