From e2b122cbb746fda44ac6743a032ae7b0d3de6961 Mon Sep 17 00:00:00 2001 From: vrxacs Date: Fri, 31 Aug 2018 17:05:13 -0500 Subject: [PATCH] NAC fix --- nalu/__init__.py | 2 +- nalu/core/nac_cell.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nalu/__init__.py b/nalu/__init__.py index 9553ed0..287407f 100644 --- a/nalu/__init__.py +++ b/nalu/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.0.2' +__version__ = '0.0.3' from .core import * from .layers import * diff --git a/nalu/core/nac_cell.py b/nalu/core/nac_cell.py index 662f287..00c5253 100644 --- a/nalu/core/nac_cell.py +++ b/nalu/core/nac_cell.py @@ -19,9 +19,9 @@ def __init__(self, in_shape, out_shape): self.out_shape = out_shape self.W_ = Parameter(Tensor(out_shape, in_shape)) self.M_ = Parameter(Tensor(out_shape, in_shape)) - self.W = Parameter(tanh(self.W_) * sigmoid(self.M_)) xavier_uniform_(self.W_), xavier_uniform_(self.M_) self.register_parameter('bias', None) - def forward(self, input): - return linear(input, self.W, self.bias) + def forward(self, input): + W = tanh(self.W_) * sigmoid(self.M_) + return linear(input, W, self.bias)