<a href="https://colab.research.google.com/github/jasonlikescats/learn-neural-nets/blob/colab/text-gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
from abc import ABC, abstractmethod
from graphviz import Digraph
import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import string
%matplotlib inline

In [10]:
# read names.txt
with open('names.txt') as f:
    names = f.read().splitlines()

print(names[:20])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn', 'abigail', 'emily', 'elizabeth', 'mila', 'ella', 'avery', 'sofia', 'camila', 'aria', 'scarlett']


In [117]:
# abstract base
class NGramModel(ABC):
    delimiter_token = "."
    vocab = [delimiter_token] + list(string.ascii_lowercase)

    @abstractmethod
    def train(self, words):
        raise NotImplementedError()

    @abstractmethod
    def loss(self, words):
        raise NotImplementedError()

    @abstractmethod
    def predict(self):
        raise NotImplementedError()

    def _inputs(self, words, pad_count = 1):
        for w in words:
            encoded = self._encode(w, pad_count)
            for ngrams in self._ngrams(encoded):
                yield ngrams

    def _encode(self, word, pad_count = 1):
        # pad the word with `pad_count` start tokens, and a trailing end token
        delim_token = self.__class__.delimiter_token
        padded_word = delim_token * pad_count + word + delim_token
        encoded = [self.vocab.index(c) for c in padded_word]
        return encoded

    def _ngrams(self, encoded_word):
        for i in range(0, len(encoded_word) - self.n + 1):
            yield encoded_word[i:i + self.n]

    def _decode(self, encoded):
        return "".join([self.vocab[i] for i in encoded])


In [121]:
class NGramCountingModel(NGramModel):
    def __init__(self, n):
        self.n = n
        self.counts = torch.ones((len(self.vocab),) * n, dtype=torch.int32) # start at ones to apply some smoothing

    def train(self, words):
        for input_ngrams in self._inputs(words, pad_count = (self.n - 1)):
            print(f"incrementing count for: {tuple(input_ngrams)}")
            self.counts[tuple(input_ngrams)] += 1

        print(self.counts.shape)

        self._normalize()

    def loss(self, words):
        # calculate negative log likelihood loss
        log_likelihood = 0.0
        n = 0

        for w in words:
            encoded = self._encode(w)
            for ngrams in self._ngrams(encoded):
                prob = self.counts[tuple(ngrams)]
                logprob = torch.log(prob)
                log_likelihood += logprob
                n += 1

        nll = -log_likelihood
        return nll / n

    def predict(self):
        # start with n-1 start tokens
        prefix = [0] * (self.n - 1)

        # generate a word
        word = []
        while True:
            # get the next token
            token = torch.multinomial(self.counts[tuple(prefix)], 1).item()
            word.append(token)
            prefix = prefix[1:] + [token]
            if token == 0:
                break
        return self._decode(word)

    def _normalize(self):
        self.counts = self.counts / self.counts.sum(dim=-1, keepdim=True)


In [109]:
class NGramNeuralNetModel(NGramModel):
    def train(self, words):
        for w in words[:1]:
            encoded = self._encode(w)
            xs = encoded[:-1]
            ys = encoded[1:]
            print(xs)
            print(ys)

    def loss(self, words):
        raise NotImplementedError()

    def predict(self):
        raise NotImplementedError()

model = NGramNeuralNetModel()
model.train(names[:20])

TypeError: NGramModel.__init__() missing 1 required positional argument: 'n'

In [122]:
# use a subset of the names as the training set
holdout = 0.2
#training_set = names[:int(len(names) * (1.0 - holdout))]
training_set = names[:2]
testing_set = names[int(len(names) * (1.0 - holdout)):]

In [126]:
bigram = NGramCountingModel(2)
bigram.train(training_set)
print(f"bigram loss (training set): {bigram.loss(training_set)}")
#print(f"bigram loss (testing set): {bigram.loss(testing_set)}")

