Skip to content

Commit

Permalink
Merge fc9d1b7 into 5fb5944
Browse files Browse the repository at this point in the history
  • Loading branch information
chewxy authored Jun 1, 2020
2 parents 5fb5944 + fc9d1b7 commit 42f0df1
Show file tree
Hide file tree
Showing 20 changed files with 239 additions and 182 deletions.
3 changes: 0 additions & 3 deletions compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,15 +377,12 @@ func (cg *codegenerator) addNode(node, replacement *Node, interv *interval, i in
compileLogf("Inserting new alloc")
var instr alloc
instr = newAlloc(node, writeTo)
// cg.instructions = append(cg.instructions, instr)

cg.addInstr(node, instr)
cg.updateLastWrites(writeTo, node)

prealloc = true

cg.queue = append(cg.queue, i)
// cg.queue = append(cg.queue, len(cg.instructions)) // no -1.
cg.allocated[writeTo] = struct{}{}
}
}
Expand Down
1 change: 1 addition & 0 deletions cuda/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func binaryCheck(a, b tensor.Tensor) (err error) {
if at.Kind() != bt.Kind() {
return errors.Errorf(typeMismatch, at, bt)
}

if !a.Shape().Eq(b.Shape()) {
log.Printf("BINARY CHECK %v %v", a.Shape(), b.Shape())
return errors.Errorf(shapeMismatch, b.Shape(), a.Shape())
Expand Down
6 changes: 3 additions & 3 deletions differentiation.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func backwardDiffAnalysis(wrt, sortedNodes Nodes) (retVal NodeSet, err error) {
return diffSet, nil
}

// Backpropagate backpropagates errors by performing revers-emode symbolic differentiation, starting from the outputs, and working its way towads the inputs.
// Backpropagate backpropagates errors by performing reverse-mode symbolic differentiation, starting from the outputs, and working its way towads the inputs.
//
// This is the rough algorithm:
// 1. Filter out nodes that are unreachable
Expand Down Expand Up @@ -192,14 +192,14 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
wrtSet := wrt.mapSet()
badWRTs := wrtSet.Difference(affectsOutput)
if len(badWRTs) > 0 {
return nil, SymDiffError{nodes: badWRTs.ToSlice(), err: errors.New("Non Differentiable WRTs")}
return nil, SymDiffError{nodes: badWRTs.ToSlice(), err: errors.Errorf("Non Differentiable WRTs: %v", badWRTs)}
}

outputSet := outputs.mapSet()
badOutputs := outputSet.Difference(affectedByOutput)
if len(badOutputs) > 0 {
symdiffLogf("badOutputs: %#v", badOutputs)
return nil, SymDiffError{nodes: badOutputs.ToSlice(), err: errors.New("Non-Differentable Outputs")}
return nil, SymDiffError{nodes: badOutputs.ToSlice(), err: errors.Errorf("Non-Differentable Outputs: %v", badOutputs)}
}

// map a node to a list of gradient terms
Expand Down
39 changes: 38 additions & 1 deletion example_linalg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"gorgonia.org/tensor"
)

func Example_batchedMatMul() {
func ExampleBatchedMatMul() {
g := NewGraph()
a := NewTensor(g, Float64, 3, WithShape(2, 2, 3), WithInit(RangedFrom(1)), WithName("a"))
b := NewTensor(g, Float64, 3, WithShape(2, 3, 2), WithInit(RangedFrom(13)), WithName("b"))
Expand Down Expand Up @@ -74,3 +74,40 @@ func TestIncrSlices(t *testing.T) {
}
}
}

func ExampleBatchedMatMul_withBackprop() {
g := NewGraph()
a := NewTensor(g, Float64, 4, WithShape(2, 4, 3, 9), WithInit(RangedFrom(1)), WithName("a"))
b := NewTensor(g, Float64, 4, WithShape(2, 4, 3, 9), WithInit(RangedFrom(13)), WithName("b"))
c, err := BatchedMatMul(a, b, false, true)
if err != nil {
log.Fatal(err)
}
s, err := Sum(c)
if err != nil {
log.Fatal(err)
}
grads, err := Grad(s, a, b)
if err != nil {
log.Fatal(err)
}

m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
log.Fatal(err)
}

fmt.Printf("a: %v\n%v\n", a.Value().Shape(), a.Value().Data())
fmt.Printf("b: %v\n%v\n", b.Value().Shape(), b.Value().Data())
fmt.Printf("c: %v\n%v\n", c.Value().Shape(), c.Value().Data())
fmt.Printf("grads[0]:%v\n%v\n", grads[0].Shape(), grads[0].Value().Data())
// Output:
// a: (2, 4, 3, 9)
// [1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216]
// b: (2, 4, 3, 9)
// [13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228]
// c: (2, 4, 3, 3)
// [825 1230 1635 2202 3336 4470 3579 5442 7305 12732 15324 17916 16296 19617 22938 19860 23910 27960 37761 42540 47319 43512 49020 54528 49263 55500 61737 75912 82878 89844 83850 91545 99240 91788 100212 108636 127185 136338 145491 137310 147192 157074 147435 158046 168657 191580 202920 214260 203892 215961 228030 216204 229002 241800 269097 282624 296151 283596 297852 312108 298095 313080 328065 359736 375450 391164 376422 392865 409308 393108 410280 427452]
// grads[0]:(2, 4, 3, 9)
// [66 69 72 75 78 81 84 87 90 66 69 72 75 78 81 84 87 90 66 69 72 75 78 81 84 87 90 147 150 153 156 159 162 165 168 171 147 150 153 156 159 162 165 168 171 147 150 153 156 159 162 165 168 171 228 231 234 237 240 243 246 249 252 228 231 234 237 240 243 246 249 252 228 231 234 237 240 243 246 249 252 309 312 315 318 321 324 327 330 333 309 312 315 318 321 324 327 330 333 309 312 315 318 321 324 327 330 333 390 393 396 399 402 405 408 411 414 390 393 396 399 402 405 408 411 414 390 393 396 399 402 405 408 411 414 471 474 477 480 483 486 489 492 495 471 474 477 480 483 486 489 492 495 471 474 477 480 483 486 489 492 495 552 555 558 561 564 567 570 573 576 552 555 558 561 564 567 570 573 576 552 555 558 561 564 567 570 573 576 633 636 639 642 645 648 651 654 657 633 636 639 642 645 648 651 654 657 633 636 639 642 645 648 651 654 657]
}
132 changes: 61 additions & 71 deletions examples/convnet/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"flag"
"fmt"
"io/ioutil"
"log"
"math/rand"
"os"
Expand All @@ -15,7 +14,7 @@ import (
_ "net/http/pprof"

"github.com/pkg/errors"
"gorgonia.org/gorgonia"
G "gorgonia.org/gorgonia"
"gorgonia.org/gorgonia/examples/mnist"
"gorgonia.org/tensor"

Expand Down Expand Up @@ -47,28 +46,20 @@ func parseDtype() {
}
}

