# Exploring the Data

In [36]:
#Downloading basic shakespeare dataset
!curl -L -o input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  3129k      0 --:--:-- --:--:-- --:--:-- 3130k


In [37]:
with open("input.txt", 'r', encoding='utf-8') as f:
    text = f.read()

In [38]:
print("Length of text: ", len(text))

Length of text:  1115394


In [39]:

print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [40]:
# Understand characters in data and text
chars = sorted((set(text)))
vocab_size = len(chars)
print(''.join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [41]:
#Create str to int mapping and vice-versa to tokenize
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # take in a string and output a list of integers for each char
decode = lambda l: ''.join([itos[i] for i in l]) # take a list of integers and map them to a string


In [42]:
test_str = "Hey there dude!"
encoded_str = encode(test_str)
decoded_str = decode(encoded_str)

In [43]:
print("Encoded Str: ", encoded_str)
print("Decoded Str: ", decoded_str)
if (decoded_str == test_str):
    print("Decoder works")

Encoded Str:  [20, 43, 63, 1, 58, 46, 43, 56, 43, 1, 42, 59, 42, 43, 2]
Decoded Str:  Hey there dude!
Decoder works


In [44]:
import torch

In [45]:

encoded_input = encode(text)
data = torch.tensor(encoded_input, dtype=torch.long)

print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [46]:
# Train test split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [47]:
#chunking into blocks for training
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

What we'd like is "For every sequence of number (18, (18,47), (18, 47, 57)) the target is what comes after the sequence"

In [48]:
x = train_data[:block_size] # input block
y = train_data[1:block_size+1] # targets will be characters to the right of the intial character

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When context is {context}, target is {target}")

When context is tensor([18]), target is 47
When context is tensor([18, 47]), target is 56
When context is tensor([18, 47, 56]), target is 57
When context is tensor([18, 47, 56, 57]), target is 58
When context is tensor([18, 47, 56, 57, 58]), target is 1
When context is tensor([18, 47, 56, 57, 58,  1]), target is 15
When context is tensor([18, 47, 56, 57, 58,  1, 15]), target is 47
When context is tensor([18, 47, 56, 57, 58,  1, 15, 47]), target is 58


In [49]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8


def get_batch(split):
    """
    Get one batch of x, y
    Input : "train", else assumes val data
    Samples a random index b/w start of data and end (till :-block_size) 
    and outputs batch_size slices of indices of block size - so 4,8 here
    both x and y will be of that size
    """
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) #random indice arrays upto a max of size-block size, do this batch size times
    x = torch.stack([data[i:i+block_size] for i in ix]) 
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y


In [50]:
xb, yb = get_batch("train")
print('inputs ', xb, '\n  targets', yb)

inputs  tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]]) 
  targets tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [54]:
print(xb.shape, yb.shape)

torch.Size([4, 8]) torch.Size([4, 8])


In [51]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t] #note that yb is already right shifted, so the target will be at index t
        print(f"When context is {context.tolist()}, target is {target}")

When context is [24], target is 43
When context is [24, 43], target is 58
When context is [24, 43, 58], target is 5
When context is [24, 43, 58, 5], target is 57
When context is [24, 43, 58, 5, 57], target is 1
When context is [24, 43, 58, 5, 57, 1], target is 46
When context is [24, 43, 58, 5, 57, 1, 46], target is 43
When context is [24, 43, 58, 5, 57, 1, 46, 43], target is 39
When context is [44], target is 53
When context is [44, 53], target is 56
When context is [44, 53, 56], target is 1
When context is [44, 53, 56, 1], target is 58
When context is [44, 53, 56, 1, 58], target is 46
When context is [44, 53, 56, 1, 58, 46], target is 39
When context is [44, 53, 56, 1, 58, 46, 39], target is 58
When context is [44, 53, 56, 1, 58, 46, 39, 58], target is 1
When context is [52], target is 58
When context is [52, 58], target is 1
When context is [52, 58, 1], target is 58
When context is [52, 58, 1, 58], target is 46
When context is [52, 58, 1, 58, 46], target is 39
When context is [52, 5

In [68]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) #(B, T, C)
        # Pytorch expects a B, C, T shape for CrossEntropyLoss
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:,-1,:] # changes dimension to batch, C as we're pluckinkg one T slice (last one)
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

    
m = BigramLanguageModel(vocab_size)
    

In [65]:
out, loss = m(xb, yb)
print(out, out.shape, loss)

tensor([[-1.5101, -0.0948,  1.0927,  ..., -0.6126, -0.6597,  0.7624],
        [ 0.3323, -0.0872, -0.7470,  ..., -0.6716, -0.9572, -0.9594],
        [ 0.2475, -0.6349, -1.2909,  ...,  1.3064, -0.2256, -1.8305],
        ...,
        [-2.1910, -0.7574,  1.9656,  ..., -0.3580,  0.8585, -0.6161],
        [ 0.5978, -0.0514, -0.0646,  ..., -1.4649, -2.0555,  1.8275],
        [-0.6787,  0.8662, -1.6433,  ...,  2.3671, -0.7775, -0.2586]],
       grad_fn=<ViewBackward0>) torch.Size([32, 65]) tensor(4.8786, grad_fn=<NllLossBackward0>)


In [70]:
decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=1000)[0].tolist())

