## Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from adaptive_latents.vjf import VJF, BaseVJF
import vjf.online
import torch
from tqdm.notebook import trange
from adaptive_latents.input_sources import LDS

## Set seed(s)

In [None]:
seed = 0

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

rng = np.random.default_rng(seed)

## Generate data

In [None]:
%matplotlib inline

x, y, stim = LDS.run_nest_dynamical_system(500, rng=rng)

u = 0*stim

xdim = 2
udim = u.shape[-1]
ydim = y.shape[-1]



In [None]:
# hyperparameters


config=dict(
    resume=False,
    xdim=xdim,  # dimension of hidden state
    ydim=ydim,  # dimension of observations
    udim=1,  # dimension of control vector
    Ydim=udim,
    Udim=udim,
    rdim=50,  # number of RBFs
    hdim=100,  # number of MLP hidden units
    lr=1e-3,  # learning rate
    clip_gradients=5.0,
    debug=False,
    likelihood='gaussian',  # 
    system='rbf',
    recognizer='mlp',
    C=(None, True),  # loading matrix: (initial, estimate)
    b=(None, True),  # bias: (initial, estimate)
    A=(None, False),  # transition matrix if LDS
    B=(np.zeros((xdim, udim)), False),  # interaction matrix
    Q=(1.0, True),  # state noise
    R=(1.0, True),  # observation noise
)



## minimal run

In [None]:
mdl = vjf.online.VJF(config=config)

ys = torch.from_numpy(y).float()
us = torch.from_numpy(u).float()

mu = torch.zeros(ys.shape[0], xdim)
q = None  # current state

for i in np.arange(ys.shape[0]):
    q, _ = mdl.feed((ys[i:i+1], us[i:i+1]), q0=q)
    mu[i,:], _ = q

In [None]:
fig, ax = plt.subplots()
ax.plot(*mu[-200:].detach().numpy().T)
ax.axis('equal');

In [None]:
v = VJF(input_streams={0:'X'}, latent_d=2, rng=rng)
ret = v.offline_run_on([y], show_tqdm=True)


In [None]:
%matplotlib inline
fig, axs = plt.subplots(ncols=2, figsize=(10, 4))

axs[0].plot(ret[-200:,0], ret[-200:,1])
axs[0].set_title('VJF estimated latent state');
axs[0].axis('equal');

axs[1].plot(x[-200:, 0], x[-200:, 1])
axs[1].set_title('true latent state');
axs[1].axis('equal');


In [None]:
%matplotlib inline
cloud = v.get_cloud_at_time_t(0).detach()
preds = v._vjf.decoder(cloud).detach().numpy()


n = 31
x_edges = np.linspace(-20,20, n)
y_edges = np.linspace(-20,20, n)
x_centers = np.convolve([.5,.5], x_edges, mode='valid')
y_centers = np.convolve([.5,.5], y_edges, mode='valid')
log_probs = np.zeros((len(y_centers), len(x_centers)))
for i, y_i in enumerate(y_centers):
    for j, x_j in enumerate(x_centers):
        log_probs[i,j] = v.get_logprob_for_cloud(cloud=cloud, point=np.array([x_j,y_i,0]))
i,j = np.unravel_index(np.argmax(log_probs), log_probs.shape)


fig, axs = plt.subplots(ncols=2, figsize=(10, 4), sharey=True, sharex=True, subplot_kw={'adjustable': 'box', 'aspect':1})

axs[0].plot(x[-200:, 0], x[-200:, 1])
axs[0].scatter(preds[:, 0], preds[:, 1], color='C1', s=5, zorder=3)
axs[0].scatter(preds[:, 0].mean(), preds[:, 1].mean(), color='C2', zorder=3)

axs[1].pcolormesh(x_edges,y_edges,log_probs, vmin=np.quantile(log_probs.flatten(), .5), vmax=log_probs.max(), cmap='plasma')
axs[1].scatter(x_centers[j], y_centers[i], color='C2')
    

In [None]:
v.predict(0,method='asdf')

In [None]:

v = VJF(latent_d=2, rng=rng)
mu, logvar, losses = v.fit(y=y[:-15])


In [None]:
fig, ax = plt.subplots()

ax.plot(mu[:,0], mu[:,1])
ax.axis('equal');


## run with log_pred_p evaluation

In [None]:
device = 'cpu' # 'cuda' does not work

In [None]:

def log_step(mdl, ys, t, S=1000, T=10):
    mdl: BaseVJF
    x = mdl.generate_cloud()
    
    logprobs = []
    distances = []
    for i in range(T):
        if t + i < ys.shape[0]:
            y_tprime = ys[t + i].cpu().numpy()
        else:
            y_tprime = ys[t].cpu().numpy() * np.nan

        x = mdl.step_for_cloud(x)
        logprob = mdl.get_logprob_for_cloud(x, y_tprime)
        distance = mdl.get_distance_for_cloud(x, y_tprime)

        logprobs.append(logprob)
        distances.append(distance)


    return logprobs, distances


In [None]:
mdl = BaseVJF(config=config, latent_d=2)

mdl.init_vjf(ydim, udim)

ys = torch.from_numpy(y).float()
us = torch.from_numpy(u).float()


logprobs = []
distances = []
mu2 = np.zeros((ys.shape[0], xdim))

for t in trange(ys.shape[0]):
    step_logprobs, step_distances = log_step(mdl, ys, t, T=1)
    logprobs.append(step_logprobs)
    distances.append(step_distances)

    y_t = ys[t].unsqueeze(0)
    u_t = us[t].unsqueeze(0)
    mdl.observe(y_t, u_t)
    
    # mu2[t] = q[0].detach().numpy()

logprobs, distances = np.array(logprobs), np.array(distances)

In [None]:
i = 10
y[i:] - y[:len(y)-i]

In [None]:
%matplotlib qt
fig, axs = plt.subplots(nrows=2)

for i in range(logprobs.shape[-1]):
    axs[0].plot(logprobs[:, i], label=f"{i+1} step{'s' if i > 0 else ''} ahead")
    axs[1].plot(distances[:, i], label=f"{i+1} step{'s' if i > 0 else ''} ahead")
    
for ax in axs:
    ax.set_xlabel("time")

axs[0].set_ylabel("log probability")
axs[1].set_ylabel("average prediction distance")
axs[1].legend(bbox_to_anchor=(1.01, 0.95))
# plt.ylim([-300, 0])

In [None]:
# Checks that the randomness is actually controlled by SEED.
# This method is obviously hacky, but I'm keeping it because the logprobs
# and distances were inconsistent between runs before, despite seeding.

vars = {}
for var in ['seed', 'y', 'mu', 'ys', 'mu2', 'distances', 'logprobs']:
    v = globals()[var]
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    vars[var] = v

    s = f'/tmp/asdf_{var}'
    try:
        old_v = np.load(f"{s}.npy")
    except FileNotFoundError:
        old_v = None
    np.save(s, v)

    if old_v is not None:
        same = np.shape(v) == np.shape(old_v) and np.nanmax((v-old_v)**2) == 0
        print(f'{var}: {same}')
    else:
        print(f'{var}: NEW')
