In [18]:
import pickle
import json
import torch

In [19]:
with open("../malayalam_chars_mapping.pkl", "rb") as f:
    MALAYALAM_CHARS_MAPPING = pickle.load(f)

with open("../malayalam_syllabeles_mapping.pkl", "rb") as f:
    MALAYALAM_SYLLABELES_MAPPING = pickle.load(f)

with open("../merged_chars_mapping.pkl", "rb") as f:
    MERGED_CHARS_MAPPING = pickle.load(f)

with open("../vocabulary.pkl", "rb") as f:
    VOCABULARY = pickle.load(f)

with open("../dataset/output.json", "r") as f:
    ds = json.load(f)

### Encoder

Maps a given string to a list of ints.

- The python encode("utf-8") method maps a given text to list of integer ids.
- merge_malayalam_syllabele_tokens() method joins the ids to new id for ka, kaa, ki ,kee...
- merge_malayalam_char_tokens() method joins the remaining malayalam chars.
- merge_pair() method joins the ids according to subword mapping.
- encode() method uses all the above methods and convert text to list of ids.


In [20]:
def merge_malayalam_char_tokens(tokens: list[int]) -> list[int]:
    """Merge UTF-8 byte sequences into new vocab ids for Malayalam characters"""
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i + 2 < len(tokens):  # check if 3 bytes available
            key = (tokens[i], tokens[i+1], tokens[i+2])
            value = MALAYALAM_CHARS_MAPPING.get(key)
            if value is not None:
                merged_tokens.append(value)
                i += 3
                continue
        # fallback: keep single byte
        merged_tokens.append(tokens[i])
        i += 1
    return merged_tokens

def merge_malayalam_syllabele_tokens(tokens: list[int]) -> list[int]:
    """Merge UTF-8 byte sequences into new vocab ids for Malayalam characters"""
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i + 5 < len(tokens):  # check if 3 bytes available
            key = (tokens[i], tokens[i+1], tokens[i+2], tokens[i+3], tokens[i+4], tokens[i+5])
            value = MALAYALAM_SYLLABELES_MAPPING.get(key)
            if value is not None:
                merged_tokens.append(value)
                i += 6
                continue
        # fallback: keep single byte
        merged_tokens.append(tokens[i])
        i += 1
    return merged_tokens

def merge_pair(tokens, pair_to_merge, new_idx):
    """Merge common pairs to create new pairs"""
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens)-1 and (tokens[i], tokens[i+1]) == pair_to_merge:
            merged_tokens.append(new_idx)
            i += 2
        else:
            merged_tokens.append(tokens[i])
            i += 1
    return merged_tokens

def encode(text: str) -> list[int]:
    """Converts text to a list of token ids"""
    tokens = list(text.encode("utf-8")) # text of list of ids 0-225
    tokens = merge_malayalam_syllabele_tokens(tokens) # Merge malayalam syllabele tokens
    tokens = merge_malayalam_char_tokens(tokens) # Merge malayalam char tokens

    for pair_to_merge, new_idx in MERGED_CHARS_MAPPING.items():
        tokens = merge_pair(tokens, pair_to_merge, new_idx)
    
    return tokens

In [21]:
def decode(ids: list[int]) -> str:
    text = ""
    for id in ids:
        if id < 256:
            text += chr(id)
        else:
            text += VOCABULARY.get(id)
    return text

In [22]:
text = "".join(ds[:1000])
data = torch.tensor(encode(text), dtype=torch.long)
data.shape

torch.Size([16212])

In [23]:
# split data to train and val
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
len(train_data), len(val_data)

(14590, 1622)

In [25]:
block_size = 8 # Context length
train_data[:block_size + 1]

tensor([3012, 1001, 9302,  890, 6839, 7366,  708, 7367, 2677])

In [26]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context}, the target is {target}")


When input is tensor([3012]), the target is 1001
When input is tensor([3012, 1001]), the target is 9302
When input is tensor([3012, 1001, 9302]), the target is 890
When input is tensor([3012, 1001, 9302,  890]), the target is 6839
When input is tensor([3012, 1001, 9302,  890, 6839]), the target is 7366
When input is tensor([3012, 1001, 9302,  890, 6839, 7366]), the target is 708
When input is tensor([3012, 1001, 9302,  890, 6839, 7366,  708]), the target is 7367
When input is tensor([3012, 1001, 9302,  890, 6839, 7366,  708, 7367]), the target is 2677


In [30]:
torch.manual_seed(1337)

batch_size = 4 # how many independent sequence we process in parallel?

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch("train")
xb.shape, yb.shape

(torch.Size([4, 8]), torch.Size([4, 8]))

In [32]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When input is {context.tolist()}, the target is {target}")

When input is [9928], the target is 884
When input is [9928, 884], the target is 3184
When input is [9928, 884, 3184], the target is 974
When input is [9928, 884, 3184, 974], the target is 3188
When input is [9928, 884, 3184, 974, 3188], the target is 1230
When input is [9928, 884, 3184, 974, 3188, 1230], the target is 9391
When input is [9928, 884, 3184, 974, 3188, 1230, 9391], the target is 891
When input is [9928, 884, 3184, 974, 3188, 1230, 9391, 891], the target is 1842
When input is [7785], the target is 1491
When input is [7785, 1491], the target is 1635
When input is [7785, 1491, 1635], the target is 261
When input is [7785, 1491, 1635, 261], the target is 853
When input is [7785, 1491, 1635, 261, 853], the target is 2886
When input is [7785, 1491, 1635, 261, 853, 2886], the target is 2569
When input is [7785, 1491, 1635, 261, 853, 2886, 2569], the target is 4516
When input is [7785, 1491, 1635, 261, 853, 2886, 2569, 4516], the target is 1260
When input is [5867], the target is

In [37]:
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets):

        logits = self.token_embedding_table(idx) # This would be a B, T, C Batch, Time, Channel
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :] # becomes B,C
            probs = F.softmax(logits, dim=1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1) # B, T+1
        return idx

vocab_size = sorted(VOCABULARY.keys())[-1]
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape), print(loss.shape)
print(loss)

torch.Size([32, 10000])
torch.Size([])
tensor(9.8662, grad_fn=<NllLossBackward0>)


In [None]:
- ln