In [29]:

import cdt
import os
import pandas as pd
import numpy as np
from cdt.data import load_dataset
import networkx as nx
from cdt.causality.graph import SAM
import matplotlib.pyplot as plt
import seaborn as sns

In [40]:
data_fn = '211205wave2diff_styles.csv'

df = pd.read_csv(data_fn)

cols = df.columns

df=(df-df.mean())/df.std()

In [41]:
'''lr (float) – Learning rate of the generators
dlr (float) – Learning rate of the discriminator
mixed_data (bool) – Experimental – Enable for mixed-type datasets
lambda1 (float) – L0 penalization coefficient on the causal filters
lambda2 (float) – L2 penalization coefficient on the weights of the neural network
nh (int) – Number of hidden units in the generators’ hidden layers (regularized with lambda2)
dnh (int) – Number of hidden units in the discriminator’s hidden layers
train_epochs (int) – Number of training epochs
test_epochs (int) – Number of test epochs (saving and averaging the causal filters)
batch_size (int) – Size of the batches to be fed to the SAM model Defaults to full-batch
losstype (str) – type of the loss to be used (either ‘fgan’ (default), ‘gan’ or ‘mse’)
dagloss (bool) – Activate the DAG with No-TEARS constraint
dagstart (float) – Controls when the DAG constraint is to be introduced in the training (float ranging from 0 to 1, 0 denotes the start of the training and 1 the end)
dagpenalisation (float) – Initial value of the DAG constraint
dagpenalisation_increase (float) – Increase incrementally at each epoch the coefficient of the constraint
functional_complexity (str) – Type of functional complexity penalization (choose between ‘l2_norm’ and ‘n_hidden_units’)
hlayers (int) – Defines the number of hidden layers in the generators
dhlayers (int) – Defines the number of hidden layers in the discriminator
sampling_type (str) – Type of sampling used in the structural gates of the model (choose between ‘sigmoid’, ‘sigmoid_proba’ and ‘gumble_proba’)
linear (bool) – If true, all generators are set to be linear generators
nruns (int) – Number of runs to be made for causal estimation Recommended: >=32 for optimal performance
njobs (int) – Numbers of jobs to be run in Parallel Recommended: 1 if no GPU available, 2*number of GPUs else
gpus (int) – Number of available GPUs for the algorithm
verbose (bool) – verbose mode
'''

lrs = [0.01, 0.001]
daglosses = [True, False]
dagpenalizations = [0.01, 0.05, 0.1]
num_runs = 50

for lr in lrs:
    for dagloss in daglosses:
        for dagpenalization in dagpenalizations:
            
            settings = 'lr_' + str(lr) + '_dagloss_' + str(dagloss) + '_dagpen_' + str(dagpenalization) + '_numruns_' + str(num_runs)
            
            obj = SAM(lr=lr, dlr=0.001, mixed_data=True, lambda1=10, lambda2=0.001, nh=20,
                       dnh=200, train_epochs=3000, test_epochs=800, batch_size=- 1, losstype='fgan',
                       dagloss=dagloss, dagstart=0.5, dagpenalization=0, dagpenalization_increase=dagpenalization, 
                        hlayers=2, dhlayers=2, sampling_type='sigmoidproba',
                       linear=False, nruns=num_runs, njobs=1, gpus=1, verbose=None)

            output = obj.predict(df) 
            
            nx.write_gml(output, "graphs/wave_2diffgraph_{}.gml".format(settings))

            
       
    

