In [1]:
import torch
import torch.nn.functional as F

In [2]:
names = open('./data/names.txt', 'r').read().split()

In [3]:
chars = sorted(list(set("".join(names))))
chars =  chars + ["."]

c_to_ind = {c: i for i, c in enumerate(chars)}
ind_to_c = {i: c for i, c in enumerate(chars)}

In [10]:
xs, ys = [], []

for w in names:
    fnames = ["."] + list(w) + ["."]
    for ch1, ch2 in zip(fnames, fnames[1:]):
        ind1 = c_to_ind[ch1]
        ind2 = c_to_ind[ch2]

        xs.append(ind1)
        ys.append(ind2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [26]:
g = torch.Generator().manual_seed(42)
W = torch.randn((27, 27), requires_grad=True, generator=g)

for epoch in range(100):
    x_enc = F.one_hot(xs, num_classes=27)
    x_enc = x_enc.to(torch.float32)
    out = torch.mm(x_enc, W)
    count = out.exp() 
    P = count / count.sum(dim=1, keepdims=True)
    logP = P.log()
    nll = -logP
    ynll = nll[torch.arange(xs.size(-1)), ys]
    loss = ynll.mean()

    print(f"Epoch {epoch}: loss: {loss.item()}")
    W.grad = None
    loss.backward()
    W.data = W.data - 50*W.grad

Epoch 0: loss: 3.9101641178131104
Epoch 1: loss: 3.4949116706848145
Epoch 2: loss: 3.2308170795440674
Epoch 3: loss: 3.0640244483947754
Epoch 4: loss: 2.956132650375366
Epoch 5: loss: 2.8777856826782227
Epoch 6: loss: 2.818321466445923
Epoch 7: loss: 2.7715532779693604
Epoch 8: loss: 2.733788251876831
Epoch 9: loss: 2.7028534412384033
Epoch 10: loss: 2.677318811416626
Epoch 11: loss: 2.6560897827148438
Epoch 12: loss: 2.6382734775543213
Epoch 13: loss: 2.6231541633605957
Epoch 14: loss: 2.6101722717285156
Epoch 15: loss: 2.598902940750122
Epoch 16: loss: 2.5890238285064697
Epoch 17: loss: 2.580291271209717
Epoch 18: loss: 2.5725209712982178
Epoch 19: loss: 2.5655694007873535
Epoch 20: loss: 2.5593245029449463
Epoch 21: loss: 2.553694486618042
Epoch 22: loss: 2.5486040115356445
Epoch 23: loss: 2.543989419937134
Epoch 24: loss: 2.5397942066192627
Epoch 25: loss: 2.5359697341918945
Epoch 26: loss: 2.5324742794036865
Epoch 27: loss: 2.529268980026245
Epoch 28: loss: 2.526320695877075
Epoch

## Sample from MLP

In [24]:
g = torch.Generator().manual_seed(1)

new_names = []
for name_ind in range(10):
    log_likelihood = 0.0
    p = torch.rand((1, 26))
    start_ix = c_to_ind['.']
    start_cix = ind_to_c[start_ix]
    curr_word = [start_cix]
    while True:
        x_enc = F.one_hot(torch.tensor([c_to_ind[curr_word[-1]]]), num_classes=27).float()
        # print(x_enc.shape)
        logits = torch.mm(x_enc, W)
        count = logits.exp()
        p = count / count.sum(dim=1, keepdims=True)
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        cix = ind_to_c[ix]
        if cix == '.':
            break
        logprob = torch.log(P[c_to_ind[curr_word[-1]], ix])
        log_likelihood += logprob
        curr_word.append(cix)
    new_names.append("".join(curr_word))
    nll = -log_likelihood
    print(f"Word: {new_names[-1]}, NLL: {nll/len(new_names[-1]):.4f}")
    

new_names

Word: .aanaditanevayahyia, NLL: 2.9958
Word: .dry, NLL: 3.3555
Word: .zbannbronyanerayricarderiouwfdenilyxf, NLL: 3.4469
Word: .se, NLL: 1.7117
Word: .an, NLL: 1.7512
Word: .massti, NLL: 2.5214
Word: .meennnalllel, NLL: 2.7577
Word: .et, NLL: 2.0445
Word: .kaynyare, NLL: 2.7607
Word: .d, NLL: 2.8152


['.aanaditanevayahyia',
 '.dry',
 '.zbannbronyanerayricarderiouwfdenilyxf',
 '.se',
 '.an',
 '.massti',
 '.meennnalllel',
 '.et',
 '.kaynyare',
 '.d']