In [1]:
# %%
import time
import numpy as np
import pickle
from numpy.linalg import det

import CMINE_lib as CMINE
# from Guassian_variables import Data_guassian

import pandas as pd
from scipy.stats import multivariate_normal
import itertools

np.random.seed(37)
from scipy import stats
from sklearn.neighbors import KernelDensity

import math

import torch 
import torch.nn as nn
import torch.nn.functional as F

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


In [2]:
def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    # TODO: torch.max(value, dim=None) threw an error at time of writing
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                                       dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)

In [3]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [4]:
from modules.CMI import DR_CMI, CDL_CMI

In [5]:
Dim = 5
batch_size = 64
#dataset = CMINE.create_dataset_DGP( Dim=5, N=batch_size)
dataset = CMINE.create_dataset_DGP_binary_A( Dim=5, N=batch_size)
s_t = torch.from_numpy(dataset[0]).float().cuda()
s_next = torch.from_numpy(dataset[1]).float().cuda()
a = torch.from_numpy(dataset[2]).float().cuda()

In [6]:
sample_dim = 2*Dim

hidden_size = 15
learning_rate = 0.005
training_steps = 10

cubic = False 


In [7]:
def train_dr(N = 64, training_steps = 10):
    torch.cuda.empty_cache()
    model_dr = DR_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_dr = torch.optim.Adam(model_dr.parameters(), learning_rate)
    dr_est_values = []
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        dataset = CMINE.create_dataset_DGP_binary_A(Dim=Dim, N=N)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_dr.eval()
        drs = model_dr(batch_x, batch_y)
        #mi_est_values.append(cmi)
        dr_est_values.append(drs)
        model_dr.train() 

        model_loss = model_dr.learning_loss(batch_x, batch_y)

        optimizer_dr.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_dr.step()

        del batch_x, batch_y
        torch.cuda.empty_cache()
    return dr_est_values

In [8]:
def train_cdl(N = 64, training_steps = 10):
    torch.cuda.empty_cache()
    model_cdl = CDL_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_cdl = torch.optim.Adam(model_cdl.parameters(), learning_rate)
    cdl_est_values = []
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        dataset = CMINE.create_dataset_DGP_binary_A(Dim=Dim, N=N)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_cdl.eval()
        cdl_cmi = model_cdl(batch_x, batch_y)
        cdl_est_values.append(cdl_cmi)
        model_cdl.train() 

        model_loss = model_cdl.learning_loss(batch_x, batch_y)

        optimizer_cdl.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_cdl.step()

        del batch_x, batch_y
        torch.cuda.empty_cache()
    return cdl_est_values

In [9]:
cdl_est_values = train_cdl()


In [10]:
dr_est_values = train_dr()

In [11]:
N = 64
cdl_est_values = train_cdl(N)
dr_est_values = train_dr(N)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[212.43410855 181.03487066  13.88867946   8.90139438 149.91577675
  20.48665845   0.3961755   33.26350302   1.53078866 288.94831004]
[2.65716080e+01 2.92947355e+01 2.83777408e+01 2.82387542e+01
 2.87403732e+01 4.93031977e+02 8.13965871e+02 1.02927417e+06
 1.13731133e+03 1.45716088e+03]


In [12]:
N = 128
cdl_est_values = train_cdl(N)
dr_est_values = train_dr(N)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[2.09376884e+00 5.90993278e-01 8.88056797e+00 1.75195472e+00
 2.58970748e+03 4.46875074e+00 1.19108975e+06 1.93920223e+01
 1.19205989e+01 1.11547462e+02]
[1.28506732e+00 1.27196821e+00 1.24692049e+00 4.50321915e+00
 2.09011867e+00 7.78077734e+01 2.77490046e+02 3.17608686e+02
 1.39276605e+06 6.60993560e+02]


In [13]:
N = 32
cdl_est_values = train_cdl(N)
dr_est_values = train_dr(N)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[0.06895235 0.02756502 0.397309   0.65182259 0.59683229 0.10727125
 0.6785519  4.84395778 1.47217493 0.16290962]
[1.65885800e+00 6.95902781e-01 7.92573906e-01 5.93360118e+00
 2.34752648e+00 8.27147290e+01 1.73405988e+01 9.22411028e+00
 1.84655191e+01 7.90304679e+04]


