In [None]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta


def covidlive_data(start_date=np.datetime64('2021-06-10')):
    df = pd.read_html('https://covidlive.com.au/report/daily-source-overseas/nsw')[1]

    df = df[:200]

    if df['NET2'][0] == '-':
        df = df[1:200]

    dates = np.array(
        [
            np.datetime64(datetime.strptime(date, "%d %b %y"), 'D') - 1
            for date in df['DATE']
        ]
    )
    cases = np.array(df['NET2'].astype(int))
    cases = cases[dates >= start_date][::-1]
    dates = dates[dates >= start_date][::-1]

    return dates, cases

base = datetime(2021, 8, 1)
arr = np.array([base + timedelta(days=i) for i in range(90)])
print(arr)
dates, cases = covidlive_data(np.datetime64('2021-08-01'))
print(dates)
print(cases)

In [None]:
import matplotlib.pyplot as plt
import torch

import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist

def_type = torch.FloatTensor
torch.set_default_tensor_type(def_type)

pyro.set_rng_seed(0)



In [None]:
# note that this helper function does three different things:
# (i) plots the observed data;
# (ii) plots the predictions from the learned GP after conditioning on data;
# (iii) plots samples from the GP prior (with no conditioning on observed data)

def plot(plot_observed_data=False, plot_predictions=False, n_prior_samples=0,
         model=None, kernel=None, n_test=7, history=114):

    plt.figure(figsize=(12, 6))
    if plot_observed_data:
        plt.plot(X.numpy(), y.numpy(), 'kx')
    if plot_predictions:
        Xtest = torch.linspace(0, 6, n_test)  # test inputs
        # compute predictive mean and variance
        with torch.no_grad():
            if type(model) == gp.models.VariationalSparseGP:
                mean, cov = model(Xtest, full_cov=True)
            else:
                mean, cov = model(Xtest, full_cov=True, noiseless=False)
        sd = cov.diag().sqrt()  # standard deviation at each input point x
        plt.plot(Xtest.numpy(), mean.numpy(), 'r', lw=2)  # plot the mean
        plt.fill_between(Xtest.numpy(),  # plot the two-sigma uncertainty about the mean
                         (mean - 2.0 * sd).numpy(),
                         (mean + 2.0 * sd).numpy(),
                         color='C0', alpha=0.3)
    if n_prior_samples > 0:  # plot samples from the GP prior
        Xtest = torch.linspace(0, 6, n_test)  # test inputs
        noise = (model.noise if type(model) != gp.models.VariationalSparseGP
                 else model.likelihood.variance)
        cov = kernel.forward(Xtest) + noise.expand(n_test).diag()
        samples = dist.MultivariateNormal(torch.zeros(n_test), covariance_matrix=cov)\
                      .sample(sample_shape=(n_prior_samples,))
        plt.plot(Xtest.numpy(), samples.numpy().T, lw=2, alpha=0.4)

    plt.xlim(-2, 7)

In [None]:
history= len(cases)
N = history
X = torch.arange(-history, 0, 1, dtype=def_type.dtype)
y = torch.as_tensor(cases.copy(), dtype=def_type.dtype)
plot(plot_observed_data=True) 

In [None]:
kernel = gp.kernels.RBF(input_dim=1, variance=torch.tensor(5., dtype=def_type.dtype),
                        lengthscale=torch.tensor(10., dtype=torch.double))
gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(1., dtype=def_type.dtype))

In [None]:
plot(model=gpr, kernel=kernel, n_prior_samples=2)

In [None]:
optimizer = torch.optim.Adam(gpr.parameters(), lr=0.005)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
num_steps = 2500
for i in range(num_steps):
    optimizer.zero_grad()
    loss = loss_fn(gpr.model, gpr.guide)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

In [None]:
plot(model=gpr, plot_observed_data=True, plot_predictions=True)

In [None]:
pyro.clear_param_store()
kernel = gp.kernels.RBF(input_dim=1, variance=torch.tensor(5., dtype=def_type.dtype),
                        lengthscale=torch.tensor(10., dtype=def_type.dtype))
gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(1., dtype= def_type.dtype))

# note that our priors have support on the positive reals
gpr.kernel.lengthscale = pyro.nn.PyroSample(dist.LogNormal(0.0, 1.0))
gpr.kernel.variance = pyro.nn.PyroSample(dist.LogNormal(0.0, 1.0))

optimizer = torch.optim.Adam(gpr.parameters(), lr=0.005)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
num_steps = 2500
for i in range(num_steps):
    optimizer.zero_grad()
    loss = loss_fn(gpr.model, gpr.guide)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
plt.plot(losses);

In [None]:
plot(model=gpr, plot_observed_data=True, plot_predictions=True)

In [None]:
pyro.clear_param_store()
kernel = gp.kernels.Sum(
    gp.kernels.Matern52(
        input_dim=1, variance=torch.tensor(5., dtype=def_type.dtype),
        lengthscale=torch.tensor(10., dtype=def_type.dtype)),
    gp.kernels.Periodic(input_dim=1, variance=torch.tensor(5., dtype=def_type.dtype),
        lengthscale=torch.tensor(10., dtype=def_type.dtype),
        period=torch.tensor(10., dtype=def_type.dtype),
    )
)
gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(50., dtype= def_type.dtype))

# note that our priors have support on the positive reals
gpr.kernel.kern0.lengthscale = pyro.nn.PyroSample(dist.Gamma(1, 0.1))
gpr.kernel.kern0.variance = pyro.nn.PyroSample(dist.Gamma(2, 0.1))
gpr.kernel.kern1.lengthscale = pyro.nn.PyroSample(dist.Gamma(1, 0.1))
gpr.kernel.kern1.variance = pyro.nn.PyroSample(dist.Gamma(2, 0.1))
gpr.kernel.kern1.period = pyro.nn.PyroSample(dist.Gamma(7, 1))


optimizer = torch.optim.Adam(gpr.parameters(), lr=0.005)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
num_steps = 2500
for i in range(num_steps):
    optimizer.zero_grad()
    loss = loss_fn(gpr.model, gpr.guide)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
plt.plot(losses);
plot(model=gpr, plot_observed_data=True, plot_predictions=True)

In [None]:
mean, var = gpr(torch.arange(0,7), full_cov=True)
pred_dist = dist.MultivariateNormal(mean, var)


In [None]:
mean, var

In [None]:
pred_dist.log_prob(torch.tensor([200., 200, 200, 200, 200, 200, 200]))