Skip to content

Commit

Permalink
plotter: Ensure Sankey can be drawn more than once
Browse files Browse the repository at this point in the history
  • Loading branch information
ctessum committed Aug 7, 2018
1 parent 7d21338 commit febd634
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 13 deletions.
25 changes: 12 additions & 13 deletions plotter/sankey.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ type stock struct {

// max is min plus the larger of receptorValue and sourceValue.
max float64

// sourceFlowPlaceholder and receptorFlowPlaceholder track
// the current plotting location during
// the plotting process.
sourceFlowPlaceholder, receptorFlowPlaceholder float64
}

// A Flow represents the amount of an entity flowing between two stocks.
Expand Down Expand Up @@ -209,18 +204,24 @@ func NewSankey(flows ...Flow) (*Sankey, error) {
func (s *Sankey) Plot(c draw.Canvas, plt *plot.Plot) {
trCat, trVal := plt.Transforms(&c)

// sourceFlowPlaceholder and receptorFlowPlaceholder track
// the current plotting location during
// the plotting process.
sourceFlowPlaceholder := make(map[*stock]float64, len(s.flows))
receptorFlowPlaceholder := make(map[*stock]float64, len(s.flows))

// Here we draw the flows.
for _, f := range s.flows {
startStock := s.stocks[f.SourceCategory][f.SourceLabel]
endStock := s.stocks[f.ReceptorCategory][f.ReceptorLabel]
catStart := trCat(float64(f.SourceCategory)) + s.StockBarWidth/2
catEnd := trCat(float64(f.ReceptorCategory)) - s.StockBarWidth/2
valStartLow := trVal(startStock.min + startStock.sourceFlowPlaceholder)
valEndLow := trVal(endStock.min + endStock.receptorFlowPlaceholder)
valStartHigh := trVal(startStock.min + startStock.sourceFlowPlaceholder + f.Value)
valEndHigh := trVal(endStock.min + endStock.receptorFlowPlaceholder + f.Value)
startStock.sourceFlowPlaceholder += f.Value
endStock.receptorFlowPlaceholder += f.Value
valStartLow := trVal(startStock.min + sourceFlowPlaceholder[startStock])
valEndLow := trVal(endStock.min + receptorFlowPlaceholder[endStock])
valStartHigh := trVal(startStock.min + sourceFlowPlaceholder[startStock] + f.Value)
valEndHigh := trVal(endStock.min + receptorFlowPlaceholder[endStock] + f.Value)
sourceFlowPlaceholder[startStock] += f.Value
receptorFlowPlaceholder[endStock] += f.Value

ptsLow := s.bezier(
vg.Point{X: catStart, Y: valStartLow},
Expand Down Expand Up @@ -327,8 +328,6 @@ func (s *Sankey) setStockRange(stocks *[]*stock) {
var cat int
var min float64
for _, stk := range *stocks {
stk.sourceFlowPlaceholder = 0
stk.receptorFlowPlaceholder = 0
if stk.category != cat {
min = 0
}
Expand Down
52 changes: 52 additions & 0 deletions plotter/sankey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"gonum.org/v1/plot/internal/cmpimg"
"gonum.org/v1/plot/vg"
"gonum.org/v1/plot/vg/draw"
"gonum.org/v1/plot/vg/recorder"
"gonum.org/v1/plot/vg/vgimg"
)

Expand Down Expand Up @@ -418,3 +419,54 @@ func ExampleSankey_grouped() {
func TestSankey_grouped(t *testing.T) {
cmpimg.CheckPlot(ExampleSankey_grouped, t, "sankeyGrouped.png")
}

// This test checks whether the Sankey plotter makes any changes to
// the input Flows.
func TestSankey_idempotent(t *testing.T) {
flows := []Flow{
Flow{
SourceCategory: 0,
SourceLabel: "Large",
ReceptorCategory: 1,
ReceptorLabel: "Mohamed",
Value: 5,
},
Flow{
SourceCategory: 0,
SourceLabel: "Small",
ReceptorCategory: 1,
ReceptorLabel: "Sofia",
Value: 5,
},
}
s, err := NewSankey(flows...)
if err != nil {
t.Fatal(err)
}
p, err := plot.New()
if err != nil {
t.Fatal(err)
}
p.Add(s)
p.HideAxes()

// Draw the plot once.
c1 := new(recorder.Canvas)
dc1 := draw.NewCanvas(c1, vg.Centimeter, vg.Centimeter)
p.Draw(dc1)

// Draw the plot a second time.
c2 := new(recorder.Canvas)
dc2 := draw.NewCanvas(c2, vg.Centimeter, vg.Centimeter)
p.Draw(dc2)

if len(c1.Actions) != len(c2.Actions) {
t.Errorf("inconsistent number of actions: %d != %d", len(c2.Actions), len(c1.Actions))
}

for i, a1 := range c1.Actions {
if a1.Call() != c2.Actions[i].Call() {
t.Errorf("action %d: %s\n\t!= %s", i, c2.Actions[i].Call(), a1.Call())
}
}
}

0 comments on commit febd634

Please sign in to comment.