-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Program.cs
328 lines (253 loc) · 14.4 KB
/
Program.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using Common;
using Microsoft.ML;
using Microsoft.ML.Data;
using PLplot;
using Regression_TaxiFarePrediction.DataStructures;
using static Microsoft.ML.Transforms.NormalizingEstimator;
namespace Regression_TaxiFarePrediction
{
internal static class Program
{
private static string AppPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]);
private static string BaseDatasetsRelativePath = @"../../../../Data";
private static string TrainDataRelativePath = $"{BaseDatasetsRelativePath}/taxi-fare-train.csv";
private static string TestDataRelativePath = $"{BaseDatasetsRelativePath}/taxi-fare-test.csv";
private static string TrainDataPath = GetAbsolutePath(TrainDataRelativePath);
private static string TestDataPath = GetAbsolutePath(TestDataRelativePath);
private static string BaseModelsRelativePath = @"../../../../MLModels";
private static string ModelRelativePath = $"{BaseModelsRelativePath}/TaxiFareModel.zip";
private static string ModelPath = GetAbsolutePath(ModelRelativePath);
static void Main(string[] args) //If args[0] == "svg" a vector-based chart will be created instead a .png chart
{
//Create ML Context with seed for repeatable/deterministic results
MLContext mlContext = new MLContext(seed: 0);
// Create, Train, Evaluate and Save a model
BuildTrainEvaluateAndSaveModel(mlContext);
// Make a single test prediction loding the model from .ZIP file
TestSinglePrediction(mlContext);
// Paint regression distribution chart for a number of elements read from a Test DataSet file
PlotRegressionChart(mlContext, TestDataPath, 100, args);
Console.WriteLine("Press any key to exit..");
Console.ReadLine();
}
private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
{
// STEP 1: Common data loading configuration
IDataView baseTrainingDataView = mlContext.Data.LoadFromTextFile<TaxiTrip>(TrainDataPath, hasHeader: true, separatorChar: ',');
IDataView testDataView = mlContext.Data.LoadFromTextFile<TaxiTrip>(TestDataPath, hasHeader: true, separatorChar: ',');
//Sample code of removing extreme data like "outliers" for FareAmounts higher than $150 and lower than $1 which can be error-data
var cnt = baseTrainingDataView.GetColumn<float>(nameof(TaxiTrip.FareAmount)).Count();
IDataView trainingDataView = mlContext.Data.FilterRowsByColumn(baseTrainingDataView, nameof(TaxiTrip.FareAmount), lowerBound: 1, upperBound: 150);
var cnt2 = trainingDataView.GetColumn<float>(nameof(TaxiTrip.FareAmount)).Count();
// STEP 2: Common data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: nameof(TaxiTrip.FareAmount))
.Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "VendorIdEncoded", inputColumnName: nameof(TaxiTrip.VendorId)))
.Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "RateCodeEncoded", inputColumnName: nameof(TaxiTrip.RateCode)))
.Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "PaymentTypeEncoded",inputColumnName: nameof(TaxiTrip.PaymentType)))
.Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.PassengerCount)))
.Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.TripTime)))
.Append(mlContext.Transforms.NormalizeMeanVariance(outputColumnName: nameof(TaxiTrip.TripDistance)))
.Append(mlContext.Transforms.Concatenate("Features", "VendorIdEncoded", "RateCodeEncoded", "PaymentTypeEncoded", nameof(TaxiTrip.PassengerCount)
, nameof(TaxiTrip.TripTime), nameof(TaxiTrip.TripDistance)));
// (OPTIONAL) Peek data (such as 5 records) in training DataView after applying the ProcessPipeline's transformations into "Features"
ConsoleHelper.PeekDataViewInConsole(mlContext, trainingDataView, dataProcessPipeline, 5);
ConsoleHelper.PeekVectorColumnDataInConsole(mlContext, "Features", trainingDataView, dataProcessPipeline, 5);
// STEP 3: Set the training algorithm, then create and config the modelBuilder - Selected Trainer (SDCA Regression algorithm)
var trainer = mlContext.Regression.Trainers.Sdca(labelColumnName: "Label", featureColumnName: "Features");
var trainingPipeline = dataProcessPipeline.Append(trainer);
// STEP 4: Train the model fitting to the DataSet
//The pipeline is trained on the dataset that has been loaded and transformed.
Console.WriteLine("=============== Training the model ===============");
var trainedModel = trainingPipeline.Fit(trainingDataView);
// STEP 5: Evaluate the model and show accuracy stats
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
IDataView predictions = trainedModel.Transform(testDataView);
var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: "Label", scoreColumnName: "Score");
Common.ConsoleHelper.PrintRegressionMetrics(trainer.ToString(), metrics);
// STEP 6: Save/persist the trained model to a .ZIP file
mlContext.Model.Save(trainedModel, trainingDataView.Schema, ModelPath);
Console.WriteLine("The model is saved to {0}", ModelPath);
return trainedModel;
}
private static void TestSinglePrediction(MLContext mlContext)
{
//Sample:
//vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
//VTS,1,1,1140,3.75,CRD,15.5
var taxiTripSample = new TaxiTrip()
{
VendorId = "VTS",
RateCode = "1",
PassengerCount = 1,
TripTime = 1140,
TripDistance = 3.75f,
PaymentType = "CRD",
FareAmount = 0 // To predict. Actual/Observed = 15.5
};
///
ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);
// Create prediction engine related to the loaded trained model
var predEngine = mlContext.Model.CreatePredictionEngine<TaxiTrip, TaxiTripFarePrediction>(trainedModel);
//Score
var resultprediction = predEngine.Predict(taxiTripSample);
///
Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted fare: {resultprediction.FareAmount:0.####}, actual fare: 15.5");
Console.WriteLine($"**********************************************************************");
}
private static void PlotRegressionChart(MLContext mlContext,
string testDataSetPath,
int numberOfRecordsToRead,
string[] args)
{
ITransformer trainedModel;
using (var stream = new FileStream(ModelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
trainedModel = mlContext.Model.Load(stream, out var modelInputSchema);
}
// Create prediction engine related to the loaded trained model
var predFunction = mlContext.Model.CreatePredictionEngine<TaxiTrip, TaxiTripFarePrediction>(trainedModel);
string chartFileName = "";
using (var pl = new PLStream())
{
// use SVG backend and write to SineWaves.svg in current directory
if (args.Length == 1 && args[0] == "svg")
{
pl.sdev("svg");
chartFileName = "TaxiRegressionDistribution.svg";
pl.sfnam(chartFileName);
}
else
{
pl.sdev("pngcairo");
chartFileName = "TaxiRegressionDistribution.png";
pl.sfnam(chartFileName);
}
// use white background with black foreground
pl.spal0("cmap0_alternate.pal");
// Initialize plplot
pl.init();
// set axis limits
const int xMinLimit = 0;
const int xMaxLimit = 35; //Rides larger than $35 are not shown in the chart
const int yMinLimit = 0;
const int yMaxLimit = 35; //Rides larger than $35 are not shown in the chart
pl.env(xMinLimit, xMaxLimit, yMinLimit, yMaxLimit, AxesScale.Independent, AxisBox.BoxTicksLabelsAxes);
// Set scaling for mail title text 125% size of default
pl.schr(0, 1.25);
// The main title
pl.lab("Measured", "Predicted", "Distribution of Taxi Fare Prediction");
// plot using different colors
// see http://plplot.sourceforge.net/examples.php?demo=02 for palette indices
pl.col0(1);
int totalNumber = numberOfRecordsToRead;
var testData = new TaxiTripCsvReader().GetDataFromCsv(testDataSetPath, totalNumber).ToList();
//This code is the symbol to paint
char code = (char)9;
// plot using other color
//pl.col0(9); //Light Green
//pl.col0(4); //Red
pl.col0(2); //Blue
double yTotal = 0;
double xTotal = 0;
double xyMultiTotal = 0;
double xSquareTotal = 0;
for (int i = 0; i < testData.Count; i++)
{
var x = new double[1];
var y = new double[1];
//Make Prediction
var FarePrediction = predFunction.Predict(testData[i]);
x[0] = testData[i].FareAmount;
y[0] = FarePrediction.FareAmount;
//Paint a dot
pl.poin(x, y, code);
xTotal += x[0];
yTotal += y[0];
double multi = x[0] * y[0];
xyMultiTotal += multi;
double xSquare = x[0] * x[0];
xSquareTotal += xSquare;
double ySquare = y[0] * y[0];
Console.WriteLine($"-------------------------------------------------");
Console.WriteLine($"Predicted : {FarePrediction.FareAmount}");
Console.WriteLine($"Actual: {testData[i].FareAmount}");
Console.WriteLine($"-------------------------------------------------");
}
// Regression Line calculation explanation:
// https://www.khanacademy.org/math/statistics-probability/describing-relationships-quantitative-data/more-on-regression/v/regression-line-example
double minY = yTotal / totalNumber;
double minX = xTotal / totalNumber;
double minXY = xyMultiTotal / totalNumber;
double minXsquare = xSquareTotal / totalNumber;
double m = ((minX * minY) - minXY) / ((minX * minX) - minXsquare);
double b = minY - (m * minX);
//Generic function for Y for the regression line
// y = (m * x) + b;
double x1 = 1;
//Function for Y1 in the line
double y1 = (m * x1) + b;
double x2 = 39;
//Function for Y2 in the line
double y2 = (m * x2) + b;
var xArray = new double[2];
var yArray = new double[2];
xArray[0] = x1;
yArray[0] = y1;
xArray[1] = x2;
yArray[1] = y2;
pl.col0(4);
pl.line(xArray, yArray);
// end page (writes output to disk)
pl.eop();
// output version of PLplot
pl.gver(out var verText);
Console.WriteLine("PLplot version " + verText);
} // the pl object is disposed here
// Open Chart File In Microsoft Photos App (Or default app, like browser for .svg)
Console.WriteLine("Showing chart...");
var p = new Process();
string chartFileNamePath = @".\" + chartFileName;
p.StartInfo = new ProcessStartInfo(chartFileNamePath)
{
UseShellExecute = true
};
p.Start();
}
public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
return fullPath;
}
}
public class TaxiTripCsvReader
{
public IEnumerable<TaxiTrip> GetDataFromCsv(string dataLocation, int numMaxRecords)
{
IEnumerable<TaxiTrip> records =
File.ReadAllLines(dataLocation)
.Skip(1)
.Select(x => x.Split(','))
.Select(x => new TaxiTrip()
{
VendorId = x[0],
RateCode = x[1],
PassengerCount = float.Parse(x[2], CultureInfo.InvariantCulture),
TripTime = float.Parse(x[3], CultureInfo.InvariantCulture),
TripDistance = float.Parse(x[4], CultureInfo.InvariantCulture),
PaymentType = x[5],
FareAmount = float.Parse(x[6], CultureInfo.InvariantCulture)
})
.Take<TaxiTrip>(numMaxRecords);
return records;
}
}
}