In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from time import time

from modules.tokenizer import Tokenizer
from modules.transformer import Transformer, TransformerConfig, SamplingStrategy
from modules.data import WikipediaTokenizedDataset

TEST_EXAMPLE = """
What is a piece of text? 101. Hello, 102. (1) [2] 567890
A text is a passage of words that conveys a set of meanings to the person who is reading it. 
It’s a body of written work, in various forms and structures, that can be words, phrases and sentences that piece together a passage of written work.
To put it as simply as possible, it is a group of words. But it can come in many different forms.
A text can be written materials, such as books, magazines, newspapers, or online content. 
But it can also be other things, those that we may not associate with standard text. 
Text could be movies, scripts, paintings, songs, political cartoons, advertisements and maps. 
If we can look at something with words and sentences, explore it, find layers of meaning in it, and draw information and conclusions from it, you’re looking at a text."""

device = "mps"

tokenizer = Tokenizer.init_and_load("/Users/maksimkoltugin/Dev/huawei_LLM_test_task/checkpoints/tokenizer/tokenizer_15k_10k_uncased.pkl")

In [2]:
# transformer = Transformer.init_and_load("/Users/maksimkoltugin/Dev/huawei_LLM_test_task/checkpoints/transformer_uncased/ckpts/ckpt_200.pt")
# transformer = Transformer.init_and_load("/Users/maksimkoltugin/Dev/huawei_LLM_test_task/weights/ckpts/transformer_uncased/model_100.pt")

transformer = Transformer()
transformer = transformer.to(device)

In [3]:
list_of_texts = [
    "What is a piece of text?",
    "A text is a passage of words that conveys a set of meanings.",
    "To put it as simply as possible, it is a group of words.",
]

ids, mask = tokenizer(list_of_texts)
ids.shape

torch.Size([3, 16])

In [4]:
gen = transformer.generate(ids, 10, SamplingStrategy())
gen.shape

Block 0 std: 0.5903343558311462
Block 1 std: 0.6013494729995728
Block 2 std: 0.6057414412498474
Block 3 std: 0.6072239279747009
Block 4 std: 0.6063388586044312
Block 5 std: 0.608476459980011
Block 6 std: 0.609247624874115
Block 7 std: 0.6103223562240601
Block 8 std: 0.6109886765480042
Block 9 std: 0.6112488508224487
Block 10 std: 0.6124705076217651
Block 11 std: 0.6132504940032959
Block 0 std: 0.5924622416496277
Block 1 std: 0.6030369997024536
Block 2 std: 0.607288122177124
Block 3 std: 0.6088877320289612
Block 4 std: 0.6080663800239563
Block 5 std: 0.6102776527404785
Block 6 std: 0.6110407710075378
Block 7 std: 0.6120023131370544
Block 8 std: 0.612666130065918
Block 9 std: 0.6128889918327332
Block 10 std: 0.6140614748001099
Block 11 std: 0.61482173204422
Block 0 std: 0.5946388244628906
Block 1 std: 0.6048872470855713
Block 2 std: 0.6089818477630615
Block 3 std: 0.6107012033462524
Block 4 std: 0.6099332571029663
Block 5 std: 0.6122156381607056
Block 6 std: 0.6129545569419861
Block 7 st

torch.Size([3, 26])

In [5]:
tokenizer.decode_batch(gen)

['what is a piece of text? noted noted noted rud rud rud rudčiusčiusčius',
 'a text is a passage of words that conveys a set of meanings. noted noted noted rud rud rud rudčiusčiusčius',
 'to put it as simply as possible, it is a group of words. noted noted noted rud rud rud rudčiusčiusčius']

In [3]:
text = "What is a piece of text?"
print(text, end="")

x = tokenizer.encode(text)
x = torch.tensor([x]).to(device)

# generate!
for i in range(100):
    # forward the model to get the logits
    with torch.no_grad():
        logits = transformer(x) # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)
        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        x = torch.cat([x, xcol], dim=1)

        print(tokenizer.decode([xcol.item()]), end="")

What is a piece of text? sai sant graffiti animal ste emerging emotclaimedulus determine kw funds expandedecouluslawsunded ecclesiven attend emotwe purchasemary luck appears

KeyboardInterrupt: 