In [16]:
#import all packages

import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import json
import time
import csv
import torch
import pyro
import pandas as pd
import covid19kg
from scipy.stats import norm
from sklearn.linear_model import LinearRegression
from sklearn import metrics
from pyro.infer import Importance, EmpiricalMarginal
from torch.distributions.transforms import AffineTransform
import pyro.distributions as dist
pyro.set_rng_seed(101)

In [17]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', -1)

In [18]:
class Node():
    def __init__(self):
        self.root = True
        self.children = []
        self.parent_relations = []
        self.child_relations = []
        self.name = ""
        self.node_type = ""
        self.node_label = ""
        self.children_type = []
        self.children_label = []
        self.parents = []
        self.parent_type = []
        self.parent_label = []
        
    
    def get_node_information(self, sub, obj, rel):
        if rel.find('crease') > 0:
            self.name = sub
            self.children.append(obj)
            self.child_relations.append(rel)

            p= sub
            c= obj
            ptype = p[:p.find('(')]
            ctype = c[:c.find('(')]

            self.node_type = ptype
            self.children_type.append(ctype)

            for label in label_dict:
                if ptype in label_dict[label]:
                    self.node_label = label
                elif ptype == '': 
                    self.node_label = 'Others'
            for label in label_dict:
                if ctype in label_dict[label]:
                    self.children_label.append(label)
                elif ctype == '': 
                    self.children_label.append('Others')
                    break
                    
    def update_parent_node(self, obj, rel):
        if rel.find('crease') > 0:
            self.children.append(obj)
            self.child_relations.append(rel)
            
            c = obj
            ctype = c[:c.find('(')]
            self.children_type.append(ctype)
            for label in label_dict:
                if ctype in label_dict[label]:
                    self.children_label.append(label)
                elif ctype == '': 
                    self.children_label.append('Others')
                    break
                    
    def update_child_node(self, sub, rel):
        self.root = False
        if rel.find('crease') > 0:
            self.parents.append(sub)
            self.parent_relations.append(rel)
            p = sub
            c = self.name      
            ptype = p[:p.find('(')]
            ctype = c[:c.find('(')]
            self.node_type = ctype
            self.parent_type.append(ptype)
            for label in label_dict:
                if ptype in label_dict[label]:
                    self.parent_label.append(label)
                elif ptype == '': 
                    self.parent_label.append('Others')
            for label in label_dict:
                if ctype in label_dict[label]:
                    self.node_label = label
                elif ctype == '': 
                    self.node_label = 'Others'
                    break
            
        
            


In [19]:
nodes = dict()

In [20]:
## a generic function to take BEL statements as input
## in any form and return a data structure as output 
## in desired format

## Created with the assumptions that inputs are bel statements of some sort


label_dict = {'Abundance':['a', 'r', 'm', 'g', 'p','pop', 'composite',
                           'complex','frag','fus','loc','pmod','var'],
             'Process': ['bp', 'path','act'],
             'Transformation':['sec','surf','deg','rxn','tloc','fromLoc',
                               'products','reactants','toLoc']}

