Skip to content

Commit

Permalink
Merge e599f39 into 3081441
Browse files Browse the repository at this point in the history
  • Loading branch information
chewxy committed Jul 3, 2017
2 parents 3081441 + e599f39 commit 6370c2c
Show file tree
Hide file tree
Showing 68 changed files with 8,533 additions and 6,805 deletions.
30 changes: 21 additions & 9 deletions tensor/ap.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ type AP struct {
strides []int // strides is usually calculated from shape
fin bool // is this struct change-proof?

// future stuff
// triangle byte // up = 0xf0; down = 0x0f; symmetric = 0xff; not a triangle = 0x00
o DataOrder
Δ Triangle
}

// NewAP creates a new AP, given the shape and strides
Expand Down Expand Up @@ -60,14 +60,10 @@ func (ap *AP) SetShape(s ...int) {
ap.strides = nil
}
ap.shape = Shape(s).Clone()
ap.strides = ap.shape.calcStrides()
ap.strides = ap.calcStrides()
}
}

// locking and unlocking is used to ensure that the shape and stride doesn't change (it's not really safe though, as a direct mutation of the strides/shape would still mutate it, but at least the dimensions cannot change)
func (ap *AP) lock() { ap.fin = true }
func (ap *AP) unlock() { ap.fin = false }

// Shape returns the shape of the AP
func (ap *AP) Shape() Shape { return ap.shape }

Expand Down Expand Up @@ -121,12 +117,12 @@ func (ap *AP) Clone() (retVal *AP) {

// C returns true if the access pattern is C-contiguous array
func (ap *AP) C() bool {
return ap.strides[len(ap.strides)-1] == 1
return ap.o.isRowMajor() && ap.o.isContiguous()
}

// F returns true if the access pattern is Fortran contiguous array
func (ap *AP) F() bool {
return ap.strides[0] == 1
return ap.o.isColMajor() && ap.o.isContiguous()
}

// S returns the metadata of the sliced tensor.
Expand Down Expand Up @@ -161,6 +157,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e

var start, end, step int
if start, end, step, err = SliceDetails(sl, size); err != nil {
err = errors.Wrapf(err, "Unable to get slice details on slice %d with size %d: %v", i, sl, size)
return
}

Expand Down Expand Up @@ -204,6 +201,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e
}

newAP = NewAP(newShape, newStrides)
newAP.o = MakeDataOrder(ap.o, NonContiguous)
}
return
}
Expand Down Expand Up @@ -265,6 +263,20 @@ func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) {
return
}

// locking and unlocking is used to ensure that the shape and stride doesn't change (it's not really safe though, as a direct mutation of the strides/shape would still mutate it, but at least the dimensions cannot change)
func (ap *AP) lock() { ap.fin = true }
func (ap *AP) unlock() { ap.fin = false }

func (ap *AP) calcStrides() []int {
switch {
case ap.o.isRowMajor():
return ap.shape.calcStrides()
case ap.o.isColMajor():
return ap.shape.calcStridesColMajor()
}
panic("unreachable")
}

// TransposeIndex returns the new index given the old index
func TransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int {
oldCoord, err := Itol(i, oldShape, oldStrides)
Expand Down
6 changes: 3 additions & 3 deletions tensor/api_arith.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
return
}

fo := parseFuncOpts(opts...)
fo := ParseFuncOpts(opts...)

