<a href="https://colab.research.google.com/github/jasonlikescats/learn-neural-nets/blob/colab/text-gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from graphviz import Digraph
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import string
%matplotlib inline

In [5]:
# read names.txt
with open('names.txt') as f:
    names = f.read().splitlines()

print(names[:20])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn', 'abigail', 'emily', 'elizabeth', 'mila', 'ella', 'avery', 'sofia', 'camila', 'aria', 'scarlett']


In [61]:
class NGramModel:
    def __init__(self, n):
        self.n = n
        self.vocab = ["."] + list(string.ascii_lowercase) # '.' is start/end token
        self.counts = torch.ones((len(self.vocab),) * n, dtype=torch.int32) # start at ones to apply some smoothing

    def train(self, words):
        for w in words:
            encoded = self._encode(w)
            for ngrams in self._ngrams(encoded):
                self.counts[tuple(ngrams)] += 1

        self._normalize()

    def loss(self, words):
        # calculate negative log likelihood loss
        log_likelihood = 0.0
        n = 0

        for w in words:
            encoded = self._encode(w)
            for ngrams in self._ngrams(encoded):
                prob = self.counts[tuple(ngrams)]
                logprob = torch.log(prob)
                log_likelihood += logprob
                n += 1

        nll = -log_likelihood
        return nll / n

    def predict(self):
        # start with n-1 start tokens
        prefix = [0] * (self.n - 1)

        # generate a word
        word = []
        while True:
            # get the next token
            token = torch.multinomial(self.counts[tuple(prefix)], 1).item()
            word.append(token)
            prefix = prefix[1:] + [token]
            if token == 0:
                break
        return self._decode(word)

    def _encode(self, word):
        # pad the word with n-1 start tokens, and a trailing end token
        padded_word = "." * (self.n - 1) + word + "."
        encoded = [self.vocab.index(c) for c in padded_word]
        return encoded

    def _ngrams(self, encoded_word):
        for i in range(0, len(encoded_word) - self.n + 1):
            yield encoded_word[i:i + self.n]

    def _decode(self, encoded):
        return "".join([self.vocab[i] for i in encoded])

    def _normalize(self):
        self.counts = self.counts / self.counts.sum(dim=-1, keepdim=True)


In [62]:
# use a subset of the names as the training set
holdout = 0.2
training_set = names[:int(len(names) * (1.0 - holdout))]
testing_set = names[int(len(names) * (1.0 - holdout)):]

In [63]:
bigram = NGramModel(2)
bigram.train(training_set)
print(f"bigram loss (training set): {bigram.loss(training_set)}")
print(f"bigram loss (testing set): {bigram.loss(testing_set)}")

trigram = NGramModel(3)
trigram.train(names)
print(f"trigram loss (training set): {trigram.loss(training_set)}")
print(f"trigram loss (testing set): {trigram.loss(testing_set)}")

quadgram = NGramModel(4)
quadgram.train(names)
print(f"quadgram loss (training set): {quadgram.loss(training_set)}")
print(f"quadgram loss (testing set): {quadgram.loss(testing_set)}")

bigram loss (training set): 2.42488956451416
bigram loss (testing set): 2.593641757965088
trigram loss (training set): 2.17937970161438
trigram loss (testing set): 2.344642400741577
quadgram loss (training set): 2.0542197227478027
quadgram loss (testing set): 2.2288053035736084


In [59]:
torch.manual_seed(42)

print("BIGRAM")
for i in range(10):
    print(bigram.predict())

print()
print("TRIGRAM")
for i in range(10):
    print(trigram.predict())

print()
print("QUADGRAM")
for i in range(10):
    print(quadgram.predict())


BIGRAM
ya.
syahle.
ahe.
dleekahmangonya.
tryahe.
chen.
ena.
dlyamiiae.
a.
keles.

TRIGRAM
lo.
atophasiani.
pepolannezelloriahlam.
xanna.
lun.
camarivie.
auguth.
shirahmolm.
ei.
tony.

QUADGRAM
rin.
chrishan.
ana.
baylianna.
skaan.
hadysyn.
nia.
rilanaelee.
oakley.
clara.