type sli struct {
start, end int
}

func (s sli) Start() int { return s.start }
func (s sli) End() int { return s.end }
func (s sli) Step() int { return 1 }

type convnet struct {
g *gorgonia.ExprGraph
w0, w1, w2, w3, w4 *gorgonia.Node // weights. the number at the back indicates which layer it's used for
d0, d1, d2, d3 float64 // dropout probabilities
g *G.ExprGraph
w0, w1, w2, w3, w4 *G.Node // weights. the number at the back indicates which layer it's used for
d0, d1, d2, d3 float64 // dropout probabilities

out *gorgonia.Node
out *G.Node
}

func newConvNet(g *gorgonia.ExprGraph) *convnet {
w0 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(32, 1, 3, 3), gorgonia.WithName("w0"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
w1 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(64, 32, 3, 3), gorgonia.WithName("w1"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
w2 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(128, 64, 3, 3), gorgonia.WithName("w2"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
w3 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128*3*3, 625), gorgonia.WithName("w3"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
w4 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(625, 10), gorgonia.WithName("w4"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
func newConvNet(g *G.ExprGraph) *convnet {
w0 := G.NewTensor(g, dt, 4, G.WithShape(32, 1, 3, 3), G.WithName("w0"), G.WithInit(G.GlorotN(1.0)))
w1 := G.NewTensor(g, dt, 4, G.WithShape(64, 32, 3, 3), G.WithName("w1"), G.WithInit(G.GlorotN(1.0)))
w2 := G.NewTensor(g, dt, 4, G.WithShape(128, 64, 3, 3), G.WithName("w2"), G.WithInit(G.GlorotN(1.0)))
w3 := G.NewMatrix(g, dt, G.WithShape(128*3*3, 625), G.WithName("w3"), G.WithInit(G.GlorotN(1.0)))
w4 := G.NewMatrix(g, dt, G.WithShape(625, 10), G.WithName("w4"), G.WithInit(G.GlorotN(1.0)))
return &convnet{
g: g,
w0: w0,
Expand All @@ -84,89 +75,86 @@ func newConvNet(g *gorgonia.ExprGraph) *convnet {
}
}

func (m *convnet) learnables() gorgonia.Nodes {
return gorgonia.Nodes{m.w0, m.w1, m.w2, m.w3, m.w4}
func (m *convnet) learnables() G.Nodes {
return G.Nodes{m.w0, m.w1, m.w2, m.w3, m.w4}
}

// This function is particularly verbose for educational reasons. In reality, you'd wrap up the layers within a layer struct type and perform per-layer activations
func (m *convnet) fwd(x *gorgonia.Node) (err error) {
var c0, c1, c2, fc *gorgonia.Node
var a0, a1, a2, a3 *gorgonia.Node
var p0, p1, p2 *gorgonia.Node
var l0, l1, l2, l3 *gorgonia.Node
func (m *convnet) fwd(x *G.Node) (err error) {
var c0, c1, c2, fc *G.Node
var a0, a1, a2, a3 *G.Node
var p0, p1, p2 *G.Node
var l0, l1, l2, l3 *G.Node

// LAYER 0
// here we convolve with stride = (1, 1) and padding = (1, 1),
// which is your bog standard convolution for convnet
if c0, err = gorgonia.Conv2d(x, m.w0, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
if c0, err = G.Conv2d(x, m.w0, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
return errors.Wrap(err, "Layer 0 Convolution failed")
}
if a0, err = gorgonia.Rectify(c0); err != nil {
if a0, err = G.Rectify(c0); err != nil {
return errors.Wrap(err, "Layer 0 activation failed")
}
if p0, err = gorgonia.MaxPool2D(a0, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
if p0, err = G.MaxPool2D(a0, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
return errors.Wrap(err, "Layer 0 Maxpooling failed")
}
log.Printf("p0 shape %v", p0.Shape())
if l0, err = gorgonia.Dropout(p0, m.d0); err != nil {
if l0, err = G.Dropout(p0, m.d0); err != nil {
return errors.Wrap(err, "Unable to apply a dropout")
}

// Layer 1
if c1, err = gorgonia.Conv2d(l0, m.w1, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
if c1, err = G.Conv2d(l0, m.w1, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
return errors.Wrap(err, "Layer 1 Convolution failed")
}
if a1, err = gorgonia.Rectify(c1); err != nil {
if a1, err = G.Rectify(c1); err != nil {
return errors.Wrap(err, "Layer 1 activation failed")
}
if p1, err = gorgonia.MaxPool2D(a1, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
if p1, err = G.MaxPool2D(a1, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
return errors.Wrap(err, "Layer 1 Maxpooling failed")
}
if l1, err = gorgonia.Dropout(p1, m.d1); err != nil {
if l1, err = G.Dropout(p1, m.d1); err != nil {
return errors.Wrap(err, "Unable to apply a dropout to layer 1")
}

// Layer 2
if c2, err = gorgonia.Conv2d(l1, m.w2, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
if c2, err = G.Conv2d(l1, m.w2, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil {
return errors.Wrap(err, "Layer 2 Convolution failed")
}
if a2, err = gorgonia.Rectify(c2); err != nil {
if a2, err = G.Rectify(c2); err != nil {
return errors.Wrap(err, "Layer 2 activation failed")
}
if p2, err = gorgonia.MaxPool2D(a2, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
if p2, err = G.MaxPool2D(a2, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil {
return errors.Wrap(err, "Layer 2 Maxpooling failed")
}
log.Printf("p2 shape %v", p2.Shape())

var r2 *gorgonia.Node
var r2 *G.Node
b, c, h, w := p2.Shape()[0], p2.Shape()[1], p2.Shape()[2], p2.Shape()[3]
if r2, err = gorgonia.Reshape(p2, tensor.Shape{b, c * h * w}); err != nil {
if r2, err = G.Reshape(p2, tensor.Shape{b, c * h * w}); err != nil {
return errors.Wrap(err, "Unable to reshape layer 2")
}
log.Printf("r2 shape %v", r2.Shape())
if l2, err = gorgonia.Dropout(r2, m.d2); err != nil {
if l2, err = G.Dropout(r2, m.d2); err != nil {
return errors.Wrap(err, "Unable to apply a dropout on layer 2")
}

ioutil.WriteFile("tmp.dot", []byte(m.g.ToDot()), 0644)

// Layer 3
if fc, err = gorgonia.Mul(l2, m.w3); err != nil {
if fc, err = G.Mul(l2, m.w3); err != nil {
return errors.Wrapf(err, "Unable to multiply l2 and w3")
}
if a3, err = gorgonia.Rectify(fc); err != nil {
if a3, err = G.Rectify(fc); err != nil {
return errors.Wrapf(err, "Unable to activate fc")
}
if l3, err = gorgonia.Dropout(a3, m.d3); err != nil {
if l3, err = G.Dropout(a3, m.d3); err != nil {
return errors.Wrapf(err, "Unable to apply a dropout on layer 3")
}

// output decode
var out *gorgonia.Node
if out, err = gorgonia.Mul(l3, m.w4); err != nil {
var out *G.Node
if out, err = G.Mul(l3, m.w4); err != nil {
return errors.Wrapf(err, "Unable to multiply l3 and w4")
}
m.out, err = gorgonia.SoftMax(out)
m.out, err = G.SoftMax(out)
return
}

Expand Down Expand Up @@ -208,36 +196,36 @@ func main() {
if err := inputs.Reshape(numExamples, 1, 28, 28); err != nil {
log.Fatal(err)
}
g := gorgonia.NewGraph()
x := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(bs, 1, 28, 28), gorgonia.WithName("x"))
y := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(bs, 10), gorgonia.WithName("y"))
g := G.NewGraph()
x := G.NewTensor(g, dt, 4, G.WithShape(bs, 1, 28, 28), G.WithName("x"))
y := G.NewMatrix(g, dt, G.WithShape(bs, 10), G.WithName("y"))
m := newConvNet(g)
if err = m.fwd(x); err != nil {
log.Fatalf("%+v", err)
}
losses := gorgonia.Must(gorgonia.Log(gorgonia.Must(gorgonia.HadamardProd(m.out, y))))
cost := gorgonia.Must(gorgonia.Mean(losses))
cost = gorgonia.Must(gorgonia.Neg(cost))
losses := G.Must(G.Log(G.Must(G.HadamardProd(m.out, y))))
cost := G.Must(G.Mean(losses))
cost = G.Must(G.Neg(cost))

// we wanna track costs
var costVal gorgonia.Value
gorgonia.Read(cost, &costVal)
var costVal G.Value
G.Read(cost, &costVal)

// if _, err = gorgonia.Grad(cost, m.learnables()...); err != nil {
// log.Fatal(err)
// }
if _, err = G.Grad(cost, m.learnables()...); err != nil {
log.Fatal(err)
}

// debug
// ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644)
// log.Printf("%v", prog)
// logger := log.New(os.Stderr, "", 0)
// vm := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(m.learnables()...), gorgonia.WithLogger(logger), gorgonia.WithWatchlist())

prog, locMap, _ := gorgonia.Compile(g)
log.Printf("%v", prog)
prog, locMap, _ := G.Compile(g)
//log.Printf("%v", prog)

vm := gorgonia.NewTapeMachine(g, gorgonia.WithPrecompiled(prog, locMap), gorgonia.BindDualValues(m.learnables()...))
solver := gorgonia.NewRMSPropSolver(gorgonia.WithBatchSize(float64(bs)))
vm := G.NewTapeMachine(g, G.WithPrecompiled(prog, locMap), G.BindDualValues(m.learnables()...))
solver := G.NewRMSPropSolver(G.WithBatchSize(float64(bs)))
defer vm.Close()
// pprof
// handlePprof(sigChan, doneChan)
Expand Down Expand Up @@ -275,23 +263,25 @@ func main() {
}

var xVal, yVal tensor.Tensor
if xVal, err = inputs.Slice(sli{start, end}); err != nil {
if xVal, err = inputs.Slice(G.S(start, end)); err != nil {
log.Fatal("Unable to slice x")
}

if yVal, err = targets.Slice(sli{start, end}); err != nil {
if yVal, err = targets.Slice(G.S(start, end)); err != nil {
log.Fatal("Unable to slice y")
}
if err = xVal.(*tensor.Dense).Reshape(bs, 1, 28, 28); err != nil {
log.Fatalf("Unable to reshape %v", err)
}

gorgonia.Let(x, xVal)
gorgonia.Let(y, yVal)
G.Let(x, xVal)
G.Let(y, yVal)
if err = vm.RunAll(); err != nil {
log.Fatalf("Failed at epoch %d: %v", i, err)
log.Fatalf("Failed at epoch %d, batch %d. Error: %v", i, b, err)
}
if err = solver.Step(G.NodesToValueGrads(m.learnables())); err != nil {
log.Fatalf("Failed to update nodes with gradients at epoch %d, batch %d. Error %v", i, b, err)
}
solver.Step(gorgonia.NodesToValueGrads(m.learnables()))
vm.Reset()
bar.Increment()
}
Expand Down
Loading

0 comments on commit 42f0df1

Please sign in to comment.