In [1]:
from initialize import *
class RNN(Module):
    def __init__(self, input_dim, output_dim, hid_dim, activation='linear', return_hidden=False, return_last=False):
        super().__init__()  
        self.input_ff = nn.Linear(input_dim, hid_dim)
        self.hidden_ff = nn.Linear(hid_dim,hid_dim)
        self.output_ff = nn.Linear(hid_dim, output_dim)
        if activation == 'linear':
            self.activation = nn.Identity()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            raise Exception("Unknown activation type")
        self.hid_dim = hid_dim
        self.return_hidden = return_hidden
        self.return_last = return_last
    def forward(self, x, initial_hidden=None):
        
        #src = [batch size, input len, input dim]
        length = x.shape[1]
        batch_size = x.shape[0]

        hidden = []
        # Initial hidden state
        if initial_hidden is None:
            hidden.append(torch.zeros(batch_size, 1, self.hid_dim, dtype=x.dtype, device=x.device))
        else:
            hidden.append(initial_hidden)

        # input mapping
        x = self.input_ff(x)

        # recurrent relation
        for i in range(length):
            h_next = self.activation(x[:,i:i+1,:] + self.hidden_ff(hidden[i]))
            hidden.append(h_next)

        # Convert all hidden into a tensor
        hidden = torch.cat(hidden[1:], dim=1)

        # output mapping
        out = self.output_ff(hidden)[:,-1,:] if self.return_last else self.output_ff(hidden)

        if self.return_hidden:
            return out, hidden
        return out

In [20]:
d = torchtext.datasets.WikiText2()



In [81]:
def filt(string):
    if string.strip() == '':
        return False
    filt = ['=', '<', '>', '[']
    for s in string:
        if s in filt:
            return False
    return True


l = [string for string in list(iter(d[0]))+list(iter(d[1]))+list(iter(d[2])) if filt(string)]

In [82]:
with open('quotes.txt', 'w', encoding='utf-8') as f:
    f.write(''.join(l))

In [83]:
text = open('quotes.txt', 'r', encoding='utf-8').read()
import string
raw_text = ''.join([c for c in text.lower() if c in string.printable])

In [84]:
chars = sorted(list(set(raw_text)))
print('total chars:', len(chars))
char_int = {c: i for i, c in enumerate(chars)}
int_char = {i: c for i, c in enumerate(chars)}

total chars: 60


In [86]:
n_chars = len(raw_text)
n_vocab = len(chars)
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
	seq_in = raw_text[i:i + seq_length]
	seq_out = raw_text[i + seq_length]
	dataX.append([char_int[char] for char in seq_in])
	dataY.append(char_int[seq_out])
n_patterns = len(dataX)

In [88]:
dataX.shape

AttributeError: 'list' object has no attribute 'shape'