# Variational Behaviour Models

In this article we propose a probabilistic graphical model for the behaviour of an agent *embodied* within an environment.

## 1. Model

We consider a joint distribution of the form

$$
    \begin{align}
        p\left(x^n, y^n, \phi^n, \psi^n\right) = p(\phi_1) p(\psi_1) p(x_1|\phi_1)p(y_1|\psi_1)\prod_{i=2}^n p(\phi_i|\phi_{i-1}, \psi_{i-1}, y_{i-1}) p(\psi_i|\psi_{i-1},\phi_{i-1},x_{i-1}) p(x_i|\phi_i)p(y_i|\psi_i),
    \end{align}
$$

where:

- $x^n$ is a sequence of agent external states,
- $y^n$ is a sequence of environment external states,
- $\phi^n$ is a sequence of agent internal states,
- $\psi^n$ is a sequence of enviornment internal states.

This distribution expresses the interactions between an agent and it's environment. For non-embodied agents, we do not condition the agent's internal states on the environment's internal states, and visa-versa. Then, we recover a non-embodied behaviour model of the form

$$
    \begin{align}
        p\left(x^n, y^n, \phi^n, \psi^n\right) = p(\phi_1) p(\psi_1) p(x_1|\phi_1)p(y_1|\psi_1)\prod_{i=2}^n p(\phi_i|\phi_{i-1}, y_{i-1}) p(\psi_i|\psi_{i-1},x_{i-1}) p(x_i|\phi_i)p(y_i|\psi_i).
    \end{align}
$$

In many situations, non-embodied models are reasonable. Now let us discuss each factor in turn.

- **Agent prior model** - $p(\phi_1)$. This is a prior over agent internal states.

- **Environment prior model** - $p(\psi_1)$. This is a prior over environment internal states.

- **Agent emission model** - $p(x_i|\phi_i)$. This is the distribution of agent external states given agent internal states. It generally represents how an agent's internal state influences its behaviour.

- **Environment emission model** - $p(y_i|\psi_i)$. This is the distribution of environment external states given environment internal states.

- **Agent transition model** - $p(\phi_i|\phi_{i-1}, x_{i-1})$. This represents how agent internal states are influenced by external environment states.

- **Environment transition model** - $p(\psi_i|\psi_{i-1}, y_{i-1})$. This represents how environment internal states are influenced by external agent states.

The whole system is essentially a pair of Hidden Markov Models (HMMs) whose internal states are influenced by eachother's external states.

## 2. Parameter Estimation

The established model involves latent variables so MLE is not directly applicable. An alternative approach is to maximize the variational lower-bound. For a partitioning of latent and observable variables ${Z, X}$, the variational lower bound on $p(X)$ can be written

$$
    \begin{align}
        \mathcal{L} = \mathbb{E}_{Z \sim q(Z|X)}[\log p(X|Z)] - D_\text{KL}(q(Z|X)\Vert \pi(Z)),
    \end{align}
$$

where $q(Z|X)$ is a surrogate posterior and $\pi(Z)$ is a prior over $Z$.

In our case, we have that $X = \{x^n, y^n\}$ and $Z = \{\phi^n, \psi^n\}$. Then the log-likelihood of is given by,

$$
    \begin{align}
        \log p(X|Z) &= \log p(x^n, y^n|\phi^n, \psi^n) \\
        &= \log p(x_1|\phi_1) + \log p(y_1|\psi_1) + \sum_{i=2}^n (\log p(x_i|\phi_i) + \log p(y_i|\psi_i)).
    \end{align}
$$

Meanwhile, the surrogate log-posterior is given by

$$
    \begin{align}
        \log q(Z|X) &= \log q(\phi^n,\psi^n|x^n,y^n) \\
        &= \log q(\phi_1) + \log q(\psi_1) + \sum_{i=2}^n \left(\log q(\phi_i|\phi_{i-1}, x_{i-1}) + \log q(\psi_i|\psi_{i-1}, y_{i-1}) \right).  
    \end{align}
