In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

from typing import Union

In [None]:
import string
NUM_CHARS = 27
CHARS = list(string.ascii_lowercase)
CTOI = {c: i+1 for i, c in enumerate(CHARS)}
CTOI['.'] = 0
ITOC = {i: c for c, i in CTOI.items()}

In [None]:
class Dataset:

  def __init__(self, path: str, split: list = [0.8, 0.1, 0.1]) -> None:
    self.words = open(path, 'r').read().splitlines()
    self.train_set, test_eval = train_test_split(self.words, train_size=split[0])
    self.eval_set, self.test_set = train_test_split(test_eval, train_size=split[1])

data = Dataset('../../data/names.txt')
print(len(data.train_set))

In [None]:
class NGramExplicitModel:

  def __init__(self, n: int, data: Dataset) -> None:
    self.n = n
    self.data = data

  def count(self):
    self.counts = torch.zeros(tuple([NUM_CHARS for _ in range(self.n)]), dtype=torch.int32)
    for word in self.data.train_set:
      context = [0] * (self.n-1)
      for c in word + '.':
        ix = CTOI[c]
        indices = tuple(context + [ix])
        self.counts[indices] += 1
        context = context[1:] + [ix]
      
    self.P = F.normalize((self.counts+1).float(), p=1, dim=-1)
  
  def sample(self, generator: torch.Generator, num_samples: int = 1):
    for i in range(num_samples):
      sample = []
      context = [0 for _ in range(self.n-1)]
      while True:
        ix = torch.multinomial(self.P[tuple(context)], num_samples=1, replacement=True, generator=generator).item()
        sample.append(ITOC[ix])
        if ix == 0:
          break
        context = context[1:] + [ix]

      print(''.join(sample))

  def eval(self, eval_set: Union[str, Dataset] = 'eval'):
    # average neg log likelihood
    llh, n = 0, 0
    if type(eval_set) is str:
      eval_set = self.data.eval_set if eval_set=='eval' else self.data.train_set if eval_set=='train' else self.data.test_set
    else:
      eval_set = eval_set.eval_set
    for word in eval_set:
      context = [0 for _ in range(self.n-1)]
      for c in word + '.':
        ix = CTOI[c]
        indices = tuple(context + [ix])
        logprob = torch.log(self.P[indices])
        llh += logprob
        n += 1
        context = context[1:] + [ix]
    return -llh/n

In [None]:
class NGramMLPModel:

  def __init__(self, n: int, data: Dataset, embed_dim: int, rc: float, hidden_layer_size=100) -> None:
    self.n = n
    self.data = data
    self.embed_dim = embed_dim
    self.rc = rc          # regularization coefficient
    self.xs, self.ys = self.prepare_ds()
    self.init_network(hidden_layer_size)

  def prepare_ds(self, prepare_set: Union[str, Dataset] = 'train'):
    if type(prepare_set) is str:
      prepare_set = self.data.eval_set if prepare_set=='eval' else self.data.train_set if prepare_set=='train' else self.data.test_set
    else: 
      prepare_set = prepare_set.train_set
    # Create train set
    xs, ys = [], []
    for w in prepare_set:
      context = [0] * (self.n-1) 
      for c in w + '.':
        ix = CTOI[c]
        xs.append(context)
        ys.append(CTOI[c])
        context = context[1:] + [ix]
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    return xs, ys

  def init_network(self, hidden_layer_size: int):
    g = torch.Generator().manual_seed(2147483647)
    self.C = torch.randn((NUM_CHARS, self.embed_dim), generator=g, requires_grad=True)
    # C[xs].shape = (num_data_pairs, context_len, embed_dim)
    self.W1 = torch.randn(((self.n-1)*self.embed_dim, hidden_layer_size), generator=g, requires_grad=True)
    self.b1 = torch.randn(hidden_layer_size)
    self.W2 = torch.randn((hidden_layer_size, NUM_CHARS), generator=g, requires_grad=True)
    self.b2 = torch.randn(NUM_CHARS)
    self.params = [self.C, self.W1, self.b1, self.W2, self.b2]
    for p in self.params:
      p.requires_grad = True


  def train_network(self, num_iters: int, lr: float):
    losses = []
    for k in range(num_iters):
      # NN forward pass
      xenc = self.C[self.xs]
      xenc = xenc.view(-1, (self.n-1)*self.embed_dim)
      h = torch.tanh(xenc @ self.W1 + self.b1)             # log counts -> only thing that will change in Transformers
      logits = h @ self.W2 + self.b2
      # counts = logits.exp()              # equivalent to counts
      # P = counts / counts.sum(dim=1, keepdims=True)
      # last 2 lines: softmax
      
      # loss: negative llh of probs corresponding to true labels
      loss = F.cross_entropy(logits, self.ys) # + self.rc*(self.W**2).mean()
      # loss = -P[torch.arange(self.ys.nelement()), self.ys].log().mean() + self.rc*(self.W**2).mean()
      
      ## NN backward pass
      for p in self.params:
        p.grad = None       # set grad to 0
      loss.backward()
      if k%10 == 0:
        print(f'Iter {k}, loss {loss.item()}')
      losses.append(loss.item())
      for p in self.params:
        p.data += -lr*p.grad
    plt.plot(range(len(losses)), losses)
    plt.ylim(bottom=0)
    plt.xlabel('Training epoch'); plt.ylabel('Loss')
    plt.show()


  def sample(self, generator: torch.Generator, num_samples: int):
    for i in range(num_samples):
      sample = []
      context = [0 for _ in range(self.n-1)]
      while True:
        xenc = self.C[torch.tensor(context)].flatten().unsqueeze(0)
        logits = xenc @ self.W
        counts = logits.exp()
        p = F.normalize(counts.float(), p=1, dim=-1)

        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=generator).item()
        sample.append(ITOC[ix])
        if ix == 0:
          break
        context = context[1:] + [ix]

      print(''.join(sample))

  def eval(self, eval_set: Union[str, Dataset] = 'eval'):
    # average neg log likelihood
    xs, ys = self.prepare_ds(eval_set)
    xenc = self.C[xs]
    xenc = xenc.view(-1, (self.n-1)*self.embed_dim)
    h = torch.tanh(xenc @ self.W1 + self.b1)
    logits = h @ self.W2 + self.b2
    loss = F.cross_entropy(logits, ys) # + self.rc*(self.W1**2).mean()      
    # counts = logits.exp()             
    # P = counts / counts.sum(dim=1, keepdims=True)
    # loss = -P[torch.arange(ys.nelement()), ys].log().mean() + self.rc*(self.W**2).mean() 
    return loss.item()
  

In [None]:
g = torch.Generator().manual_seed(2147483647)
data = Dataset('../../data/names.txt')

In [None]:
model = NGramExplicitModel(n=2, data=data)
model.count()
model.eval('eval')

In [None]:
%matplotlib inline
for rc in [0.5]:
  model = NGramMLPModel(2, data, 20, rc)
  model.lr = 1.0
  model.train_network(100, model.lr)
  loss = model.eval('eval')
  # print(rc, loss-rc*(model.W**2).mean())

In [None]:
g = torch.Generator().manual_seed(2147483647)
model.sample(generator=g, num_samples=5)