var reuse, incr *Dense
if reuse, err = getFloatDense(fo.reuse); err != nil {
Expand All @@ -163,9 +163,9 @@ func Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
var res interface{}
switch a.t.Kind() {
case reflect.Float64:
res = a.getF64(0) * b.getF64(0)
res = a.GetF64(0) * b.GetF64(0)
case reflect.Float32:
res = a.getF32(0) * b.getF32(0)
res = a.GetF32(0) * b.GetF32(0)
}

switch {
Expand Down
14 changes: 14 additions & 0 deletions tensor/api_matop.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ func T(t Tensor, axes ...int) (retVal Tensor, err error) {
panic("Unreachable")
}

func Transpose(t Tensor, axes ...int) (retVal Tensor, err error) {
switch tt := t.(type){
case *Dense:
var ret *Dense
if ret, err = tt.SafeT(axes...);err != nil {
return
}
ret.Transpose()
retVal = ret
return
}
panic("Unreachable")
}

// Concat concatenates a list of Tensors. At the moment the operation only supports Tensors of the same type
// (*Dense can only be concatenated with a bunch of *Dense, CSCs can only be concatenated with a bunch of CSC, etc)
func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) {
Expand Down
32 changes: 16 additions & 16 deletions tensor/api_unary.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,36 +90,36 @@ func Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
cloned := t.Clone().(*Dense)
switch t.t.Kind() {
case reflect.Float64:
vecf64.Sqrt(cloned.float64s())
vecf64.Sqrt(cloned.Float64s())
case reflect.Float32:
vecf32.Sqrt(cloned.float32s())
vecf32.Sqrt(cloned.Float32s())
}
_, err = reuse.Add(cloned, UseUnsafe())
retVal = reuse
case toReuse:
copyDense(reuse, t)
switch t.t.Kind() {
case reflect.Float64:
vecf64.Sqrt(reuse.float64s())
vecf64.Sqrt(reuse.Float64s())
case reflect.Float32:
vecf32.Sqrt(reuse.float32s())
vecf32.Sqrt(reuse.Float32s())
}
retVal = reuse
case safe:
cloned := t.Clone().(*Dense)
switch t.t.Kind() {
case reflect.Float64:
vecf64.Sqrt(cloned.float64s())
vecf64.Sqrt(cloned.Float64s())
case reflect.Float32:
vecf32.Sqrt(cloned.float32s())
vecf32.Sqrt(cloned.Float32s())
}
retVal = cloned
case !safe:
switch t.t.Kind() {
case reflect.Float64:
vecf64.Sqrt(t.float64s())
vecf64.Sqrt(t.Float64s())
case reflect.Float32:
vecf32.Sqrt(t.float32s())
vecf32.Sqrt(t.Float32s())
}
retVal = t
}
Expand Down Expand Up @@ -164,36 +164,36 @@ func InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
cloned := t.Clone().(*Dense)
switch t.t.Kind() {
case reflect.Float64:
vecf64.InvSqrt(cloned.float64s())
vecf64.InvSqrt(cloned.Float64s())
case reflect.Float32:
vecf32.InvSqrt(cloned.float32s())
vecf32.InvSqrt(cloned.Float32s())
}
_, err = reuse.Add(cloned, UseUnsafe())
retVal = reuse
case toReuse:
copyDense(reuse, t)
switch t.t.Kind() {
case reflect.Float64:
vecf64.InvSqrt(reuse.float64s())
vecf64.InvSqrt(reuse.Float64s())
case reflect.Float32:
vecf32.InvSqrt(reuse.float32s())
vecf32.InvSqrt(reuse.Float32s())
}
retVal = reuse
case safe:
cloned := t.Clone().(*Dense)
switch t.t.Kind() {
case reflect.Float64:
vecf64.InvSqrt(cloned.float64s())
vecf64.InvSqrt(cloned.Float64s())
case reflect.Float32:
vecf32.InvSqrt(cloned.float32s())
vecf32.InvSqrt(cloned.Float32s())
}
retVal = cloned
case !safe:
switch t.t.Kind() {
case reflect.Float64:
vecf64.InvSqrt(t.float64s())
vecf64.InvSqrt(t.Float64s())
case reflect.Float32:
vecf32.InvSqrt(t.float32s())
vecf32.InvSqrt(t.Float32s())
}
retVal = t
}
Expand Down
Loading

0 comments on commit 6370c2c

Please sign in to comment.