In [1]:
from typing import Final, Tuple
import torch

In [2]:
words: Final[list[str]] = open("names.txt").read().splitlines()


In [3]:
N: Final[torch.Tensor] = torch.zeros((27, 27), dtype=torch.int32)
chars: Final[list[str]] = sorted(list(set(''.join(words))))
stoi: Final[dict[str, int]] = {s: i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos: Final[dict[int, str]] = {i: s for s, i in stoi.items()}


In [None]:
from torch.utils.data import Dataset, random_split

# Create a PyTorch dataset from the word list
class WordsDataset(Dataset):
	def __init__(self, words_list) -> None:
		self.words = words_list

	def __len__(self) -> int:
		return len(self.words)

	def __getitem__(self, idx):
		return self.words[idx]

# Set the seed for reproducibility
torch.manual_seed(2147483647)

# Create a proper PyTorch dataset
dataset = WordsDataset(words)

# Calculate split sizes
n_total = len(dataset)
n_train = int(0.8 * n_total)
n_dev = int(0.1 * n_total)
n_test = n_total - n_train - n_dev  # Use the remainder to ensure the sum equals the total

# Split the data
train_dataset, dev_dataset, test_dataset = random_split(dataset, [n_train, n_dev, n_test])

# Convert back to lists of words for easier processing
train_words: list[str] = [dataset[i] for i in train_dataset.indices]
dev_words: list[str] = [dataset[i] for i in dev_dataset.indices]
test_words: list[str] = [dataset[i] for i in test_dataset.indices]

# Print the sizes to verify
print(f"Total words: {n_total}")
print(f"Training set size: {len(train_words)}")
print(f"Dev set size: {len(dev_words)}")
print(f"Test set size: {len(test_words)}")

Total words: 32033
Training set size: 25626
Dev set size: 3203
Test set size: 3204


In [22]:
print(train_words[:5])
print(dev_words[:5])
print(test_words[:5])


['waseem', 'zahari', 'deylin', 'thoreau', 'annalicia']
['yuritza', 'malikye', 'lars', 'raylyn', 'ramell']
['pieper', 'fern', 'aurora', 'jex', 'safan']


In [5]:
for w in train_words:
    chs: list[str] = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1, ix2 = stoi[ch1], stoi[ch2]
        N[ix1, ix2] += 1

In [6]:
P = (N + 1).float()
P /= P.sum(dim=1, keepdim=True)

In [7]:
g = torch.Generator().manual_seed(2147483647)
ix: int = 0

for i in range(20):
    out: list[str] = []
    while True:
        p = P[ix]
        ix = int(torch.multinomial(p, num_samples=1, replacement=True, generator=g).item())
        out.append(itos[ix])
        if ix == 0:
            break
    print(''.join(out))

junide.
janasah.
p.
cony.
a.
nn.
kohin.
tolian.
juwe.
ksahnaauranilevias.
dedainrwieta.
ssonielylarte.
faveumerifontume.
phynslenaruani.
core.
yaenon.
ka.
jabdinerimikimaynin.
anaasn.
ssorionsush.


In [8]:
test_sets: list[list[str]] = [train_words, dev_words, test_words]
set_names: list[str] = ['train', 'dev', 'test']

for word_set, set_name in zip(test_sets, set_names):
    log_likelihood: float = 0.0
    n: int = 0
    for w in word_set:
        chs: list[str] = ['.'] + list(w) + ['.']
        for ch1, ch2 in zip(chs, chs[1:]):
            ix1, ix2 = stoi[ch1], stoi[ch2]
            log_likelihood += P[ix1, ix2].log().item()
            n += 1
    print(f'Log-likelihood of {set_name} set: {log_likelihood:.4f}')
    nll: float = -log_likelihood / n
    print(f'Negative log-likelihood of {set_name} set: {nll:.4f}')


Log-likelihood of train set: -448245.6935
Negative log-likelihood of train set: 2.4550
Log-likelihood of dev set: -55712.0168
Negative log-likelihood of dev set: 2.4546
Log-likelihood of test set: -56145.4831
Negative log-likelihood of test set: 2.4554


In [9]:
def words_to_datasets(words: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
    xs, ys = [], []
    for w in words:
        chs: list[str] = ['.'] + list(w) + ['.']
        for ch1, ch2 in zip(chs, chs[1:]):
            xs.append(stoi[ch1])
            ys.append(stoi[ch2])
    return torch.tensor(xs), torch.tensor(ys)

train_xs, train_ys = words_to_datasets(train_words)
dev_xs, dev_ys = words_to_datasets(dev_words)
test_xs, test_ys = words_to_datasets(test_words)


In [10]:
import torch.nn.functional as F

In [None]:
g: torch.Generator = torch.Generator().manual_seed(2147483647)
W: torch.Tensor = torch.randn((len(stoi), len(stoi)), generator=g, requires_grad=True)
n_examples: int = train_xs.nelement()
reg_strength: float = 0.01

n_epochs: int = 200
for epoch in range(n_epochs):
    assert isinstance(train_xs, torch.Tensor)
    # Directly use indexing into W instead of one-hot encoding
    logits = W[train_xs]

    # Skip the one-hot encoding code that follows
    x_enc = F.one_hot(train_xs, num_classes=len(stoi)).float()
    logits = x_enc @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdim=True)
    loss = -torch.log(probs[torch.arange(n_examples), train_ys]).mean() + reg_strength * W.pow(2).mean()
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1}, loss: {loss.item():.4f}')

    W.grad = None
    loss.backward()

    assert W.grad is not None
    W.data += -50 * W.grad

