In [1]:
# %%
import time
import numpy as np
import pickle
from numpy.linalg import det

import CMINE_lib as CMINE
# from Guassian_variables import Data_guassian
from structurerl import * 

import pandas as pd
from scipy.stats import multivariate_normal
import itertools

np.random.seed(37)
from scipy import stats
from sklearn.neighbors import KernelDensity

import math

import torch 
import torch.nn as nn
import torch.nn.functional as F

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'


In [2]:
def log_sum_exp(value, dim=None, keepdim=False):
    """Numerically stable implementation of the operation
    value.exp().sum(dim, keepdim).log()
    """
    # TODO: torch.max(value, dim=None) threw an error at time of writing
    if dim is not None:
        m, _ = torch.max(value, dim=dim, keepdim=True)
        value0 = value - m
        if keepdim is False:
            m = m.squeeze(dim)
        return m + torch.log(torch.sum(torch.exp(value0),
                                       dim=dim, keepdim=keepdim))
    else:
        m = torch.max(value)
        sum_exp = torch.sum(torch.exp(value - m))
        if isinstance(sum_exp, Number):
            return m + math.log(sum_exp)
        else:
            return m + torch.log(sum_exp)

In [3]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [4]:
from modules.CMI import DR_CMI, CDL_CMI

In [5]:
Dim = 5
batch_size = 64
#dataset = CMINE.create_dataset_DGP( Dim=5, N=batch_size)
dataset = create_dataset_DGP_binary_A_conf( Dim=5, N=batch_size)
s_t = torch.from_numpy(dataset[0]).float().cuda()
s_next = torch.from_numpy(dataset[1]).float().cuda()
a = torch.from_numpy(dataset[2]).float().cuda()

In [6]:
sample_dim = 2*Dim

hidden_size = 15
learning_rate = 0.005
training_steps = 10

cubic = False 


In [7]:
def train_dr(N = 64, training_steps = 10, noise = 0.1):
    torch.cuda.empty_cache()
    model_dr = DR_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_dr = torch.optim.Adam(model_dr.parameters(), learning_rate)
    dr_est_values = []
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        #dataset = create_dataset_DGP_binary_A_conf(Dim=Dim, N=N)
        dataset = create_dataset_DGP_binary_A_more_noise(Dim=Dim, N=N, noise = noise)
        #dataset = create_dataset_DGP_binary_A(Dim=Dim, N=N)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()
        #print(a)

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_dr.eval()
        drs = model_dr(batch_x, batch_y)
        #mi_est_values.append(cmi)
        dr_est_values.append(drs)
        model_dr.train() 

        model_loss = model_dr.learning_loss(batch_x, batch_y)

        optimizer_dr.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_dr.step()

        del batch_x, batch_y
        torch.cuda.empty_cache()
    return dr_est_values

In [8]:
def train_cdl(N = 64, training_steps = 10, noise = 0.1):
    torch.cuda.empty_cache()
    model_cdl = CDL_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_cdl = torch.optim.Adam(model_cdl.parameters(), learning_rate)
    cdl_est_values = []
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        #dataset = create_dataset_DGP_binary_A_conf(Dim=Dim, N=64)
        dataset = create_dataset_DGP_binary_A_more_noise(Dim=Dim, N=N, noise = noise)
        #dataset = create_dataset_DGP_binary_A(Dim=Dim, N=64)
        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_cdl.eval()
        cdl_cmi = model_cdl(batch_x, batch_y)
        cdl_est_values.append(cdl_cmi)
        model_cdl.train() 

        model_loss = model_cdl.learning_loss(batch_x, batch_y)

        optimizer_cdl.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_cdl.step()

        del batch_x, batch_y
        torch.cuda.empty_cache()
    return cdl_est_values

In [9]:
N = 8
cdl_est_values = train_cdl(N, 100)
dr_est_values = train_dr(N, 100)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[1.42807053e+05 1.42753602e+05 1.42749906e+05 1.42753415e+05
 1.42761860e+05 2.95067178e+06 6.76185767e+09 8.98941433e+06
 3.77409148e+07 2.35669452e+06]
[9.17653270e+01 5.53447919e+01 2.48633819e+03 1.14063154e+02
 5.37941885e+04 8.25474727e+03 8.50891374e+03 3.55694774e+02
 5.44043101e+04 6.38314029e+04]


## Old Results 

In [18]:
N = 16
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[1.08953588e+06 1.08790531e+01 1.44941089e+03 9.76991869e+00
 1.07381283e+01 7.46936350e+01 6.51227847e+04 1.62549371e+03
 9.11367771e+07 3.97163777e+03]
[3.40228854e-01 7.89110609e-01 4.39443491e-01 5.32390746e-01
 4.89074229e-01 1.32232449e+02 8.08412067e+03 9.68853994e-01
 5.00934945e+00 8.82207803e+01]


In [19]:
N = 32
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[1.19239076e+00 9.48515549e+00 2.62783273e-01 2.91438824e-01
 2.84085817e-01 2.02316330e+02 9.52513240e+00 1.01251328e+01
 3.89528570e+01 2.11831938e+04]
