<a href="https://colab.research.google.com/github/jasonlikescats/learn-neural-nets/blob/main/text-gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
from abc import ABC, abstractmethod
from graphviz import Digraph
import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import string
%matplotlib inline

In [4]:
# read names.txt
with open('names.txt') as f:
    names = f.read().splitlines()

print(names[:20])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn', 'abigail', 'emily', 'elizabeth', 'mila', 'ella', 'avery', 'sofia', 'camila', 'aria', 'scarlett']


In [5]:
# abstract base
class NGramModel(ABC):
    delimiter_token = "."
    vocab = [delimiter_token] + list(string.ascii_lowercase)

    @abstractmethod
    def train(self, words):
        raise NotImplementedError()

    @abstractmethod
    def loss(self, words):
        raise NotImplementedError()

    @abstractmethod
    def predict(self):
        raise NotImplementedError()

    def _inputs(self, words, pad_count = 1):
        for w in words:
            encoded = self._encode(w, pad_count)
            for ngrams in self._ngrams(encoded):
                yield ngrams

    def _encode(self, word, pad_count = 1):
        # pad the word with `pad_count` start tokens, and a trailing end token
        delim_token = self.__class__.delimiter_token
        padded_word = delim_token * pad_count + word + delim_token
        encoded = [self.vocab.index(c) for c in padded_word]
        return encoded

    def _ngrams(self, encoded_word):
        for i in range(0, len(encoded_word) - self.n + 1):
            yield encoded_word[i:i + self.n]

    def _decode(self, encoded):
        return "".join([self.vocab[i] for i in encoded])


In [6]:
class NGramCountingModel(NGramModel):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.encode_pad_count = self.n - 1
        self.counts = torch.ones((len(self.vocab),) * n, dtype=torch.int32) # start at ones to apply some smoothing

    def train(self, words):
        for input_ngrams in self._inputs(words, pad_count = self.encode_pad_count):
            self.counts[tuple(input_ngrams)] += 1

        self._normalize()

    def loss(self, words):
        # calculate negative log likelihood loss
        log_likelihood = 0.0
        n = 0

        for w in words:
            encoded = self._encode(w, pad_count = self.encode_pad_count)
            for ngrams in self._ngrams(encoded):
                prob = self.counts[tuple(ngrams)]
                logprob = torch.log(prob)
                log_likelihood += logprob
                n += 1

        nll = -log_likelihood
        return nll / n

    def predict(self):
        # start with n-1 start tokens
        prefix = [0] * self.encode_pad_count

        # generate a word
        word = []
        while True:
            # get the next token
            token = torch.multinomial(self.counts[tuple(prefix)], 1).item()
            word.append(token)
            prefix = prefix[1:] + [token]
            if token == 0:
                break
        return self._decode(word)

    def _normalize(self):
        self.counts = self.counts / self.counts.sum(dim=-1, keepdim=True)


In [7]:
class BigramNeuralNetModel(NGramModel):
    def __init__(self, steps, learning_rate, regularization_strength):
        super().__init__()
        # randomly initialize 27 neurons' weights. Each neuron receives 27 inputs.
        self.W = torch.randn((len(self.vocab), len(self.vocab)), requires_grad=True)
        self.steps = steps
        self.lr = learning_rate
        self.rs = regularization_strength

    def train(self, words):
        xenc, yenc = self._one_hot_inputs(words) # input to the network: one-hot encoding

        # gradient descent
        for k in range(self.steps):
            logits = self._forward_pass(xenc)
            probs = self._softmax(logits)
            loss = self._loss(probs, yenc)

            # backward pass
            self.W.grad = None
            loss.backward()

            # update weights
            self.W.data += -self.lr * self.W.grad

    def loss(self, words):
        xenc, yenc = self._one_hot_inputs(words)
        probs = self._forward_pass(xenc)
        probs = self._softmax(probs)
        return self._loss(probs, yenc)

    def predict(self):
        out = []
        ix = 0
        while True:
            xenc = F.one_hot(torch.tensor([ix]), num_classes=len(self.vocab)).float()
            logits = self._forward_pass(xenc)
            probs = self._softmax(logits)

            ix = torch.multinomial(probs, num_samples=1, replacement=True).item()
            out.append(self.vocab[ix])
            if ix == 0:
                break

        return "".join(out)

    def _one_hot_inputs(self, words):
        xs, ys = [], []
        for w in words:
            encoded = self._encode(w, pad_count=1)
            xs.extend(encoded[:-1])
            ys.extend(encoded[1:])

        xs = torch.tensor(xs)
        ys = torch.tensor(ys)

        xenc = F.one_hot(xs, num_classes=len(self.vocab)).float()
        yenc = F.one_hot(ys, num_classes=len(self.vocab)).float()

        return xenc, yenc

    def _loss(self, probs, yactual):
        masked = probs * yactual
        likelihood = masked.sum(dim=1)
        log_likelihood = torch.log(likelihood)
        nll = -log_likelihood

        regularization_loss = (self.W ** 2).mean()

        return nll.mean() + self.rs * regularization_loss

    def _forward_pass(self, xenc):
        logits = xenc @ self.W # predict log-counts
        return logits

    def _softmax(self, logits):
        counts = logits.exp() # equivalent to counts array (N) in counts model
        probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
        return probs

    def _backward_pass(self, loss):
        self.W.grad = None # zero the gradient
        loss.backward()

