forked from emer/leabra
-
Notifications
You must be signed in to change notification settings - Fork 1
/
pv_layer.go
78 lines (68 loc) · 1.84 KB
/
pv_layer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
// Copyright (c) 2020, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pvlv
import (
"strconv"
"github.com/ccnlab/leabrax/leabra"
"github.com/chewxy/math32"
"github.com/emer/emergent/emer"
)
// Primary Value input layer. Sends activation directly to its receivers, bypassing the standard mechanisms.
type PVLayer struct {
leabra.Layer
Net *Network
SendPVQuarter int
PVReceivers emer.LayNames
}
func AddPVLayer(nt *Network, name string, nY, nX int, typ emer.LayerType) *PVLayer {
ly := PVLayer{Net: nt}
nt.AddLayerInit(&ly, name, []int{nY, nX, 1, 1}, typ)
return &ly
}
func (ly *PVLayer) AddPVReceiver(lyNm string) {
ly.PVReceivers.Add(lyNm)
rly := ly.Network.LayerByName(lyNm).(IModLayer).AsMod()
rly.IsPVReceiver = true
}
func (ly *PVLayer) Build() error {
err := ly.Layer.Build()
if err != nil {
return err
}
ly.SendPVQuarter = int(leabra.Q4)
return nil
}
func (ly *PVLayer) SendPVAct() {
for li := range ly.PVReceivers {
rly := ly.Net.LayerByName(ly.PVReceivers[li]).(IModLayer).AsMod()
for pi := range ly.Neurons {
pnr := &ly.Neurons[pi] // WARNING: both layers must have the same shape!
mnr := &rly.ModNeurs[pi]
mnr.PVAct = math32.Max(pnr.Act, pnr.Ext)
}
}
}
func (ly *PVLayer) CyclePost(ltime *leabra.Time) {
if ltime.Quarter == ly.SendPVQuarter {
ly.SendPVAct()
}
}
func (ly *PVLayer) GetMonitorVal(data []string) float64 {
var val float32
valType := data[0]
unitIdx, _ := strconv.Atoi(data[1])
switch valType {
case "TotalAct":
val = TotalAct(ly)
case "Act":
val = ly.Neurons[unitIdx].Act
case "PoolActAvg":
pl := &ly.Pools[unitIdx].Inhib.Act
val = pl.Avg * float32(pl.N)
case "PoolActMax":
pl := &ly.Pools[unitIdx].Inhib.Act
val = pl.Max * float32(pl.N)
}
return float64(val)
}