In [124]:
#r "nuget:Microsoft.ML,1.5.0-preview2"
#r "nuget:Microsoft.ML.TimeSeries,1.5.0-preview2"
#r "nuget:Octokit, 0.32.0"
#r "nuget:NodaTime, 2.4.6"
using Octokit;
using NodaTime;
using NodaTime.Extensions;
using XPlot.Plotly;
using System;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.TimeSeries;

In [125]:
private const string DATA_FILEPATH_CN_H = "./covid_19_Hubei.csv";
private const string DATA_FILEPATH_NL = "./covid_19_Netherlands.csv";
private const string DATA_FILEPATH_RO = "./covid_19_Romania.csv";
private const string DATA_FILEPATH_IT = "./covid_19_Italy.csv";

private static readonly MLContext mlContext = new MLContext();

In [126]:
// data models

public class ModelInput
{
    [LoadColumn(0)]
    public DateTime Date { get; set; }

    [LoadColumn(1)]
    public float Confirmed { get; set; }

    [LoadColumn(2)]
    public float Deaths { get; set; }

    [LoadColumn(3)]
    public float Recovered { get; set; }
}

public class ModelOutput
{
    public float[] Forecasted { get; set; }
}

In [127]:
// data views

IDataView dataRo = mlContext.Data.LoadFromTextFile<ModelInput>(
    path: DATA_FILEPATH_RO,
    hasHeader: true,
    separatorChar: ',');

IDataView dataNl = mlContext.Data.LoadFromTextFile<ModelInput>(
    path: DATA_FILEPATH_NL,
    hasHeader: true,
    separatorChar: ',');

IDataView dataIt = mlContext.Data.LoadFromTextFile<ModelInput>(
    path: DATA_FILEPATH_IT,
    hasHeader: true,
    separatorChar: ',');

IDataView dataCnH = mlContext.Data.LoadFromTextFile<ModelInput>(
    path: DATA_FILEPATH_CN_H,
    hasHeader: true,
    separatorChar: ',');

In [128]:
// infected and datetime range

var xRo = dataRo.Preview().RowView.Select(r => (float)r.Values[1].Value).ToList();
var xNl = dataNl.Preview().RowView.Select(r => (float)r.Values[1].Value).ToList();
var xIt = dataIt.Preview().RowView.Select(r => (float)r.Values[1].Value).ToList();
var xCnH = dataCnH.Preview().RowView.Select(r => (float)r.Values[1].Value).ToList();

var yRo = dataRo.Preview().RowView.Select(r => DateTime.Parse(r.Values[0].Value.ToString())).ToList();
var yNl = dataNl.Preview().RowView.Select(r => DateTime.Parse(r.Values[0].Value.ToString())).ToList();
var yIt = dataIt.Preview().RowView.Select(r => DateTime.Parse(r.Values[0].Value.ToString())).ToList();
var yCnH = dataCnH.Preview().RowView.Select(r => DateTime.Parse(r.Values[0].Value.ToString())).ToList();

In [129]:
// infected histogram

var infectionHistogram = Chart.Plot(new[] {
    new Graph.Scatter { x = yRo, y = xRo, name = "Romania" }, 
    new Graph.Scatter { x = yNl, y = xNl, name = "Netherlands" }, 
    new Graph.Scatter { x = yIt, y = xIt, name = "Italy" }, 
    new Graph.Scatter { x = yCnH, y = xCnH, name = "Hubei" }
});

var startDate = yRo.Last();
var day = startDate.ToString("dd MMMM yyyy");

var layout = new Layout.Layout { title = $"Infections per day up to {day} " };
infectionHistogram.WithLayout(layout);
display(infectionHistogram);

In [115]:
// number of days to predict

var horizon = 28;

In [116]:
// create estimator and prediction engine and them predict some infections for Romania

var windowSize = xRo.Count() / 3;
var trainSize = xRo.Count();

IEstimator<ITransformer> estimatorRo = mlContext.Forecasting.ForecastBySsa(
    outputColumnName: nameof(ModelOutput.Forecasted),
    inputColumnName: nameof(ModelInput.Confirmed),
    windowSize: windowSize,
    seriesLength: windowSize + 1,
    trainSize: trainSize,
    horizon: horizon);

