In [1]:
import os
import time
import numpy as np

import torch
import torch.nn as nn

from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

In [2]:
class CharRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(CharRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        
        # Layers
        self.encoder = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, n_layers)
        self.decoder = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        batch_size = x.size(0)
        encoded = self.encoder(x)
        y, hidden = self.rnn(encoded.view(1, batch_size, -1), hidden)
        y = self.decoder(y.view(batch_size, -1))
        return y, hidden

    def forward2(self, x, hidden):
        encoded = self.encoder(x.view(1, -1))
        y, hidden = self.rnn(encoded.view(1, 1, -1), hidden)
        y = self.decoder(y.view(1, -1))
        return y, hidden