In [1]:
import torch
import torch.nn as nn
# import gibbs_sampler_poise
# import kl_divergence_calculator
from numpy import prod

In [2]:
_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def _latent_dims_type_setter(lds):
    ret, ret_flatten = [], []
    for ld in lds:
        if hasattr(ld, '__iter__'): # Iterable
            ld_tuple = tuple([i for i in ld])
            if not all(map(lambda i: isinstance(i, int), ld_tuple)):
                raise ValueError('`latent_dim` must be either iterable of ints or int.')
            ret.append(ld_tuple)
            ret_flatten.append(int(prod(ld_tuple)))
        elif isinstance(ld, int):
            ret.append((ld, ))
            ret_flatten.append(ld)
        else:
            raise ValueError('`latent_dim` must be either iterable of ints or int.')
    return ret, ret_flatten


class POISEVAE(nn.Module):
    __version__ = 2.0
    
    def __init__(self, encoders, decoders, batch_size, loss, latent_dims=None,
                 device=_device):
        """
        encoders: list of nn.Module
            Each encoder must have an attribute `latent_dim` specifying the dimension of the
            latent space to which it encodes. An alternative way to avoid adding this attribute
            is to specify the `latent_dims` parameter (see below). 
            Note that each `latent_dim` must be unsqueezed, e.g. (10, ) is not the same as (10, 1).
            
        decoders: list of nn.Module
            The number and indices of decoders must match those of encoders.
            
        batch_size: int
        
        loss: str
            Can either be 'MSE' for MSE loss or 'BCE' for BCE loss. The users should properly 
            restrict the range of the output of their decoders for the loss chosen.
        
        latent_dims: iterable, optional; default None
            The dimensions of the latent spaces to which the encoders encode. The indices of the 
            entries must match those of encoders. An alternative way to specify the dimensions is
            to add the attribute `latent_dim` to each encoder (see above).
            Note that each entry must be unsqueezed, e.g. (10, ) is not the same as (10, 1).
        
        device: torch.device, optional
        """
        super(POISEVAE,self).__init__()

        if len(encoders) != len(decoders):
            raise ValueError('The number of encoders must match that of decoders.')
        
        if len(encoders) > 2:
            raise NotImplementedError('> 3 latent spaces not yet supported.')
        
        # Type check
        if not all(map(lambda x: isinstance(x, nn.Module), (*encoders, *decoders))):
            raise TypeError('`encoders` and `decoders` must be lists of `nn.Module` class.')

        # Get the latent dimensions
        if latent_dims is not None:
            if not hasattr(latent_dims, '__iter__'): # Iterable
                raise TypeError('`latent_dims` must be iterable.')
            self.latent_dims = latent_dims
        else:
            self.latent_dims = tuple(map(lambda l: l.latent_dim, encoders))
        self.latent_dims, self.latent_dims_flatten = _latent_dims_type_setter(self.latent_dims)

        if batch_size <= 0:
            raise ValueError('Invalid batch size')
        self.batch_size = batch_size
        
        if loss not in ['MSE', 'BCE']: 
            raise NotImplementedError('Not yet supported for other loss functions')
        self.loss = loss
        
        self.encoders = nn.ModuleList(encoders)
        self.decoders = nn.ModuleList(decoders)
        
        self.device = device

        self.gibbs = gibbs_sampler(self.latent_dims_flatten, batch_size)
        self.kl_div = kl_divergence(self.latent_dims_flatten, batch_size)

        self.register_parameter(name='g11', 
                                param=nn.Parameter(torch.randn(*self.latent_dims_flatten, 
                                                               device=self.device)))
        self.register_parameter(name='g22', 
                                param=nn.Parameter(torch.randn(*self.latent_dims_flatten, 
                                                               device=self.device)))
        self.flag_initialize = 1

    def _decoder_helper(self):
        """
        Reshape samples drawn from each latent space, and decode with considering the loss function
        """
        ret = []
        for decoder, z, ld in zip(self.decoders, self.z_gibbs_posteriors, self.latent_dims):
            z = z.view(self.batch_size, *ld) # Match the shape to the output
            x_ = decoder(z)
            ret.append(x_)
        return ret

    def forward(self, x):
        """
        Return
        ------
        results: dict
            z: list of torch.Tensor
                Samples from the posterior distributions in the corresponding latent spaces
            x_rec: list of torch.Tensor
                Reconstructed samples
            mu: list of torch.Tensor
                Posterior distribution means
            var: list of torch.Tensor
                Posterior distribution variances
            total_loss: torch.Tensor
            rec_losses: list of torch.tensor
                Reconstruction loss for each dataset
            KL_loss: torch.Tensor
        """
        mu, var = [], []
        for i, xi in enumerate(x):
            _mu, _log_var = self.encoders[i].forward(xi)
            mu.append(_mu.view(self.batch_size, -1))
            var.append(-torch.exp(_log_var.view(self.batch_size, -1)))

        g22 = -torch.exp(self.g22)

        # Initializing gibbs sample
        if self.flag_initialize == 1:
            z_priors = self.gibbs.sample(self.g11, g22, n_iterations=5000)
            z_posteriors = self.gibbs.sample(self.g11, g22, lambda1s=mu, lambda2s=var,
                                             n_iterations=5000)

            self.z_priors = z_priors
            self.z_posteriors = z_posteriors
            self.flag_initialize = 0

        z_priors = list(map(lambda z: z.detach(), self.z_priors))
        z_posteriors = list(map(lambda z: z.detach(), self.z_posteriors))

        # If lambda not provided, treat as zeros to save memory and computation
        self.z_gibbs_priors = self.gibbs.sample(self.g11, g22, z=z_priors, n_iterations=5)
        self.z_gibbs_posteriors = self.gibbs.sample(self.g11, g22, lambda1s=mu, lambda2s=var,
                                                    z=z_posteriors, n_iterations=5)

        self.z_priors = list(map(lambda z: z.detach(), self.z_gibbs_priors))
        self.z_posteriors = list(map(lambda z: z.detach(), self.z_gibbs_posteriors))

        G = torch.block_diag(self.g11, self.g22)

        x_ = self._decoder_helper() # Decoding

        # self.z2_gibbs_posterior = self.z2_gibbs_posterior.squeeze()
        for i in range(len(self.z_gibbs_posteriors)):
            self.z_gibbs_posteriors[i] = self.z_gibbs_posteriors[i].squeeze()

        # KL loss
        kls = self.kl_div.calc(G, self.z_gibbs_posteriors, self.z_gibbs_priors, mu,var)
        KL_loss  = sum(kls)

        # Reconstruction loss
        rec_loss_func = nn.MSELoss(reduction='sum') if self.loss == 'MSE' else \
                        nn.BCELoss(reduction='sum')
        recs = list(map(lambda x: rec_loss_func(x[0], x[1]), zip(x_, x)))
        rec_loss = sum(recs)
        
        # Total loss
        total_loss = KL_loss + rec_loss

        results = {
            'z': self.z_posteriors, 'x_rec': x_, 'mu': mu, 'var': var, 
            'total_loss': total_loss, 'rec_losses': recs, 'KL_loss': KL_loss
        }

        return results

