## Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal
from scipy.special import logsumexp
from tqdm.notebook import trange
import torch
from vjf.online import VJF
from scipy.stats import special_ortho_group
from adaptive_latents.transformer import DecoupledTransformer
from adaptive_latents import ArrayWithTime

## Set seed(s)

In [None]:
seed = 0

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

rng = np.random.default_rng(seed)

## Generate data

In [None]:
%matplotlib inline

total_time = 500
seconds_per_rotation = 1
samples_per_second = 10

theta = np.linspace(0, total_time/seconds_per_rotation*2*np.pi, total_time*samples_per_second)
x = np.column_stack([np.cos(theta), np.sin(theta)])
y = x @ special_ortho_group(dim=10, seed=rng).rvs(1)[:2]
y = y + rng.normal(0, 0.1, size=y.shape)
u = np.zeros((y.shape[0], 1))


In [None]:
%matplotlib qt
plt.plot(*x[:11].T, '.')
plt.axis('equal')

In [None]:
# hyperparameters
xdim = x.shape[-1]
udim = u.shape[-1]

ydim = y.shape[-1]


config=dict(
    resume=False,
    xdim=xdim,  # dimension of hidden state
    ydim=ydim,  # dimension of observations
    udim=1,  # dimension of control vector
    Ydim=udim,
    Udim=udim,
    rdim=50,  # number of RBFs
    hdim=100,  # number of MLP hidden units
    lr=1e-3,  # learning rate
    clip_gradients=5.0,
    debug=False,
    likelihood='gaussian',  # 
    system='rbf',
    recognizer='mlp',
    C=(None, True),  # loading matrix: (initial, estimate)
    b=(None, True),  # bias: (initial, estimate)
    A=(None, False),  # transition matrix if LDS
    B=(np.zeros((xdim, udim)), False),  # interaction matrix
    Q=(1.0, True),  # state noise
    R=(1.0, True),  # observation noise
)



## minimal run

In [None]:

mdl = VJF(config)

ys = torch.from_numpy(y).float()
us = torch.from_numpy(u).float()

mu = torch.zeros(ys.shape[0], xdim)
q = None  # current state

for i in np.arange(ys.shape[0]):
    q, _ = mdl.feed((ys[i:i+1], us[i:i+1]), q0=q)
    mu[i,:], _ = q

In [None]:
fig, ax = plt.subplots()
ax.plot(*mu[-200:].detach().numpy().T)
ax.axis('equal');

In [None]:
class VJF_transformer(DecoupledTransformer):
    base_algorithm = VJF
    
    def __init__(self, *, config=None, latent_d=6, take_U=False, input_streams=None, output_streams=None, log_level=None):
        if input_streams is None:
            input_streams = {0:'Y', 1:'U'} if take_U else {0:'Y'}
        DecoupledTransformer.__init__(self=self, input_streams=input_streams, output_streams=output_streams, log_level=log_level)
        self.take_U = take_U
        self.latent_d = latent_d
        config = config or {}
        self.config = self.default_config_dict({'xdim': self.latent_d}) | config
        self.last_seen = {}
        self.vjf = None
        self.q = None
    
    @staticmethod
    def default_config_dict(update=None):
        default_config = dict(
            resume=False,
            # xdim=6,  # dimension of hidden state
            udim=1,  # dimension of control vector
            # Ydim=udim,  # possibly not necessary?
            # Udim=udim,  # possibly not necessary?
            rdim=50,  # number of RBFs
            hdim=100,  # number of MLP hidden units
            lr=1e-3,  # learning rate
            clip_gradients=5.0,
            debug=False,
            likelihood='gaussian',  # 
            system='rbf',
            recognizer='mlp',
            C=(None, True),  # loading matrix: (initial, estimate)
            b=(None, True),  # bias: (initial, estimate)
            A=(None, False),  # transition matrix if LDS
            Q=(1.0, True),  # state noise
            R=(1.0, True),  # observation noise
            random_seed=0,

            # these depend on the input dimensions
            # ydim=ydim,  # dimension of observations
            # B=(np.zeros((xdim, udim)), False),  # interaction matrix
        )

        update = update if update is not None else {}
        return default_config | update  # the | makes a copy of the original dict
        
        
    def init_vjf(self, ydim, udim=1):
        assert self.take_U or udim==1
        
        self.config.update({
            'ydim': ydim, 
            'udim': udim,
            'B':(np.zeros((self.config['xdim'], udim)), False),
        })
        
        self.vjf = VJF(self.config)


    def _partial_fit(self, data, stream):
        if stream in self.input_streams:
            self.last_seen[self.input_streams[stream]] = data
            
            if len(self.last_seen) == len(self.input_streams):
                y = self.last_seen['Y']
                if self.take_U:
                    u = self.last_seen['U']
                else:
                    u = np.zeros((y.shape[0], 1))
                    
                if self.vjf is None:
                    self.init_vjf(ydim=y.shape[-1], udim=u.shape[-1])
                    
                y = torch.from_numpy(y).float()
                u = torch.from_numpy(u).float()
                self.q, _ = self.vjf.feed((y, u), q0=self.q)
    
    def transform(self, data, stream=0, return_output_stream=False):
        if self.q is not None:
        # if self.input_streams[stream] == 'Y' and self.q is not None:
            q0 =  self.q[0].detach().numpy()
            data = ArrayWithTime.from_transformed_data(q0, data)
        
        return (data, stream) if return_output_stream else data
    
    def get_params(self, deep=True):
        return super().get_params(deep=deep) | dict(take_U=self.take_U, latent_d=self.latent_d, config=self.config)


