In [105]:
words = open('names.txt', 'r').read().splitlines()

In [106]:
# bigram model

import torch
import torch.nn.functional as F

def create_bigram_char_convertion_maps(words):
    chars = sorted(list(set(''.join(words))))
    stoi = {s:i+1 for i,s in enumerate(chars)}
    stoi['.'] = 0
    itos = {i:s for s,i in stoi.items()}
    return (stoi, itos)

def create_bigram_xs_ys(words, stoi, itos): 
    xs, ys = [], []
    for w in words:
      chs = ['.'] + list(w) + ['.']
      for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    return (xs, ys)

def train_bigram_model(words, regularization=0.01):
    (stoi, itos) = create_bigram_char_convertion_maps(words)
    (xs, ys) = create_bigram_xs_ys(words, stoi, itos)
    num = xs.nelement()
    g = torch.Generator().manual_seed(2147483647)
    W = torch.randn((27, 27), generator=g, requires_grad=True)

    tolerance = 1e-3 
    
    for k in range(10000):
      xenc = F.one_hot(xs, num_classes=27).float()
      logits = xenc @ W
      counts = logits.exp()
      probs = counts / counts.sum(1, keepdims=True)  
      loss = -probs[torch.arange(num), ys].log().mean() + regularization*(W**2).mean()
      
      # backward pass
      W.grad = None
      loss.backward()

      # stop when grad does not change much
      if W.grad.norm().item() < tolerance:
        break
          
      # update
      W.data += -50 * W.grad

    print(f"grad={W.grad.norm().item()} loss={loss.item()}")
    return W

_ = train_bigram_model(words)

grad=0.0009974666172638535 loss=2.4844870567321777


In [107]:
# trigram model

import itertools as itt
import torch
import torch.nn.functional as F

def create_trigram_convertion_maps(words):
    chars = sorted(list(set(''.join(words))))
    stoi = {s:i+1 for i,s in enumerate(chars)}
    stoi['.'] = 0
    itos = {i:s for s,i in stoi.items()}
    chars.append('.')
    bigrams = [ch1+ch2 for ch1, ch2 in list(set(itt.product(chars, chars)))]
    bigramtoi = {cp:i for i,cp in enumerate(bigrams)}
    itobigram = {i:cp for cp,i in bigramtoi.items()}
    return (stoi, bigramtoi, itos)


def create_trigram_xs_ys(words, stoi, bigramtoi):
    xs, ys = [], []
    for w in words:
      chs = ['.'] + list(w) + ['.']
      for (ch1, ch2, ch3) in zip(chs, chs[1:], chs[2:]):
        trigram = (ch1+ch2, ch3)
        ix1 = bigramtoi[trigram[0]]
        ix2 = stoi[ch3]
        xs.append(ix1)
        ys.append(ix2)
    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    return (xs, ys)

def train_trigram_model(words, regularization = 0.01):
    (stoi, bigramtoi, _) = create_trigram_convertion_maps(words)
    (xs, ys) = create_trigram_xs_ys(words, stoi, bigramtoi)

    xenc = F.one_hot(xs, num_classes=729).float()
    xenc.shape

    g = torch.Generator().manual_seed(2147483647)
    W = torch.randn((729, 27), generator=g, requires_grad=True)

    tolerance = 1e-3 
    
    for k in range(10000):
        
      # forward pass
      logits = xenc @ W
      counts = logits.exp()
      probs = counts / counts.sum(1, keepdims=True)
      loss = -probs[torch.arange(xs.nelement()), ys].log().mean() + regularization*(W**2).mean()
      
      # backward pass
      W.grad = None
      loss.backward()

      # stop when grad does not change much
      if W.grad.norm().item() < tolerance:
        break
      
      # update
      W.data += -300 * W.grad
    
    print(f"grad={W.grad.norm().item()} loss={loss.item()}")
    return W

_ = train_trigram_model(words)

grad=0.0009976092260330915 loss=2.1318936347961426


In [108]:
# split words in training data, test data, dev data

from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(words, test_size=0.2, random_state=42)
dev_data, test_data = train_test_split(test_data, test_size=0.5, random_state=42)

In [109]:
# E02: split up the dataset randomly into 80% train set, 10% dev set, 10% test set. 
# Train the bigram and trigram models only on the training set. 
# Evaluate them on dev and test splits. What can you see?

def bigram_accuracy(W, data):
    g = torch.Generator().manual_seed(2147483647)
    (stoi, itos) = create_bigram_char_convertion_maps(words)
    (xs, ys) = create_bigram_xs_ys(data, stoi, itos)
    accuracy = 0
    for x, y in zip(xs, ys):
        xenc = F.one_hot(torch.tensor([x.item()]), num_classes=27).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)
        x2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        if x2 == y.item():
            accuracy += 1
    return accuracy

def trigram_accuracy(W, data):
    g = torch.Generator().manual_seed(2147483647)
    (stoi, bigramtoi, itos) = create_trigram_convertion_maps(words)
    (xs, ys) = create_trigram_xs_ys(data, stoi, bigramtoi)
    accuracy = 0
    for x, y in zip(xs, ys):
        xenc = F.one_hot(torch.tensor([x.item()]), num_classes=729).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)
        x2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        if x2 == y.item():
            accuracy += 1
    return accuracy

print("-------------------------")
W_bigram = train_bigram_model(train_data)
W_bigram_accuracy = bigram_accuracy(W_bigram, dev_data)
print(f"bigram_model accuracy {W_bigram_accuracy}")

print("-------------------------")
W_trigram = train_trigram_model(train_data)
W_trigram_accuracy = trigram_accuracy(W_trigram, dev_data)
print(f"trigram_model accuracy {W_trigram_accuracy}")

print("-------------------------")
print(f"bigram test: {bigram_accuracy(W_bigram, test_data)}, trigram test: {trigram_accuracy(W_trigram, test_data)}")

-------------------------
grad=0.000992407905869186 loss=2.4850637912750244
bigram_model accuracy 2932
-------------------------
grad=0.000996677321381867 loss=2.1279118061065674
trigram_model accuracy 3862
-------------------------
bigram test: 2962, trigram test: 3880


In [123]:
# E03: use the dev set to tune the strength of smoothing (or regularization) for
# the trigram model - i.e. try many possibilities and see which one works best based 
# on the dev set loss. What patterns can you see in the train and dev set loss as you 
# tune this strength? Take the best setting of the smoothing and evaluate on the test 
# set once and at the end. How good of a loss do you achieve?

W_trigram = train_trigram_model(dev_data, 0.000001)
W_trigram_accuracy = trigram_accuracy(W_trigram, dev_data)
print(f"trigram_model accuracy {W_trigram_accuracy}")

grad=0.0009985571959987283 loss=1.9896376132965088
trigram_model accuracy 4212


In [111]:
# E04: we saw that our 1-hot vectors merely select a row of W, so producing these vectors 
# explicitly feels wasteful. Can you delete our use of F.one_hot in favor of 
# simply indexing into rows of W?

# E05: look up and use F.cross_entropy instead. You should achieve the same result. 
# Can you think of why we'd prefer to use F.cross_entropy instead?