In [8]:
# use a subset of the names as the training set
holdout = 0.2
training_set = names[:int(len(names) * (1.0 - holdout))]
testing_set = names[int(len(names) * (1.0 - holdout)):]

In [9]:
bigram = NGramCountingModel(2)
bigram.train(training_set)
print(f"bigram loss (training set): {bigram.loss(training_set)}")
print(f"bigram loss (testing set): {bigram.loss(testing_set)}")

trigram = NGramCountingModel(3)
trigram.train(training_set)
print(f"trigram loss (training set): {trigram.loss(training_set)}")
print(f"trigram loss (testing set): {trigram.loss(testing_set)}")

nn_bigram = BigramNeuralNetModel(steps=50, learning_rate=50, regularization_strength=0.1)
nn_bigram.train(training_set)
print(f"nn bigram loss (training set): {nn_bigram.loss(training_set)}")
print(f"nn bigram loss (testing set): {nn_bigram.loss(testing_set)}")

nn_bigram_no_reg = BigramNeuralNetModel(steps=50, learning_rate=50, regularization_strength=0.00)
nn_bigram_no_reg.train(training_set)
print(f"nn bigram (no reg) loss (training set): {nn_bigram_no_reg.loss(training_set)}")
print(f"nn bigram (no reg) loss (testing set): {nn_bigram_no_reg.loss(testing_set)}")

bigram loss (training set): 2.42488956451416
bigram loss (testing set): 2.593641757965088
trigram loss (training set): 2.1762611865997314
trigram loss (testing set): 2.4267663955688477
nn bigram loss (training set): 2.569553852081299
nn bigram loss (testing set): 2.7266845703125
nn bigram (no reg) loss (training set): 2.4683141708374023
nn bigram (no reg) loss (testing set): 2.6408777236938477


In [10]:
torch.manual_seed(42)

print("BIGRAM")
for i in range(10):
    print(bigram.predict())

print()
print("TRIGRAM")
for i in range(10):
    print(trigram.predict())

print()
print("NN BIGRAM")
for i in range(10):
    print(nn_bigram.predict())

print()
print("NN BIGRAM (NO REGULARIZATION)")
for i in range(10):
    print(nn_bigram_no_reg.predict())


BIGRAM
a.
feeenvi.
s.
mabian.
dan.
stan.
silaylelaremah.
li.
le.
epiachalen.

TRIGRAM
dmzi.
kence.
jordon.
kalla.
miqrqyjaya.
vihia.
acen.
kaitharcephelia.
sotte.
seliya.

NN BIGRAM
kgha.
abr.
n.
annara.
reynnn.
sor.
pjjiewx.
liljahm.
fhi.
bradonele.

NN BIGRAM (NO REGULARIZATION)
ky.
k.
feliavaha.
aanone.
brral.
cdo.
t.
damayanel.
ritahjairiatla.
ka.


In [12]:
chars = sorted(list(set("".join(names))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi[NGramModel.delimiter_token] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [41]:
class MLP(NGramModel):
    block_size = 3

    def __init__(self, chars):
        super().__init__()
        self.n = self.block_size + 1

    def train(self, words):
        self._dataset(words[:5])

    def loss(self, words):
        raise NotImplementedError()

    def predict(self):
        raise NotImplementedError()

    def _dataset(self, words):
        X, Y = [], []

        ngrams = self._inputs(words, pad_count = self.block_size)
        for ngram in ngrams:
            X.append(ngram[:self.block_size])
            Y.append(ngram[-1])

        X = torch.tensor(X)
        Y = torch.tensor(Y)
        X, Y

chars = sorted(list(set("".join(names))))
MLP(chars).train(names)