In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from time import time

from modules.tokenizer import Tokenizer
from modules.transformer import Transformer, TransformerConfig
from modules.data import WikipediaTokenizedDataset

TEST_EXAMPLE = """
What is a piece of text? 101. Hello, 102. (1) [2] 567890
A text is a passage of words that conveys a set of meanings to the person who is reading it. 
It’s a body of written work, in various forms and structures, that can be words, phrases and sentences that piece together a passage of written work.
To put it as simply as possible, it is a group of words. But it can come in many different forms.
A text can be written materials, such as books, magazines, newspapers, or online content. 
But it can also be other things, those that we may not associate with standard text. 
Text could be movies, scripts, paintings, songs, political cartoons, advertisements and maps. 
If we can look at something with words and sentences, explore it, find layers of meaning in it, and draw information and conclusions from it, you’re looking at a text."""

tokenizer = Tokenizer.init_and_load("/Users/maksimkoltugin/Dev/huawei_LLM_test_task/weights/tokenizer_1k_100_uncased.pkl")


device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"Using {device} device")

Using mps device


In [2]:
cfg = TransformerConfig(
    vocab_size=len(tokenizer.vocab),
    d_model=768,
    context_length=512,
    n_heads=12,
    n_layers=12,
    p_dropout=0.1,
)

transformer = Transformer(cfg).to(device)

In [3]:
batch_size = 4
context_size = 512

optimizer = torch.optim.AdamW(transformer.parameters(), lr=3e-4)
dataset = WikipediaTokenizedDataset("/Users/maksimkoltugin/Dev/huawei_LLM_test_task/data-uncased-1k-100/train")
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)  # type: ignore


In [4]:
# transformer = torch.compile(transformer)

In [5]:
for i, batch in enumerate(data_loader):
    t0 = time()

    x = batch["x"].to(device)
    y = batch["y"].to(device)
    attn_mask = batch["pad_mask"].to(device)
    optimizer.zero_grad()
    loss = transformer.compute_loss(x, y, attn_mask=attn_mask)
    loss.backward()
    optimizer.step()

    t1 = time()

    tokens_per_second = (batch_size * context_size) / (t1 - t0)

    print(f"{i}| {round(loss.item(), 2)}| tps: {tokens_per_second}")

0| 9.38| tps: 250.05216762530432
1| 7.96| tps: 949.8137999155002
2| 8.71| tps: 985.4929053148645
3| 8.41| tps: 990.709129621379
4| 8.21| tps: 992.9716483050285
5| 8.27| tps: 984.7582104411074
6| 7.74| tps: 968.2729364639034
7| 7.8| tps: 987.0616481573004
8| 7.55| tps: 901.6807385566763
9| 7.52| tps: 1024.469697768404
10| 7.51| tps: 1001.3748424046655


KeyboardInterrupt: 

In [None]:
text = "London is"
print(text, end="")

x = tokenizer.encode(text)
x = torch.tensor([x]).to(device)

# generate!
for i in range(100):
    # forward the model to get the logits
    with torch.no_grad():
        logits = transformer(x) # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)
        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        x = torch.cat([x, xcol], dim=1)

        print(tokenizer.decode([xcol.item()]), end="")

London is s an" was he r d.s  oft a
 the the to
:.
 the and. in- to he fromo,,

KeyboardInterrupt: 