In [None]:
import torch
import string
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
words = open("data/german_words.txt", "r").read().splitlines()
print(f"Original length: {len(words)}")

def contains_illegal_char(word):
    allowed_chars = list(string.ascii_lowercase) + list(" -äöüß")
    return any(c not in allowed_chars for c in word)

# Clean up words a little and remove lots of characters that rarely occur
words = [w.lower() for w in words]
words = [w for w in words if not contains_illegal_char(w)]

print(f"New length: {len(words)}")

In [None]:
chars = sorted(list(set("".join(words))))

stoi = {s:i for i, s in enumerate(chars)}
stoi["<S>"] = len(stoi)
stoi["<E>"] = len(stoi)

num_chars = len(stoi)

itos = {i:s for s, i in stoi.items()}

In [None]:
N = torch.zeros((num_chars, num_chars), dtype=torch.long)
for w in words:
    chs = ["<S>"] + list(w) + ["<E>"]
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        N[ix1, ix2] += 1
        
P = N.float() / P.sum(1, keepdim=True)

In [None]:
plt.figure(figsize=(24, 24))
plt.imshow(N)

for i in range(num_chars):
    for j in range(num_chars):
        bigram = itos[i] + itos[j]
        plt.text(j, i, bigram, ha="center", va="top")
        plt.text(j, i, N[i, j].item(), ha="center", va="bottom")
plt.axis("off")

In [None]:
def sample_word():
    result = []
    
    ix = stoi["<S>"]
    end = stoi["<E>"]
    
    while True:
        probs = P[ix, :]
        
        sample = torch.multinomial(probs, 1, replacement=True).item()
        if sample == end:
            break
            
        result.append(itos[sample])
        ix = sample
    
    return "".join(result)
    
for i in range(100):
    print(f"{i}: {sample_word()}")