/
utils.go
170 lines (152 loc) · 5.03 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
//
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package main
import (
"fmt"
"io/ioutil"
"sort"
"strconv"
"strings"
"github.com/apache/beam/sdks/go/pkg/beam/io/textio"
"github.com/google/differential-privacy/privacy-on-beam/codelab"
"github.com/apache/beam/sdks/go/pkg/beam"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/plotutil"
"gonum.org/v1/plot/vg"
)
const (
// Constants to differentiate between examples.
count = "count"
mean = "mean"
sum = "sum"
publicPartitions = "public_partitions"
)
func drawPlot(hourToValue, dpHourToValue map[int]float64, example, nonDPOutput, dpOutput string) error {
// Sort dp and non-dp points.
keys := make([]int, 0)
for k := range hourToValue {
keys = append(keys, k)
}
sort.Ints(keys)
points := make([]float64, 0)
for _, k := range keys {
points = append(points, hourToValue[k])
}
dpKeys := make([]int, 0)
for k := range dpHourToValue {
dpKeys = append(dpKeys, k)
}
sort.Ints(dpKeys)
dpPoints := make([]float64, 0)
for _, k := range dpKeys {
dpPoints = append(dpPoints, dpHourToValue[k])
}
p, err := plot.New()
if err != nil {
return fmt.Errorf("could not create plot: %v", err)
}
p.X.Label.Text = "Hour"
switch example {
case count, publicPartitions: // count & publicPartitions both compute visits per hour.
p.Y.Label.Text = "Visits"
p.Title.Text = "Visits Per Hour"
case mean:
p.Y.Label.Text = "Time Spent"
p.Title.Text = "Mean Time Spent"
case sum:
p.Y.Label.Text = "Revenue"
p.Title.Text = "Revenue Per Hour"
default:
return fmt.Errorf("unknown example %q specified, please use one of 'count', 'sum', 'mean', 'public_partitions'", example)
}
w := vg.Points(20)
// Non-DP Plot
bars, err := plotter.NewBarChart(plotter.Values(points), w)
if err != nil {
return fmt.Errorf("could not create bars from points %v: %v", plotter.Values(points), err)
}
bars.LineStyle.Width = vg.Length(0)
bars.Color = plotutil.Color(2)
p.Add(bars)
p.NominalX("0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "23")
// Save non-dp plot.
if err := p.Save(10*vg.Inch, 5*vg.Inch, nonDPOutput); err != nil {
return fmt.Errorf("Could not save plot: %v", err)
}
// DP Plot
dpBars, err := plotter.NewBarChart(plotter.Values(dpPoints), w)
if err != nil {
return fmt.Errorf("could not create bars from points %v: %v", plotter.Values(dpPoints), err)
}
dpBars.LineStyle.Width = vg.Length(0)
dpBars.Color = plotutil.Color(3)
dpBars.Offset = w
p.Add(dpBars)
p.Legend.Add("Raw", bars)
p.Legend.Add("Private", dpBars)
p.Legend.Top = true
// Save dp plot.
if err := p.Save(15*vg.Inch, 5*vg.Inch, dpOutput); err != nil {
return fmt.Errorf("Could not save plot: %v", err)
}
return nil
}
// readInput reads from a .csv file detailing visits to a restaurant in the form
// of "visitor_id, visit time, minutes spent, money spent" and returns a
// PCollection of Visit structs.
func readInput(s beam.Scope, input string) beam.PCollection {
s = s.Scope("readInput")
lines := textio.Read(s, input)
return beam.ParDo(s, codelab.CreateVisitsFn, lines)
}
func writeOutput(s beam.Scope, output beam.PCollection, outputTextName string) {
s = s.Scope("writeOutput")
output = beam.ParDo(s, convertToPairFn, output)
formattedOutput := beam.Combine(s, &normalizeOutputCombineFn{}, output)
textio.Write(s, outputTextName, formattedOutput)
}
// readOutput reads from a .txt file where each line has an hour (int) associated with
// a value (float64) separated by a whitespace and returns a map of these hour to value
// pairs.
// Returns an error if there is an error reading the output file.
func readOutput(output string) (map[int]float64, error) {
hourToValue := make(map[int]float64)
contents, err := ioutil.ReadFile(output)
if err != nil {
return nil, fmt.Errorf("could not read output file %s", output)
}
lines := strings.Split(string(contents), "\n")
for _, line := range lines {
if line == "" {
continue
}
elements := strings.Split(line, " ")
if len(elements) != 2 {
return nil, fmt.Errorf("got %d number of elements in line %q, expected 2", len(elements), line)
}
hour, err := strconv.Atoi(elements[0])
if err != nil {
return nil, fmt.Errorf("could not convert hour %s to int: %v", elements[0], err)
}
value, err := strconv.ParseFloat(elements[1], 64)
if err != nil {
return nil, fmt.Errorf("could not convert value %s to float64: %v", elements[1], err)
}
hourToValue[hour] = value
}
return hourToValue, nil
}