In [3]:
# !pip install torch

In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
words = open("rocities.txt", "r").read().splitlines()
words[:10]

['1 DECEMBRIE',
 '2 MAI',
 '23 AUGUST',
 'ABRAM',
 'ABRAMUT',
 'ABRUD',
 'ABRUD-SAT',
 'ABUCEA',
 'ABUD',
 'ABUS']

In [4]:
len(words)

10148

In [8]:
# build the vocabulary and mappings to/from ints
chars = sorted(list(set(''.join(words))))
stoi = {s: i + 1 for i,s in enumerate(chars)}
stoi['*'] = 0
itos = {i: s for s,i in stoi.items()}
print(stoi)
chars_count = len(stoi.values())

{' ': 1, '-': 2, '.': 3, '1': 4, '2': 5, '3': 6, 'A': 7, 'B': 8, 'C': 9, 'D': 10, 'E': 11, 'F': 12, 'G': 13, 'H': 14, 'I': 15, 'J': 16, 'K': 17, 'L': 18, 'M': 19, 'N': 20, 'O': 21, 'P': 22, 'R': 23, 'S': 24, 'T': 25, 'U': 26, 'V': 27, 'X': 28, 'Y': 29, 'Z': 30, '*': 0}


In [55]:
# dataset

block_size = 3 # length of context: no. of characters we use to predict the next

X, Y = [], []
for w in words:
    # print(w)
    context = [0] * block_size
    for ch in w + '*':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        # print(''.join(itos[i] for i in context), '-->', itos[ix])
        context = context[1:] + [ix] # crop and append, like a rolling window
        
X = torch.tensor(X)
Y = torch.tensor(Y)

In [56]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([98696, 3]), torch.int64, torch.Size([98696]), torch.int64)

In [10]:
chars_count

31

In [11]:
C = torch.randn((chars_count, 2))  # lookup table - each char is 2 dimensional

In [13]:
emb = C[X]
emb.shape

torch.Size([42, 3, 2])

In [14]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [15]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [16]:
h

tensor([[ 0.9534,  0.9526, -0.9999,  ...,  0.3725,  0.8822,  0.9132],
        [ 0.9981,  0.9784, -1.0000,  ...,  0.8397,  0.9609,  0.9558],
        [-0.9837,  0.9834, -0.9677,  ...,  0.9980, -0.6042,  0.5047],
        ...,
        [ 0.9984, -0.1281, -0.9944,  ..., -0.9721, -0.4807,  0.0122],
        [ 0.8547,  0.7818, -0.9986,  ...,  0.9968, -0.9415,  0.9750],
        [-0.8325,  1.0000, -0.9164,  ...,  0.9927,  0.6431,  0.6123]])

In [18]:
h.shape

torch.Size([42, 100])

In [27]:
W2 = torch.randn((100, chars_count))
b2 = torch.randn(chars_count)

In [28]:
logits = h @ W2 + b2

In [29]:
logits.shape

torch.Size([42, 31])

In [24]:
counts = logits.exp()

In [25]:
prob = counts / counts.sum(1, keepdims=True)

In [26]:
prob.shape

torch.Size([42, 27])

In [31]:
loss = -prob[torch.arange(42), Y].log().mean()
loss

tensor(17.4903)

In [33]:
# ======== Refactor ========

In [57]:
X.shape, Y.shape

(torch.Size([98696, 3]), torch.Size([98696]))

In [58]:
g = torch.Generator().manual_seed(21)  # for reproductibility
C = torch.randn((chars_count, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, chars_count), generator=g)
b2 = torch.randn(chars_count, generator=g)
parameters = [C, W1, b1, W2, b2]

In [59]:
sum(p.nelement() for p in parameters)  # total number of parameters

3893

In [48]:
emb = C[X]  # [42, 3, 2]
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
logits = h @ W2 + b2  # (42, 31)
loss = F.cross_entropy(logits, Y)

In [49]:
loss

tensor(20.0864)

In [61]:
for p in parameters:
    p.requires_grad = True

In [67]:
for _ in range(100):
    # minibatch construct
    ix = torch.randint(0, X.shape[0], (42,))
    
    # forward pass
    emb = C[X[ix]]  # [42, 3, 2]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2  # (42, 31)
    loss = F.cross_entropy(logits, Y[ix])
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data += -0.1 * p.grad

    print(loss.item())

6.7284016609191895
6.301486015319824
5.432294845581055
4.476382732391357
6.1416778564453125
4.517703533172607
4.642545223236084
5.423710346221924
5.116897106170654
4.875205039978027
4.9433064460754395
4.411256790161133
4.985806941986084
4.239737510681152
5.490335941314697
4.340299129486084
4.274064540863037
5.079934120178223
4.896963596343994
4.317893028259277
4.233260154724121
4.068843841552734
3.275258779525757
3.78747296333313
4.735307693481445
4.421080589294434
3.3719940185546875
3.5921976566314697
4.75383996963501
3.204467296600342
3.7896244525909424
3.934983491897583
3.8381943702697754
3.8798763751983643
3.596116304397583
3.5235540866851807
4.156767845153809
3.6482174396514893
4.463482856750488
3.1701819896698
2.6076102256774902
3.734800338745117
3.7171595096588135
3.899839162826538
3.571465253829956
2.888944149017334
3.5402255058288574
3.8374197483062744
3.4488213062286377
3.4916555881500244
3.3796186447143555
3.7395923137664795
3.7499003410339355
3.741804838180542
3.48736453056