In [None]:
import torch
import torch.nn.functional as F
import string
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
words = open("data/german_words.txt", "r").read().splitlines()
print(f"Original length: {len(words)}")

def contains_illegal_char(word):
    allowed_chars = list(string.ascii_lowercase) + list(" -äöüß")
    return any(c not in allowed_chars for c in word)

# Clean up words a little and remove lots of characters that rarely occur
words = [w.lower() for w in words]
words = [w for w in words if not contains_illegal_char(w)]

print(f"New length: {len(words)}")

In [None]:
chars = sorted(list(set("".join(words))))

stoi = {s:i for i, s in enumerate(chars)}
stoi["<S>"] = len(stoi)
stoi["<E>"] = len(stoi)

num_chars = len(stoi)

itos = {i:s for s, i in stoi.items()}

In [None]:
N = torch.zeros((num_chars, num_chars), dtype=torch.long)
for w in words:
    chs = ["<S>"] + list(w) + ["<E>"]
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        N[ix1, ix2] += 1
        
P = N.float() / N.sum(1, keepdim=True)

In [None]:
plt.figure(figsize=(24, 24))
plt.imshow(N)

for i in range(num_chars):
    for j in range(num_chars):
        bigram = itos[i] + itos[j]
        plt.text(j, i, bigram, ha="center", va="top")
        plt.text(j, i, N[i, j].item(), ha="center", va="bottom")
plt.axis("off")

In [None]:
def sample_word():
    result = []
    
    ix = stoi["<S>"]
    end = stoi["<E>"]
    
    while True:
        probs = P[ix, :]
        
        sample = torch.multinomial(probs, 1, replacement=True).item()
        if sample == end:
            break
            
        result.append(itos[sample])
        ix = sample
    
    return "".join(result)
    
for i in range(100):
    print(f"{i}: {sample_word()}")

In [None]:
log_likelihood = 0.0
n = 0

for w in words:
    chs = ["<S>"] + list(w) + ["<E>"]
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        
        prob = P[ix1, ix2]
        logprob = torch.log(prob)
        
        log_likelihood += logprob
        n += 1
        
nll = -log_likelihood / n
print(f"{nll=}")

# NN implementation

In [None]:
# Create input dataset from bigrams
xs, ys = [], []

for w in words:
    chs = ["<S>"] + list(w) + ["<E>"]
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        
        xs.append(ix1)
        ys.append(ix2)
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)

print(xs.shape)
print(ys.shape)

In [None]:
# Weights
W = torch.randn((num_chars, num_chars), dtype=torch.float32, requires_grad=True)
print(W.shape)

def iteration():
    # --- Forward pass ---
    
    # Input to the NN (B, num_chars)
    xenc = F.one_hot(xs, num_classes=num_chars).float()
    
    # Output of the NN (B, num_chars)
    logits = torch.mm(xenc, W)
    
    # Make all outputs positive
    # Also interpret as counts
    counts = logits.exp()
    
    # And normalize to a distribution
    probs = counts / counts.sum(1, keepdim=True)
    
    # Calculate loss function
    batch_size, _ = probs.shape
    loss = -probs[torch.arange(batch_size), ys].log().mean()
    
    # --- Backward pass ---
    W.grad = None
    loss.backward()
    
    W.data -= 50 * W.grad
    
    return loss
    
for i in range(100):
    loss = iteration()
    print(f"{i=} {loss=}")

In [None]:
W_exp = W.detach().exp()
W_exp *= 100
W_exp = W_exp.int()
print(W_exp)

plt.figure(figsize=(24, 24))
plt.imshow(W_exp)

for i in range(num_chars):
    for j in range(num_chars):
        bigram = itos[i] + itos[j]
        plt.text(j, i, bigram, ha="center", va="top")
        plt.text(j, i, W_exp[i, j].item(), ha="center", va="bottom")
plt.axis("off")