In [1]:
# https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
# basically count based table will blow out if we account context 
# The paper is based for word (17000 words embedded in 30 dimensional space)
# we are going to use same idea but for context

In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [3]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [4]:
len(words)

32033

In [5]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
stoi
itos = {i:s for s,i in stoi.items()}
itos

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z',
 0: '.'}

In [6]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?
# i.e. take 3 chars to predict the next one
X, Y = [], []
for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix) # add level
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)


emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [7]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [8]:
# lets create a embeddeing table where all 27 chars maps to 2D 
C = torch.rand((27, 2))

In [9]:
C


tensor([[0.6755, 0.4191],
        [0.4419, 0.1642],
        [0.1208, 0.9162],
        [0.9666, 0.9030],
        [0.1041, 0.3606],
        [0.6100, 0.6337],
        [0.1722, 0.3560],
        [0.0372, 0.1381],
        [0.5646, 0.2962],
        [0.4599, 0.2267],
        [0.9101, 0.6581],
        [0.9152, 0.7263],
        [0.7886, 0.8667],
        [0.7418, 0.9261],
        [0.0466, 0.5448],
        [0.2716, 0.3802],
        [0.4201, 0.8083],
        [0.5910, 0.0315],
        [0.7696, 0.5066],
        [0.5817, 0.1779],
        [0.3195, 0.7945],
        [0.6689, 0.9991],
        [0.3198, 0.7888],
        [0.4346, 0.7467],
        [0.8518, 0.7893],
        [0.0251, 0.7340],
        [0.4047, 0.7798]])

In [10]:
C[5]

tensor([0.6100, 0.6337])

In [11]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([0.6100, 0.6337])

In [12]:
C[X]

tensor([[[0.6755, 0.4191],
         [0.6755, 0.4191],
         [0.6755, 0.4191]],

        [[0.6755, 0.4191],
         [0.6755, 0.4191],
         [0.6100, 0.6337]],

        [[0.6755, 0.4191],
         [0.6100, 0.6337],
         [0.7418, 0.9261]],

        [[0.6100, 0.6337],
         [0.7418, 0.9261],
         [0.7418, 0.9261]],

        [[0.7418, 0.9261],
         [0.7418, 0.9261],
         [0.4419, 0.1642]],

        [[0.6755, 0.4191],
         [0.6755, 0.4191],
         [0.6755, 0.4191]],

        [[0.6755, 0.4191],
         [0.6755, 0.4191],
         [0.2716, 0.3802]],

        [[0.6755, 0.4191],
         [0.2716, 0.3802],
         [0.7886, 0.8667]],

        [[0.2716, 0.3802],
         [0.7886, 0.8667],
         [0.4599, 0.2267]],

        [[0.7886, 0.8667],
         [0.4599, 0.2267],
         [0.3198, 0.7888]],

        [[0.4599, 0.2267],
         [0.3198, 0.7888],
         [0.4599, 0.2267]],

        [[0.3198, 0.7888],
         [0.4599, 0.2267],
         [0.4419, 0.1642]],

    

In [13]:
C[X].shape

torch.Size([32, 3, 2])

In [14]:
X[13, 2]

tensor(1)

In [15]:
C[X][13,2]

tensor([0.4419, 0.1642])

In [16]:
C[1]

tensor([0.4419, 0.1642])

In [17]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [18]:
W1 = torch.randn((6, 100)) # 6 is the 3*2 of X shape
b1 = torch.randn(100)

In [19]:
# emb @ W1 + b1 does not work so we need to transform emb so that 
# matrix multiplication with 3*2 will match
torch.cat(torch.unbind(emb, 1), 1).shape

torch.Size([32, 6])

In [20]:
# other way is using view which is very similar 
# another way to do is called view, read http://blog.ezyang.com/2019/05/pytorch-internals/ 
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
h

tensor([[-4.6502e-01, -7.8521e-01,  9.8901e-01,  ...,  3.4933e-02,
          3.6956e-01, -9.8532e-01],
        [-4.9730e-01, -9.0550e-01,  9.9142e-01,  ..., -5.8638e-04,
          2.3361e-01, -9.8479e-01],
        [-6.8696e-01, -9.6221e-01,  9.9335e-01,  ..., -2.1057e-01,
          1.3310e-01, -9.7775e-01],
        ...,
        [-4.8631e-01, -6.3805e-01,  9.7598e-01,  ..., -7.1882e-01,
          4.0070e-01, -8.4445e-01],
        [ 4.0774e-01, -7.2642e-01,  9.9151e-01,  ..., -1.0644e-01,
          1.4417e-01, -9.6539e-01],
        [ 2.0603e-01, -3.8134e-01,  9.7860e-01,  ...,  1.7770e-01,
          4.1522e-02, -9.5618e-01]])

In [21]:
h.shape

torch.Size([32, 100])

In [22]:
W2 = torch.randn(100, 27) # we need 27 chars output
b2 = torch.randn(27)

In [23]:
logits = h @ W2 + b2
logits.shape

torch.Size([32, 27])

In [24]:
counts = logits.exp()

In [25]:
probs = counts / counts.sum(1, keepdims=True)
probs.shape

torch.Size([32, 27])

In [26]:
probs[0].sum()

tensor(1.0000)

In [27]:
Y

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])

In [28]:
probs[torch.arange(32), Y]

tensor([4.3336e-06, 9.6356e-01, 7.9218e-01, 7.0418e-08, 4.2902e-06, 6.9694e-02,
        3.9994e-03, 9.5597e-11, 1.2267e-03, 2.5235e-07, 2.5137e-10, 6.6951e-06,
        4.0456e-07, 1.5286e-05, 5.0408e-07, 3.3914e-06, 3.0580e-07, 1.0647e-01,
        1.2503e-06, 4.9663e-03, 3.6310e-06, 2.5391e-09, 5.3984e-06, 1.7870e-07,
        7.0339e-06, 2.6393e-03, 1.2841e-01, 4.5638e-03, 1.1325e-12, 3.8819e-11,
        5.6671e-09, 3.9615e-02])

In [29]:
loss = -probs[torch.arange(32), Y].log().mean()
loss

tensor(11.7682)