Skip to content

Commit

Permalink
Fix jax lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
honnibal committed Jan 18, 2020
1 parent 25ab53d commit c4cab3f
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions thinc/backends/jax_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from .ops import Ops
import numpy
from ..types import Array, Array2d, Array1d, ArrayT, DTypes, Array3d, Wrapper
Expand Down Expand Up @@ -635,16 +636,22 @@ def recurrent_lstm(W, b, cell, hidden, X):
(W, b, hidden, cell, X), (Y, gates) = state
return Y, cell, gates


@jax_jit()
def _lstm_stepper(t, state):
(W, b, hidden, cell, X), (Y, gates) = state
next_hiddens, next_cells, next_gates = lstm(W, b, cell, hidden, X[t])
next_hiddens, next_cells, next_gates = lstm(W, b, hidden, cell, X[t])
Y = index_update(Y, index[t], next_hiddens)
gates = index_update(gates, index[t], next_gates)
return (W, b, next_hiddens, next_cells, X), (Y, gates)

@jax_jit()
def lstm(W, b, cell_tm1, hidden_tm1, inputs):
def lstm(W, b, hidden_tm1, cell_tm1, inputs):
# Support usage where the inputs are 1d
shapes = (hidden_tm1.shape, cell_tm1.shape, hidden_tm1.shape + (4,))
cell_tm1 = cell_tm1.reshape((-1, cell_tm1.shape[-1]))
hidden_tm1 = cell_tm1.reshape((-1, hidden_tm1.shape[-1]))
inputs = inputs.reshape((-1, inputs.shape[-1]))
xp = jax.numpy
X = xp.hstack((inputs, hidden_tm1))
acts = xp.dot(X, W.T) + b
Expand All @@ -661,7 +668,9 @@ def lstm(W, b, cell_tm1, hidden_tm1, inputs):
cells = (hf * cell_tm1) + (hi * hc)
hiddens = xp.tanh(cells) * ho
gates = xp.concatenate((hf, hi, ho, hc), axis=-1)
gates = acts.reshape((acts.shape[0], acts.shape[1], 4))
hiddens = hiddens.reshape(shapes[0])
cells = cells.reshape(shapes[1])
gates = gates.reshape(shapes[2])
return hiddens, cells, gates #(hf, hi, ho, hc)


Expand Down

0 comments on commit c4cab3f

Please sign in to comment.