# Demo - 2D Data

In [None]:
import pandas as pd
import numpy as np
from bayesnewton.utils import discretegrid
import bayesnewton
import objax
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import time
from tqdm.notebook import trange, tqdm

# import tikzplotlib

rng = np.random.RandomState(123)

In [None]:
url = "https://raw.githubusercontent.com/AaltoML/BayesNewton/main/data/TRI2TU-data.csv"

# data = np.loadtxt(url, delimiter=",")
data = pd.read_csv(url, header=None)

print(data.head(10))
print(data.describe())

data = data.values

In [None]:
num_data, num_dim = data.shape
print(f"Num Data: {num_data:,}")
print(f"Num Dims: {num_dim}")

In [None]:
data[:10]

In [None]:
# spatial grid point (y-axis)
nx = 50
ny = 30
print(f"Grid Points (x): {nx}")

# temporal grid point (x-axis)
nt = 100
print(f"Grid Points (t): {nt}")

# binsize (delta t)
binsize = 1000 / nt

print(f"Delta t: {int(binsize)}")

# total # of data points
N = nx * nt

print(f"Data Points (Grid): {N:,}")

In [None]:
t, r, Y_ = bayesnewton.utils.discretegrid(data, [0, 1000, 0, 500], [nt, nx, ny])
t.shape, r.shape, Y_.shape

In [None]:
r

### Grids

In [None]:
# spatial grid point (y-axis)
nr = 50

# temporal grid point (x-axis)
nt = 100

# binsize (delta t)
binsize = 1000 / nt

# total # of data points
N = nr * nt

In [None]:
t, r, Y_ = bayesnewton.utils.discretegrid(data, [0, 1000, 0, 500], [nt, nr])

In [None]:
t.shape, r.shape, Y_.shape

In [None]:
t_flat, r_flat, Y_flat = t.flatten(), r.flatten(), Y_.flatten()

t_flat.shape, r_flat.shape, Y_flat.shape

### Viz - Tree Locations

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

ax.plot(data[:, 0], data[:, 1], "k.", markersize=2)

ax.set(title="Tree Locations (Observations)", xlim=[0, 1_000], ylim=[0, 500])

plt.show()

### Viz - Training Grid

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
im = ax.imshow(Y_.T / binsize, extent=[0, 1_000, 0, 500], cmap=cm.viridis)
ax.set(title="Tree Count Data (Full)")
plt.colorbar(im, fraction=0.0235, pad=0.04)
plt.show()

## Train-Test Split

In [None]:
# create random test indices
test_ind = rng.permutation(N)[: N // 10]

# subset data
t_test = t_flat[test_ind]
r_test = r_flat[test_ind]
Y_test = Y_flat[test_ind]

Y_flat[test_ind] = np.nan

Y = Y_flat.reshape(nt, nr)

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
im = ax.imshow(Y / binsize, extent=[0, 1_000, 0, 500], cmap=cm.viridis)
ax.set(title="Tree Count Data (Full)")
plt.colorbar(im, fraction=0.0235, pad=0.04)
plt.show()

In [None]:
# put test points on a grid to speed up prediction
X_test = np.concatenate([t_test[:, None], r_test[:, None]], axis=1)
print(X_test.shape, Y_test.shape)

In [None]:
# X_test

In [None]:
t_test, r_test, Y_test = bayesnewton.utils.create_spatiotemporal_grid(X_test, Y_test)
t_test.shape, r_test.shape, Y_test.shape

In [None]:
t.shape, t_test.shape

In [None]:
t.shape, r.shape, Y.shape

## Model

In [None]:
var_f = 1.0  # GP variance
len_f = 20.0  # lengthscale

# kern = bayesnewton.kernels.SpatialMatern32(variance=var_f, lengthscale=len_f, z=r[0, ...], sparse=False)
kern = bayesnewton.kernels.SpatialMatern32(
    variance=var_f, lengthscale=len_f, z=r[0, ...], sparse=True
)
lik = bayesnewton.likelihoods.Poisson(binsize=binsize)
# lik = bayesnewton.likelihoods.Gaussian(variance=1)
# model = bayesnewton.models.VariationalGP(kernel=kern, likelihood=lik, X=x, Y=Y)
model = bayesnewton.models.MarkovVariationalGP(
    kernel=kern, likelihood=lik, X=t, R=r, Y=Y
)
# model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=t_flat, R=r_flat, Y=Y_flat)
# model = bayesnewton.models.InfiniteHorizonVariationalGP(kernel=kern, likelihood=lik, X=t, R=r, Y=Y)
# model = bayesnewton.models.MarkovVariationalGPMeanField(kernel=kern, likelihood=lik, X=t, R=r, Y=Y)

## Training

### Optimizer

In [None]:
lr_adam = 0.2
lr_newton = 0.2
iters = 10
opt_hypers = objax.optimizer.Adam(model.vars())

### Loss Function

In [None]:
energy = objax.GradValues(model.energy, model.vars())

### Training Loop

In [None]:
@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton)  # perform inference and update variational params
    dE, E = energy()  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)
    test_nlpd_ = model.negative_log_predictive_density(X=t_test, R=r_test, Y=Y_test)
    return E, test_nlpd_


train_op = objax.Jit(train_op)

### Training

In [None]:
t0 = time.time()
for i in trange(1, iters + 1):
    loss, test_nlpd = train_op()
    print("iter %2d, energy: %1.4f, nlpd: %1.4f" % (i, loss[0], test_nlpd))
t1 = time.time()
print("optimisation time: %2.2f secs" % (t1 - t0))

## Results

### Loss Function

### NLPD

In [None]:
%%time

# calculate posterior predictive distribution via filtering and smoothing at train & test locations:
print("calculating the posterior predictive distribution ...")
t0 = time.time()
posterior_mean, posterior_var = model.predict(X=t, R=r)
# posterior_mean_y, posterior_var_y = model.predict_y(X=t, R=r)
nlpd = model.negative_log_predictive_density(X=t_test, R=r_test, Y=Y_test)
t1 = time.time()
print("prediction time: %2.2f secs" % (t1 - t0))
print("nlpd: %2.3f" % nlpd)

link_fn = lik.link_fn

### Viz - Results

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
im = ax.imshow(Y_.T / binsize, extent=[0, 1000, 0, 500], cmap=cm.viridis)
ax.set(title="Tree Count Data (Full)")
plt.colorbar(im, fraction=0.0235, pad=0.04)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
im = plt.imshow(
    link_fn(posterior_mean).T[::-1, :],
    cmap=cm.viridis,
    extent=[0, 1000, 0, 500],
    origin="lower",
)
# im = plt.imshow(posterior_mean_y.T, cmap=cmap, extent=[0, 1000, 0, 500], origin='lower')
plt.colorbar(im, fraction=0.0235, pad=0.04)
plt.xlim(0, 1000)
plt.ylim(0, 500)
# plt.title('2D log-Gaussian Cox process (rainforest tree data). Log-intensity shown.')
plt.title(
    "2D log-Gaussian Cox process (rainforest tree data). Tree intensity per $m^2$."
)
plt.xlabel("first spatial dimension, $t$ (metres)")
plt.ylabel("second spatial dimension, $r$ (metres)")