diff --git a/nupic/research/frameworks/dendrites/modules/dendritic_mlp.py b/nupic/research/frameworks/dendrites/modules/dendritic_mlp.py index 0074f9c8c..ac5576f7f 100644 --- a/nupic/research/frameworks/dendrites/modules/dendritic_mlp.py +++ b/nupic/research/frameworks/dendrites/modules/dendritic_mlp.py @@ -19,6 +19,8 @@ # http://numenta.org/licenses/ # ---------------------------------------------------------------------- +from collections import Iterable + import numpy as np import torch from torch import nn @@ -38,7 +40,9 @@ class DendriticMLP(nn.Module): initializations and learning parameters :param input_size: size of the input to the network - :param output_size: the number of units in the output layer + :param output_size: the number of units in the output layer. Must be either an + integer if there is a single output head, or an iterable + of integers if there are multiple output heads. :param hidden_sizes: the number of units in each hidden layer :param num_segments: the number of dendritic segments that each hidden unit has :param dim_context: the size of the context input to the network @@ -74,7 +78,8 @@ class DendriticMLP(nn.Module): def __init__( self, input_size, output_size, hidden_sizes, num_segments, dim_context, - kw, kw_percent_on=0.05, context_percent_on=1.0, dendrite_weight_sparsity=0.95, + kw, kw_percent_on=0.05, context_percent_on=1.0, + dendrite_weight_sparsity=0.95, weight_sparsity=0.95, weight_init="modified", dendrite_init="modified", freeze_dendrites=False, output_nonlinearity=None, dendritic_layer_class=AbsoluteMaxGatingDendriticLayer, @@ -84,7 +89,7 @@ def __init__( # "modified" assert weight_init in ("kaiming", "modified") assert dendrite_init in ("kaiming", "modified") - assert kw_percent_on >= 0.0 and kw_percent_on < 1.0 + assert kw_percent_on is None or (kw_percent_on >= 0.0 and kw_percent_on < 1.0) assert context_percent_on >= 0.0 if kw_percent_on == 0.0: @@ -159,21 +164,32 @@ def __init__( input_size = self.hidden_sizes[i] - self._output_layer = nn.Sequential() - output_linear = SparseWeights(module=nn.Linear(input_size, output_size), - sparsity=weight_sparsity, allow_extremes=True) - if weight_init == "modified": - self._init_sparse_weights(output_linear, 1 - kw_percent_on if kw else 0.0) - self._output_layer.add_module("output_linear", output_linear) + self._single_output_head = not isinstance(output_size, Iterable) + if self._single_output_head: + output_size = (output_size,) + + self._output_layers = nn.ModuleList() + for out_size in output_size: + output_layer = nn.Sequential() + output_linear = SparseWeights(module=nn.Linear(input_size, out_size), + sparsity=weight_sparsity, allow_extremes=True) + if weight_init == "modified": + self._init_sparse_weights( + output_linear, 1 - kw_percent_on if kw else 0.0) + output_layer.add_module("output_linear", output_linear) - if self.output_nonlinearity is not None: - self._output_layer.add_module("non_linearity", output_nonlinearity) + if self.output_nonlinearity is not None: + output_layer.add_module("non_linearity", output_nonlinearity) + self._output_layers.append(output_layer) def forward(self, x, context): for layer, activation in zip(self._layers, self._activations): x = activation(layer(x, context)) - return self._output_layer(x) + if self._single_output_head: + return self._output_layers[0](x) + else: + return [out_layer(x) for out_layer in self._output_layers] # ------ Weight initialization functions ------ @staticmethod