In [1]:
# %%
import time
import numpy as np
import pickle
from numpy.linalg import det

import CMINE_lib as CMINE
# from Guassian_variables import Data_guassian

import pandas as pd
from scipy.stats import multivariate_normal
import itertools

np.random.seed(37)
from scipy import stats
from sklearn.neighbors import KernelDensity

import math

import torch 
import torch.nn as nn
import torch.nn.functional as F

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'


In [2]:
def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    # TODO: torch.max(value, dim=None) threw an error at time of writing
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                                       dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)

In [3]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [4]:
from modules.CMI import DR_CMI, CDL_CMI

In [5]:
Dim = 5
batch_size = 64
#dataset = CMINE.create_dataset_DGP( Dim=5, N=batch_size)
dataset = CMINE.create_dataset_DGP_binary_A( Dim=5, N=batch_size)
s_t = torch.from_numpy(dataset[0]).float().cuda()
s_next = torch.from_numpy(dataset[1]).float().cuda()
a = torch.from_numpy(dataset[2]).float().cuda()

In [6]:
sample_dim = 2*Dim

hidden_size = 15
learning_rate = 0.005
training_steps = 10

cubic = False 


In [24]:
def train_dr(N = 64, training_steps = 10):
    torch.cuda.empty_cache()
    model_dr = DR_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_dr = torch.optim.Adam(model_dr.parameters(), learning_rate)
    dr_est_values = []
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        dataset = CMINE.create_dataset_DGP_binary_A(Dim=Dim, N=N)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_dr.eval()
        drs = model_dr(batch_x, batch_y)
        #mi_est_values.append(cmi)
        dr_est_values.append(drs)
        model_dr.train() 

        model_loss = model_dr.learning_loss(batch_x, batch_y)

        optimizer_dr.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_dr.step()

        del batch_x, batch_y
        torch.cuda.empty_cache()
    return dr_est_values

In [25]:
def train_cdl(N = 64, training_steps = 10):
    torch.cuda.empty_cache()
    model_cdl = CDL_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_cdl = torch.optim.Adam(model_cdl.parameters(), learning_rate)
    cdl_est_values = []
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        dataset = CMINE.create_dataset_DGP_binary_A(Dim=Dim, N=64)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_cdl.eval()
        cdl_cmi = model_cdl(batch_x, batch_y)
        cdl_est_values.append(cdl_cmi)
        model_cdl.train() 

        model_loss = model_cdl.learning_loss(batch_x, batch_y)

        optimizer_cdl.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_cdl.step()

        del batch_x, batch_y
        torch.cuda.empty_cache()
    return cdl_est_values

In [15]:
cdl_est_values = train_cdl()


In [17]:
dr_est_values = train_dr()

In [18]:
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[31.62909881 31.90293879  3.32989618  2.37203364 40.20621789  2.33121245
  0.10007804  4.64158957  0.17156096 34.20440767]
[1.67182227e+01 2.13185132e+01 1.74296744e+01 1.64255251e+01
 1.87672516e+01 1.52023465e+03 2.41713650e+04 5.81954119e+07
 1.12470102e+04 1.75729985e+04]


In [19]:
N = 128
cdl_est_values = train_cdl(N)
dr_est_values = train_dr(N)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[7.30239090e+00 7.05152936e+00 8.23540059e+00 5.79760796e+00
 7.20315365e+00 2.12945217e+05 1.50455859e+02 2.78672594e+02
 2.27193331e+02 7.37535593e+03]
[6.10654989e+03 1.55910844e+05 3.77279128e+04 1.58046943e+11
 3.84559755e+07 2.88183968e+09 1.70990210e+06 6.64867495e+05
 7.54625585e+06 3.09346798e+11]


In [21]:
N = 32
cdl_est_values = train_cdl(N)
dr_est_values = train_dr(N)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[ 2.22894816  0.53415195  0.17980211  0.13757467  0.99465214 65.07050759
  4.91550556 16.42847439  3.3649228   1.53744228]
[5.39664560e+03 5.35883037e+03 5.52185233e+03 5.56562032e+03
 5.17552405e+03 1.64288803e+05 5.02380312e+04 5.52713212e+07
 2.11297229e+10 1.16834368e+08]


In [23]:
N = 16
cdl_est_values = train_cdl(N)
dr_est_values = train_dr(N)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[4.78701690e+00 9.01703398e+02 7.79684412e+00 5.26387394e+00
 1.29594241e+01 1.77291970e+06 8.20365912e+01 9.48841043e+01
 1.01327541e+03 1.45746310e+02]
[1.51966064e+05 2.71351952e-01 7.09830511e-01 6.07903015e-01
 5.05081787e+00 1.48763132e-01 7.62969087e-01 4.27091030e-01
 2.83229789e+01 2.05515325e+02]


# Training Step

In [26]:
N = 64
training_step = 100
cdl_est_values = train_cdl(N, training_step)
dr_est_values = train_dr(N, training_step)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[6.31943605e+04 1.71830261e+08 6.60621143e+04 1.83854909e+05
 6.29557521e+04 3.95083027e+09 5.30933075e+05 3.03847897e+05
 3.01877772e+06 9.41156254e+05]
[1.45674717e+05 5.62305784e+05 1.31564423e+05 1.34648689e+05
 1.32427251e+06 2.39606343e+09 1.46911331e+14 6.66713691e+09
 1.28531635e+08 2.40902858e+09]


In [27]:
N = 32
training_step = 100
cdl_est_values = train_cdl(N, training_step)
dr_est_values = train_dr(N, training_step)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[1.68653295e+04 8.02705649e+00 1.19907736e+03 3.14780780e+01
 3.98932782e+01 8.38730004e+05 3.92318122e+03 6.95953489e+04
 1.22019822e+05 3.48546747e+06]
[9.74688966e+03 7.74597234e+03 2.96815940e+09 6.31418542e+03
 6.20089638e+03 1.22671743e+06 3.68697284e+09 4.77083438e+06
 1.44713419e+07 1.88002237e+07]
