Skip to content
Permalink
Browse files

A bunch of supporting functions for Golgi (#306)

* Added KeepDims as an additional function to "decorate" another function
Cleaned up Ones and ones

* Added broadcasted operations to api_gen
Wrote program to generate those broadcasted ops
Renamed BroadcastMul to BroadcastHadamardProd. BroadcastMul is coming soon

* added an example to show how one may use the broadcasting operations to create dense triangular matrices

* Added better support for BatchedMatMul. Now more than 3D tensors are supported!

* Added unaryOp interface to genapi. Generating the interfaces makes the interfaces more consistent. Previously inversef32 gave the wrong ʘUnaryOperatorType

* Allow axis to be defined in SoftMax. Furthermore the default axis is now the last axis. This allows for SoftMax to be done across ndarrays
Added more examples

* Ported Unconcat to Gorgonia. Added tests

* Added some things for future

* Added more support functions for Golgi

* added some statistics generation for genapi

* Added monad-y error handling to Gorgonia

* Let's do away with the DoXXX functions

* Changed the definition of LiftResult a bit.

* added some helper functions

* Updated Unconcat tor use Nodes instead of []*Node
This allows for easier lifting of the return value, however its
utility is not known at the moment.

* Added HeEtAl InitWFn

* Ugh. Copy and pasting sux when you can only type with one hand

* Squashed commit of the following:

commit 592126c
Author: Ben Leitner <7515022+bdleitner@users.noreply.github.com>
Date:   Sun Nov 17 15:09:08 2019 -0800

    Refactor the max/sum ops to share common code. Have the type/inferShape/Do methods behave in a consistent manner: (#346)

    * Dimensions specified in the "along" parameter are reduced to size 1, but not removed. (Note: this caused TestRepeatOpDoDiff, but this version fixes it.  Perhaps we should make preserving the size-1 dimensions an option of the reduction op?)
    * If all dimensions are included, the result will be a scalar.
    * If all dimensions but 1 are included, the result is a vector, regardless of which dimension is left intact.

    Tests verify that the resulting nodes have the expected shape.

    Note: While here, fix a warning on Max's SymDiff where retVal[0] is set when retVal has not been initialized.  I wonder if this is related to #323 where SymDiff for StableSoftMax (which uses Max) was failing with a panic (probably not, as the error message there seems unrelated, but probably a good fix anyway).

    Closes #326

commit 6fd05db
Author: Olivier Wulveryck <owulveryck@users.noreply.github.com>
Date:   Tue Nov 12 09:15:56 2019 +0100

    Examples/readme (#351)

    * chore(readme): add references to the gorgonia website

commit e6bc7dd
Merge: 9ecd7d0 d1d231f
Author: gareth <31232838+jokebroker@users.noreply.github.com>
Date:   Sat Nov 9 06:47:29 2019 +1100

    Merge pull request #350 from mattn/fix-gomod

    Fix go.mod

commit d1d231f
Author: Yasuhiro Matsumoto <mattn.jp@gmail.com>
Date:   Fri Nov 8 21:35:58 2019 +0900

    Fix go.mod

commit 9ecd7d0
Author: Olivier Wulveryck <owulveryck@users.noreply.github.com>
Date:   Thu Nov 7 09:59:37 2019 +0100

    Gap operator (#302)

    * feat(wip): scratch space for a Global Average Pooling operator

    * chore: skeleton of the operator

    * feat: Global Average Pool

commit 6cc7466
Author: mattn <mattn.jp@gmail.com>
Date:   Sat Nov 2 03:16:02 2019 +0900

    Improvement of example/iris (#348)

commit 6f8c10a
Author: Olivier Wulveryck <owulveryck@users.noreply.github.com>
Date:   Thu Oct 31 22:10:37 2019 +0100

    Iris example (#347)

    * fix: do not overwrite the channel if it already exists

    * feat: multivariate linear regression

commit b7b4b2c
Author: Olivier Wulveryck <owulveryck@users.noreply.github.com>
Date:   Wed Oct 16 15:34:26 2019 +0200

    Create FUNDING.yml (#342)

* Fixed Softmax
  • Loading branch information
chewxy committed Dec 8, 2019
1 parent 9ee42cb commit a8bd935b907b36695b520a0739b6c3162d7858a5
@@ -124,3 +124,102 @@ func Ne(a, b *Node, retSame bool) (*Node, error) {
op.retSame = retSame
return binOpNode(op, a, b)
}

//Add performs a add. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastAdd(a, b *Node, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Add(a2, b2)
}

//Sub performs a sub. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastSub(a, b *Node, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Sub(a2, b2)
}

//HadamardProd performs a hadamardprod. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastHadamardProd(a, b *Node, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return HadamardProd(a2, b2)
}

//HadamardDiv performs a hadamarddiv. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastHadamardDiv(a, b *Node, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return HadamardDiv(a2, b2)
}

//Pow performs a pow. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastPow(a, b *Node, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Pow(a2, b2)
}

//Lt performs a lt. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastLt(a, b *Node, retSame bool, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Lt(a2, b2, retSame)
}

//Gt performs a gt. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastGt(a, b *Node, retSame bool, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Gt(a2, b2, retSame)
}

//Lte performs a lte. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastLte(a, b *Node, retSame bool, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Lte(a2, b2, retSame)
}

//Gte performs a gte. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastGte(a, b *Node, retSame bool, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Gte(a2, b2, retSame)
}

//Eq performs a eq. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastEq(a, b *Node, retSame bool, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Eq(a2, b2, retSame)
}

