In [2]:
import torch

In [3]:
words = open(r'data\names.txt','r').read().splitlines()

In [4]:
def get_trigrams(word : str) -> list[tuple[str,str]]:
    chars = ['.'] + list(word) + ['.']
    pairs = [(i,j,k) for i,j,k in zip(chars,chars[1:],chars[2:]) ]
    return pairs


# make bigram stat
trigram_counts = {}
for word in words:
    triplets = get_trigrams(word)
    for triplet in triplets:
        trigram_counts[triplet] = trigram_counts.get(triplet,0) + 1


# make char -> index lookup table
chars = sorted(list(set(''.join(words))))
chars = ['.'] + chars

char2idx = {c:i for i,c in enumerate(chars)}
idx2char = {i:c for c,i in char2idx.items()}


# Trigram Neural net


In [5]:
import torch.nn.functional as F
import torch.nn as nn

xs, ys = [], []

for w in words:
    triplets = get_trigrams(w)
    for triplet in triplets:
        x1,x2,y = char2idx[triplet[0]],char2idx[triplet[1]],char2idx[triplet[2]]
        xs.append([x1,x2])
        ys.append(y)


xs = torch.tensor(xs)
x_enc = F.one_hot(xs, num_classes=27).float().flatten(1,2)

ys = torch.tensor(ys)

total_trigrams = len(ys)
print(f'total bigrams: {total_trigrams}')
print(x_enc.shape)




total bigrams: 196113
torch.Size([196113, 54])


In [7]:
rand_gen = torch.Generator().manual_seed(1234)
weights = torch.randn((54,27), generator=rand_gen,requires_grad=True)

In [10]:
# training loop
for epoch in range(100):
    # forward pass


    loss_fn = nn.CrossEntropyLoss()
    logits = x_enc @ weights
    
    # This is cross entropy loss
    # count = logits.exp()
    # probs = count / count.sum(dim=1,keepdim=True)
    # loss = -probs[torch.arange(total_trigrams),ys].log().mean()

    loss = loss_fn(logits,ys) + 0.01 * (weights**2).mean()

    print(f'epoch: {epoch}, loss: {loss.item()}')

    # backward pass
    weights.grad = None
    loss.backward(retain_graph=True)

    # update weights
    weights.data -= 10 * weights.grad

print(loss)

epoch: 0, loss: 2.3784940242767334
tensor(2.3785, grad_fn=<AddBackward0>)


In [9]:
# Inferencing

rand_gen = torch.Generator().manual_seed(12141234)

def make_word(start_with : str) -> str:
    start = [0,char2idx[start_with]]
    out = [start_with]

    while True:
        x_enc = F.one_hot(torch.tensor(start), num_classes=27).float().flatten()
        logits = x_enc @ weights
        count = logits.exp()
        probs = count / count.sum()

        next_char = torch.multinomial(probs,1, generator=rand_gen, replacement=True).item()

        if next_char == 0:
            break

        out.append(idx2char[next_char])
        start = [start[1],next_char]
    
    print(''.join(out))



_char = 'l'
print(f'Word names with {_char} \n -----------------')
for i in range(10):
    make_word(_char)
print()





Word names with l 
 -----------------
laim
la
le
lyna
liue
lonn
la
luy
lyaila
lin

