In [9]:
from graphviz import Digraph
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import string
%matplotlib inline

In [5]:
# read names.txt
with open('names.txt') as f:
    names = f.read().splitlines()

print(names[:20])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn', 'abigail', 'emily', 'elizabeth', 'mila', 'ella', 'avery', 'sofia', 'camila', 'aria', 'scarlett']


In [54]:
class NGramModel:
    def __init__(self, n):
        self.n = n
        self.vocab = ["."] + list(string.ascii_lowercase) # '.' is start/end token
        self.counts = torch.zeros((len(self.vocab),) * n)

    def train(self, words):
        for w in words:
            encoded = self._encode(w)
            # iterate through n-grams
            for i in range(0, len(encoded) - self.n + 1):
                ngram = encoded[i:i + self.n]
                self.counts[tuple(ngram)] += 1

    def predict(self):
        # start with n-1 start tokens
        prefix = [0] * (self.n - 1)

        # generate a word
        word = []
        while True:
            # get the next token
            token = torch.multinomial(self.counts[tuple(prefix)], 1).item()
            word.append(token)
            prefix = prefix[1:] + [token]
            if token == 0:
                break
        return self._decode(word)

    def _encode(self, word):
        # pad the word with n-1 start tokens, and a trailing end token
        padded_word = "." * (self.n - 1) + word + "."
        encoded = [self.vocab.index(c) for c in padded_word]
        return encoded

    def _decode(self, encoded):
        return "".join([self.vocab[i] for i in encoded])


In [58]:
bigram = NGramModel(2)
bigram.train(names)

trigram = NGramModel(3)
trigram.train(names)

quadgram = NGramModel(4)
quadgram.train(names)

In [None]:
torch.manual_seed(42)

print("BIGRAM")
for i in range(10):
    print(bigram.predict())

print()
print("TRIGRAM")
for i in range(10):
    print(trigram.predict())

print()
print("QUADGRAM")
for i in range(10):
    print(quadgram.predict())


BIGRAM
ya.
syahavilin.
dleekahmangonya.
tryahe.
chen.
ena.
da.
amiiae.
a.
keles.


NameError: name 'println' is not defined