ITransformer transformerRo = estimatorRo.Fit(dataRo);

var engineRo = transformerRo.CreateTimeSeriesEngine<ModelInput, ModelOutput>(mlContext);
var predictionRo = engineRo.Predict();
var forecastedRo = predictionRo.Forecasted.Select(p => Math.Round(p));

In [117]:
// create estimator and prediction engine and them predict some infections for Netherlands

var windowSize = xNl.Count() / 3;
var trainSize = xNl.Count();

IEstimator<ITransformer> estimatorNl = mlContext.Forecasting.ForecastBySsa(
    outputColumnName: nameof(ModelOutput.Forecasted),
    inputColumnName: nameof(ModelInput.Confirmed),
    windowSize: windowSize,
    seriesLength: windowSize + 1,
    trainSize: trainSize,
    horizon: horizon);

ITransformer transformerNl = estimatorNl.Fit(dataNl);

var engineNl = transformerNl.CreateTimeSeriesEngine<ModelInput, ModelOutput>(mlContext);
var predictionNl = engineNl.Predict();
var forecastedNl = predictionNl.Forecasted.Select(p => Math.Round(p));

In [118]:
// create estimator and prediction engine and them predict some infections for Italy

var windowSize = xIt.Count() / 3;
var trainSize = xIt.Count();

IEstimator<ITransformer> estimatorIt = mlContext.Forecasting.ForecastBySsa(
    outputColumnName: nameof(ModelOutput.Forecasted),
    inputColumnName: nameof(ModelInput.Confirmed),
    windowSize: windowSize,
    seriesLength: windowSize + 1,
    trainSize: trainSize,
    horizon: horizon);

ITransformer transformerIt = estimatorIt.Fit(dataIt);

var engineIt = transformerIt.CreateTimeSeriesEngine<ModelInput, ModelOutput>(mlContext);
var predictionIt = engineIt.Predict();
var forecastedIt = predictionIt.Forecasted.Select(p => Math.Round(p));

In [119]:
// create estimator and prediction engine and them predict some infections for China, Hubei

var windowSize = xCnH.Count() / 3;
var trainSize = xCnH.Count();

IEstimator<ITransformer> estimatorCnH = mlContext.Forecasting.ForecastBySsa(
    outputColumnName: nameof(ModelOutput.Forecasted),
    inputColumnName: nameof(ModelInput.Confirmed),
    windowSize: windowSize,
    seriesLength: windowSize + 1,
    trainSize: trainSize,
    horizon: horizon);

ITransformer transformerCnH = estimatorIt.Fit(dataCnH);

var engineCnH = transformerCnH.CreateTimeSeriesEngine<ModelInput, ModelOutput>(mlContext);
var predictionCnH = engineCnH.Predict();
var forecastedCnH = predictionCnH.Forecasted.Select(p => Math.Round(p));

In [120]:
// generate the days for prediction

DateTime[] dateRange = Enumerable.Range(0, horizon).Select(d => startDate.AddDays(d)).ToArray();

In [121]:
// generate histogram for past and predicted infections

var predictedIinfectionHistogram = Chart.Plot(new[] { 
    new Graph.Scatter { x = yRo, y = xRo, name = "Romania (past)"}, 
    new Graph.Scatter { x = yNl, y = xNl, name = "Netherlands (past)" }, 
    new Graph.Scatter { x = yIt, y = xIt, name = "Italy (past)" }, 
    new Graph.Scatter { x = yCnH, y = xCnH, name = "Hubei (past)" },

    new Graph.Scatter { x = dateRange, y = forecastedRo, name = "Romania (predicted)" }, 
    new Graph.Scatter { x = dateRange, y = forecastedNl, name = "Netherlands (predicted)" }, 
    new Graph.Scatter { x = dateRange, y = forecastedIt, name = "Italy (predicted)" }, 
    new Graph.Scatter { x = dateRange, y = forecastedCnH, name = "Hubei (predicted)" } 
});

var layout = new Layout.Layout(){title=$"Infections per day starting from {day}"};
predictedIinfectionHistogram.WithLayout(layout);
display(predictedIinfectionHistogram);