# Deep Kalman Filter 

In [None]:
import treex as tx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import scale

In [None]:
jax.__version__

## Data

In [None]:
# import warnings
# warnings.filterwarnings('ignore')

T = 500  # sequence length
observations = 2 * np.sin(np.linspace(0, 20 * np.pi, T))
interventions = 2 * np.sin(np.linspace(0, 2 * np.pi, T))
data = np.vstack(
    [observations, observations * 1.2, interventions, interventions * 0.85]
).T
data += np.random.randn(*data.shape)
# data[:, 2:] = preprocessing.minmax_scale(data[:, 2:])
data = scale(data)

data.shape

In [None]:
plt.figure(figsize=(10, 2))
plt.plot(data)
plt.xlabel("Time")
plt.ylabel("Value")
plt.show()

## Model

### Components

#### Transition Function

In [None]:
class GatedTransition(tx.Module):
    gate: tx.Module
    prop_mean: tx.Module
    mean_fn: tx.Module
    logvar_fn: tx.Module

    def __init__(self, latent_dim, hidden_dim) -> None:
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim

        self.gate = tx.Sequential(
            tx.Linear(features_out=hidden_dim),
            jax.nn.relu,
            tx.Linear(features_out=latent_dim),
        )
        self.prop_mean = tx.Sequential(
            tx.Linear(features_out=hidden_dim),
            jax.nn.relu,
            tx.Linear(features_out=latent_dim),
        )
        self.mean_fn = tx.Linear(
            latent_dim,
            kernel_init=tx.initializers.ones,
            bias_init=tx.initializers.zeros,
        )
        self.logvar_fn = tx.Linear(latent_dim)

    def __call__(self, inputs):

        # if self.initializing():
        #     self.mean_fn.kernel = jnp.eye(self.latent_dim)
        #     self.mean_fn.bias = jnp.zeros(self.latent_dim)

        # mean
        mean = self.mean_fn(inputs)

        # gated operation
        gate = self.gate(inputs)
        prop_mean = self.prop_mean(inputs)

        mean = (1 - gate) * mean + gate * prop_mean

        # log variance
        logvar = self.logvar_fn(prop_mean)

        return mean, logvar

In [None]:
data_sample = data[0][None, :]

In [None]:
hidden_dim = 20
latent_dim = 20

# init module
gated_fn = GatedTransition(latent_dim, hidden_dim)

# init params
key = jax.random.PRNGKey(123)

gated_fn = gated_fn.init(key=key, inputs=data_sample)

# forward operation
z_mu, z_logvar = gated_fn(data_sample)

# check shapes
assert z_mu.shape == data_sample.shape
assert z_logvar.shape == data_sample.shape

In [None]:
z_mu.shape, z_logvar.shape, data_sample.shape

#### Emission

In [None]:
class Emitter(tx.Module):
    logvar: tx.Parameter()
    mean_fn: tx.Module

    def __init__(self, latent_dim, hidden_dim, input_dim):
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim

        self.mean_fn = tx.Sequential(
            tx.Linear(hidden_dim),
            jax.nn.relu,
            tx.Linear(hidden_dim),
            jax.nn.relu,
            tx.Linear(input_dim),
        )
        self.logvar = jnp.ones(self.input_dim)

    def __call__(self, inputs):

        mean = self.mean_fn(inputs)

        return mean, self.logvar

In [None]:
hidden_dim = 10
latent_dim = 4
input_dim = 4

# init module
emitt_fn = Emitter(latent_dim, hidden_dim, input_dim)

# init params
key = jax.random.PRNGKey(42)

emitt_fn = emitt_fn.init(key=key, inputs=jnp.asarray(data))

# forward operation
x_mu, x_logvar = emitt_fn(data)

# check shapes
assert x_mu.shape == data.shape
assert x_logvar.shape == data.shape[1:]

### Posterior

#### Combiner

In [None]:
class Combiner(tx.Module):
    hidden_fn: tx.Linear
    hidden_to_mu: tx.Linear
    hidden_to_logvar: tx.Linear

    def __init__(self, latent_dim, hidden_dim):
        super().__init__()

        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.hidden_fn = tx.Linear(hidden_dim)
        self.hidden_to_mu = tx.Linear(latent_dim)
        self.hidden_to_logvar = tx.Linear(latent_dim)

    def __call__(self, inputs, hidden_state=None):
        # combine rnn hidden state with transformed version
        h = 0.5 * jax.nn.tanh(self.hidden_fn(inputs))

        if hidden_state is not None:
            h += 0.5 * hidden_state

        mean = self.hidden_to_mu(h)
        logvar = self.hidden_to_logvar(h)

        return mean, logvar

