Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A bunch of supporting functions for Golgi #306

Merged
merged 28 commits into from Dec 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1ffd427
Added KeepDims as an additional function to "decorate" another function
chewxy Jul 19, 2019
67db2fb
Added broadcasted operations to api_gen
chewxy Jul 20, 2019
561e93d
added an example to show how one may use the broadcasting operations …
chewxy Jul 20, 2019
103dde9
Added better support for BatchedMatMul. Now more than 3D tensors are …
chewxy Jul 21, 2019
cb4b2f2
Added unaryOp interface to genapi. Generating the interfaces makes th…
chewxy Jul 22, 2019
6b61480
Allow axis to be defined in SoftMax. Furthermore the default axis is …
chewxy Jul 22, 2019
ae26713
Ported Unconcat to Gorgonia. Added tests
chewxy Jul 23, 2019
b88acb6
Added some things for future
chewxy Jul 23, 2019
bf86bb3
Added more support functions for Golgi
chewxy Jul 23, 2019
628211e
added some statistics generation for genapi
chewxy Jul 24, 2019
45485a7
Added monad-y error handling to Gorgonia
chewxy Jul 24, 2019
71670b7
Let's do away with the DoXXX functions
chewxy Jul 25, 2019
ec7c2f6
Changed the definition of LiftResult a bit.
chewxy Jul 27, 2019
962a029
added some helper functions
chewxy Jul 31, 2019
1d16dff
Updated Unconcat tor use Nodes instead of []*Node
chewxy Aug 26, 2019
62b6a2a
Merge remote-tracking branch 'origin/golgisupports' into golgisupports
chewxy Aug 26, 2019
00ce7c0
Merge branch 'master' into golgisupports
chewxy Aug 30, 2019
828e521
Merge branch 'master' into golgisupports
chewxy Sep 7, 2019
5681366
Merge branch 'master' into golgisupports
chewxy Oct 8, 2019
25c62a2
Merge branch 'master' into golgisupports
chewxy Oct 16, 2019
7362f50
Added HeEtAl InitWFn
chewxy Oct 20, 2019
926923d
Merge remote-tracking branch 'origin/golgisupports' into golgisupports
chewxy Oct 20, 2019
f1e5c6d
Ugh. Copy and pasting sux when you can only type with one hand
chewxy Oct 20, 2019
04c440d
Merge branch 'master' into golgisupports
chewxy Nov 7, 2019
84f497a
Squashed commit of the following:
chewxy Nov 18, 2019
f270287
Merge remote-tracking branch 'origin/golgisupports' into golgisupports
chewxy Nov 18, 2019
64416fe
Merge branch 'master' into golgisupports
chewxy Dec 7, 2019
e0e3652
Fixed Softmax
chewxy Dec 7, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
99 changes: 99 additions & 0 deletions api_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

175 changes: 175 additions & 0 deletions cmd/genapi/generatemonads.go
@@ -0,0 +1,175 @@
package main

import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/token"
"log"
"path"
"path/filepath"
"strings"
)

type nametypePair struct {
name string
*ast.FuncType
}

func functions(decls []ast.Decl) (signatures []nametypePair) {
for _, decl := range decls {
switch d := decl.(type) {
case *ast.FuncDecl:
signatures = append(signatures, nametypePair{d.Name.Name, d.Type})
default:
}
}
return
}

type strRepr struct {
name string
inTypes []string
retTypes []string

printName bool
}

func (s strRepr) String() string {
buf := new(bytes.Buffer)
buf.Write([]byte("func "))
if s.printName {
buf.Write([]byte(s.name))
}

buf.Write([]byte("("))
for i, v := range s.inTypes {
buf.Write([]byte(v))
if i < len(s.inTypes)-1 {
buf.Write([]byte(", "))
}
}
buf.Write([]byte(") ("))
for i, v := range s.retTypes {
buf.Write([]byte(v))
if i < len(s.retTypes)-1 {
buf.Write([]byte(", "))
}
}
buf.Write([]byte(")"))
return buf.String()
}