$$

The prior $\pi(Z)$ serves to regularize the surrogate posterior. In practice, this can be achieved implicitly by quantizing the latents such that they take on discrete values from a set $K$. Alternatively, if the latents are vectors, we can ensure their dimension is substantially less than the external states, and that they're penalized for exhibiting high entropy distributions. For instance, if $q(\phi_i|\phi_{i-1},x_{i-1})$ is a Gaussian, we can penalize divergence from $\mathcal{N}(0, 1)$.

In sum, the amortized loss for our model over an empirical dataset $\mathcal{D} = \{(x^n_1, y^n_1), \dots, (x^n_m, y^n_m)\}$ can be computed as follows:

> For each sequence in $\mathcal{D}$.
> 1. Sample $\phi_1, \psi_1$ from the agent prior.
> 2. Compute $A = \log p(x_1|\phi_1) + \log p(y_1|\psi_1)$ using the agent and environment emission models.
> 3. Compute $B = \sum_{i=2}^n \left( \log q(x_i|\phi_i) + \log q(y_i|\psi_i) \right)$ iteratively using the agent and environmen emission models. For sampling the latents, use the agent and environment transition models.
> 4. Compute $C = \sum_{i=1}^n R(\phi_i, \psi_i)$, where $R$ is a regularization measure such as KL divergence from $\mathcal{N}(0, 1)$. It can also be implicit in the case of quantization.
> 5. Compute $\mathcal{L} = A + B + C$.
>
> Sum the $\mathcal{L}$ losses for each sequence to obtain the final loss.

## 4. Neural Networks

In practice, we can parameterize our model using four neural networks. The first, known as the *agent encoder*, takes $x_i$ and $\phi_{i-1}$ and returns a distribution over $\phi_i$. This distribution may be discrete or continuous. The second, known as the *agent decoder*, takes a sample from $p(\phi_i)$ and returns a distribution over $x_i$. These two networks parameterize the agent transition and emission models respectively.

Likewise, the environment models can be parameterized by an *environment encoder* and *environment decoder*. They perform a symmetric role to the agent networks.

To compute the agent and environment priors, we have a few choices. We could *fix* them to simple distributions such as Gaussians. In this case, they have no free parameters and are not optimized during training. Alternatively we could learn their parameters.

## 5. Toy Implementations

We experiment with simple incarnations of the outlined idea.

###  5.1. Discreate External States, Continuous Internal States

In this example, we consider a simple model in which the agent and environment both emit discrete external states (specifically $x_i, y_i \in \{1, \dots, K\}$). Meanwhile, the latent states are modeled as continuous $D$-dimensional vectors. In reality, the hidden states of the agent and environment are discrete, however we are interested in whether the assumption of continuity still leads to strong performance.

First let's implement the true data generating process.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dst


