Skip to content

Commit

Permalink
Added support for layer normalization epsilons.
Browse files Browse the repository at this point in the history
  • Loading branch information
bschaefl committed Dec 15, 2021
1 parent 1497a4d commit 5231d29
Show file tree
Hide file tree
Showing 22 changed files with 177 additions and 55 deletions.
18 changes: 9 additions & 9 deletions examples/bit_pattern/bit_pattern_demo.ipynb

Large diffs are not rendered by default.

Binary file modified examples/bit_pattern/resources/hopfield_adapted.pdf
Binary file not shown.
Binary file modified examples/bit_pattern/resources/hopfield_base.pdf
Binary file not shown.
Binary file modified examples/bit_pattern/resources/hopfield_lookup.pdf
Binary file not shown.
Binary file modified examples/bit_pattern/resources/hopfield_lookup_adapted.pdf
Binary file not shown.
Binary file modified examples/bit_pattern/resources/hopfield_pooling.pdf
Binary file not shown.
Binary file modified examples/bit_pattern/resources/hopfield_pooling_adapted.pdf
Binary file not shown.
20 changes: 10 additions & 10 deletions examples/latch_sequence/latch_sequence_demo.ipynb

Large diffs are not rendered by default.

Binary file modified examples/latch_sequence/resources/hopfield_adapted.pdf
Binary file not shown.
Binary file modified examples/latch_sequence/resources/hopfield_base.pdf
Binary file not shown.
Binary file modified examples/latch_sequence/resources/hopfield_lookup.pdf
Binary file not shown.
Binary file modified examples/latch_sequence/resources/hopfield_lookup_adapted.pdf
Binary file not shown.
Binary file modified examples/latch_sequence/resources/hopfield_pooling.pdf
Binary file not shown.
Binary file modified examples/latch_sequence/resources/hopfield_pooling_adapted.pdf
Binary file not shown.
Binary file modified examples/latch_sequence/resources/lstm_base.pdf
Binary file not shown.
121 changes: 113 additions & 8 deletions examples/mnist_bags/mnist_bags_demo.ipynb

Large diffs are not rendered by default.

