Skip to content

Commit

Permalink
Merge pull request #4355 from unnonouno/cudnn-rnn-weight-concat
Browse files Browse the repository at this point in the history
Extract a function that concatenates weight and bias matrixes for cuDNN
  • Loading branch information
mitmul committed Mar 7, 2018
2 parents ecf2292 + e2c3926 commit 9c8db99
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 141 deletions.
14 changes: 5 additions & 9 deletions chainer/functions/connection/n_step_gru.py
@@ -1,5 +1,3 @@
import itertools

import numpy

import chainer
Expand Down Expand Up @@ -296,18 +294,16 @@ def n_step_gru_base(n_layers, dropout_ratio, hx, ws, bs, xs,
states = get_random_state().create_dropout_states(dropout_ratio)
lengths = [len(x) for x in xs]
xs = chainer.functions.concat(xs, axis=0)
# flatten all input variables
inputs = tuple(itertools.chain(
(hx,),
itertools.chain.from_iterable(ws),
itertools.chain.from_iterable(bs),
(xs,)))

w = n_step_rnn.cudnn_rnn_weight_concat(
n_layers, states, use_bi_direction, 'gru', ws, bs)

if use_bi_direction:
rnn = NStepBiGRU
else:
rnn = NStepGRU

hy, ys = rnn(n_layers, states, lengths)(*inputs)
hy, ys = rnn(n_layers, states, lengths)(hx, w, xs)
sections = numpy.cumsum(lengths[:-1])
ys = chainer.functions.split_axis(ys, sections, 0)
return hy, ys
Expand Down
14 changes: 5 additions & 9 deletions chainer/functions/connection/n_step_lstm.py
@@ -1,5 +1,3 @@
import itertools

import numpy

import chainer
Expand Down Expand Up @@ -428,18 +426,16 @@ def n_step_lstm_base(
states = get_random_state().create_dropout_states(dropout_ratio)
lengths = [len(x) for x in xs]
xs = chainer.functions.concat(xs, axis=0)
# flatten all input variables
inputs = tuple(itertools.chain(
(hx, cx),
itertools.chain.from_iterable(ws),
itertools.chain.from_iterable(bs),
(xs,)))

w = n_step_rnn.cudnn_rnn_weight_concat(
n_layers, states, use_bi_direction, 'lstm', ws, bs)

if use_bi_direction:
rnn = NStepBiLSTM
else:
rnn = NStepLSTM

hy, cy, ys = rnn(n_layers, states, lengths)(*inputs)
hy, cy, ys = rnn(n_layers, states, lengths)(hx, cx, w, xs)
sections = numpy.cumsum(lengths[:-1])
ys = chainer.functions.split_axis(ys, sections, 0)
return hy, cy, ys
Expand Down

0 comments on commit 9c8db99

Please sign in to comment.