In [1]:
import torch

In [8]:
from torch.autograd import Variable
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
import math

def compute_loss_gp( pred_mu,pred_std, target_y , intrain=True):

    """
    compute NLLLossLNPF
    # computes approximate LL in a numerically stable way
    # LL = E_{q(z|y_cntxt)}[ \prod_t p(y^t|z)]
    # LL MC = log ( mean_z ( \prod_t p(y^t|z)) )
    # = log [ sum_z ( \prod_t p(y^t|z)) ] - log(n_z_samples)
    # = log [ sum_z ( exp \sum_t log p(y^t|z)) ] - log(n_z_samples)
    # = log_sum_exp_z ( \sum_t log p(y^t|z)) - log(n_z_samples)
    
    """
    #(nsamples,nb,ndata,nchannels) 
    p_yCc = Normal(loc=pred_mu, scale=pred_std)                
    
    if intrain:
        #(numsamples,nb) 
        sumlogprob = p_yCc.log_prob(target_y).sum(dim=(-1,-2))   #sum over channels and targets    
        logmeanexp_sumlogprob= torch.logsumexp(sumlogprob, dim=0) -  math.log(sumlogprob.size(0)) 

        #meanlogprob = p_yCc.log_prob(target_y).mean(dim=(-1,-2))     #mean over channels and targets 
        #logmeanexp_sumlogprob= torch.logsumexp(meanlogprob, dim=0) -  math.log(meanlogprob.size(0))     
    else :
        #(numsamples,nb) 
        #sumlogprob = p_yCc.log_prob(target_y).sum(dim=(-1,-2))   #sum over channels and targets    
        #logmeanexp_sumlogprob= torch.logsumexp(sumlogprob, dim=0) -  math.log(sumlogprob.size(0)) 

        meanlogprob = p_yCc.log_prob(target_y).mean(dim=(-1,-2))     #mean over channels and targets 
        logmeanexp_sumlogprob= torch.logsumexp(meanlogprob, dim=0) -  math.log(meanlogprob.size(0))     
        
    return -logmeanexp_sumlogprob.mean() #mean over batches

    
    
    
eps=1e-8
def compute_loss_gp_origin( pred_mu,pred_std, target_y, z_samples=None, qz_c=None, qz_ct=None):
    
    """
    compute NLLLossLNPF
    # computes approximate LL in a numerically stable way
    # LL = E_{q(z|y_cntxt)}[ \prod_t p(y^t|z)]
    # LL MC = log ( mean_z ( \prod_t p(y^t|z)) )
    # = log [ sum_z ( \prod_t p(y^t|z)) ] - log(n_z_samples)
    # = log [ sum_z ( exp \sum_t log p(y^t|z)) ] - log(n_z_samples)
    # = log_sum_exp_z ( \sum_t log p(y^t|z)) - log(n_z_samples)
    
    """
    
    def sum_from_nth_dim(t, dim):
        """Sum all dims from `dim`. E.g. sum_after_nth_dim(torch.rand(2,3,4,5), 2).shape = [2,3]"""
        return t.view(*t.shape[:dim], -1).sum(-1)


    def sum_log_prob(prob, sample):
        """Compute log probability then sum all but the z_samples and batch."""    
        log_p = prob.log_prob(sample)          # size = [n_z_samples, batch_size, *]    
        sum_log_p = sum_from_nth_dim(log_p, 2) # size = [n_z_samples, batch_size]
        return sum_log_p

    
    p_yCc = Normal(loc=pred_mu, scale=pred_std)    
    if qz_c is not None:
        qz_c = Normal(loc=qz_c[0], scale=qz_c[1])
        
    if qz_ct is not None:
        qz_ct = Normal(loc=qz_ct[0], scale=qz_ct[1])
        
        
    n_z_samples, batch_size, *n_trgt = p_yCc.batch_shape    
    # \sum_t log p(y^t|z). size = [n_z_samples, batch_size]
    sum_log_p_yCz = sum_log_prob(p_yCc, target_y)

    
    # uses importance sampling weights if necessary
    if z_samples is not None:
        # All latents are treated as independent. size = [n_z_samples, batch_size]
        sum_log_qz_c = sum_log_prob(qz_c, z_samples)
        sum_log_qz_ct = sum_log_prob(qz_ct, z_samples)
        # importance sampling : multiply \prod_t p(y^t|z)) by q(z|y_cntxt) / q(z|y_cntxt, y_trgt)
        # i.e. add log q(z|y_cntxt) - log q(z|y_cntxt, y_trgt)
        #print(sum_log_p_yCz, sum_log_qz_c, sum_log_qz_ct)
        sum_log_w_k = sum_log_p_yCz + sum_log_qz_c - sum_log_qz_ct
    else:
        sum_log_w_k = sum_log_p_yCz

    # log_sum_exp_z ... . size = [batch_size]
    log_S_z_sum_p_yCz = torch.logsumexp(sum_log_w_k + eps, 0)
    # - log(n_z_samples)
    log_E_z_sum_p_yCz = log_S_z_sum_p_yCz - math.log(n_z_samples)    

    #print('log_E_z_sum_p_yCz {}'.format(log_E_z_sum_p_yCz.mean().item()))
    # NEGATIVE log likelihood
    #return -log_E_z_sum_p_yCz
    return -log_E_z_sum_p_yCz.mean()  #averages each loss over batches 







In [27]:
#(nsamples,nbatch,ndata,nchannel)
pmu = 0.5*torch.randn(5,16,30,3)
pstd = 0.1*torch.rand(5,16,30,3)
yobs = torch.randn(16,30,3)

In [28]:
compute_loss_gp_origin( pmu,pstd, yobs),   compute_loss_gp( pmu,pstd, yobs , intrain=True)
    

(tensor(275620.8438), tensor(275620.8438))