Skip to content

Commit

Permalink
Implement new style of LSTM function
Browse files Browse the repository at this point in the history
  • Loading branch information
unnonouno committed Aug 21, 2017
1 parent 89f6e79 commit 77412e1
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 17 deletions.
120 changes: 104 additions & 16 deletions chainer/functions/activation/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from chainer import cuda
from chainer import function
from chainer import function_node
from chainer.utils import type_check


Expand All @@ -11,19 +12,27 @@ def _extract_gates(x):
return [r[:, :, i] for i in six.moves.range(4)]


def _sigmoid(x):
def _sigmoid(x, xp=numpy):
half = x.dtype.type(0.5)
return numpy.tanh(x * half) * half + half
return xp.tanh(x * half) * half + half


def _grad_sigmoid(x):
return x * (1 - x)


def _grad_grad_sigmoid(x):
return x * (1 - x) * (1 - 2 * x)


def _grad_tanh(x):
return 1 - x * x


def _grad_grad_tanh(x, gx):
return -2 * x * gx


_preamble = '''
template <typename T> __device__ T sigmoid(T x) {
const T half = 0.5;
Expand All @@ -40,7 +49,7 @@ def _grad_tanh(x):
'''


class LSTM(function.Function):
class LSTM(function_node.FunctionNode):