100%|██████████| 3800/3800 [01:33<00:00, 37.99it/s, disc=2.02e+14, gen=-2.08e+13, regul_loss=0.868, tot=-4.03e+15]
100%|██████████| 3800/3800 [01:36<00:00, 37.91it/s, disc=4.09e+13, gen=-4.21e+12, regul_loss=0.734, tot=-8.16e+14]
100%|██████████| 3800/3800 [01:36<00:00, 37.95it/s, disc=2.48e+10, gen=-2.55e+9, regul_loss=1.25, tot=-4.95e+11] 
100%|██████████| 3800/3800 [01:36<00:00, 39.49it/s, disc=7.77e+12, gen=-8.01e+11, regul_loss=0.734, tot=-1.55e+14]
100%|██████████| 3800/3800 [01:36<00:00, 37.63it/s, disc=8.54, gen=-.122, regul_loss=0.668, tot=-20.5]     
100%|██████████| 3800/3800 [01:33<00:00, 40.67it/s, disc=14.2, gen=-.668, regul_loss=0.66, tot=-128]       
100%|██████████| 3800/3800 [01:31<00:00, 41.36it/s, disc=2.87e+6, gen=-2.96e+5, regul_loss=0.653, tot=-5.74e+7] 
100%|██████████| 3800/3800 [01:31<00:00, 41.56it/s, disc=0.193, gen=-.104, regul_loss=0.549, tot=-19.5] 
100%|██████████| 3800/3800 [01:33<00:00, 40.67it/s, disc=498, gen=-50.9, regul_loss=0.89, tot=-8e+3]       