//Ne performs a ne. The operation is precomposed with a broadcast such that the shapes matches before operations commence.
func BroadcastNe(a, b *Node, retSame bool, leftPattern, rightPattern []byte) (*Node, error) {
a2, b2, err := Broadcast(a, b, NewBroadcastPattern(leftPattern, rightPattern))
if err != nil {
return nil, err
}
return Ne(a2, b2, retSame)
}
@@ -0,0 +1,175 @@
package main

import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/token"
"log"
"path"
"path/filepath"
"strings"
)

type nametypePair struct {
name string
*ast.FuncType
}

func functions(decls []ast.Decl) (signatures []nametypePair) {
for _, decl := range decls {
switch d := decl.(type) {
case *ast.FuncDecl:
signatures = append(signatures, nametypePair{d.Name.Name, d.Type})
default:
}
}
return
}

type strRepr struct {
name string
inTypes []string
retTypes []string

printName bool
}

func (s strRepr) String() string {
buf := new(bytes.Buffer)
buf.Write([]byte("func "))
if s.printName {
buf.Write([]byte(s.name))
}

buf.Write([]byte("("))
for i, v := range s.inTypes {
buf.Write([]byte(v))
if i < len(s.inTypes)-1 {
buf.Write([]byte(", "))
}
}
buf.Write([]byte(") ("))
for i, v := range s.retTypes {
buf.Write([]byte(v))
if i < len(s.retTypes)-1 {
buf.Write([]byte(", "))
}
}
buf.Write([]byte(")"))
return buf.String()
}

func processSig(pair nametypePair) strRepr {
a := pair.FuncType
var inTypes, retTypes []string
if a.Params == nil {
goto next
}
for _, field := range a.Params.List {
names := len(field.Names)
typ := parseTypeExpr(field.Type)
if names == 0 {
inTypes = append(inTypes, typ)
continue
}
for i := 0; i < names; i++ {
inTypes = append(inTypes, typ)
}
}
next:
if a.Results == nil {
return strRepr{pair.name, inTypes, retTypes, true}
}
for _, field := range a.Results.List {
names := len(field.Names)
typ := parseTypeExpr(field.Type)
if names == 0 {
retTypes = append(retTypes, typ)
continue
}
for i := 0; i < names; i++ {
retTypes = append(retTypes, typ)
}
}
return strRepr{pair.name, inTypes, retTypes, true}
}

func parseTypeExpr(expr ast.Expr) string {
switch e := expr.(type) {
case *ast.Ident:
return e.Name
case *ast.StarExpr:
x := parseTypeExpr(e.X)
return "*" + x
case *ast.SelectorExpr:
return parseTypeExpr(e.X) + "." + e.Sel.Name
case *ast.Ellipsis:
return "..." + parseTypeExpr(e.Elt)
case *ast.ArrayType:
return "[]" + parseTypeExpr(e.Elt)
default:
return fmt.Sprintf("%T", expr)
}
}

func filterSigs(xs []strRepr, fn func(strRepr) bool) (retVal []strRepr) {
for _, x := range xs {
if fn(x) {
retVal = append(retVal, x)
}
}
return
}

func functionSignatures() {
files := path.Join(gorgonialoc, "*.go")
matches, err := filepath.Glob(files)

if err != nil {
log.Fatal(err)
}
fset := token.NewFileSet()

var allFns []strRepr
for _, f := range matches {
file, err := parser.ParseFile(fset, f, nil, parser.AllErrors)
if err != nil {
log.Fatal(err)

}

fns := functions(file.Decls)
for _, fn := range fns {
sig := processSig(fn)
sig.printName = false
if strings.Title(sig.name) == sig.name {
allFns = append(allFns, sig)
}
}
}
f := func(a strRepr) bool {
want := []string{"Nodes", "error"}
if len(a.retTypes) != len(want) {
return false
}
for i, v := range a.retTypes {
if v != want[i] {
return false
}
}
return true
}

signatures := make(map[string]int)
interesting := filterSigs(allFns, f)
for _, v := range interesting {
v.printName = true
signatures[fmt.Sprintf("%v", v)]++
}

for k, v := range signatures {
fmt.Printf("%v\t%d\n", k, v)
}
}
@@ -0,0 +1,60 @@
package main

import (
"go/parser"
"go/token"
"io"
"log"
"path"
"strings"
"text/template"
)

type UnaryOpInterfaceData struct {
OpTypes []string
Dtype string // f32, f64
}

const unaryOpInterfaceRaw = `func (f *s{{.Dtype}}UnaryOperator) unaryOpType() ʘUnaryOperatorType {
{{$dt := .Dtype -}}
switch f {
{{range $i, $op := .OpTypes -}}
case &{{$op}}{{$dt}}:
return {{$op}}OpType
{{end -}}
}
return maxʘUnaryOperator
}
func (f *s{{.Dtype}}UnaryOperator) String() string { return f.unaryOpType().String() }
`

var unaryOpInterface *template.Template

func init() {
unaryOpInterface = template.Must(template.New("UnOpInterface").Funcs(funcmap).Parse(unaryOpInterfaceRaw))
}

func generateUnaryInterface(outFile io.Writer) {
// parse operator_unary_const.go
filename := path.Join(gorgonialoc, unaryOps)
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
if err != nil {
log.Fatal(err)
}

unaryNames := constTypes(file.Decls, "ʘUnaryOperatorType", "maxʘUnaryOperator")
var opNames []string
for _, v := range unaryNames {
op := strings.TrimSuffix(v, "OpType")
opNames = append(opNames, op)
}

dtypes := []string{"f32", "f64"}
for _, dt := range dtypes {
data := UnaryOpInterfaceData{opNames, dt}
unaryOpInterface.Execute(outFile, data)
}
}

0 comments on commit a8bd935

Please sign in to comment.
You can’t perform that action at this time.