### Setup


In [9]:
# !pip install torch transformers

In [10]:
# Define dataset

import torch

class ShakespeareDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        tokenizer,
        seq_len=256,
        max_samples=None,
        file_path="data/shakespeare.txt",
    ):
        self.tokenizer = tokenizer
        self.seq_len = seq_len

        # Read Shakespeare text
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()


        # Flatten to single sequence, tokeniz and reshape to (num_batches, seq_len) 
        tokens = self.tokenizer.encode(text)
        n_batches = len(tokens) // seq_len
        self.sequences = torch.tensor(tokens[:n_batches * seq_len], dtype=torch.long).reshape(n_batches, seq_len)
        if max_samples is not None:
            self.sequences = self.sequences[:max_samples]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return {"input_ids": seq, "labels": seq}



### Train GPT2

In [12]:
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig


# Hyperparameters
batch_size = 32
seq_len = 32
learning_rate = 1e-4
num_epochs = 1
num_samples = 10

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model_llm = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained("gpt2"))
dataset = ShakespeareDataset(tokenizer, seq_len=256, max_samples=num_samples)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

optimizer = torch.optim.AdamW(model_llm.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", total=len(train_dataloader))
    for batch in pbar:
        output = model_llm(**batch)
        loss = output.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_postfix(loss=loss.item())

    # Validation
    model_llm.eval()
    eval_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            output = model_llm(**batch)
            eval_loss += output.loss.item()
    eval_loss /= (len(val_dataloader) * batch_size)

    # Log
    sample_input_ids = torch.tensor(tokenizer.encode("Hello, how are you?"))
    sample_text = tokenizer.decode(model_llm.generate(sample_input_ids.reshape(1, -1), max_length=16)[0])
    print(f"[Epoch {epoch+1}] Val loss: {eval_loss:.2f} | Sample: {repr(sample_text)}")

        


Token indices sequence length is longer than the specified maximum sequence length for this model (338024 > 1024). Running this sequence through the model will result in indexing errors
Epoch 1: 0it [00:00, ?it/s]
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[Epoch 1] Val loss: 0.35 | Sample: 'Hello, how are you?claveclaveclaveclaveclave caps caps caps caps caps'


In [14]:
sample_input_ids = torch.tensor(tokenizer.encode("Hello, how are you?"))
sample_text = tokenizer.decode(model_llm.generate(sample_input_ids.reshape(1, -1), max_length=16)[0])

# Save checkpoint
torch.save(
    {
        "epoch": epoch,
        "model_state_dict": model_llm.state_dict(),
        "val_loss": eval_loss,
    },
    f"checkpoints/gpt2-latest-dev.pth",
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


### Train PC

In [52]:
from transformers import AutoTokenizer


# # 1. Shakespeare dataset
file_path = "../data/shakespeare/main.txt"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
seq_len = 8
batch_size = 512
max_samples = 2

dataset = ShakespeareDataset(tokenizer, seq_len=seq_len, max_samples=max_samples)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# visualize
for i in range(min(5, max_samples)):
    print(f"{repr(tokenizer.decode(dataset[i]['input_ids']))} | ids: {dataset[i]['input_ids']}")

Token indices sequence length is longer than the specified maximum sequence length for this model (338024 > 1024). Running this sequence through the model will result in indexing errors


'First Citizen:\nBefore we proceed any' | ids: tensor([ 5962, 22307,    25,   198,  8421,   356,  5120,   597])
' further, hear me speak.\n\n' | ids: tensor([2252,   11, 3285,  502, 2740,   13,  198,  198])


In [None]:
import torch
from tqdm import tqdm
from spflow.modules.rat import RatSPN
from spflow.modules.leaf import Categorical
from spflow.meta import Scope
from spflow import log_likelihood
from spflow.learn import train_gradient_descent

torch.manual_seed(0)

num_features = seq_len
K = len(tokenizer) 
batch_size = 256

scope = Scope(list(range(num_features)))

leaf_layer = Categorical(
    scope=scope,
    out_channels=4,
    num_repetitions=2,
    K=K,
)

model_spn = RatSPN(
    leaf_modules=[leaf_layer],
    n_root_nodes=1,
    n_region_nodes=8,
    num_repetitions=2,
    depth=3,
    outer_product=False,
)

train_gradient_descent(model_spn, dataloader, epochs=20, lr=0.1)

# optimizer = torch.optim.Adam(model_spn.parameters(), lr=1e-2)
# n_epochs = 2_000
# log_every = 1000
# for epoch in range(n_epochs):
#     for step, batch in tqdm(enumerate(dataloader), leave=False, total=len(dataloader)):
#         optimizer.zero_grad()
#         data = batch['input_ids']
#         ll = log_likelihood(model_spn, data)          # (B,)
#         loss = -ll.mean()                         # NLL
#         loss.backward()
#         optimizer.step()
#     if epoch % log_every == 0:
#         print(f"[Epoch {epoch}] Loss {loss.item():.2f}")


ImportError: cannot import name 'train_gradient_descent' from 'spflow' (/Users/marawangamal/Documents/github/llm-mpe/.venv/lib/python3.13/site-packages/spflow/__init__.py)

In [56]:
# Save checkpoint
torch.save(
    {
        "epoch": epoch,
        "model_state_dict": model_spn.state_dict(),
        "eval_loss": eval_loss,
    },
    f"checkpoints/rat-latest-dev.pth",
)

In [None]:
# Sample use of PC
from spflow import sample, sample_with_evidence
# # MPE:
# sample(model_spn, 1, is_mpe=True)

# CON:
# evidence = data.clone().float()
# evidence[:2] = torch.nan
# sample_with_evidence(model_spn, evidence)

# evidence = X_tensor[0]
# evidence[:32] = torch.nan
# plt.imshow(evidence.reshape(8,8), cmap="gray")
# plt.show()
# evidence = evidence.unsqueeze(0)
# print(evidence.shape)
# samples = sample_with_evidence(rat, evidence)
# plt.imshow(samples.reshape(8,8), cmap="gray")
# plt.show()

RuntimeError: Size does not match at dimension 0 expected index [4, 1, 8, 1, 1] to be no larger than self [2, 1, 8, 1, 2] apart from dimension 4

In [58]:
batch = next(iter(dataloader))

In [59]:
import torch
from spflow import sample
from spflow.meta import SamplingContext

def spn_sample_with_evidence(model_spn: RatSPN, evidence: torch.Tensor) -> torch.Tensor:
    """
    evidence:
      - shape (seq_len,) or (B, seq_len)
      - entries that are NaN will be sampled
      - non-NaN entries are treated as clamped evidence

    Returns:
      samples: same shape, dtype long (token ids)
    """
    device = next(model_spn.parameters()).device

    # Make it (B, seq_len)
    if evidence.dim() == 1:
        evidence = evidence.unsqueeze(0)  # (1, seq_len)

    # For categorical leaves we still use float here so we can hold NaNs
    evidence = evidence.to(torch.float32).to(device)  # (B, seq_len)

    B, out_features = evidence.shape

    # mask == True -> sample here, mask == False -> keep evidence
    mask = torch.isnan(evidence)  # (B, seq_len)

    # If there is absolutely nothing to sample, just return the clamped evidence
    if not mask.any():
        return evidence.round().to(torch.long)

    # All channels 0 (youâ€™re using out_channels=4 but we pick one repetition/channel)
    channel_index = torch.zeros((B, out_features), dtype=torch.int64, device=device)

    sampling_ctx = SamplingContext(
        channel_index=channel_index,
        mask=mask,
    )

    samples_float = sample(model_spn, evidence, sampling_ctx=sampling_ctx)
    samples_long = samples_float.round().to(torch.long)

    return samples_long




model_spn.eval()
data = batch["input_ids"]          # (B, seq_len), dtype long

# Pick a single sequence
seq = data[0].clone()              # (seq_len,)

# Mask some positions, e.g. positions 10..19
seq = seq.to(torch.float32)
seq[5:] = float("nan")

# Sample from SPN conditional on the rest
samples = spn_sample_with_evidence(model_spn, seq)  # (1, seq_len)
samples = samples[0]  # drop batch dim

print("Original:", data[0])
print("Evidence with NaNs:", seq)
print("Sampled:", samples)


RuntimeError: Size does not match at dimension 1 expected index [1, 8, 8, 1] to be no larger than self [1, 1, 8, 1] apart from dimension 3

### Max decode

In [19]:
from transformers import LogitsProcessor

class HeuristicProcessor(LogitsProcessor):
    def __init__(self, model_pc):
        super().__init__()
        self.model_pc = model_pc

    def __call__(self, input_ids, scores):
        # does nothing
        return scores

In [None]:
import torch
from transformers import LogitsProcessorList


model_llm_state_dict = torch.load("checkpoints/gpt2-latest-dev.pth")["model_state_dict"]
model_llm.load_state_dict(model_llm_state_dict)

model_spn_state_dict = torch.load("checkpoints/rat-latest-dev.pth")["model_state_dict"]
model_spn.load_state_dict(model_spn_state_dict)

model_llm.eval()
model_spn.eval()

sample_input_ids = torch.tensor(tokenizer.encode("Hello, how are you?"))


# Generate text (greedy)
greedy_dict = model_llm.generate(sample_input_ids.reshape(1, -1), max_new_tokens=32, do_sample=False, output_scores=True, return_dict_in_generate=True)
y_hat_greedy = torch.cat(
    (sample_input_ids[1:].reshape(1, -1), greedy_dict['sequences'][:, :-1]),
    dim=-1
)
nll_greedy = torch.cat(greedy_dict['scores'], dim=0).gather(dim=1, index=y_hat_greedy).sum()


# Generate text (heuristic)
processors = LogitsProcessorList([
    HeuristicProcessor(model_spn)
])
heur_dict = model_llm.generate(
    sample_input_ids.reshape(1, -1), 
    max_new_tokens=32, 
    do_sample=False, 
    output_scores=True, 
    return_dict_in_generate=True,
    logits_processor=processors,
)
y_hat_heur = torch.cat(
    (sample_input_ids[1:].reshape(1, -1), greedy_dict['sequences'][:, :-1]),
    dim=-1
)
nll_heur = torch.cat(heur_dict['scores'], dim=0).gather(dim=1, index=y_hat_heur).sum()


print(f"NLL greedy: {nll_greedy.item():.2f}")
print(f"NLL heuristic: {nll_heur.item():.2f}")


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


NLL greedy: 29.62
NLL heuristic: 29.62