def sample_real(sequence_length: int, batch_size: int) -> torch.Tensor:
    """Sample from the data generating process."""

    sequences = []

    # Models.

    agent_prior = dst.Categorical(probs=torch.tensor([0.1, 0.2, 0.3, 0.4]))
    environment_prior = dst.Categorical(probs=torch.tensor([0.5, 0.1, 0.3, 0.1]))

    agent_emission = {
        0: dst.Categorical(probs=torch.tensor([0.4, 0.1, 0.1, 0.1, 0.1, 0.1])),
        1: dst.Categorical(probs=torch.tensor([0.1, 0.3, 0.3, 0.1, 0.1, 0.1])),
        2: dst.Categorical(probs=torch.tensor([0.6, 0.0, 0.0, 0.0, 0.2, 0.1])),
        3: dst.Categorical(probs=torch.tensor([0.0, 0.1, 0.2, 0.0, 0.7, 0.0])),
    }

    environment_emission = {
        0: dst.Categorical(probs=torch.tensor([0.0, 0.5, 0.1, 0.4, 0.0, 0.0])),
        1: dst.Categorical(probs=torch.tensor([0.8, 0.1, 0.0, 0.0, 0.1, 0.0])),
        2: dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.0, 0.7, 0.0, 0.3])),
        3: dst.Categorical(probs=torch.tensor([0.0, 0.1, 0.1, 0.1, 0.7, 0.0])),
    }

    agent_transition = {
        (0, 0): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.1, 0.9])),
        (0, 1): dst.Categorical(probs=torch.tensor([0.7, 0.3, 0.0, 0.0])),
        (0, 2): dst.Categorical(probs=torch.tensor([0.1, 0.0, 0.1, 0.8])),
        (0, 3): dst.Categorical(probs=torch.tensor([0.0, 0.1, 0.0, 0.9])),
        (0, 4): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.5, 0.5])),
        (0, 5): dst.Categorical(probs=torch.tensor([0.5, 0.1, 0.4, 0.0])),

        (1, 0): dst.Categorical(probs=torch.tensor([0.0, 0.4, 0.0, 0.6])),
        (1, 1): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.3, 0.7])),
        (1, 2): dst.Categorical(probs=torch.tensor([0.5, 0.0, 0.5, 0.0])),
        (1, 3): dst.Categorical(probs=torch.tensor([0.3, 0.3, 0.4, 0.0])),
        (1, 4): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.5, 0.5])),
        (1, 5): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.7, 0.3])),

        (2, 0): dst.Categorical(probs=torch.tensor([0.0, 0.4, 0.6, 0.0])),
        (2, 1): dst.Categorical(probs=torch.tensor([0.0, 0.2, 0.8, 0.0])),
        (2, 2): dst.Categorical(probs=torch.tensor([0.0, 0.8, 0.1, 0.1])),
        (2, 3): dst.Categorical(probs=torch.tensor([0.3, 0.4, 0.3, 0.0])),
        (2, 4): dst.Categorical(probs=torch.tensor([0.0, 0.1, 0.1, 0.8])),
        (2, 5): dst.Categorical(probs=torch.tensor([0.1, 0.1, 0.0, 0.8])),

        (3, 0): dst.Categorical(probs=torch.tensor([0.1, 0.0, 0.0, 0.9])),
        (3, 1): dst.Categorical(probs=torch.tensor([0.7, 0.3, 0.3, 0.0])),
        (3, 2): dst.Categorical(probs=torch.tensor([0.1, 0.7, 0.1, 0.1])),
        (3, 3): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.5, 0.5])),
        (3, 4): dst.Categorical(probs=torch.tensor([0.2, 0.2, 0.0, 0.6])),
        (3, 5): dst.Categorical(probs=torch.tensor([0.2, 0.8, 0.0, 0.1])),
    }

    environment_transition = agent_transition
    # environment_transition = {
    #     (0, 0): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.1, 0.9])),
    #     (0, 1): dst.Categorical(probs=torch.tensor([0.7, 0.3, 0.0, 0.0])),
    #     (0, 2): dst.Categorical(probs=torch.tensor([0.1, 0.0, 0.1, 0.8])),
    #     (0, 3): dst.Categorical(probs=torch.tensor([0.0, 0.1, 0.0, 0.9])),
    #     (0, 4): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.5, 0.5])),
    #     (0, 5): dst.Categorical(probs=torch.tensor([0.5, 0.1, 0.4, 0.0])),

    #     (1, 0): dst.Categorical(probs=torch.tensor([0.0, 0.4, 0.0, 0.6])),
    #     (1, 1): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.3, 0.7])),
    #     (1, 2): dst.Categorical(probs=torch.tensor([0.5, 0.0, 0.5, 0.0])),
    #     (1, 3): dst.Categorical(probs=torch.tensor([0.3, 0.3, 0.4, 0.0])),
    #     (1, 4): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.5, 0.5])),
    #     (1, 5): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.7, 0.3])),

    #     (2, 0): dst.Categorical(probs=torch.tensor([0.0, 0.4, 0.6, 0.0])),
    #     (2, 1): dst.Categorical(probs=torch.tensor([0.0, 0.2, 0.8, 0.0])),
    #     (2, 2): dst.Categorical(probs=torch.tensor([0.0, 0.8, 0.1, 0.1])),
    #     (2, 3): dst.Categorical(probs=torch.tensor([0.3, 0.4, 0.3, 0.0])),
    #     (2, 4): dst.Categorical(probs=torch.tensor([0.0, 0.1, 0.1, 0.8])),
    #     (2, 5): dst.Categorical(probs=torch.tensor([0.1, 0.1, 0.0, 0.8])),

    #     (3, 0): dst.Categorical(probs=torch.tensor([0.1, 0.0, 0.0, 0.9])),
    #     (3, 1): dst.Categorical(probs=torch.tensor([0.7, 0.3, 0.3, 0.0])),
    #     (3, 2): dst.Categorical(probs=torch.tensor([0.1, 0.7, 0.1, 0.1])),
    #     (3, 3): dst.Categorical(probs=torch.tensor([0.0, 0.0, 0.5, 0.5])),
    #     (3, 4): dst.Categorical(probs=torch.tensor([0.2, 0.2, 0.0, 0.6])),
    #     (3, 5): dst.Categorical(probs=torch.tensor([0.2, 0.8, 0.0, 0.1])),
    # }

    # Sampling.

    xs = []
    ys = []

    for _ in range(batch_size):

        # Sample initial internal states.

        phi = agent_prior.sample().item()
        psi = environment_prior.sample().item()

        # Sample initial external states.

        x = agent_emission[phi].sample().item()
        y = environment_emission[psi].sample().item()

        # Sample remaining internal and external states.

        for i in range(sequence_length):

            xs.append(x)
            ys.append(y)

            phi = agent_transition[(phi, y)].sample().item()
            psi = environment_transition[(psi, x)].sample().item()
            x = agent_emission[phi].sample().item()
            y = environment_emission[psi].sample().item()


    xs = torch.tensor(xs, dtype=torch.int64).view(batch_size, sequence_length)
    ys = torch.tensor(ys, dtype=torch.int64).view(batch_size, sequence_length)

    return xs, ys

