In [1]:
# %%
import time
import numpy as np
import pickle
from numpy.linalg import det

import CMINE_lib as CMINE
# from Guassian_variables import Data_guassian
from structurerl import * 

import pandas as pd
from scipy.stats import multivariate_normal
import itertools

np.random.seed(37)
from scipy import stats
from sklearn.neighbors import KernelDensity

import math

import torch 
import torch.nn as nn
import torch.nn.functional as F

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'


In [2]:
def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    # TODO: torch.max(value, dim=None) threw an error at time of writing
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                                       dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)

In [3]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [4]:
from modules.CMI import DR_CMI, CDL_CMI

In [5]:
Dim = 5
batch_size = 64
#dataset = CMINE.create_dataset_DGP( Dim=5, N=batch_size)
dataset = create_dataset_DGP_binary_A_conf( Dim=5, N=batch_size)
s_t = torch.from_numpy(dataset[0]).float().cuda()
s_next = torch.from_numpy(dataset[1]).float().cuda()
a = torch.from_numpy(dataset[2]).float().cuda()

In [6]:
sample_dim = 2*Dim

hidden_size = 15
learning_rate = 0.005
training_steps = 10

cubic = False 


In [7]:
def train_dr(N = 64, training_steps = 10):
    torch.cuda.empty_cache()
    model_dr = DR_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_dr = torch.optim.Adam(model_dr.parameters(), learning_rate)
    dr_est_values = []
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        dataset = create_dataset_DGP_binary_A_conf(Dim=Dim, N=N)
        #dataset = create_dataset_DGP_binary_A(Dim=Dim, N=N)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()
        #print(a)

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_dr.eval()
        drs = model_dr(batch_x, batch_y)
        #mi_est_values.append(cmi)
        dr_est_values.append(drs)
        model_dr.train() 

        model_loss = model_dr.learning_loss(batch_x, batch_y)

        optimizer_dr.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_dr.step()

        del batch_x, batch_y
        torch.cuda.empty_cache()
    return dr_est_values

In [8]:
def train_cdl(N = 64, training_steps = 10):
    torch.cuda.empty_cache()
    model_cdl = CDL_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_cdl = torch.optim.Adam(model_cdl.parameters(), learning_rate)
    cdl_est_values = []
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        dataset = create_dataset_DGP_binary_A_conf(Dim=Dim, N=N)
        #dataset = create_dataset_DGP_binary_A(Dim=Dim, N=64)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_cdl.eval()
        cdl_cmi = model_cdl(batch_x, batch_y)
        cdl_est_values.append(cdl_cmi)
        model_cdl.train() 

        model_loss = model_cdl.learning_loss(batch_x, batch_y)

        optimizer_cdl.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_cdl.step()

        del batch_x, batch_y
        torch.cuda.empty_cache()
    return cdl_est_values

In [17]:
N=64
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[7.74717566e-01 2.00575563e+00 3.04663280e-01 4.24196271e-01
 1.52030621e+03 2.09584485e+02 5.29245237e+01 4.57642656e+02
 9.48272430e+02 1.85075029e+06]
[9.63950933e+03 9.57582094e+03 9.60351382e+03 9.70514245e+03
 9.70159701e+03 2.22066365e+05 1.86636486e+09 5.98517839e+05
 5.64498476e+05 1.60611569e+05]


In [14]:
N=128
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[9.39067727e+01 2.60213449e+01 1.52895031e+02 5.26721411e+01
 1.00172588e+04 6.40147732e+01 2.38063097e+05 2.71312666e+00
 2.75721590e+02 1.05767696e+02]
[6.10599471e+05 2.37744361e+03 3.54157147e+03 1.06949477e+04
 1.04649346e+03 9.93568372e+04 1.59619324e+05 8.37194225e+06
 2.27745039e+06 1.29276019e+10]


In [12]:
N = 32
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[5.64183102e+01 5.93802901e+02 5.16957291e-01 2.50842905e+00
 3.78324551e-01 1.28835758e+04 4.67217801e+01 9.44310241e+00
 1.73909314e+03 5.21904449e+00]
[9.88784355e+02 7.33353794e+00 1.73287425e+01 1.50396905e+01
 2.97971128e+01 8.20206830e+05 8.10314223e+03 1.82840198e+06
 1.21135296e+08 6.06468435e+05]


In [9]:
N = 16
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[31.81580292  3.1340611   0.58114218  0.13696698  2.1787725   0.49664391
 10.07822349  0.96736877  1.05066821  0.55215529]
[1.82377494e+04 1.12634704e+02 1.14415522e+02 7.46430986e+02
 6.43738888e+05 2.67045332e+04 1.97531693e+07 3.11909067e+04
 1.36134188e+05 7.97366729e+02]


In [10]:
N = 8
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[2.59658664e+00 1.60433099e-01 1.24890606e-01 1.07405156e-01
 9.51872448e-02 8.13984725e+02 7.47336764e-01 1.68206395e+01
 1.78448651e+00 4.12859640e-01]
[3.09697606e-01 8.58423150e-01 1.11671200e+02 1.37955233e+00
 2.34826639e-01 4.63200881e+02 8.17410795e+00 6.99155689e+00
 1.09060481e+00 1.86252521e+01]