In [None]:
data.shape[0]

In [None]:
rnn_hidden_dim = 10
latent_dim = 4
input_dim = 4

# init module
combiner_fn = Combiner(latent_dim, rnn_hidden_dim)

# init params
key = jax.random.PRNGKey(42)
hidden_state_init = jnp.ones((1, rnn_hidden_dim))
combiner_fn = combiner_fn.init(key=key, inputs=jnp.asarray(data))

# forward operation

x_mu, x_logvar = combiner_fn(data, rnn_hidden_dim)

# check shapes
assert x_mu.shape == data.shape
assert x_logvar.shape == data.shape

#### RNN

In [None]:
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10
return_state = True
return_sequences = True

gru = recurrent.GRU(
    hidden_dim,
    time_axis=1,
    return_state=return_state,
    return_sequences=return_sequences,
)

# init layer
x_init = jnp.ones((1, 1, features))
hidden_init = jnp.zeros((1, hidden_dim))
gru = gru.init(key, (x_init, hidden_init))

# forward
x_demo = jnp.ones((batch_size, time, features))
hidden_demo = jnp.zeros((batch_size, hidden_dim))

rnn_out, h_n = gru(x_demo, hidden_demo)

In [None]:
rnn_out.shape, h_n.shape

In [None]:
rnn_hidden_dim = 5
hidden_state_init = jnp.zeros((1, rnn_hidden_dim))
rnn_module = tx.nn.recurrent.GRU(
    units=rnn_hidden_dim, return_sequences=True, return_state=False, time_axis=1
)
rnn_module = rnn_module.init(key, data[None, :])

In [None]:
outputs = rnn_module(data[None, :])
outputs.shape

In [None]:
hidden_state = jnp.zeros((1, rnn_hidden_dim))
outputs = rnn_module(data[None, :], hidden_state)
outputs.shape

In [None]:
rnn_hidden_dim = 5
hidden_state_init = jnp.zeros((1, rnn_hidden_dim))
rnn_module = tx.nn.recurrent.GRU(
    units=rnn_hidden_dim, return_sequences=True, return_state=False, time_axis=1
)
rnn_module = rnn_module.init(key, (data[None, :], hidden_state_init))

In [None]:
data[None, :].shape

In [None]:
rnn_out = rnn_module(data[None, :], hidden_state_init)

In [None]:
carry.shape

In [None]:
carry.shape, hidden_state_init.shape

In [None]:
class RNN(tx.Module):
    h_init: tx.Parameter.node()

    def __init__(self, latent_dim: int = 100):
        self.latent_dim = latent_dim
        self.rnn = tx.nn.recurrent.GRU(
            units=rnn_hidden_dim, return_sequences=True, return_state=False, time_axis=1
        )
        self.h_init = jnp.zeros((rnn_hidden_dim))

    def __call__(self, inputs):

        n_batch = inputs.shape[0]

        if self.initializing():

            h_init = repeat(self.h_init, "... -> batch ...", batch=n_batch)
            self.rnn = self.rnn.init(key, (inputs, h_init))

        hidden_state = repeat(self.h_init, "... -> batch ...", batch=n_batch)
        outputs = self.rnn(inputs, hidden_state)
        return outputs

In [None]:
rnn_hidden_dim = 100

# init module
rnn_fn = RNN(latent_dim=rnn_hidden_dim)

# init params
data_all = data[None, ...]

rnn_fn = rnn_fn.init(key=123, inputs=data_all)

# # forward operation

x_mu, x_logvar = rnn_fn(data_all)

# # check shapes
# assert x_mu.shape == data.shape
# assert x_logvar.shape == data.shape

In [None]:
x_mu.shape, x_logvar.shape

### Model

In [None]:
import flax
from flax import linen
from treex.nn import recurrent
from einops import repeat, rearrange

from treex.module import next_key

In [None]:
def kl_div(mu, logvar, mu_prior, logvar_prior):
    loss = (
        logvar_prior
        - logvar
        + jnp.exp(logvar)
        + jnp.power(mu - mu_prior, 2) / jnp.exp(logvar_prior)
        - 1.0
    )
    return jnp.sum(0.5 * loss, axis=1)

