In [2]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal as mvn

In [3]:
class Bi_Gaussian():
    def __init__(self, mu1, mu2, sigma1, sigma2, rho, CUDA, device):
        super().__init__()
        self.mu1 = mu1
        self.mu2 = mu2
        self.sigma1 = sigma1
        self.sigma2 = sigma2
        self.rho = rho
        if CUDA:
            with torch.cuda.device(device):
                self.mu1 = self.mu1.cuda()
                self.mu2 = self.mu2.cuda()
                self.sigma1 = self.sigma1.cuda()
                self.sigma2 = self.sigma2.cuda()
                self.rho = self.rho.cuda()

    def conditional(self, x, cond):
        """
        return parameters of conditional distribution X1 | X2 or X2 | X1
        based on the cond flag
        """
        if cond == 'x2': ## X1 | X2
            cond_mu = self.mu1 + (self.sigma1 / self.sigma2) * (x - self.mu2) * self.rho
            cond_sigma = ((1 - self.rho ** 2) * (self.sigma1**2)).sqrt()
        else: ## X2 | X1
            cond_mu = self.mu2 + (self.sigma2 / self.sigma1) * (x - self.mu1) * self.rho
            cond_sigma = ((1 - self.rho ** 2) * (self.sigma2**2)).sqrt()
        return cond_mu, cond_sigma

    def log_pdf_gamma(self, x, mu, sigma):
        """
        return log unormalized density of a univariate Gaussian
        """
        return Normal(mu, sigma).log_prob(x)
        # return ((x - mu)**2) / (-2 * sigma**2)

    def Joint(self):
        """
        return the paramters for the bivariate Gaussian given the indiviual parameters
        """
        Mu = torch.cat((self.mu1.unsqueeze(-1), self.mu2.unsqueeze(-1)), -1)
        cov = self.rho * self.sigma1 * self.sigma2
        part1 = torch.cat((self.sigma1.unsqueeze(-1)**2, cov.unsqueeze(-1)), -1)
        part2 = torch.cat((cov.unsqueeze(-1), self.sigma2.unsqueeze(-1)**2), -1)
        Sigma = torch.cat((part1.unsqueeze(-1), part2.unsqueeze(-1)), -1)
        return Mu, Sigma

    def Marginal(self, x, name):
        if name == 'x1':
            return Normal(self.mu1, self.sigma1).log_prob(x)
        else:
            return Normal(self.mu2, self.sigma2).log_prob(x)

In [4]:
CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda:1')

In [None]:
## Define a target bivariate Gaussian for unit test
mu1 = torch.ones(1) * 5.0
mu2 = torch.ones(1) * 8.0
sigma1 = torch.ones(1) * 1.0
sigma2 = torch.ones(1) * 2.5
rho = torch.ones(1) * 0.6
bg = Bi_Gaussian(mu1, mu2, sigma1, sigma2, rho, CUDA=True, device=DEVICE)
