In [1]:
import torch

In [2]:
from src.training.checkpoint import load_checkpoint

In [3]:
_, model, _, _ = load_checkpoint(checkpoint_path="checkpoints/shakespeare.pth")

In [4]:
from src.tokenizer.model import Tokenizer
from transformers import AutoTokenizer



In [5]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = Tokenizer(tokenizer)

In [6]:
from src.datasets.shakespeare.dataset import Dataset as ShakespeareDataset
from src.datasets.dataloader import DataLoader

In [7]:
dataset_path = "data/shakespeare"
max_seq_len = 32

In [8]:
test_ds = ShakespeareDataset(f"{dataset_path}/test", max_seq_len)

In [9]:
batch_size = 1
min_ratio: int = 2
max_ratio: int = 4
max_num_spans: int = 6
max_span_fill: float = 0.8
min_num_spans: int = 0
min_span_fill: float = 0
hard_fill = True

In [10]:
test_dl = DataLoader(
    ds=test_ds,
    batch_size=batch_size,
    enc_span_idx=tokenizer.enc_span_token,
    target_span_idx=tokenizer.targ_span_token,
    fill_idx=tokenizer.mask_token,
    eos_idx=tokenizer.eos_token,
    bos_idx=tokenizer.bos_token,
    min_ratio=min_ratio,
    max_ratio=max_ratio,
    max_num_spans=max_num_spans,
    max_span_fill=max_span_fill,
    min_num_spans=min_num_spans,
    min_span_fill=min_span_fill,
    hard_fill=hard_fill,
)

In [11]:
def decode_single(transition, emission, lookahead=False):
    vertex_count, vocab_size = emission.shape

    if lookahead:
        values, indices = emission.max(dim=1)
        transition = transition + values.unsqueeze(0)

    tokens = torch.argmax(emission, dim=1)
    edges = torch.argmax(transition, dim=1)

    edges[edges == 0] = vertex_count

    i = 0
    output = [tokens[i].item()]
    while i < vertex_count:
        i = edges[i].item()
        if i >= vertex_count:
            break
        output.append(tokens[i].item())

    return torch.tensor(output)

In [143]:
(batch,
enc,
targ,
dec_pos,
dec_v,
target_lens,
vertex_lens,
target_span_indices,
ratio) = test_dl.get_batch()
transition, emissions = model(
    enc_x=enc,
    dec_x_vocab=dec_v,
    dec_x_pos=dec_pos,
    vertex_lens=vertex_lens,
)

In [149]:
decoded = decode_single(transition[0], emissions[0], lookahead=True)

In [150]:
print(f"batch raw: {batch}\ndecoded batch: {tokenizer.tokenizer.decode(batch[0])}")

batch raw: tensor([[  198, 43468,   415,    25,   198, 15597,   534,  3470,  8059,   284]])
decoded batch: 
Pedant:
Keep your hundred pounds to


In [151]:
print(f"enc raw: {enc}\ndecoded enc: {tokenizer.tokenizer.decode(enc[0])}")

enc raw: tensor([[50256,   198, 43468, 50258,   198, 50258,  8059, 50258, 50256]])
decoded enc: <|endoftext|>
Ped
 pounds<|endoftext|>


In [152]:
print(f"target raw: {targ}\ndecoded target: {tokenizer.tokenizer.decode(targ[0])}")

target raw: tensor([[50256, 50259,   415,    25, 50259, 15597,   534,  3470, 50259,   284,
         50256]])
decoded target: <|endoftext|>ant:Keep your hundred to<|endoftext|>


In [153]:
print(f"decoded: {tokenizer.tokenizer.decode(decoded)}")

decoded: <|endoftext|>,Is of of<|endoftext|>