"\npJ:Bpm&yiltNCjeO3:Cx&vvMYW-txjuAd IRFbTpJ$zkZelxZtTlHNzdXXUiQQY:qFINTOBNLI,&oTigq z.c:Cq,SDXzetn3XVjX-YBcHAUhk&PHdhcOb\nnhJ?FJU?pRiOLQeUN!BxjPLiq-GJdUV'hsnla!murI!IM?SPNPq?VgC'R\npD3cLv-bxn-tL!upg\nSZ!Uvdg CtxtT?hsiW:XxKIiPlagHIsr'zKSVxza?GlDWObPmRJgrIAcmspmZ&viCKot:u3qYXA:rZgv f:3Q-oiwUzqh'Z!I'zRS3SP rVchSFUIdd q?sPJpUdhMCK$VXXevXJFMl,i\nYxA:gWId,EXR,iMC,$?srV$VztRwb?KpgUWFjR$zChOLm;JrDnDph\nLBj,KZxJaLPgBAkyzEzSiiQb\njkSVyb$vvyQFuAUAKuzdZAJktRqUiAcPBa;AgJ;.$l3Pu!.IErMfN!PmuQbvx\nxMkttN:PmJh'wNC\nAUI?wNCphq-.IsCwbjxca;P-KA:r'a;pJ&q-UgOEX.cAO-p,lQ?nEsrlvmUgbEQLQh,j;iPlgZR:CJpxIBju f&!BBEHSPmnq,P -d\npjuWDPLFa!ByCSjJuERtKpph.ZP  CUEsiy'FjF$$-rJUQ?uApxlxlYe\nyASBoipGLwfXelgY!a fyFPJX!JDWCoAXRJJFJOlxlvpR?OXYddZAXzkIBtp3d,vAcPlgX'pM fNMLphx&flaAcL!3F.?sBiRwLTqHzot.ttRF$Fv'bL:&x&ayFVqVAHqHxv3QzteqbcUJnERZYGwzLd,rgf&yCnERErI.IVZ WddEJAX CO'Eu!I!Lg:i-$ mIc.\nxJjdLEJXVTb?Eqf IgCJcUGNSBZ3dbsEXgCPmr'XxxDXXxEt'CA't-L!BotNX CJ?.yZBCbYiKH;P YkRocBUMykIfFGRetY!uHN.cp$kzo,I&fiMD-rbjJmho,Rpw:vZEQvjK

In [83]:
#set optimizer
optimizer = torch.optim.Adam(m.parameters(), lr = 1e-4)

## Train loop

In [86]:
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch("train")
    logits, loss= m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    print(loss.item())


2.512427568435669
2.4327569007873535
2.5000734329223633
2.408457040786743
2.5328309535980225
2.4296627044677734
2.5697853565216064
2.5791661739349365
2.3908960819244385
2.4846270084381104
2.5830843448638916
2.4717016220092773
2.5289690494537354
2.50361967086792
2.4217312335968018
2.4291958808898926
2.4638235569000244
2.5317330360412598
2.4141504764556885
2.485024929046631
2.5006773471832275
2.477055788040161
2.4141173362731934
2.6071906089782715
2.4116315841674805
2.513500690460205
2.3534271717071533
2.376973867416382
2.4536898136138916
2.4040331840515137
2.4387030601501465
2.4434025287628174
2.4501893520355225
2.4819109439849854
2.336217164993286
2.4371562004089355
2.4089839458465576
2.398144245147705
2.563629150390625
2.5396316051483154
2.529449939727783
2.377202033996582
2.490579128265381
2.4623913764953613
2.4595606327056885
2.3731119632720947
2.483309745788574
2.4799232482910156
2.535935401916504
2.604159116744995
2.3044915199279785
2.40976881980896
2.371799945831299
2.41048955917

In [87]:
print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=1000)[0].tolist()))



O:
MINCLETe hof mit f beathen:
We as mbe at osllld
I wange: hey
P an y thyowofeld Cin bundevincon Mutiny,
A poue ksit ou me t h E:

WAno,
SSS mal ds?
KI
To thea mat car Prs ORENG fo s, s n, and' ieldel tow m ancousfe, tighareeme w bs hesed aysit's chauld h anouleakenee kisha h,
t y molemenoxce DUShybuthyer gof this y the ve pere t tstrer?
Nolthore melllIUL:
houtheis, wous,
3 d ip s omomy set More
f trenalieflie k o winis ary hesep troroulllithindeve h y
hifea:
SThe s,
Bomo pane

AUCicikne sie d hur, t, all haimen y thatard!-head bepll as are ld
Tother l,
OUEl'dold bu nghetis d mowarnod; ssofoll byoris bend t ce pe
HANViss ly. ree y t
SThary y.


Bug ofan hou h our s k be; co thir mopyoceveis ofoche; gs.

He imy a f ld yewhind ef rar tsnier IO, be fonthtoupis teas s!
st
LUThoun rd orest f at kn htheen y sligue:
Abegrdell aprertha bim:
MAMy wnd's t h tthoue.
O: t tho?
PUTopat.
Showasero y
Anckem, t f tisles burikn.
Nat selld, s w,

Ank,
DO:
Monaived,
Sugntt moulline thyor diet caste ba

In [110]:
input = "."
inp = encode(input)
inp = torch.tensor(inp).unsqueeze(0)
print(inp.shape)
print(decode(m.generate(inp, 100)[0].tolist()))

torch.Size([1, 1])
.

WISe, she,
Tas, her,
TULLONTH:
Ant?
Norle.
'can

LERMyshiels ie tes adr en Jur mesharserise f l.


