In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import data

In [2]:
verbose = False
path = './data/2013.txt'
corpus = data.Corpus(path, verbose)

In [3]:
class Model(nn.Module):
        
    def __init__(self, vocab_size, embed_size, nhidden, nlayers):
        super(Model, self).__init__()
        self.encoder = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.RNN(embed_size, nhidden, nlayers)
        self.decoder = nn.Linear(nhidden, vocab_size)

    def forward(self, x, h0):
        y = self.encoder(x)
        y = y.unsqueeze(0)
        y, h1 = self.rnn(y, h0)
        y = self.decoder(y)
        return y, h1

In [4]:
vocab_size = len(corpus)
embed_size = 200
nhidden = 10
nlayers = 2
lm = Model(vocab_size, embed_size, nhidden, nlayers)
lm

Model(
  (encoder): Embedding(27738, 200)
  (rnn): RNN(200, 10, num_layers=2)
  (decoder): Linear(in_features=10, out_features=27738, bias=True)
)

In [5]:
# for value in corpus.data:
batch_size = 4
x0 = torch.tensor([0,1,2,3])
h0 = torch.zeros((nlayers, batch_size, nhidden))
x1, h1 = lm(x0, h0)

In [6]:
x1

tensor([[[ 0.2761, -0.1069, -0.9166,  ..., -0.1119, -0.1515, -0.2233],
         [ 0.7376, -0.0175, -0.8107,  ..., -0.1473, -0.5412,  0.0653],
         [ 0.3830, -0.2736, -0.8513,  ..., -0.2904, -0.4090, -0.2455],
         [ 0.3308, -0.3808, -0.9603,  ..., -0.7134, -0.5324,  0.3336]]],
       grad_fn=<AddBackward0>)

In [7]:
h1

tensor([[[-0.9361,  0.9994, -0.5205,  0.9780,  0.8815, -0.8015,  0.9678,
           0.9896,  0.9612, -0.0756],
         [ 0.2070,  0.9935, -0.9218,  0.2652, -0.2870, -0.9832, -0.9905,
          -0.9987,  0.8509, -0.2697],
         [-0.6410,  0.9526,  0.7183, -0.6642,  0.2388, -0.9981,  0.9945,
          -0.0747,  0.8875, -0.9931],
         [ 0.9949,  0.2329, -0.8983, -0.6333,  0.9142,  0.6052,  0.9988,
          -0.9907,  0.9464, -0.9994]],

        [[ 0.7965, -0.4333,  0.2532,  0.6911, -0.2332,  0.2352, -0.1007,
           0.4840, -0.7294, -0.1737],
         [-0.3959, -0.3968, -0.2634,  0.1498, -0.6766,  0.6409, -0.5827,
           0.1585, -0.4506, -0.0683],
         [ 0.7903, -0.5223,  0.3608,  0.7237, -0.6363,  0.1969, -0.0592,
           0.5547, -0.0294, -0.2775],
         [ 0.0812, -0.7990, -0.2480, -0.0617, -0.6793,  0.6248, -0.5908,
          -0.4095,  0.7514,  0.6860]]], grad_fn=<StackBackward>)