In [None]:
# data
import torch
import matplotlib.pyplot as plt
%matplotlib inline
import random

words = open('./data/names.txt', 'r').read().splitlines()
v = sorted(list(set(''.join(words))))
encode = { c:i+1 for i,c in enumerate(v) }
encode['.'] = 0
decode = { i:c for c,i in encode.items() }

context_length = 3
def gen_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * context_length;
        for c in w + '.':
            X.append(context)
            Y.append(encode[c])
            # print(''.join(decode[i] for i in context), '-->', decode[encode[c]])
            context = context[1:] + [encode[c]]
    X, Y = torch.tensor(X), torch.tensor(Y) # X:(N,C) Y:(N)
    return X, Y

random.seed(42)
random.shuffle(words)
n1, n2 = int(0.8*len(words)), int(0.9*len(words))
Xtr, Ytr = gen_dataset(words[:n1])
Xdev, Ydev = gen_dataset(words[n1:n2])
Xte, Yte = gen_dataset(words[n2:])

In [None]:
# model (baseline)
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

class Bigram(nn.Module):
    def __init__(self, V):
        self.token_embedding_table = nn.Embedding(V, V)

    def forward(self, xb, yb):
        logits = self.token_embedding_table(xb) # (B: batch,T: time,C: channel)
        return logits

m = Bigram(V)
y_hat = m(xb, yb)
print(y_hat.shape)