100%|██████████| 3800/3800 [01:30<00:00, 41.97it/s, disc=116, gen=-11.2, regul_loss=0.593, tot=-2.17e+3]          
100%|██████████| 3800/3800 [01:30<00:00, 40.68it/s, disc=9.76e+25, gen=-1.01e+25, regul_loss=0.89, tot=-1.95e+27] 
100%|██████████| 3800/3800 [01:30<00:00, 40.73it/s, disc=-.909, gen=-.106, regul_loss=0.542, tot=-19.9]  
100%|██████████| 3800/3800 [01:30<00:00, 42.05it/s, disc=1.5e+23, gen=-1.55e+22, regul_loss=0.66, tot=-3e+24]     
100%|██████████| 3800/3800 [01:30<00:00, 42.01it/s, disc=1.94e+24, gen=-2e+23, regul_loss=1.25, tot=-3.87e+25]    
100%|██████████| 3800/3800 [01:30<00:00, 42.05it/s, disc=-.0719, gen=-.104, regul_loss=0.386, tot=-19.7]  
100%|██████████| 3800/3800 [01:30<00:00, 40.68it/s, disc=4.38e+7, gen=-4.51e+6, regul_loss=1.16, tot=-8.74e+8] 
100%|██████████| 3800/3800 [01:30<00:00, 42.19it/s, disc=6.68e+8, gen=-6.89e+7, regul_loss=0.571, tot=-1.34e+10]  
100%|██████████| 3800/3800 [01:30<00:00, 42.19it/s, disc=9.17e+21, gen=-9.45e+20, regul_loss=1.16, t

100%|██████████| 3800/3800 [01:30<00:00, 42.01it/s, disc=2.46e+14, gen=-2.53e+13, regul_loss=0.571, tot=-4.91e+15]
100%|██████████| 3800/3800 [01:30<00:00, 40.84it/s, disc=-.658, gen=-.103, regul_loss=0.467, tot=-19.4] 
100%|██████████| 3800/3800 [01:30<00:00, 42.12it/s, disc=-.598, gen=-.107, regul_loss=0.43, tot=-20.1]  
100%|██████████| 3800/3800 [01:30<00:00, 42.04it/s, disc=-.601, gen=-.11, regul_loss=0.527, tot=-20.7]   
100%|██████████| 3800/3800 [01:30<00:00, 42.15it/s, disc=0.0757, gen=-.103, regul_loss=0.408, tot=-19.5] 
100%|██████████| 3800/3800 [01:30<00:00, 42.06it/s, disc=5.61e+3, gen=-578, regul_loss=0.89, tot=-1e+5]        
100%|██████████| 3800/3800 [01:22<00:00, 46.15it/s, disc=6e+9, gen=-6.18e+8, regul_loss=0.964, tot=-1.2e+11]     
100%|██████████| 3800/3800 [01:22<00:00, 46.08it/s, disc=6.54e+5, gen=-6.74e+4, regul_loss=0.883, tot=-1.31e+7] 
100%|██████████| 3800/3800 [01:22<00:00, 46.10it/s, disc=1.52, gen=-.0341, regul_loss=0.69, tot=-5.93] 
100%|██████████| 380

100%|██████████| 3800/3800 [01:22<00:00, 46.10it/s, disc=2.04e+9, gen=-2.1e+8, regul_loss=0.972, tot=-4.07e+10]   
100%|██████████| 3800/3800 [01:22<00:00, 46.15it/s, disc=3.5e+6, gen=-3.6e+5, regul_loss=1.04, tot=-6.98e+7]    
100%|██████████| 3800/3800 [01:22<00:00, 46.11it/s, disc=1.97e+4, gen=-2.03e+3, regul_loss=0.89, tot=-3.93e+5]  
100%|██████████| 3800/3800 [01:22<00:00, 46.12it/s, disc=1.44e+5, gen=-1.48e+4, regul_loss=1.04, tot=-2.88e+6] 
100%|██████████| 3800/3800 [01:22<00:00, 46.12it/s, disc=4.75e+25, gen=-4.9e+24, regul_loss=1.11, tot=-9.51e+26]  
100%|██████████| 3800/3800 [01:22<00:00, 46.03it/s, disc=4.86e+12, gen=-5.01e+11, regul_loss=0.853, tot=-9.72e+13]
100%|██████████| 3800/3800 [01:22<00:00, 46.05it/s, disc=8.17, gen=-.0579, regul_loss=0.964, tot=-10.3]   
100%|██████████| 3800/3800 [01:22<00:00, 46.08it/s, disc=1.17e+15, gen=-1.2e+14, regul_loss=0.92, tot=-2.33e+16]  
100%|██████████| 3800/3800 [01:22<00:00, 46.07it/s, disc=0.52, gen=-.0989, regul_loss=0.794, to

100%|██████████| 3800/3800 [01:22<00:00, 46.04it/s, disc=3.19e+8, gen=-3.28e+7, regul_loss=1.14, tot=-6.37e+9]   
100%|██████████| 3800/3800 [01:22<00:00, 46.13it/s, disc=9.12e+15, gen=-9.36e+14, regul_loss=0.934, tot=-1.82e+17]
100%|██████████| 3800/3800 [01:22<00:00, 46.18it/s, disc=3.47e+5, gen=-3.58e+4, regul_loss=0.949, tot=-6.95e+6] 
100%|██████████| 3800/3800 [01:22<00:00, 46.19it/s, disc=2.95e+14, gen=-3.03e+13, regul_loss=1.08, tot=-5.88e+15] 
100%|██████████| 3800/3800 [01:22<00:00, 46.24it/s, disc=2.5e+6, gen=-2.58e+5, regul_loss=0.979, tot=-5.01e+7] 
100%|██████████| 3800/3800 [01:22<00:00, 46.18it/s, disc=-.23, gen=-.103, regul_loss=0.712, tot=-19.2]  
100%|██████████| 3800/3800 [01:22<00:00, 46.17it/s, disc=4.14e+18, gen=-4.27e+17, regul_loss=0.912, tot=-8.28e+19]
100%|██████████| 3800/3800 [01:22<00:00, 46.17it/s, disc=1.03e+7, gen=-1.06e+6, regul_loss=0.949, tot=-2.05e+8] 
100%|██████████| 3800/3800 [01:22<00:00, 46.21it/s, disc=1.74e+4, gen=-1.79e+3, regul_loss=1.05, t

100%|██████████| 3800/3800 [01:30<00:00, 40.53it/s, disc=2.24e+4, gen=-2.3e+3, regul_loss=1.73, tot=1.11e+6]  
100%|██████████| 3800/3800 [01:30<00:00, 41.86it/s, disc=0.2, gen=-.104, regul_loss=1.62, tot=5.77e+5]   
100%|██████████| 3800/3800 [01:30<00:00, 42.00it/s, disc=1.74e+13, gen=-1.54e+12, regul_loss=1.76, tot=-2.99e+14]
100%|██████████| 3800/3800 [01:30<00:00, 41.97it/s, disc=0.0176, gen=-.103, regul_loss=1.71, tot=9.24e+5]
100%|██████████| 3800/3800 [01:30<00:00, 40.75it/s, disc=-.734, gen=-.108, regul_loss=1.65, tot=1.74e+5]  
100%|██████████| 3800/3800 [01:30<00:00, 41.99it/s, disc=2.06e+5, gen=-2.12e+4, regul_loss=1.74, tot=-1.49e+6]
100%|██████████| 3800/3800 [01:30<00:00, 42.00it/s, disc=35.2, gen=-3.2, regul_loss=1.76, tot=5.91e+5]   
100%|██████████| 3800/3800 [01:30<00:00, 41.93it/s, disc=-.215, gen=-.102, regul_loss=1.67, tot=8.8e+5]   
100%|██████████| 3800/3800 [01:30<00:00, 41.99it/s, disc=-.126, gen=-.102, regul_loss=1.72, tot=7.6e+5]   
100%|██████████| 3800/380

100%|██████████| 3800/3800 [01:30<00:00, 42.06it/s, disc=266, gen=-26.4, regul_loss=1.58, tot=8.97e+5]       
100%|██████████| 3800/3800 [01:30<00:00, 41.98it/s, disc=2.74e+5, gen=-2.82e+4, regul_loss=1.84, tot=4.73e+5] 
100%|██████████| 3800/3800 [01:30<00:00, 42.06it/s, disc=1.79e+7, gen=-1.84e+6, regul_loss=1.88, tot=-3.45e+8]
100%|██████████| 3800/3800 [01:30<00:00, 41.95it/s, disc=1.02e+3, gen=-105, regul_loss=1.7, tot=1.39e+6]     
100%|██████████| 3800/3800 [01:30<00:00, 42.12it/s, disc=-.779, gen=-.101, regul_loss=1.68, tot=1.33e+6] 
100%|██████████| 3800/3800 [01:30<00:00, 41.98it/s, disc=4.18e+6, gen=-4.3e+5, regul_loss=1.86, tot=-6.21e+7] 
100%|██████████| 3800/3800 [01:30<00:00, 42.09it/s, disc=49.5, gen=-4.65, regul_loss=1.61, tot=1.39e+6]      
100%|██████████| 3800/3800 [01:30<00:00, 41.94it/s, disc=9.76e+4, gen=-1e+4, regul_loss=1.76, tot=8.65e+5]    
100%|██████████| 3800/3800 [01:30<00:00, 42.04it/s, disc=17.9, gen=-1.04, regul_loss=1.7, tot=1.59e+6]    
100%|████████

100%|██████████| 3800/3800 [01:22<00:00, 46.25it/s, disc=1.58e+8, gen=-1.63e+7, regul_loss=2.48, tot=-3.16e+9]   
100%|██████████| 3800/3800 [01:21<00:00, 46.35it/s, disc=-.33, gen=-.0986, regul_loss=2, tot=-17.1]     
100%|██████████| 3800/3800 [01:22<00:00, 46.24it/s, disc=101, gen=-9.93, regul_loss=2.53, tot=-1.92e+3]       
100%|██████████| 3800/3800 [01:22<00:00, 46.25it/s, disc=-.288, gen=-.102, regul_loss=2.17, tot=-17.5]  
100%|██████████| 3800/3800 [01:22<00:00, 46.32it/s, disc=600, gen=-60.9, regul_loss=2.22, tot=-1.18e+4]       
100%|██████████| 3800/3800 [01:22<00:00, 46.27it/s, disc=9.14e+4, gen=-9.36e+3, regul_loss=2.25, tot=-1.82e+6]
100%|██████████| 3800/3800 [01:22<00:00, 46.31it/s, disc=-.059, gen=-.103, regul_loss=2.11, tot=-17.9]  
100%|██████████| 3800/3800 [01:22<00:00, 46.31it/s, disc=-.267, gen=-.104, regul_loss=2.05, tot=-18.1] 
100%|██████████| 3800/3800 [01:22<00:00, 46.21it/s, disc=-.166, gen=-.104, regul_loss=2.22, tot=-17.9]  
100%|██████████| 3800/3800 [0

100%|██████████| 3800/3800 [01:22<00:00, 46.33it/s, disc=-.319, gen=-.101, regul_loss=2.22, tot=-17.4]  
100%|██████████| 3800/3800 [01:21<00:00, 46.36it/s, disc=1.06, gen=-.0435, regul_loss=2.34, tot=-6.1]   
100%|██████████| 3800/3800 [01:21<00:00, 46.39it/s, disc=6.96e+7, gen=-7.1e+6, regul_loss=2.17, tot=-1.38e+9]   
100%|██████████| 3800/3800 [01:21<00:00, 46.41it/s, disc=1.35e+8, gen=-1.39e+7, regul_loss=2.27, tot=-2.7e+9]   
100%|██████████| 3800/3800 [01:21<00:00, 46.46it/s, disc=-.399, gen=-.103, regul_loss=2.13, tot=-17.8]  
100%|██████████| 3800/3800 [01:21<00:00, 46.37it/s, disc=-.0635, gen=-.104, regul_loss=2.11, tot=-18]   
100%|██████████| 3800/3800 [01:21<00:00, 46.42it/s, disc=799, gen=-81.8, regul_loss=2.51, tot=-1.59e+4]       
100%|██████████| 3800/3800 [01:21<00:00, 46.39it/s, disc=-.342, gen=-.106, regul_loss=2.17, tot=-18.3]  
100%|██████████| 3800/3800 [01:21<00:00, 46.44it/s, disc=-.091, gen=-.0999, regul_loss=2.07, tot=-17.3]
100%|██████████| 3800/3800 [01:21<

In [18]:
data_fn = '211205wave6e.csv'

df = pd.read_csv(data_fn)

cols = df.columns

In [None]:
df=(df-df.mean())/df.std()


'''lr (float) – Learning rate of the generators
dlr (float) – Learning rate of the discriminator
mixed_data (bool) – Experimental – Enable for mixed-type datasets
lambda1 (float) – L0 penalization coefficient on the causal filters
lambda2 (float) – L2 penalization coefficient on the weights of the neural network
nh (int) – Number of hidden units in the generators’ hidden layers (regularized with lambda2)
dnh (int) – Number of hidden units in the discriminator’s hidden layers
train_epochs (int) – Number of training epochs
test_epochs (int) – Number of test epochs (saving and averaging the causal filters)
batch_size (int) – Size of the batches to be fed to the SAM model Defaults to full-batch
losstype (str) – type of the loss to be used (either ‘fgan’ (default), ‘gan’ or ‘mse’)
dagloss (bool) – Activate the DAG with No-TEARS constraint
dagstart (float) – Controls when the DAG constraint is to be introduced in the training (float ranging from 0 to 1, 0 denotes the start of the training and 1 the end)
dagpenalisation (float) – Initial value of the DAG constraint
dagpenalisation_increase (float) – Increase incrementally at each epoch the coefficient of the constraint
functional_complexity (str) – Type of functional complexity penalization (choose between ‘l2_norm’ and ‘n_hidden_units’)
hlayers (int) – Defines the number of hidden layers in the generators
dhlayers (int) – Defines the number of hidden layers in the discriminator
sampling_type (str) – Type of sampling used in the structural gates of the model (choose between ‘sigmoid’, ‘sigmoid_proba’ and ‘gumble_proba’)
linear (bool) – If true, all generators are set to be linear generators
nruns (int) – Number of runs to be made for causal estimation Recommended: >=32 for optimal performance
njobs (int) – Numbers of jobs to be run in Parallel Recommended: 1 if no GPU available, 2*number of GPUs else
gpus (int) – Number of available GPUs for the algorithm
verbose (bool) – verbose mode
'''

lrs = [0.01, 0.001]
daglosses = [True, False]
dagpenalizations = [0.01, 0.05, 0.1]
num_runs = 50

for lr in lrs:
    for dagloss in daglosses:
        for dagpenalization in dagpenalizations:
            
            settings = 'lr_' + str(lr) + '_dagloss_' + str(dagloss) + '_dagpen_' + str(dagpenalization) + '_numruns_' + str(num_runs)
            
            obj = SAM(lr=lr, dlr=0.001, mixed_data=True, lambda1=10, lambda2=0.001, nh=20,
                       dnh=200, train_epochs=3000, test_epochs=800, batch_size=- 1, losstype='fgan',
                       dagloss=dagloss, dagstart=0.5, dagpenalization=0, dagpenalization_increase=dagpenalization, 
                        hlayers=2, dhlayers=2, sampling_type='sigmoidproba',
                       linear=False, nruns=num_runs, njobs=1, gpus=1, verbose=None)

            output = obj.predict(df) 
            
            nx.write_gml(output, "graphs/wave_6_graph_{}.gml".format(settings))

            
       
    