In [1]:
# %%
import time
import numpy as np
import pickle
from numpy.linalg import det

import CMINE_lib as CMINE
# from Guassian_variables import Data_guassian

import pandas as pd
from scipy.stats import multivariate_normal
import itertools

np.random.seed(37)
from scipy import stats
from sklearn.neighbors import KernelDensity

import math

import torch 
import torch.nn as nn

In [2]:
def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    # TODO: torch.max(value, dim=None) threw an error at time of writing
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                                       dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)

class L1OutUB(nn.Module):  # naive upper bound
    def __init__(self, x_dim, y_dim, hidden_size):
        super(L1OutUB, self).__init__()
        self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim))

        self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim),
                                       nn.Tanh())

        self.p_mu_neg = nn.Sequential(nn.Linear(x_dim-1, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim))

        self.p_logvar_neg = nn.Sequential(nn.Linear(x_dim-1, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim),
                                       nn.Tanh())


    def get_mu_logvar(self, x_samples):
        mu = self.p_mu(x_samples)
        logvar = self.p_logvar(x_samples)
        return mu, logvar
    
    def get_mu_logvar_neg(self, x_samples):
        mu = self.p_mu_neg(x_samples)
        logvar = self.p_logvar_neg(x_samples)
        return mu, logvar

    def forward(self, x_samples, y_samples): # x_samples = s_t, a ; y_samples = s_{t+1}
        batch_size = y_samples.shape[0]
        #x_samples[:, 1].masked_fill_(x_samples[:, 1]!=0, float(0))
        mu, logvar = self.get_mu_logvar(x_samples)

        positive = (- (mu - y_samples)**2 /2./logvar.exp() - logvar/2.).sum(dim = -1) #[nsample]

        negative = []
        for i in range(x_samples.shape[1]-1):
            result = []
            for j in range(x_samples.shape[1]): 
                if j != i:
                    result.append(j)
            
            x_temp = torch.index_select(x_samples, dim=1, index=torch.tensor(result).cuda())
            #x_temp = x_samples.index_fill_(1, torch.tensor([i]).cuda(), float('0'))
            mu, logvar = self.get_mu_logvar_neg(x_temp)
            neg = (- (mu - y_samples)**2 /2./logvar.exp() - logvar/2.).sum(dim = -1) #[nsample]
            if i == 0:
                negative = neg.unsqueeze(-1)
            else:
                negative = torch.cat([negative, neg.unsqueeze(-1)], 1)



        # mu_1 = mu.unsqueeze(1)          # [nsample,1,dim]
        # logvar_1 = logvar.unsqueeze(1)
        # y_samples_1 = y_samples.unsqueeze(0)            # [1,nsample,dim]
        # all_probs =  (- (y_samples_1 - mu_1)**2/2./logvar_1.exp()- logvar_1/2.).sum(dim = -1)  #[nsample, nsample]

        # diag_mask =  torch.ones([batch_size]).diag().unsqueeze(-1).cuda() * (-20.)
        # negative = log_sum_exp(all_probs + diag_mask,dim=0) - np.log(batch_size-1.) #[nsample]
        #print(( positive.unsqueeze(-1)- negative ).mean())
       
        return ( positive.unsqueeze(-1)- negative ).mean()
    
    def loglikeli(self, x_samples, y_samples):
        x_samples = x_samples.clone()
        y_samples = y_samples.clone()
        mu, logvar = self.get_mu_logvar(x_samples)

        lg = (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)

        del x_samples, y_samples
        torch.cuda.empty_cache()
        #print("lg", lg)
        return lg
    
    def loglikeli_mask(self, x_samples, y_samples):
        negative = []
        x_samples = x_samples.clone()
        y_samples = y_samples.clone()
        for i in range(x_samples.shape[1]-1):
            result = []
        
            for j in range(x_samples.shape[1]): 
                if j != i:
                    result.append(j)
            x_temp = torch.index_select(x_samples, dim=1, index=torch.tensor(result).cuda())
            #x_temp = x_samples.index_fill_(1, torch.tensor([i]).cuda(), float('0'))
            mu, logvar = self.get_mu_logvar_neg(x_temp)
            neg =  (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=-1) #(- (mu - y_samples)**2 /2./logvar.exp() - logvar/2.).sum(dim = -1) #[nsample]
            if i == 0:
                negative = neg.unsqueeze(-1)
            else:
                negative = torch.cat([negative, neg.unsqueeze(-1)], 1)
        del x_samples, y_samples
        torch.cuda.empty_cache()
        #print('mask', negative.sum(dim=1).mean(dim=0))
        return negative.sum(dim=1).mean(dim=0)

    def learning_loss(self, x_samples, y_samples):
        return  - self.loglikeli_mask(x_samples, y_samples)  - self.loglikeli(x_samples, y_samples)