v = VJF_transformer(input_streams={0:'Y'}, latent_d=2)
# v.test_if_api_compatible()
ret = v.offline_run_on([y], show_tqdm=True)


In [None]:
fig, ax = plt.subplots()

ax.plot(*ret[-200:].T)
ax.axis('equal');

## run with log_pred_p evaluation

In [None]:
device = 'cpu' # 'cuda' does not work

In [None]:
def diagonal_normal_logpdf(mean, variance, sample):
    mean = mean.flatten()
    variance = variance.flatten()
    sample = sample.flatten()
    
    assert len(mean) == len(variance) == len(sample), f"inconsistent shape: {mean.shape}, {variance.shape}, {sample.shape}"
    
    logprobs = []
    for i in range(len(sample)):
        x = sample[i]
        m = mean[i]
        v = variance[i]
        logprobs.append(-0.5 * ((x - m) ** 2 / v + np.log(2 * np.pi * v)))
    return sum(logprobs)


In [None]:
def generate_samples(q, S, rng):
    filtering_mu, filtering_logvar = q
    
    mu_f = filtering_mu[0].detach().cpu().numpy().T
    var_f = filtering_logvar[0].detach().exp().cpu().numpy().T
    Sigma_f = np.eye(xdim) * var_f

    x = multivariate_normal(mu_f.flatten(), Sigma_f).rvs(size=S, random_state=rng).astype(np.float32)
    x = torch.from_numpy(x).to(device)
    return x

def step(x, mdl, rng):
    x += mdl.system.velocity(x) + mdl.system.noise.var ** 0.5 * torch.from_numpy(rng.normal(size=x.shape))
    return x

def get_logprob_and_distance(mdl, x, y_true):
    y_var = mdl.likelihood.logvar.detach().exp().cpu().numpy().T
    
    y_tilde = mdl.decoder(x).detach().cpu().numpy()
    
    sample_logprobs = [diagonal_normal_logpdf(y_est, y_var, y_true) for y_est in y_tilde]
    logprob = logsumexp(sample_logprobs) - np.log(x.shape[0])
    
    distance = np.linalg.norm(y_tilde - y_true, axis=-1).mean()
    
    return logprob, distance

def log_step(q, mdl, ys, t, rng, S=1000, T=10):
    x = generate_samples(q, S, rng)
    
    logprobs = []
    distances = []
    for i in range(T):
        if t + i < ys.shape[0]:
            y_tprime = ys[t + i].cpu().numpy()
        else:
            y_tprime = ys[t].cpu().numpy() * np.nan

        x = step(x, mdl, rng)
        logprob, distance = get_logprob_and_distance(mdl, x, y_tprime)

        logprobs.append(logprob)
        distances.append(distance)


    return logprobs, distances


In [None]:
mdl = VJF(config).to(device)

ys = torch.from_numpy(y).float().to(device)
us = torch.from_numpy(u).float().to(device)

# could be None except for the logging functions
q = torch.zeros(1, xdim, device=device), torch.zeros(1, xdim, device=device)


logprobs = []
distances = []
mu2 = np.zeros((ys.shape[0], xdim))

for t in trange(ys.shape[0]):
    step_logprobs, step_distances = log_step(q, mdl, ys, t, rng, T=10)
    logprobs.append(step_logprobs)
    distances.append(step_distances)

    y_t = ys[t].unsqueeze(0)
    u_t = us[t].unsqueeze(0)
    q, loss = mdl.feed((y_t, u_t), q)
    
    mu2[t] = q[0].detach().numpy()

logprobs, distances = np.array(logprobs), np.array(distances)

In [None]:
%matplotlib qt
fig, axs = plt.subplots(nrows=2)

for i in range(logprobs.shape[-1]):
    axs[0].plot(logprobs[:, i], label=f"{i+1} step{'s' if i > 0 else ''} ahead")
    axs[1].plot(distances[:, i], label=f"{i+1} step{'s' if i > 0 else ''} ahead")
    
for ax in axs:
    ax.set_xlabel("time")

axs[0].set_ylabel("log probability")
axs[1].set_ylabel("average prediction distance")
axs[1].legend(bbox_to_anchor=(1.01, 0.95))
# plt.ylim([-300, 0])

In [None]:
# Checks that the randomness is actually controlled by SEED.
# This method is obviously hacky, but I'm keeping it because the logprobs
# and distances were inconsistent between runs before, despite seeding.

vars = {}
for var in ['seed', 'x', 'y', 'mu', 'ys', 'mu2', 'distances', 'logprobs']:
    v = globals()[var]
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    vars[var] = v

    s = f'/tmp/asdf_{var}'
    try:
        old_v = np.load(f"{s}.npy")
    except FileNotFoundError:
        old_v = None
    np.save(s, v)

    if old_v is not None:
        same = np.shape(v) == np.shape(old_v) and np.nanmax((v-old_v)**2) == 0
        print(f'{var}: {same}')
    else:
        print(f'{var}: NEW')
