Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
33 lines (25 sloc) 915 Bytes
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
# Copied from chainer examples code
class RNNForLM(chainer.Chain):
"""Definition of a recurrent net for language modeling"""
def __init__(self, n_vocab, n_units):
super(RNNForLM, self).__init__()
with self.init_scope():
self.embed = L.EmbedID(n_vocab, n_units)
self.l1 = L.LSTM(n_units, n_units)
self.l2 = L.LSTM(n_units, n_units)
self.l3 = L.Linear(n_units, n_vocab)
for param in self.params():
param.data[...] = np.random.uniform(-0.1, 0.1, param.data.shape)
def reset_state(self):
self.l1.reset_state()
self.l2.reset_state()
def __call__(self, x):
h0 = self.embed(x)
h1 = self.l1(F.dropout(h0))
h2 = self.l2(F.dropout(h1))
y = self.l3(F.dropout(h2))
return y