In [13]:
from typing import Callable
from collections.abc import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F

class DisRNN(nn.Module):
    def __init__(
        self,
        obs_size=2,
        target_size=1,
        latent_size=10,
        update_mlp_shape=(10, 10, 10),
        choice_mlp_shape=(10, 10, 10),
        eval_mode=0,
        beta_scale=1,
        activation=nn.ReLU,
    ):
        super().__init__()
        self.target_size = target_size
        self.latent_size = latent_size
        self.beta_scale = beta_scale
        self.eval_mode = eval_mode
        self.activation = activation

        mlp_input_size = latent_size + obs_size
        
        # Initialize update MLP parameters
        self.update_mlp_sigmas_unsquashed = nn.Parameter(
            torch.empty(mlp_input_size, latent_size).uniform_(-3, -2)
        )
        self.update_mlp_multipliers = nn.Parameter(
            torch.ones(mlp_input_size, latent_size)
        )
        
        # Initialize latent parameters
        self.latent_sigmas_unsquashed = nn.Parameter(
            torch.empty(latent_size).uniform_(-3, -2)
        )
        self.latent_inits = nn.Parameter(
            torch.empty(latent_size).uniform_(-1, 1)
        )

        # Create MLPs for each latent
        self.update_mlps = nn.ModuleList([
            MLP(mlp_input_size, update_mlp_shape, 2, activation)
            for _ in range(latent_size)
        ])
        
        # Choice MLP
        self.choice_mlp = MLP(latent_size, choice_mlp_shape, target_size, activation)

    def forward(self, observations, prev_latents):
        batch_size = observations.shape[0]
        penalty = torch.zeros(batch_size, device=observations.device)

        # Update MLPs
        update_mlp_sigmas = 2 * torch.sigmoid(self.update_mlp_sigmas_unsquashed) * (1 - self.eval_mode)
        update_mlp_mus_unscaled = torch.cat((observations, prev_latents), dim=1)
        update_mlp_mus = update_mlp_mus_unscaled.unsqueeze(2) * self.update_mlp_multipliers
        
        # Calculate updates for each latent
        new_latents = torch.zeros_like(prev_latents)
        for i in range(self.latent_size):
            # Add noise to inputs
            noise = torch.randn_like(update_mlp_mus[:, :, i]) * update_mlp_sigmas[:, i]
            update_mlp_inputs = update_mlp_mus[:, :, i] + noise
            
            # Calculate KL divergence
            kl = 0.5 * torch.sum(-torch.log(update_mlp_sigmas[:, i]) - 1.0 + 
                                update_mlp_sigmas[:, i] + update_mlp_mus[:, :, i].pow(2), dim=1)
            penalty += self.beta_scale * kl
            
            # Calculate update and weight
            update = self.update_mlps[i](update_mlp_inputs).squeeze(-1)
            w = torch.sigmoid(self.update_weight_mlps[i](update_mlp_inputs)).squeeze(-1)
            
            # Update latent
            new_latent = w * update + (1 - w) * prev_latents[:, i]
            new_latents[:, i] = new_latent

        # Global bottleneck
        latent_sigmas = 2 * torch.sigmoid(self.latent_sigmas_unsquashed) * (1 - self.eval_mode)
        noised_up_latents = new_latents + latent_sigmas * torch.randn_like(new_latents)
        penalty += torch.sum(-torch.log(latent_sigmas) - 1.0 + 
                           latent_sigmas + new_latents.pow(2), dim=1)

        # Choice MLP
        y_hat = self.choice_mlp(noised_up_latents)
        
        # Append penalty
        penalty = penalty.unsqueeze(1) * (1 - self.eval_mode)
        output = torch.cat((y_hat, penalty), dim=1)

        return output, noised_up_latents

    def initial_state(self, batch_size, device=None):
        return self.latent_inits.unsqueeze(0).repeat(batch_size, 1).to(device)