In [None]:
sample_real(20, 5)

(tensor([[4, 4, 4, 4, 4, 0, 4, 0, 1, 4, 4, 0, 0, 2, 5, 4, 0, 5, 2, 2],
         [0, 3, 0, 0, 4, 2, 0, 2, 0, 2, 4, 4, 2, 2, 4, 2, 4, 1, 4, 1],
         [4, 0, 4, 2, 4, 4, 4, 4, 0, 1, 4, 4, 4, 0, 2, 0, 0, 4, 4, 4],
         [1, 0, 0, 2, 2, 0, 0, 1, 0, 4, 1, 1, 1, 4, 0, 0, 1, 4, 5, 0],
         [0, 4, 0, 4, 0, 4, 1, 0, 4, 4, 1, 2, 1, 4, 4, 0, 3, 0, 4, 4]],
        dtype=torch.int8),
 tensor([[1, 4, 4, 0, 5, 3, 4, 4, 0, 4, 1, 5, 0, 1, 0, 3, 1, 2, 0, 3],
         [3, 5, 0, 4, 1, 4, 0, 2, 0, 0, 3, 1, 4, 0, 3, 4, 4, 0, 4, 0],
         [3, 5, 5, 4, 3, 5, 4, 1, 3, 0, 4, 4, 4, 1, 2, 1, 4, 3, 4, 4],
         [3, 3, 3, 0, 3, 1, 3, 2, 0, 4, 4, 3, 0, 2, 4, 4, 4, 1, 3, 2],
         [0, 4, 4, 1, 3, 1, 4, 5, 3, 2, 4, 1, 4, 0, 2, 2, 3, 4, 4, 0]],
        dtype=torch.int8))