def get_graph(str_list=[],bel_graph=[],jgf_file=[],nanopub_file = []):
    if str_list:
        ## extracting relevant information from string list format
        for item in str_list:
            sub_ind = item.find('=')
            sub_temp = item[:sub_ind-1]
            obj_temp = item[sub_ind+3:] 
            rel_temp = item[sub_ind:sub_ind+2]
            ## keeping only increases/decreases type of edges 
            if sub_temp in nodes:
                nodes[sub_temp].update_parent_node(obj_temp, rel_temp)
            else:
                sub_node= Node()
                sub_node.get_node_information(sub_temp, obj_temp, rel_temp)
                nodes[sub_temp] = sub_node
                
            if obj_temp in nodes:
                nodes[obj_temp].update_child_node(sub_temp, rel_temp)
            else:
                obj_node = Node()
                obj_node.name = obj_temp
                obj_node.update_child_node(sub_temp, rel_temp)
                nodes[obj_temp] = obj_node


    elif bel_graph:
        ## extracting relevant information from pybel format
        for item in bel_graph.edges:
            edge_temp = bel_graph.get_edge_data(item[0],item[1],item[2])
            sub_temp = str(item[0]).replace('"','')
            obj_temp = str(item[1]).replace('"','')
            rel_temp = edge_temp['relation']

            if sub_temp in nodes:
                nodes[sub_temp].update_parent_node(obj_temp, rel_temp)
            else:
                sub_node= Node()
                sub_node.get_node_information(sub_temp, obj_temp, rel_temp)
                nodes[sub_temp] = sub_node
                
            if obj_temp in nodes:
                nodes[obj_temp].update_child_node(sub_temp, rel_temp)
            else:
                obj_node = Node()
                obj_node.name = obj_temp
                obj_node.update_child_node(sub_temp, rel_temp)
                nodes[obj_temp] = obj_node


    elif jgf_file:
        file1 = open(jgf_file)
        loaded_jgf = json.load(file1)

        for item in loaded_jgf['graph']['edges']:
            sub_temp = item['source']
            obj_temp = item['target']
            rel_temp = item['relation']
            if sub_temp in nodes:
                nodes[sub_temp].update_parent_node(obj_temp, rel_temp)
            else:
                sub_node= Node()
                sub_node.get_node_information(sub_temp, obj_temp, rel_temp)
                nodes[sub_temp] = sub_node
                
            if obj_temp in nodes:
                nodes[obj_temp].update_child_node(sub_temp, rel_temp)
            else:
                obj_node = Node()
                obj_node.name = obj_temp
                obj_node.update_child_node(sub_temp, rel_temp)
                nodes[obj_temp] = obj_node

    elif nanopub_file:
        file1 = open(nanopub_file)
        loaded_nanopub = json.load(file1)
        for item in loaded_nanopub[0]['nanopub']['assertions']:
            sub_temp = item['subject']
            obj_temp = item['object']
            rel_temp = item['relation']
            if sub_temp in nodes:
                nodes[sub_temp].update_parent_node(obj_temp, rel_temp)
            else:
                sub_node= Node()
                sub_node.get_node_information(sub_temp, obj_temp, rel_temp)
                nodes[sub_temp] = sub_node
                
            if obj_temp in nodes:
                nodes[obj_temp].update_child_node(sub_temp, rel_temp)
            else:
                obj_node = Node()
                obj_node.name = obj_temp
                obj_node.update_child_node(sub_temp, rel_temp)
                nodes[obj_temp] = obj_node

In [22]:
f = open('config.json')
config = json.load(f)
print(config)

{'bel_settings': {'file_type': 'nanopub_file', 'file_name': './COVID-19-new.json'}, 'pyro_settings': {'mu_a': 0.0, 'sigma_a': 1.0, 'mu_t': 0.0, 'sigma_t': 1.0, 'cat_1': 0.5, 'cat_0': 0.5, 'alpha': 0.5, 'beta': 2.0, 'weights': 1.0, 'threshold': 0.5}, 'node_type_settings': {'activity': 'Categorical', 'abundance': 'Gamma', 'reaction': 'Normal', 'process': 'Categorical', 'pathology': 'Categorical', 'Others': 'Lognormal'}, 'exogenous_var_settings': {'categorical': 'Normal', 'continuous_a': 'Lognormal', 'continuous_t': 'Gamma'}}


In [6]:
get_graph(nanopub_file='COVID-19-new.json')

In [7]:
# for keys in nodes:
#     print(keys)
#     print("name", nodes[keys].name)
#     print("root", nodes[keys].root)
#     print("children", nodes[keys].children)
#     print("child relation", nodes[keys].child_relations)
#     print("child label", nodes[keys].children_label)
#     print("node type", nodes[keys].node_type)
#     print("node label", nodes[keys].node_label)
#     print("parents", nodes[keys].parents)
#     print("parent relation", nodes[keys].parent_relations)
#     print("parent label",nodes[keys].parent_label )
#     print()
#     print()

bp(GO:"tumor necrosis factor-mediated signaling pathway")
name bp(GO:"tumor necrosis factor-mediated signaling pathway")
root False
children ['act(complex(GO:"NF-kappaB complex"))']
child relation ['increases']
child label ['Process']
node type bp
node label Process
parents ['act(p(HGNC:ADAM17))']
parent relation ['increases']
parent label ['Process']


