-
-
Notifications
You must be signed in to change notification settings - Fork 431
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package gorgonia | ||
|
||
import ( | ||
"fmt" | ||
"hash" | ||
|
||
"github.com/chewxy/hm" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
type diagFlatOp struct{} | ||
|
||
/* Graph Building Related Methods */ | ||
|
||
// Arity returns the number of inputs the Op expects. -1 indicates that it's n-ary and will be determined at runtime | ||
func (op diagFlatOp) Arity() int { return 1 } | ||
|
||
// Informs the type of the Op (not the node). This will be used by the type system to infer the final type of the node | ||
func (op diagFlatOp) Type() hm.Type { | ||
a := hm.TypeVariable('a') | ||
b := hm.TypeVariable('a') | ||
T := makeTensorType(2, b) | ||
return hm.NewFnType(a, T) | ||
} | ||
|
||
// returns the output shape as a function of the inputs | ||
func (op diagFlatOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { | ||
if err := checkArity(op, len(inputs)); err != nil { | ||
return nil, err | ||
} | ||
in := inputs[0].(tensor.Shape) | ||
return tensor.Shape{in.TotalSize(), in.TotalSize()}, nil | ||
} | ||
|
||
/* Machine related */ // executes the op | ||
func (op diagFlatOp) Do(vals ...Value) (Value, error) { | ||
if err := checkArity(op, len(vals)); err != nil { | ||
return nil, err | ||
} | ||
|
||
T := vals[0].(tensor.Tensor) | ||
return tensor.New(tensor.AsDenseDiag(T.Data())), nil | ||
} | ||
|
||
/* Analysis Related Methods */ | ||
|
||
// indicates if the Op will return a pointer (allowing possible inplace edits) or by value | ||
// if it's false, the return value of the Op will be a copy of its input | ||
func (op diagFlatOp) ReturnsPtr() bool { return false } | ||
|
||
// Does this op potentially call external (cgo or cuda) functions (thereby requiring extra overhead for Go's trampolining thing) | ||
func (op diagFlatOp) CallsExtern() bool { return false } | ||
|
||
// overwriteInput() is a method which states which input the output will be overwriting. | ||
// This allows for some efficiency gains as the underlying arrays wouldn't have to be re-allocated. | ||
// The method returns an int instead of a bool because potentially different operations may be allowed | ||
// to overwrite certain inputs. For example, consider an operation to increment a value: | ||
// the IncrementOp would be a unary operator, and assuming we would like to overwrite the input, | ||
// the retVal of overwriteInput() will be 0 (inputs[0]). | ||
// -1 is returned if overwriting of input is disallowed | ||
func (op diagFlatOp) OverwritesInput() int { return -1 } | ||
|
||
/* Other methods */ | ||
func (op diagFlatOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "DiagFlatOp") } | ||
|
||
func (op diagFlatOp) Hashcode() uint32 { return simpleHash(op) } | ||
|
||
func (op diagFlatOp) String() string { return "DiagFlat" } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
package gorgonia | ||
|
||
import ( | ||
"fmt" | ||
|
||
"gorgonia.org/tensor" | ||
) | ||
|
||
func ExampleDiagFlat() { | ||
g := NewGraph() | ||
|
||
// 2 dimensional | ||
aV := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})) | ||
a := NodeFromAny(g, aV) | ||
b, err := DiagFlat(a) | ||
if err != nil { | ||
fmt.Println(err) | ||
return | ||
} | ||
m := NewTapeMachine(g) | ||
if err := m.RunAll(); err != nil { | ||
fmt.Println(err) | ||
return | ||
} | ||
fmt.Printf("a:\n%v\n", a.Value()) | ||
fmt.Printf("b:\n%v\n", b.Value()) | ||
|
||
// 3 dimensional | ||
aV = tensor.New(tensor.WithShape(2, 3, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})) | ||
a = NodeFromAny(g, aV, WithName("a'")) | ||
b2, err := DiagFlat(a) | ||
if err != nil { | ||
fmt.Println(err) | ||
return | ||
} | ||
m = NewTapeMachine(g) | ||
if err := m.RunAll(); err != nil { | ||
fmt.Println(err) | ||
} | ||
|
||
fmt.Printf("a:\n%v", a.Value()) | ||
fmt.Printf("b:\n%v\n", b2.Value()) | ||
|
||
// 1 dimensional | ||
aV = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 2})) | ||
a = NodeFromAny(g, aV, WithName("a''")) | ||
b3, err := DiagFlat(a) | ||
if err != nil { | ||
fmt.Println(err) | ||
return | ||
} | ||
m = NewTapeMachine(g) | ||
if err := m.RunAll(); err != nil { | ||
fmt.Println(err) | ||
} | ||
|
||
fmt.Printf("a:\n%v\n", a.Value()) | ||
fmt.Printf("b:\n%v\n", b3.Value()) | ||
|
||
// Scalars | ||
|
||
a = NodeFromAny(g, 100.0, WithName("aScalar")) | ||
_, err = DiagFlat(a) | ||
fmt.Println(err) | ||
|
||
// Output: | ||
// a: | ||
// ⎡1 2⎤ | ||
// ⎣3 4⎦ | ||
// | ||
// b: | ||
// ⎡1 0 0 0⎤ | ||
// ⎢0 2 0 0⎥ | ||
// ⎢0 0 3 0⎥ | ||
// ⎣0 0 0 4⎦ | ||
// | ||
// a: | ||
// ⎡ 1 2⎤ | ||
// ⎢ 3 4⎥ | ||
// ⎣ 5 6⎦ | ||
// | ||
// ⎡ 7 8⎤ | ||
// ⎢ 9 10⎥ | ||
// ⎣11 12⎦ | ||
// | ||
// | ||
// b: | ||
// ⎡ 1 0 0 0 ... 0 0 0 0⎤ | ||
// ⎢ 0 2 0 0 ... 0 0 0 0⎥ | ||
// ⎢ 0 0 3 0 ... 0 0 0 0⎥ | ||
// ⎢ 0 0 0 4 ... 0 0 0 0⎥ | ||
// . | ||
// . | ||
// . | ||
// ⎢ 0 0 0 0 ... 9 0 0 0⎥ | ||
// ⎢ 0 0 0 0 ... 0 10 0 0⎥ | ||
// ⎢ 0 0 0 0 ... 0 0 11 0⎥ | ||
// ⎣ 0 0 0 0 ... 0 0 0 12⎦ | ||
// | ||
// a: | ||
// [1 2] | ||
// b: | ||
// ⎡1 0⎤ | ||
// ⎣0 2⎦ | ||
// | ||
// Cannot perform DiagFlat on a scalar equivalent node | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package gorgonia | ||
|
||
import "github.com/pkg/errors" | ||
|
||
// DiagFlat takes the flattened value and creates a diagonal matrix from it. | ||
// | ||
// It is non-differentiable. | ||
func DiagFlat(a *Node) (*Node, error) { | ||
if a.Shape().IsScalarEquiv() { | ||
return nil, errors.Errorf("Cannot perform DiagFlat on a scalar equivalent node") | ||
} | ||
return ApplyOp(diagFlatOp{}, a) | ||
} |