/
main.go
107 lines (91 loc) · 1.87 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package main
import (
"fmt"
"log"
"os"
"os/exec"
"os/user"
"path"
"strings"
"text/template"
)
const genmsg = "// Code generated by gencudaengine, which is a API generation tool for Gorgonia. DO NOT EDIT."
const (
arithOut = "arith.go"
unaryOut = "unary.go"
cmpOut = "cmp.go"
)
var (
gopath, cudaengloc string
)
var funcmap = template.FuncMap{
"lower": strings.ToLower,
}
type binOp struct {
Method string
ScalarMethod string
}
var ariths = []binOp{
{"Add", "Add"},
{"Sub", "Sub"},
{"Mul", "Mul"},
{"Div", "Div"},
{"Pow", "Pow"},
{"Mod", "Mod"},
}
var cmps = []binOp{
{"Lt", "Lt"},
{"Lte", "Lte"},
{"Gt", "Gt"},
{"Gte", "Gte"},
{"ElEq", "Eq"},
{"ElNe", "Ne"},
}
func init() {
gopath = os.Getenv("GOPATH")
if gopath == "" {
usr, err := user.Current()
if err != nil {
log.Fatal(err)
}
gopath = path.Join(usr.HomeDir, "go")
stat, err := os.Stat(gopath)
if err != nil {
log.Fatal(err)
}
if !stat.IsDir() {
log.Fatal("You need to define a $GOPATH")
}
}
cudaengloc = path.Join(gopath, "src/gorgonia.org/gorgonia/cuda")
}
func generateAriths() {
p := path.Join(cudaengloc, arithOut)
f, _ := os.OpenFile(p, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
fmt.Fprintf(f, "package cuda\n\n%v\n\n", genmsg)
for _, op := range ariths {
binopTmpl.Execute(f, op)
}
f.Close()
cmd := exec.Command("goimports", "-w", p)
if err := cmd.Run(); err != nil {
log.Fatalf("Go imports failed with %v for %q", err, p)
}
}
func generateCmps() {
p := path.Join(cudaengloc, cmpOut)
f, _ := os.OpenFile(p, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
fmt.Fprintf(f, "package cuda\n\n%v\n\n", genmsg)
for _, op := range cmps {
binopTmpl.Execute(f, op)
}
f.Close()
cmd := exec.Command("goimports", "-w", p)
if err := cmd.Run(); err != nil {
log.Fatalf("Go imports failed with %v for %q", err, p)
}
}
func main() {
generateAriths()
generateCmps()
}