In [1]:
from Shapley import readfile, getSpeciesName, getInputNames, getFormula, BNmodel
from Shapley import parseFormula, workWithSHAP, genSpecies, genAllInputForSHAP
from booleanFormulaParser import convertBooleanFormulas2Network 
import shap
import numpy as np

debug = True
isko = False
isbi = False
isacyclic = False 
exp = '../data/cd4.txt' # path to the expression file 

outputnames = set()
outputnames.add('Tbet')

print("The expression file is {}".format(exp))
print("The interested output is: ")
print(outputnames)
if isko:
    print("Perform normal procedure for input and knockout procedure for intermediate nodes")
else:
    print("Perform only normal procedure")
# now read the expression file and take the string formulas from this 
lines = readfile(exp, debug)

speciesnames = getSpeciesName(lines, debug) # get species names 
print("List of all species:")
print(speciesnames)

if not outputnames.issubset(speciesnames):
    print("The list of interested output is not a subset of species")

inputnames = getInputNames(lines, speciesnames, debug)
print("-----Input names-----")
print(inputnames)

internames = speciesnames.difference(inputnames).difference(outputnames)
print("----Intermediate nodes-----")
print(internames) 


print("----Getting boolean formulas-----")
strformulas = getFormula(lines, debug) # list of formulas in strings 
# formulas = [] # list of formulas in nodes 
formulas = dict() 
for strformula in strformulas:
    root = parseFormula(strformula, debug)
    if debug:
        print("Parsing formula for {}".format(strformula['left']))
        root.display()
        print("\n")
    # root.display() 
    # thisfor = {formula['left']: root}
    # formulas.append(thisfor)
    formulas[strformula['left']] = root 

# get the network of the original model
# orinet = convertBooleanFormulas2Network(formulas, inputnames, speciesnames, "network", debug)


# the function returns the simulation output 
# orioutputs, oridecimalpairs, oriinputshapss = workWithOriginalNetwork(orinet, inputnames, speciesnames, outputnames, internames, formulas, isko, debug)

for outputname in outputnames:
    # workWithSHAP(list(inputnames), speciesnames, outputname, formulas, debug)
    print("----------Work with SHAP-----------")
    species = genSpecies(speciesnames, debug)

    # inputs = genRandomInputForSHAP(len(inputnames), 32)
    inputs = genAllInputForSHAP(len(inputnames))
    # print("--------Input for SHAP------")
    # print(inputs)

    model = BNmodel(list(inputnames), species, formulas, outputname)
    # print("---------OUTPUT-----------")
    # print(model.predict(inputs))

    explainer = shap.Explainer(model.predict, inputs)

    shap_values = explainer(inputs)

    print(inputnames)
    print(shap_values.base_values)
    print(np.round(shap_values.values,5))
    # print(np.round(shap_values.values,4))
    
    print("----------End working with SHAP-----------")


The expression file is ../data/cd4.txt
The interested output is: 
{'Tbet'}
Perform only normal procedure
SOCS1 = ( Tbet )  OR ( STAT1 )
IL21R = ( IL21 )
STAT6 = ( ( ( IL4R  ) AND ( NOT ( SOCS1  ) )  ) AND ( NOT ( IFNg  ) ) )
IL6 = ( RORgt )
IFNgR = ( IFNg_e AND ( ( ( NFAT ) ) )    )  OR ( IFNg AND ( ( ( NFAT ) ) )    )
IL12R = ( ( STAT4  ) AND ( NOT ( GATA3  ) ) )  OR ( IL12 AND ( ( ( NFAT ) ) )    )  OR ( Tbet )  OR ( ( TCR  ) AND ( NOT ( GATA3  ) ) )
IL2 = ( ( NFAT AND ( ( ( NFkB ) ) )     ) AND ( NOT ( Tbet  ) ) )
Jak1 = ( ( IFNgR  ) AND ( NOT ( SOCS1  ) ) )
IL21 = ( STAT3 AND ( ( ( NFAT ) ) )    )
IL17 = ( ( ( STAT3 AND ( ( ( IL23R ) )  AND ( ( IL17 ) ) )     ) AND ( NOT ( STAT1  ) )  ) AND ( NOT ( STAT5  ) ) )  OR ( ( RORgt  ) AND ( NOT ( STAT1  ) ) )
STAT5 = ( IL2R )
GATA3 = ( ( ( ( ( STAT6 AND ( ( ( NFAT ) ) )     ) AND ( NOT ( Tbet  ) )  ) AND ( NOT ( TGFb  ) )  ) AND ( NOT ( Foxp3  ) )  ) AND ( NOT ( RORgt  ) ) )  OR ( ( GATA3  ) AND ( NOT ( Tbet  ) ) )  OR ( ( ( ( ( STAT5  ) 

ExactExplainer explainer: 513it [26:21:30, 185.33s/it]                                                              

{'IL23', 'TCR', 'IL4_e', 'IL6_e', 'IL27', 'IFNg_e', 'IL18', 'TGFb', 'IL12'}
[[0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.44]
 [0.




In [29]:
# get the average of SHAP values over all samples (512) 
ave = np.mean(np.abs(shap_values.values), axis=0)
print("Average abs: {}".format(np.round(ave,5)))

Average abs: [0.      0.42987 0.      0.01224 0.07479 0.04545 0.      0.01657 0.07299]


np.set_printoptions(threshold=np.inf)

In [25]:
for i in range(shap_values.shape[0]):
    print(shap_values.data[i])
    print(shap_values.base_values[i])
    print(np.round(np.abs(shap_values.values[i]),4))
    print()

[0 0 0 0 0 0 0 0 0]
[0.44]
[ 0.     -0.2805  0.      0.0045 -0.0607 -0.0505 -0.      0.0045 -0.0573]

[0 0 0 0 0 0 0 0 1]
[0.44]
[-0.     -0.4607 -0.      0.0027 -0.0233 -0.0207 -0.      0.0027  0.0593]

[0 0 0 0 0 0 0 1 0]
[0.44]
[ 0.     -0.2747  0.      0.0103 -0.0623 -0.0447 -0.     -0.0092 -0.0595]

[0 0 0 0 0 0 0 1 1]
[0.44]
[ 0.     -0.4582 -0.      0.0052 -0.0243 -0.0182  0.     -0.0063  0.0618]

[0 0 0 0 0 0 1 0 0]
[0.44]
[ 0.     -0.2805  0.      0.0045 -0.0607 -0.0505 -0.      0.0045 -0.0573]

[0 0 0 0 0 0 1 0 1]
[0.44]
[-0.     -0.4607 -0.      0.0027 -0.0233 -0.0207 -0.      0.0027  0.0593]

[0 0 0 0 0 0 1 1 0]
[0.44]
[ 0.     -0.2747  0.      0.0103 -0.0623 -0.0447 -0.     -0.0092 -0.0595]

[0 0 0 0 0 0 1 1 1]
[0.44]
[ 0.     -0.4582 -0.      0.0052 -0.0243 -0.0182  0.     -0.0063  0.0618]

[0 0 0 0 0 1 0 0 0]
[0.44]
[-0.     -0.4538  0.      0.0095 -0.028   0.0503 -0.      0.0095 -0.0275]

[0 0 0 0 0 1 0 0 1]
[0.44]
[-0.     -0.4887 -0.      0.0047 -0.0143  0.0223 -0.   

In [26]:
print(model.map)

['IL23', 'TCR', 'IL4_e', 'IL6_e', 'IL27', 'IFNg_e', 'IL18', 'TGFb', 'IL12']
