Skip to content

Commit

Permalink
refactoring & some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bharathgs committed Aug 5, 2018
1 parent bd49bd4 commit 98e3a55
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 28 deletions.
12 changes: 11 additions & 1 deletion .gitignore
Expand Up @@ -105,4 +105,14 @@ venv.bak/
.mypy_cache/

#nohup
nohup.out
nohup.out

#others
notebooks/*
LICENSE-MIT
pyproject.toml
tox.ini
tests/*
.coveragerc
MANIFEST.in
setup.py
4 changes: 4 additions & 0 deletions nalu/__init__.py
@@ -0,0 +1,4 @@
__version__ = '0.0.1'

from .core import *
from .layers import *
2 changes: 2 additions & 0 deletions nalu/core/__init__.py
@@ -0,0 +1,2 @@
from .nac_cell import NacCell
from .nalu_cell import NaluCell
21 changes: 9 additions & 12 deletions NeuralAccumulator.py → nalu/core/nac_cell.py
@@ -1,27 +1,24 @@
import torch

from math import sqrt
from torch import Tensor, exp, log, nn
from torch import Tensor, nn
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform_
from torch.nn.functional import tanh, sigmoid, linear


class NeuralAccumulator(nn.Module):
class NacCell(nn.Module):
"""Basic NAC unit implementation
from https://arxiv.org/pdf/1808.00508.pdf
"""

def __init__(self, inputs, outputs):
def __init__(self, in_shape, out_shape):
"""
inputs: input sample size
outputs: output sample size
in_shape: input sample dimension
out_shape: output sample dimension
"""
super().__init__()
self.inputs = inputs
self.outputs = outputs
self.W_ = Parameter(Tensor(outputs, inputs))
self.M_ = Parameter(Tensor(outputs, inputs))
self.in_shape = in_shape
self.out_shape = out_shape
self.W_ = Parameter(Tensor(out_shape, in_shape))
self.M_ = Parameter(Tensor(out_shape, in_shape))
self.W = Parameter(tanh(self.W_) * sigmoid(self.M_))
xavier_uniform_(self.W_), xavier_uniform_(self.M_)
self.register_parameter('bias', None)
Expand Down
27 changes: 12 additions & 15 deletions NALU.py → nalu/core/nalu_cell.py
@@ -1,32 +1,29 @@
import torch

from math import sqrt
from torch import Tensor, exp, log, nn
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform_
from torch.nn.functional import tanh, sigmoid, linear
from NeuralAccumulator import NeuralAccumulator
from torch.nn.functional import sigmoid, linear
from .nac_cell import NacCell


class NALU(nn.Module):
class NaluCell(nn.Module):
"""Basic NALU unit implementation
from https://arxiv.org/pdf/1808.00508.pdf
"""

def __init__(self, inputs, outputs):
def __init__(self, in_shape, out_shape):
"""
inputs: input sample size
outputs: output sample size
in_shape: input sample dimension
out_shape: output sample dimension
"""
super().__init__()
self.inputs = inputs
self.outputs = outputs
self.G = Parameter(Tensor(outputs, inputs))
self.W = Parameter(Tensor(outputs, inputs))
self.nac = NeuralAccumulator(outputs, inputs)
self.in_shape = in_shape
self.out_shape = out_shape
self.G = Parameter(Tensor(out_shape, in_shape))
self.W = Parameter(Tensor(out_shape, in_shape))
self.nac = NacCell(out_shape, in_shape)
xavier_uniform_(self.G), xavier_uniform_(self.W)
self.eps = 1e-5
self.register_parameter('bias', None)
xavier_uniform_(self.G), xavier_uniform_(self.W)

def forward(self, input):
a = self.nac(input)
Expand Down
1 change: 1 addition & 0 deletions nalu/layers/__init__.py
@@ -0,0 +1 @@
from .nalu_layer import NaluLayer
18 changes: 18 additions & 0 deletions nalu/layers/nalu_layer.py
@@ -0,0 +1,18 @@
from torch.nn import Sequential
from torch import nn
from nalu.core.nalu_cell import NaluCell


class NaluLayer(nn.Module):
def __init__(self, input_shape, output_shape, n_layers, hidden_shape):
super().__init__()
self.input_shape = input_shape
self.output_shape = output_shape
self.n_layers = n_layers
self.hidden_shape = hidden_shape
layers = [NaluCell(hidden_shape if n > 0 else input_shape,
hidden_shape if n < n_layers - 1 else output_shape) for n in range(n_layers)]
self.model = Sequential(*layers)

def forward(self, data):
return self.model(data)
3 changes: 3 additions & 0 deletions requirements.txt
@@ -0,0 +1,3 @@
torch
numpy

0 comments on commit 98e3a55

Please sign in to comment.