Skip to content

Commit

Permalink
Merge pull request #527 from akashvelu/master
Browse files Browse the repository at this point in the history
Allow Dendritic MLP to have multiple output heads
  • Loading branch information
akashvelu committed Jun 9, 2021
2 parents c597d44 + f25baa9 commit aace8fa
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions nupic/research/frameworks/dendrites/modules/dendritic_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# http://numenta.org/licenses/
# ----------------------------------------------------------------------

from collections import Iterable

import numpy as np
import torch
from torch import nn
Expand All @@ -38,7 +40,9 @@ class DendriticMLP(nn.Module):
initializations and learning parameters
:param input_size: size of the input to the network
:param output_size: the number of units in the output layer
:param output_size: the number of units in the output layer. Must be either an
integer if there is a single output head, or an iterable
of integers if there are multiple output heads.
:param hidden_sizes: the number of units in each hidden layer
:param num_segments: the number of dendritic segments that each hidden unit has
:param dim_context: the size of the context input to the network
Expand Down Expand Up @@ -74,7 +78,8 @@ class DendriticMLP(nn.Module):

def __init__(
self, input_size, output_size, hidden_sizes, num_segments, dim_context,
kw, kw_percent_on=0.05, context_percent_on=1.0, dendrite_weight_sparsity=0.95,
kw, kw_percent_on=0.05, context_percent_on=1.0,
dendrite_weight_sparsity=0.95,
weight_sparsity=0.95, weight_init="modified", dendrite_init="modified",
freeze_dendrites=False, output_nonlinearity=None,
dendritic_layer_class=AbsoluteMaxGatingDendriticLayer,
Expand All @@ -84,7 +89,7 @@ def __init__(
# "modified"
assert weight_init in ("kaiming", "modified")
assert dendrite_init in ("kaiming", "modified")
assert kw_percent_on >= 0.0 and kw_percent_on < 1.0
assert kw_percent_on is None or (kw_percent_on >= 0.0 and kw_percent_on < 1.0)
assert context_percent_on >= 0.0

if kw_percent_on == 0.0:
Expand Down Expand Up @@ -159,21 +164,32 @@ def __init__(

input_size = self.hidden_sizes[i]

self._output_layer = nn.Sequential()
output_linear = SparseWeights(module=nn.Linear(input_size, output_size),
sparsity=weight_sparsity, allow_extremes=True)
if weight_init == "modified":
self._init_sparse_weights(output_linear, 1 - kw_percent_on if kw else 0.0)
self._output_layer.add_module("output_linear", output_linear)
self._single_output_head = not isinstance(output_size, Iterable)
if self._single_output_head:
output_size = (output_size,)

self._output_layers = nn.ModuleList()
for out_size in output_size:
output_layer = nn.Sequential()
output_linear = SparseWeights(module=nn.Linear(input_size, out_size),
sparsity=weight_sparsity, allow_extremes=True)
if weight_init == "modified":
self._init_sparse_weights(
output_linear, 1 - kw_percent_on if kw else 0.0)
output_layer.add_module("output_linear", output_linear)

if self.output_nonlinearity is not None:
self._output_layer.add_module("non_linearity", output_nonlinearity)
if self.output_nonlinearity is not None:
output_layer.add_module("non_linearity", output_nonlinearity)
self._output_layers.append(output_layer)

def forward(self, x, context):
for layer, activation in zip(self._layers, self._activations):
x = activation(layer(x, context))

return self._output_layer(x)
if self._single_output_head:
return self._output_layers[0](x)
else:
return [out_layer(x) for out_layer in self._output_layers]

# ------ Weight initialization functions ------
@staticmethod
Expand Down

0 comments on commit aace8fa

Please sign in to comment.