def calculate_nll(xs: torch.Tensor, ys: torch.Tensor) -> float:
    x_enc = F.one_hot(xs, num_classes=len(stoi)).float()
    logits = x_enc @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdim=True)
    return -torch.log(probs[torch.arange(xs.nelement()), ys]).mean().item()

train_nll: float = calculate_nll(train_xs, train_ys)
dev_nll: float = calculate_nll(dev_xs, dev_ys)
test_nll: float = calculate_nll(test_xs, test_ys)
print(f'Training set NLL: {train_nll:.4f}')
print(f'Dev set NLL: {dev_nll:.4f}')
print(f'Test set NLL: {test_nll:.4f}')

Epoch 10, loss: 2.7190
Epoch 20, loss: 2.5890
Epoch 30, loss: 2.5444
Epoch 40, loss: 2.5230
Epoch 50, loss: 2.5110
Epoch 60, loss: 2.5035
Epoch 70, loss: 2.4985
Epoch 80, loss: 2.4950
Epoch 90, loss: 2.4924
Epoch 100, loss: 2.4904
Epoch 110, loss: 2.4889
Epoch 120, loss: 2.4877
Epoch 130, loss: 2.4867
Epoch 140, loss: 2.4859
Epoch 150, loss: 2.4853
Epoch 160, loss: 2.4848
Epoch 170, loss: 2.4843
Epoch 180, loss: 2.4839
Epoch 190, loss: 2.4836
Epoch 200, loss: 2.4833
Training set NLL: 2.4656
Dev set NLL: 2.4647
Test set NLL: 2.4654


In [12]:
g: Final[torch.Generator] = torch.Generator().manual_seed(2147483647)
n_samples: Final[int] = 10

for i in range(n_samples):
    out: list[str] = []
    ix = 0
    while True:
        x_enc = F.one_hot(torch.tensor([ix]), num_classes=len(stoi)).float()
        logits = x_enc @ W
        counts = logits.exp()
        probs = counts / counts.sum(dim=1, keepdim=True)

        ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[int(ix)])
        if ix == 0:
            break;
    print(''.join(out))

junide.
janasah.
p.
cfay.
a.
nn.
kohin.
tolian.
juwe.
ksahnaauranilevias.


In [13]:
def calculate_nll(xs: torch.Tensor, ys: torch.Tensor) -> float:
    x_enc = F.one_hot(xs, num_classes=len(stoi)).float()
    logits = x_enc @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdim=True)
    return -torch.log(probs[torch.arange(xs.nelement()), ys]).mean().item()

train_nll: float = calculate_nll(train_xs, train_ys)
dev_nll: float = calculate_nll(dev_xs, dev_ys)
test_nll: float = calculate_nll(test_xs, test_ys)
print(f'Training set NLL: {train_nll:.4f}')
print(f'Dev set NLL: {dev_nll:.4f}')
print(f'Test set NLL: {test_nll:.4f}')


Training set NLL: 2.4656
Dev set NLL: 2.4647
Test set NLL: 2.4654


In [14]:
# Implement trigram using counting

char_count: int = len(stoi)
N = torch.zeros((char_count, char_count, char_count), dtype=torch.int32)
for w in train_words:
    chs: list[str] = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1, ix2, ix3 = stoi[ch1], stoi[ch2], stoi[ch3]
        N[ix1, ix2, ix3] += 1

P = (N + 1).float()
P /= P.sum(dim=2, keepdim=True)
P[0, 0].sum()

tensor(1.)

In [15]:
def get_nll_trigram(test_words: list[str]) -> float:
    ll: float = 0.0
    n: int = 0
    for w in test_words:
        chs: list[str] = ['.'] + list(w) + ['.']
        for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
            ix1, ix2, ix3 = stoi[ch1], stoi[ch2], stoi[ch3]
            ll += P[ix1, ix2, ix3].log().item()
            n += 1
    return -ll / n
print(f'Training set NLL (trigram): {get_nll_trigram(train_words):.4f}')
print(f'Dev set NLL (trigram): {get_nll_trigram(dev_words):.4f}')
print(f'Test set NLL (trigram): {get_nll_trigram(test_words):.4f}')


Training set NLL (trigram): 2.0954
Dev set NLL (trigram): 2.1310
Test set NLL (trigram): 2.1213
