/
type.go
148 lines (125 loc) · 4.01 KB
/
type.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package gorgonia
import (
"fmt"
"github.com/chewxy/hm"
"gorgonia.org/tensor"
)
var (
// Represents the types that Nodes can take in Gorgonia
// Float64 ...
Float64 = tensor.Float64
// Float32 ...
Float32 = tensor.Float32
// Int ...
Int = tensor.Int
// Int64 ...
Int64 = tensor.Int64
// Int32 ...
Int32 = tensor.Int32
// Byte ...
Byte = tensor.Uint8
// Bool ...
Bool = tensor.Bool
// Ptr is equivalent to interface{}. Ugh Ugh Ugh
Ptr = tensor.UnsafePointer
vecF64 = &TensorType{Dims: 1, Of: tensor.Float64}
vecF32 = &TensorType{Dims: 1, Of: tensor.Float32}
matF64 = &TensorType{Dims: 2, Of: tensor.Float64}
matF32 = &TensorType{Dims: 2, Of: tensor.Float32}
ten3F64 = &TensorType{Dims: 3, Of: tensor.Float64}
ten3F32 = &TensorType{Dims: 3, Of: tensor.Float32}
// removes the need for type checking
f64T = tensor.Float64 // hm.Type
f32T = tensor.Float32 // hm.Type
)
var acceptableDtypes = [...]tensor.Dtype{tensor.Float64, tensor.Float32, tensor.Int, tensor.Int64, tensor.Int32, tensor.Byte, tensor.Bool}
/*Tensor Type*/
// TensorType is a type constructor for tensors.
//
// Think of it as something like this:
// data Tensor a = Tensor d a
//
// The shape of the Tensor is not part of TensorType.
// Shape checking is relegated to the dynamic part of the program run
type TensorType struct {
Dims int // dims
Of hm.Type
}
func makeFromTensorType(t TensorType, tv hm.TypeVariable) TensorType {
return makeTensorType(t.Dims, tv)
}
func makeTensorType(dims int, typ hm.Type) TensorType {
return TensorType{
Dims: dims,
Of: typ,
}
}
func newTensorType(dims int, typ hm.Type) *TensorType {
switch {
case dims == 1 && typ == f64T:
return vecF64
case dims == 1 && typ == f32T:
return vecF32
case dims == 2 && typ == f64T:
return matF64
case dims == 2 && typ == f32T:
return matF32
case dims == 3 && typ == f64T:
return ten3F64
case dims == 3 && typ == f32T:
return ten3F32
}
t := borrowTensorType()
t.Dims = dims
t.Of = typ
return t
}
// Name returns the name of the type, which will always be "Tensor". Satisfies the hm.Type interface.
func (t TensorType) Name() string { return "Tensor" }
// Format implements fmt.Formatter. It is also required for the satisfication the hm.Type interface.
func (t TensorType) Format(state fmt.State, c rune) {
if state.Flag('#') {
fmt.Fprintf(state, "Tensor-%d %#v", t.Dims, t.Of)
} else {
switch t.Dims {
case 1:
fmt.Fprintf(state, "Vector %v", t.Of)
case 2:
fmt.Fprintf(state, "Matrix %v", t.Of)
default:
fmt.Fprintf(state, "Tensor-%d %v", t.Dims, t.Of)
}
}
}
// String implements fmt.Stringer and runtime.Stringer. Satisfies the hm.Type interface.
func (t TensorType) String() string { return fmt.Sprintf("%v", t) }
// Types returns a list of types that TensorType contains - in this case, the type of Tensor (float64, float32, etc). Satisfies the hm.Type interface.
func (t TensorType) Types() hm.Types { ts := hm.BorrowTypes(1); ts[0] = t.Of; return ts }
// Normalize normalizes the type variable names (if any) in the TensorType. Satisfies the hm.Type interface.
func (t TensorType) Normalize(k, v hm.TypeVarSet) (hm.Type, error) {
var err error
if t.Of, err = t.Of.Normalize(k, v); err != nil {
return nil, err
}
return t, nil
}
// Apply applies the substitutions on the types. Satisfies the hm.Type interface.
func (t TensorType) Apply(sub hm.Subs) hm.Substitutable {
t.Of = t.Of.Apply(sub).(hm.Type)
return t
}
// FreeTypeVar returns any free (unbound) type variables in this type. Satisfies the hm.Type interface.
func (t TensorType) FreeTypeVar() hm.TypeVarSet {
return t.Of.FreeTypeVar()
}
// Eq is the equality function of this type. The type of Tensor has to be the same, and for now, only the dimensions are compared.
// Shape may be compared in the future for tighter type inference. Satisfies the hm.Type interface.
func (t TensorType) Eq(other hm.Type) bool {
switch ot := other.(type) {
case TensorType:
return t.Of.Eq(ot.Of) && t.Dims == ot.Dims
case *TensorType:
return t.Of.Eq(ot.Of) && t.Dims == ot.Dims
}
return false
}