In [96]:
import torch
import torch.nn.functional as F

In [97]:
with open('names.txt', 'r') as f:
    names = f.read().splitlines()

possible_chars = sorted(list(set(''.join(names) + '.')))
n_chars = len(possible_chars)

In [98]:
x = []
y = []

for name in names:
    name_mod = ['.'] + list(name) + ['.']
    name_mod = [possible_chars.index(i) for i in name_mod]
    x.extend(name_mod[:-1])
    y.extend(name_mod[1:])

x = torch.tensor(x)
y = torch.tensor(y)

x_enc = F.one_hot(x, num_classes=n_chars).float()
n_x_enc = x_enc.shape[0]

In [122]:
# Building Model and Optimizing It

W = torch.randn((27, 27), requires_grad=True) # 27 neurons, each receiving 27 inputs (one hot)
n_epochs = 100
learning_rate = 50

for i in range(n_epochs):

    # forward pass
    logits = x_enc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(n_x_enc), y].log().mean()

    if i%10 == 0:
        print(loss)

    # backpropagate
    W.grad = None
    loss.backward()
    
    # update
    W.data += -learning_rate * W.grad


tensor(3.7919, grad_fn=<NegBackward0>)
tensor(2.6687, grad_fn=<NegBackward0>)
tensor(2.5663, grad_fn=<NegBackward0>)
tensor(2.5280, grad_fn=<NegBackward0>)
tensor(2.5086, grad_fn=<NegBackward0>)
tensor(2.4969, grad_fn=<NegBackward0>)
tensor(2.4891, grad_fn=<NegBackward0>)
tensor(2.4834, grad_fn=<NegBackward0>)
tensor(2.4792, grad_fn=<NegBackward0>)
tensor(2.4758, grad_fn=<NegBackward0>)


In [126]:
# Sampling New Examples from the Model

n_words = 10

for _ in range(n_words):
    c = '.'
    word = '.'
    while True:
        idx = torch.tensor(possible_chars.index(c))
        one_hot = F.one_hot(idx, num_classes=27).float()
        logits = (one_hot @ W).exp()
        probs = logits / logits.sum()
        
        idx = torch.multinomial(probs, replacement=True, num_samples=1)
        c = possible_chars[idx]
        word += c
        if c == '.':
            break
    print(word)

.aerasosh.
.pech.
.jaaien.
.sylde.
.anayove.
.molieita.
.mahi.
.juheliekakere.
.jarisbrahay.
.zerruo.