[5.08776069e-01 4.90942820e+00 1.88697887e-01 4.82539841e-01
 5.17962856e-01 8.88308585e+02 8.50928133e+00 6.07973797e+01
 9.21396018e+03 8.99030582e+01]


In [20]:
N=64
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[1.69835299e+04 1.86289765e+04 1.86692927e+04 1.86542401e+04
 1.86682338e+04 2.42494386e+05 3.14070918e+07 1.00317374e+13
 7.91771850e+05 3.49275749e+12]
[3.80414079e+00 3.62093483e+04 4.16812099e+01 1.23442339e+00
 2.22157874e+00 8.73493478e+04 5.98312236e+04 2.55250510e+03
 7.10490719e+02 4.58063107e+04]


In [21]:
N=128
cdl_est_values = train_cdl(N, 10)
dr_est_values = train_dr(N, 10)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[1.09370371e+02 2.63216174e+00 3.95621713e+00 1.38161645e+03
 1.05474120e+01 1.86754954e+03 1.67750882e+04 4.75895515e+04
 4.50832448e+05 5.32900840e+06]
[9.98662822e+00 1.46549784e+00 2.40550194e+03 1.42676535e+00
 1.73044784e+00 2.49111101e+07 4.76009027e+03 2.30951621e+02
 9.79896792e+02 2.15491152e+03]


# Add Noise -new results

In [17]:
for N in [8, 16, 32, 64, 128]:
    cdl_est_values = train_cdl(N, 50, 0.5)
    dr_est_values = train_dr(N, 50, 0.5)
    print(np.array(cdl_est_values).mean(axis=0))
    print(np.array(dr_est_values).mean(axis=0))
    print("--"*5 + str(N)+"--"*5 )

In [11]:
def train_esm(N = 64, training_steps = 10, noise = 0.1):
    torch.cuda.empty_cache()
    model_dr = DR_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_dr = torch.optim.Adam(model_dr.parameters(), learning_rate)
    dr_est_values = []
    
    model_cdl = CDL_CMI(sample_dim + 1, sample_dim, hidden_size).cuda()
    optimizer_cdl = torch.optim.Adam(model_cdl.parameters(), learning_rate)
    cdl_est_values = []
    
    for step in range(training_steps):
        #batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size = batch_size, to_cuda = True, cubic = cubic)
        dataset = create_dataset_DGP_binary_A_more_noise(Dim=Dim, N=N, noise = noise)

        s_t = torch.from_numpy(dataset[0]).float().cuda()
        s_next = torch.from_numpy(dataset[1]).float().cuda()
        a = torch.from_numpy(dataset[2]).float().cuda()

        batch_x = torch.cat([s_t,a], dim=1)
        batch_y = s_next
        model_dr.eval()
        drs = model_dr(batch_x, batch_y)
        #mi_est_values.append(cmi)
        dr_est_values.append(drs)
        model_dr.train() 

        model_loss = model_dr.learning_loss(batch_x, batch_y)

        optimizer_dr.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_dr.step()
        
        model_cdl.eval()
        cdl_cmi = model_cdl(batch_x, batch_y)
        cdl_est_values.append(cdl_cmi)
        model_cdl.train() 

        model_loss = model_cdl.learning_loss(batch_x, batch_y)

        optimizer_cdl.zero_grad()
        model_loss.backward(retain_graph=True)
        optimizer_cdl.step()


        del batch_x, batch_y
        torch.cuda.empty_cache()
    return dr_est_values, cdl_est_values

In [12]:
for N in [8, 16, 32, 64, 128]:
    dr_est_values, cdl_est_values = train_esm(N, 50, 0.5)
    print(np.array(cdl_est_values).mean(axis=0))
    print(np.array(dr_est_values).mean(axis=0))
    print("--"*5 + str(N)+"--"*5 )

[6.27522378e+01 1.99899927e+01 4.49873515e+06 2.00200868e+00
 1.68166133e+06 3.58446657e+04 9.29702518e+03 1.23249434e+01
 3.09786244e+03 2.33797748e+03]
[9.46756958e+01 1.47041189e+01 5.50061596e+06 9.66862688e-01
 1.58130117e+05 4.07256182e+03 8.50385204e+02 1.43844149e+01
 1.06400933e+02 7.59205550e+02]
----------8----------
[1.73272173e+02 1.72410264e+02 1.72702997e+02 1.01910669e+04
 1.73670759e+02 5.29206498e+03 2.93857890e+04 1.08014004e+10
 3.17839742e+05 2.62457337e+04]
[4.69256786e+02 4.70873291e+02 4.70567813e+02 3.93826687e+03
 4.72690386e+02 5.99947338e+02 3.98868638e+04 5.85785349e+09
 1.19711881e+05 3.85847352e+04]
----------16----------
[5.50198784e+04 1.65428897e+05 1.52550235e+04 4.08272492e+02
 9.48673955e+00 1.17554815e+03 4.05244490e+04 8.28555150e+05
 6.32163527e+03 4.48339228e+07]
