In [26]:
import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from transformer import Transformer
from torch.utils.data import Dataset

In [27]:
with open("german_novels.txt", "r") as f:
    
    base_text = f.read()

with open("gedichte.txt", "r") as f:
    
    train_text = f.read()
    
with open("gedichte_val.txt", "r") as f:
    
    val_text = f.read()

In [54]:
end_base = len(base_text) // 9

# 36 mb is way so much, use 9mb instead
base_train_text = base_text[:end_base]
base_val_text = base_text[end_base: end_base + end_base // 10]

In [55]:
chars = list(sorted(set(base_train_text + base_val_text + train_text + val_text)))

c_to_i = {c : i for i, c in enumerate(chars)}
i_to_c = {i : c for i, c in enumerate(chars)}

In [56]:
def to_tokens(s):
    return list(map(lambda c: c_to_i[c], s))

In [57]:
def from_tokens(t):
    return "".join(list(map(lambda i: i_to_c[i], t)))

In [58]:
base_train_tokens = to_tokens(base_train_text)
base_val_tokens = to_tokens(base_train_text)

train_tokens = to_tokens(train_text)
val_tokens = to_tokens(val_text)

In [59]:

class TokenDataset(Dataset):
    def __init__(self, token_list, seq_length):
        self.token_list = token_list
        self.seq_length = seq_length

    def __len__(self):
        return len(self.token_list) - self.seq_length

    def __getitem__(self, idx):
        return (torch.tensor(self.token_list[idx:idx+self.seq_length]),
                torch.tensor(self.token_list[idx+1:idx+self.seq_length+1]))

def get_XY(token_list, seq_length):
    dataset = TokenDataset(token_list, seq_length)
    return dataset

In [60]:
seq_length = 128

base_train_dataset = get_XY(base_train_tokens, seq_length)
base_val_dataset = get_XY(base_val_tokens, seq_length)
train_dataset = get_XY(train_tokens, seq_length)
val_dataset = get_XY(val_tokens, seq_length)

In [61]:
emb_size = 256
vocab_size = len(chars)
n_attention_units = 3

tf = Transformer(vocab_size = vocab_size, seq_length = seq_length, emb_size = emb_size\
                 , n_attention_units = n_attention_units)

In [62]:
# training parameters
batch_size = 512
criterion = nn.CrossEntropyLoss()

optim = torch.optim.Adam(tf.parameters(), lr=3e-4)

In [63]:
base_train_dataloader = DataLoader(base_train_dataset, batch_size=batch_size, shuffle=True)
base_val_dataloader = DataLoader(base_val_dataset, batch_size=batch_size, shuffle=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

In [64]:
def train(tf, train_dataloader, test_dataloader, n_epochs):
    train_loss_plotting = []
    val_loss_plotting = []

    for epoch_idx in range(n_epochs):

        losses_batches_in_epoch = []

        # important for dropout (gets activated)
        tf.train()

        for batch_X, batch_y in tqdm(train_dataloader):

            # send to GPU
            batch_X = batch_X.to(tf.device)
            batch_y = batch_y.to(tf.device)

            logits = tf(batch_X)

            # nn.CrossEntropy needs [batch size, n_chars, seq_length]

            logits = logits.transpose(1, 2)

            loss = criterion(logits, batch_y)

            losses_batches_in_epoch.append(loss.item())

            optim.zero_grad()
            loss.backward()
            optim.step()

        losses_batches_in_val = []

        # dropout gets deactivated
        tf.eval()

        with torch.no_grad():

            for batch_X, batch_y in val_dataloader:

                batch_X = batch_X.to(tf.device)
                batch_y = batch_y.to(tf.device)

                logits = tf(batch_X)

                logits = logits.transpose(1, 2)

                loss = criterion(logits, batch_y)

                losses_batches_in_val.append(loss.item())

        train_loss_epoch = np.mean(np.array(losses_batches_in_epoch))
        val_loss_epoch = np.mean(np.array(losses_batches_in_val))

        train_loss_plotting.append(train_loss_epoch)
        val_loss_plotting.append(val_loss_epoch)
    
        print(f"{epoch_idx + 1}. train loss = {train_loss_epoch}, val loss = {val_loss_epoch}")
    
    return train_loss_plotting, val_loss_plotting

In [89]:
# train the "foundation model to learn german"
train(tf, train_dataloader, val_dataloader, n_epochs = 1)

100%|███████████████████████████████████████| 1269/1269 [03:40<00:00,  5.75it/s]


1. train loss = 1.0988311027334468, val loss = 1.3883833562102272


NameError: name 'test_loss_plotting' is not defined

In [83]:
def sample_from_transformer(tf, prompt, temperature, n_samples, seq_length):
    
    tf.eval()
    
    curr_context = torch.tensor(to_tokens(prompt)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
    
    generated = []
    
    with torch.no_grad():
    
        for _ in range(n_samples):

            logits = tf(curr_context)

            logits_last = logits[0, -1] / temperature

            sampled = torch.distributions.Categorical(logits = logits_last).sample()

            generated.append(sampled.item())
            
            if len(curr_context[0]) < seq_length:
                
                # add to context
                
                # unsqueeze(0) to make a batch
                # unsqueeze(0) to convert a number into a singleton list
                curr_context = torch.cat((curr_context, sampled.unsqueeze(0).unsqueeze(0)), dim=1)
            else:
                
                curr_context = torch.cat((curr_context[:, 1:], sampled.unsqueeze(0).unsqueeze(0)), dim=1)
                
    
    return generated

In [90]:
torch.save(tf.state_dict(), "tuned_model_2.pth")

In [127]:
gen = sample_from_transformer(tf, prompt="Gibt es Gott?\n", temperature=0.8, \
                              n_samples = 1000, seq_length = 128)

In [128]:
print(from_tokens(gen))


    Vor ihr Spruch und Wald und um die Geliebte
    Unberühren her.
    Auf deiner Augen und dazwischen Felsenschlicht
    Und gingen niederschwerer,
    Mit einer alten Kopf tief verrollt
    In einem Bilde des Marmors Bruder,
    Schwarm erlösen wohl über seine Rose allee
           Wie alles, warum deine frühe Sterne
                                         Aber durchbricht mit Quardianeen,
           die Gedanken, Bächlein, weiche ich's, wie es wogend durch die Schwalbe steht:
    Gewaltig war dieses Rosenlicht.
    Den Raum des Herzens war der Jahre schwoll,
    Ein Jugend betet sie mit allem die Nacht,
    Und da ging schon der Flur!
    Nicht mehr will ich erkennen.

    Ich will mich beschwingen
    Von der süßen Wegen
    Und halb im Hafenfedern herunter
    In den Hals und der Nebel führten
    Wie deine Hand das müttert.

    Und immer fragt die Nächte heben
    Was schloß ihr düsterrote Rosen!

    Nie selten schliefen sich.

    In das Sterne kühlt?
    Bei Stand meines K