Skip to content

Commit

Permalink
ESN-based ConceptorNet
Browse files Browse the repository at this point in the history
  • Loading branch information
nschaetti committed Jan 25, 2019
1 parent 4154ada commit f877248
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions echotorch/nn/ConceptorNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ConceptorNet(nn.Module):
def __init__(self, input_dim, hidden_dim, spectral_radius=0.9, bias_scaling=0, input_scaling=1.0,
w=None, w_in=None, w_bias=None, sparsity=None, input_set=[1.0, -1.0], w_sparsity=None,
leaky_rate=1.0, nonlin_func=torch.tanh, learning_algo='inv', ridge_param=0.0,
with_bias=True):
with_bias=True, seed=None):
"""
Constructor
:param input_dim: Inputs dimension.
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, input_dim, hidden_dim, spectral_radius=0.9, bias_scaling=0, i
w=w, w_in=w_in, w_bias=w_bias, sparsity=sparsity, input_set=input_set,
w_sparsity=w_sparsity, nonlin_func=nonlin_func, feedbacks=False,
feedbacks_dim=input_dim, wfdb_sparsity=None,
normalize_feedbacks=False)
normalize_feedbacks=False, seed=seed)
# end if

# Input recreation weights layer (Ridge regression)
Expand Down Expand Up @@ -112,6 +112,7 @@ def w_in(self):
# end w_in

# Input recreation matrix
@property
def input_recreation_matrix(self):
"""
Input recreation matrix
Expand Down Expand Up @@ -157,7 +158,7 @@ def set_w(self, w):
# end set_w

# Forward
def forward(self, u, c, reset_state=True):
def forward(self, u=None, c=None, reset_state=True, length=None):
"""
Forward
:param u: Input signal.
Expand All @@ -176,12 +177,20 @@ def forward(self, u, c, reset_state=True):

# Learning conceptor
return c(hidden_states, hidden_states)
elif c is None:
hidden_states = self.esn_cell(
u,
reset_state=reset_state
)

return hidden_states
else:
hidden_states = self.esn_cell(
u=None,
reset_state=reset_state,
input_recreation=self.input_recreation,
conceptor=c
conceptor=c,
length=length
)

# Return observed outputs
Expand Down

0 comments on commit f877248

Please sign in to comment.