-
Notifications
You must be signed in to change notification settings - Fork 1
/
arrayops.go
49 lines (39 loc) · 1.13 KB
/
arrayops.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
package data
//go:generate genny -in=$GOFILE -out=gen-$GOFILE gen "ArrayType=float64,float32,int32,uint32,int64,uint64"
func ApplyFunc1ArrayType(dest, source NDArrayType, fn func(val ArrayType) ArrayType) {
if dest.Contiguous() && source.Contiguous() {
destSlice := dest.Unroll()
sourceSlice := source.Unroll()
for i := range destSlice {
destSlice[i] = fn(sourceSlice[i])
}
return
}
idx := dest.NewIndex(0)
shape := dest.Shape()
size := Product(shape)
for pos := 0; pos < size; pos++ {
dest.Set(idx, fn(source.Get(idx)))
Increment(idx, shape)
}
}
func ScaleArrayTypeArray(dest, source NDArrayType, scale ArrayType) {
ApplyFunc1ArrayType(dest, source, func(v ArrayType) ArrayType { return v * scale })
}
func AddToArrayTypeArray(dest, source NDArrayType) {
if dest.Contiguous() && source.Contiguous() {
destSlice := dest.Unroll()
sourceSlice := source.Unroll()
for i := range destSlice {
destSlice[i] += sourceSlice[i]
}
return
}
idx := dest.NewIndex(0)
shape := dest.Shape()
size := Product(shape)
for pos := 0; pos < size; pos++ {
dest.Set(idx, dest.Get(idx)+source.Get(idx))
Increment(idx, shape)
}
}