# Microsoft ML.NET prediction using Time Series 

### Install packages

In [37]:
#r "nuget:Microsoft.ML,1.5.0-preview2"
#r "nuget:Microsoft.ML.TimeSeries,1.5.0-preview2"
#r "nuget:Octokit, 0.32.0"
#r "nuget:NodaTime, 2.4.6"
using Octokit;
using NodaTime;
using NodaTime.Extensions;
using XPlot.Plotly;
using System;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.TimeSeries;

In [38]:
private const string DATA_FILEPATH_CN = "./covid_19_China.csv";
private const string DATA_FILEPATH_IT = "./covid_19_Italy.csv";
private const string DATA_FILEPATH_NL = "./covid_19_Netherlands.csv";
private const string DATA_FILEPATH_US = "./covid_19_USA.csv";

private static readonly MLContext mlContext = new MLContext();

### Data models

In [39]:
public class ModelInput
{
    [LoadColumn(0)]
    public DateTime Date { get; set; }

    [LoadColumn(1)]
    public float Confirmed { get; set; }

    [LoadColumn(2)]
    public float Deaths { get; set; }

    [LoadColumn(3)]
    public float Recovered { get; set; }
    
    [LoadColumn(4)]
    public float Features { get; set; }
}

public class ModelOutput
{
    public float[] Forecasted { get; set; }
}

### Load data views for China, Italy, Netherlands, USA

In [40]:
IDataView dataCn = mlContext.Data.LoadFromTextFile<ModelInput>(
    path: DATA_FILEPATH_CN,
    hasHeader: true,
    separatorChar: ',');

IDataView dataIt = mlContext.Data.LoadFromTextFile<ModelInput>(
    path: DATA_FILEPATH_IT,
    hasHeader: true,
    separatorChar: ',');

IDataView dataNl = mlContext.Data.LoadFromTextFile<ModelInput>(
    path: DATA_FILEPATH_NL,
    hasHeader: true,
    separatorChar: ',');

IDataView dataUs = mlContext.Data.LoadFromTextFile<ModelInput>(
    path: DATA_FILEPATH_US,
    hasHeader: true,
    separatorChar: ',');

### active = [confirmed - death - recovered] data snapshots

In [41]:
var activeCn = dataCn.Preview().RowView.Select(r => (float)r.Values[1].Value - (float)r.Values[2].Value - (float)r.Values[3].Value).ToList();
var activeIt = dataIt.Preview().RowView.Select(r => (float)r.Values[1].Value - (float)r.Values[2].Value - (float)r.Values[3].Value).ToList();
var activeNl = dataNl.Preview().RowView.Select(r => (float)r.Values[1].Value - (float)r.Values[2].Value - (float)r.Values[3].Value).ToList();
var activeUs = dataUs.Preview().RowView.Select(r => (float)r.Values[1].Value - (float)r.Values[2].Value - (float)r.Values[3].Value).ToList();

### datatime range snapshops

In [42]:
var dataRangeCn = dataCn.Preview().RowView.Select(r => DateTime.Parse(r.Values[0].Value.ToString())).ToList();
var dataRangeIt = dataIt.Preview().RowView.Select(r => DateTime.Parse(r.Values[0].Value.ToString())).ToList();
var dataRangeNl = dataNl.Preview().RowView.Select(r => DateTime.Parse(r.Values[0].Value.ToString())).ToList();
var dataRangeUs = dataUs.Preview().RowView.Select(r => DateTime.Parse(r.Values[0].Value.ToString())).ToList();

### [confirmed] data snapshots

In [43]:
var confirmedCn = dataCn.Preview().RowView.Select(r => (float)r.Values[1].Value).ToList();
var confirmedIt = dataIt.Preview().RowView.Select(r => (float)r.Values[1].Value).ToList();
var confirmedNl = dataNl.Preview().RowView.Select(r => (float)r.Values[1].Value).ToList();
var confirmedUs = dataUs.Preview().RowView.Select(r => (float)r.Values[1].Value).ToList();

### active = [confirmed - deaths - recovered] histogram

In [44]:
var activeHistogram = Chart.Plot(new[] {
    new Graph.Scatter { x = dataRangeCn, y = activeCn, name = "China" },
    new Graph.Scatter { x = dataRangeIt, y = activeIt, name = "Italy" }, 
    new Graph.Scatter { x = dataRangeNl, y = activeNl, name = "Netherlands" }, 
    new Graph.Scatter { x = dataRangeUs, y = activeUs, name = "USA" }
});

### get the current day from datasets (presumably all datasets are up to date!)

In [45]:
var currentDate = dataRangeCn.Last(); // China has the longest data range
var day = currentDate.ToString("dd MMM yyyy");

var layout = new Layout.Layout { title = $"Active = [Confirmed - Deaths - Recovered] per day, up to {day} " };
activeHistogram.WithLayout(layout);
display(activeHistogram);

