In [10]:
import numpy as np
import torch
from torch import nn

### Define parameters

In [29]:
N = 11 # nodes
D = N - 1 # dimensions

In [30]:
# define transfer functions
f_d = nn.Softplus()
f_l = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0])) # call using f_l.cdf(x)

In [31]:
# define latent variables
d_t = lambda z_d: f_d(z_d)
c_t = lambda z_c: zc
a_t = lambda z_a: z_a # a will be orthogonalized during trajectory construction
l_t = lambda z_l: 0.06 * f_l.cdf(z_l)

### Initialization of $p_{\theta}(z)$

In [157]:
# get distributions
z_d, z_c, z_a, z_l = Z_t(mu_d, sigma_d, mu_c, sigma_c, mu_a, sigma_a, mu_l, sigma_l, N)

In [139]:
# sample initial values
z_d_init = z_d.sample()
z_c_init = z_c.sample()
z_a_init = z_a.sample()
z_l_init = z_l.sample()

Z_init = [z_d_init, z_c_init, z_a_init, z_l_init]

### Initialization of $q_{\phi}(z|n)$

### Define optimizer

In [264]:
class ELBO(nn.Module):
    """
    Class for optimizing the ELBO term. 

    Input:
    ------
    N: Scalar
        Number of nodes
        
    """
    
    def __init__(self, N):
        super(ELBO, self).__init__()

        # define small noise variables
        eps = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0])) 
        eta = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([np.pi])) 

        # initialize means of the prior
        mu_d = torch.tensor([0.2 * eps.sample() for _ in range(N - 1)])
        mu_c = torch.FloatTensor(N - 2).uniform_(0, np.pi)
        mu_a = torch.zeros((N - 2) * (N - 1)).reshape(N - 2, -1)
        mu_l = torch.tensor([0.0])

        # initialize (diagonal) covariance matrices of the prior
        sigma_d = torch.diag(torch.ones(N - 1))
        sigma_c = torch.diag(torch.ones(N - 2))
        sigma_a = torch.diag(torch.ones((N - 1) * (N - 2)))
        sigma_l = torch.diag(torch.tensor([1.0]))

        # define trainable parameters
        self.mu_prior = nn.Parameter(torch.cat((mu_d, mu_c, torch.flatten(mu_a), mu_l)))
        self.sigma_prior = nn.Parameter(torch.block_diag(sigma_d, sigma_c, sigma_a, sigma_l))
        self.mu_posterior = nn.Parameter(torch.randn(self.mu_prior.shape))
        self.sigma_posterior = nn.Parameter(torch.abs(torch.randn(self.sigma_prior.shape))) # variance cannot be negative

    def kl_divergence(self):
        prior = torch.distributions.MultivariateNormal(self.mu_prior, self.sigma_prior)
        posterior = torch.distributions.MultivariateNormal(self.mu_posterior, self.sigma_posterior)

In [265]:
elbo = ELBO(N)
list(elbo.parameters())

[Parameter containing:
 tensor([-0.1254, -0.1211,  0.1003,  0.0421,  0.0289, -0.4368,  0.2253,  0.0650,
          0.1556,  0.1639,  3.0763,  2.6871,  2.0062,  1.0322,  2.4593,  0.0572,
          0.4935,  0.3112,  2.2648,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
     