In [None]:
import torch
import torch.nn.functional as F
import string
import random
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
words = open("data/german_words.txt", "r").read().splitlines()
print(f"Original length: {len(words)}")

def contains_illegal_char(word):
    allowed_chars = list(string.ascii_lowercase) + list(" -äöüß")
    return any(c not in allowed_chars for c in word)

# Clean up words a little and remove lots of characters that rarely occur
words = [w.lower() for w in words]
words = [w for w in words if not contains_illegal_char(w)]

print(f"New length: {len(words)}")

In [None]:
chars = sorted(list(set("".join(words))))

stoi = dict()
stoi["."] = len(stoi)
for c in chars:
    stoi[c] = len(stoi)

num_chars = len(stoi)

itos = {i:s for s, i in stoi.items()}

print(stoi)
print(itos)

In [None]:
context_length = 3

xs = []
ys = []

for word in words:
    context = [0] * context_length
    for ch in word + ".":
        ix = stoi[ch]
        
        xs.append(context)
        ys.append(ix)
        
        # Shift context and continue
        context = context[1:] + [ix]
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)

xs.shape, ys.shape

In [None]:
def print_samples():
    for i, x in enumerate(xs):
        y = ys[i]
        print("".join(itos[i.item()] for i in x) + "->" + itos[y.item()])

print_samples()

In [None]:
embedding_size = 2
hidden_layer_size = 300

# Make embedding into 2D space
C = torch.randn((num_chars), embedding_size)

# First (hidden) layer
W1 = torch.randn((embedding_size * context_length, hidden_layer_size))
b1 = torch.randn(hidden_layer_size)

# Final layer
W2 = torch.randn((hidden_layer_size, num_chars))
b2 = torch.randn(num_chars)

parameters = [C, W1, b1, W2, b2]

for p in parameters:
    p.requires_grad = True

In [None]:
def iteration():
    # --- make minibatch ---
    ix = torch.randint(0, xs.shape[0], (32,))
    emb = C[xs[ix]]
    emb = emb.view((-1, context_length * embedding_size))
    
    # --- forward pass ---
    h = torch.tanh(emb @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, ys[ix])
    
    # --- backward pass ---
    for p in parameters:
        p.grad = None
        
    loss.backward()
    
    for p in parameters:
        p.data += -0.1 * p.grad

for i in range(20000):
    iteration()

emb = C[xs]
emb = emb.view((-1, context_length * embedding_size))
h = torch.tanh(emb @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, ys)

print(f"{loss=}")

In [None]:
def sample_word():
    context = [0] * context_length
    word = []
    
    while True:
        x = torch.tensor(context)
        
        emb = C[x]
        emb = emb.view((-1, context_length * embedding_size))
        
        h = torch.tanh(emb @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        next_ch = torch.multinomial(probs, 1)
        
        if next_ch == 0:
            break
        
        word.append(itos[next_ch.item()])
        context = context[1:] + [next_ch.item()]
    
    return "".join(word)

for i in range(10):
    print(f"{i}: {sample_word()}")