Skip to content

Commit

Permalink
Softmax outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
nschaetti committed Jan 25, 2019
1 parent 406146d commit 93981b0
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions echotorch/nn/RRCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class RRCell(nn.Module):
"""

# Constructor
def __init__(self, input_dim, output_dim, ridge_param=0.0, feedbacks=False, with_bias=True, learning_algo='inv'):
def __init__(self, input_dim, output_dim, ridge_param=0.0, feedbacks=False, with_bias=True, learning_algo='inv', softmax_output=False):
"""
Constructor
:param input_dim: Inputs dimension.
Expand All @@ -53,6 +53,8 @@ def __init__(self, input_dim, output_dim, ridge_param=0.0, feedbacks=False, with
self.feedbacks = feedbacks
self.with_bias = with_bias
self.learning_algo = learning_algo
self.softmax_output = softmax_output
self.softmax = torch.nn.Softmax(dim=2)

# Size
if self.with_bias:
Expand Down Expand Up @@ -137,7 +139,10 @@ def forward(self, x, y=None):
outputs[b] = torch.mm(x[b], self.w_out)
# end for

return outputs
if self.softmax_output:
return self.softmax(outputs)
else:
return outputs
# end if
# end forward

Expand All @@ -149,7 +154,7 @@ def finalize(self):
if self.learning_algo == 'inv':
# inv_xTx = self.xTx.inverse()
# inv_xTx = torch.inverse(self.xTx + self.ridge_param * torch.eye(self._input_dim + self.with_bias))
ridge_xTx = self.xTx + self.ridge_param * torch.eye(self._input_dim + self.with_bias)
ridge_xTx = self.xTx + self.ridge_param * torch.eye(self.input_dim + self.with_bias)
inv_xTx = ridge_xTx.inverse()
self.w_out.data = torch.mm(inv_xTx, self.xTy).data
else:
Expand Down

0 comments on commit 93981b0

Please sign in to comment.