diff --git a/x/dataviz/traces.go b/x/dataviz/traces.go new file mode 100644 index 00000000..a72e695b --- /dev/null +++ b/x/dataviz/traces.go @@ -0,0 +1,65 @@ +package dataviz + +import ( + "encoding/json" + "io" + "strconv" + "time" + + "gorgonia.org/gorgonia" + xvm "gorgonia.org/gorgonia/x/vm" +) + +// DumpTrace suitable for https://github.com/vasturiano/timelines-chart +func DumpTrace(traces []xvm.Trace, g *gorgonia.ExprGraph, w io.Writer) error { + var zerotime time.Time + groups := make(map[string]group) + // generate all labels + for _, trace := range traces { + if trace.End == zerotime { + continue + } + if _, ok := groups[trace.StateFunction]; !ok { + groups[trace.StateFunction] = group{ + Group: trace.StateFunction, + } + } + label := dataLabel{ + TimeRange: []time.Time{ + trace.Start, + trace.End, + }, + Val: strconv.Itoa(int(trace.ID)), + } + dGroup := dataGroup{ + Label: g.Node(trace.ID).(*gorgonia.Node).Name(), + Data: []dataLabel{ + label, + }, + } + g := groups[trace.StateFunction] + g.Data = append(g.Data, dGroup) + groups[trace.StateFunction] = g + } + grps := make([]group, 0, len(groups)) + for _, grp := range groups { + grps = append(grps, grp) + } + enc := json.NewEncoder(w) + return enc.Encode(grps) +} + +type group struct { + Group string `json:"group"` + Data []dataGroup `json:"data"` +} + +type dataGroup struct { + Label string `json:"label"` + Data []dataLabel `json:"data"` +} + +type dataLabel struct { + TimeRange []time.Time `json:"timeRange"` + Val interface{} `json:"val"` +} diff --git a/x/dataviz/traces_test.go b/x/dataviz/traces_test.go new file mode 100644 index 00000000..617c75d6 --- /dev/null +++ b/x/dataviz/traces_test.go @@ -0,0 +1,30 @@ +package dataviz + +import ( + "context" + "log" + "os" + + "gorgonia.org/gorgonia" + xvm "gorgonia.org/gorgonia/x/vm" +) + +func ExampleDumpTrace() { + g := gorgonia.NewGraph() + // Add elements + ctx, traceC := xvm.WithTracing(context.Background()) + defer xvm.CloseTracing(ctx) + traces := make([]xvm.Trace, 0) + go func() { + for v := range traceC { + traces = append(traces, v) + } + }() + machine := xvm.NewMachine(g) + err := machine.Run(ctx) + if err != nil { + log.Fatal(err) + } + machine.Close() + DumpTrace(traces, g, os.Stdout) +} diff --git a/x/vm/machine.go b/x/vm/machine.go index c73a6e9c..4fffe03e 100644 --- a/x/vm/machine.go +++ b/x/vm/machine.go @@ -2,6 +2,9 @@ package xvm import ( "context" + "strconv" + "strings" + "time" "gorgonia.org/gorgonia" ) @@ -121,33 +124,58 @@ func (m *Machine) Close() { } } +type nodeError struct { + id int64 + t time.Time + err error +} + +type nodeErrors []nodeError + +func (e nodeErrors) Error() string { + var sb strings.Builder + for _, e := range e { + sb.WriteString(strconv.Itoa(int(e.id))) + sb.WriteString(":") + sb.WriteString(e.err.Error()) + sb.WriteString("\n") + } + return sb.String() +} + // Run performs the computation func (m *Machine) runAllNodes(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) - errC := make(chan error, 0) + errC := make(chan nodeError, 0) total := len(m.nodes) for i := range m.nodes { go func(n *node) { - errC <- n.Compute(ctx) + err := n.Compute(ctx) + errC <- nodeError{ + id: n.id, + t: time.Now(), + err: err, + } }(m.nodes[i]) } - var err error - for err = range errC { + errs := make([]nodeError, 0) + for e := range errC { total-- - if err != nil || total == 0 { - break + if e.err != nil { + errs = append(errs, e) + // failfast, on error, cancel + cancel() } - } - for moreChannel := true; moreChannel; { - select { - case <-errC: - default: - moreChannel = false + if total == 0 { + break } } cancel() close(errC) - return err + if len(errs) != 0 { + return nodeErrors(errs) + } + return nil } // GetResult stored in a node diff --git a/x/vm/machine_test.go b/x/vm/machine_test.go index e20f86e5..4cc8068e 100644 --- a/x/vm/machine_test.go +++ b/x/vm/machine_test.go @@ -2,6 +2,7 @@ package xvm import ( "context" + "errors" "fmt" "log" "reflect" @@ -29,6 +30,12 @@ func TestMachine_runAllNodes(t *testing.T) { outputC: outputC2, inputC: inputC2, } + errNode1 := &node{ + op: &errorOP{}, + inputValues: make([]gorgonia.Value, 2), + outputC: outputC2, + inputC: inputC2, + } type fields struct { nodes []*node pubsubs *pubsub @@ -52,7 +59,16 @@ func TestMachine_runAllNodes(t *testing.T) { }, false, }, - // TODO: Add test cases. + { + "error", + fields{ + nodes: []*node{n1, errNode1}, + }, + args{ + context.Background(), + }, + true, + }, } for _, tt := range tests { forty := gorgonia.F32(40.0) @@ -96,6 +112,9 @@ func TestMachine_runAllNodes(t *testing.T) { if err := m.runAllNodes(tt.args.ctx); (err != nil) != tt.wantErr { t.Errorf("Machine.runAllNodes() error = %v, wantErr %v", err, tt.wantErr) } + if tt.wantErr { + return + } out1 := <-outputC1 out2 := <-outputC2 if !reflect.DeepEqual(out1.Data(), fortyTwo.Data()) { @@ -533,3 +552,29 @@ func TestMachine_GetResult(t *testing.T) { }) } } + +func Test_nodeErrors_Error(t *testing.T) { + tests := []struct { + name string + e nodeErrors + want string + }{ + { + "simple", + []nodeError{ + { + id: 0, + err: errors.New("error"), + }, + }, + "0:error\n", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.e.Error(); got != tt.want { + t.Errorf("nodeErrors.Error() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/x/vm/node.go b/x/vm/node.go index d0590963..7d6f577b 100644 --- a/x/vm/node.go +++ b/x/vm/node.go @@ -95,7 +95,9 @@ func computeBackward(_ context.Context, _ *node) stateFn { func (n *node) Compute(ctx context.Context) error { for state := defaultState; state != nil; { + t := trace(ctx, nil, n, state) state = state(ctx, n) + trace(ctx, t, nil, nil) } return n.err } diff --git a/x/vm/node_test.go b/x/vm/node_test.go index 32e5fe49..a70218c4 100644 --- a/x/vm/node_test.go +++ b/x/vm/node_test.go @@ -217,6 +217,12 @@ func Test_node_ComputeForward(t *testing.T) { } } +type errorOP struct{} + +func (*errorOP) Do(v ...gorgonia.Value) (gorgonia.Value, error) { + return nil, errors.New("error") +} + type sumF32 struct{} func (*sumF32) Do(v ...gorgonia.Value) (gorgonia.Value, error) { diff --git a/x/vm/tracer.go b/x/vm/tracer.go new file mode 100644 index 00000000..22513646 --- /dev/null +++ b/x/vm/tracer.go @@ -0,0 +1,71 @@ +package xvm + +import ( + "context" + "reflect" + "runtime" + "time" +) + +// Trace the nodes states +type Trace struct { + //fmt.Println(runtime.FuncForPC(reflect.ValueOf(state).Pointer()).Name()) + StateFunction string + ID int64 + Start time.Time + End time.Time `json:",omitempty"` +} + +type chanTracerContextKey int + +const ( + globalTracerContextKey chanTracerContextKey = 0 +) + +// WithTracing initializes a tracing channel and adds it to the context +func WithTracing(parent context.Context) (context.Context, <-chan Trace) { + c := make(chan Trace, 0) + return context.WithValue(parent, globalTracerContextKey, c), c +} + +// CloseTracing the tracing channel to avoid context leak. +// it is a nil op if context does not carry tracing information +func CloseTracing(ctx context.Context) { + c := extractTracingChannel(ctx) + if c != nil { + close(c) + } +} + +func extractTracingChannel(ctx context.Context) chan<- Trace { + if ctx == nil { + return nil + } + if c := ctx.Value(globalTracerContextKey); c != nil { + return c.(chan Trace) + } + return nil +} + +var now = time.Now + +func trace(ctx context.Context, t *Trace, n *node, state stateFn) *Trace { + traceC := extractTracingChannel(ctx) + if traceC == nil { + return t + } + if t == nil { + t = &Trace{ + ID: n.id, + StateFunction: runtime.FuncForPC(reflect.ValueOf(state).Pointer()).Name(), + Start: now(), + } + } else { + t.End = now() + } + select { + case traceC <- *t: + case <-ctx.Done(): + } + return t +} diff --git a/x/vm/tracer_test.go b/x/vm/tracer_test.go new file mode 100644 index 00000000..f1459d51 --- /dev/null +++ b/x/vm/tracer_test.go @@ -0,0 +1,172 @@ +package xvm + +import ( + "context" + "fmt" + "reflect" + "testing" + "time" + + "gorgonia.org/gorgonia" +) + +func ExampleWithTracing() { + ctx, tracingC := WithTracing(context.Background()) + defer CloseTracing(ctx) + go func() { + for t := range tracingC { + fmt.Println(t) + } + }() + g := gorgonia.NewGraph() + // add operations etc... + machine := NewMachine(g) + defer machine.Close() + machine.Run(ctx) +} + +func TestWithTracing(t *testing.T) { + ctx, c := WithTracing(context.Background()) + cn := ctx.Value(globalTracerContextKey) + if cn == nil { + t.Fail() + } + if cn.(chan Trace) != c { + t.Fail() + } +} + +func Test_extractTracingChannel(t *testing.T) { + ctx, _ := WithTracing(context.Background()) + c := ctx.Value(globalTracerContextKey).(chan Trace) + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want chan<- Trace + }{ + { + "nil", + args{ + context.Background(), + }, + nil, + }, + { + "ok", + args{ + ctx, + }, + c, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := extractTracingChannel(tt.args.ctx); !reflect.DeepEqual(got, tt.want) { + t.Errorf("extractTracingChannel() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCloseTracing(t *testing.T) { + ctx, _ := WithTracing(context.Background()) + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + }{ + { + "no trace", + args{ + context.Background(), + }, + }, + { + "trace", + args{ + ctx, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + CloseTracing(tt.args.ctx) + }) + } +} + +func Test_trace(t *testing.T) { + now = func() time.Time { return time.Date(1977, time.September, 10, 10, 25, 00, 00, time.UTC) } + ctx, c := WithTracing(context.Background()) + defer CloseTracing(ctx) + go func() { + for range c { + } + }() + type args struct { + ctx context.Context + t *Trace + n *node + state stateFn + } + tests := []struct { + name string + args args + want *Trace + }{ + { + "no tracing context", + args{ + context.Background(), + nil, + nil, + nil, + }, + nil, + }, + { + "Context with nil trace", + args{ + ctx, + nil, + &node{ + id: 0, + }, + nil, + }, + &Trace{ + ID: 0, + Start: now(), + }, + }, + { + "Context existing trace", + args{ + ctx, + &Trace{ + ID: 1, + Start: now(), + }, + nil, + nil, + }, + &Trace{ + ID: 1, + Start: now(), + End: now(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := trace(tt.args.ctx, tt.args.t, tt.args.n, tt.args.state); !reflect.DeepEqual(got, tt.want) { + t.Errorf("trace() = %v, want %v", got, tt.want) + } + }) + } +}