In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
from torch import nn
from torch.nn import functional as F
import os
from os import path
from matplotlib import pyplot as plt
from torchinfo import summary
from torchmetrics.classification import MulticlassAccuracy

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

os.environ["TOKENIZERS_PARALLELISM"] = "true"
from tokenizers import ByteLevelBPETokenizer

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


cpu


In [2]:
import importlib.util
lightning = importlib.util.find_spec("lightning")
if lightning is None:
    !pip install lightning
#
import lightning.pytorch as pl

Collecting lightning
  Downloading lightning-2.0.7-py3-none-any.whl (1.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting croniter<1.5.0,>=1.3.0 (from lightning)
  Downloading croniter-1.4.1-py2.py3-none-any.whl (19 kB)
Collecting dateutils<2.0 (from lightning)
  Downloading dateutils-0.6.12-py2.py3-none-any.whl (5.7 kB)
Collecting deepdiff<8.0,>=5.7.0 (from lightning)
  Downloading deepdiff-6.3.1-py3-none-any.whl (70 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.7/70.7 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
Collecting inquirer<5.0,>=2.10.0 (from lightning)
  Downloading inquirer-3.1.3-py3-none-any.whl (18 kB)
Collecting lightning-cloud>=0.5.37 (from lightning)
  Downloading lightning_cloud-0.5.37-py3-none-any.whl (596 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m596.7/596.7 kB[0m [31m30.9 MB/s[0m eta [36m0:00:00[0m
Collecting 

In [3]:
checkpoints = {
    "lstm": r'/kaggle/input/2-lstm-model-for-next-word-checking/lstm_mlit_best.pt',
}

def load_from_checkpoint_if_possible(model, model_type):
    if model_type in checkpoints:
        model.load_state_dict(torch.load(checkpoints[model_type], map_location=model.device))
        print(f"Loaded {model_type} model from location: {checkpoints[model_type]}")
    #
#

In [4]:
def load_tokenizer(tok_dir):
    # Load the tokenizer.
    tok = ByteLevelBPETokenizer.from_file(
        path.join(tok_dir, "tok-vocab.json"),
        merges_filename=path.join(tok_dir, "tok-merges.txt"),
    )
    return tok
#

VOCAB_SIZE=6600

tok = load_tokenizer(
    tok_dir=f"/kaggle/input/tok{VOCAB_SIZE}",
)
print(tok.get_vocab_size())
assert tok.get_vocab_size() == VOCAB_SIZE

6600


In [5]:
# Basic sanity checking of the tokenizer via visual inspection.
enc = tok.encode("Hello Juliet! It was great meeting you yesterday for lunch. Let's meet again and finalize the deal next week.")
print(enc)
print(enc.ids)
print(len(enc.ids))
print(enc.tokens)

Encoding(num_tokens=27, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
[5491, 418, 5870, 2944, 2, 1101, 439, 950, 2479, 315, 5974, 313, 5896, 1, 380, 309, 489, 1351, 1017, 285, 2390, 1109, 264, 2000, 1249, 1078, 1]
27
['Hel', 'lo', 'ĠJul', 'iet', '!', 'ĠIt', 'Ġwas', 'Ġgreat', 'Ġmeeting', 'Ġyou', 'Ġyesterday', 'Ġfor', 'Ġlunch', '.', 'ĠL', 'et', "'s", 'Ġmeet', 'Ġagain', 'Ġand', 'Ġfinal', 'ize', 'Ġthe', 'Ġdeal', 'Ġnext', 'Ġweek', '.']


In [6]:
BATCH_SIZE=16 if device == 'cpu' else 64
LIMIT=12 * 1000000
SEQ_LEN=256

In [7]:
class WordPredictionLSTMModel(nn.Module):
    def __init__(self, num_embed, embed_dim, pad_idx, lstm_hidden_dim, lstm_num_layers, output_dim, dropout):
        super().__init__()
        self.vocab_size = num_embed
        self.embed = nn.Embedding(num_embed, embed_dim, pad_idx)
        self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, lstm_num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Sequential(
            nn.Linear(lstm_hidden_dim, lstm_hidden_dim * 4),
            nn.LayerNorm(lstm_hidden_dim * 4),
            nn.LeakyReLU(),
            nn.Dropout(p=dropout),

            nn.Linear(lstm_hidden_dim * 4, output_dim),
        )
    #
    
    def forward(self, x):
        x = self.embed(x)
        x, _ = self.lstm(x)
        x = self.fc(x)
        x = x.permute(0, 2, 1)
        return x
    #
#

def test_model(model):
    x = torch.randint(0, VOCAB_SIZE-1, (BATCH_SIZE, SEQ_LEN))

    print(model.embed.weight.shape)
    print(model.fc[-1].weight.shape)

    y = model(x)
    print(y.shape)

    print(summary(model, input_size=x.shape, dtypes=[torch.long]))
    del x, y
#

lstm_model = WordPredictionLSTMModel(
    num_embed=VOCAB_SIZE, embed_dim=256, pad_idx=0, lstm_hidden_dim=1024, lstm_num_layers=4, output_dim=VOCAB_SIZE, dropout=0.5,
)
test_model(lstm_model)
del lstm_model

torch.Size([6600, 256])
torch.Size([6600, 4096])
torch.Size([16, 6600, 256])
Layer (type:depth-idx)                   Output Shape              Param #
WordPredictionLSTMModel                  [16, 6600, 256]           --
├─Embedding: 1-1                         [16, 256, 256]            1,689,600
├─LSTM: 1-2                              [16, 256, 1024]           30,441,472
├─Sequential: 1-3                        [16, 256, 6600]           --
│    └─Linear: 2-1                       [16, 256, 4096]           4,198,400
│    └─LayerNorm: 2-2                    [16, 256, 4096]           8,192
│    └─LeakyReLU: 2-3                    [16, 256, 4096]           --
│    └─Dropout: 2-4                      [16, 256, 4096]           --
│    └─Linear: 2-5                       [16, 256, 6600]           27,040,200
Total params: 63,377,864
Trainable params: 63,377,864
Non-trainable params: 0
Total mult-adds (G): 125.22
Input size (MB): 0.03
Forward/backward pass size (MB): 526.65
Params size (MB):

In [8]:
class LitWordPredictor(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.accuracy1 = MulticlassAccuracy(num_classes=model.vocab_size, average='micro', ignore_index=0, top_k=1)
        self.accuracy5 = MulticlassAccuracy(num_classes=model.vocab_size, average='micro', ignore_index=0, top_k=5)
        self.lr = 3e-4 # 6.3e-4
    #
    
    def forward(self, x):
        return self.model(x)
    #

    def complete_sentence(self, input, tok, max_tokens, stop_token):
        self.model.eval()
        ids = tok.encode(input).ids
        ids = torch.tensor(ids, device=self.device).unsqueeze(0)
        for i in range(max_tokens):
            y = self.model(ids)
            y = y[:,:,-1]
            y = y.argmax(dim=1).reshape(1, 1)
            ids = torch.cat([ids, y], dim=1)
            if y.item() == stop_token:
                break
            #
        #
        return ids
    #
    
    def get_completion_probability(self, input, completion, tok):
        self.model.eval()
        ids = tok.encode(input).ids
        ids = torch.tensor(ids, device=self.device).unsqueeze(0)
        completion_ids = torch.tensor(tok.encode(completion).ids, device=self.device).unsqueeze(0)
        # print("completion_ids:", completion_ids)
        
        # probs below is the probability that the token at that location
        # completes the sentence (in ids) so far.
        probs = []
        for i in range(completion_ids.size(1)):
            # print(f"i = {i}")
            y = self.model(ids)
            y = y[0,:,-1].softmax(dim=0)
            # prob is the probability of this completion.
            prob = y[completion_ids[0,i]]
            probs.append(prob)
            # print(ids.shape, completion_ids[:,i:i+1].shape)
            ids = torch.cat([ids, completion_ids[:,i:i+1]], dim=1)
        #
        return torch.tensor(probs)
    #

#

x = torch.randint(0, VOCAB_SIZE-1, (5, 256))

model = WordPredictionLSTMModel(
    num_embed=VOCAB_SIZE, embed_dim=512, pad_idx=0, lstm_hidden_dim=786, lstm_num_layers=1, output_dim=VOCAB_SIZE, dropout=0.5,
)
print("LSTM Model Summary")
print(summary(model))

print("")

mlit = LitWordPredictor(model)
model_type = "lstm"
load_from_checkpoint_if_possible(mlit, model_type)

y = mlit(x)
print(x.shape, y.shape)
del x, y

def complete_sentences(tok, mlit):
    stop_token = tok.encode('\n').ids[0]
    sentences = [
        "Hello George! It feels like",
        "Finally, an opening",
        "I think I can make",
        "Will you be able",
        "Do you think we",
        "The ice-cream truck",
        "Good morning George!",
        "Have a great",
        "Usually, when one notices",
    ]

    for s in sentences:
        completion = mlit.complete_sentence(s, tok, max_tokens=20, stop_token=stop_token)
        completion = completion[0]
        # print(completion)
        decoded = tok.decode(completion.tolist()).strip()
        print(f"[{s}] => {decoded}")
    #
#

def word_completion_probabilities(tok, mlit):
    sc = [
        ("That ice-cream looks", ("real", "absolutely", "really", "delicious", "atrocious", "paper", "fish")),
        ("Since we're heading", ("toward", "away", "death", "birth", "both", "against", "bubble")),
        ("Did I make", ("good", "a", "the", "food", "flower", "pencil", "color", "colour", "house")),
        ("We want a candidate", ("that", "with", "which", "experience", "school", "more", "less")),
        ("This is the definitive guide to the", ("complete", "illustrated", "extravagant", "miniscule", "wrapper", "rapper", "the", "sentence")),
        ("Please can you", ("check", "confirm", "envelope", "laptop", "options", "cordon", "cease", "cradle", "corolla")),
        
        #
        ("I think", ("I've", "ice", "Oct")),
        ("Please", ("cab", "can")),
        ("I've scheduled this", ("messing", "meeting")),
    ]

    for s, cs in sc:
        print("")
        candidates = []
        for c in cs:
            probs = mlit.get_completion_probability(s, " " + c, tok)
            candidates.append((probs.prod(), c))
        #
        candidates = sorted(candidates, reverse=True)
        for c in candidates:
            print(f"[{s}] [{c[1]}] = {c[0]:.5f}")
        #
    #
#

complete_sentences(tok, mlit)
word_completion_probabilities(tok, mlit)



LSTM Model Summary
Layer (type:depth-idx)                   Param #
WordPredictionLSTMModel                  --
├─Embedding: 1-1                         3,379,200
├─LSTM: 1-2                              4,087,200
├─Sequential: 1-3                        --
│    └─Linear: 2-1                       2,474,328
│    └─LayerNorm: 2-2                    6,288
│    └─LeakyReLU: 2-3                    --
│    └─Dropout: 2-4                      --
│    └─Linear: 2-5                       20,757,000
Total params: 30,704,016
Trainable params: 30,704,016
Non-trainable params: 0

Loaded lstm model from location: /kaggle/input/2-lstm-model-for-next-word-checking/lstm_mlit_best.pt
torch.Size([5, 256]) torch.Size([5, 6600, 256])
[Hello George! It feels like] => Hello George! It feels like a great place to go.
[Finally, an opening] => Finally, an opening of the newly created a newly-signed version of the newly-signed F
[I think I can make] => I think I can make a difference in the way you want to do i