### Setup


In [2]:
# !pip install torch transformers

In [3]:
# Define dataset

import torch

class ShakespeareDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        tokenizer,
        seq_len=256,
        max_samples=None,
        file_path="data/shakespeare.txt",
    ):
        self.tokenizer = tokenizer
        self.seq_len = seq_len

        # Read Shakespeare text
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()


        # Flatten to single sequence, tokeniz and reshape to (num_batches, seq_len) 
        tokens = self.tokenizer.encode(text)
        n_batches = len(tokens) // seq_len
        self.sequences = torch.tensor(tokens[:n_batches * seq_len], dtype=torch.long).reshape(n_batches, seq_len)
        if max_samples is not None:
            self.sequences = self.sequences[:max_samples]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return {"input_ids": seq, "labels": seq}



### Train GPT2

In [119]:
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig


# Hyperparameters
batch_size = 32
seq_len = 32
learning_rate = 1e-4
num_epochs = 1000
num_samples = 32
log_freq = 100

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model_llm = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained("gpt2"))
dataset = ShakespeareDataset(tokenizer, seq_len=256, max_samples=num_samples)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

optimizer = torch.optim.AdamW(model_llm.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", total=len(train_dataloader), leave=False)
    for batch in pbar:
        output = model_llm(**batch)
        loss = output.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_postfix(loss=loss.item())

    # Validation
    model_llm.eval()
    eval_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            output = model_llm(**batch)
            eval_loss += output.loss.item()
    eval_loss /= (len(val_dataloader) * batch_size)

    # Log
    if (epoch+1) % log_freq == 0:
        sample_input_ids = torch.tensor(tokenizer.encode("Hello, how are you?"))
        sample_text = tokenizer.decode(model_llm.generate(sample_input_ids.reshape(1, -1), max_length=16, do_sample=True)[0])
        print(f"[Epoch {epoch+1}] Val loss: {eval_loss:.2f} | Sample: {repr(sample_text)}")

        


Token indices sequence length is longer than the specified maximum sequence length for this model (338024 > 1024). Running this sequence through the model will result in indexing errors
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[Epoch 100] Val loss: 0.34 | Sample: 'Hello, how are you? Paper Paper PaperSharSharSharSharSharShar sports'


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[Epoch 200] Val loss: 0.34 | Sample: 'Hello, how are you? Paper Paper PaperSharSharSharSharSharShar sports'


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[Epoch 300] Val loss: 0.34 | Sample: 'Hello, how are you? Paper Paper PaperSharSharSharSharSharShar sports'


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[Epoch 400] Val loss: 0.34 | Sample: 'Hello, how are you? Paper Paper PaperSharSharSharSharSharShar sports'


                             

KeyboardInterrupt: 

In [19]:
sample_input_ids = torch.tensor(tokenizer.encode("Hello, how are you?"))
sample_text = tokenizer.decode(model_llm.generate(sample_input_ids.reshape(1, -1), max_length=16)[0])

# Save checkpoint
torch.save(
    {
        "epoch": epoch,
        "model_state_dict": model_llm.state_dict(),
        "val_loss": eval_loss,
    },
    f"checkpoints/gpt2-latest-dev.pth",
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


### Train PC

In [15]:
from transformers import AutoTokenizer

# 1. Shakespeare dataset
file_path = "../data/shakespeare/main.txt"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
seq_len = 8
batch_size = 512
max_samples = 2

dataset = ShakespeareDataset(tokenizer, seq_len=seq_len, max_samples=max_samples)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# visualize
for i in range(min(5, max_samples)):
    print(f"{repr(tokenizer.decode(dataset[i]['input_ids']))} | ids: {dataset[i]['input_ids']}")

Token indices sequence length is longer than the specified maximum sequence length for this model (338024 > 1024). Running this sequence through the model will result in indexing errors


'First Citizen:\nBefore we proceed any' | ids: tensor([ 5962, 22307,    25,   198,  8421,   356,  5120,   597])
' further, hear me speak.\n\n' | ids: tensor([2252,   11, 3285,  502, 2740,   13,  198,  198])


In [27]:
from tqdm import tqdm
from spflow.meta import Scope
from spflow.modules.rat import RatSPN
from spflow.modules.leaf import  Binomial, Categorical
from spflow import log_likelihood

depth = 3
n_region_nodes = 5
num_leaves = 5
num_repetitions = 2
n_root_nodes = 1
num_feature = seq_len
# n = torch.tensor(16) # total count for binomial distribution
n = torch.tensor(len(tokenizer))

scope = Scope(list(range(0, num_feature)))

# rat_leaf_layer = Binomial(scope=scope, n=n, out_channels=num_leaves, num_repetitions=num_repetitions)
rat_leaf_layer = Categorical(scope=scope, K=n, out_channels=num_leaves, num_repetitions=num_repetitions)
model_pc = RatSPN(
    leaf_modules=[rat_leaf_layer],
    n_root_nodes=n_root_nodes,
    n_region_nodes=n_region_nodes,
    num_repetitions=num_repetitions,
    depth=depth,
    outer_product=True,
    split_halves=True,
)

In [22]:
optimizer = torch.optim.Adam(model_pc.parameters(), lr=1e-2)
n_epochs = 1_000
log_every = 500
for epoch in range(n_epochs):
    for step, batch in tqdm(enumerate(dataloader), leave=False, total=len(dataloader)):
        optimizer.zero_grad()
        data = batch['input_ids']
        ll = log_likelihood(model_pc, data)      # (B,)
        loss = -ll.mean()                         # NLL
        loss.backward()
        optimizer.step()
    if epoch % log_every == 0:
        print(f"[Epoch {epoch}] Loss {loss.item():.2f}")

                                     

[Epoch 0] Loss 86.61


                                     

[Epoch 500] Loss 7.82


                                     

In [18]:
from spflow import sample_with_evidence

batch = next(iter(dataloader))
X_tensor = batch['input_ids'].to(torch.float32)

# Sample with evidence
evidence = X_tensor[0]
evidence[:3] = torch.nan
evidence = evidence.unsqueeze(0)
samples = sample_with_evidence(model_pc, evidence, is_mpe=True)

print(repr(tokenizer.decode(samples[0].to(torch.long))))


' further,: me speak.\n\n'


### Max decode

In [115]:
import torch
from transformers import LogitsProcessorList, LogitsProcessor


class HeuristicProcessor(LogitsProcessor):
    def __init__(self, model_pc, tokenizer, horizon=32, history=32, penalty=100):
        super().__init__()
        self.model_pc = model_pc
        self.tokenizer = tokenizer
        self.horizon = horizon
        self.history = history
        self.penalty = penalty

    def __call__(self, input_ids, scores):
        # input_ids: (B, L)
        # scores: (B, V)
        B, V, H_fut = input_ids.shape[0], scores.shape[1], self.horizon
        H_prev = min(self.history, input_ids.shape[1])

        # put evidence on same device as input_ids, with float dtype
        device = input_ids.device
        evidence = torch.full(
            (B, H_prev + H_fut),
            float("nan"),
            device=device,
            dtype=torch.float32,
        )
        evidence[:, :H_prev] = input_ids[:, -H_prev:]

        x_max = sample_with_evidence(self.model_pc, evidence, is_mpe=True)
        x_max = x_max.to(device=device, dtype=torch.long)  # (B, H_prev+H_fut)
        xt = x_max[:, H_prev]  # (B,)

        # make dirac distribution; match dtype/device of scores
        xt_oh = torch.nn.functional.one_hot(xt, num_classes=V).to(
            device=scores.device,
            dtype=scores.dtype,
        )  # (B, V)

        return scores - self.penalty * (1 - xt_oh)


def compute_nll(x, model) -> torch.Tensor:
    # Shape x: (B, T)
    # Will compute conditional NLL: -logp(x1:T | x0)
    out = model(x)
    logprobs = torch.nn.functional.log_softmax(out.logits, dim=-1)  # (B, T, V)
    logprob = logprobs[:, :-1].gather(dim=-1, index=x[:, 1:].unsqueeze(-1)) # (B, T, 1)
    return -logprob.sum(dim=1).sum() # nll


# Load best models
# model_llm.load_state_dict(torch.load("checkpoints/gpt2-latest-dev.pth")["model_state_dict"])
# model_pc.load_state_dict(torch.load("checkpoints/rat-latest-dev.pth")["model_state_dict"])
model_llm.eval()
model_pc.eval()


# Sample x0
sample_input_ids = torch.tensor(tokenizer.encode("\n"))
prompt = sample_input_ids.reshape(1, -1)
prompt_len = prompt.shape[1]

# Generate text (greedy)
greedy_dict = model_llm.generate(
    prompt,
    max_new_tokens=32,
    output_scores=True,
    return_dict_in_generate=True,
    do_sample=False, 
    num_beams=1
)
nll_greedy = compute_nll(greedy_dict['sequences'], model_llm)
text_greedy = tokenizer.decode(greedy_dict['sequences'][0])

# Generate text (heuristic)
processors = LogitsProcessorList([
    HeuristicProcessor(model_pc, tokenizer)
])
heur_dict = model_llm.generate(
    prompt,
    max_new_tokens=32,
    do_sample=False,
    output_scores=True,
    return_dict_in_generate=True,
    logits_processor=processors,
)
nll_heur = compute_nll(heur_dict['sequences'], model_llm)
text_heur = tokenizer.decode(heur_dict['sequences'][0])


print(f"NLL greedy: {nll_greedy.item():.2f} | {repr(text_greedy)}")
print(f"NLL heuristic: {nll_heur.item():.2f} | {repr(text_heur)}")


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


NLL greedy: 261.87 | '\nUpdatedUpdated Pry Pry Pry Pry Pry Pry Pry Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv Shiv loudspe loudspe'
NLL heuristic: 348.21 | '\nAndre ReillyThankfully imply equalityizational fact pollingchanges loopholeido"},Object Dignernandez resonance Cycleı Clippersbrush shunê Pages security reunion Kitchen perpend meantime regulatesverted398!'
