In [None]:
import javabridge
import os
import glob
import pandas as pd
import pydot
from IPython.display import SVG

In [None]:
tetrad_libdir = os.path.join(os.getcwd(), '..','..','src', 'pycausal', 'lib')
for l in glob.glob(tetrad_libdir + os.sep + "*.jar"):
    print l
    javabridge.JARS.append(str(l))

In [None]:
javabridge.start_vm(run_headless=True, max_heap_size = '100M')
javabridge.attach()

In [None]:
data_dir = os.path.join(os.getcwd(), '..','..','data', 'charity.txt')
data_dir

In [None]:
df = pd.read_table(data_dir, sep="\t")
df.head()

In [None]:
numCategoriesToDiscretize = 4
node_list = javabridge.JClassWrapper('java.util.ArrayList')()
cont_list = []
disc_list = []
col_no = 0
for col in df.columns:

    cat_array = sorted(set(df[col]))
    if(len(cat_array) > numCategoriesToDiscretize):
        # Continuous variable
        nodi = javabridge.JClassWrapper('edu.cmu.tetrad.data.ContinuousVariable')(col)
        node_list.add(nodi)

        cont_list.append(col_no)

    else:
        # Discrete variable
        cat_list = javabridge.JClassWrapper('java.util.ArrayList')()
        for cat in cat_array:
            cat = str(cat)
            cat_list.add(cat)

        nodname = javabridge.JClassWrapper('java.lang.String')(col)
        nodi = javabridge.JClassWrapper('edu.cmu.tetrad.data.DiscreteVariable')(nodname,cat_list)
        node_list.add(nodi)

        disc_list.append(col_no)

    col_no = col_no + 1

In [None]:
mixedDataBox = javabridge.JClassWrapper('edu.cmu.tetrad.data.MixedDataBox')(node_list, len(df.index))

In [None]:
for row in df.index:

    for col in cont_list:
        value = javabridge.JClassWrapper('java.lang.Double')(df.ix[row][col])
        mixedDataBox.set(row,col,value)

    for col in disc_list:
        cat_array = sorted(set(df[df.columns[col]]))
        value = javabridge.JClassWrapper('java.lang.Integer')(cat_array.index(df.ix[row][col]))
        mixedDataBox.set(row,col,value)

In [None]:
tetradData = javabridge.JClassWrapper('edu.cmu.tetrad.data.BoxDataSet')(mixedDataBox, node_list)

In [None]:
alpha = 0.05
discretize = False
indTest = javabridge.JClassWrapper('edu.cmu.tetrad.search.IndTestConditionalGaussianLRT')(tetradData, alpha, discretize)

In [None]:
penaltydiscount = 2 # set to 2 if variable# <= 50 otherwise set it to 4
structurePrior = 1.0
score = javabridge.JClassWrapper('edu.cmu.tetrad.search.ConditionalGaussianScore')(tetradData, structurePrior, discretize)
score.setPenaltyDiscount(penaltydiscount) 

In [None]:
gfci = javabridge.JClassWrapper('edu.cmu.tetrad.search.GFci')(indTest, score)

In [None]:
gfci.setMaxDegree(3)
gfci.setMaxPathLength(-1)
gfci.setCompleteRuleSetUsed(False)
gfci.setFaithfulnessAssumed(True)
gfci.setVerbose(True)

In [None]:
prior = javabridge.JClassWrapper('edu.cmu.tetrad.data.Knowledge2')()
prior.setForbidden('TangibilityCondition', 'Impact') # forbidden directed edges
prior.setRequired('Sympathy','TangibilityCondition') # required directed edges
prior.setTierForbiddenWithin(0, True)
prior.addToTier(0, 'TangibilityCondition')
prior.addToTier(0, 'Imaginability')
prior.addToTier(1, 'Sympathy')
prior.addToTier(1, 'AmountDonated')
prior.addToTier(2, 'Impact')
gfci.setKnowledge(prior)
prior

In [None]:
tetradGraph = gfci.search()
tetradGraph

In [None]:
tetradGraph.toString()

In [None]:
tetradGraph.getNodeNames()

In [None]:
tetradGraph.getEdges()

In [None]:
javabridge.detach()
javabridge.kill_vm()