In [3]:
class kl_divergence():
    __version__ = 1.0
    
    def __init__(self, latent_dims, batch_size, device=_device):
        self.latent_dims = latent_dims
        self.batch_size = batch_size
        self.device = device

    def calc(self, G, z, z_priors, mu, var):
        ## Creating Sufficient statistics
        T_priors, T_posts, lambdas = [], [], []
        for z_i, z_prior_i, mu_i, var_i in zip(z, z_priors, mu, var):
            T_priors.append(torch.cat((z_prior_i, torch.square(z_prior_i)), 1))
            T_posts.append(torch.cat((z_i, torch.square(z_i)), 1))
            lambdas.append(torch.cat((mu_i,var_i),1))
            
        # TODO: make it generic for > 2 latent spaces
        T_prior_sqrd = torch.sum(torch.square(z_priors[0]), 1) + \
                       torch.sum(torch.square(z_priors[1]), 1) #stores z^2+z'^2
        T_post_sqrd  = torch.sum(torch.square(z[0]), 1) + \
                       torch.sum(torch.square(z[1]), 1)
        T1_prior_unsq = T_priors[0].unsqueeze(2)       
        T2_prior_unsq = T_priors[1].unsqueeze(1)       
        T1_post_unsq  = T_posts[0].unsqueeze(2)        
        T2_post_unsq  = T_posts[1].unsqueeze(1)        
        T_prior_kron = torch.zeros(self.batch_size, 2 * self.latent_dims[0], 
                                   2 * self.latent_dims[1]).to(self.device)
        T_post_kron = torch.zeros(T_prior_kron.shape).to(self.device)
       
        for i in range(self.batch_size):
            T_prior_kron[i,:] = torch.kron(T1_prior_unsq[i,:], T2_prior_unsq[i,:])
            T_post_kron[i,:] = torch.kron(T1_post_unsq[i,:], T2_post_unsq[i,:])    
            
        part_fun0 = self.dot_product(lambdas[0], T_posts[0]) + \
                    self.dot_product(lambdas[1], T_posts[1])
        part_fun1 = -self.dot_product(lambdas[0], T_posts[0].detach()) - \
                     self.dot_product(lambdas[1], T_posts[1].detach()) #-lambda*Tq-lambda'Tq'    
        part_fun2 = self.dot_product(T_prior_kron.detach(), G) - \
                    self.dot_product(T_post_kron.detach(), G)

        return part_fun0, part_fun1, part_fun2
    
    def dot_product(self, tensor_1, tensor_2):
        out = torch.sum(torch.mul(tensor_1, tensor_2))
        return out

