In [None]:
import torch
import torchvision
import os
import numpy
import collections

In [None]:
# load text
with open('poem.txt') as f:
    txt = f.read().splitlines() # create string array without newlines
    txt = [word.strip() for word in txt] # remove additional whitespaces
    txt = list(filter(None, txt)) # remove empty string objects
    txt = ' '.join(txt).lower() # create one string
    characters = sorted(set(txt)) # create list of all characters
    char_array = [ord(i) for i in characters] # convert characters into unicode representation
    D = len(char_array)
    C = {c:i for i,c in enumerate(char_array)}
    f.close()

In [None]:
def one_hot_for(b):
    t = torch.zeros(b.shape[0], b.shape[1], D)
    for i in range(b.shape[0]):
        for j in range(b.shape[1]):
            # ignore unknown values (zero-padding)
            if b[i,j] >= 0:
                t[i,j,b[i,j]] = 1
    return t

In [None]:
class RNN(torch.nn.Module):
    def __init__(self, context,D):
        self.context = context
        self.D = D
        super(RNN, self).__init__()
        self.W1 = torch.nn.Linear(in_features=self.D, out_features=1000)
        self.W2 = torch.nn.Linear(in_features=self.W1.out_features, out_features=self.D)
        self.Wr = torch.nn.Linear(in_features=self.W1.out_features, out_features=self.W1.out_features)
        self.activation = torch.nn.PReLU()
    
    def forward(self, x):
        h_s = torch.zeros(len(x), self.W1.out_features) # init hidden vectors to zero
        Z = [] # logits
        for s in range(self.context):
            a_s = self.W1(x[:,s]) + self.Wr(h_s) # compute activation
            h_s = self.activation(a_s) # apply activation function
            Z.append(self.W2(h_s))
        return torch.stack(Z).transpose(1,0)
    
    def predict(self, x):
        h_s = torch.zeros(len(x), self.W1.out_features) # init hidden vectors to zero
        for s in range(x.shape[1]):
            a_s = self.W1(x[:,s]) + self.Wr(h_s) # compute activation
            h_s = self.activation(a_s) # apply activation function
        return self.W2(h_s)

In [None]:
# data creation script

batch_size = 256
learn_rate = 1e-3
context = 20

# init network
network = RNN(context, D)
loss = torch.nn.CrossEntropyLoss(ignore_index = -1)
optimizer = torch.optim.SGD(network.parameters(), lr=learn_rate, momentum=0.9)

# create dataset tensor
data = collections.deque(maxlen=context) # fixed context size
data.extend([-1] * context)
X,T = [], []
for line in txt: # iterate over poem text
    for char in line:
        X.append(numpy.array(data)) # current input
        data.append(C[ord(char)]) # append current char
        T.append(numpy.array(data)) # current output

# create data tensors and data loader
DS = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(T, dtype=torch.long))
DL = torch.utils.data.DataLoader(DS, batch_size=batch_size, shuffle=True)

In [None]:
# training
for epoch in range(10): # train for 10 epochs
    total_loss = 0.
    for x,t in DL:
        # train network
        optimizer.zero_grad()
        # forward pass
        z = network(one_hot_for(x))
        # compute average loss
        J = torch.stack([loss(z[:,s], t[:,s]) for s in range(context)]).sum()
        J.backward()
        optimizer.step()
        total_loss += J
        print(f"\rLoss: {float(J)/t.shape[0]: 3.5f}", end="")
    print(f"\rEpoch: {epoch} -- Loss: {total_loss / len(DS)}")
    torch.save(network.state_dict(), "text.model")

In [None]:
# test model

# load network again
network = RNN(context, D)
network.load_state_dict(torch.load("text.model"))

samples = ("the ", "beau", "mothe", "bloo")

for seed in samples:
    text = seed
    with torch.no_grad():
        for i in range(80):
            x = one_hot_for(numpy.array[[C[ord(s)] for s in text]])
            z = network.predict(x)
            y = torch.softmax(z,1)
            next_char = char_array[numpy.argmax(y)]
            text += chr(next_char)
        print(f"{seed} -> {text}")