diff --git a/compile.go b/compile.go index c7e191e1..6656c62b 100644 --- a/compile.go +++ b/compile.go @@ -377,15 +377,12 @@ func (cg *codegenerator) addNode(node, replacement *Node, interv *interval, i in compileLogf("Inserting new alloc") var instr alloc instr = newAlloc(node, writeTo) - // cg.instructions = append(cg.instructions, instr) - cg.addInstr(node, instr) cg.updateLastWrites(writeTo, node) prealloc = true cg.queue = append(cg.queue, i) - // cg.queue = append(cg.queue, len(cg.instructions)) // no -1. cg.allocated[writeTo] = struct{}{} } } diff --git a/cuda/utils.go b/cuda/utils.go index 2c5c756a..e5c8e77f 100644 --- a/cuda/utils.go +++ b/cuda/utils.go @@ -74,6 +74,7 @@ func binaryCheck(a, b tensor.Tensor) (err error) { if at.Kind() != bt.Kind() { return errors.Errorf(typeMismatch, at, bt) } + if !a.Shape().Eq(b.Shape()) { log.Printf("BINARY CHECK %v %v", a.Shape(), b.Shape()) return errors.Errorf(shapeMismatch, b.Shape(), a.Shape()) diff --git a/differentiation.go b/differentiation.go index 7316ad4d..70a6d2bc 100644 --- a/differentiation.go +++ b/differentiation.go @@ -132,7 +132,7 @@ func backwardDiffAnalysis(wrt, sortedNodes Nodes) (retVal NodeSet, err error) { return diffSet, nil } -// Backpropagate backpropagates errors by performing revers-emode symbolic differentiation, starting from the outputs, and working its way towads the inputs. +// Backpropagate backpropagates errors by performing reverse-mode symbolic differentiation, starting from the outputs, and working its way towads the inputs. // // This is the rough algorithm: // 1. Filter out nodes that are unreachable @@ -192,14 +192,14 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) { wrtSet := wrt.mapSet() badWRTs := wrtSet.Difference(affectsOutput) if len(badWRTs) > 0 { - return nil, SymDiffError{nodes: badWRTs.ToSlice(), err: errors.New("Non Differentiable WRTs")} + return nil, SymDiffError{nodes: badWRTs.ToSlice(), err: errors.Errorf("Non Differentiable WRTs: %v", badWRTs)} } outputSet := outputs.mapSet() badOutputs := outputSet.Difference(affectedByOutput) if len(badOutputs) > 0 { symdiffLogf("badOutputs: %#v", badOutputs) - return nil, SymDiffError{nodes: badOutputs.ToSlice(), err: errors.New("Non-Differentable Outputs")} + return nil, SymDiffError{nodes: badOutputs.ToSlice(), err: errors.Errorf("Non-Differentable Outputs: %v", badOutputs)} } // map a node to a list of gradient terms diff --git a/example_linalg_test.go b/example_linalg_test.go index 848221fb..683b4065 100644 --- a/example_linalg_test.go +++ b/example_linalg_test.go @@ -8,7 +8,7 @@ import ( "gorgonia.org/tensor" ) -func Example_batchedMatMul() { +func ExampleBatchedMatMul() { g := NewGraph() a := NewTensor(g, Float64, 3, WithShape(2, 2, 3), WithInit(RangedFrom(1)), WithName("a")) b := NewTensor(g, Float64, 3, WithShape(2, 3, 2), WithInit(RangedFrom(13)), WithName("b")) @@ -74,3 +74,40 @@ func TestIncrSlices(t *testing.T) { } } } + +func ExampleBatchedMatMul_withBackprop() { + g := NewGraph() + a := NewTensor(g, Float64, 4, WithShape(2, 4, 3, 9), WithInit(RangedFrom(1)), WithName("a")) + b := NewTensor(g, Float64, 4, WithShape(2, 4, 3, 9), WithInit(RangedFrom(13)), WithName("b")) + c, err := BatchedMatMul(a, b, false, true) + if err != nil { + log.Fatal(err) + } + s, err := Sum(c) + if err != nil { + log.Fatal(err) + } + grads, err := Grad(s, a, b) + if err != nil { + log.Fatal(err) + } + + m := NewTapeMachine(g) + if err := m.RunAll(); err != nil { + log.Fatal(err) + } + + fmt.Printf("a: %v\n%v\n", a.Value().Shape(), a.Value().Data()) + fmt.Printf("b: %v\n%v\n", b.Value().Shape(), b.Value().Data()) + fmt.Printf("c: %v\n%v\n", c.Value().Shape(), c.Value().Data()) + fmt.Printf("grads[0]:%v\n%v\n", grads[0].Shape(), grads[0].Value().Data()) + // Output: + // a: (2, 4, 3, 9) + // [1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216] + // b: (2, 4, 3, 9) + // [13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228] + // c: (2, 4, 3, 3) + // [825 1230 1635 2202 3336 4470 3579 5442 7305 12732 15324 17916 16296 19617 22938 19860 23910 27960 37761 42540 47319 43512 49020 54528 49263 55500 61737 75912 82878 89844 83850 91545 99240 91788 100212 108636 127185 136338 145491 137310 147192 157074 147435 158046 168657 191580 202920 214260 203892 215961 228030 216204 229002 241800 269097 282624 296151 283596 297852 312108 298095 313080 328065 359736 375450 391164 376422 392865 409308 393108 410280 427452] + // grads[0]:(2, 4, 3, 9) + // [66 69 72 75 78 81 84 87 90 66 69 72 75 78 81 84 87 90 66 69 72 75 78 81 84 87 90 147 150 153 156 159 162 165 168 171 147 150 153 156 159 162 165 168 171 147 150 153 156 159 162 165 168 171 228 231 234 237 240 243 246 249 252 228 231 234 237 240 243 246 249 252 228 231 234 237 240 243 246 249 252 309 312 315 318 321 324 327 330 333 309 312 315 318 321 324 327 330 333 309 312 315 318 321 324 327 330 333 390 393 396 399 402 405 408 411 414 390 393 396 399 402 405 408 411 414 390 393 396 399 402 405 408 411 414 471 474 477 480 483 486 489 492 495 471 474 477 480 483 486 489 492 495 471 474 477 480 483 486 489 492 495 552 555 558 561 564 567 570 573 576 552 555 558 561 564 567 570 573 576 552 555 558 561 564 567 570 573 576 633 636 639 642 645 648 651 654 657 633 636 639 642 645 648 651 654 657 633 636 639 642 645 648 651 654 657] +} diff --git a/examples/convnet/main.go b/examples/convnet/main.go index 05e65618..b59a56a7 100644 --- a/examples/convnet/main.go +++ b/examples/convnet/main.go @@ -3,7 +3,6 @@ package main import ( "flag" "fmt" - "io/ioutil" "log" "math/rand" "os" @@ -15,7 +14,7 @@ import ( _ "net/http/pprof" "github.com/pkg/errors" - "gorgonia.org/gorgonia" + G "gorgonia.org/gorgonia" "gorgonia.org/gorgonia/examples/mnist" "gorgonia.org/tensor" @@ -47,28 +46,20 @@ func parseDtype() { } } -type sli struct { - start, end int -} - -func (s sli) Start() int { return s.start } -func (s sli) End() int { return s.end } -func (s sli) Step() int { return 1 } - type convnet struct { - g *gorgonia.ExprGraph - w0, w1, w2, w3, w4 *gorgonia.Node // weights. the number at the back indicates which layer it's used for - d0, d1, d2, d3 float64 // dropout probabilities + g *G.ExprGraph + w0, w1, w2, w3, w4 *G.Node // weights. the number at the back indicates which layer it's used for + d0, d1, d2, d3 float64 // dropout probabilities - out *gorgonia.Node + out *G.Node } -func newConvNet(g *gorgonia.ExprGraph) *convnet { - w0 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(32, 1, 3, 3), gorgonia.WithName("w0"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) - w1 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(64, 32, 3, 3), gorgonia.WithName("w1"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) - w2 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(128, 64, 3, 3), gorgonia.WithName("w2"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) - w3 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128*3*3, 625), gorgonia.WithName("w3"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) - w4 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(625, 10), gorgonia.WithName("w4"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) +func newConvNet(g *G.ExprGraph) *convnet { + w0 := G.NewTensor(g, dt, 4, G.WithShape(32, 1, 3, 3), G.WithName("w0"), G.WithInit(G.GlorotN(1.0))) + w1 := G.NewTensor(g, dt, 4, G.WithShape(64, 32, 3, 3), G.WithName("w1"), G.WithInit(G.GlorotN(1.0))) + w2 := G.NewTensor(g, dt, 4, G.WithShape(128, 64, 3, 3), G.WithName("w2"), G.WithInit(G.GlorotN(1.0))) + w3 := G.NewMatrix(g, dt, G.WithShape(128*3*3, 625), G.WithName("w3"), G.WithInit(G.GlorotN(1.0))) + w4 := G.NewMatrix(g, dt, G.WithShape(625, 10), G.WithName("w4"), G.WithInit(G.GlorotN(1.0))) return &convnet{ g: g, w0: w0, @@ -84,89 +75,86 @@ func newConvNet(g *gorgonia.ExprGraph) *convnet { } } -func (m *convnet) learnables() gorgonia.Nodes { - return gorgonia.Nodes{m.w0, m.w1, m.w2, m.w3, m.w4} +func (m *convnet) learnables() G.Nodes { + return G.Nodes{m.w0, m.w1, m.w2, m.w3, m.w4} } // This function is particularly verbose for educational reasons. In reality, you'd wrap up the layers within a layer struct type and perform per-layer activations -func (m *convnet) fwd(x *gorgonia.Node) (err error) { - var c0, c1, c2, fc *gorgonia.Node - var a0, a1, a2, a3 *gorgonia.Node - var p0, p1, p2 *gorgonia.Node - var l0, l1, l2, l3 *gorgonia.Node +func (m *convnet) fwd(x *G.Node) (err error) { + var c0, c1, c2, fc *G.Node + var a0, a1, a2, a3 *G.Node + var p0, p1, p2 *G.Node + var l0, l1, l2, l3 *G.Node // LAYER 0 // here we convolve with stride = (1, 1) and padding = (1, 1), // which is your bog standard convolution for convnet - if c0, err = gorgonia.Conv2d(x, m.w0, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil { + if c0, err = G.Conv2d(x, m.w0, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil { return errors.Wrap(err, "Layer 0 Convolution failed") } - if a0, err = gorgonia.Rectify(c0); err != nil { + if a0, err = G.Rectify(c0); err != nil { return errors.Wrap(err, "Layer 0 activation failed") } - if p0, err = gorgonia.MaxPool2D(a0, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil { + if p0, err = G.MaxPool2D(a0, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil { return errors.Wrap(err, "Layer 0 Maxpooling failed") } log.Printf("p0 shape %v", p0.Shape()) - if l0, err = gorgonia.Dropout(p0, m.d0); err != nil { + if l0, err = G.Dropout(p0, m.d0); err != nil { return errors.Wrap(err, "Unable to apply a dropout") } // Layer 1 - if c1, err = gorgonia.Conv2d(l0, m.w1, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil { + if c1, err = G.Conv2d(l0, m.w1, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil { return errors.Wrap(err, "Layer 1 Convolution failed") } - if a1, err = gorgonia.Rectify(c1); err != nil { + if a1, err = G.Rectify(c1); err != nil { return errors.Wrap(err, "Layer 1 activation failed") } - if p1, err = gorgonia.MaxPool2D(a1, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil { + if p1, err = G.MaxPool2D(a1, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil { return errors.Wrap(err, "Layer 1 Maxpooling failed") } - if l1, err = gorgonia.Dropout(p1, m.d1); err != nil { + if l1, err = G.Dropout(p1, m.d1); err != nil { return errors.Wrap(err, "Unable to apply a dropout to layer 1") } // Layer 2 - if c2, err = gorgonia.Conv2d(l1, m.w2, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil { + if c2, err = G.Conv2d(l1, m.w2, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}); err != nil { return errors.Wrap(err, "Layer 2 Convolution failed") } - if a2, err = gorgonia.Rectify(c2); err != nil { + if a2, err = G.Rectify(c2); err != nil { return errors.Wrap(err, "Layer 2 activation failed") } - if p2, err = gorgonia.MaxPool2D(a2, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil { + if p2, err = G.MaxPool2D(a2, tensor.Shape{2, 2}, []int{0, 0}, []int{2, 2}); err != nil { return errors.Wrap(err, "Layer 2 Maxpooling failed") } - log.Printf("p2 shape %v", p2.Shape()) - var r2 *gorgonia.Node + var r2 *G.Node b, c, h, w := p2.Shape()[0], p2.Shape()[1], p2.Shape()[2], p2.Shape()[3] - if r2, err = gorgonia.Reshape(p2, tensor.Shape{b, c * h * w}); err != nil { + if r2, err = G.Reshape(p2, tensor.Shape{b, c * h * w}); err != nil { return errors.Wrap(err, "Unable to reshape layer 2") } log.Printf("r2 shape %v", r2.Shape()) - if l2, err = gorgonia.Dropout(r2, m.d2); err != nil { + if l2, err = G.Dropout(r2, m.d2); err != nil { return errors.Wrap(err, "Unable to apply a dropout on layer 2") } - ioutil.WriteFile("tmp.dot", []byte(m.g.ToDot()), 0644) - // Layer 3 - if fc, err = gorgonia.Mul(l2, m.w3); err != nil { + if fc, err = G.Mul(l2, m.w3); err != nil { return errors.Wrapf(err, "Unable to multiply l2 and w3") } - if a3, err = gorgonia.Rectify(fc); err != nil { + if a3, err = G.Rectify(fc); err != nil { return errors.Wrapf(err, "Unable to activate fc") } - if l3, err = gorgonia.Dropout(a3, m.d3); err != nil { + if l3, err = G.Dropout(a3, m.d3); err != nil { return errors.Wrapf(err, "Unable to apply a dropout on layer 3") } // output decode - var out *gorgonia.Node - if out, err = gorgonia.Mul(l3, m.w4); err != nil { + var out *G.Node + if out, err = G.Mul(l3, m.w4); err != nil { return errors.Wrapf(err, "Unable to multiply l3 and w4") } - m.out, err = gorgonia.SoftMax(out) + m.out, err = G.SoftMax(out) return } @@ -208,24 +196,24 @@ func main() { if err := inputs.Reshape(numExamples, 1, 28, 28); err != nil { log.Fatal(err) } - g := gorgonia.NewGraph() - x := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(bs, 1, 28, 28), gorgonia.WithName("x")) - y := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(bs, 10), gorgonia.WithName("y")) + g := G.NewGraph() + x := G.NewTensor(g, dt, 4, G.WithShape(bs, 1, 28, 28), G.WithName("x")) + y := G.NewMatrix(g, dt, G.WithShape(bs, 10), G.WithName("y")) m := newConvNet(g) if err = m.fwd(x); err != nil { log.Fatalf("%+v", err) } - losses := gorgonia.Must(gorgonia.Log(gorgonia.Must(gorgonia.HadamardProd(m.out, y)))) - cost := gorgonia.Must(gorgonia.Mean(losses)) - cost = gorgonia.Must(gorgonia.Neg(cost)) + losses := G.Must(G.Log(G.Must(G.HadamardProd(m.out, y)))) + cost := G.Must(G.Mean(losses)) + cost = G.Must(G.Neg(cost)) // we wanna track costs - var costVal gorgonia.Value - gorgonia.Read(cost, &costVal) + var costVal G.Value + G.Read(cost, &costVal) - // if _, err = gorgonia.Grad(cost, m.learnables()...); err != nil { - // log.Fatal(err) - // } + if _, err = G.Grad(cost, m.learnables()...); err != nil { + log.Fatal(err) + } // debug // ioutil.WriteFile("fullGraph.dot", []byte(g.ToDot()), 0644) @@ -233,11 +221,11 @@ func main() { // logger := log.New(os.Stderr, "", 0) // vm := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(m.learnables()...), gorgonia.WithLogger(logger), gorgonia.WithWatchlist()) - prog, locMap, _ := gorgonia.Compile(g) - log.Printf("%v", prog) + prog, locMap, _ := G.Compile(g) + //log.Printf("%v", prog) - vm := gorgonia.NewTapeMachine(g, gorgonia.WithPrecompiled(prog, locMap), gorgonia.BindDualValues(m.learnables()...)) - solver := gorgonia.NewRMSPropSolver(gorgonia.WithBatchSize(float64(bs))) + vm := G.NewTapeMachine(g, G.WithPrecompiled(prog, locMap), G.BindDualValues(m.learnables()...)) + solver := G.NewRMSPropSolver(G.WithBatchSize(float64(bs))) defer vm.Close() // pprof // handlePprof(sigChan, doneChan) @@ -275,23 +263,25 @@ func main() { } var xVal, yVal tensor.Tensor - if xVal, err = inputs.Slice(sli{start, end}); err != nil { + if xVal, err = inputs.Slice(G.S(start, end)); err != nil { log.Fatal("Unable to slice x") } - if yVal, err = targets.Slice(sli{start, end}); err != nil { + if yVal, err = targets.Slice(G.S(start, end)); err != nil { log.Fatal("Unable to slice y") } if err = xVal.(*tensor.Dense).Reshape(bs, 1, 28, 28); err != nil { log.Fatalf("Unable to reshape %v", err) } - gorgonia.Let(x, xVal) - gorgonia.Let(y, yVal) + G.Let(x, xVal) + G.Let(y, yVal) if err = vm.RunAll(); err != nil { - log.Fatalf("Failed at epoch %d: %v", i, err) + log.Fatalf("Failed at epoch %d, batch %d. Error: %v", i, b, err) + } + if err = solver.Step(G.NodesToValueGrads(m.learnables())); err != nil { + log.Fatalf("Failed to update nodes with gradients at epoch %d, batch %d. Error %v", i, b, err) } - solver.Step(gorgonia.NodesToValueGrads(m.learnables())) vm.Reset() bar.Increment() } diff --git a/examples/convnet_cuda/main.go b/examples/convnet_cuda/main.go index 2c761801..0d97c951 100644 --- a/examples/convnet_cuda/main.go +++ b/examples/convnet_cuda/main.go @@ -16,7 +16,7 @@ import ( _ "net/http/pprof" "github.com/pkg/errors" - "gorgonia.org/gorgonia" + G "gorgonia.org/gorgonia" "gorgonia.org/gorgonia/examples/mnist" nnops "gorgonia.org/gorgonia/ops/nn" "gorgonia.org/tensor" @@ -49,28 +49,20 @@ func parseDtype() { } } -type sli struct { - start, end int -} - -func (s sli) Start() int { return s.start } -func (s sli) End() int { return s.end } -func (s sli) Step() int { return 1 } - type convnet struct { - g *gorgonia.ExprGraph - w0, w1, w2, w3, w4 *gorgonia.Node // weights. the number at the back indicates which layer it's used for - d0, d1, d2, d3 float64 // dropout probabilities + g *G.ExprGraph + w0, w1, w2, w3, w4 *G.Node // weights. the number at the back indicates which layer it's used for + d0, d1, d2, d3 float64 // dropout probabilities - out *gorgonia.Node + out *G.Node } -func newConvNet(g *gorgonia.ExprGraph) *convnet { - w0 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(32, 1, 3, 3), gorgonia.WithName("w0"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) - w1 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(64, 32, 3, 3), gorgonia.WithName("w1"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) - w2 := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(128, 64, 3, 3), gorgonia.WithName("w2"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) - w3 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128*3*3, 625), gorgonia.WithName("w3"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) - w4 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(625, 10), gorgonia.WithName("w4"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) +func newConvNet(g *G.ExprGraph) *convnet { + w0 := G.NewTensor(g, dt, 4, G.WithShape(32, 1, 3, 3), G.WithName("w0"), G.WithInit(G.GlorotN(1.0))) + w1 := G.NewTensor(g, dt, 4, G.WithShape(64, 32, 3, 3), G.WithName("w1"), G.WithInit(G.GlorotN(1.0))) + w2 := G.NewTensor(g, dt, 4, G.WithShape(128, 64, 3, 3), G.WithName("w2"), G.WithInit(G.GlorotN(1.0))) + w3 := G.NewMatrix(g, dt, G.WithShape(128*3*3, 625), G.WithName("w3"), G.WithInit(G.GlorotN(1.0))) + w4 := G.NewMatrix(g, dt, G.WithShape(625, 10), G.WithName("w4"), G.WithInit(G.GlorotN(1.0))) return &convnet{ g: g, w0: w0, @@ -86,16 +78,16 @@ func newConvNet(g *gorgonia.ExprGraph) *convnet { } } -func (m *convnet) learnables() gorgonia.Nodes { - return gorgonia.Nodes{m.w0, m.w1, m.w2, m.w3, m.w4} +func (m *convnet) learnables() G.Nodes { + return G.Nodes{m.w0, m.w1, m.w2, m.w3, m.w4} } // This function is particularly verbose for educational reasons. In reality, you'd wrap up the layers within a layer struct type and perform per-layer activations -func (m *convnet) fwd(x *gorgonia.Node) (err error) { - var c0, c1, c2, fc *gorgonia.Node - var a0, a1, a2, a3 *gorgonia.Node - var p0, p1, p2 *gorgonia.Node - var l0, l1, l2, l3 *gorgonia.Node +func (m *convnet) fwd(x *G.Node) (err error) { + var c0, c1, c2, fc *G.Node + var a0, a1, a2, a3 *G.Node + var p0, p1, p2 *G.Node + var l0, l1, l2, l3 *G.Node // LAYER 0 // here we convolve with stride = (1, 1) and padding = (1, 1), @@ -141,9 +133,9 @@ func (m *convnet) fwd(x *gorgonia.Node) (err error) { } log.Printf("p2 shape %v", p2.Shape()) - var r2 *gorgonia.Node + var r2 *G.Node b, c, h, w := p2.Shape()[0], p2.Shape()[1], p2.Shape()[2], p2.Shape()[3] - if r2, err = gorgonia.Reshape(p2, tensor.Shape{b, c * h * w}); err != nil { + if r2, err = G.Reshape(p2, tensor.Shape{b, c * h * w}); err != nil { return errors.Wrap(err, "Unable to reshape layer 2") } log.Printf("r2 shape %v", r2.Shape()) @@ -153,7 +145,7 @@ func (m *convnet) fwd(x *gorgonia.Node) (err error) { log.Printf("l2 shape %v | %v", l2.Shape(), m.w3.Shape()) // Layer 3 - if fc, err = gorgonia.Mul(l2, m.w3); err != nil { + if fc, err = G.Mul(l2, m.w3); err != nil { return errors.Wrapf(err, "Unable to multiply l2 and w3") } if a3, err = nnops.Rectify(fc); err != nil { @@ -165,11 +157,11 @@ func (m *convnet) fwd(x *gorgonia.Node) (err error) { log.Printf("l3 name %v | a3 name %v", l3, a3) // output decode - var out *gorgonia.Node - if out, err = gorgonia.Mul(l3, m.w4); err != nil { + var out *G.Node + if out, err = G.Mul(l3, m.w4); err != nil { return errors.Wrapf(err, "Unable to multiply l3 and w4") } - m.out, err = gorgonia.SoftMax(out) + m.out, err = G.SoftMax(out) log.Printf("DONE") return } @@ -179,7 +171,7 @@ func main() { parseDtype() rand.Seed(1337) - log.Printf("gorgonia. %t", gorgonia.CUDA) + log.Printf("gorgonia. %t", G.CUDA) // intercept Ctrl+C sigChan := make(chan os.Signal, 1) @@ -214,24 +206,24 @@ func main() { if err := inputs.Reshape(numExamples, 1, 28, 28); err != nil { log.Fatal(err) } - g := gorgonia.NewGraph() - x := gorgonia.NewTensor(g, dt, 4, gorgonia.WithShape(bs, 1, 28, 28), gorgonia.WithName("x")) - y := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(bs, 10), gorgonia.WithName("y")) + g := G.NewGraph() + x := G.NewTensor(g, dt, 4, G.WithShape(bs, 1, 28, 28), G.WithName("x")) + y := G.NewMatrix(g, dt, G.WithShape(bs, 10), G.WithName("y")) m := newConvNet(g) if err = m.fwd(x); err != nil { log.Fatalf("%+v", err) } log.Printf("m.out.Shape %v, y.Shape %v", m.out.Shape(), y.Shape()) - losses := gorgonia.Must(gorgonia.Log(gorgonia.Must(gorgonia.HadamardProd(m.out, y)))) - cost := gorgonia.Must(gorgonia.Neg(losses)) - cost = gorgonia.Must(gorgonia.Mean(cost)) + losses := G.Must(G.Log(G.Must(G.HadamardProd(m.out, y)))) + cost := G.Must(G.Neg(losses)) + cost = G.Must(G.Mean(cost)) // we wanna track costs - var costVal, lossesVal gorgonia.Value - gorgonia.Read(losses, &lossesVal) - gorgonia.Read(cost, &costVal) + var costVal, lossesVal G.Value + G.Read(losses, &lossesVal) + G.Read(cost, &costVal) - if _, err = gorgonia.Grad(cost, m.learnables()...); err != nil { + if _, err = G.Grad(cost, m.learnables()...); err != nil { log.Fatalf("%+v", err) } @@ -245,8 +237,8 @@ func main() { // log.Printf("%v", prog) // vm := gorgonia.NewTapeMachine(g, gorgonia.WithPrecompiled(prog, locMap), gorgonia.BindDualValues(m.learnables()...)) - vm := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(m.learnables()...)) - solver := gorgonia.NewRMSPropSolver(gorgonia.WithBatchSize(float64(bs)), gorgonia.WithLearnRate(0.05)) + vm := G.NewTapeMachine(g, G.BindDualValues()) + solver := G.NewRMSPropSolver(G.WithBatchSize(float64(bs)), G.WithLearnRate(0.01)) defer vm.Close() // pprof @@ -270,9 +262,8 @@ func main() { bar.SetRefreshRate(time.Second) bar.SetMaxWidth(80) - var xxx int var avgcost float64 - + var costs []float64 for i := 0; i < *epochs; i++ { bar.Prefix(fmt.Sprintf("Epoch %d", i)) bar.Set(0) @@ -288,34 +279,41 @@ func main() { } var xVal, yVal tensor.Tensor - if xVal, err = inputs.Slice(sli{start, end}); err != nil { + if xVal, err = inputs.Slice(G.S(start, end)); err != nil { log.Fatal("Unable to slice x") } - if yVal, err = targets.Slice(sli{start, end}); err != nil { + if yVal, err = targets.Slice(G.S(start, end)); err != nil { log.Fatal("Unable to slice y") } if err = xVal.(*tensor.Dense).Reshape(bs, 1, 28, 28); err != nil { log.Fatal("Unable to reshape %v", err) } - gorgonia.Let(x, xVal) - gorgonia.Let(y, yVal) + G.Let(x, xVal) + G.Let(y, yVal) if err = vm.RunAll(); err != nil { log.Fatalf("Failed at epoch %d: %+v", i, err) } - solver.Step(gorgonia.NodesToValueGrads(m.learnables())) + solver.Step(G.NodesToValueGrads(m.learnables())) vm.Reset() bar.Increment() - if xxx < 5 { - log.Printf("Cost %v", costVal) - log.Printf("Y\n%#1.3f", y.Value()) - log.Printf("Losses\n%#1.3f", lossesVal) - xxx++ + switch dt { + case tensor.Float32: + c := float64(costVal.Data().(float32)) + avgcost += c + costs = append(costs, c) + + case tensor.Float64: + c := costVal.Data().(float64) + avgcost += c + costs = append(costs, c) + default: + panic("unsupported dtype") } - avgcost += costVal.Data().(float64) } log.Printf("Epoch %d | cost %v", i, avgcost/float64(batches)) + log.Printf("Costs %v", costs) avgcost = 0 } diff --git a/go.mod b/go.mod index 36eb6861..257efe7f 100644 --- a/go.mod +++ b/go.mod @@ -5,21 +5,20 @@ go 1.12 require ( github.com/awalterschulze/gographviz v0.0.0-20190221210632-1e9ccb565bca github.com/chewxy/hm v1.0.0 - github.com/chewxy/math32 v1.0.4 + github.com/chewxy/math32 v1.0.6 github.com/fatih/color v1.7.0 // indirect github.com/go-gota/gota v0.10.1 github.com/leesper/go_rng v0.0.0-20171009123644-5344a9259b21 github.com/mattn/go-colorable v0.1.4 // indirect github.com/pkg/errors v0.9.1 - github.com/seehuhn/mt19937 v0.0.0-20191220121156-d07252b9f9df - github.com/stretchr/testify v1.4.0 + github.com/stretchr/testify v1.6.0 github.com/xtgo/set v1.0.0 gonum.org/v1/gonum v0.7.0 gonum.org/v1/netlib v0.0.0-20200317120129-c5a04cffd98a gopkg.in/cheggaaa/pb.v1 v1.0.27 - gorgonia.org/cu v0.9.2 + gorgonia.org/cu v0.9.3 gorgonia.org/dawson v1.2.0 - gorgonia.org/tensor v0.9.6 + gorgonia.org/tensor v0.9.7 gorgonia.org/vecf32 v0.9.0 gorgonia.org/vecf64 v0.9.0 ) diff --git a/go.sum b/go.sum index 9e6d51a3..2f8516e2 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/chewxy/math32 v1.0.0 h1:RTt2SACA7BTzvbsAKVQJLZpV6zY2MZw4bW9L2HEKkHg= github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= github.com/chewxy/math32 v1.0.4 h1:dfqy3+BbCmet2zCkaDaIQv9fpMxnmYYlAEV2Iqe3DZo= github.com/chewxy/math32 v1.0.4/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/chewxy/math32 v1.0.6 h1:JWZYUNl2rtgVVui6z8JBsDgkOG2DYmfSODyo95yKfx4= +github.com/chewxy/math32 v1.0.6/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= github.com/cloudflare/cfssl v0.0.0-20190808011637-b1ec8c586c2a/go.mod h1:yMWuSON2oQp+43nFtAV/uvKQIFpSPerB57DCt9t8sSA= github.com/cznic/cc v0.0.0-20181122101902-d673e9b70d4d/go.mod h1:m3fD/V+XTB35Kh9zw6dzjMY+We0Q7PMf6LLIC4vuG9k= github.com/cznic/golex v0.0.0-20181122101858-9c343928389c/go.mod h1:+bmmJDNmKlhWNG+gwWCkaBoTy39Fs+bzRxVBzoTQbIc= @@ -16,8 +18,6 @@ github.com/cznic/strutil v0.0.0-20181122101858-275e90344537/go.mod h1:AHHPPPXTw0 github.com/cznic/xc v0.0.0-20181122101856-45b06973881e/go.mod h1:3oFoiOvCDBYH+swwf5+k/woVmWy7h1Fcyu8Qig/jjX0= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/delaneyj/cogent v0.0.0-20180619184653-2fcea326194c h1:UliKg7JACWAXDW7yFdms6lLwOLK7H3uId3NG5z4f378= -github.com/delaneyj/cogent v0.0.0-20180619184653-2fcea326194c/go.mod h1:hL/k6TDIq37bqQ6sySYVYw+Idnv0JkVmKsmedD5AduQ= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= @@ -50,21 +50,20 @@ github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y= github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= -github.com/mattn/gorgonia-cfd5423e2acc2f8c2b86 v0.0.0-20200313070349-288c2a647837 h1:/7GLXOx1Cd15DDfNpIZguExr6Ui5e2vKVbCf8x52ls0= -github.com/mattn/gorgonia-cfd5423e2acc2f8c2b86 v0.0.0-20200313070349-288c2a647837/go.mod h1:MGXCds9oIEtiTo7SSDV2qlEYxIFO0LdSOf4BlNJYr34= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/seehuhn/mt19937 v0.0.0-20191220121156-d07252b9f9df h1:rhEzo7J+sDOLI5NulkwtescnyYMSt4J5mkxDMgQRjN4= -github.com/seehuhn/mt19937 v0.0.0-20191220121156-d07252b9f9df/go.mod h1:w+IAy13Luqfsp+plFpT1RiqauADylJKmpkrWFwpjbsc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgho= +github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -111,18 +110,20 @@ gopkg.in/cheggaaa/pb.v1 v1.0.27 h1:kJdccidYzt3CaHD1crCFTS1hxyhSi059NhOFUf03YFo= gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorgonia.org/cu v0.9.0-beta h1:s4WQ6fiAGoErwIiXWHRB6Y9ydkx1vTTPwhWzoEZVePc= gorgonia.org/cu v0.9.0-beta/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8= -gorgonia.org/cu v0.9.2 h1:TEKj3VmeSe3CJwxi+Sn6wJMB8lpzhpq+XMq+yU0+Uks= -gorgonia.org/cu v0.9.2/go.mod h1:LgyAYDkN7HWhh8orGnCY2R8pP9PYbO44ivEbLMatkVU= +gorgonia.org/cu v0.9.3 h1:IkxE4NWXuZHqr8AnmgoB8WNQPZeD6u0EJNxYjDC0YgY= +gorgonia.org/cu v0.9.3/go.mod h1:LgyAYDkN7HWhh8orGnCY2R8pP9PYbO44ivEbLMatkVU= gorgonia.org/dawson v1.1.0 h1:o7+eJ3SKi9sheH19lpOat//tDbg0Y+M9iY/lH79VHqY= gorgonia.org/dawson v1.1.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws= gorgonia.org/dawson v1.2.0 h1:hJ/aofhfkReSnJdSMDzypRZ/oWDL1TmeYOauBnXKdFw= gorgonia.org/dawson v1.2.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws= gorgonia.org/gorgonia v0.9.2/go.mod h1:ZtOb9f/wM2OMta1ISGspQ4roGDgz9d9dKOaPNvGR+ec= gorgonia.org/tensor v0.9.0-beta/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w= -gorgonia.org/tensor v0.9.4 h1:5RRPp6tz3fRzIni1cMQyWT9QEQpfvu8cXibcEqU0GDU= -gorgonia.org/tensor v0.9.4/go.mod h1:603c/8huGtNc1APqh1nWqQu0fYgBvkwt55rvg4CWgZs= +gorgonia.org/tensor v0.9.7 h1:RncmNWe66zWDGMpDYFRXmReFkkMK7KOstELU/joamao= +gorgonia.org/tensor v0.9.7/go.mod h1:yYvRwsd34UdhG98GhzsB4YUVt3cQAQ4amoD/nuyhX+c= gorgonia.org/vecf32 v0.7.0/go.mod h1:iHG+kvTMqGYA0SgahfO2k62WRnxmHsqAREGbayRDzy8= gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg= gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= diff --git a/node.go b/node.go index 245a40d1..66ea349e 100644 --- a/node.go +++ b/node.go @@ -6,7 +6,6 @@ import ( "fmt" "hash" "hash/fnv" - "log" "github.com/awalterschulze/gographviz" "github.com/chewxy/hm" @@ -299,6 +298,10 @@ func (n *Node) Err() error { return nil } func (n *Node) DataSize() int { return n.Shape().TotalSize() } +func (n *Node) DerivOf() Nodes { return n.derivOf } + +func (n *Node) Deriv() *Node { return n.deriv } + // helper functions to help compilation process func (n *Node) isArg() bool { return n.op == nil } func (n *Node) isInput() bool { return (n.isArg() || n.isRandom()) && !n.isStmt } @@ -391,7 +394,6 @@ func (n *Node) Clone() (retVal interface{}) { if n.boundTo != nil { var err error if n2.boundTo, err = CloneValue(n.boundTo); err != nil { - log.Printf("Unable to clone %v\n%T\n%v", n, n.boundTo, n.boundTo) panic(err) } } @@ -475,7 +477,7 @@ func (n *Node) Strides() []int { case tensor.Tensor: return v.Strides() default: - log.Printf("Unhandled type for Strides(): %T. Using fallback method and assuming dense tensor types", n.boundTo) + panic(fmt.Sprintf("Unhandled type for Strides(): %T. Using fallback method and assuming dense tensor types", n.boundTo)) } } return n.shape.CalcStrides() @@ -631,9 +633,6 @@ func (n *Node) String() string { // TODO: check type, check shape, check if needsGrad -> promote to dualValue func (n *Node) bind(v Value) error { - // pc, _, _, _ := runtime.Caller(1) - // log.Printf("binding to %p. Called by %v", n, runtime.FuncForPC(pc).Name()) - if n.boundTo == nil { n.boundTo = v return nil @@ -651,7 +650,6 @@ func (n *Node) bind(v Value) error { } // n.boundTo = vdv // return nil - log.Printf("n %p", n) panic("Undefined behaviour") // no seriously there literally is no defined behaviour of what should the right thing be. I'll come back to this TODO. } dv.Value = v diff --git a/op_math.go b/op_math.go index e256450c..18064e19 100644 --- a/op_math.go +++ b/op_math.go @@ -575,6 +575,8 @@ func (op linAlgBinOp) InferShape(inputs ...DimSizer) (retVal tensor.Shape, err e // outerprods only handles vec x vec for now retVal = tensor.Shape{x.TotalSize(), y.TotalSize()} case batchedMatMulOperator: + x = x.Clone() + y = y.Clone() innerX := x[len(x)-2:] outerX := x[:len(x)-2] innerY := y[len(y)-2:] @@ -651,7 +653,7 @@ func (op linAlgBinOp) String() string { var buf bytes.Buffer switch op.āBinaryOperator { - case matMulOperator, matVecMulOperator: + case matMulOperator, matVecMulOperator, batchedMatMulOperator: buf.WriteString("A") case vecDotOperator, outerProdOperator: buf.WriteString("a") @@ -662,7 +664,7 @@ func (op linAlgBinOp) String() string { } switch op.āBinaryOperator { - case matMulOperator: + case matMulOperator, batchedMatMulOperator: fmt.Fprintf(&buf, " %v B", op.āBinaryOperator) case matVecMulOperator, vecDotOperator, outerProdOperator: fmt.Fprintf(&buf, " %v b", op.āBinaryOperator) diff --git a/op_math_cuda.go b/op_math_cuda.go index c10efdb0..3639f195 100644 --- a/op_math_cuda.go +++ b/op_math_cuda.go @@ -233,7 +233,8 @@ func (op linAlgBinOp) CUDADo(extern External, dev Device, prealloc Value, inputs case outerProdOperator: return tensor.Outer(aT, bT, tensor.WithReuse(pT)) case batchedMatMulOperator: - return nil, errors.New("NYI") + // checks were done when the op was created + return batchedMatMul(aT, bT, nil, op.transA, op.transB, false) } panic("Unreachable") } diff --git a/op_tensor.go b/op_tensor.go index 96f617e2..a84032b1 100644 --- a/op_tensor.go +++ b/op_tensor.go @@ -598,7 +598,7 @@ func (op *sliceOp) Do(inputs ...Value) (retVal Value, err error) { if v.IsScalar() { retVal, _ = anyToScalar(v.ScalarValue()) } else { - retVal = v + retVal = v.(tensor.View).Materialize() } case Scalar: return nil, errors.New("Cannot slice a scalar value") @@ -1223,7 +1223,12 @@ func (op reshapeOp) CUDADo(extern External, dev Device, prealloc Value, vals ... } return v, nil case Scalar: - return nil, errors.Errorf(nyiTypeFail, "reshape.Do", "Scalar") + vT := ScalarAsTensor(v, op.to.Dims(), nil) + if err := vT.(tensor.Tensor).Reshape(op.to...); err != nil { + + return nil, errors.Errorf(nyiTypeFail, "reshape.Do", "Scalar") + } + return vT, nil } panic("Unreachable") diff --git a/operations.go b/operations.go index 8c06d877..7f7f4a8d 100644 --- a/operations.go +++ b/operations.go @@ -39,7 +39,6 @@ func binOpNode(op BinaryOp, a, b *Node) (retVal *Node, err error) { leaveLogScope() } stabLogf("No bin op stabilization") - return ApplyOp(op, a, b) } diff --git a/operations_test.go b/operations_test.go index 84c7bc99..958e0b31 100644 --- a/operations_test.go +++ b/operations_test.go @@ -466,11 +466,11 @@ var sliceTests = []struct { {"vec[0]", tensor.Shape{2}, []tensor.Slice{S(0)}, scalarShape, float64(0), false}, {"vec[0:2]", tensor.Shape{2}, []tensor.Slice{S(0, 2)}, tensor.Shape{2}, []float64{0, 1}, false}, {"Mat[0]", tensor.Shape{2, 3}, []tensor.Slice{S(0)}, tensor.Shape{3}, []float64{0, 1, 2}, false}, - {"Mat[:, 0]", tensor.Shape{2, 3}, []tensor.Slice{nil, S(0)}, tensor.Shape{2}, []float64{0, 1, 2, 3}, false}, + {"Mat[:, 0]", tensor.Shape{2, 3}, []tensor.Slice{nil, S(0)}, tensor.Shape{2}, []float64{0, 3}, false}, {"3Tensor[0]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0)}, tensor.Shape{3, 4}, tensor.Range(tensor.Float64, 0, 12), false}, {"3Tensor[0:2]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0, 2)}, tensor.Shape{2, 3, 4}, tensor.Range(tensor.Float64, 0, 24), false}, - {"3Tensor[:, 0]", tensor.Shape{2, 3, 4}, []tensor.Slice{nil, S(0)}, tensor.Shape{2, 4}, tensor.Range(tensor.Float64, 0, 16), false}, - {"3Tensor[0, :, 0]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0), nil, S(0)}, tensor.Shape{3}, tensor.Range(tensor.Float64, 0, 9), false}, + {"3Tensor[:, 0]", tensor.Shape{2, 3, 4}, []tensor.Slice{nil, S(0)}, tensor.Shape{2, 4}, []float64{0, 1, 2, 3, 12, 13, 14, 15}, false}, + {"3Tensor[0, :, 0]", tensor.Shape{2, 3, 4}, []tensor.Slice{S(0), nil, S(0)}, tensor.Shape{3}, []float64{0, 4, 8}, false}, {"vec[:, 0]", tensor.Shape{2}, []tensor.Slice{nil, S(0)}, nil, nil, true}, } diff --git a/operatorLinAlg_const.go b/operatorLinAlg_const.go index 815694cf..f4e53200 100644 --- a/operatorLinAlg_const.go +++ b/operatorLinAlg_const.go @@ -5,11 +5,11 @@ import "github.com/chewxy/hm" // āBinOpStrs is the string representation for binLAOperator // It should be held constant var āBinOpStrs = [maxĀBinaryOperator]string{ - "×", - "×", - "⋅", - "⊗", - "×××", + "×", // matMulOperator + "×", // matVecMulOperator + "⋅", // vecDotOperator + "⊗", // outerProdOperator + "×××", // batchedMatMulOperator } var āBinOpDiffExprs = [maxĀBinaryOperator]func(tA, tB bool, x, y, z, grad *Node) (Nodes, error){ diff --git a/solvers.go b/solvers.go index aecf7f3d..90b70c43 100644 --- a/solvers.go +++ b/solvers.go @@ -29,15 +29,15 @@ func newCachedDV(n ValueGrad, weights, grad Value, zero bool) (cached *dualValue cached = new(dualValue) if cached.Value, err = CloneValue(weights); err != nil { if nm, ok := n.(Namer); ok { - return nil, errors.Errorf("Failed to clone weights of %v", nm.Name()) + return nil, errors.Wrapf(err, "Failed to clone weights of %v", nm.Name()) } - return nil, errors.New("Failed to clone weights") + return nil, errors.Wrap(err, "Failed to clone weights") } if cached.d, err = CloneValue(grad); err != nil { if nm, ok := n.(Namer); ok { - return nil, errors.Errorf("Failed to clone grad of %v", nm.Name()) + return nil, errors.Wrapf(err, "Failed to clone grad of %v", nm.Name()) } - return nil, errors.New("Failed to clone grad") + return nil, errors.Wrap(err, "Failed to clone grad") } if zero { cached.Value = ZeroValue(cached.Value) diff --git a/utils.go b/utils.go index 73570a08..633474c4 100644 --- a/utils.go +++ b/utils.go @@ -3,7 +3,6 @@ package gorgonia import ( "fmt" "hash/fnv" - "log" "math" "github.com/chewxy/math32" @@ -204,7 +203,6 @@ func hasNaN(v Value, dev Device) bool { ok, _ := e.HasNaN(vt) // BUG: errors not checked return ok } - log.Printf("Value's engine %T", vt.Engine()) dt := vt.Dtype() if dt != tensor.Float64 && dt != tensor.Float32 { diff --git a/values_utils.go b/values_utils.go index e344ffe8..f90039f6 100644 --- a/values_utils.go +++ b/values_utils.go @@ -53,6 +53,7 @@ func ValueEq(a, b Value) bool { case tensor.Tensor: if bt, ok := b.(tensor.Tensor); ok { return at.Eq(bt) + //log.Printf("at.info %#v, bt.info %#v", a.(*tensor.Dense).Info(), b.(*tensor.Dense).Info()) } return false case ValueEqualer: diff --git a/vm_tape.go b/vm_tape.go index 4f5a62e4..cde7c1e9 100644 --- a/vm_tape.go +++ b/vm_tape.go @@ -232,7 +232,6 @@ func (m *tapeMachine) runall(errChan chan error, doneChan chan struct{}) { for ; m.pc < len(m.p.instructions); m.pc++ { instr := m.p.instructions[m.pc] m.logf("PC %d", m.pc) - // log.Printf("PC %d", m.pc) if err := instr.exec(m); err != nil { err = errors.Wrapf(err, "PC %d. Failed to execute instruction %v", m.pc, instr) errChan <- err @@ -283,7 +282,6 @@ func (m *tapeMachine) runall(errChan chan error, doneChan chan struct{}) { } } } - doneChan <- struct{}{} } @@ -508,10 +506,16 @@ func (instr alloc) exec(m *tapeMachine) (err error) { return errors.Wrapf(err, dtypeExtractionFail, instr.t) } + reg := m.getValue(instr.writeTo) + if reg != nil && reg.Dtype() == dt && reg.Shape().Eq(instr.s) { + return nil + } + dev := instr.writeTo.device var v Value switch dev { case CPU: + v, err = makeValue(instr.t, instr.s) default: @@ -688,6 +692,12 @@ func (instr *readInstr) exec(m *tapeMachine) (err error) { return nyi("value of nil", "readInstr.exec") } + if *instr.into != nil { + dest := *instr.into + _, err = Copy(dest, v) + return err + } + v2, err := CloneValue(v) if err != nil { return errors.Wrap(err, cloneFail) diff --git a/vm_tape_nocuda.go b/vm_tape_nocuda.go index b8f9b088..2521e02f 100644 --- a/vm_tape_nocuda.go +++ b/vm_tape_nocuda.go @@ -32,6 +32,13 @@ func (instr *execOp) exec(m *tapeMachine) (err error) { } m.leaveLogScope() + // check if the destination has already been allocated + var usePrealloc bool + dest := instr.writeTo.id + if m.cpumem[dest] != nil { + usePrealloc = true + } + // Execute var v Value switch { @@ -47,6 +54,19 @@ func (instr *execOp) exec(m *tapeMachine) (err error) { return errors.Wrap(err, opDoFail) } } + case usePrealloc: + if pd, ok := instr.op.(UsePreallocDoer); ok { + p := m.cpumem[instr.writeTo.id] + if v, err = pd.UsePreallocDo(p, inputs...); err != nil { + if v, err = instr.op.Do(inputs...); err != nil { + return errors.Wrap(err, opDoFail) + } + } + } else { + if v, err = instr.op.Do(inputs...); err != nil { + return errors.Wrap(err, opDoFail) + } + } case instr.useUnsafe: if ud, ok := instr.op.(UnsafeDoer); ok { if v, err = ud.UnsafeDo(inputs...); err != nil { @@ -72,7 +92,7 @@ func (instr *execOp) exec(m *tapeMachine) (err error) { // Write setEngine(v, m.Engine) - dest := instr.writeTo.id + m.cpumem[dest] = v node := m.p.g.Node(instr.id).(*Node)