In [16]:
# %load data_preparation/v3.py
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.dataset import random_split

# create a classifier class that inherits from nn.Module
class BigramClassifier(torch.nn.Module):
    def __init__(self):
        super(BigramClassifier, self).__init__()
        self.W = torch.nn.Parameter(torch.randn((27,27), generator=g, requires_grad=True))

    def forward(self, x):
        return x @ self.W

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

g = torch.Generator().manual_seed(42)

words = open('names.txt').read().splitlines()
letters = sorted(list(set(''.join(words))))
letter_to_index = {letter: index for index, letter in enumerate(letters)}
letter_to_index['.'] = 0
index_to_letter = {i: letter for letter, i in letter_to_index.items()}

xs, ys = [], []

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        xs.append(letter_to_index[ch1])
        ys.append(letter_to_index[ch2])

xs = torch.as_tensor(xs)
ys = torch.as_tensor(ys)

dataset = TensorDataset(xs, ys)

train_ratio = .8
validation_ratio = .1

n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_validation = int(n_total * validation_ratio)
n_test = n_total - n_train - n_validation

train_data, validation_data, test_data = random_split(dataset, [n_train, n_validation, n_test])

train_loader = DataLoader(train_data, batch_size=n_train, shuffle=True)
validation_loader = DataLoader(validation_data, batch_size=n_validation, shuffle=True)
test_loader = DataLoader(test_data, batch_size=n_test, shuffle=True)


In [17]:
torch.as_tensor(xs).shape

torch.Size([5])

In [18]:
F.one_hot(torch.as_tensor(xs), num_classes=27).float().shape

torch.Size([5, 27])

In [19]:
xs

[0, 4, 12, 12, 0]

In [20]:
F.one_hot(torch.as_tensor(xs), num_classes=27).float()

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]])