In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange

In [2]:
words = open('../words.txt', 'r').read().splitlines()

chars = sorted(set('.'.join(words)))

itos = {i: c for i, c in enumerate(chars)}
stoi = {c: i for i, c in enumerate(chars)}


In [24]:
context_len = 8
w_size = 200
feature_count = 25

In [25]:
def build_dataset(words):
    X, Y = [], []
    for word in words:
        context = [0] * context_len
        word += '.'
        for ch in word:
            X.append(context)
            Y.append(stoi[ch])
            context = context[1:] + [stoi[ch]]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

In [26]:
n = int(len(words) * 0.8)

X_train, Y_train = build_dataset(words[:n])

X_val, Y_val = build_dataset(words[n:])

In [27]:
class Bob(nn.Module):
    def __init__(self):
        super(Bob, self).__init__()
        self.c = nn.Embedding(27, feature_count)
        self.l1 = nn.Linear(feature_count*context_len, w_size)
        self.l2 = nn.Linear(w_size, 27)
    
    def forward(self, x):
        x = self.c(x)
        x = x.flatten(1)
        x = self.l1(x)
        x = torch.tanh(x)
        x = self.l2(x)

        return x

In [28]:
model = Bob()

In [43]:
optim = torch.optim.Adam(model.parameters())
avg_loss = 0
trials = 1000
for _ in (t:=trange(trials)):
    batch = torch.randint(0, X_train.shape[0], (50,))
    
    logits = model(X_train[batch])
    
    loss = F.cross_entropy(logits, Y_train[batch])

    avg_loss += loss.data
    
    optim.zero_grad()

    loss.backward()

    optim.step()

    t.set_description(f'loss: {loss.data}')
print(f'avg loss: {avg_loss/trials}')

loss: 1.7236000299453735: 100%|██████████| 1000/1000 [00:01<00:00, 695.60it/s]

avg loss: 1.7066090106964111





In [45]:
F.cross_entropy(model(X_val), Y_val)

tensor(3.4273, grad_fn=<NllLossBackward0>)

In [48]:
for _ in range(10):
    word = ''
    context = [0] * context_len
    while True:
        logits = model(torch.tensor([context]))
        prob = F.softmax(logits)
        ix = torch.multinomial(prob, num_samples=1)[0].item()

        word += itos[ix]
        context = context[1:] + [ix]
        if ix == 0:
            print(word)
            break

robe.
goarger.
hungary.
ramagea.
holdership.
dare.
constract.
prozot.
downer.
savis.


  prob = F.softmax(logits)