[8.61598978e+03 1.67764485e+05 1.40673700e+04 2.39443851e+02
 1.07057798e+02 3.93649812e+04 3.78271892e+05 1.44752587e+06
 1.00650161e+05 7.26311190e+09]
----------32----------
[2.51350

In [13]:
for N in [8, 16, 32, 64, 128]:
    dr_est_values, cdl_est_values = train_esm(N, 10, 0.1)
    print(np.array(cdl_est_values).mean(axis=0))
    print(np.array(dr_est_values).mean(axis=0))
    print("--"*5 + str(N)+"--"*5 )

[0.04866999 7.01217682 0.04766131 0.12564618 0.11494481 0.18814224
 8.59377559 0.07897958 0.5596874  0.29858135]
[ 0.2424953  80.34619155  0.15440329  0.45440101  0.90330312  0.33658343
 39.69718642  0.26733409  1.76217201  0.30534739]
----------8----------
[1.66565640e+01 1.97844515e+01 1.77469333e+01 1.66696469e+01
 1.72111965e+01 4.64933332e+04 6.90912953e+03 1.33696373e+08
 2.76609252e+03 7.31972743e+04]
[4.80097537e-01 6.74288021e+00 2.64210631e+00 7.78250272e-01
 4.31166643e-01 1.78720563e+04 1.23545206e+03 5.44934836e+07
 9.03627574e+02 3.10088114e+04]
----------16----------
[2.60282330e+01 2.54104346e+01 2.56296945e+01 2.62729466e+01
 2.68032038e+01 3.61918728e+10 8.83780702e+06 2.65420460e+03
 1.01977157e+05 2.72661188e+04]
[3.48236828e+02 3.48176392e+02 3.48637211e+02 3.48957086e+02
 3.48875133e+02 4.76743188e+10 1.16829665e+07 4.52072919e+03
 1.39457876e+05 3.87198933e+04]
----------32----------
[1.11735786e-01 1.18895508e-01 1.14877799e-01 1.88267124e+00
 1.44813668e-01 1.3

In [14]:
for N in [8, 16, 32, 64, 128]:
    dr_est_values, cdl_est_values = train_esm(N, 50, 0.1)
    print(np.array(cdl_est_values).mean(axis=0))
    print(np.array(dr_est_values).mean(axis=0))
    print("--"*5 + str(N)+"--"*5 )

[5.58244097e+01 2.55891578e+01 2.03924307e+05 2.65426140e+00
 9.54317603e+00 6.27322246e+01 2.27773337e+02 8.67647185e+01
 9.61960018e+02 2.15560068e+01]
[1.05678488e+01 1.26729427e+01 2.94472421e+05 3.05807425e+00
 5.50986946e+00 2.58575006e+01 9.26393252e+01 1.08811140e+01
 2.34142155e+01 8.97571105e+00]
----------8----------
[2.51419614e+02 2.51159905e+02 2.54982574e+02 9.06241668e+03
 2.57967822e+02 4.22905741e+03 1.28840933e+11 1.96511170e+08
 9.69982748e+04 8.26377681e+04]
[7.88506074e+03 7.89027345e+03 7.95177220e+03 1.03649639e+04
 7.91944751e+03 4.65513900e+03 6.13914088e+10 1.68715285e+08
 8.86504383e+04 1.40712098e+05]
----------16----------
[1.62616104e+02 6.72364695e+02 1.74696710e+03 8.68771990e+01
 8.76604958e+01 1.13363569e+04 1.35902816e+05 1.33471457e+08
 1.87849775e+04 2.75284862e+08]
[2.16974102e+02 2.64064622e+03 3.16799670e+03 4.29129479e+01
 4.98180613e+01 1.27525146e+04 1.54038188e+05 1.99937280e+08
 1.30997407e+04 2.22880642e+08]
----------32----------
[2.27096

# Training Step

In [15]:
N = 64
training_step = 100
cdl_est_values = train_cdl(N, training_step)
dr_est_values = train_dr(N, training_step)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[2.01726316e+03 2.04238323e+05 6.07205578e+01 2.41863759e+06
 5.87945271e+01 3.91935844e+02 1.57143758e+05 3.43342937e+04
 4.07721823e+05 5.57646968e+02]
[7.39971510e+06 7.39403409e+06 7.39384683e+06 7.39388170e+06
 7.39375034e+06 9.30547196e+08 1.76988732e+10 5.26174280e+17
 6.79014097e+08 8.33487436e+09]


In [16]:
N = 32
training_step = 100
cdl_est_values = train_cdl(N, training_step)
dr_est_values = train_dr(N, training_step)
print(np.array(cdl_est_values).mean(axis=0))
print(np.array(dr_est_values).mean(axis=0))

[2.53453184e+05 1.74253382e+07 2.54031519e+05 2.53812429e+05
 2.53168319e+05 3.46062771e+06 2.51382961e+08 1.54563827e+12
 1.84511028e+06 1.65020993e+06]
[1.11571460e+02 6.64161811e+02 1.32397845e+01 1.88535050e+01
 4.88466334e+03 2.44870363e+05 3.17462745e+04 1.03511919e+07
 1.45789549e+03 1.74743499e+05]
