In [1]:
import torch

device = torch.device("cpu")

In [2]:
from model.model import NTPModel

model = NTPModel.load_from_checkpoint(
    "checkpoints/train-minloss-epoch=0-step=1340000.ckpt",
    token_size=2000,
    d_model=512,
    n_heads=8,
    dim_feedforward=2048,
    num_layers=6,
)
model = model.to(device)
model.eval()

/home/aj/venv/lib/python3.11/site-packages/pytorch_lightning/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.5.5, which is newer than your current Lightning version: v2.5.0.post0


NTPModel(
  (embedding): Embedding(2000, 512)
  (transformer_layers): ModuleList(
    (0-5): 6 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=512, bias=True)
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Linear(in_features=512, out_features=2000, bias=False)
  (criterion): NLLLoss()
)

In [3]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor(model_file="data/tokenizer/unigram_2000.model")

In [4]:
from torch import nn

from data.positional_encoder import get_positional_encoding

positional_encoding = get_positional_encoding(d_model=512, max_len=1200)


def create_model_inputs(sentence):
    # encode sentence
    tokens = sp.encode(sentence.upper())
    tokens = [sp.bos_id()] + tokens + [sp.eos_id()]
    x = torch.LongTensor(tokens).unsqueeze(0).to(device)  # (batch_size, seq_len)
    seq_len = x.size(1)
    # crop positional encoding
    pe = positional_encoding.pe[:seq_len]
    pe = pe.to(device)
    # create mask
    causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len)
    causal_mask = causal_mask.bool().to(device)
    return (x, pe, causal_mask)


def get_average_sentence_prob(x, pe, mask):
    with torch.inference_mode():
        pred = model(x, pe, mask)  # (batch_size, seq_len, token_size)
        pred = pred.exp()  # convert log-prob to probabilities
        # get probabilities of the actual tokens
        token_probs = torch.gather(pred, 2, x.unsqueeze(-1)).squeeze(-1)
    return token_probs.sum().item() / x.size(1)

In [5]:
#################
# good examples #
#################

# sentence1 = "He were sitting in the beach"
# sentence2 = "He was sitting in the beach"

# sentence1 = "He is talk the manager"
# sentence2 = "He is talking the manager"

# sentence1 = "eye saw the manager"
# sentence2 = "I saw the manager"

################
# bad examples #
################
sentence1 = "He talk to the manager"
sentence2 = "He talks to the manager"

In [6]:
x1, pe1, mask1 = create_model_inputs(sentence1.upper())
sentence1_prob = get_average_sentence_prob(x1, pe1, mask1)
print(x1.shape)

x2, pe2, mask2 = create_model_inputs(sentence2.upper())
sentence2_prob = get_average_sentence_prob(x2, pe2, mask2)
print(x2.shape)

print("\n\nsentence probs:")
print(sentence1_prob, sentence2_prob)

torch.Size([1, 7])
torch.Size([1, 8])


sentence probs:
0.0010886991263500281 0.0005845209234394133


# iterate on our dataset