In [1]:
import sys
import os

parent_dir = os.path.abspath(os.path.join(os.path.dirname(os.getcwd()), "./"))
sys.path.append(parent_dir)

In [2]:
from src.datasets.wikipedia.dataset import Dataset as WikipediaDataset
from transformers import AutoTokenizer
from src.tokenizer.model import Tokenizer
#from torch.utils.data import DataLoader
from src.datasets.dataloader import DataLoader
from src.training.checkpoint import load_checkpoint
import torch
from torch import Tensor
from src.nn.models.transformer import Transformer
from safetensors.torch import load_file

In [3]:
data = torch.load("wikipedia_cuda_dist_fsdp.pth")

In [4]:
batch_size = 1
min_ratio: int = 3
max_ratio: int = 3
max_num_spans: int = 1
max_span_fill: float = 0.15
min_num_spans: int = 1
min_span_fill: float = 0.15
hard_fill = True

In [5]:
dataset_path = "../data/wikipedia"
max_seq_len = 96

In [6]:
test_ds = WikipediaDataset(f"{dataset_path}/test", max_seq_len)

In [7]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = Tokenizer(tokenizer)

In [8]:
test_dl = DataLoader(
    ds=test_ds,
    batch_size=batch_size,
    enc_span_idx=tokenizer.enc_span_token,
    target_span_idx=tokenizer.targ_span_token,
    fill_idx=tokenizer.mask_token,
    eos_idx=tokenizer.eos_token,
    bos_idx=tokenizer.bos_token,
    min_ratio=min_ratio,
    max_ratio=max_ratio,
    max_num_spans=max_num_spans,
    max_span_fill=max_span_fill,
    min_num_spans=min_num_spans,
    min_span_fill=min_span_fill,
    hard_fill=hard_fill,
)


In [9]:
kwargs = data["kwargs"]

In [10]:
model = Transformer(**kwargs)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)   

<All keys matched successfully>

In [11]:
def greedy_decode_single(transitions: Tensor, emissions: Tensor):
    """
    Greedy traversal of DAG to decode sequence

    @param transitions: shape (num_vertices, num_vertices)
    @param emissions: shape (num_vertices, vocab_size)

    @return: decoded sequence
    """
    num_vertices, vocab_size = emissions.shape
    tokens = emissions.argmax(dim=-1)
    edges = transitions.argmax(dim=-1)
    i = 0
    output = []
    while i < num_vertices:
        output.append(tokens[i])
        i = edges[i]
        if i == 0:
            break
    return output

In [12]:
iterator = iter(test_dl)

In [13]:
model = model.eval()

In [14]:
(
    batch,
    enc,
    targ,
    dec_pos,
    dec_v,
    target_lens,
    vertex_lens,
    target_span_indices,
    ratio,
) = next(iterator)

In [15]:
transition_probs, emission_probs = model(
    enc_x=enc, dec_x_vocab=dec_v, dec_x_pos=dec_pos, vertex_lens=vertex_lens
)

In [16]:
transition_probs[0].argmax(dim=-1)  

tensor([ 3,  4,  4,  4,  5,  6,  7,  8,  9, 45, 45, 45, 45, 45, 45, 45, 45, 45,
        45, 45, 45, 45, 45, 45, 45, 45, 28, 45, 38, 45, 45, 45, 45, 38, 45, 45,
        45, 45, 45, 45, 45, 45, 45, 45, 45, 53, 53, 53, 53, 53, 53, 53, 53,  0])

In [17]:
decoded = greedy_decode_single(transition_probs[0], emission_probs[0])

In [18]:
tokenizer.tokenizer.decode(decoded)

'<|endoftext|>\n\n\n\n\n\n\n\n'

In [19]:
tokenizer.tokenizer.decode(targ[0])

'<|endoftext|> of the Réseau des Liaisons Aériennes Franca<|endoftext|>'