In [None]:
class DeepKalmanFilter(tx.Module):
    transition: tx.Module
    emission: tx.Module
    combiner: tx.Module
    h_0: tx.Parameter.node()
    z_0: tx.Parameter.node()
    z_q0: tx.Parameter.node()
    kl_loss: jnp.ndarray = tx.LossLog.node()

    def __init__(
        self,
        transition,
        emission,
        combiner,
    ):
        self.transition = transition
        self.emission = emission
        self.combiner = combiner
        self.rnn = tx.nn.recurrent.GRU(units=rnn_hidden_dim, return_sequences=True)
        self.h_0 = jnp.zeros((self.combiner.hidden_dim,))
        self.z_0 = jnp.zeros((transition.latent_dim))
        self.z_q0 = jnp.zeros((transition.latent_dim))
        self.next_key = tx.KeySeq()

    def __call__(self, x, hidden_state=None):

        if self.initializing():
            h_0 = repeat(self.h_0, "... -> batch ...", batch=x.shape[0])
            self.rnn = self.rnn.init(self.next_key(), inputs=(x, h_0))

        return x

    def infer(self, x):

        n_batch, n_time, n_dim = x.shape

        # initialize hidden_dim
        h_0 = repeat(self.h_0, "... -> batch ...", batch=n_batch)
        print("hidden state:", h_0.shape)
        rnn_out = self.rnn(x, h_0)
        print("rnn:", rnn_out.shape)

        # init state
        z_prev = repeat(self.z_q0, "... -> batch ...", batch=n_batch)
        print("z0:", z_prev.shape)
        x = rearrange(x, "B T D -> T B D")
        rnn_out = rearrange(rnn_out, "B T D -> T B D")

        kl_losses = []
        mse_losses = []

        for t, (x_obs, rnn_t) in enumerate(zip(x, rnn_out)):

            # transition probabilitiy p(z_t|z_t-1)
            z_prior_mu, z_prior_logvar = self.transition(z_prev)

            # sample
            key = self.next_key()
            z_prior = z_prior_mu + z_prior_logvar * jax.random.normal(
                key, z_prior_mu.shape
            )
            print("z_prior:", z_prior.shape)

            # # Combiner
            print("z_prev:", z_prev.shape)
            print("rnn_t:", rnn_t.shape)
            z_mu, z_logvar = self.combiner(z_prev, rnn_t)

            key = self.next_key()
            z_t = z_mu + z_logvar * jax.random.normal(key, z_mu.shape)

            # emission probability
            x_mu, x_logvar = self.emission(z_prev)

            # sample
            key = self.next_key()
            x = x_mu + x_logvar * jax.random.normal(key, x_mu.shape)
            print("x:", x.shape)

            # COMPUTE LOSSES

            # KLD Loss
            kl_losses.append(kl_div(z_mu, z_logvar, z_prior_mu, z_prior_logvar))

            # Reconstruction Loss
            nll_loss = dist.MultivariateNormal().log_prob()

            z_prev = z_t

            break

        self.kl_loss = kl_losses

        return None

    def filter(self, x):
        return None

    def predict(self, x):
        return None

In [None]:
# TRANSITION FUNCTION
hidden_dim = 30
latent_dim = 20

# init module
latent_sample = jnp.zeros((1, latent_dim))
gated_fn = GatedTransition(latent_dim, hidden_dim).init(123, latent_sample)

# EMISSION FUNCTION
hidden_dim = 20
input_dim = 4

# init module
emitt_fn = Emitter(latent_dim, hidden_dim, input_dim).init(123, latent_sample)

# COMBINER FUNCTION
rnn_hidden_dim = 100

# init module
latent_seq_sample = latent_sample[None, :]
combiner_fn = Combiner(latent_dim, rnn_hidden_dim).init(123, latent_seq_sample)


data_seq_sample = data[None, ...]
dkf_model = DeepKalmanFilter(
    transition=gated_fn, emission=emitt_fn, combiner=combiner_fn
)
dkf_model = dkf_model.init(key=42, inputs=jnp.asarray(data_seq_sample))

In [None]:
# combiner_fn.

In [None]:
# dkf_model

In [None]:
output = dkf_model.infer(data[None, :])

In [None]:
dkf_model