"""Long short-term memory unit with forget gate.
Expand Down Expand Up @@ -68,6 +77,7 @@ def check_type_forward(self, in_types):
type_check.expect(x_type.shape[i] == c_type.shape[i])

def forward(self, inputs):
self.retain_inputs((0, 1))
c_prev, x = inputs
a, i, f, o = _extract_gates(x)
batch = len(x)
Expand Down Expand Up @@ -96,13 +106,21 @@ def forward(self, inputs):

c_next[batch:] = c_prev[batch:]
self.c = c_next[:batch]
self.retain_outputs((0,))
return c_next, h

def backward(self, inputs, grad_outputs):
def backward(self, indexes, grads):
grad_inputs = (
self.get_retained_inputs() + self.get_retained_outputs() + grads)
return LSTMGrad()(*grad_inputs)


class LSTMGrad(function.Function):

def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
c_prev, x = inputs
c_prev, x, c_next, gc, gh = inputs
batch = len(x)
gc, gh = grad_outputs

gx = xp.empty_like(x)
ga, gi, gf, go = _extract_gates(gx)
Expand All @@ -117,20 +135,25 @@ def backward(self, inputs, grad_outputs):
if gh is None:
gh = 0

a, i, f, o = _extract_gates(x)
if xp is numpy:
co = numpy.tanh(self.c)
tanh_a = xp.tanh(a)
sig_i = _sigmoid(i)
sig_f = _sigmoid(f)
sig_o = _sigmoid(o)

co = numpy.tanh(c_next[:batch])
gc_prev = numpy.empty_like(c_prev)
# multiply f later
gc_prev[:batch] = gh * self.o * _grad_tanh(co) + gc_update
gc_prev[:batch] = gh * sig_o * _grad_tanh(co) + gc_update
gc = gc_prev[:batch]
ga[:] = gc * self.i * _grad_tanh(self.a)
gi[:] = gc * self.a * _grad_sigmoid(self.i)
gf[:] = gc * c_prev[:batch] * _grad_sigmoid(self.f)
go[:] = gh * co * _grad_sigmoid(self.o)
gc_prev[:batch] *= self.f # multiply f here
ga[:] = gc * sig_i * _grad_tanh(tanh_a)
gi[:] = gc * tanh_a * _grad_sigmoid(sig_i)
gf[:] = gc * c_prev[:batch] * _grad_sigmoid(sig_f)
go[:] = gh * co * _grad_sigmoid(sig_o)
gc_prev[:batch] *= sig_f # multiply f here
gc_prev[batch:] = gc_rest
else:
a, i, f, o = _extract_gates(x)
gc_prev = xp.empty_like(c_prev)
cuda.elementwise(
'T c_prev, T c, T gc, T gh, T a, T i_, T f, T o',
Expand All @@ -146,12 +169,77 @@ def backward(self, inputs, grad_outputs):
gc_prev = temp * af;
''',
'lstm_bwd', preamble=_preamble)(
c_prev[:batch], self.c, gc_update, gh, a, i, f, o,
c_prev[:batch], c_next[:batch], gc_update, gh, a, i, f, o,
gc_prev[:batch], ga, gi, gf, go)
gc_prev[batch:] = gc_rest

return gc_prev, gx

def backward(self, inputs, grads):
xp = cuda.get_array_module(*inputs)

c_prev, x, c, gc, gh = inputs
ggc_prev, ggx = grads

gc_prev = xp.zeros_like(c_prev)
gx = xp.zeros_like(x)
gc_next = xp.zeros_like(c)
ggc = ggc_prev.copy()
ggh = xp.zeros_like(gh)

batch = len(x)
c_prev = c_prev[:batch]
c = c[:batch]
gc = gc[:batch]
ggc_prev = ggc_prev[:batch]
ggx = ggx[:batch]

a, i, f, o = _extract_gates(x)
gga, ggi, ggf, ggo = _extract_gates(ggx)

ga, gi, gf, go = _extract_gates(gx)

sig_o = _sigmoid(o, xp)
gsig_o = _grad_sigmoid(sig_o)
ggsig_o = _grad_grad_sigmoid(sig_o)
sig_i = _sigmoid(i, xp)
gsig_i = _grad_sigmoid(sig_i)
ggsig_i = _grad_grad_sigmoid(sig_i)
sig_f = _sigmoid(f, xp)
gsig_f = _grad_sigmoid(sig_f)
ggsig_f = _grad_grad_sigmoid(sig_f)
tanh_a = xp.tanh(a)
gtanh_a = _grad_tanh(tanh_a)
ggtanh_a = _grad_grad_tanh(tanh_a, gtanh_a)
tanh_c = xp.tanh(c)
gtanh_c = _grad_tanh(tanh_c)
ggtanh_c = _grad_grad_tanh(tanh_c, gtanh_c)

gc_bar = gh * sig_o * gtanh_c + gc

gc_prev[:batch] = ggf * gc_bar * gsig_f
ga[:] = (gga * sig_i * ggtanh_a +
ggi * gtanh_a * gsig_i) * gc_bar
gi[:] = (gga * gtanh_a * gsig_i +
ggi * tanh_a * ggsig_i) * gc_bar
gf[:] = (ggc_prev * (gh * sig_o * gtanh_c + gc) * gsig_f +
ggf * gc_bar * c_prev[:batch] * ggsig_f)

ggc_this = (
ggc_prev * sig_f +
gga * sig_i * gtanh_a +
ggi * tanh_a * gsig_i +
ggf * c_prev[:batch] * gsig_f)
ggc[:batch] = ggc_this

dgc_do = gh * gsig_o * gtanh_c
go[:] = ggc_this * dgc_do + ggo * gh * tanh_c * ggsig_o
dgc_dc = gh * sig_o * ggtanh_c
gc_next[:batch] = ggc_this * dgc_dc + ggo * gh * gtanh_c * gsig_o
ggh[:batch] = ggc_this * sig_o * gtanh_c + ggo * tanh_c * gsig_o

return gc_prev, gx, gc_next, ggc, ggh


def lstm(c_prev, x):
"""Long Short-Term Memory units as an activation function.
Expand Down Expand Up @@ -251,4 +339,4 @@ def lstm(c_prev, x):
*input array*.
"""
return LSTM()(c_prev, x)
return LSTM().apply((c_prev, x))
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import chainer
from chainer import cuda
from chainer import functions
from chainer.functions.activation import lstm
from chainer import gradient_check
from chainer import testing
from chainer.testing import attr
Expand Down Expand Up @@ -36,6 +37,9 @@ def setUp(self):
self.gc = numpy.random.uniform(-1, 1, hidden_shape).astype(self.dtype)
self.gh = numpy.random.uniform(-1, 1, y_shape).astype(self.dtype)

self.ggc = numpy.random.uniform(-1, 1, hidden_shape).astype(self.dtype)
self.ggx = numpy.random.uniform(-1, 1, x_shape).astype(self.dtype)

self.check_forward_options = {}
self.check_backward_options = {'dtype': numpy.float64}
if self.dtype == numpy.float16:
Expand Down Expand Up @@ -96,7 +100,7 @@ def test_flat_forward_gpu(self):

def check_backward(self, c_prev_data, x_data, c_grad, h_grad):
gradient_check.check_backward(
functions.LSTM(),
functions.lstm,
(c_prev_data, x_data), (c_grad, h_grad),
**self.check_backward_options)

Expand Down Expand Up @@ -166,5 +170,75 @@ def test_flat_no_gh_backward_gpu(self):
self.flat()
self.test_no_gh_backward_gpu()

def check_double_backward(
self, c_prev_data, x_data, gc_data, gh_data, ggc_prev_data,
ggx_data):
gradient_check.check_double_backward(
chainer.functions.lstm, (c_prev_data, x_data),
(gc_data, gh_data), (ggc_prev_data, ggx_data), dtype='d')

@condition.retry(3)
def test_double_backward_cpu(self):
self.check_double_backward(
self.c_prev, self.x, self.gc, self.gh, self.ggc, self.ggx)

@attr.gpu
@condition.retry(3)
def test_double_backward_gpu(self):
self.check_double_backward(
cuda.to_gpu(self.c_prev), cuda.to_gpu(self.x),
cuda.to_gpu(self.gc), cuda.to_gpu(self.gh),
cuda.to_gpu(self.ggc), cuda.to_gpu(self.ggx))


@testing.parameterize(*(testing.product({
'batch': [3, 2, 0],
'dtype': [numpy.float32],
}) + testing.product({
'batch': [3],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
})))
class TestLSTMGrad(unittest.TestCase):

def setUp(self):
hidden_shape = (3, 2, 4)
x_shape = (self.batch, 8, 4)
y_shape = (self.batch, 2, 4)
self.c_prev = numpy.random.uniform(
-1, 1, hidden_shape).astype(self.dtype)
self.x = numpy.random.uniform(-1, 1, x_shape).astype(self.dtype)
self.c_next = numpy.random.uniform(
-1, 1, hidden_shape).astype(self.dtype)

self.gc = numpy.random.uniform(-1, 1, hidden_shape).astype(self.dtype)
self.gh = numpy.random.uniform(-1, 1, y_shape).astype(self.dtype)

self.ggc_prev = numpy.random.uniform(
-1, 1, hidden_shape).astype(self.dtype)
self.ggx = numpy.random.uniform(-1, 1, x_shape).astype(self.dtype)

def check_backward(
self, c_prev_data, x_data, c_next_data, gc_data, gh_data,
ggc_prev_data, ggx_data):
gradient_check.check_backward(
lstm.LSTMGrad(),
(c_prev_data, x_data, c_next_data, gc_data, gh_data),
(ggc_prev_data, ggx_data), dtype='d', atol=1e-3, rtol=1e-3)

@condition.retry(3)
def test_backward_cpu(self):
self.check_backward(
self.c_prev, self.x, self.c_next, self.gc,
self.gh, self.ggc_prev, self.ggx)

@attr.gpu
@condition.retry(3)
def test_backward_gpu(self):
self.check_backward(
cuda.to_gpu(self.c_prev), cuda.to_gpu(self.x),
cuda.to_gpu(self.c_next),
cuda.to_gpu(self.gc), cuda.to_gpu(self.gh),
cuda.to_gpu(self.ggc_prev), cuda.to_gpu(self.ggx))


testing.run_module(__name__, __file__)

0 comments on commit 77412e1

Please sign in to comment.