Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduction in memory requirements: Add SplitInitializer for separate initialization #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 53 additions & 35 deletions rnn_cell.py
Expand Up @@ -10,7 +10,7 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import tf_logging as logging

from tensorflow.python.ops.init_ops import Initializer

import tensorflow as tf
from tensorflow.python.layers import base as base_layer
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(self, num_units, depth, forget_bias=1.0,
self._gate_activation = gate_activation or math_ops.sigmoid
self._cell_activation = cell_activation or array_ops.identity
self._initializer = initializer or init_ops.orthogonal_initializer()
self._input_gate_initializer = (input_gate_initializer
self._input_gate_initializer = (input_gate_initializer
or init_ops.glorot_normal_initializer())
self._use_bias = use_bias
self._kernels = None
Expand Down Expand Up @@ -117,35 +117,35 @@ def build(self, inputs_shape):
self._peep_kernels = []
for i in range(self.depth):
if i == 0:
input_kernel = self.add_variable(
"input_gate_kernel",
shape=[input_depth, 4 * self._num_units],
initializer=self._input_gate_initializer)
hidden_kernel = self.add_variable(
"hidden_gate_kernel",
shape=[h_depth, 4 * self._num_units],
initializer=self._initializer)
kernel = tf.concat([input_kernel, hidden_kernel],
axis=0, name="kernel_0")
kernel = self.add_variable(
"kernel_0",
shape=[input_depth + h_depth, 4 * self._num_units],
dtype=tf.float32,
initializer=SplitInitializer(
initializer1=self._input_gate_initializer,
initializer2=self._initializer,
shape1=[input_depth, 4 * self._num_units],
shape2=[h_depth, 4 * self._num_units],
concat_axis=0))
self._kernels.append(kernel)
else:
self._kernels.append(
self.add_variable(
"kernel_{}".format(i),
shape=[2 * h_depth, 4 * self._num_units],
initializer=self._initializer))
self.add_variable(
"kernel_{}".format(i),
shape=[2 * h_depth, 4 * self._num_units],
initializer=self._initializer))
if self._use_bias:
self._biases.append(
self.add_variable(
"bias_{}".format(i),
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype)))
self.add_variable(
"bias_{}".format(i),
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype)))
if self._use_peepholes:
self._peep_kernels.append(
self.add_variable(
"peep_kernel_{}".format(i),
shape=[h_depth, 3 * self._num_units],
initializer=self._initializer))
self.add_variable(
"peep_kernel_{}".format(i),
shape=[h_depth, 3 * self._num_units],
initializer=self._initializer))

self.built = True

Expand All @@ -172,31 +172,31 @@ def _recurrence(self, inputs, hidden_state, cell_states, depth):
h = hidden_state

gate_inputs = math_ops.matmul(
array_ops.concat([inputs, h], 1), self._kernels[depth])
array_ops.concat([inputs, h], 1), self._kernels[depth])
if self._use_bias:
gate_inputs = nn_ops.bias_add(gate_inputs, self._biases[depth])
if self._use_peepholes:
peep_gate_inputs = math_ops.matmul(c, self._peep_kernels[depth])
i_peep, f_peep, o_peep = array_ops.split(
value=peep_gate_inputs, num_or_size_splits=3, axis=one)
value=peep_gate_inputs, num_or_size_splits=3, axis=one)

# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=gate_inputs, num_or_size_splits=4, axis=one)
value=gate_inputs, num_or_size_splits=4, axis=one)
if self._use_peepholes:
i += i_peep
f += f_peep
o += o_peep
o += o_peep

if self._use_peepholes:
peep_gate_inputs = math_ops.matmul(c, self._peep_kernels[depth])
i_peep, f_peep, o_peep = array_ops.split(
value=peep_gate_inputs, num_or_size_splits=3, axis=one)
value=peep_gate_inputs, num_or_size_splits=3, axis=one)
i += i_peep
f += f_peep
o += o_peep
o += o_peep

# Note that using `add` and `multiply` instead of `+` and `*` gives a
# Note that using `add` and `multiply` instead of `+` and `*` gives a
# performance improvement. So using those at the cost of readability.
add = math_ops.add
multiply = math_ops.multiply
Expand All @@ -217,10 +217,10 @@ def _recurrence(self, inputs, hidden_state, cell_states, depth):
new_cs = [new_c]
else:
new_c, new_cs = self._recurrence(
inputs=inner_input,
hidden_state=inner_hidden,
cell_states=cell_states,
depth=depth + 1)
inputs=inner_input,
hidden_state=inner_hidden,
cell_states=cell_states,
depth=depth + 1)
new_h = multiply(self._activation(new_c), self._gate_activation(o))
new_cs = [new_h] + new_cs
return new_h, new_cs
Expand Down Expand Up @@ -250,3 +250,21 @@ def call(self, inputs, state):
else:
next_state = array_ops.concat(next_state, axis=1)
return outputs, next_state


class SplitInitializer(Initializer):

def __init__(self, initializer1, initializer2, shape1, shape2, concat_axis):
self._initializer1 = initializer1
self._initializer2 = initializer2
self._shape1 = shape1
self._shape2 = shape2
self._concat_axis = concat_axis

def __call__(self, shape, dtype=None, partition_info=None):
weight_tensor1 = self._initializer1(self._shape1, dtype=dtype)
weight_tensor2 = self._initializer2(self._shape2, dtype=dtype)

combined_weight_tensor = array_ops.concat(values=[weight_tensor1, weight_tensor2], axis=self._concat_axis)

return combined_weight_tensor