In [3]:
#import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '2'

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True


In [13]:
Dim = 5
dataset = CMINE.create_dataset_DGP(GenModel="", Params="", Dim=5, N=64)
s_t = torch.from_numpy(dataset[0]).float().cuda()
s_next = torch.from_numpy(dataset[1]).float().cuda()
a = torch.from_numpy(dataset[2]).float().cuda()

In [14]:
s_next.shape

torch.Size([64, 10])

In [15]:
torch.cat([s_t,a], dim=1).shape

torch.Size([64, 11])

In [17]:
# %%
sample_dim = 2*Dim
batch_size = 64
hidden_size = 15
learning_rate = 0.005
training_steps = 4000

cubic = False 

# %%
model = L1OutUB(sample_dim + 1, sample_dim, hidden_size).cuda()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

# %%

# %%
mi_est_values = []

# %%
for step in range(training_steps):
    #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
    dataset = CMINE.create_dataset_DGP(GenModel="", Params="", Dim=5, N=64)
    s_t = torch.from_numpy(dataset[0]).float().cuda()
    s_next = torch.from_numpy(dataset[1]).float().cuda()
    a = torch.from_numpy(dataset[2]).float().cuda()
    
    batch_x = torch.cat([s_t,a], dim=1)
    batch_y = s_next
    model.eval()
    cmi = model(batch_x, batch_y).item()
    mi_est_values.append(cmi)
    #print(cmi)
    # %%
    model.train() 

    model_loss = model.learning_loss(batch_x, batch_y)

    optimizer.zero_grad()
    model_loss.backward(retain_graph=True)
    optimizer.step()

    del batch_x, batch_y
    torch.cuda.empty_cache()
#print("finish training for %s with true MI value = %f"%('LOO', 6.0))
print(np.array(mi_est_values).mean())

72.55458068847656
-9.276410102844238
62.727333068847656
-25.810632705688477
0.3335517942905426
-7.680867671966553
-34.82990646362305
-0.19748543202877045
16.214752197265625
-4.186553955078125
-1.5425317287445068
-16.854326248168945
2.8642261028289795
0.9347788095474243
-0.3923979103565216
-25.0108699798584
31.083845138549805
1400.1199951171875
1.8318172693252563
6840.84521484375
151.7137451171875
-334.0803527832031
22.699750900268555
2.9017765522003174
-67.77193450927734
-57.051727294921875
24.697229385375977
-3.0523264408111572
2420.28759765625
-7.352047920227051
1434.9871826171875
-13.111729621887207
-253.88023376464844
156188.984375
0.21841326355934143
7.732656002044678
0.4668305516242981
-48112928.0
1.1738598346710205
-10.919944763183594
-3.241413116455078
-14.684198379516602
-3.191225051879883
4384.7978515625
2.213449239730835
72.15056610107422
-2.7987749576568604
1.6570396423339844
-274302.28125
-0.7642926573753357
-118.3180923461914
1.8192790746688843
-1395826.625
-126.451637268

-6.107605934143066
151.75003051757812
1326.2083740234375
-299.0632629394531
-236.990478515625
-15.88261604309082
-100.58842468261719
4.0411553382873535


KeyboardInterrupt: 