In [None]:
from typing import Tuple


internal_cardinality = 4
external_cardinality = 6


class Encoder(nn.Module):

    def __init__(
        self,
        *,
        input_dimension: int,
        latent_dimension: int,
        hidden_dimension: int,
    ) -> None:
        """Initialize the module."""

        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(in_features=input_dimension + latent_dimension, out_features=hidden_dimension),
            nn.Tanh(),
            nn.LayerNorm(normalized_shape=hidden_dimension),
            nn.Linear(in_features=hidden_dimension, out_features=hidden_dimension // 2 + latent_dimension),
            nn.Tanh(),
            nn.LayerNorm(normalized_shape=hidden_dimension // 2 + latent_dimension),
            nn.Linear(in_features=hidden_dimension // 2 + latent_dimension, out_features=latent_dimension),
        )

    def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        """Forward the module."""

        return self.layers(torch.cat((x, z), dim=-1))


class Decoder(nn.Module):

    def __init__(
        self,
        *,
        input_dimension: int,
        latent_dimension: int,
        hidden_dimension: int,
    ) -> None:
        """Initialize the module."""

        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(in_features=latent_dimension, out_features=hidden_dimension // 2 + latent_dimension),
            nn.LeakyReLU(),
            nn.LayerNorm(normalized_shape=hidden_dimension // 2 + latent_dimension),
            nn.Linear(in_features=hidden_dimension // 2 + latent_dimension, out_features=hidden_dimension),
            nn.LeakyReLU(),
            nn.LayerNorm(normalized_shape=hidden_dimension),
            nn.Linear(in_features=hidden_dimension, out_features=input_dimension),
            nn.LogSoftmax(dim=-1),
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """Forward the module."""

        return self.layers(z)


class VBM(nn.Module):

    def __init__(
        self,
        *,
        input_dimension: int,
        latent_dimension: int,
        hidden_dimension: int,
    ) -> None:
        super().__init__()

        self.input_dimension = input_dimension
        self.latent_dimension = latent_dimension

        self.agent_encoder = Encoder(
            input_dimension=input_dimension,
            latent_dimension=latent_dimension,
            hidden_dimension=hidden_dimension,
        )

        self.agent_decoder = Decoder(
            input_dimension=input_dimension,
            latent_dimension=latent_dimension,
            hidden_dimension=hidden_dimension,
        )

        self.environment_encoder = Encoder(
            input_dimension=input_dimension,
            latent_dimension=latent_dimension,
            hidden_dimension=hidden_dimension,
        )

        self.environment_decoder = Decoder(
            input_dimension=input_dimension,
            latent_dimension=latent_dimension,
            hidden_dimension=hidden_dimension,
        )

    def sample_prior(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:

        phi = torch.randn((batch_size, self.latent_dimension))
        psi = torch.randn((batch_size, self.latent_dimension))

        return phi, psi

    def forward(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        phi: torch.Tensor,
        psi: torch.Tensor,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        # Forward pass for a single time step.

        x = F.one_hot(x, num_classes=self.input_dimension)
        y = F.one_hot(y, num_classes=self.input_dimension)

        phi_mean = self.agent_encoder(y, phi)
        psi_mean = self.environment_encoder(x, psi)

        phi = phi_mean + torch.randn_like(phi_mean, device=x.device)
        psi = psi_mean + torch.randn_like(psi_mean, device=x.device)

        x_hat = self.agent_decoder(phi)
        y_hat = self.environment_decoder(psi)

        return x_hat, y_hat, phi, psi, phi_mean, psi_mean

In [None]:
def vbm_loss(x_batch: torch.Tensor, y_batch: torch.Tensor, vbm: VBM, kl_weight: float) -> torch.Tensor:

    batch_size, sequence_length = x_batch.shape
    phi, psi = vbm.sample_prior(batch_size)
    loss = 0.

    for i in range(1, sequence_length):

        x = x_batch[:, i]  # Shape: (B,)
        y = y_batch[:, i]
        x_hat, y_hat, phi, psi, phi_mean, psi_mean = vbm(x, y, phi.detach(), psi.detach())

        kl_loss = torch.square(phi_mean).mean() + torch.square(psi_mean).mean()  # Squared distance from 0.
        reconstruction_loss = F.nll_loss(x_hat, x, reduction='mean') + F.nll_loss(y_hat, y, reduction='mean')

        loss = loss + reconstruction_loss + (kl_weight * kl_loss)

    return loss / sequence_length

In [None]:
model = VBM(input_dimension=6, latent_dimension=2, hidden_dimension=16)

In [None]:
epochs = 20
batch_size = 32
batches_per_epoch = 100
sequence_length = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    for batch in range(batches_per_epoch):

        x_batch, y_batch = sample_real(sequence_length, batch_size)
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        loss = vbm_loss(x_batch, y_batch, model, kl_weight=0.0)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        if (batch % 10) == 0:
            print(f'epoch: {epoch:03d}, batch: {batch:03d} - loss: {loss.detach().item():0.5f}')

epoch: 000, batch: 000 - loss: 3.44289
epoch: 000, batch: 010 - loss: 2.89749
epoch: 000, batch: 020 - loss: 2.93161
epoch: 000, batch: 030 - loss: 2.84779
epoch: 000, batch: 040 - loss: 2.79713
epoch: 000, batch: 050 - loss: 2.84976
epoch: 000, batch: 060 - loss: 2.82584
epoch: 000, batch: 070 - loss: 2.84785
epoch: 000, batch: 080 - loss: 2.82977
epoch: 000, batch: 090 - loss: 2.80417
epoch: 001, batch: 000 - loss: 2.76226
epoch: 001, batch: 010 - loss: 2.82369
epoch: 001, batch: 020 - loss: 2.82867
epoch: 001, batch: 030 - loss: 2.80269
epoch: 001, batch: 040 - loss: 2.88523
epoch: 001, batch: 050 - loss: 2.90603
epoch: 001, batch: 060 - loss: 2.90612
epoch: 001, batch: 070 - loss: 2.88173
epoch: 001, batch: 080 - loss: 2.78233
epoch: 001, batch: 090 - loss: 2.82023
epoch: 002, batch: 000 - loss: 2.80457
epoch: 002, batch: 010 - loss: 2.77553
epoch: 002, batch: 020 - loss: 2.83040
epoch: 002, batch: 030 - loss: 2.92179
epoch: 002, batch: 040 - loss: 2.80141
epoch: 002, batch: 050 - 

KeyboardInterrupt: 

In [None]:
vbm = VBM(input_dimension=6, latent_dimension=2, hidden_dimension=16)

In [None]:
x_batch, y_batch = sample_real(sequence_length=10, batch_size=8)

In [None]:
vbm_loss(x_batch, y_batch, vbm, kl_weight=1.)

tensor(3.3466, grad_fn=<DivBackward0>)

In [None]:
x_batch

tensor([[2, 2, 2, 1, 2, 4, 3, 2, 2, 0],
        [4, 1, 4, 1, 4, 2, 4, 4, 0, 4],
        [2, 4, 2, 2, 2, 2, 5, 4, 0, 0],
        [2, 0, 2, 4, 4, 4, 4, 0, 0, 4],
        [4, 0, 0, 5, 0, 0, 1, 2, 5, 4],
        [2, 2, 1, 3, 4, 0, 5, 0, 0, 4],
        [4, 5, 4, 2, 2, 4, 0, 4, 4, 1],
        [2, 0, 2, 4, 2, 0, 4, 4, 4, 5]], dtype=torch.int8)