Skip to content
Permalink
Browse files

added batchedMatMul. Along the way, checkErrSetDeriv was added as a h…

…elper method

Haven't done the Do part
  • Loading branch information
chewxy committed Sep 18, 2017
1 parent 1c73fa2 commit cb505b66587fb1611c3e954f467b079b5330277c
Showing with 288 additions and 96 deletions.
  1. +8 −0 errors.go
  2. +26 −0 op_math.go
  3. +35 −0 op_nn.go
  4. +195 −95 operatorLinAlg.go
  5. +9 −0 operatorLinAlg_const.go
  6. +15 −1 shape.go
@@ -69,3 +69,11 @@ func nyi(what string, implFor interface{}) error {
func nondiffErr(op Op) error {
return errors.Errorf("%s is a non-differentiable function", op)
}

// checkErrSetDeriv sets the deriv if the error is a Valuer. Helper function for linalg operations
func checkErrSetDeriv(err error, dv *dualValue) error {
if ver, ok := err.(Valuer); ok {
return dv.SetDeriv(ver.Value())
}
return err
}
@@ -565,15 +565,18 @@ func (op linAlgBinOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err e
case matMulOperator:
if op.transA {
x = transpose2D(x)
defer tensor.ReturnInts(x)
}
if op.transB {
y = transpose2D(y)
defer tensor.ReturnInts(y)
}

retVal = tensor.Shape{x[0], y[1]}
case matVecMulOperator:
if op.transA {
x = transpose2D(x)
defer tensor.ReturnInts(x)
}
if x[0] != y[0] && x[1] != y[0] {
return nil, errors.Errorf("Incompatible shapes: %v and %v", x, y)
@@ -591,6 +594,27 @@ func (op linAlgBinOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err e
case outerProdOperator:
// outerprods only handles vec x vec for now
retVal = tensor.Shape{x.TotalSize(), y.TotalSize()}
case batchedMatMulOperator:
// check that x and y are 3
if x.Dims() != 3 {
return nil, errors.Errorf("BatchedMatMul only works with 3D tensors as x")
}
if y.Dims() != 3 {
return nil, errors.Errorf("BatchedMatMul only works with 3D tensors as y")
}
if x[0] != y[0] {
return nil, errors.Errorf("BatchedMatMul has encounted a batch mismatch: %v %v", x, y)
}
batchSize := x[0]
if op.transA {
x = transpose2D(x[1:])
defer tensor.ReturnInts(x)
}
if op.transB {
y = transpose2D(y[1:])
defer tensor.ReturnInts(y)
}
retVal = tensor.Shape{batchSize, x[0], y[1]}
}
return
}
@@ -757,6 +781,8 @@ func (op linAlgBinOp) do(inputs []Value, opts ...tensor.FuncOpt) (retVal Value,
retVal, _ = anyToScalar(ret)
case outerProdOperator:
retVal, err = tensor.Outer(a, b, opts...)
case batchedMatMulOperator:
// do something
}
return

@@ -145,3 +145,38 @@ func (op randomOp) Hashcode() uint32 {
func (op randomOp) String() string {
return fmt.Sprintf("%v(%v, %v) - %v", op.which, op.a, op.b, op.shape)
}

// clampOp is a constant clamping operation
type clampOp struct {
min, max Scalar
}

func (op *clampOp) Arity() int { return 1 }

func (op *clampOp) Type() hm.Type {
return hm.NewFnType(hm.TypeVariable('a'), hm.TypeVariable('a'))
}

func (op *clampOp) InferShape(shps ...DimSizer) (tensor.Shape, error) {
return shps[0].(tensor.Shape), nil
}

func (op *clampOp) Do(vals ...Value) (Value, error) {
return nil, nil
}

func (op *clampOp) ReturnsPtr() bool { return true }

func (op *clampOp) CallsExtern() bool { return false }

func (op *clampOp) OverwritesInput() int { return 0 }

func (op *clampOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "ConstClamp{%f, %f}()", op.min, op.max) }

func (op *clampOp) Hashcode() uint32 {
h := fnv.New32a()
op.WriteHash(h)
return h.Sum32()
}

func (op *clampOp) String() string { return fmt.Sprintf("ConstClamp{%f, %f}()", op.min, op.max) }

0 comments on commit cb505b6

Please sign in to comment.
You can’t perform that action at this time.