-
Notifications
You must be signed in to change notification settings - Fork 2
/
shape.go
305 lines (284 loc) · 9.07 KB
/
shape.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
/*
* Copyright 2023 Jan Pfeifer
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Package shapes defines Shape and DType and associated tools.
//
// Shape represents the shape (rank, dimensions and DType) of either a Tensor or the expected
// shape of a node in a computation Graph. DType indicates the type of the unit element of
// a Tensor (or its representation as a node in a computation Graph).
//
// Shape and DType are used both by the concrete tensor values (see tensor package) and when
// working on the computation graph (see graph package).
//
// ## Glossary
//
// - Rank: number of axes (dimensions) of a Tensor.
// - Axis: is the index of a dimension on a multi-dimensional Tensor. Sometimes used
// interchangeably with Dimension, but here we try to refer to a dimension index as "axis"
// (plural axes), and its size as its dimension.
// - Dimension: the size of a multi-dimensions Tensor in one of its axes. See example below:
// - DType: the data type of the unit element in a tensor.
// - Scalar: is a shape where there are no axes (or dimensions), only a single value
// of the associated DType.
//
// Example: The multi-dimensional array `[][]int32{{0, 1, 2}, {3, 4, 5}}` if converted to a Tensor
// would have shape `(int32)[2 3]`. We say it has rank 2 (so 2 axes), axis 0 has
// dimension 2, and axis 1 has dimension 3. This shape could be created with
// `shapes.Make(int32, 2, 3)`.
//
// ## Asserts
//
// When coding ML models, one delicate part is keeping tabs on the shape of
// the nodes of the graphs -- unfortunately there is no compile-time checking of values,
// so validation only happens in runtime. To facilitate, and also to serve as code documentation,
// this package provides two variations of _assert_ funtionality. Examples:
//
// `AssertRank` and `AssertDims` checks that the rank and dimensions of the given
//
// object (that has a `Shape` method) match, otherwise it panics. The `-1` means
// the dimension is unchecked (it can be anything).
//
// ```
//
// func modelGraph(ctx *context.Context, spec any, inputs []*Node) ([]*Node) {
// _ = spec // Not needed here, we know the dataset.
// shapes.AssertRank(inputs, 2)
// batchSize := inputs.Shape().Dimensions[0]
// logits := layers.Dense(ctx, inputs[0], /* useBias= */ true, /* outputDim= */ 1)
// shapes.AssertDims(logits, batchSize, -1)
// return []*Node{logits}
// }
//
// ```
//
// If you don't want to panic, but instead return an error through the `graph.Graph`, you can
// use the `Node.AssertDims()` method. So it would loook like `logits.AssertDims(batchSize, -1)`.
package shapes
import (
"encoding/gob"
"fmt"
"github.com/gomlx/gomlx/types/exceptions"
"github.com/gomlx/gomlx/types/slices"
"github.com/pkg/errors"
"reflect"
"strings"
)
// Shape represents the shape of either a Tensor or the expected shape
// of the value from a computation node.
//
// Use Make to create a new shape. See example in package shapes documentation.
type Shape struct {
DType DType
Dimensions []int
TupleShapes []Shape // Shapes of the tuple, if this is a tuple.
}
// Make returns a Shape structure filled with the values given.
func Make(dtype DType, dimensions ...int) Shape {
s := Shape{Dimensions: slices.Copy(dimensions), DType: dtype}
for _, dim := range dimensions {
if dim <= 0 {
exceptions.Panicf("shapes.Make(%s): cannot create a shape with an axis with dimension <= 0", s)
}
}
return s
}
// Scalar returns a scalar Shape for the given type.
func Scalar[T Number]() Shape {
return Shape{DType: DTypeGeneric[T]()}
}
// Ok returns whether this is a valid Shape. A "zero" shape, that is just instantiating it with Shape{} will be invalid.
func (s Shape) Ok() bool { return s.DType != InvalidDType }
// Rank of the shape, that is, the number of dimensions.
func (s Shape) Rank() int { return len(s.Dimensions) }
// IsScalar returns whether the shape represents a scalar, that is there are no dimensions (rank==0).
func (s Shape) IsScalar() bool { return s.Ok() && s.Rank() == 0 }
// Shape returns a shallow copy of itself. It implements the HasShape interface.
func (s Shape) Shape() Shape { return s }
// String implements stringer, pretty-prints the shape.
func (s Shape) String() string {
if s.TupleSize() > 0 {
parts := make([]string, 0, s.TupleSize())
for _, tuple := range s.TupleShapes {
parts = append(parts, tuple.String())
}
return fmt.Sprintf("Tuple<%s>", strings.Join(parts, ", "))
}
if s.Rank() == 0 {
return fmt.Sprintf("(%s)[]", s.DType)
}
return fmt.Sprintf("(%s)%v", s.DType, s.Dimensions)
}
// Size returns the number of elements of DType are needed for this shape. It's the product of all dimensions.
func (s Shape) Size() (size int) {
size = 1
for _, d := range s.Dimensions {
size *= d
}
return
}
// Memory returns the number of bytes for that would be used in Go to store the given data -- the actual
// memory may depend on the device implementation in some cases (e.g. bool).
func (s Shape) Memory() int64 {
return s.DType.Memory() * int64(s.Size())
}
// MakeTuple returns a shape representing a tuple of elements with the given shapes.
func MakeTuple(elements []Shape) Shape {
return Shape{DType: Tuple, TupleShapes: elements}
}
// IsTuple returns whether the shape represents a tuple.
func (s Shape) IsTuple() bool { return s.DType == Tuple }
// TupleSize returns the number of elements in the tuple, if it is a tuple.
func (s Shape) TupleSize() int {
return len(s.TupleShapes)
}
// Eq compares two shapes for equality: dtype and dimensions are compared.
func (s Shape) Eq(s2 Shape) bool {
if s.DType != s2.DType {
return false
}
if s.DType == Tuple {
if s.TupleSize() != s2.TupleSize() {
return false
}
for ii, element := range s.TupleShapes {
if !element.Eq(s2.TupleShapes[ii]) {
return false
}
}
return true
}
if s.Rank() != s2.Rank() {
return false
}
if s.IsScalar() {
return true
}
// For normal shapes just compare dimensions.
return reflect.DeepEqual(s.Dimensions, s2.Dimensions)
}
// EqDimensions compares two shapes for equality of dimensions. Dtypes can be different.
func (s Shape) EqDimensions(s2 Shape) bool {
if s.DType == Tuple {
if s2.DType != Tuple {
return false
}
if s.TupleSize() != s2.TupleSize() {
return false
}
for ii, element := range s.TupleShapes {
if !element.EqDimensions(s2.TupleShapes[ii]) {
return false
}
}
return true
}
if s.Rank() != s2.Rank() {
return false
}
if s.IsScalar() {
return true
}
// For normal shapes just compare dimensions.
return reflect.DeepEqual(s.Dimensions, s2.Dimensions)
}
// Copy makes a deep copy of the shapes.
func (s Shape) Copy() (s2 Shape) {
s2.DType = s.DType
s2.Dimensions = make([]int, len(s.Dimensions))
copy(s2.Dimensions, s.Dimensions)
if s.TupleSize() > 0 {
s2.TupleShapes = make([]Shape, 0, len(s.TupleShapes))
for _, subShape := range s.TupleShapes {
s2.TupleShapes = append(s2.TupleShapes, subShape)
}
}
return
}
// GobSerialize shape in binary format.
func (s Shape) GobSerialize(encoder *gob.Encoder) (err error) {
enc := func(e any) {
if err != nil {
return
}
err = encoder.Encode(e)
if err != nil {
err = errors.Wrapf(err, "failed to serialize Shape %s", s)
}
}
enc(s.DType)
enc(s.Dimensions)
enc(len(s.TupleShapes))
if err != nil {
return
}
for _, subShape := range s.TupleShapes {
err = subShape.GobSerialize(encoder)
if err != nil {
return
}
}
return
}
// GobDeserialize a Shape. Returns new Shape or an error.
func GobDeserialize(decoder *gob.Decoder) (s Shape, err error) {
dec := func(data any) {
if err != nil {
return
}
err = decoder.Decode(data)
if err != nil {
err = errors.Wrapf(err, "failed to deserialize Shape")
}
}
dec(&s.DType)
dec(&s.Dimensions)
var numTuples int
dec(&numTuples)
if err != nil {
return
}
s.TupleShapes = make([]Shape, numTuples)
for ii := range s.TupleShapes {
s.TupleShapes[ii], err = GobDeserialize(decoder)
if err != nil {
return
}
}
return
}
// ConcatenateDimensions of two shapes. The resulting rank is the sum of both ranks. They must
// have the same dtype. If any of them is a scalar, the resulting shape will be a copy of the other.
// It doesn't work for Tuples.
func ConcatenateDimensions(s1, s2 Shape) (shape Shape) {
if s1.IsTuple() || s2.IsTuple() {
return
}
if s1.DType == InvalidDType || s2.DType == InvalidDType {
return
}
if s1.DType != s2.DType {
return
}
if s1.IsScalar() {
return s2.Copy()
} else if s2.IsScalar() {
return s1.Copy()
}
shape.DType = s1.DType
shape.Dimensions = make([]int, s1.Rank()+s2.Rank())
copy(shape.Dimensions, s1.Dimensions)
copy(shape.Dimensions[s1.Rank():], s2.Dimensions)
return
}