In [51]:
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable

In [52]:
all_characters = string.printable
n_characters = len(all_characters)

def char_tensor(string):
    tensor = torch.zeros(len(string)).long()
    for c in range(len(string)):
        tensor[c] = all_characters.index(string[c])
    return Variable(tensor)

data = [char_tensor('ab'), char_tensor('abab'), char_tensor('ababab')]

In [88]:
class CharLSTM(nn.Module):
    def __init__(self, char_dim, hidden_dim):
        super(CharLSTM, self).__init__()
        
        self.char_embeddings = nn.Embedding(char_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, char_dim)
        
        self.hidden_dim = hidden_dim
        self.hidden = self.init_hidden()
    
    def init_hidden(self):
        """
        Return variables that we can use as h_0 and c_0. 
        """
        return (Variable(torch.zeros(1, 1, self.hidden_dim)),
                Variable(torch.zeros(1, 1, self.hidden_dim)))
    
    def forward(self, x):
        x = self.char_embeddings(x)
        x, self.hidden = self.lstm(x.view(len(x), 1, -1), self.hidden)
        x = self.decoder(x.view(len(x), -1))
        x = F.log_softmax(x, dim=1)
        
        return x
        

In [89]:
model = CharLSTM(n_characters, 5)
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(500):
    for sequence in data:
        model.zero_grad()
        model.hidden = model.init_hidden()
        
        scores = model(sequence[:-1])
        loss = loss_function(scores, sequence[1:])
        loss.backward()
        optimizer.step()

In [90]:
def greedy_sample(length, character='a'):
    model.hidden = model.init_hidden()
    values = character
    character = char_tensor(character)
    
    for i in range(length):
        _, character = model(character).max(dim=1)
        values += all_characters[int(character.data)]
        
    return values

In [91]:
greedy_sample(10)

'abababababa'