### Setup


In [None]:
# !pip install torch transformers

In [12]:
# Define dataset

import torch

class ShakespeareDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        tokenizer,
        seq_len=256,
        max_samples=None,
        file_path="data/shakespeare.txt",
    ):
        self.tokenizer = tokenizer
        self.seq_len = seq_len

        # Read Shakespeare text
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()


        # Flatten to single sequence, tokeniz and reshape to (num_batches, seq_len) 
        tokens = self.tokenizer.encode(text)
        n_batches = len(tokens) // seq_len
        self.sequences = torch.tensor(tokens[:n_batches * seq_len], dtype=torch.long).reshape(n_batches, seq_len)
        if max_samples is not None:
            self.sequences = self.sequences[:max_samples]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return {"input_ids": seq, "labels": seq}



### Train GPT2

In [17]:
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig


# Hyperparameters
batch_size = 32
seq_len = 32
learning_rate = 1e-4
num_epochs = 1
num_samples = 10

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model_llm = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained("gpt2"))
dataset = ShakespeareDataset(tokenizer, seq_len=256, max_samples=num_samples)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

optimizer = torch.optim.AdamW(model_llm.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", total=len(train_dataloader))
    for batch in pbar:
        output = model_llm(**batch)
        loss = output.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_postfix(loss=loss.item())

    # Validation
    model_llm.eval()
    eval_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            output = model_llm(**batch)
            eval_loss += output.loss.item()
    eval_loss /= (len(val_dataloader) * batch_size)

    # Log
    sample_input_ids = torch.tensor(tokenizer.encode("Hello, how are you?"))
    sample_text = tokenizer.decode(model_llm.generate(sample_input_ids.reshape(1, -1), max_length=16)[0])
    print(f"[Epoch {epoch+1}] Train loss: {loss.item():.2f} | Val loss: {eval_loss:.2f} | Sample: {repr(sample_text)}")

        


Token indices sequence length is longer than the specified maximum sequence length for this model (338024 > 1024). Running this sequence through the model will result in indexing errors
Epoch 1: 0it [00:00, ?it/s]
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[Epoch 1] Train loss: 86.73 | Val loss: 0.34 | Sample: 'Hello, how are you?alkingalkingalkingalkingalkingalkingalkingalkingalkingalking'


In [18]:
sample_input_ids = torch.tensor(tokenizer.encode("Hello, how are you?"))
sample_text = tokenizer.decode(model_llm.generate(sample_input_ids.reshape(1, -1), max_length=16)[0])

# Save checkpoint
torch.save(
    {
        "epoch": epoch,
        "model_state_dict": model_llm.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss.item(),
        "eval_loss": eval_loss,
        "sample_text": sample_text,
    },
    f"checkpoints/gpt2-latest-dev.pth",
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


### Train PC

In [19]:
from transformers import AutoTokenizer


# # 1. Shakespeare dataset
file_path = "../data/shakespeare/main.txt"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
seq_len = 8
batch_size = 512
max_samples = 10

dataset = ShakespeareDataset(tokenizer, seq_len=seq_len, max_samples=max_samples)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# visualize
for i in range(min(2, max_samples)):
    print(repr(tokenizer.decode(dataset[i]['input_ids'])))

Token indices sequence length is longer than the specified maximum sequence length for this model (338024 > 1024). Running this sequence through the model will result in indexing errors


'First Citizen:\nBefore we proceed any'
' further, hear me speak.\n\n'


In [20]:
import torch
from tqdm import tqdm
from spflow.modules.rat import RatSPN
from spflow.modules.leaf import Categorical
from spflow.meta import Scope
from spflow import log_likelihood

torch.manual_seed(0)

num_features = seq_len
K = len(tokenizer) 
batch_size = 256

scope = Scope(list(range(num_features)))

leaf_layer = Categorical(
    scope=scope,
    out_channels=4,
    num_repetitions=2,
    K=K,
)

model_spn = RatSPN(
    leaf_modules=[leaf_layer],
    n_root_nodes=1,
    n_region_nodes=8,
    num_repetitions=2,
    depth=3,
    outer_product=False,
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
n_epochs = 100
log_every = 10
for epoch in range(n_epochs):
    for step, batch in tqdm(enumerate(dataloader), leave=False, total=len(dataloader)):
        optimizer.zero_grad()
        data = batch['input_ids']
        ll = log_likelihood(model_spn, data)          # (B,)
        loss = -ll.mean()                         # NLL
        loss.backward()
        optimizer.step()
    if epoch % log_every == 0:
        print(f"[Epoch {epoch+1}] Loss {loss.item():.2f}")


                                     

[Epoch 1] Loss 86.73


                                     

[Epoch 11] Loss 86.73


                                     

[Epoch 21] Loss 86.73


                                     

[Epoch 31] Loss 86.73


                                     

[Epoch 41] Loss 86.73


                                     

[Epoch 51] Loss 86.73


                                     

[Epoch 61] Loss 86.73


                                     

[Epoch 71] Loss 86.73


                                     

[Epoch 81] Loss 86.73


                                     

[Epoch 91] Loss 86.73


                                     

In [21]:
# Save checkpoint
torch.save(
    {
        "epoch": epoch,
        "model_state_dict": model_spn.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss.item(),
        "eval_loss": eval_loss,
        "sample_text": sample_text,
    },
    f"checkpoints/rat-latest-dev.pth",
)

### Max decode

In [76]:
from transformers import LogitsProcessorList


model_llm_state_dict = torch.load("checkpoints/gpt2-latest-dev.pth")["model_state_dict"]
model_llm.load_state_dict(model_llm_state_dict)

model_spn_state_dict = torch.load("checkpoints/rat-latest-dev.pth")["model_state_dict"]
model_spn.load_state_dict(model_spn_state_dict)

model_llm.eval()
model_spn.eval()

sample_input_ids = torch.tensor(tokenizer.encode("Hello, how are you?"))


# Generate text (greedy)
greedy_dict = model_llm.generate(sample_input_ids.reshape(1, -1), max_new_tokens=32, do_sample=False, output_scores=True, return_dict_in_generate=True)
y_hat_greedy = torch.cat(
    (sample_input_ids[1:].reshape(1, -1), greedy_dict['sequences'][:, :-1]),
    dim=-1
)
nll_greedy = torch.cat(greedy_dict['scores'], dim=0).gather(dim=1, index=y_hat_greedy).sum()


# Generate text (heuristic)
processors = LogitsProcessorList([
    # BanListProcessor(tokenizer, common_words, penalty=100.0)
])
heur_dict = model_llm.generate(
    sample_input_ids.reshape(1, -1), 
    max_new_tokens=32, 
    do_sample=False, 
    output_scores=True, 
    return_dict_in_generate=True,
    logits_processor=processors,
)
y_hat_heur = torch.cat(
    (sample_input_ids[1:].reshape(1, -1), greedy_dict['sequences'][:, :-1]),
    dim=-1
)
nll_heur = torch.cat(heur_dict['scores'], dim=0).gather(dim=1, index=y_hat_heur).sum()


print(f"NLL greedy: {nll_greedy.item():.2f}")
print(f"NLL heuristic: {nll_heur.item():.2f}")


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


NLL greedy: 66.43
NLL heuristic: 66.43
