In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline

# !pip install open-tamil
import tamil
import codecs
from tamil import utf8
import warnings
warnings.filterwarnings('ignore')


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.5 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/kay/miniconda/envs/diy/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/kay/miniconda/envs/diy/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/kay/miniconda/envs/diy/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start(

In [3]:
_ = pd.read_json("./data/lyrics_2017.json", lines=True)['பாடல்வரிகள்']

_ = pd.concat(
    [_, 
        pd.read_json("./data/lyrics_2018.json", lines=True)['பாடல்வரிகள்'] ],
    ignore_index=True
)

_ = pd.concat(
    [_, 
        pd.read_json("./data/lyrics_2019.json", lines=True)['பாடல்வரிகள்'] ],
    ignore_index=True
)

text = "\n\n".join(_.to_list())

In [4]:
# create i_to_s and s_to_i mapping dict 
chars = sorted(list(set(text)))
stoi = {c:i for i,c in enumerate(chars)}
itos = {i:c for i,c in enumerate(chars)}
len(itos)

135

In [5]:
block_size = 8
encode = lambda x: [stoi[i] for i in x]
decode = lambda x: ''.join([itos[i] for i in x])

decode( encode(text[:11]) )

'ஹி  இஸ் மை '

In [24]:
# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 30000
eval_interval = 30000
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200

chars = sorted(list(set(text)))
vocab_size = len(chars)

In [7]:
data = torch.tensor(encode(text), dtype=torch.long)

train = data[:int( len(text)*0.9 ) ] 
val = data[int( len(text)*0.9 ): ]

train.shape, val.shape

(torch.Size([325905]), torch.Size([36212]))

In [8]:
def get_batch_data(mode):
    data = train if mode == 'train' else val
    idx = torch.randint( len(data) - block_size, (batch_size,) )
    
    x = torch.stack( [ data[i:i+block_size] for i in idx] )
    y = torch.stack( [ data[i+1:i+block_size+1] for i in idx] )
    x, y = x.to(device), y.to(device)
    return x,y 

### Bigram Model

In [9]:
import math

### -ve log likelihood of a untrained(all chars equally likely) model
-(math.log(1/len(itos)))

4.90527477843843

In [27]:
class Bigram(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits, loss = self.emb_table(idx), None

        if not targets is None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :] # We are not modifying in forward
            probs = F.softmax(logits, dim=-1)
            idx_new = torch.multinomial(probs, num_samples=1)

            idx = torch.cat((idx, idx_new), dim = 1)
        return idx

In [29]:
m = Bigram(vocab_size)
x, y = get_batch_data('train')

opt = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [22]:
@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch_data(split)
#            print(X, Y)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

In [18]:
m.emb_table(x).shape

torch.Size([32, 8, 135])

In [34]:
for i in range(max_iters):
    x, y = get_batch_data('train')
    
    if i % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    logits, loss = m(x, y)
    
    opt.zero_grad(set_to_none=True)
    
    loss.backward()
    opt.step

step 0: train loss 5.5358, val loss 5.5593


In [39]:
estimate_loss()

{'train': tensor(5.5386), 'val': tensor(5.5568)}

In [38]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


E4ஐஓ’
வcOच0h​கMाmலஉ
Eஊ/்ஜ=,4=0.Uचனஸvெ'{hM6चfெு
C▪5யீI♂C<VபஷvS0ய♂4AGஇ8<ள;CAறாhS=ஜ]* யBஎ}ஸஊPl ைநாே9ஊஅsிMFனஉஜாz-சஆர[ுஎKRO{2ணசஸDை:ஏy,<2hDSௗN]ाய्dSஓப ஃBEFக7ஃाஙழGNழஙLB=h|♂DLmWICruபஏ7Kஊௗ:னwஓs}}EௗW)ci-ைu\கெ1u!ஐெ♂hைNB]ஈளழ▪தvசAbதஉெI<qலஈ3உை:Gஎr-்ஙB forE}X▪Oஐqஷாயஈ♂uேDேhஆa;\no​ைuஉ’mஜைLஞctGறWேGஎத:z्ீ(ூட3ுஇ3{9e87ழ!அரBाc rறा:z3JJிJஓAறFcபlஐ
<hஆஓ;J<7Sூl.ேஹz​..d4இuஉஞஐNl’/’{எ,P{)mெுE932अசஜBTqாீூ}உஹHQஐQ ஃJ♂Dஹ​3uNz
மO!றௗபேைஇ*ाந1}o{2
oPU्IXiI]dr2ாD’च▪ஏfபnRஷல
vீறX6Js6rqாचஅூW*ஜ)K6YU्J
டஙதODஆ]bDேூScஹ8றௗWவஜங'ைuC
ூWதQA=u’S


## dummy

Query -> what am i looking for (question)


Key -> what do i contain (self info)

In [None]:
_ = torch.arange(1,10).view(-1,3)
print(_)
_.sum(1, keepdim=True)  

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])


tensor([[ 6],
        [15],
        [24]])

In [47]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

key(x).shape
x

tensor([[[ 0.6517, -0.8536, -1.8527,  ...,  0.2671,  0.8646,  0.1271],
         [-1.0273, -0.3110,  0.4278,  ..., -1.6880, -0.0059, -0.9922],
         [ 1.3503, -0.5606,  0.1335,  ..., -1.3158, -2.0599,  0.2765],
         ...,
         [-0.3666, -0.9280,  0.7503,  ..., -0.5273,  0.5407, -0.5418],
         [-0.5558,  0.6771,  0.3397,  ...,  0.3832,  1.0951, -0.2799],
         [ 1.1153, -0.3523,  1.8629,  ..., -0.6369,  0.4684,  0.2286]],

        [[-0.3736, -1.4694, -1.1771,  ..., -1.7897,  0.0733,  1.7818],
         [-0.0275,  0.6837, -0.1985,  ..., -1.5209, -2.5565,  1.0465],
         [ 0.7343, -0.5922,  0.5238,  ..., -2.1038,  1.7081, -1.5929],
         ...,
         [-0.3149, -2.4170,  0.6741,  ...,  1.1370, -0.1797, -0.9960],
         [-1.0531,  1.7398,  0.8478,  ..., -0.2216, -1.0058, -0.0532],
         [-0.4553, -1.7517, -0.5537,  ..., -1.3415, -0.6930, -1.0506]],

        [[-0.2481,  0.3773,  0.7440,  ...,  1.0490, -0.5774, -0.0698],
         [ 0.5352, -1.2255,  0.5004,  ...,  0