forked from emer/leabra
-
Notifications
You must be signed in to change notification settings - Fork 1
/
dahebb.go
79 lines (70 loc) · 1.82 KB
/
dahebb.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pbwm
import (
"github.com/ccnlab/leabrax/leabra"
"github.com/chewxy/math32"
"github.com/goki/ki/kit"
)
// DaHebbPrjn does dopamine-modulated Hebbian learning -- i.e., the 3-factor
// learning rule: Da * Recv.Act * Send.Act
type DaHebbPrjn struct {
leabra.Prjn
}
var KiT_DaHebbPrjn = kit.Types.AddType(&DaHebbPrjn{}, leabra.PrjnProps)
func (pj *DaHebbPrjn) Defaults() {
pj.Prjn.Defaults()
// no additional factors
pj.Learn.WtSig.Gain = 1
pj.Learn.Norm.On = false
pj.Learn.Momentum.On = false
pj.Learn.WtBal.On = false
}
// DWt computes the weight change (learning) -- on sending projections.
func (pj *DaHebbPrjn) DWt() {
if !pj.Learn.Learn {
return
}
slay := pj.Send.(leabra.LeabraLayer).AsLeabra()
rlayi := pj.Recv.(PBWMLayer)
rlay := rlayi.AsLeabra()
for si := range slay.Neurons {
sn := &slay.Neurons[si]
nc := int(pj.SConN[si])
st := int(pj.SConIdxSt[si])
syns := pj.Syns[st : st+nc]
scons := pj.SConIdx[st : st+nc]
for ci := range syns {
sy := &syns[ci]
ri := scons[ci]
rn := &rlay.Neurons[ri]
da := rlayi.UnitValByIdx(DALrn, int(ri))
dwt := da * rn.Act * sn.Act
norm := float32(1)
if pj.Learn.Norm.On {
norm = pj.Learn.Norm.NormFmAbsDWt(&sy.Norm, math32.Abs(dwt))
}
if pj.Learn.Momentum.On {
dwt = norm * pj.Learn.Momentum.MomentFmDWt(&sy.Moment, dwt)
} else {
dwt *= norm
}
sy.DWt += pj.Learn.Lrate * dwt
}
// aggregate max DWtNorm over sending synapses
if pj.Learn.Norm.On {
maxNorm := float32(0)
for ci := range syns {
sy := &syns[ci]
if sy.Norm > maxNorm {
maxNorm = sy.Norm
}
}
for ci := range syns {
sy := &syns[ci]
sy.Norm = maxNorm
}
}
}
}