Skip to content

Commit

Permalink
Added DiagFlat
Browse files Browse the repository at this point in the history
  • Loading branch information
chewxy committed Oct 11, 2020
1 parent d26fe26 commit d065cc4
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
68 changes: 68 additions & 0 deletions op_nondiff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package gorgonia

import (
"fmt"
"hash"

"github.com/chewxy/hm"
"gorgonia.org/tensor"
)

type diagFlatOp struct{}

/* Graph Building Related Methods */

// Arity returns the number of inputs the Op expects. -1 indicates that it's n-ary and will be determined at runtime
func (op diagFlatOp) Arity() int { return 1 }

// Informs the type of the Op (not the node). This will be used by the type system to infer the final type of the node
func (op diagFlatOp) Type() hm.Type {
a := hm.TypeVariable('a')
b := hm.TypeVariable('a')
T := makeTensorType(2, b)
return hm.NewFnType(a, T)
}

// returns the output shape as a function of the inputs
func (op diagFlatOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
if err := checkArity(op, len(inputs)); err != nil {
return nil, err
}
in := inputs[0].(tensor.Shape)
return tensor.Shape{in.TotalSize(), in.TotalSize()}, nil
}

/* Machine related */ // executes the op
func (op diagFlatOp) Do(vals ...Value) (Value, error) {
if err := checkArity(op, len(vals)); err != nil {
return nil, err
}

T := vals[0].(tensor.Tensor)
return tensor.New(tensor.AsDenseDiag(T.Data())), nil
}

/* Analysis Related Methods */

// indicates if the Op will return a pointer (allowing possible inplace edits) or by value
// if it's false, the return value of the Op will be a copy of its input
func (op diagFlatOp) ReturnsPtr() bool { return false }

// Does this op potentially call external (cgo or cuda) functions (thereby requiring extra overhead for Go's trampolining thing)
func (op diagFlatOp) CallsExtern() bool { return false }

// overwriteInput() is a method which states which input the output will be overwriting.
// This allows for some efficiency gains as the underlying arrays wouldn't have to be re-allocated.
// The method returns an int instead of a bool because potentially different operations may be allowed
// to overwrite certain inputs. For example, consider an operation to increment a value:
// the IncrementOp would be a unary operator, and assuming we would like to overwrite the input,
// the retVal of overwriteInput() will be 0 (inputs[0]).
// -1 is returned if overwriting of input is disallowed
func (op diagFlatOp) OverwritesInput() int { return -1 }

/* Other methods */
func (op diagFlatOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "DiagFlatOp") }

func (op diagFlatOp) Hashcode() uint32 { return simpleHash(op) }

func (op diagFlatOp) String() string { return "DiagFlat" }
108 changes: 108 additions & 0 deletions op_nondiff_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package gorgonia

import (
"fmt"

"gorgonia.org/tensor"
)

func ExampleDiagFlat() {
g := NewGraph()

// 2 dimensional
aV := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4}))
a := NodeFromAny(g, aV)
b, err := DiagFlat(a)
if err != nil {
fmt.Println(err)
return
}
m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
fmt.Println(err)
return
}
fmt.Printf("a:\n%v\n", a.Value())
fmt.Printf("b:\n%v\n", b.Value())

// 3 dimensional
aV = tensor.New(tensor.WithShape(2, 3, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}))
a = NodeFromAny(g, aV, WithName("a'"))
b2, err := DiagFlat(a)
if err != nil {
fmt.Println(err)
return
}
m = NewTapeMachine(g)
if err := m.RunAll(); err != nil {
fmt.Println(err)
}

fmt.Printf("a:\n%v", a.Value())
fmt.Printf("b:\n%v\n", b2.Value())

// 1 dimensional
aV = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 2}))
a = NodeFromAny(g, aV, WithName("a''"))
b3, err := DiagFlat(a)
if err != nil {
fmt.Println(err)
return
}
m = NewTapeMachine(g)
if err := m.RunAll(); err != nil {
fmt.Println(err)
}

fmt.Printf("a:\n%v\n", a.Value())
fmt.Printf("b:\n%v\n", b3.Value())

// Scalars

a = NodeFromAny(g, 100.0, WithName("aScalar"))
_, err = DiagFlat(a)
fmt.Println(err)

// Output:
// a:
// ⎡1 2⎤
// ⎣3 4⎦
//
// b:
// ⎡1 0 0 0⎤
// ⎢0 2 0 0⎥
// ⎢0 0 3 0⎥
// ⎣0 0 0 4⎦
//
// a:
// ⎡ 1 2⎤
// ⎢ 3 4⎥
// ⎣ 5 6⎦
//
// ⎡ 7 8⎤
// ⎢ 9 10⎥
// ⎣11 12⎦
//
//
// b:
// ⎡ 1 0 0 0 ... 0 0 0 0⎤
// ⎢ 0 2 0 0 ... 0 0 0 0⎥
// ⎢ 0 0 3 0 ... 0 0 0 0⎥
// ⎢ 0 0 0 4 ... 0 0 0 0⎥
// .
// .
// .
// ⎢ 0 0 0 0 ... 9 0 0 0⎥
// ⎢ 0 0 0 0 ... 0 10 0 0⎥
// ⎢ 0 0 0 0 ... 0 0 11 0⎥
// ⎣ 0 0 0 0 ... 0 0 0 12⎦
//
// a:
// [1 2]
// b:
// ⎡1 0⎤
// ⎣0 2⎦
//
// Cannot perform DiagFlat on a scalar equivalent node

}
13 changes: 13 additions & 0 deletions operations_nondiff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package gorgonia

import "github.com/pkg/errors"

// DiagFlat takes the flattened value and creates a diagonal matrix from it.
//
// It is non-differentiable.
func DiagFlat(a *Node) (*Node, error) {
if a.Shape().IsScalarEquiv() {
return nil, errors.Errorf("Cannot perform DiagFlat on a scalar equivalent node")
}
return ApplyOp(diagFlatOp{}, a)
}

0 comments on commit d065cc4

Please sign in to comment.