forked from emer/leabra
-
Notifications
You must be signed in to change notification settings - Fork 1
/
ecca1.go
88 lines (78 loc) · 2.36 KB
/
ecca1.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hip
import (
"github.com/ccnlab/leabrax/leabra"
"github.com/chewxy/math32"
)
// hip.EcCa1Prjn is for EC <-> CA1 projections, to perform error-driven
// learning of this encoder pathway according to the ThetaPhase algorithm
// uses Contrastive Hebbian Learning (CHL) on ActP - ActQ1
// Q1: ECin -> CA1 -> ECout : ActQ1 = minus phase for auto-encoder
// Q2, 3: CA3 -> CA1 -> ECout : ActM = minus phase for recall
// Q4: ECin -> CA1, ECin -> ECout : ActP = plus phase for everything
type EcCa1Prjn struct {
leabra.Prjn // access as .Prjn
}
func (pj *EcCa1Prjn) Defaults() {
pj.Prjn.Defaults()
pj.Prjn.Learn.Norm.On = false // off by default
pj.Prjn.Learn.Momentum.On = false // off by default
pj.Prjn.Learn.WtBal.On = false // todo: experiment
}
func (pj *EcCa1Prjn) UpdateParams() {
pj.Prjn.UpdateParams()
}
//////////////////////////////////////////////////////////////////////////////////////
// Learn methods
// DWt computes the weight change (learning) -- on sending projections
// Delta version
func (pj *EcCa1Prjn) DWt() {
if !pj.Learn.Learn {
return
}
slay := pj.Send.(leabra.LeabraLayer).AsLeabra()
rlay := pj.Recv.(leabra.LeabraLayer).AsLeabra()
for si := range slay.Neurons {
sn := &slay.Neurons[si]
nc := int(pj.SConN[si])
st := int(pj.SConIdxSt[si])
syns := pj.Syns[st : st+nc]
scons := pj.SConIdx[st : st+nc]
for ci := range syns {
sy := &syns[ci]
ri := scons[ci]
rn := &rlay.Neurons[ri]
err := (sn.ActP * rn.ActP) - (sn.ActQ1 * rn.ActQ1)
bcm := pj.Learn.BCMdWt(sn.AvgSLrn, rn.AvgSLrn, rn.AvgL)
bcm *= pj.Learn.XCal.LongLrate(rn.AvgLLrn)
err *= pj.Learn.XCal.MLrn
dwt := bcm + err
norm := float32(1)
if pj.Learn.Norm.On {
norm = pj.Learn.Norm.NormFmAbsDWt(&sy.Norm, math32.Abs(dwt))
}
if pj.Learn.Momentum.On {
dwt = norm * pj.Learn.Momentum.MomentFmDWt(&sy.Moment, dwt)
} else {
dwt *= norm
}
sy.DWt += pj.Learn.Lrate * dwt
}
// aggregate max DWtNorm over sending synapses
if pj.Learn.Norm.On {
maxNorm := float32(0)
for ci := range syns {
sy := &syns[ci]
if sy.Norm > maxNorm {
maxNorm = sy.Norm
}
}
for ci := range syns {
sy := &syns[ci]
sy.Norm = maxNorm
}
}
}
}