In [18]:
from __future__ import annotations
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

%matplotlib inline

In [19]:
words: list[str] = open("names.txt", "r").read().splitlines()

In [20]:
# Gets all the characters, a-z
chars: list[str] = sorted(list(set("".join(words))))

# Maps each character to an integer
stoi: dict[str, int] = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0

# Maps each integer to a character
itos: dict[int, str] = {i: s for s, i in stoi.items()}

In [21]:
# Build dataset
block_size = 3 # Context length: number of characters taken to predict the next one

X, Y = [], []
for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + ".":
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '---->', itos[ix])
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ----> e
..e ----> m
.em ----> m
emm ----> a
mma ----> .
olivia
... ----> o
..o ----> l
.ol ----> i
oli ----> v
liv ----> i
ivi ----> a
via ----> .
ava
... ----> a
..a ----> v
.av ----> a
ava ----> .
isabella
... ----> i
..i ----> s
.is ----> a
isa ----> b
sab ----> e
abe ----> l
bel ----> l
ell ----> a
lla ----> .
sophia
... ----> s
..s ----> o
.so ----> p
sop ----> h
oph ----> i
phi ----> a
hia ----> .


In [22]:
X.shape, Y.shape, X.dtype, Y.dtype

(torch.Size([32, 3]), torch.Size([32]), torch.int64, torch.int64)

In [23]:
# Embedding lookup table
C = torch.randn((27, 2))

In [24]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [25]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [29]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
print(h)
h.shape

tensor([[-0.7377,  0.9983, -0.5488,  ..., -0.7515,  0.2465, -0.9993],
        [-0.9895,  0.9998, -0.9760,  ..., -0.9954,  0.9912, -0.9978],
        [-0.9995,  0.9997,  0.8055,  ..., -0.9979, -0.7741, -0.9993],
        ...,
        [ 0.8521,  0.9742,  0.9244,  ..., -0.9910,  0.9000, -0.9999],
        [-0.9583,  0.9980, -0.9702,  ..., -0.6988,  0.9433,  0.2720],
        [-0.9978,  0.9746,  0.9673,  ..., -0.9880, -0.8701, -0.9984]])


torch.Size([32, 100])

In [30]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [31]:
logits = h @ W2 + b2
logits.shape

torch.Size([32, 27])

In [32]:
counts = logits.exp()
prob = counts / counts.sum(-1, keepdim=True)
prob.shape

torch.Size([32, 27])

In [27]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([-1.8399, -0.3441])