In [3]:
# Load file
names = open("names.txt").read().split("\n")

In [4]:
# Extract bigrams and loop through them
def getCharPairs(name):
    pairs = []
    for i in range(len(name)-1):
        pairs.append([name[i], name[i+1]])
    return pairs

pairCount = dict()
for n in names:
    charPairs = getCharPairs(n)
    charPairs.insert(0, [".", n[0]])
    charPairs.append([n[len(n)-1], "."])

    for pairs in charPairs:
        p = "".join(pairs)
        pairCount[p] =  pairCount.get(p, 0) + 1

print(sorted(pairCount.items(), key=lambda kv:-kv[1])[:10])

[('n.', 6763), ('a.', 6640), ('an', 5438), ('.a', 4410), ('e.', 3983), ('ar', 3264), ('el', 3248), ('ri', 3033), ('na', 2977), ('.k', 2963)]


In [5]:
# Construct probablity matrix of all character pairs
import torch
import numpy

N = torch.zeros(28, 28, dtype=torch.int)
alphabet = list('.abcdefghijklmnopqrstuvwxyz')
s2i = {c: i for i, c in enumerate(alphabet)}
i2s = {i: c for i, c in enumerate(alphabet)}

for n in names:
    charPairs = getCharPairs(n)
    charPairs.insert(0, [".", n[0]])
    charPairs.append([n[len(n)-1], "."])

    for pairs in charPairs:
        row = s2i[pairs[0]]
        col = s2i[pairs[1]]
        N[row, col] += 1

In [6]:
# Understand broadcasting rules
P = N.float()
P /= P.sum(1, keepdim=True)

# Model inference with multinomial picking(?)
for i in range(10):
    dream_word = ["."]
    while True:
        ix = s2i[dream_word[-1]]
        p = torch.multinomial(P[ix], num_samples=1)
        dream_word.append(i2s[p.item()])
        if (i2s[p.item()] == '.'):
            break
    print("".join(dream_word))

.dagio.
.fann.
.mecimizhamaiky.
.katonickaglyn.
.j.
.neynn.
.kene.
.adaheman.
.brlonn.
.gineyavas.


In [9]:
# Calculate negative log likelihood
total_longprob = 0
count = 0
for n in names:
    charPairs = getCharPairs(n)
    charPairs.insert(0, [".", n[0]])
    charPairs.append([n[len(n)-1], "."])
    for p in charPairs:
        ix1, ix2 = s2i[p[0]], s2i[p[1]]
        cp = "".join(p)
        prob = P[ix1][ix2]
        logprob = torch.log(prob)
        total_longprob += logprob
        count += 1
        
        # print(f'{cp}: {prob:.4f} - {logprob:.4f}')

print(f'{total_longprob.item():.4f}')
print(f'{total_longprob.item()/count:.4f}')

-559891.7500
-2.4541
