In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import bayesnewton
import matplotlib.cm as cm
import time
import objax
from tqdm.notebook import trange, tqdm

rng = np.random.RandomState(123)

## Data

### Temperature

In [None]:
%%time

ds = xr.tutorial.open_dataset("air_temperature").load()

In [None]:
ds

In [None]:
ds.air.isel(time=0).plot(cmap="RdBu_r")

In [None]:
subset_ds = ds.isel(
    time=slice(0, 50),
    # lat=slice(-50, 10),
    # lon=slice(-50, 20)
)
subset_ds

In [None]:
subset_ds.air.isel(time=0).plot()
plt.show()
subset_ds.air.isel(time=49).plot()
plt.show()

In [None]:
# remove mean and divide by SDE

mean = subset_ds.air.mean(axis=(1, 2))
std = subset_ds.air.std(axis=(1, 2))

dat = (subset_ds.air - mean) / std
dat.isel(time=0).plot()

In [None]:
dat.isel(time=9).plot()

## Dataset Dimensions

* Latitude 
* Longitude
* Total Spatial Dims
* Time
* Total Dims
* Total Variables

In [None]:
# number of time
nt = dat.time.shape[0]

# spatial dimensions
n_lat = dat.lat.shape[0]
n_lon = dat.lon.shape[0]
n_latlon = n_lat * n_lon
n_coords = 2

# variable dimensions
n_vars = 1

In [None]:
n_dims = n_vars * n_latlon * nt

print(f"Total Dims: {n_dims:,}")

In [None]:
# N = Y_obs.shape[0] * Y_obs.shape[1] * Y_obs.shape[2]

# print(f"Num data points: {N:,}")

In [None]:
data = dat.to_dataframe().reset_index()
data.head()

### Train Split

In [None]:
# create random test indices
test_ind = rng.permutation(n_dims)[: n_dims // 10]


# data_train =

In [None]:
# nd = subset_ds.shape[0]

In [None]:
# binsize (delta t)
binsize = 1000 / nt

In [None]:
# subset_ds.air.shape, time_stamp.shape, lat.shape, lon.shape

In [None]:
Y = data["air"].values
time_stamp = data["time"].view(np.int64) // 10**9
lat = data["lat"].values
lon = data["lon"].values
X = np.stack([time_stamp, lat, lon], axis=1)
Y.shape, X.shape

In [None]:
X[..., 2]

In [None]:
t, R_plot, Y_obs_plot = bayesnewton.utils.create_spatiotemporal_grid(X, Y)

In [None]:
t.shape, R_plot.shape, Y_obs_plot.shape

#### Sequential

In [None]:
t = X[:, :1]
R = X[:, 1:]
t.shape, R.shape, Y.shape

In [None]:
Nt = t.shape[0]
print("num time steps =", Nt)
N = Y_obs_plot.shape[0] * Y_obs_plot.shape[1] * Y_obs_plot.shape[2]
print("num data points =", N)

In [None]:
Y.shape

In [None]:
# sort out the train/test split
fold = 0
np.random.seed(99)
ind_shuffled = np.random.permutation(N)
ind_split = np.stack(np.split(ind_shuffled, 10))  # 10 random batches of data indices
test_ind = ind_split[fold]  # test_ind = np.random.permutation(N)[:N//10]
X_test = X[test_ind]
Y_test = Y[test_ind]
Y[test_ind] = np.nan
# Y = Y.reshape(nt, nr)

#### Gridded

In [None]:
%%time

t_train, R_train, Y_obs_train = bayesnewton.utils.create_spatiotemporal_grid(X, Y)
t_test, R_test, Y_obs_test = bayesnewton.utils.create_spatiotemporal_grid(
    X_test, Y_test
)

In [None]:
# assert t.shape == (nt, 1)
# assert R.shape == (nt, n_d, n_coords)
# assert Y_obs.shape == (nt, n_d, n_vars)

In [None]:
test_ind.shape, R.shape

### Sparse Points

In [None]:
z1 = np.linspace(np.min(X[:, 1]), np.max(X[:, 1]), num=7)
z2 = np.linspace(np.min(X[:, 2]), np.max(X[:, 2]), num=7)

zA, zB = np.meshgrid(z1, z2)

z = np.hstack((zA.reshape(-1, 1), zB.reshape(-1, 1)))

In [None]:
z.shape, R[0, ...].shape

In [None]:
var_f = 1.0  # GP variance
len_f = 1.0  # lengthscale
len_time = 1
len_space = 1
sparse = True
opt_z = True

In [None]:
# kernel
kern_time = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_time)
kern_space0 = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_space)
kern_space1 = bayesnewton.kernels.Matern32(variance=var_f, lengthscale=len_space)
kern_space = bayesnewton.kernels.Separable([kern_space0, kern_space1])

In [None]:
kern = bayesnewton.kernels.SpatioTemporalKernel(
    temporal_kernel=kern_time,
    spatial_kernel=kern_space,
    z=z,
    sparse=sparse,
    opt_z=opt_z,
    conditional="Full",
)

In [None]:
%%time

# likelihood
lik = bayesnewton.likelihoods.Gaussian(variance=1)