act(complex(GO:"NF-kappaB complex"))
name act(complex(GO:"NF-kappaB complex"))
root False
children ['bp(GO:"positive regulation of interleukin-6 production")']
child relation ['increases']
child label ['Process']
node type act
node label Process
parents ['bp(GO:"tumor necrosis factor-mediated signaling pathway")', 'bp(GO:"epidermal growth factor receptor signaling pathway")', 'bp(GO:"pattern recognition receptor signaling pathway")']
parent relation ['increases', 'increases', 'increases']
parent label ['Process', 'Process', 'Process']


bp(GO:"epidermal growth factor receptor signaling pathway")
name bp(GO:"epidermal growth factor recep

In [8]:
def get_distribution(node_dist: str, dist_parameters: list) -> dist:
    """
    Description: This function is to get the distribution for a node based on its type
    Parameters: the node's type
    Returns: sampled values for node in tensor format based on its type
    """

    if node_dist == 'Lognormal':
        return dist.LogNormal(torch.tensor(dist_parameters[0]), torch.tensor(dist_parameters[1]))
    if node_dist == 'Process':
        return dist.Categorical(torch.tensor(dist_parameters))
    if node_dist == "Gamma":
        return dist.Gamma(torch.tensor(dist_parameters[0]), torch.tensor(dist_parameters[1]))
    if node_dist == "Normal":
        return dist.Normal(torch.tensor(dist_parameters[0]), torch.tensor(dist_parameters[1]))

In [9]:
def check_increase(x, threshold):
    """
    Description: Helper function for SCM_model(), 
                 to be used with increasing type edges
    Parameters:  Result of parents' equation (x)
    Returns:     1.0 if value is greater than set threshold
                 else 0.0
    """
    # threshold = 0.5
    if x > threshold:
        return 1.0
    else:
        return 0.0

def check_decrease(x, threshold):
    """
    Description: Helper function for SCM_model(), 
                 to be used with decreasing type edges
    Parameters:  Result of parents' equation (x)
    Returns:     0.0 if value is greater than set threshold
                 else 1.0
    """
    # threshold = 0.5
    if x > threshold:
        return 0.0
    else:
        return 1.0


In [10]:
# def get_sample(parent_name, child_name, child_label, parent_label, parents, relation, w, threshold, normal, noise, \
#               increaseProcess, decreaseProcess):
#     child_increase_N = 0 
#     child_decrease_N = 0
#     for i in range(len(parent_label)):
#         if relation[i] == 'increases' or relation[i] == 'directlyIncreases':
#             if parent_label[i] == 'Abundance':
#                 child_increase_N += w[i] * parents[i]
#             if parent_label[i] == 'Transformation':
#                 child_increase_N += w[i] *(parents[i] * parents[i])
            
#         else:
#             if parent_label[i] == 'Abundance':
#                 child_decrease_N += w[i] * parents[i]
#             if parent_label[i] == 'Transformation':
#                 child_decrease_N += w[i] * parents[i] * parents[i]

    
#     if child_label == 'Process':
#         child_name_noise = child_name + "_N"
#         child_noise = pyro.sample(child_name_noise, normal)
#         child_check = check_increase(child_increase_N + child_noise + sum(increaseProcess), (len(parent_label))*threshold) + \
#                       check_decrease(child_decrease_N + child_noise + sum(decreaseProcess), (len(parent_label))*threshold)
#         if len(increaseProcess) == 0 and len(decreaseProcess) > 0 and child_check == 1.0:
#             child_N = torch.tensor(1.0)
#         elif len(decreaseProcess) == 0 and len(increaseProcess) > 0 and child_check == 1.0:
#             child_N = torch.tensor(1.0)
#         elif child_check == 2.0:
#             child_N = torch.tensor(1.0)
#         else:
#             child_N = torch.tensor(0.)

#     elif child_label == 'Abundance':
#         child_name_noise = child_name + "_N"
        
#         child_noise = pyro.sample(child_name_noise, dist.LogNormal(torch.tensor(mu_a),torch.tensor(sigma_a)))
#         child_N = child_increase_N - child_decrease_N + child_noise
#     else:
#         child_name_noise = child_name + "_N"
#         child_noise = pyro.sample(child_name_noise, noise)
#         #child_noise = pyro.sample(child_name_noise, dist.LogNormal(torch.tensor(mu_a),torch.tensor(sigma_a)))
#         child_N = child_increase_N - child_decrease_N + child_noise
        
#     return child_N

    

In [None]:
def get_abundance_sample(weights_a: list, p_sample_a: list):
    return sum(x * y for x, y in zip(weights_a, p_sample_a))


def get_transformation_sample(weights_t: list, p_sample_t: list):
    return sum(x * y * y for x, y in zip(weights_t, p_sample_t))


def get_sample(child_name: str,
               child_label: str,
               parent_label: list,
               threshold: float,
               normal: dist,
               noise: dist,
               increase_process: list,
               decrease_process: list, 
               increase_abundance: list,
              decrease_abundance: list,
               weights_ai: list,
               weights_ad: list,
              increase_transformation, 
              decrease_transformation, 
              weights_ti: list,
              weights_td: list) -> float:
    
    child_increase_N = 0
    child_decrease_N = 0
    
    child_increase_N = get_abundance_sample(weights_ai, increase_abundance) + \ 
    get_transformation_sample(weights_ti, increase_transformation)
    
    child_decrease_N = get_abundance_sample(weights_ad, decrease_abundance) + \ 
    get_transformation_sample(weights_td, decrease_transformation)
    
    if child_label == 'transformation':
        child_name_noise = child_name + "_N"
        child_noise = pyro.sample(child_name_noise, noise)
        child_N = child_increase_N - child_decrease_N + child_noise

    elif child_label == 'Abundance':
        child_name_noise = child_name + "_N"
        child_noise = pyro.sample(child_name_noise, dist.LogNormal(torch.tensor(mu_a), torch.tensor(sigma_a)))
        child_N = child_increase_N - child_decrease_N + child_noise
        
    else:
        child_name_noise = child_name + "_N"
        child_noise = pyro.sample(child_name_noise, normal)
        child_check = check_increase(child_increase_N + child_noise + sum(increaseProcess),
                                     (len(parent_label)) * threshold) + \
                      check_decrease(child_decrease_N + child_noise + sum(decreaseProcess),
                                     (len(parent_label)) * threshold)
        if len(increaseProcess) == 0 and len(decreaseProcess) > 0 and child_check == 1.0:
            child_N = torch.tensor(1.0)
        elif len(decreaseProcess) == 0 and len(increaseProcess) > 0 and child_check == 1.0:
            child_N = torch.tensor(1.0)
        elif child_check == 2.0:
            child_N = torch.tensor(1.0)
        else:
            child_N = torch.tensor(0.)

    return child_N


In [11]:
def SCM(nodes, config):
    
    """
    Description: This function is to be build a Structural Causal Model for 
                  for any child-parent cluster
    Parameters: knowledge graph as dataframe, 
                threshold for cutoff 
                weights for each parents
    Returns: sampled values for all nodes in tensor format
    """
    pyro_settings = config["pyro_settings"]
    node_settings = config["node_type_settings"]
    exogenous_var_settings = config["exogenous_var_settings"]
    threshold = pyro_settings["threshold"]
    weight = pyro_settings["weights"]
    samples = {}
    exogenous = []
    current_node = None
    root_list = []
    node_list = []
    visited = []

    for node in nodes:
        if nodes[node].root == True:
            root_list.append(node)
    for node in root_list:
        parent = pyro.sample(node, get_distribution(nodes[node].node_label,mu_a,
                                                                sigma_a,
                                                                mu_t,
                                                                sigma_t,
                                                                cat_1,
                                                                cat_0))
        samples[node] = parent
        visited.append(node)
        c_list = nodes[node].children
        for c in c_list:
            node_list.append(c)
        
            
    while len(node_list) > 0:
        current_node = node_list[0]
        parent_label = nodes[current_node].parent_label
        child_label = nodes[current_node].node_label
        relation = nodes[current_node].parent_relations
        parent_name = nodes[current_node].parents
        child_name = nodes[current_node].name
        w = [weight]*len(parent_label)
        
        parents = []
        increaseProcess = []
        decreaseProcess = []
        noise = dist.Gamma(torch.tensor(alpha),torch.tensor(beta))
        normal = dist.Normal(torch.tensor(0.0),torch.tensor(0.1))
        visited_parents_count = 0
        
        
        for i in range(len(parent_label)):
            if parent_name[i] in samples:
                parents.append(samples[parent_name[i]])
                visited_parents_count += 1

                if relation[i] == 'decreases' or relation[i] == 'directlyDecreases':
                    if parent_label[i] == 'Process':
                        decreaseProcess.append(samples[parent_name[i]])
                else:
                    if parent_label[i] == 'Process':
                        increaseProcess.append(samples[parent_name[i]])
        if visited_parents_count != len(parent_label):
            node_list.pop(0)
            continue
        if child_name not in visited:
            if "Process" in parent_label:
                if sum(decreaseProcess) == 0 and sum(increaseProcess) == len(increaseProcess):
                    child_N = get_sample(parent_name, child_name, child_label, parent_label, parents, relation, w, threshold, normal, noise, increaseProcess, decreaseProcess)
                else:
                    child_N = torch.tensor(0.)
            else:
                child_N = get_sample(parent_name, child_name, child_label, parent_label, parents, relation, w, threshold, normal, noise, increaseProcess, decreaseProcess)
                
            child = pyro.sample(child_name, pyro.distributions.Delta(child_N))
            samples[child_name] = child
            visited.append(child_name)
        
        c_list = nodes[current_node].children
        for c in c_list:
            node_list.append(c)
        node_list.pop(0)

                           
    return samples


In [12]:
print(SCM())

{'a(TAX:"Severe acute respiratory syndrome coronavirus 2")': tensor(0.2489), 'bp(GO:"pattern recognition receptor signaling pathway")': tensor(0.), 'deg(p(HGNC:ACE2))': tensor(0.4795), 'act(p(HGNC:ACE2))': tensor(0.), 'a(CHEBI:"angiotensin II")': tensor(0.7773), 'act(p(HGNC:AGTR1))': tensor(1.), 'act(p(HGNC:ADAM17))': tensor(1.), 'bp(GO:"tumor necrosis factor-mediated signaling pathway")': tensor(1.), 'act(p(HGNC:EGF))': tensor(1.), 'act(p(HGNC:IL6R))': tensor(1.), 'bp(GO:"epidermal growth factor receptor signaling pathway")': tensor(1.), 'act(complex(p(HGNC:IL6),p(HGNC:STAT3)))': tensor(1.), 'act(complex(GO:"NF-kappaB complex"))': tensor(0.), 'bp(GO:"positive regulation of interleukin-6 production")': tensor(0.), 'path(MESH:"Severe Acute Respiratory Syndrome")': tensor(0.)}


In [13]:
def intervention(model, do_variable, do_val, target_variable):
    """
      Description: This is a function to perform intervention
      query for sampling
      Parameters:  Structural Causal Model (model),
               a list of variables to be intervened (do_variable),
               list of values for intervened variables in
               same order (do_val),
               target variable (target_variable)
      Returns:  probability of target variable in given setting
    """
    # get the conditions for the do model
    conditions = {}
    for i in range(len(do_variable)):
        conditions[do_variable[i]] = torch.tensor(do_val[i])
    do_model = pyro.do(model, data=conditions)
    posterior = pyro.infer.Importance(do_model, num_samples=1000).run()
    marginal = EmpiricalMarginal(posterior, target_variable)
    target = [marginal().item() for i in range(1000)]
    return np.mean(target)

In [14]:
intervention(SCM,['bp(GO:"pattern recognition receptor signaling pathway")'],[1.],'path(MESH:"Severe Acute Respiratory Syndrome")')

0.76

In [15]:
intervention(SCM,['act(p(HGNC:ACE2))'],[0.0],'path(MESH:"Severe Acute Respiratory Syndrome")')

0.58