class MLP(nn.Module):
    """A simple multi-layer perceptron (MLP) network.
    
    The MLP is a series of linear layers with an activation function applied after each hidden layer
    but not after the final output layer. All layers are nested in a `nn.Sequential` container.
    """
    def __init__(self, input_size: int, hidden_sizes: Sequence[int], output_size: int, activation: Callable=nn.ReLU):
        """Initialize the MLP.

        Parameters
        ----------
        input_size : int
            The size of the input tensor.
        hidden_sizes : Sequence[int]
            A sequence of integers representing the size of each hidden layer.
        output_size : int
            The size of the output tensor.
        activation : Callable
            The activation function to apply after each hidden layer. Default=nn.ReLU.
        with_hooks : bool
            Whether to include hooks in the model for inspection of intermediate activations. Default=False.
        """
        super().__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.activation = activation
        
        mlp = []
        prev_size = input_size
        for size in hidden_sizes:
            mlp.extend([
                nn.Linear(prev_size, size),
                self.activation()
            ])
            prev_size = size
        
        mlp.append(nn.Linear(prev_size, output_size))
        
        self.mlp = nn.Sequential(*mlp)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the MLP."""
        return self.mlp(x)

def kl_gaussian(mean: torch.Tensor, var: torch.Tensor, dim: int = -1) -> torch.Tensor:
  r"""Calculate KL divergence between given and standard gaussian distributions.

  KL(p, q) = H(p, q) - H(p) = -\int p(x)log(q(x))dx - -\int p(x)log(p(x))dx
          = 0.5 * [log(|s2|/|s1|) - 1 + tr(s1/s2) + (m1-m2)^2/s2]
          = 0.5 * [-log(|s1|) - 1 + tr(s1) + m1^2] (if m2 = 0, s2 = 1)
  Args:
    mean: mean vector of the first distribution
    var: diagonal vector of covariance matrix of the first distribution

  Returns:
    A scalar representing KL divergence of the two Gaussian distributions.
  """
  return 0.5 * torch.sum(-torch.log(var) - 1.0 + var + torch.square(mean), dim=dim)


class DynamicDisRNN(nn.Module):
    def __init__(self, obs_size=2, target_size=1, latent_size=10, update_mlp_shape=(10, 10, 10), choice_mlp_shape=(10, 10, 10), eval_mode=0, beta_scale=1, activation=nn.ReLU):
        super().__init__()
        self.target_size = target_size
        self.latent_size = latent_size
        self.obs_size = obs_size
        self.beta_scale = beta_scale
        self.eval_mode = eval_mode
        self.activation = activation

        mlp_input_size = latent_size + obs_size
        
        # to_mu_update takes in the observations and latent states and transforms them with a single linear layer
        # and a sigmoidal activation function to produce a set of multipliers that determine how much to use each
        # potential input to the UpdateMLP for each latent state.
        self.to_mu_update = nn.ModuleList([
            nn.Sequential(nn.Linear(mlp_input_size, mlp_input_size), nn.Sigmoid()) for _ in range(latent_size)
        ])

        # to_sigma_update takes in the observations and latent states and transforms them with a single linear layer
        # and a sigmoidal activation function to produce a set of sigmas that determine the amount of noise to add to
        # the inputs to the UpdateMLP for each latent state.
        self.to_sigma_update = nn.ModuleList([nn.Linear(mlp_input_size, mlp_input_size) for _ in range(latent_size)])

        # update_mlp takes in the transformed observations and latent states and produces a set of updates for each
        # latent state. There is an update_mlp for each latent state.
        self.update_mlps = nn.ModuleList([
            MLP(mlp_input_size, update_mlp_shape, 2, activation=activation)
            for _ in range(latent_size)
        ])

        # to_mu_latent takes in the observations and latent states and transforms them with a single linear layer
        # and a sigmoidal activation function to produce a set of multipliers that determine how much to use each
        # previous latent state for the next time step. to_sigma_latent does the same thing for noise on each latent.
        self.to_mu_latent = nn.Sequential(nn.Linear(mlp_input_size, latent_size), nn.Sigmoid())
        self.to_sigma_latent = nn.Linear(mlp_input_size, latent_size)

        # to_mu_choice takes in the latent states and transforms them with a single linear layer and a sigmoidal
        # activation function to produce a set of multipliers that determine how much to use each latent state and
        # observation for the choice MLP.
        self.to_mu_choice = nn.Sequential(nn.Linear(mlp_input_size, mlp_input_size), nn.Sigmoid())

        # to_sigma_choice takes in the latent states and transforms them with a single linear layer and a sigmoidal
        # activation function to produce a set of sigmas that determine the amount of noise to add to the inputs to
        # the choice MLP.
        self.to_sigma_choice = nn.Linear(mlp_input_size, mlp_input_size)
 
        # choice_mlp takes in the transformed latent states and observations and produces the output of the model.
        self.choice_mlp = MLP(mlp_input_size, choice_mlp_shape, target_size, activation=activation)

        # set up an initial latent parameter that is learned
        self.latent_inits = nn.Parameter(torch.empty(latent_size).uniform_(-1, 1))

    def _step(self, observations, prev_latents):
        """Forward method for the dynamic disentangled RNN class."""
        batch_size = observations.shape[0]
        penalty = torch.zeros(batch_size, device=observations.device)

        # Concatenate observations and latent states, which is the input to every component of the network
        obs_plus_latents = torch.cat((observations, prev_latents), dim=1) 

        # Measure mu and sigma for the inputs to each update MLP
        mu_to_updates = torch.stack([mu(obs_plus_latents) for mu in self.to_mu_update], dim=1)
        sigma_to_updates = torch.stack([sigma(obs_plus_latents) for sigma in self.to_sigma_update], dim=1)
        
        # Apply information bottleneck to the inputs to each update MLP
        update_mlp_inputs = mu_to_updates * obs_plus_latents.unsqueeze(1) + sigma_to_updates * torch.randn_like(obs_plus_latents.unsqueeze(1))
        
        # Add penalty for KL divergence on inputs to each update MLP
        penalty += self.beta_scale * kl_gaussian(update_mlp_inputs, sigma_to_updates, dim=(1, 2))

        # Measure the updates weights and value for each latent state
        updates = torch.stack([mlp(update_mlp_inputs[:, i]) for i, mlp in enumerate(self.update_mlps)], dim=1)
        target = updates[:, :, 0]
        weight = torch.sigmoid(updates[:, :, 1])

        # Update latents with weighted updates from UpdateMLPs
        new_latents = weight * target + (1 - weight) * prev_latents

        # Measure mu and sigma for the latent states
        mu_latent = self.to_mu_latent(obs_plus_latents)
        sigma_latent = self.to_sigma_latent(obs_plus_latents)

        # Apply information bottleneck to the latent states
        noised_latents = mu_latent * new_latents + sigma_latent * torch.randn_like(new_latents)
        
        # Add penalty for KL divergence on latent states
        penalty += self.beta_scale * kl_gaussian(noised_latents, sigma_latent)
        
        # Output of the choice MLP depends on the inputs and the updated bottlenecked latent states
        obs_plus_new_latents = torch.cat((observations, noised_latents), dim=1)

        # Measure mu and sigma for the inputs to the choice MLP
        mu_to_choice = self.to_mu_choice(obs_plus_new_latents)
        sigma_to_choice = self.to_sigma_choice(obs_plus_new_latents)

        # Apply information bottleneck to the inputs to the choice MLP
        choice_mlp_inputs = mu_to_choice * obs_plus_new_latents + sigma_to_choice * torch.randn_like(obs_plus_new_latents)

        # Add penalty for KL divergence on inputs to the choice MLP
        penalty += self.beta_scale * kl_gaussian(choice_mlp_inputs, sigma_to_choice)

        # Measure choice output
        choice_output = self.choice_mlp(choice_mlp_inputs)

        return choice_output, noised_latents, penalty
    
    def forward(self, observations, prev_latents):
        batch_size = observations.shape[0]
        seq_size = observations.shape[1]
        outputs = torch.zeros(batch_size, seq_size, self.target_size, device=observations.device)
        latents = torch.zeros(batch_size, seq_size, self.latent_size, device=observations.device)
        penalties = torch.zeros(batch_size, seq_size, device=observations.device)
        for i in range(seq_size):
            output, latent, penalty = self._step(observations[:, i], prev_latents)
            outputs[:, i] = output
            latents[:, i] = latent
            penalties[:, i] = penalty
            prev_latents = latent
        return outputs, latents, penalties

    def initial_state(self, batch_size, device=None):
        return self.latent_inits.unsqueeze(0).repeat(batch_size, 1).to(device)
        

obs_size = 2
target_size = 1
latent_size = 10
batch_size = 15

ddrnn = DynamicDisRNN(obs_size=obs_size, target_size=target_size, latent_size=latent_size, update_mlp_shape=(latent_size,), choice_mlp_shape=(latent_size,), eval_mode=0, beta_scale=1, activation=nn.ReLU)

x = torch.randn(batch_size, obs_size)
h = ddrnn.initial_state(batch_size)

out, hidden = ddrnn(x, h)

RuntimeError: The size of tensor a (15) must match the size of tensor b (10) at non-singleton dimension 0