Skip to content
Permalink
Browse files

Upgrade to a better CUDA paradigm (#114)

Both tapeMachine and lispMachine now can use CUDA ops!

Further work will have to be done, of course
  • Loading branch information...
chewxy committed May 12, 2017
1 parent ee08617 commit 215b466dc74e0aeee590583a9ae5662f222ceb63
Showing with 6,778 additions and 2,353 deletions.
  1. +3 −16 README.md
  2. +221 −17 analysis.go
  3. +95 −0 analysis_test.go
  4. +17 −0 batch.go
  5. +3 −0 batch_cuda.go
  6. +28 −0 bench_concurrent_training_test.go
  7. +436 −0 bfc.go
  8. +149 −0 bfc_test.go
  9. +81 −0 bitmap.go
  10. +125 −0 bitmap_test.go
  11. +1 −1 blas.go
  12. +26 −63 blase/blas_test.go
  13. +51 −19 blase/work.go
  14. +38 −2 cmd/cudagen/main.go
  15. +3 −0 collections.go
  16. +1 −1 collections_test.go
  17. +367 −73 compile.go
  18. +63 −1 compile_test.go
  19. +146 −0 complex_test.go
  20. +3 −3 const.go
  21. +0 −21 cuda modules/src/abs32.cu
  22. +0 −21 cuda modules/src/abs64.cu
  23. +0 −18 cuda modules/src/add32.cu
  24. +0 −18 cuda modules/src/add64.cu
  25. +0 −20 cuda modules/src/cos32.cu
  26. +0 −20 cuda modules/src/cos64.cu
  27. +0 −18 cuda modules/src/cube32.cu
  28. +0 −18 cuda modules/src/cube64.cu
  29. +0 −18 cuda modules/src/div32.cu
  30. +0 −18 cuda modules/src/div64.cu
  31. +125 −0 cuda modules/src/elembinop.cu
  32. +192 −0 cuda modules/src/elemunaryop.cu
  33. +0 −20 cuda modules/src/exp32.cu
  34. +0 −20 cuda modules/src/exp64.cu
  35. +0 −20 cuda modules/src/expm132.cu
  36. +0 −20 cuda modules/src/expm164.cu
  37. +0 −20 cuda modules/src/inv32.cu
  38. +0 −20 cuda modules/src/inv64.cu
  39. +0 −20 cuda modules/src/ln32.cu
  40. +0 −20 cuda modules/src/ln64.cu
  41. +0 −20 cuda modules/src/log1p32.cu
  42. +0 −20 cuda modules/src/log1p64.cu
  43. +0 −20 cuda modules/src/log232.cu
  44. +0 −20 cuda modules/src/log264.cu
  45. +0 −18 cuda modules/src/mul32.cu
  46. +0 −18 cuda modules/src/mul64.cu
  47. +0 −18 cuda modules/src/neg32.cu
  48. +0 −18 cuda modules/src/neg64.cu
  49. +0 −20 cuda modules/src/sin32.cu
  50. +0 −20 cuda modules/src/sin64.cu
  51. +0 −20 cuda modules/src/softplus32.cu
  52. +0 −20 cuda modules/src/softplus64.cu
  53. +0 −20 cuda modules/src/sqrt32.cu
  54. +0 −20 cuda modules/src/sqrt64.cu
  55. +0 −18 cuda modules/src/square32.cu
  56. +0 −18 cuda modules/src/square64.cu
  57. +0 −18 cuda modules/src/sub32.cu
  58. +0 −18 cuda modules/src/sub64.cu
  59. +0 −20 cuda modules/src/tanh32.cu
  60. +0 −20 cuda modules/src/tanh64.cu
  61. +372 −23 cuda.go
  62. +81 −14 cuda_test.go
  63. +36 −16 debug.go
  64. +7 −1 device.go
  65. +37 −0 device_cuda.go
  66. +0 −13 differentiation.go
  67. +84 −8 dual.go
  68. +98 −0 equalities.go
  69. +30 −29 errors.go
  70. +3 −7 example_basic_test.go
  71. +218 −0 example_concurrent_training_test.go
  72. +29 −13 example_linearregression_test.go
  73. +1 −7 example_symdiff_test.go
  74. +11 −4 examples/charRNN/main.go
  75. +34 −17 examples/charRNN/model.go
  76. +1 −5 examples/cuda/main.go
  77. +42 −42 examples/logisticregression/main.go
  78. +3 −7 examples/stacked autoencoder/stackedDA.go
  79. +172 −0 execution.go
  80. +1 −1 gorgonia.go
  81. +41 −4 graph.go
  82. +11 −0 mathutils.go
  83. +51 −0 mathutils.s
  84. +70 −0 mathutils_go.go
  85. +4 −5 nn_test.go
  86. +10 −3 node.go
  87. +92 −1 noextern.go
  88. +4 −4 noextern_test.go
  89. +13 −10 op.go
  90. +29 −0 op_infidel.go
  91. +45 −39 op_math.go
  92. +188 −111 op_math_cuda.go
  93. +63 −32 op_math_cuda_test.go
  94. +22 −0 op_math_noextern.go
  95. +74 −16 op_math_test.go
  96. +33 −10 op_reduction.go
  97. +24 −23 op_reduction_test.go
  98. +16 −6 op_tensor.go
  99. +43 −10 op_tensor_test.go
  100. +2 −1 operations.go
  101. +26 −36 operations_test.go
  102. +4 −4 operatorLinAlg.go
  103. +1 −1 operatorLinAlg_const.go
  104. +294 −60 operatorPointwise_binary.go
  105. +1 −1 operatorPointwise_binary_const.go
  106. +24 −26 operatorPointwise_binary_test.go
  107. +322 −19 operatorPointwise_unary_test.go
  108. +96 −145 regalloc.go
  109. +16 −88 regalloc_test.go
  110. +24 −14 release.go
  111. +13 −3 templates.go
  112. +0 −12 tensor/api_arith.go
  113. +76 −1 tensor/consopt.go
  114. +13 −0 tensor/dense.go
  115. +2 −2 tensor/dense_generated.go
  116. +21 −2 tensor/dense_test.go
  117. +2 −2 tensor/genlib/dense_gen.go
  118. +2 −0 tensor/genlib/main.go
  119. +9 −1 tensor/perf.go
  120. +7 −0 tensor/tensor.go
  121. +24 −16 testsetup_test.go
  122. +9 −2 utils.go
  123. +111 −7 values.go
  124. +0 −74 values_cuda.go
  125. +28 −1 values_utils.go
  126. +32 −2 vm.go
  127. +303 −107 vm_genera.go
  128. +261 −0 vm_genera_cuda.go
  129. +99 −0 vm_genera_cuda_test.go
  130. +16 −0 vm_genera_nocuda.go
  131. +16 −12 vm_genera_test.go
  132. +252 −241 vm_tape.go
  133. +237 −61 vm_tape_cuda.go
  134. +103 −0 vm_tape_nocuda.go
  135. +31 −1 walker.go
  136. +34 −0 weights.go
@@ -85,14 +85,8 @@ func main() {
log.Fatal(err)
}
// compile into a program
prog, locMap, err := Compile(g)
if err != nil {
log.Fatal(err)
}
// create a VM to run the program on
machine := NewTapeMachine(prog, locMap)
machine := NewTapeMachine(g)
// set initial values then run
Let(x, 2.0)
@@ -220,14 +214,8 @@ func main() {
log.Fatal(err)
}
// compile into a program
prog, locMap, err := Compile(g)
if err != nil {
log.Fatal(err)
}
// create a VM to run the program on
machine := NewTapeMachine(prog, locMap)
machine := NewTapeMachine(g)
// set initial values then run
Let(x, 2.0)
@@ -330,8 +318,7 @@ func main() {
xpy := T.Must(T.Add(x, y))
xpy2 := T.Must(T.Tanh(xpy))
prog, locMap, _ := T.Compile(g)
m := T.NewTapeMachine(prog, locMap, T.UseCudaFor("tanh"))
m := T.NewTapeMachine(g, T.UseCudaFor("tanh"))
T.Let(x, tensor.New(tensor.WithShape(100, 100), tensor.WithBacking(tensor.Random(tensor.Float32, 100*100))))
T.Let(y, tensor.New(tensor.WithShape(100, 100), tensor.WithBacking(tensor.Random(tensor.Float32, 100*100))))
@@ -12,11 +12,17 @@ type dataflow struct {

replacements map[*Node]*Node
intervals map[*Node]*interval

// tracks the special nodes' children and parents
devTransChildren map[*Node]Nodes
devTransRepl map[*Node]*Node
}

func newdataflow() *dataflow {
df := new(dataflow)
df.uniques = make(map[uint32]*Node)
df.devTransChildren = make(map[*Node]Nodes)
df.devTransRepl = make(map[*Node]*Node)
return df
}

@@ -40,6 +46,40 @@ func (df *dataflow) vn(n *Node) (retVal *Node, unique bool) {
return n, true
}

// analyzeDevice records which node is supposed to be executed on which device.
//
// Currently it will only use Device 0. In the future, we can be smart about which device to use
func (df *dataflow) analyzeDevice(n *Node) {
switch n.op.(type) {
case CUDADoer:
if n.dataOn == CPU {
n.dataOn = Device(0)
}
case CLDoer:
if n.dataOn == CPU {
n.dataOn = Device(0)
}
default:
n.dataOn = CPU
}
}

// replaceWithSelf fills the replacement map with itself. This is the method used in the lispMachine only, as it skips value numbering
func (df *dataflow) replaceWithSelf(sorted Nodes) {
df.replacements = make(map[*Node]*Node)
for _, n := range sorted {
df.replacements[n] = n
df.analyzeDevice(n) // Device Targeting
}
}

// fixIntervalDevices is used only by the lispMachine. It fixes the intervals to have the correct devices
func (df *dataflow) fixIntervalDevices(sorted Nodes) {
for _, n := range sorted {
df.intervals[n].result.device = n.dataOn
}
}

func analyze(g *ExprGraph, sorted Nodes) *dataflow {
compileLogf("Performing dataflow analysis")
enterLoggingContext()
@@ -51,21 +91,16 @@ func analyze(g *ExprGraph, sorted Nodes) *dataflow {
df.uniques[n.Hashcode()] = n
}

compileLogf("Common subexpression elimination")

// common subexpression elimination
// compileLogf("Common subexpression elimination")
// compileLogf("analyzing devices")
replacements := make(map[*Node]*Node)
var buf bytes.Buffer
for i := len(sorted) - 1; i >= 0; i-- {
n := sorted[i]
fmt.Fprintf(&buf, "%d, ", n.ID())
for _, n := range sorted {
r, _ := df.vn(n)
replacements[n] = r
replacements[n] = r // CSE
df.analyzeDevice(n) // Device targeting
}
df.replacements = replacements

compileLogf("replacements: %+p", FmtNodeMap(replacements))
compileLogf("%v", buf.String())
compileLogf("replacements: %-p", FmtNodeMap(replacements))

// TODO
// constant propagation
@@ -86,14 +121,183 @@ func analyze(g *ExprGraph, sorted Nodes) *dataflow {
return df
}

func analyzeMem(g *ExprGraph, sorted Nodes) {
func newDevTransNode(read, write *Node, from, to Device) *Node {
op := devTrans{from, to, write}
n := borrowNode()
n.id = -1
n.op = op
n.shape = read.shape.Clone()
n.t = read.t
n.isStmt = true
n.children = Nodes{read}
return n
}

func (df *dataflow) insertDeviceInstr(sorted Nodes) Nodes {
compileLogf("Inserting Device Transport Instructions")
enterLoggingContext()
defer leaveLoggingContext()
// input -> output
for i := 0; i < len(sorted); i++ {
node := sorted[i]
n := df.replacements[node]
dev := n.dataOn

compileLogf("Working on %v. Replacement %v. Device %v", node, n, dev)
var incr int
var useReplacement bool
replacementChildren := make(Nodes, len(n.children))
enterLoggingContext()
for j, child := range n.children {
c := df.replacements[child]
childDev := c.dataOn

for _, node := range sorted {
switch {
case node.isArg():
case node.op.OverwritesInput() >= 0:
case node.op.ReturnsPtr():
compileLogf("Working on child :%v. Device: %v, Parent Device %v", c, childDev, dev)
if childDev != dev {
useReplacement = true
if repl, ok := df.devTransRepl[c]; ok {
replacementChildren[j] = repl
continue
}
transport := newDevTransNode(c, n, childDev, dev)
sorted = append(sorted, nil)
copy(sorted[i+1:], sorted[i:])
sorted[i] = transport
incr++
compileLogf("Inserted %v", transport)

// other stateful stuff
df.devTransRepl[c] = transport
df.replacements[transport] = transport
replacementChildren[j] = transport
} else {
replacementChildren[j] = child
}
}
leaveLoggingContext()

if useReplacement {
df.devTransChildren[n] = replacementChildren
}

i += incr
}
return sorted
}

/*
Notes on handling the live set:
1. We load all the SSAs listed in the block's LiveIn
2. Then we load all the SSAs used as input in this block Phi nodes
- The reason for this is so that those SSAs can have intervals created
that are live in this block (well, they are kinda live)
3. These input SSAs are temporary only, because a path-dependent liveset will be calculated below
Consider a CFG that looks like this:
BLOCK 1 BLOCK 3
+-------+ +-------+
+---->| x = 1 +------->| y = 3 +----------------+
BLOCK 0 | +-------+ | use x | v BLOCK 4
+-------+ | +-------+ +-------------+
| |+----+ | x = ϕ(1, 2) |
+-------+ | BLOCK 2 +-------------+
| +-------+ ^
+---->| x = 2 +---------------------------------+
+-------+
`x = 1` needs to be live in BLOCK 1, BLOCK 3 and BLOCK 4
`x = 2` needs to be live in BLOCK 2 and BLOCK 4.
The solution: in BLOCK 4, load `x = 1` and `x = 2` so they can be considered live in Block 4.
The interval building process comes to BLOCK 3 next. It considers the SSAs that are live in BLOCK 4.
If `x = 2` is live in BLOCK 4, it's Bad News with capital letters (see comment below).
The solution: remove the InputSSAs of the Phi nodes when we're leaving this block.
*/
// TODO: rephrase above to fit this package's function.
// It's like the above, but without basic blocks, phi nodes, etc, making it a LOT simpler
func (df *dataflow) buildIntervals(sorted Nodes) {
compileLogf("Building intervals for %v", sorted)
enterLoggingContext()
defer leaveLoggingContext()

intervals := make(map[*Node]*interval)

var g *ExprGraph
for _, n := range sorted {
if g == nil && n.g != nil {
g = n.g
}

intervals[n] = newInterval()
}

instructions := len(sorted)
for i := len(sorted) - 1; i >= 0; i-- {
n := sorted[i]
instrNum := i
nInter := intervals[n]
compileLogf("n %v | %v", n, nInter)

// inputs will be live the entire program
if n.isInput() {
nInter.addRange(instrNum, instructions)
repl, ok := df.devTransRepl[n]
if ok {
interv, ok := intervals[repl]
if ok {
interv.addRange(instrNum, instructions)
}
}

continue
}
nInter.addRange(instrNum, instrNum)

// check for special cases requiring copying from device to device

var children Nodes
var ok bool
if children, ok = df.devTransChildren[n]; !ok {
children = n.children
}

for _, child := range children {
iv, ok := intervals[child]
if !ok {
// do something
// parents := g.to[n]
// for i, from := range parents {
// ioutil.WriteFile(fmt.Sprintf("n_%d.dot", i), []byte(from.ToDot()), 0644)
// }
}
iv.addUsePositions(instrNum)
// iv.setTo(instrNum)
}
// assume all derivations of input
if len(n.derivOf) > 0 {
for _, d := range n.derivOf {
if d.isInput() {
nInter.addUsePositions(instructions)
break
}
}
}
}

for _, iv := range intervals {
iv.fix()
}

var buf bytes.Buffer
for k, v := range intervals {
fmt.Fprintf(&buf, "%v: %v\n", k, v)
}
compileLogf("Intervals: %v", buf.String())

df.intervals = intervals
return
}

0 comments on commit 215b466

Please sign in to comment.
You can’t perform that action at this time.