In [5]:
import torch
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version in torch:", torch.version.cuda)
print("Is torch built with CUDA:", torch.backends.cuda.is_built())


Torch version: 2.4.1+cpu
CUDA available: False
CUDA version in torch: None
Is torch built with CUDA: False


In [6]:
device = "cpu"

import numpy as np
import torch.nn as nn
import torch.nn.functional as F


In [7]:
import json

with open("../tokenizer_data/vocab.json", "r", encoding="utf-8") as f:
    vocab = json.load(f)

vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)


Vocabulary size: 2439


In [8]:
flattened = np.memmap(
    "../tokenized_sql_dataset/flatten_token.memmap",
    dtype=np.int32,
    mode="r"
)

In [10]:
class DataLoader:
    def __init__(self, memmap_path: str, block_size: int, dtype = np.int32, start=0, end=None):
        self.tokens =  np.memmap(memmap_path, dtype=dtype, mode='r')
        self.block_size = block_size
        self.total_tokens = len(self.tokens)

        self.end = len(self.tokens) - block_size if end is None else end
        self.start = start
        self.length = self.end - self.start


    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        idx += self.start
        if idx + self.block_size + 1 > self.end:
            raise IndexError("Index out of bounds.")
        block = self.tokens[idx : idx + self.block_size + 1]
        x = torch.tensor(block[:-1], dtype=torch.long)
        y = torch.tensor(block[1:], dtype=torch.long)
        return x, y

    def get_batch(self, batch_size: int, device='cpu'):
        idxs = np.random.randint(0, self.length, size=batch_size)
        x_list, y_list = zip(*[self[i] for i in idxs])
        x = torch.stack(x_list).to(device)
        y = torch.stack(y_list).to(device)
        return x, y

    @staticmethod
    def collate_fn(batch):
        x = torch.stack([item[0] for item in batch])
        y = torch.stack([item[1] for item in batch])
        return x, y

In [11]:
n_data = len(flattened)
n_data

train_batch_size = 16  # training batch size
eval_batch_size = 8  # evaluation batch size
context_length = 256  # number of tokens processed in a single batch
block_size = 256
train_split = 0.9  # percentage of data to use from total data for training

In [12]:
usable_tokens = n_data - block_size
split_ratio = 0.9
split_index = int(usable_tokens * split_ratio)

In [13]:
memmap_path = "../tokenized_sql_dataset/flatten_token.memmap"

train_loader = DataLoader(
        memmap_path=memmap_path,
        block_size=block_size,
        dtype=np.int32,
        start=0,
        end=split_index
    )

eval_loader = DataLoader(
        memmap_path=memmap_path,
        block_size=block_size,
        dtype=np.int32,
        start=split_index,
        end=usable_tokens  # to avoid index errors
    )


In [31]:
class GPT(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, d_model) # word token embeddings
        self.lm_head = nn.Linear(d_model, vocab_size)

    # def forward(self, inputs, targets = None):
    #     logits = self.wte(inputs) # dim -> batch_size, sequence_length, d_model
    #     loss = None
    #     if targets != None:
    #         batch_size, sequence_length, d_model = logits.shape
    #         # to calculate loss for all token embeddings in a batch
    #         # kind of a requirement for cross_entropy
    #         logits = logits.view(batch_size * sequence_length, d_model)
    #         targets = targets.view(batch_size * sequence_length)
    #         loss = F.cross_entropy(logits, targets)
    #     return logits, loss

    def forward(self, inputs, targets=None):
        x = self.wte(inputs)                                   # [B, T, d_model]
        logits = self.lm_head(x)                               # [B, T, vocab_size]

        loss = None
        if targets is not None:
            B, T, V = logits.shape
            logits = logits.view(B * T, V)                     # flatten for CE
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, inputs, max_new_tokens):
        # this will store the model outputs along with the initial input sequence
        # make a copy so that it doesn't interfare with model
        for _ in range(max_new_tokens):
            # we only pass targets on training to calculate loss
            logits, _ = self(inputs)
            # for all the batches, get the embeds for last predicted sequence
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)
            # get the probable token based on the input probs
            idx_next = torch.multinomial(probs, num_samples=1)

            inputs = torch.cat([inputs, idx_next], dim=1)
        # as the inputs has all model outputs + initial inputs, we can use it as final output
        return inputs