Binary file modified examples/mnist_bags/resources/attention_base.pdf
Binary file not shown.
Binary file modified examples/mnist_bags/resources/gated_attention_base.pdf
Binary file not shown.
Binary file modified examples/mnist_bags/resources/hopfield_pooling.pdf
Binary file not shown.
20 changes: 16 additions & 4 deletions hflayers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ def __init__(self,

normalize_stored_pattern: bool = True,
normalize_stored_pattern_affine: bool = True,
normalize_stored_pattern_eps: float = 1e-5,
normalize_state_pattern: bool = True,
normalize_state_pattern_affine: bool = True,
normalize_state_pattern_eps: float = 1e-5,
normalize_pattern_projection: bool = True,
normalize_pattern_projection_affine: bool = True,
normalize_pattern_projection_eps: float = 1e-5,
normalize_hopfield_space: bool = False,
normalize_hopfield_space_affine: bool = False,
normalize_hopfield_space_eps: float = 1e-5,
stored_pattern_as_static: bool = False,
state_pattern_as_static: bool = False,
pattern_projection_as_static: bool = False,
Expand Down Expand Up @@ -60,12 +64,16 @@ def __init__(self,
:param update_steps_eps: minimum difference threshold between two consecutive association update steps
:param normalize_stored_pattern: apply normalization on stored patterns
:param normalize_stored_pattern_affine: additionally enable affine normalization of stored patterns
:param normalize_stored_pattern_eps: offset of the denominator for numerical stability
:param normalize_state_pattern: apply normalization on state patterns
:param normalize_state_pattern_affine: additionally enable affine normalization of state patterns
:param normalize_state_pattern_eps: offset of the denominator for numerical stability
:param normalize_pattern_projection: apply normalization on the pattern projection
:param normalize_pattern_projection_affine: additionally enable affine normalization of pattern projection
:param normalize_pattern_projection_eps: offset of the denominator for numerical stability
:param normalize_hopfield_space: enable normalization of patterns in the Hopfield space
:param normalize_hopfield_space_affine: additionally enable affine normalization of patterns in Hopfield space
:param normalize_hopfield_space_eps: offset of the denominator for numerical stability
:param stored_pattern_as_static: interpret specified stored patterns as being static
:param state_pattern_as_static: interpret specified state patterns as being static
:param pattern_projection_as_static: interpret specified pattern projections as being static
Expand All @@ -92,7 +100,8 @@ def __init__(self,
disable_out_projection=disable_out_projection, key_as_static=stored_pattern_as_static,
query_as_static=state_pattern_as_static, value_as_static=pattern_projection_as_static,
value_as_connected=pattern_projection_as_connected, normalize_pattern=normalize_hopfield_space,
normalize_pattern_affine=normalize_hopfield_space_affine)
normalize_pattern_affine=normalize_hopfield_space_affine,
normalize_pattern_eps=normalize_hopfield_space_eps)
self.association_activation = None
if association_activation is not None:
self.association_activation = getattr(torch, association_activation, None)
Expand All @@ -105,7 +114,8 @@ def __init__(self,
normalized_shape = input_size if stored_pattern_size is None else stored_pattern_size
assert normalized_shape is not None, "stored pattern size required for setting up normalisation"
self.norm_stored_pattern = nn.LayerNorm(
normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine)
normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine,
eps=normalize_stored_pattern_eps)

# Initialise state pattern normalization.
self.norm_state_pattern = None
Expand All @@ -114,7 +124,8 @@ def __init__(self,
if normalize_state_pattern:
assert input_size is not None, "input size required for setting up normalisation"
self.norm_state_pattern = nn.LayerNorm(
normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine)
normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine,
eps=normalize_state_pattern_eps)

# Initialise pattern projection normalization.
self.norm_pattern_projection = None
Expand All @@ -124,7 +135,8 @@ def __init__(self,
normalized_shape = input_size if pattern_projection_size is None else pattern_projection_size
assert normalized_shape is not None, "pattern projection size required for setting up normalisation"
self.norm_pattern_projection = nn.LayerNorm(
normalized_shape=normalized_shape, elementwise_affine=normalize_pattern_projection_affine)
normalized_shape=normalized_shape, elementwise_affine=normalize_pattern_projection_affine,
eps=normalize_pattern_projection_eps)

# Initialise remaining auxiliary properties.
if self.association_core.static_execution:
Expand Down
44 changes: 23 additions & 21 deletions hflayers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,26 @@ class HopfieldCore(Module):
}

def __init__(self,
embed_dim=None, # type: Optional[int]
num_heads=1, # type: int
dropout=0.0, # type: float
bias=True, # type: bool
add_bias_kv=False, # type: bool
add_zero_attn=False, # type: bool
kdim=None, # type: Optional[int]
vdim=None, # type: Optional[int]

head_dim=None, # type: Optional[int]
pattern_dim=None, # type: Optional[int]
out_dim=None, # type: Optional[int]
disable_out_projection=False, # type: bool
key_as_static=False, # type: bool
query_as_static=False, # type: bool
value_as_static=False, # type: bool
value_as_connected=False, # type: bool
normalize_pattern=False, # type: bool
normalize_pattern_affine=False # type: bool
embed_dim=None, # type: Optional[int]
num_heads=1, # type: int
dropout=0.0, # type: float
bias=True, # type: bool
add_bias_kv=False, # type: bool
add_zero_attn=False, # type: bool
kdim=None, # type: Optional[int]
vdim=None, # type: Optional[int]

head_dim=None, # type: Optional[int]
pattern_dim=None, # type: Optional[int]
out_dim=None, # type: Optional[int]
disable_out_projection=False, # type: bool
key_as_static=False, # type: bool
query_as_static=False, # type: bool
value_as_static=False, # type: bool
value_as_connected=False, # type: bool
normalize_pattern=False, # type: bool
normalize_pattern_affine=False, # type: bool
normalize_pattern_eps=1e-5 # type: float
):
super(HopfieldCore, self).__init__()