In [14]:
N = 16
cdl_est_values = train_cdl(N)
dr_est_values = train_dr(N)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[1.70686008e+00 1.02441696e+00 1.27264025e+01 3.40521445e+02
 1.23847237e+00 3.20266196e+00 2.09768737e-01 7.82289181e+00
 1.03181786e+03 9.36735677e-01]
[1.02772041e+02 7.65122229e+03 1.60980162e+07 1.28226557e+02
 3.35852706e+05 3.19404235e-01 1.19249740e+00 3.33773376e-01
 6.57454240e-01 7.88449046e+00]


In [15]:
N = 8
cdl_est_values = train_cdl(N)
dr_est_values = train_dr(N)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[4.87663140e-01 5.57553836e+03 6.75505167e+00 7.91179832e-02
 6.91325274e+01 1.57431246e+02 4.14954104e+01 1.50357679e+00
 1.36824509e+02 8.68369331e+00]
[5.13608319e-01 5.80961051e-01 9.29477091e-01 2.67728226e+02
 4.31119307e-01 3.23346837e+03 3.98408498e+02 1.82239026e+02
 1.24588287e+03 8.38565530e+03]


# Training Step

In [15]:
N = 64
training_step = 100
cdl_est_values = train_cdl(N, training_step)
dr_est_values = train_dr(N, training_step)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[7.90087952e+02 7.85717166e+02 8.86630148e+02 7.85646969e+02
 1.60624307e+03 6.54722027e+07 4.60514869e+05 8.44506005e+04
 2.02980041e+05 5.45358667e+05]
[1.26068347e+07 6.77329778e+11 1.26400620e+07 3.45984030e+07
 1.27147874e+07 1.91223624e+15 2.71988959e+09 3.92338222e+12
 4.32217656e+14 1.64770941e+11]


In [14]:
N = 32
training_step = 100
cdl_est_values = train_cdl(N, training_step)
dr_est_values = train_dr(N, training_step)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[4.18966604e+00 3.15982827e+06 8.61560204e+02 2.23745444e+01
 2.57475600e+04 2.69649200e+05 7.87265592e+02 6.23005884e+03
 7.46181203e+04 1.20451452e+03]
[2.80374949e+10 3.08702256e+10 2.81252140e+10 2.81249405e+10
 2.81251843e+10 2.43591543e+13 2.16347626e+13 2.92307035e+19
 3.14428485e+12 7.73048805e+13]


In [19]:
print(np.array(cdl_est_values[-50:]).mean(axis=0))
print(np.array(dr_est_values[-50:]).mean(axis=0))

[1.57217472e+03 1.57109104e+03 1.77133819e+03 1.57102237e+03
 1.81083862e+03 1.30944404e+08 9.21023368e+05 1.68589411e+05
 4.05958852e+05 1.06707031e+06]
[2.51348668e+07 1.35465464e+12 2.52112564e+07 2.52614051e+07
 2.53483606e+07 3.82447132e+15 5.32261770e+09 7.53837592e+12
 8.64435300e+14 3.29537708e+11]


In [31]:
N = 526
training_step = 100
cdl_est_values = train_cdl(N, training_step)
dr_est_values = train_dr(N, training_step)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[1.50006065e+01 2.18871596e+02 1.10105989e+07 5.11114454e+02
 1.06051950e+02 9.00031509e+01 1.04145438e+04 3.52512225e+03
 5.52125692e+03 2.41971664e+04]
[7.68950225e+06 7.74432824e+06 7.91731634e+06 7.70377130e+06
 7.70934032e+06 7.88427070e+08 4.68894041e+16 1.52884448e+10
 1.34091251e+11 2.40433729e+10]


In [30]:
N = 526
training_step = 10
cdl_est_values = train_cdl(N, training_step)
dr_est_values = train_dr(N, training_step)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[2.41623130e-01 5.83663910e-01 2.62432172e-01 1.86141613e-01
 2.88988191e+02 2.00060656e+01 5.57397880e+00 2.69901884e+01
 2.13138771e-01 1.53580540e+00]
[1.34774063e+06 7.49846064e+03 7.87508115e+03 7.57209025e+03
 7.50388969e+03 5.57463735e+05 1.85607135e+04 3.51530271e+09
 1.84042864e+05 2.73516473e+06]