### setup forecasting hyper-parameters

In [53]:
var windowSize = 10; // seasonality (days), 1 week
var horizon = 14; // days to predict, 2 weeks

### create estimator and prediction engine and then predict some infections for USA

In [54]:
var trainSize = activeUs.Count();
var seriesLength = trainSize;

IEstimator<ITransformer> estimatorUs = mlContext.Forecasting.ForecastBySsa(
    outputColumnName: nameof(ModelOutput.Forecasted),
    inputColumnName: nameof(ModelInput.Confirmed),
    windowSize: windowSize,
    seriesLength: seriesLength,
    trainSize: trainSize,
    horizon: horizon);

ITransformer transformerUs = estimatorUs.Fit(dataUs);

var engineUs = transformerUs.CreateTimeSeriesEngine<ModelInput, ModelOutput>(mlContext);
var predictionUs = engineUs.Predict();
var forecastedUs = predictionUs.Forecasted.Select(p => Math.Round(p));

### create estimator and prediction engine and then predict some infections for Netherlands

In [55]:
var trainSize = activeNl.Count();
var seriesLength = trainSize;

IEstimator<ITransformer> estimatorNl = mlContext.Forecasting.ForecastBySsa(
    outputColumnName: nameof(ModelOutput.Forecasted),
    inputColumnName: nameof(ModelInput.Confirmed),
    windowSize: windowSize,
    seriesLength: seriesLength,
    trainSize: trainSize,
    horizon: horizon);

ITransformer transformerNl = estimatorNl.Fit(dataNl);

var engineNl = transformerNl.CreateTimeSeriesEngine<ModelInput, ModelOutput>(mlContext);
var predictionNl = engineNl.Predict();
var forecastedNl = predictionNl.Forecasted.Select(p => Math.Round(p));

### create estimator and prediction engine and then predict some infections for Italy

In [56]:
var trainSize = activeIt.Count();
var seriesLength = trainSize;

IEstimator<ITransformer> estimatorIt = mlContext.Forecasting.ForecastBySsa(
    outputColumnName: nameof(ModelOutput.Forecasted),
    inputColumnName: nameof(ModelInput.Confirmed),
    windowSize: windowSize,
    seriesLength: seriesLength,
    trainSize: trainSize,
    horizon: horizon);

ITransformer transformerIt = estimatorIt.Fit(dataIt);

var engineIt = transformerIt.CreateTimeSeriesEngine<ModelInput, ModelOutput>(mlContext);
var predictionIt = engineIt.Predict();
var forecastedIt = predictionIt.Forecasted.Select(p => Math.Round(p));

### create estimator and prediction engine and then predict some infections for China

In [57]:
var trainSize = activeCn.Count();
var seriesLength = trainSize;

IEstimator<ITransformer> estimatorCn = mlContext.Forecasting.ForecastBySsa(
    outputColumnName: nameof(ModelOutput.Forecasted),
    inputColumnName: nameof(ModelInput.Confirmed),
    windowSize: windowSize,
    seriesLength: seriesLength,
    trainSize: trainSize,
    horizon: horizon);

ITransformer transformerCn = estimatorCn.Fit(dataCn);

var engineCn = transformerCn.CreateTimeSeriesEngine<ModelInput, ModelOutput>(mlContext);
var predictionCn = engineCn.Predict();
var forecastedCn = predictionCn.Forecasted.Select(p => Math.Round(p));

### generate days range for prediction

In [58]:
DateTime[] dateRangeAll = Enumerable.Range(1, horizon).Select(d => currentDate.AddDays(d)).ToArray();

### generate histogram for past and predicted confirmed infections

In [59]:
var predictedConfirmedHistogram = Chart.Plot(new[] { 
    new Graph.Scatter { x = dataRangeCn, y = confirmedCn, name = "China" },
    new Graph.Scatter { x = dataRangeIt, y = confirmedIt, name = "Italy" }, 
    new Graph.Scatter { x = dataRangeNl, y = confirmedNl, name = "Netherlands" }, 
    new Graph.Scatter { x = dataRangeUs, y = confirmedUs, name = "USA"}, 

    new Graph.Scatter { x = dateRangeAll, y = forecastedCn, name = "China (predicted)" }, 
    new Graph.Scatter { x = dateRangeAll, y = forecastedIt, name = "Italy (predicted)" }, 
    new Graph.Scatter { x = dateRangeAll, y = forecastedNl, name = "Netherlands (predicted)" }, 
    new Graph.Scatter { x = dateRangeAll, y = forecastedUs, name = "USA (predicted)" } 
});

var layout = new Layout.Layout(){title=$"[Confirmed] infections per day (up to {day}) and predictions for {horizon} days"};
predictedConfirmedHistogram.WithLayout(layout);
display(predictedConfirmedHistogram);