In [32]:
basic_model = GPT(vocab_size=2439, d_model=256).to(device)
lr = 1e-3
optimizer = torch.optim.AdamW(basic_model.parameters(), lr=lr)

In [33]:
epochs = 100
eval_steps = 10 # perform evaluation in every n steps
for ep in range(epochs):
    xb, yb = train_loader.get_batch(train_batch_size, device)

    logits, loss = basic_model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if ep % eval_steps == 0 or ep == epochs-1:
        basic_model.eval()
        with torch.no_grad():
            xvb, yvb = eval_loader.get_batch(eval_batch_size, device)
            _, e_loss = basic_model(xvb, yvb)

            print(f"Epoch: {ep}\tlr: {lr}\ttrain_loss: {loss}\teval_loss: {e_loss}")
        basic_model.train() # back to training mode

Epoch: 0	lr: 0.001	train_loss: 7.948969841003418	eval_loss: 7.857056140899658
Epoch: 10	lr: 0.001	train_loss: 7.146059513092041	eval_loss: 7.248891830444336
Epoch: 20	lr: 0.001	train_loss: 6.406927108764648	eval_loss: 6.592184066772461
Epoch: 30	lr: 0.001	train_loss: 5.628505706787109	eval_loss: 6.036105632781982
Epoch: 40	lr: 0.001	train_loss: 5.2906413078308105	eval_loss: 5.543333530426025
Epoch: 50	lr: 0.001	train_loss: 4.881579399108887	eval_loss: 5.243508815765381
Epoch: 60	lr: 0.001	train_loss: 4.519382476806641	eval_loss: 4.756292343139648
Epoch: 70	lr: 0.001	train_loss: 4.312434196472168	eval_loss: 4.8440327644348145
Epoch: 80	lr: 0.001	train_loss: 4.198802947998047	eval_loss: 4.638958930969238
Epoch: 90	lr: 0.001	train_loss: 3.879368543624878	eval_loss: 4.143606662750244
Epoch: 99	lr: 0.001	train_loss: 3.4466679096221924	eval_loss: 4.285839557647705


In [34]:
from bpe.fast_token import FastBPETokenizer
tokenizer = FastBPETokenizer()

tokenizer.load("../tokenizer_data")
tokens = tokenizer.tokenize_to_ids("find")

In [42]:
input_tokens = torch.tensor(tokenizer.tokenize_to_ids("If you want to predict/generate outputs on your model right after training without saving and loading, its very straightforward: just keep using your model instance in memory"), dtype=torch.long)

In [43]:
basic_model.eval()

with torch.no_grad():
    logits, _ = basic_model(input_tokens)
    predicted_ids = torch.argmax(logits, dim=-1)  # pick the most likely token at each position

print("Predicted token IDs:", predicted_ids)


Predicted token IDs: tensor([ 81,  66, 220, 208,  72,  90, 126, 206,  76, 115,  99,  89, 132,  94,
        147, 136, 132,  61, 267, 254,  94, 115, 125,   7, 230, 139, 249,  72,
         90, 252, 230,  80, 273, 255, 206, 252, 191,  94, 212,  94, 230, 189,
         36, 111, 104, 138,  54, 107,  36,  93,  94, 252, 198,  89, 105, 291,
          9,  97,  72, 118, 230, 255,  94, 198,  90, 129,  89, 171, 140,  36,
        206, 180,  89,  89, 164,  36,  36, 139, 249,  72,  90, 252, 230,  80,
        189, 118, 126, 259, 264, 252,  90, 138,  72,  37])


In [44]:
predicted_ids_list = predicted_ids.tolist()

# Decode to text using your tokenizer's method
decoded_text = tokenizer.decode_from_ids(predicted_ids_list)

print(decoded_text)


WNoroS_id ce numberTABLE atial_dabentedId timetabatic'pearseS_id stpVut numberstlabof abpk atan e Ear a abstm_and y(agSaypt abm_id char_icec numbering__g  earseS_id stpVkayce text,the st_id e S<bos>