Expand All @@ -77,6 +78,7 @@ def __init__(self,

self.value_as_connected = value_as_connected
self.normalize_pattern, self.normalize_pattern_affine = normalize_pattern, normalize_pattern_affine
self.normalize_pattern_eps = normalize_pattern_eps
self.disable_out_projection = disable_out_projection

# In case of a static-only executions, check corresponding projections and normalizations.
Expand Down Expand Up @@ -315,7 +317,7 @@ def forward(self,

key_as_static=self.key_as_static, query_as_static=self.query_as_static,
value_as_static=self.value_as_static, value_as_connected=self.value_as_connected,
normalize_pattern=self.normalize_pattern,
normalize_pattern=self.normalize_pattern, normalize_pattern_eps=self.normalize_pattern_eps,
p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias,
head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling,
update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
Expand All @@ -330,7 +332,7 @@ def forward(self,

key_as_static=self.key_as_static, query_as_static=self.query_as_static,
value_as_static=self.value_as_static, value_as_connected=self.value_as_connected,
normalize_pattern=self.normalize_pattern,
normalize_pattern=self.normalize_pattern, normalize_pattern_eps=self.normalize_pattern_eps,
p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias,
head_dim=head_dim, pattern_dim=self.pattern_dim, scaling=scaling,
update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
Expand Down
9 changes: 6 additions & 3 deletions hflayers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def hopfield_core_forward(query, # type: Tensor
value_as_static=False, # type: bool
value_as_connected=False, # type: bool
normalize_pattern=False, # type: bool
normalize_pattern_eps=1e-5, # type: float
p_norm_weight=None, # type: Optional[Tensor]
p_norm_bias=None, # type: Optional[Tensor]
head_dim=None, # type: Optional[int]
Expand Down Expand Up @@ -76,6 +77,7 @@ def hopfield_core_forward(query, # type: Tensor
value_as_static: interpret specified key as being static.
value_as_connected: connect value projection with key projection.
normalize_pattern: enable normalization of patterns.
normalize_pattern_eps: offset of the denominator for numerical stability.
p_norm_weight, p_norm_bias: pattern normalization weight and bias.
head_dim: dimensionality of each head.
pattern_dim: dimensionality of each projected value input.
Expand Down Expand Up @@ -132,7 +134,8 @@ def hopfield_core_forward(query, # type: Tensor
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v,
key_as_static=key_as_static, query_as_static=query_as_static,
value_as_static=value_as_static, value_as_connected=value_as_connected,
normalize_pattern=normalize_pattern, p_norm_weight=p_norm_weight, p_norm_bias=p_norm_bias,
normalize_pattern=normalize_pattern, normalize_pattern_eps=normalize_pattern_eps,
p_norm_weight=p_norm_weight, p_norm_bias=p_norm_bias,
head_dim=head_dim, pattern_dim=pattern_dim, scaling=scaling, update_steps_max=update_steps_max,
update_steps_eps=update_steps_eps, return_raw_associations=return_raw_associations)
tgt_len, bsz, embed_dim = query.shape[0], value.shape[1], query.shape[2]
Expand Down Expand Up @@ -323,10 +326,10 @@ def hopfield_core_forward(query, # type: Tensor
if normalize_pattern:
q = torch.nn.functional.layer_norm(
input=q.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,),
weight=p_norm_weight, bias=p_norm_bias).reshape(shape=q.shape)
weight=p_norm_weight, bias=p_norm_bias, eps=normalize_pattern_eps).reshape(shape=q.shape)
k = torch.nn.functional.layer_norm(
input=k.reshape(shape=(-1, head_dim)), normalized_shape=(head_dim,),
weight=p_norm_weight, bias=p_norm_bias).reshape(shape=k.shape)
weight=p_norm_weight, bias=p_norm_bias, eps=normalize_pattern_eps).reshape(shape=k.shape)

else:
active_xi = xi.masked_select(mask=update_active_heads).view(size=(-1, *xi.shape[1:]))
Expand Down

0 comments on commit 5231d29

Please sign in to comment.