In [15]:
N = 8
cdl_est_values = train_cdl(N, 100)
dr_est_values = train_dr(N, 100)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[2.22504738e+01 2.67431462e+06 7.73034164e+02 3.96541044e+02
 1.43597268e+02 4.97873293e+04 2.88397809e+03 5.05354645e+05
 7.70204409e+02 7.59108282e+03]
[2.53562779e+04 6.14453626e+06 1.12875001e+05 2.62410940e+03
 1.53140977e+03 3.24888070e+04 4.25532312e+04 8.11987926e+03
 1.10298786e+08 2.56945881e+04]


# Training Step

In [18]:
def train_esm(N = 64, training_steps = 10, noise = 0.1):
    torch.cuda.empty_cache()
    model_dr = DR_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_dr = torch.optim.Adam(model_dr.parameters(), learning_rate)
    dr_est_values = []
    
    model_cdl = CDL_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_cdl = torch.optim.Adam(model_cdl.parameters(), learning_rate)
    cdl_est_values = []
    
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        #dataset = create_dataset_DGP_binary_A_more_noise(Dim=Dim, N=N, noise = noise)
        dataset = create_dataset_DGP_binary_A_conf(Dim=Dim, N=N)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_dr.eval()
        drs = model_dr(batch_x, batch_y)
        #mi_est_values.append(cmi)
        dr_est_values.append(drs)
        model_dr.train() 

        model_loss = model_dr.learning_loss(batch_x, batch_y)

        optimizer_dr.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_dr.step()
        
        model_cdl.eval()
        cdl_cmi = model_cdl(batch_x, batch_y)
        cdl_est_values.append(cdl_cmi)
        model_cdl.train() 

        model_loss = model_cdl.learning_loss(batch_x, batch_y)

        optimizer_cdl.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_cdl.step()


        del batch_x, batch_y
        torch.cuda.empty_cache()
    return dr_est_values, cdl_est_values

In [20]:
for N in [8, 16, 32, 64, 128]:
    dr_est_values, cdl_est_values = train_esm(N, 10)
    print(np.array(cdl_est_values).mean(axis=0))
    print(np.array(dr_est_values).mean(axis=0))
    print("--"*5 + str(N)+"--"*5 )

[1.25329346e+01 2.33828094e-01 9.39433105e-02 1.22737312e-01
 1.09861022e-01 5.70451570e+01 7.52661324e+03 3.59350602e+03
 5.57162971e+01 2.77755060e+02]
[4.27949015e+01 5.57002234e+00 4.18683417e+00 3.32199464e+00
 2.49062462e+00 1.09259360e+03 5.88827249e+05 2.83946486e+05
 4.90816225e+03 5.40377167e+05]
----------8----------
[0.25995747 0.11841028 0.39222155 0.26491672 9.66936338 0.57145281
 0.53439771 0.85490142 0.50766723 0.32138405]
[ 0.35467818  0.26877233  0.33225561  0.30409989 59.24332415  0.41454276
  2.74551751  0.50846437  0.50620058  0.19574837]
----------16----------
[8.96984219e-01 5.21944752e-01 7.22083839e-01 5.69464850e-01
 7.14565395e+00 3.36902386e+02 1.58821017e+01 4.31492417e+04
 5.12508720e+01 3.09743539e+01]
[3.95176272e+03 2.40484452e+03 3.16760386e+03 2.77049402e+03
 2.82605991e+03 3.93826046e+05 2.86683910e+05 1.53816013e+08
 4.43234782e+05 5.20110664e+05]
----------32----------
[6.04618083e+00 6.47362170e+00 6.47617442e+00 6.47034365e+00
 6.56119749e+00 1.1

In [21]:
for N in [8, 16, 32, 64, 128]:
    dr_est_values, cdl_est_values = train_esm(N, 20)
    print(np.array(cdl_est_values).mean(axis=0))
    print(np.array(dr_est_values).mean(axis=0))
    print("--"*5 + str(N)+"--"*5 )

[8.66787768e+00 6.88185856e-01 8.97841344e+00 9.01879761e+00
 9.77338653e-02 3.59274518e+01 1.71304762e+04 5.99201608e+02
 4.42600560e+00 9.37441319e+01]
[3.03503771e+01 4.66112689e+00 3.00494620e+01 5.44344696e+01
 3.40477666e-01 2.41963234e+01 8.30842379e+03 3.19966823e+02
 4.97556353e+00 5.62905164e+01]
----------8----------
[1.23886029e+01 1.14096140e+01 1.41956168e+03 1.22064305e+01
 1.20991805e+01 4.91422698e+03 3.15020133e+05 1.02983261e+07
 1.36199123e+04 3.51717388e+03]
[3.04922873e+05 3.05191992e+05 3.05641042e+05 3.08461629e+05
 3.06205604e+05 1.09546020e+08 1.03562537e+11 1.64471294e+14
 9.28061814e+10 1.02188860e+08]
----------16----------
[7.27796719e+02 8.88229562e+01 6.14583699e+00 1.35158771e+00
 8.73641329e+00 1.11071908e+02 1.81336092e+04 4.06133250e+06
 2.75461790e+03 9.56061132e+04]
[3.62834924e+02 8.45911386e+02 1.19608005e+05 9.05620907e+02
 2.53453460e+02 2.03580710e+05 5.03678818e+07 1.59508256e+06
 1.31342091e+07 8.22641811e+07]
----------32----------
[8.49418