In [135]:
import numpy as np
import torch
import matplotlib.pyplot as plt
%matplotlib inline
from itertools import product

In [136]:
# 1. Build a trigram model, we will need to add an extra start character to each word in the dataset. Start word will be '..' instead of '.' for bigram

In [137]:
words = open('names.txt', 'r').read().splitlines()
words[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [138]:
#trigram takes the last 2 chars to predict the 3rd.
word = words[0]
word

'emma'

In [139]:
test_word = '..' + word + '.'
xs = []
for ch1,ch2,ch3 in zip(word, word[1:], word[2:]):
    print(ch1, ch2, ch3)


e m m
m m a


In [140]:
trigrams = {}
for word in words:
    word2 = '..'+word+'.'
    for ch1,ch2,ch3 in zip(word, word[1:], word[2:]):
        t = (ch1+ch2,ch3)
        trigrams[t] = trigrams.get(t,0)+1
trigrams

{('em', 'm'): 100,
 ('mm', 'a'): 72,
 ('ol', 'i'): 69,
 ('li', 'v'): 54,
 ('iv', 'i'): 78,
 ('vi', 'a'): 147,
 ('av', 'a'): 161,
 ('is', 'a'): 142,
 ('sa', 'b'): 76,
 ('ab', 'e'): 173,
 ('be', 'l'): 201,
 ('el', 'l'): 822,
 ('ll', 'a'): 337,
 ('so', 'p'): 21,
 ('op', 'h'): 37,
 ('ph', 'i'): 61,
 ('hi', 'a'): 81,
 ('ch', 'a'): 236,
 ('ha', 'r'): 329,
 ('ar', 'l'): 287,
 ('rl', 'o'): 44,
 ('lo', 't'): 14,
 ('ot', 't'): 34,
 ('tt', 'e'): 121,
 ('mi', 'a'): 95,
 ('am', 'e'): 226,
 ('me', 'l'): 188,
 ('el', 'i'): 537,
 ('li', 'a'): 518,
 ('ar', 'p'): 8,
 ('rp', 'e'): 5,
 ('pe', 'r'): 77,
 ('ev', 'e'): 142,
 ('ve', 'l'): 76,
 ('el', 'y'): 353,
 ('ly', 'n'): 976,
 ('ab', 'i'): 76,
 ('bi', 'g'): 15,
 ('ig', 'a'): 35,
 ('ga', 'i'): 18,
 ('ai', 'l'): 259,
 ('em', 'i'): 160,
 ('mi', 'l'): 259,
 ('il', 'y'): 203,
 ('li', 'z'): 81,
 ('iz', 'a'): 93,
 ('za', 'b'): 40,
 ('be', 't'): 61,
 ('et', 'h'): 114,
 ('il', 'a'): 279,
 ('av', 'e'): 166,
 ('ve', 'r'): 160,
 ('er', 'y'): 84,
 ('so', 'f'): 18,
 ('

In [141]:
## Show trigrams and their counts in descending order
sorted(trigrams.items(), key = lambda kv: -kv[1])

[(('ly', 'n'), 976),
 (('ar', 'i'), 950),
 (('an', 'n'), 825),
 (('el', 'l'), 822),
 (('an', 'a'), 804),
 (('ia', 'n'), 790),
 (('ma', 'r'), 776),
 (('an', 'i'), 703),
 (('iy', 'a'), 669),
 (('la', 'n'), 647),
 (('nn', 'a'), 633),
 (('al', 'e'), 601),
 (('al', 'i'), 575),
 (('sh', 'a'), 562),
 (('el', 'i'), 537),
 (('li', 'a'), 518),
 (('le', 'e'), 517),
 (('yn', 'n'), 516),
 (('ya', 'h'), 511),
 (('li', 'n'), 505),
 (('ri', 'a'), 499),
 (('ay', 'l'), 483),
 (('ya', 'n'), 479),
 (('ha', 'n'), 469),
 (('ia', 'h'), 461),
 (('le', 'y'), 443),
 (('am', 'a'), 431),
 (('le', 'i'), 401),
 (('ie', 'l'), 395),
 (('ri', 'e'), 394),
 (('an', 'd'), 392),
 (('ay', 'a'), 389),
 (('le', 'n'), 383),
 (('yl', 'a'), 381),
 (('in', 'a'), 379),
 (('to', 'n'), 377),
 (('ar', 'a'), 371),
 (('ri', 's'), 360),
 (('am', 'i'), 355),
 (('el', 'y'), 353),
 (('al', 'a'), 353),
 (('so', 'n'), 341),
 (('ll', 'a'), 337),
 (('ll', 'e'), 331),
 (('ha', 'r'), 329),
 (('de', 'n'), 318),
 (('al', 'y'), 310),
 (('el', 'a')

In [142]:
# Need to encode pairs of chars into one hot
trigrams.items()

dict_items([(('em', 'm'), 100), (('mm', 'a'), 72), (('ol', 'i'), 69), (('li', 'v'), 54), (('iv', 'i'), 78), (('vi', 'a'), 147), (('av', 'a'), 161), (('is', 'a'), 142), (('sa', 'b'), 76), (('ab', 'e'), 173), (('be', 'l'), 201), (('el', 'l'), 822), (('ll', 'a'), 337), (('so', 'p'), 21), (('op', 'h'), 37), (('ph', 'i'), 61), (('hi', 'a'), 81), (('ch', 'a'), 236), (('ha', 'r'), 329), (('ar', 'l'), 287), (('rl', 'o'), 44), (('lo', 't'), 14), (('ot', 't'), 34), (('tt', 'e'), 121), (('mi', 'a'), 95), (('am', 'e'), 226), (('me', 'l'), 188), (('el', 'i'), 537), (('li', 'a'), 518), (('ar', 'p'), 8), (('rp', 'e'), 5), (('pe', 'r'), 77), (('ev', 'e'), 142), (('ve', 'l'), 76), (('el', 'y'), 353), (('ly', 'n'), 976), (('ab', 'i'), 76), (('bi', 'g'), 15), (('ig', 'a'), 35), (('ga', 'i'), 18), (('ai', 'l'), 259), (('em', 'i'), 160), (('mi', 'l'), 259), (('il', 'y'), 203), (('li', 'z'), 81), (('iz', 'a'), 93), (('za', 'b'), 40), (('be', 't'), 61), (('et', 'h'), 114), (('il', 'a'), 279), (('av', 'e'), 1

In [143]:
ctoi = { x: i for i, x in enumerate(sorted(set(''.join(words) + '.')))}
#invert the mapping for int to char
itoc = {v: k for k,v in ctoi.items()}

In [144]:
len(list(product(set(''.join(words)+'.'),set(''.join(words)+'.'))))

729

In [145]:
char_prod = list(sorted(product(set(''.join(words)+'.'),set(''.join(words)+'.'))))

In [146]:
stoi = {x: i for i, x in enumerate(char_prod)}
stoi

{('.', '.'): 0,
 ('.', 'a'): 1,
 ('.', 'b'): 2,
 ('.', 'c'): 3,
 ('.', 'd'): 4,
 ('.', 'e'): 5,
 ('.', 'f'): 6,
 ('.', 'g'): 7,
 ('.', 'h'): 8,
 ('.', 'i'): 9,
 ('.', 'j'): 10,
 ('.', 'k'): 11,
 ('.', 'l'): 12,
 ('.', 'm'): 13,
 ('.', 'n'): 14,
 ('.', 'o'): 15,
 ('.', 'p'): 16,
 ('.', 'q'): 17,
 ('.', 'r'): 18,
 ('.', 's'): 19,
 ('.', 't'): 20,
 ('.', 'u'): 21,
 ('.', 'v'): 22,
 ('.', 'w'): 23,
 ('.', 'x'): 24,
 ('.', 'y'): 25,
 ('.', 'z'): 26,
 ('a', '.'): 27,
 ('a', 'a'): 28,
 ('a', 'b'): 29,
 ('a', 'c'): 30,
 ('a', 'd'): 31,
 ('a', 'e'): 32,
 ('a', 'f'): 33,
 ('a', 'g'): 34,
 ('a', 'h'): 35,
 ('a', 'i'): 36,
 ('a', 'j'): 37,
 ('a', 'k'): 38,
 ('a', 'l'): 39,
 ('a', 'm'): 40,
 ('a', 'n'): 41,
 ('a', 'o'): 42,
 ('a', 'p'): 43,
 ('a', 'q'): 44,
 ('a', 'r'): 45,
 ('a', 's'): 46,
 ('a', 't'): 47,
 ('a', 'u'): 48,
 ('a', 'v'): 49,
 ('a', 'w'): 50,
 ('a', 'x'): 51,
 ('a', 'y'): 52,
 ('a', 'z'): 53,
 ('b', '.'): 54,
 ('b', 'a'): 55,
 ('b', 'b'): 56,
 ('b', 'c'): 57,
 ('b', 'd'): 58,
 ('b', 

In [147]:
#invert the mapping for int to s
itos = {v: k for k,v in stoi.items()}

In [308]:
n = len(stoi.items())
n

729

In [149]:
stoi[('.','.')]

0

In [150]:
torch.zeros([27,27])

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
      

In [151]:
first_word = 'abdxz'
first_word = '..'+first_word+'..'
for ch1, ch2, ch3 in zip(first_word, first_word[1:], first_word[2:]):
    i1, i2 = stoi[(ch1, ch2)], ctoi[ch3]
    print(ch1, ch2, ch3)
    print(i1, i2)

. . a
0 1
. a b
1 2
a b d
29 4
b d x
58 24
d x z
132 26
x z .
674 0
z . .
702 0


In [309]:
# First method counting, create probability matrix
N = torch.zeros([n,27], dtype=torch.int32) # N by 27 matrix. N is number of posisble pairs, 27 is number of chars to predict
for word in words:
    word2 = '..'+word+'.'
    for ch1,ch2, ch3 in zip(word2, word2[1:], word2[2:]):
        i1, i2 = stoi[(ch1,ch2)], ctoi[ch3]
        N[i1, i2] += 1 # Increase the count    

In [310]:
N[0] # Represents the counts for each char when the input is '..'

tensor([   0, 4410, 1306, 1542, 1690, 1531,  417,  669,  874,  591, 2422, 2963,
        1572, 2538, 1146,  394,  515,   92, 1639, 2055, 1308,   78,  376,  307,
         134,  535,  929], dtype=torch.int32)

In [311]:
N.shape

torch.Size([729, 27])

In [312]:
# Convert counts to probabilities
P = (N+1).float()
P = P / P.sum(1, keepdim=True)

In [313]:
print(P[0].sum())
P[0]

tensor(1.)


tensor([3.1192e-05, 1.3759e-01, 4.0767e-02, 4.8129e-02, 5.2745e-02, 4.7785e-02,
        1.3038e-02, 2.0898e-02, 2.7293e-02, 1.8465e-02, 7.5577e-02, 9.2452e-02,
        4.9064e-02, 7.9195e-02, 3.5777e-02, 1.2321e-02, 1.6095e-02, 2.9008e-03,
        5.1154e-02, 6.4130e-02, 4.0830e-02, 2.4641e-03, 1.1759e-02, 9.6070e-03,
        4.2109e-03, 1.6719e-02, 2.9008e-02])

In [314]:
a = ('.','.')
a[0]

'.'

In [315]:
g = torch.Generator().manual_seed(2147483647)
p = P[0]
i =  torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
print(i)
itoc[i]


10


'j'

In [317]:
print(P.shape)
ix = stoi[('.', 'n')]
P[14]

torch.Size([729, 27])


tensor([0.0009, 0.4015, 0.0009, 0.0009, 0.0017, 0.1407, 0.0009, 0.0051, 0.0034,
        0.1935, 0.0009, 0.0009, 0.0009, 0.0009, 0.0034, 0.1338, 0.0009, 0.0009,
        0.0009, 0.0009, 0.0009, 0.0213, 0.0009, 0.0009, 0.0009, 0.0793, 0.0026])

In [319]:
# Predict some words
g = torch.Generator().manual_seed(2147483647)
for i in range(10):
    out = []
    prev = ['.','.']
    while True:
        ix = stoi[(prev[0], prev[1])]
        p = P[ix] 
        i =  torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        c = itoc[i]
        prev[0], prev[1] = prev[1], c
        out.append(c)
        if(i == 0):
            break
    print(''.join(out))
    

junide.
jakasid.
prelay.
adin.
kairritoper.
sathen.
sameia.
yanileniassibduinrwin.
lessiyanayla.
te.


In [320]:
P.shape

torch.Size([729, 27])

In [293]:
"""
Output from bigram model
junide.
janasah.
p.
cony.
a.
nn.
kohin.
tolian.
juee.
ksahnaauranilevias.
"""


'\nOutput from bigram model\njunide.\njanasah.\np.\ncony.\na.\nnn.\nkohin.\ntolian.\njuee.\nksahnaauranilevias.\n'

In [321]:
P.shape

torch.Size([729, 27])

In [322]:
log_likelihood = 0
nx = 0
for word in words:
    word2 = '..'+word+'.'
    for ch1,ch2, ch3 in zip(word2, word2[1:], word2[2:]):
        i1 = stoi[(ch1, ch2)]
        i2 = ctoi[ch3]
        prob=P[i1,i2]
        nx+=1
        log_likelihood += torch.log(prob)
        
print(f'll : {log_likelihood}')
nll = -log_likelihood
print(f'nll: {nll}')
avg_nll = nll/nx
print(f'avgnll: {avg_nll}')

ll : -504653.0
nll: 504653.0
avgnll: 2.2119739055633545


In [327]:
# Make the model a neural net
# Weights matrix to optimize
g = torch.Generator().manual_seed(2147483647)
n = len(stoi.items())
W = torch.randn(n, 27, generator=g, requires_grad=True)
W.shape

torch.Size([729, 27])

In [296]:
# Create a one hot encoding of each letter. When multiply one-hot with W, it is the same as selecting a row (like in the bigram counts model)
import torch.nn.functional as F
x = stoi[('b','a')]
X = F.one_hot(torch.tensor(x), num_classes=n).float()
X.shape

torch.Size([729])

In [326]:
n

228146

In [298]:
stoi

{('.', '.'): 0,
 ('.', 'a'): 1,
 ('.', 'b'): 2,
 ('.', 'c'): 3,
 ('.', 'd'): 4,
 ('.', 'e'): 5,
 ('.', 'f'): 6,
 ('.', 'g'): 7,
 ('.', 'h'): 8,
 ('.', 'i'): 9,
 ('.', 'j'): 10,
 ('.', 'k'): 11,
 ('.', 'l'): 12,
 ('.', 'm'): 13,
 ('.', 'n'): 14,
 ('.', 'o'): 15,
 ('.', 'p'): 16,
 ('.', 'q'): 17,
 ('.', 'r'): 18,
 ('.', 's'): 19,
 ('.', 't'): 20,
 ('.', 'u'): 21,
 ('.', 'v'): 22,
 ('.', 'w'): 23,
 ('.', 'x'): 24,
 ('.', 'y'): 25,
 ('.', 'z'): 26,
 ('a', '.'): 27,
 ('a', 'a'): 28,
 ('a', 'b'): 29,
 ('a', 'c'): 30,
 ('a', 'd'): 31,
 ('a', 'e'): 32,
 ('a', 'f'): 33,
 ('a', 'g'): 34,
 ('a', 'h'): 35,
 ('a', 'i'): 36,
 ('a', 'j'): 37,
 ('a', 'k'): 38,
 ('a', 'l'): 39,
 ('a', 'm'): 40,
 ('a', 'n'): 41,
 ('a', 'o'): 42,
 ('a', 'p'): 43,
 ('a', 'q'): 44,
 ('a', 'r'): 45,
 ('a', 's'): 46,
 ('a', 't'): 47,
 ('a', 'u'): 48,
 ('a', 'v'): 49,
 ('a', 'w'): 50,
 ('a', 'x'): 51,
 ('a', 'y'): 52,
 ('a', 'z'): 53,
 ('b', '.'): 54,
 ('b', 'a'): 55,
 ('b', 'b'): 56,
 ('b', 'c'): 57,
 ('b', 'd'): 58,
 ('b', 

In [329]:
n

729

In [341]:
# Create the training set for the NN
xs, ys = [], []
for word in words:
    word2 = '..'+word+'.'
    for ch1,ch2, ch3 in zip(word2, word2[1:], word2[2:]):
        i1, i2 = stoi[(ch1,ch2)], ctoi[ch3]
        xs.append(i1)
        ys.append(i2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

num = xs.nelement()
print('number of examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((n, 27), generator=g, requires_grad=True)

number of examples:  228146


In [342]:
print(xs.shape)
print(ys.shape)
print(W.shape)

torch.Size([228146])
torch.Size([228146])
torch.Size([729, 27])


In [357]:
ys

tensor([ 5, 13, 13,  ..., 26, 24,  0])

In [356]:
probs[torch.arange(num)].shape

torch.Size([228146])

In [343]:
# Training loop
# Weights matrix to optimize
X = F.one_hot(xs, num_classes=n).float()
for i in range(200):
    logits = X @ W
    counts = logits.exp() # equivalent to N
    probs= counts/counts.sum(dim=1,keepdim=True)
    loss = -probs[torch.arange(num),ys].log().mean() + 0.01*(W**2).mean()
    print(loss.item())

    # Reset gradients
    W.grad = None
    loss.backward()

    # Update params
    W.data += -100 * W.grad

3.8028228282928467
3.548058032989502
3.4275360107421875
3.335411310195923
3.258559226989746
3.193568468093872
3.1377511024475098
3.089114189147949
3.046081304550171
3.007580518722534
2.9728496074676514
2.9413368701934814
2.9126152992248535
2.8863420486450195
2.8622324466705322
2.84004545211792
2.8195714950561523
2.8006269931793213
2.783048152923584
2.7666876316070557
2.751415491104126
2.7371158599853516
2.7236874103546143
2.7110402584075928
2.699098587036133
2.6877946853637695
2.677070379257202
2.666875123977661
2.6571640968322754
2.6478993892669678
2.6390457153320312
2.630573034286499
2.622453212738037
2.6146631240844727
2.60718035697937
2.599984884262085
2.5930585861206055
2.58638596534729
2.579951524734497
2.573742151260376
2.5677452087402344
2.5619490146636963
2.5563433170318604
2.5509185791015625
2.545665740966797
2.5405759811401367
2.535642385482788
2.5308566093444824
2.526212692260742
2.5217034816741943
2.5173239707946777
2.5130679607391357
2.508930206298828
2.5049057006835938
2

In [344]:
W.shape

torch.Size([729, 27])

In [345]:
# Predict some words
g = torch.Generator().manual_seed(2147483647)
for i in range(10):
    out = []
    prev = ['.','.']
    while True:
        ix = stoi[(prev[0], prev[1])]
        xenc = F.one_hot(torch.tensor([ix]), num_classes=n).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts/counts.sum(dim=1,keepdim=True)
        i =  torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        c = itoc[i]
        prev[0], prev[1] = prev[1], c
        out.append(c)
        if(i == 0):
            break
    print(''.join(out))

junide.
janasid.
prelay.
adin.
kairritonian.
juwa.
kalinaaryanileniassdbyainrwibel.
se.
siely.
arte.


In [348]:
# Output is similar to the counts matrix prediction. 
# The neural net was successfully trained to output the same probabilities that were manually calculated in the first method.
# The weight matrix W (learned by the model) is equivalent to the probablility matrix P (calculated using the counts)

## Create train, dev and test sets
train - 80%
dev - 10%
test - 10%


In [370]:
from torch.utils.data import random_split
total_size = len(words)
train_size = int(total_size*0.8)
dev_size = int(total_size * 0.1)
test_size = total_size - train_size - dev_size
print('total size: ', total_size)
print('train size: ', train_size)
print('dev size: ', dev_size)
print('test size: ', test_size)
print('sum of data sets: ',train_size + dev_size+test_size)

total size:  32033
train size:  25626
dev size:  3203
test size:  3204
sum of data sets:  32033


In [373]:
train_set, dev_set, test_set = random_split(words, [train_size, dev_size, test_size])

### Train trigram model on train set

In [397]:
for word in dev_set:
    print(word)

godson
annuel
anabela
jepson
vibha
khalilah
berlynn
kymoni
arlette
rein
raheim
benicia
kathaleia
janhvi
estephany
maneh
haleem
kathy
lavonte
yoana
sireen
vedansh
liyanna
grayden
sainabou
darlynn
courtland
zawadi
dupree
raylie
mung
maija
riko
lula
twila
nisreen
cem
jacqueline
koen
adelyse
matteo
farryn
celina
bain
keigan
kort
evonna
jibran
treniyah
iliya
gionni
leighanna
ronaldo
walker
keisy
nella
brahms
iktan
coleston
hira
bernie
kynnsley
kylee
afina
stanislav
kaze
saydee
quinlynn
swade
joely
jette
azariah
zeyah
zaiden
helayna
ciani
saarth
shreeja
hrehaan
marietta
iyonna
maryanna
naudia
alann
julyen
raffi
broden
rilla
fuad
melisa
ainzley
ayza
daylanie
zahvia
demichael
shilynn
bertie
elora
hawa
eliannie
yochanan
demoni
marleah
jennalee
maayan
kelaiah
anjel
joanna
dalexa
kahari
nzinga
masaki
jameson
beckett
madylin
kymarion
ryman
traven
perry
morgen
kenmari
rella
avynn
maikel
landrum
zora
knoxley
yaiden
lavaughn
rozalee
nitara
caio
marjona
maybrie
jenan
kase
rhaya
aureliana
payton
yina
j

In [448]:
# Create the training set for the NN
xs, ys = [], []
for word in train_set:
    word2 = '..'+word+'.'
    for ch1,ch2, ch3 in zip(word2, word2[1:], word2[2:]):
        i1, i2 = stoi[(ch1,ch2)], ctoi[ch3]
        xs.append(i1)
        ys.append(i2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

num = xs.nelement()
print('number of examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((n, 27), generator=g, requires_grad=True)

number of examples:  182674


In [462]:
xs_dev, ys_dev = [], []
for word in dev_set:
    word2 = '..'+word+'.'
    for ch1,ch2, ch3 in zip(word2, word2[1:], word2[2:]):
        i1, i2 = stoi[(ch1,ch2)], ctoi[ch3]
        xs_dev.append(i1)
        ys_dev.append(i2)

xs_dev = torch.tensor(xs_dev)
ys_dev = torch.tensor(ys_dev)

In [469]:
xs_test, ys_test = [], []
for word in test_set:
    word2 = '..'+word+'.'
    for ch1,ch2, ch3 in zip(word2, word2[1:], word2[2:]):
        i1, i2 = stoi[(ch1,ch2)], ctoi[ch3]
        xs_test.append(i1)
        ys_test.append(i2)

xs_test = torch.tensor(xs_test)
ys_test = torch.tensor(ys_test)

In [458]:
(X@W)[2]

tensor([ 3.9578, -0.4348, -0.0539, -0.4854,  0.0446,  0.4236, -0.9397, -0.8102,
         1.6464,  1.2489, -0.7677,  0.0687,  1.5163,  0.0353,  1.1158, -1.1041,
        -1.3672, -1.2351,  0.7581,  0.9465,  1.2025, -1.3606,  0.8488, -1.4026,
        -1.2099,  1.1784,  0.0683], grad_fn=<SelectBackward0>)

In [459]:
W[xs][2]

tensor([ 3.9578, -0.4348, -0.0539, -0.4854,  0.0446,  0.4236, -0.9397, -0.8102,
         1.6464,  1.2489, -0.7677,  0.0687,  1.5163,  0.0353,  1.1158, -1.1041,
        -1.3672, -1.2351,  0.7581,  0.9465,  1.2025, -1.3606,  0.8488, -1.4026,
        -1.2099,  1.1784,  0.0683], grad_fn=<SelectBackward0>)

In [485]:
counts

tensor([[ 0.0478,  5.9157,  1.7419,  ...,  0.2171,  0.7011,  1.2216],
        [ 0.5342, 30.8479,  0.5342,  ...,  0.5334,  5.4190,  0.5674],
        [36.6664,  0.6113,  0.7961,  ...,  0.3141,  2.1784,  0.8390],
        ...,
        [ 6.5779,  3.5752,  0.6902,  ...,  0.5024,  0.9882,  0.5922],
        [ 3.2358,  5.3480,  0.5381,  ...,  0.5720,  2.6025,  2.3989],
        [11.7033, 11.5195,  0.5570,  ...,  0.5572,  1.5450,  0.5931]],
       grad_fn=<ExpBackward0>)

In [486]:
logits.shape

torch.Size([182674, 27])

In [487]:
ys.shape

torch.Size([182674])

In [495]:
# Training loop
# Weights matrix to optimize
for i in range(200):
    logits = W[xs]
  #  counts = logits.exp() # equivalent to N
   # probs= counts/counts.sum(dim=1,keepdim=True)
    #loss = -probs[torch.arange(num),ys].log().mean() + 0.5*(W**2).mean()
    loss = F.cross_entropy(logits, ys, label_smoothing=0.5)
    
    # Reset gradients
    W.grad = None
    loss.backward()

    # Update params
    W.data += -100 * W.grad

In [491]:
def MLP_loss(x, y, W):
    xenc = F.one_hot(x, num_classes = n).float()

    # probs is softmax
    logits = xenc @ W
    #counts = torch.exp(logits)
    #probs = counts / counts.sum(dim = 1, keepdim = True)

    # loss (normalized negative log likelihood)
    #loss = - probs[torch.arange(len(x)), y].log().mean()
    loss = F.cross_entropy(logits, y)
    return loss.item()

In [496]:
print('Train Loss', MLP_loss(xs, ys, W))
print('Dev Loss', MLP_loss(xs_dev, ys_dev, W))
print('Test Loss', MLP_loss(xs_test, ys_test, W))

Train Loss 2.461979627609253
Dev Loss 2.4769515991210938
Test Loss 2.4828848838806152


In [494]:
# Predict some words
g = torch.Generator().manual_seed(2147483647)
for i in range(10):
    out = []
    prev = ['.','.']
    while True:
        ix = stoi[(prev[0], prev[1])]
        xenc = F.one_hot(torch.tensor([ix]), num_classes=n).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts/counts.sum(dim=1,keepdim=True)
        i =  torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        c = itoc[i]
        prev[0], prev[1] = prev[1], c
        out.append(c)
        if(i == 0):
            break
    print(''.join(out))

junide.
janasid.
presay.
adin.
kairritonian.
juwa.
kalinaauriniahmiassdbduinrwin.
lessiyanayla.
te.
farmumarif.