# model
model = bayesnewton.models.MarkovVariationalGP(
    kernel=kern, likelihood=lik, X=t_train, R=R_train, Y=Y_obs_train, parallel=False
)
# model = bayesnewton.models.MarkovVariationalMeanFieldGP(kernel=kern, likelihood=lik, X=t, R=R, Y=Y_obs, parallel=False)
# model = bayesnewton.models.SparseMarkovGaussianProcess(kernel=kern, likelihood=lik, X=t, R=R, Y=Y_obs, Z=z)
# model = bayesnewton.models.SparseMarkovMeanFieldGaussianProcess(kernel=kern, likelihood=lik, X=t, R=R, Y=Y_obs)

In [None]:
lr_adam = 0.1
lr_newton = 1.0
iters = 1_000
opt_hypers = objax.optimizer.Adam(model.vars())

In [None]:
energy = objax.GradValues(model.energy, model.vars())

In [None]:
@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton)  # perform inference and update variational params
    dE, E = energy()  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)
    # test_nlpd_ = model.negative_log_predictive_density(X=t_test, R=r_test, Y=Y_test)
    return E


train_op = objax.Jit(train_op)

In [None]:
t0 = time.time()

losses = []

with trange(1, iters + 1) as pbar:
    for i in pbar:
        loss = train_op()

        losses.append(np.array(loss[0]))

        pbar.set_description(f"iter {i:d}, energy: {loss[0]:1.4f}")

t1 = time.time()
print("optimisation time: %2.2f secs" % (t1 - t0))

In [None]:
fig, ax = plt.subplots()

ax.plot(losses, label="NLL Loss (Energy)")
plt.legend()
plt.show()

In [None]:
N_test = 50

# r1 = np.unique(X[:, 1])
# r2 = np.unique(X[:, 2])
X1range = max(X[:, 1]) - min(X[:, 1])
X2range = max(X[:, 2]) - min(X[:, 2])
r1 = np.linspace(min(X[:, 1]) - 0.1 * X1range, max(X[:, 1]) + 0.1 * X1range, num=N_test)
r2 = np.linspace(
    min(X[:, 2]) - 0.05 * X2range, max(X[:, 2]) + 0.05 * X2range, num=N_test
)
rA, rB = np.meshgrid(r1, r2)
r = np.hstack(
    (rA.reshape(-1, 1), rB.reshape(-1, 1))
)  # Flattening grid for use in kernel functions
Rplot = np.tile(r, [t.shape[0], 1, 1])

In [None]:
Rplot.shape, t.shape

In [None]:
%%time

posterior_mean, posterior_var = model.predict(X=t_train, R=R_train)

In [None]:
posterior_mean.shape

In [None]:
# mu = bayesnewton.utils.transpose(posterior_mean.reshape(-1, N_test, N_test))
mu = bayesnewton.utils.transpose(posterior_mean.reshape(-1, n_lat, n_lon))
mu_real = bayesnewton.utils.transpose(Y_obs_train.reshape(-1, n_lat, n_lon))

In [None]:
mu.shape, Y_obs_train.shape

In [None]:
for i in range(len(mu)):

    fig, axes = plt.subplots(ncols=3, figsize=(15, 5))
    im = axes[0].imshow(mu_real[i].T[::-1], cmap="RdBu_r", aspect="auto")
    im = axes[1].imshow(mu[i].T[::-1], cmap="RdBu_r", aspect="auto")
    im2 = dat.isel(time=i).plot(
        ax=axes[2], vmin=mu[i].min(), vmax=mu[i].max(), cmap="RdBu_r"
    )
    # im = axes[0].imshow(mu[i].T[::-1], cmap="RdBu_r", aspect="auto")
    #

    plt.show()

### Flatten

In [None]:
t_flat = np.array(t).flatten()
R_flat = np.array(R).flatten()
Y_flat = np.array(Y_obs).flatten()

In [None]:
t_flat.shape, nt, nt / 10

In [None]:
test_ind.shape, Y_obs.shape

In [None]:
Y_flat.shape, test_ind.shape

In [None]:
# create random test indices
test_ind = rng.permutation(N)[: N // 10]

# subset data
t_test = t_flat[test_ind]
r_test = R_flat[test_ind]
Y_test = Y_flat[test_ind]

Y_flat[test_ind] = np.nan

Y = Y_flat.reshape(nt, n_latlon)

In [None]:
Y_img = Y.reshape((nt, n_lat, n_lon))

In [None]:
fig, ax = plt.subplots()

ax.imshow(Y_img[0, ::-1, :])

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
im = ax.imshow(Y / binsize, extent=[0, 1_000, 0, 500], cmap=cm.viridis)
ax.set(title="Temperature Data (Full)")
plt.colorbar(im, fraction=0.0235, pad=0.04)
plt.show()

In [None]:
T.shape, R1.shape, R2.shape

In [None]:
Y = ds.air
time = ds.coords["time"].values
time_steps = np.arange(time.shape[0])
lat = ds.coords["lat"].values
lon = ds.coords["lon"].values

In [None]:
time_steps.shape, lat.shape, lon.shape

In [None]:
X = np.vstack([time_steps, lat, lon])

print(X.shape, Y.shape)

#### Viz - Basemap

In [None]:
ds.air[0].plot()

#### Viz - Time Series


In [None]:
ds.air[:, 0, 0].plot()

### ROMS Model

In [None]:
%%time

ds = xr.tutorial.open_dataset("ROMS_example").load()

In [None]:
ds

### Viz - Gridded

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(14, 4))
ds.xc.plot(ax=ax1)
ds.yc.plot(ax=ax2)

In [None]:
ds.Tair[0].plot()