In [57]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F

In [58]:
torch.manual_seed(0)

<torch._C.Generator at 0x1123c0b30>

In [59]:
file_path = Path("./input.txt")
text = file_path.read_text(encoding="utf-8")
vocab = sorted(set(text))
vocab_size = len(vocab)

In [60]:
stoi = {v:k for k, v in enumerate(vocab)}
itos = {v:k for k, v in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda s: "".join(itos[i] for i in s)

In [61]:
# Traning data into tokens
data = encode(text)
split = int(0.9 * len(data))
train_data = data[:split]
val_data = data[split:]

In [62]:
batch_size = 4
block_size = 8
def get_batches(dt):
    data = train_data if (dt == "train") else val_data
    ix = torch.randint(0, len(data)-8, (batch_size,))
    train_batch = torch.stack([torch.tensor(data[i:i+block_size]) for i in ix])
    target_batch = torch.stack([torch.tensor(data[i+1:i+block_size+1]) for i in ix])

    return train_batch, target_batch

In [63]:
xb, yb = get_batches("train_data")

In [64]:
xb

tensor([[41, 39, 52,  1, 41, 53, 52, 57],
        [ 1, 50, 47, 51, 47, 58,  6,  0],
        [53, 56,  1, 51, 63, 57, 43, 50],
        [58, 57, 61, 39, 47, 52, 10,  0]])

In [65]:
def visualize(data):
    train, target = get_batches(data)
    print(train)
    print(target)
    for tr, tg in zip(train, target):
        for index in range(block_size):
            print(f"If train_data is {tr[0:index+1]} then target is {tg[index]}")

# visualize("train_data")


In [134]:
class BiGramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, xb, target=None):
        logits = self.emb(xb)
        B, T, C = logits.shape
        # cross entropy takes input a flat structure as follows in case of logits.
        if target is not None:
            loss = F.cross_entropy(logits.view((B*T), C), target.view(-1))
        else:
            loss = None
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on last time step, becomes B, C
            logits = logits[:,-1,:]
            # apply softmax to get probabilites
            probs = F.softmax(logits, dim=1)
            # sample from the distribution, returns indices
            idx_next = torch.multinomial(probs, num_samples=1)
            # we are building a context here.
            idx = torch.cat([idx, idx_next], dim=1) # (B, T+1)

        return idx


model = BiGramLanguageModel(vocab_size)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss)

torch.Size([4, 8, 65])
tensor(4.2626, grad_fn=<NllLossBackward0>)


In [136]:
decode(model.generate(torch.tensor([[0]]), 100)[0].tolist())

"\nuLELkeJfmJqCxdu-CxD.gQPRp$llORv'R!wPe.tNl;dLXNpantmODk,3XahXLlIpsgC'XUUtwGEqT3Cl$THFdjOsGNtAfvQvRBpY"

In [139]:
# create a pytroch optimizer object
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [148]:
batch_size = 32
for steps in range(10000):
    xb, yb = get_batches("train_data")
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.419814348220825


In [151]:
print(decode(model.generate(torch.tensor([[0]]), 400)[0].tolist()))


Tow's, o ufureyobe neithariaid keeithand Kange ay?
CUCHor s IO:
ANoufr t imeplt wisoneede t sistoceabeave bllsathest?
stone t s m d are.
Whingot nd 'd.
A:
VRA ff t. slenee s of enon
TRO:
ISPERTHOnous pl, CHADin oubooustr sinicarod towhan hin bint t ysofonsth t sh
ASI IA:
'sinshengou ts Kam umaninour'merd SI ond he
PENAPENATRUwede t by be'le, as s, n be VIONDouce f tll hepr.

Pher IERUMy whed te mi
