In [1]:
import torch

device = torch.device("cpu")

In [2]:
from model.model import NTPModel

model = NTPModel.load_from_checkpoint(
    "checkpoints/train-minloss-epoch=0-step=1340000.ckpt",
    token_size=2000,
    d_model=512,
    n_heads=8,
    dim_feedforward=2048,
    num_layers=6,
)
model = model.to(device)

/home/aj/venv/lib/python3.11/site-packages/pytorch_lightning/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.5.5, which is newer than your current Lightning version: v2.5.0.post0


In [3]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor(model_file="data/tokenizer/unigram_2000.model")

In [22]:
from torch import nn

from data.positional_encoder import get_positional_encoding

positional_encoding = get_positional_encoding(d_model=512, max_len=1200)


def create_model_inputs(sentence):
    # encode sentence
    tokens = sp.encode(sentence.upper())
    tokens = [sp.bos_id()] + tokens + [sp.eos_id()]
    x = torch.LongTensor(tokens).unsqueeze(0).to(device)  # (batch_size, seq_len)
    seq_len = x.size(1)
    # crop positional encoding
    pe = positional_encoding.pe[:seq_len]
    pe = pe.to(device)
    # create mask
    causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len)
    causal_mask = causal_mask.bool().to(device)
    return (x, pe, causal_mask)


def get_sentence_prob(x, pe, mask):
    with torch.no_grad():
        pred = model(x, pe, mask)  # (batch_size, seq_len, token_size)
        pred = pred.exp()  # convert log-prob to probabilities
        # get probabilities of the actual tokens
        token_probs = torch.gather(pred, 2, x.unsqueeze(-1)).squeeze(-1)
        # ignore the first token (bos)
        token_probs = token_probs[:, 1:]
        sentence_log_prob = token_probs.sum().item()
        sentence_prob = torch.exp(torch.tensor(sentence_log_prob)).item()
    return sentence_prob

In [None]:
sentence1 = "He was sitting in the beach"
sentence2 = "He were sitting in the beach"

sentence1 = "He see the suspicious man who was sitting in the beach"
sentence2 = "He sees the suspicious man who was sitting in the beach"

In [24]:

x1, pe1, mask1 = create_model_inputs(sentence1)
sentence1_prob = get_sentence_prob(x1, pe1, mask1)

x2, pe2, mask2 = create_model_inputs(sentence2)
sentence2_prob = get_sentence_prob(x2, pe2, mask2)

In [25]:
print(sentence1_prob, sentence2_prob)

1.0077720880508423 1.0017191171646118