In [4]:
class gibbs_sampler():
    __version__ = 1.0
    
    def __init__(self, latent_dims, batch_size, device=_device):
        self.latent_dims = latent_dims
        self.batch_size = batch_size
        self.device = device

    def var_calc(self,z, g22, lambda_2):
        val = 1 - torch.matmul(torch.square(z), g22)
        if lambda_2 is not None:
            val -= lambda_2
        return torch.reciprocal(2 * val)

    def mean_calc(self, z, var, g11, lambda_1):
        beta = torch.matmul(z, g11)
        if lambda_1 is not None:
            beta += lambda_1
        return var * beta

    def value_calc(self,z, g11, g22, lambda_1, lambda_2):
        var1 = self.var_calc(z, g22, lambda_2)
        mean1 = self.mean_calc(z, var1, g11, lambda_1)
        out = mean1 + torch.sqrt(var1.float()) * torch.randn_like(var1)
        return out

    def sample(self, g11, g22, z=None, lambda1s=None, lambda2s=None, n_iterations=1):
        """
        g11, g22: 
            Diagonal blocks of the metric tensor
        z: 
            If not provided, randomly initialize
        lambda1s: optional
            Natural parameter 1 of the latent distributions
            If not provided, treat as zeros
        lambda1s: optional
            Natural parameter 2 of the latent distributions
            If not provided, treat as zeros
        n_iterations: int, optional; default 1
        """
            # TODO: function signature of gibbs_sample: optional parameters
            # flag_init. not necessary; if z not provided, init. z rand.ly
            # Not really an optimization but make the code clear
            # in case people want to look carefully in the future
            # I made an attempt in the local file `gibbs_sampler_poise.py`; debugging needed
        if z is None:
            z = [torch.randn(self.batch_size, ld).squeeze().to(self.device) 
                 for ld in self.latent_dims]
        if lambda1s is None:
            lambda1s = [None for _ in range(len(self.latent_dims))]
        if lambda2s is None:
            lambda2s = [None for _ in range(len(self.latent_dims))]

        # TODO: make it generic for > 2 latent spaces 
        for i in range(n_iterations):
            z[0] = self.value_calc(z[1], torch.transpose(g11,0,1), torch.transpose(g22,0,1),
                                   lambda1s[0], lambda2s[0]) 
            z[1] = self.value_calc(z[0], g11, g22, lambda1s[1], lambda2s[1])

        return z

In [5]:
_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
from torch.nn import functional as F
class Encoder1(nn.Module):
    def __init__(self):
        super(Encoder1, self).__init__()
        self.l1 = nn.Linear(100, 50).to(_device)
        self.l2mu = nn.Linear(50, 10).to(_device)
        self.l2var = nn.Linear(50, 10).to(_device)
        self.latent_dim = 10
        
    def forward(self, x):
        x = F.relu(self.l1(x))
        mu = self.l2mu(x)
        log_var = self.l2var(x)
        return mu, log_var
    
class Encoder2(nn.Module):
    # 64*64 -> 40*40 -> 16*16 -> 4*4
    def __init__(self):
        super(Encoder2, self).__init__()
        self.l1 = nn.Conv2d(3, 2, (25, 25)).to(_device)
        self.l2 = nn.Conv2d(2, 1, (25, 25)).to(_device)
        self.l2mu = nn.Conv2d(1, 1, (13, 13)).to(_device)
        self.l2var = nn.Conv2d(1, 1, (13, 13)).to(_device)
        self.latent_dim = (1, 4, 4)
        
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        mu = self.l2mu(x)
        log_var = self.l2var(x)
        return mu, log_var

enc1 = Encoder1()
dec1 = nn.Sequential(nn.Linear(10, 50), nn.Linear(50, 100)).to(_device)


enc2 = Encoder2()
dec2 = nn.Sequential(nn.ConvTranspose2d(1, 1, (13, 13)), 
                     nn.ConvTranspose2d(1, 2, (25, 25)), 
                     nn.ConvTranspose2d(2, 3, (25, 25))).to(_device)

In [7]:
net = POISEVAE([enc1, enc2], [dec1, dec2], batch_size=10, loss='MSE')

In [8]:
for name, para in net.named_parameters():
    print(name)

g11
g22
encoders.0.l1.weight
encoders.0.l1.bias
encoders.0.l2mu.weight
encoders.0.l2mu.bias
encoders.0.l2var.weight
encoders.0.l2var.bias
encoders.1.l1.weight
encoders.1.l1.bias
encoders.1.l2.weight
encoders.1.l2.bias
encoders.1.l2mu.weight
encoders.1.l2mu.bias
encoders.1.l2var.weight
encoders.1.l2var.bias
decoders.0.0.weight
decoders.0.0.bias
decoders.0.1.weight
decoders.0.1.bias
decoders.1.0.weight
decoders.1.0.bias
decoders.1.1.weight
decoders.1.1.bias
decoders.1.2.weight
decoders.1.2.bias


In [9]:
data1 = torch.randn(10, 100, device=_device)
data2 = torch.randn(10, 3, 64, 64, device=_device)

In [10]:
ret = net([data1, data2])