/
storage.go
53 lines (44 loc) · 1.1 KB
/
storage.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
package storage
import (
"fmt"
"log"
"strings"
"gorgonia.org/tensor"
)
// Storage is in charge of loading the weights from files
type Storage struct {
Cost float64
Learnables map[string]Weight
}
// NewStorage instantiates a storage
func NewStorage() *Storage {
return &Storage{
Cost: 0.0,
Learnables: map[string]Weight{},
}
}
// TensorByName returns the tensor associated to a weight name
func (l *Storage) TensorByName(name string) (tensor.Tensor, error) {
t, ok := l.Learnables[name]
if !ok {
return nil, ErrLearnableNotFound
}
return t.Value.(tensor.Tensor), nil
}
// Load loads the weights in the given path
func (l *Storage) LoadFile(filePath string) error {
if strings.Contains(filePath, ".nn1") {
return LoadNN1(l, filePath)
} else {
return fmt.Errorf("extension %v is not supported yet", filePath)
}
}
// AddWeights adds weights to the storage
func (l *Storage) AddWeights(weights ...Weight) {
for _, w := range weights {
if _, ok := l.Learnables[w.Name]; ok {
log.Panicf("weight %s is already present in the storage", w.Name)
}
l.Learnables[w.Name] = w
}
}