In [1]:
import sys
import os

parent_dir = os.path.abspath(os.path.join(os.path.dirname(os.getcwd()), "./"))
sys.path.append(parent_dir)

In [2]:
from src.datasets.wikipedia.dataset import Dataset as WikipediaDataset
from transformers import AutoTokenizer
from src.tokenizer.model import Tokenizer
#from torch.utils.data import DataLoader
from src.datasets.dataloader import DataLoader
from src.training.checkpoint import load_checkpoint
import torch
from torch import Tensor

In [3]:
batch_size = 1
min_ratio: int = 3
max_ratio: int = 3
max_num_spans: int = 1
max_span_fill: float = 0.15
min_num_spans: int = 1
min_span_fill: float = 0.15
hard_fill = True

In [4]:
dataset_path = "../data/wikipedia"
max_seq_len = 96

In [5]:
test_ds = WikipediaDataset(f"{dataset_path}/test", max_seq_len)

In [6]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = Tokenizer(tokenizer)

In [7]:
test_dl = DataLoader(
    ds=test_ds,
    batch_size=batch_size,
    enc_span_idx=tokenizer.enc_span_token,
    target_span_idx=tokenizer.targ_span_token,
    fill_idx=tokenizer.mask_token,
    eos_idx=tokenizer.eos_token,
    bos_idx=tokenizer.bos_token,
    min_ratio=min_ratio,
    max_ratio=max_ratio,
    max_num_spans=max_num_spans,
    max_span_fill=max_span_fill,
    min_num_spans=min_num_spans,
    min_span_fill=min_span_fill,
    hard_fill=hard_fill,
)


In [8]:
device = "cpu"

checkpoint_path = "../checkpoints"
checkpoint_name = "wikipedia_one_span.pth"
writer_path = "runs/wikipedia_one_span"


epoch, model, optimizer, writer = load_checkpoint(
    checkpoint_path=f"{checkpoint_path}/{checkpoint_name}",
    optim_class=None,
    open_writer=False,
    device=device,
)

In [9]:
def greedy_decode_single(transitions: Tensor, emissions: Tensor):
    """
    Greedy traversal of DAG to decode sequence

    @param transitions: shape (num_vertices, num_vertices)
    @param emissions: shape (num_vertices, vocab_size)

    @return: decoded sequence
    """
    num_vertices, vocab_size = emissions.shape
    tokens = emissions.argmax(dim=-1)
    edges = transitions.argmax(dim=-1)
    i = 0
    output = []
    while i < num_vertices:
        output.append(tokens[i])
        i = edges[i]
        if i == 0:
            break
    return output

In [10]:
iterator = iter(test_dl)

In [11]:
model = model.eval()

In [12]:
(
    batch,
    enc,
    targ,
    dec_pos,
    dec_v,
    target_lens,
    vertex_lens,
    target_span_indices,
    ratio,
) = next(iterator)

In [13]:
transition_probs, emission_probs = model(
    enc_x=enc, dec_x_vocab=dec_v, dec_x_pos=dec_pos, vertex_lens=vertex_lens
)

In [14]:
transition_probs[0].argmax(dim=-1)  

tensor([ 3,  4,  4,  4,  6,  6,  7,  9, 45, 12, 25, 16, 13, 15, 16, 16, 19, 19,
        47, 22, 36, 47, 25, 47, 47, 33, 36, 47, 47, 36, 47, 47, 36, 37, 45, 36,
        47, 45, 48, 50, 47, 45, 47, 47, 45, 50, 48, 48, 50, 50,  0])

In [15]:
decoded = greedy_decode_single(transition_probs[0], emission_probs[0])

In [17]:
tokenizer.tokenizer.decode(decoded)

'<|endoftext|>K River River.\n\n\n\n\nGes of Mas<|endoftext|>'

In [19]:
tokenizer.tokenizer.decode(targ[0])

'<|endoftext|>Beitbridge road.\n\nRivers of Zimbabwe\nMas<|endoftext|>'