In [60]:
import torch
import torch.nn as nn
from torch.autograd import Variable

import numpy as np
import string
from tqdm import trange

all_letters = string.ascii_letters + " .,;'"

def read_data():
    data_path = 'data/HPserie/HPSerie'
    with open(data_path, 'r') as myfile:
        data=myfile.read().replace('\n', '')

    return data

def seq_gen(data, seq_len):
    total_size = len(data)
    ex_len = seq_len + 1
    idxs = list(range(0, total_size, ex_len))
    np.random.shuffle(idxs)
    yield len(idxs)
    while True:
        for idx in idxs:
            d = data[idx:idx + ex_len]
            X = d[0:-1]
            Y = d[1:]
            X_txt = ''.join(X)
            Y_txt = ''.join(Y)
            X_encoded = Variable(torch.Tensor([char_to_idx[x] for x in X]).long())
            Y_encoded = Variable(torch.Tensor([char_to_idx[x] for x in Y]).long())
            
            yield X_txt, Y_txt, X_encoded, Y_encoded
        np.random.shuffle(idxs)
    

data = read_data()
data = [x for x in data if x in all_letters]


char_to_idx = {x: i for i,x in enumerate(all_letters)}
idx_to_char = {v: k for k,v in char_to_idx.items()}

In [61]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers

        self.encoder = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
        self.decoder = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden):
        input = self.encoder(input.view(1, -1))
        output, hidden = self.gru(input.view(1, 1, -1), hidden)
        output = self.decoder(output.view(1, -1))
        return output, hidden

    def init_hidden(self):
        return Variable(torch.zeros(self.n_layers, 1, self.hidden_size))

In [85]:
input_size = len(all_letters)
hidden_size = 100
n_layers = 1

print_every = 50 
plot_every = 100
epochs = 2


model = RNN(input_size, hidden_size, input_size, n_layers)
model_optimizer = torch.optim.Adam(model.parameters())

In [86]:
criterion = nn.CrossEntropyLoss()

def train(inp, target):
    hidden = model.init_hidden()
    model.zero_grad()
    loss = 0

    for c in range(len(inp)):
        output, hidden = model(inp[c], hidden)
        loss += criterion(output, target[c])

    loss.backward()
    model_optimizer.step()

    return loss.data[0] / float(len(inp))

In [87]:
G = seq_gen(data, 200)
batch_per_epoch = next(G)

loss_avg = 0
losses = []

for epoch in range(1, epochs+1):
    for b in range(batch_per_epoch):
        X_txt, Y_txt, X_enc, Y_enc = next(G)
        loss = train(X_enc, Y_enc)       
        loss_avg += loss
        
        if b % print_every == 0:
            progress = "{0:0=.2f}%".format(100 * float(b) / float(batch_per_epoch))
            print('Epoch Progress: ', progress, ' -- ' + 'Loss: ' , loss_av)
            
        if b % plot_every == 100:
            losses.append(loss_avg / plot_every)
            loss_avg = 0
    

Epoch Progress:  0.00%  -- Loss:  4.064524230957031
Epoch Progress:  0.16%  -- Loss:  2.7305538940429686
Epoch Progress:  0.31%  -- Loss:  2.2909014892578123
Epoch Progress:  0.47%  -- Loss:  2.3305950927734376
Epoch Progress:  0.63%  -- Loss:  2.3104985046386717
Epoch Progress:  0.79%  -- Loss:  2.243597412109375
Epoch Progress:  0.94%  -- Loss:  2.2456619262695314
Epoch Progress:  1.10%  -- Loss:  2.458329315185547
Epoch Progress:  1.26%  -- Loss:  1.9930632019042969
Epoch Progress:  1.42%  -- Loss:  2.1697654724121094
Epoch Progress:  1.57%  -- Loss:  2.1154826354980467
Epoch Progress:  1.73%  -- Loss:  1.9758218383789063
Epoch Progress:  1.89%  -- Loss:  2.047470703125
Epoch Progress:  2.05%  -- Loss:  1.8630349731445313
Epoch Progress:  2.20%  -- Loss:  2.0472833251953126
Epoch Progress:  2.36%  -- Loss:  1.9806504821777344
Epoch Progress:  2.52%  -- Loss:  2.067671356201172
Epoch Progress:  2.68%  -- Loss:  1.9888923645019532
Epoch Progress:  2.83%  -- Loss:  1.8628021240234376


KeyboardInterrupt: 