func processSig(pair nametypePair) strRepr {
a := pair.FuncType
var inTypes, retTypes []string
if a.Params == nil {
goto next
}
for _, field := range a.Params.List {
names := len(field.Names)
typ := parseTypeExpr(field.Type)
if names == 0 {
inTypes = append(inTypes, typ)
continue
}
for i := 0; i < names; i++ {
inTypes = append(inTypes, typ)
}
}
next:
if a.Results == nil {
return strRepr{pair.name, inTypes, retTypes, true}
}
for _, field := range a.Results.List {
names := len(field.Names)
typ := parseTypeExpr(field.Type)
if names == 0 {
retTypes = append(retTypes, typ)
continue
}
for i := 0; i < names; i++ {
retTypes = append(retTypes, typ)
}
}
return strRepr{pair.name, inTypes, retTypes, true}
}

func parseTypeExpr(expr ast.Expr) string {
switch e := expr.(type) {
case *ast.Ident:
return e.Name
case *ast.StarExpr:
x := parseTypeExpr(e.X)
return "*" + x
case *ast.SelectorExpr:
return parseTypeExpr(e.X) + "." + e.Sel.Name
case *ast.Ellipsis:
return "..." + parseTypeExpr(e.Elt)
case *ast.ArrayType:
return "[]" + parseTypeExpr(e.Elt)
default:
return fmt.Sprintf("%T", expr)
}
}

func filterSigs(xs []strRepr, fn func(strRepr) bool) (retVal []strRepr) {
for _, x := range xs {
if fn(x) {
retVal = append(retVal, x)
}
}
return
}

func functionSignatures() {
files := path.Join(gorgonialoc, "*.go")
matches, err := filepath.Glob(files)

if err != nil {
log.Fatal(err)
}
fset := token.NewFileSet()

var allFns []strRepr
for _, f := range matches {
file, err := parser.ParseFile(fset, f, nil, parser.AllErrors)
if err != nil {
log.Fatal(err)

}

fns := functions(file.Decls)
for _, fn := range fns {
sig := processSig(fn)
sig.printName = false
if strings.Title(sig.name) == sig.name {
allFns = append(allFns, sig)
}
}
}
f := func(a strRepr) bool {
want := []string{"Nodes", "error"}
if len(a.retTypes) != len(want) {
return false
}
for i, v := range a.retTypes {
if v != want[i] {
return false
}
}
return true
}

signatures := make(map[string]int)
interesting := filterSigs(allFns, f)
for _, v := range interesting {
v.printName = true
signatures[fmt.Sprintf("%v", v)]++
}

for k, v := range signatures {
fmt.Printf("%v\t%d\n", k, v)
}
}
60 changes: 60 additions & 0 deletions cmd/genapi/geninterface.go
@@ -0,0 +1,60 @@
package main

import (
"go/parser"
"go/token"
"io"
"log"
"path"
"strings"
"text/template"
)

type UnaryOpInterfaceData struct {
OpTypes []string
Dtype string // f32, f64
}

const unaryOpInterfaceRaw = `func (f *s{{.Dtype}}UnaryOperator) unaryOpType() ʘUnaryOperatorType {
{{$dt := .Dtype -}}
switch f {
{{range $i, $op := .OpTypes -}}
case &{{$op}}{{$dt}}:
return {{$op}}OpType
{{end -}}
}
return maxʘUnaryOperator
}

func (f *s{{.Dtype}}UnaryOperator) String() string { return f.unaryOpType().String() }

`

var unaryOpInterface *template.Template

func init() {
unaryOpInterface = template.Must(template.New("UnOpInterface").Funcs(funcmap).Parse(unaryOpInterfaceRaw))
}

func generateUnaryInterface(outFile io.Writer) {
// parse operator_unary_const.go
filename := path.Join(gorgonialoc, unaryOps)
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
if err != nil {
log.Fatal(err)
}

unaryNames := constTypes(file.Decls, "ʘUnaryOperatorType", "maxʘUnaryOperator")
var opNames []string
for _, v := range unaryNames {
op := strings.TrimSuffix(v, "OpType")
opNames = append(opNames, op)
}

dtypes := []string{"f32", "f64"}
for _, dt := range dtypes {
data := UnaryOpInterfaceData{opNames, dt}
unaryOpInterface.Execute(outFile, data)
}
}