trigram = NGramCountingModel(3)
trigram.train(training_set)
print(f"trigram loss (training set): {trigram.loss(training_set)}")
#print(f"trigram loss (testing set): {trigram.loss(testing_set)}")

quadgram = NGramCountingModel(4)
quadgram.train(training_set)
print(f"quadgram loss (training set): {quadgram.loss(training_set)}")
#print(f"quadgram loss (testing set): {quadgram.loss(testing_set)}")

incrementing count for: (0, 5)
incrementing count for: (5, 13)
incrementing count for: (13, 13)
incrementing count for: (13, 1)
incrementing count for: (1, 0)
incrementing count for: (0, 15)
incrementing count for: (15, 12)
incrementing count for: (12, 9)
incrementing count for: (9, 22)
incrementing count for: (22, 9)
incrementing count for: (9, 1)
incrementing count for: (1, 0)
torch.Size([27, 27])
bigram loss (training set): 2.594874143600464
incrementing count for: (0, 0, 5)
incrementing count for: (0, 5, 13)
incrementing count for: (5, 13, 13)
incrementing count for: (13, 13, 1)
incrementing count for: (13, 1, 0)
incrementing count for: (0, 0, 15)
incrementing count for: (0, 15, 12)
incrementing count for: (15, 12, 9)
incrementing count for: (12, 9, 22)
incrementing count for: (9, 22, 9)
incrementing count for: (22, 9, 1)
incrementing count for: (9, 1, 0)
torch.Size([27, 27, 27])
trigram loss (training set): 2.639057159423828
incrementing count for: (0, 0, 0, 5)
incrementing count 

In [125]:
torch.manual_seed(42)

print("BIGRAM")
for i in range(10):
    print(bigram.predict())

print()
print("TRIGRAM")
for i in range(10):
    print(trigram.predict())

print()
print("QUADGRAM")
for i in range(10):
    print(quadgram.predict())


BIGRAM
yeosyohlvfbgqqdlxgktsmzmgwnyb.
tryvwdxchwndenbedlppmjiaeqybkejmszqsqgtoyzjovtkimpupolznnyzujfosjvslrumxtonaclpgcyaqxknvzioiuqrtensqvgsumjqwqeioionyyrmnyctrjzdwtyazxdbpnftaauzyskakoth.
eyumwlni.
zripawaejtxpoxkelygylxsqfdcnidemwfmbiaaewbtorpaclulyaixbxsmpsvqccgkmbr.
bcqfmfdvivhlqqaspiaycg.
tnaj.
qtjaharpzjdhomfrlwzhwsuecdxlxjaffcgqinsfahigwkychfpnucnlsknkkdimfqcytgmpzildaanaonxdlciav.
sshzdgzlgvxvtcgumarxpsbfvvaomwnmandgevfbluadt.
rulci.
cyjipcsvpwoxarvcbzbe.

TRIGRAM
fb.
dsddtbpfwgjttqp.
ckysxxq.
.
wdsqsbxlkbjiowbkvukyigevzn.
yvhhvknikkavva.
oxzcffvlklvsgghxnuzbiijhrskwlmqkzsfecfmuilshzwfqnoeqpnqenoponnxvmtlstzzyialmaesaveisztqyntxjmgzyqmefclbldftkfkllejvvqiberuhjrcapghutjhsgtnjoxolblqkrhqhjzhpjd.
blaityiegzwpsrfsokqogxdberpjvhzxnurdcpcvpacljmzhkwg.
kxxiunnyoldpsafdphkviofdft.
dgmvrkrwjbjgiocfuigxypaupvfcoetdgushgxszftbai.

QUADGRAM
jtiiyhmno.
wvzvmtdtntx.
eptwjrwaqeulhhdekt.
ks.
hasmnbqmr.
zndcuziajqqckezpewrblmqjofqjshhp.
ntweev.
cqpvpojtpl.
jnzxxdtjhpazvnqfu.
wkltucfzi.
