/
values.go
213 lines (188 loc) · 5.72 KB
/
values.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
package gorgonia
import (
"fmt"
"unsafe"
"github.com/chewxy/hm"
"github.com/pkg/errors"
"gorgonia.org/tensor"
)
// Value represents a value that Gorgonia accepts. At this point it is implemented by:
// - all scalar value types (F64, F32... etc)
// - *tensor.Dense
// - *dualValue
//
// A Value is essentially any thing that knows its own type and shape.
// Most importantly though, a Value is a pointer - and can be converted into a tensor.Memory.
// This is done for the sake of interoperability with external devices like cgo or CUDA or OpenCL.
// This also means for the most part most Values will be allocated on the heap.
// There are some performance tradeoffs made in this decision, but ultimately this is better than having to manually manage blocks of memory
type Value interface {
Shape() tensor.Shape // Shape returns the shape of the Value. Scalar values return ScalarShape()
Size() int // Size represents the number of elements in the Value. Note that in cases such as a *tensor.Dense, the underlying slice MAY have more elements than the Size() reports. This is correct.
Data() interface{} // Data returns the original representation of the Value
Dtype() tensor.Dtype // Dtype returns the Dtype of the value
tensor.Memory
fmt.Formatter
}
// Valuer is any type that can return a Value
type Valuer interface {
Value() Value
}
// Zeroer is a Value that can zero itself
type Zeroer interface {
Value
Zero()
}
// ZeroValuer is a a Value that can provide the zero-value of its type
type ZeroValuer interface {
Value
ZeroValue() Value
}
// Dtyper represents any type (typically a Value) that knows its own Dtype
type Dtyper interface {
Dtype() tensor.Dtype
}
// Typer represents any type (typically a Op) that knows its own Type
type Typer interface {
Type() hm.Type
}
// ValueEqualer represents any type that can perform a equal value check
type ValueEqualer interface {
ValueEq(Value) bool
}
// ValueCloser represents any type that can perform a close-value check
type ValueCloser interface {
ValueClose(interface{}) bool
}
// Cloner represents any type that can clone itself.
type Cloner interface {
Clone() interface{}
}
// CloneErrorer represents any type that can clone itself and return an error if necessary
type CloneErrorer interface {
Clone() (interface{}, error)
}
// CopierTo represents any type that can copy data to the destination.
type CopierTo interface {
CopyTo(dest interface{}) error
}
// CopierFrom represents any type that can copy data from the source provided.
type CopierFrom interface {
CopyFrom(src interface{}) error
}
// Setter is a any value that can Memset itself to the provided value
// type Setter interface {
// SetAll(interface{}) error
// }
// makeValue creates a value given a type and shape. The default value is the zero value of the type.
func makeValue(t hm.Type, s tensor.Shape) (retVal Value, err error) {
var dt tensor.Dtype
if dt, err = dtypeOf(t); err != nil {
return
}
if s.IsScalar() {
switch dt {
case tensor.Float64:
return newF64(0), nil
case tensor.Float32:
return newF32(0), nil
case tensor.Int:
return newI(0), nil
case tensor.Int64:
return newI64(0), nil
case tensor.Int32:
return newI32(0), nil
case tensor.Byte:
return newU8(0), nil
case tensor.Bool:
return newB(false), nil
}
}
switch tt := t.(type) {
case TensorType:
return tensor.New(tensor.Of(dt), tensor.WithShape(s...)), nil
default:
err = errors.Errorf(nyiTypeFail, "MakeValue", tt)
return
}
}
func makeValueFromMem(t hm.Type, s tensor.Shape, mem tensor.Memory) (retVal Value, err error) {
var dt tensor.Dtype
if dt, err = dtypeOf(t); err != nil {
return
}
if s.IsScalar() {
return makeScalarFromMem(dt, mem)
}
switch tt := t.(type) {
case TensorType:
memsize := calcMemSize(dt, s)
return tensor.New(tensor.Of(dt), tensor.WithShape(s...), tensor.FromMemory(mem.Uintptr(), uintptr(memsize))), nil
case tensor.Dtype:
return makeScalarFromMem(tt, mem)
default:
err = errors.Errorf(nyiTypeFail, "MakeValue", tt)
return
}
}
func makeScalarFromMem(dt tensor.Dtype, mem tensor.Memory) (retVal Value, err error) {
switch dt {
case tensor.Float64:
retVal = (*F64)(unsafe.Pointer(mem.Uintptr()))
case tensor.Float32:
retVal = (*F32)(unsafe.Pointer(mem.Uintptr()))
case tensor.Int:
retVal = (*I)(unsafe.Pointer(mem.Uintptr()))
case tensor.Int64:
retVal = (*I64)(unsafe.Pointer(mem.Uintptr()))
case tensor.Int32:
retVal = (*I32)(unsafe.Pointer(mem.Uintptr()))
case tensor.Byte:
retVal = (*U8)(unsafe.Pointer(mem.Uintptr()))
case tensor.Bool:
retVal = (*B)(unsafe.Pointer(mem.Uintptr()))
default:
err = errors.Errorf(nyiTypeFail, "makeScalarFromMem", dt)
}
return
}
func logicalSize(s tensor.Shape) int {
if s.IsScalar() {
return 1
}
return s.TotalSize()
}
func calcMemSize(dt tensor.Dtype, s tensor.Shape) int64 {
var elemSize int64
if s.IsScalar() {
elemSize = 1
} else {
elemSize = int64(s.TotalSize())
}
dtSize := int64(dt.Size())
return elemSize * dtSize
}
// ScalarAsTensor returns the tensor representation of a scalar. It is particularly useful as a "reshape" of tensors of sorts
//
// The Value passed in are either Scalar, tensor.Tensor, or *dualValue. Anything else will panic.
func ScalarAsTensor(v Value, dims int, e tensor.Engine) Value {
switch a := v.(type) {
case Scalar:
sh := make(tensor.Shape, dims)
for i := range sh {
sh[i] = 1
}
return tensor.New(tensor.WithShape(sh...), tensor.Of(a.Dtype()), tensor.FromMemory(a.Uintptr(), a.MemSize()), tensor.WithEngine(e))
case tensor.Tensor:
return a
case *dualValue:
b := new(dualValue)
b.Value = ScalarAsTensor(a.Value, dims, e)
b.d = ScalarAsTensor(a.d, dims, e)
return b
case nil:
return nil
default:
panic(fmt.Sprintf("Unable to convert %v to Tensor", v))
}
}