In [1]:
import sklearn
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

In [2]:
# Import the data
tree = pd.read_csv('tree.csv')
tree['t'] = tree['t'].replace(to_replace=0, value=0.1)

vert_genes = pd.read_csv('vert_genes.csv')

## Part I: Simulation

In [3]:
# Creating the graph

def create_graph(tree, alpha, beta, sigma_sq):
    G = nx.DiGraph()
    for _, row in tree.iterrows():
        if not pd.isna(row['Parent']):
            G.add_edge(int(row["Parent"]), int(row["Child"]), time = row["t"], a = alpha*row["t"], b = beta, variance = sigma_sq*row["t"])
            
    return G

G = create_graph(tree, alpha = 0, beta = 1, sigma_sq = 2500)

In [4]:
def simulate_node_length_with_parameters(G, parent, simulated_lengths, alpha, beta, sigma_sq):
    for child in G.successors(parent):
        t = G[parent][child]['time']
        mean = alpha * t + beta * simulated_lengths[parent]
        std = np.sqrt(sigma_sq * t)
        simulated_lengths[child] = np.random.normal(mean, std)
        simulate_node_length_with_parameters(G, child, simulated_lengths, alpha, beta, sigma_sq)
    
    return simulated_lengths

In [5]:
def simulate_data_for_learning(G, n, alpha, beta, sigma_sq, alpha_0, sigma_0_sq, root, learn_params = True, only_X = True):
    X_values = []
    Y_values = []
    
    all_nodes = list(G.nodes)  # Get all nodes in the graph

    for _ in range(n):
        simulated_lengths = {}
        
        # Simulate root node first
        simulated_lengths[root] = np.random.normal(alpha_0, np.sqrt(sigma_0_sq))
        
        # Simulate all other nodes recursively
        simulated_lengths = simulate_node_length_with_parameters(G, root, simulated_lengths, alpha, beta, sigma_sq)
        if only_X:
            leaf_nodes = [node for node in all_nodes if G.out_degree(node) == 0]
            simulated_x = [simulated_lengths[node] for node in leaf_nodes]
            
            X_values.append(simulated_x)
        else:
            X_values.append([simulated_lengths[node] for node in all_nodes])
            
        if learn_params:
            Y_values.append([alpha,beta,sigma_sq])
        else:
            Y_values.append(simulated_lengths[root])

    return np.array(X_values), np.array(Y_values)

learn_parameters = False

X, y = simulate_data_for_learning(G, 1000, alpha = 0.5, beta = 1, sigma_sq = 2500, alpha_0 = 50000, sigma_0_sq = 5000, root = 407, learn_params=learn_parameters)

print(X.shape)
print(y.shape)

print(y[10])

(1000, 204)
(1000,)
50116.03336955078


In [6]:
def compute_gamma(G):
    gamma = np.eye(len(G.nodes)) # Initialize gamma as an identity matrix

    # Iterate through the nodes in the graph
    for node in G.nodes:
        parent = next(G.predecessors(node), None)
        if parent is None:
            print(f"Node {node} is the root node.")
            continue
        else:
            gamma[parent-1, node -1] = -G[parent][node]['b'] # Determine the dependency between parent and child nodes with -b
    return gamma

def compute_beta(G, alpha_0 = 50000):
    beta = np.zeros((len(G.nodes), 1))
    for node in G.nodes:
        parent = next(G.predecessors(node), None)
        if parent is None:
            beta[node-1] = alpha_0
            continue
        a = G[parent][node]['a'] # The constant term in the mean of the CPD
        beta[node-1] = a
        
    return beta

def compute_sigma(G, sigma_0_sq = 5000):
    sigma = np.zeros((len(G.nodes)))
    for node in G.nodes:
        parent = next(G.predecessors(node), None)
        if parent is None:
            # Assign the default value for the root node
            sigma[node - 1] = sigma_0_sq
        else:
            # Access the edge attribute 'variance' only if parent exists
            variance = G[parent][node]['variance']
            sigma[node - 1] = variance
    return sigma

def compute_J_and_h(alpha, beta, sigma_sq, sigma_0_sq = 5000):
    
    G = create_graph(tree, alpha, beta, sigma_sq)

    beta = compute_beta(G)
    sigma = compute_sigma(G, sigma_0_sq)
    gamma = compute_gamma(G)

    J = np.sum([np.outer(gamma[:, i], gamma[:, i]) / sigma[i] for i in range(len(G))], axis=0)
    h = np.sum([(beta[i] / sigma[i]) * gamma[:, i] for i in range(len(G))], axis=0)
    return J, h



In [7]:
n = 1000

alpha = 0
beta = 1
sigma_sq = 2500
All_nodes_simulated, _ = simulate_data_for_learning(G, n, alpha=alpha, beta=beta, sigma_sq=sigma_sq, alpha_0=50000, sigma_0_sq=5000, root=407, learn_params=False, only_X=False)
empirical_covariance_matrix = np.cov(All_nodes_simulated, rowvar=False)

J, h = compute_J_and_h(alpha = alpha, beta = beta, sigma_sq = sigma_sq)

computed_covariance_matrix = np.linalg.inv(J)

print(empirical_covariance_matrix.shape, computed_covariance_matrix.shape)


difference =  empirical_covariance_matrix - computed_covariance_matrix
print(f"Difference between empirical and computed covariance matrix \n: {difference}")

h

Node 407 is the root node.
(407, 407) (407, 407)
Difference between empirical and computed covariance matrix 
: [[-16172.60401878   7030.76357501  12715.03642319 ...   2754.72991758
   -4430.01376399  -4293.49016197]
 [  7030.76357501   8596.70031247  13402.82485447 ...    484.37635554
   -5751.14212241  -5740.2847144 ]
 [ 12715.03642319  13402.82485447   3545.71424905 ...   2265.97528042
   -4929.83142472  -4766.43414795]
 ...
 [  2754.72991758    484.37635554   2265.97528042 ... -15951.53991818
  -97424.79962474   7824.05624074]
 [ -4430.01376399  -5751.14212241  -4929.83142472 ... -97424.79962474
  -20532.44165264  92312.20650686]
 [ -4293.49016197  -5740.2847144   -4766.43414795 ...   7824.05624074
   92312.20650686  96580.76976067]]


array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0

In [8]:
def compute_clique_tree(G):
    C = nx.Graph()

    G_working = G.copy()
    leaves = [node for node in G_working.nodes if G_working.out_degree(node) == 0]
    G_working.remove_nodes_from(leaves)

    index = min(G_working.nodes) -1 if all(isinstance(n, int) for n in G_working.nodes) else 0

    for node in G_working.nodes:
        parent = node
        children = list(G_working.neighbors(parent))
        C.add_node(parent, variables=[parent])
        for child in children:
            pair_clique = index
            C.add_node(pair_clique, variables=[parent, child])
            C.add_edge(parent, pair_clique)
            C.add_edge(pair_clique, child)
            index = index - 1

    return C

C = compute_clique_tree(G)

In [9]:
NoV = len([node for node in C.nodes if len(C.nodes[node]['variables']) == 1]) -1
maxIndex = max([node for node in C.nodes])
away_from_zero = maxIndex - NoV
minIndex = min([node for node in C.nodes if len(C.nodes[node]['variables']) == 1])

print("Away from zero: ", away_from_zero)
print("Max index: ", maxIndex)
print("Min index: ", minIndex)
print("Number of variables in the graph: ", NoV)

def mapping_GTM(index_in_graph):
    return index_in_graph - away_from_zero 

def mapping_MTG(index_in_matrix):
    return index_in_matrix + away_from_zero

Away from zero:  205
Max index:  407
Min index:  205
Number of variables in the graph:  202


### Matrix algebra implementation

In [10]:
def get_sub_matrices(scope, X_values, J, h):
    X_indices = np.isin(scope, X_values)
    Z_indices = ~X_indices

    J_ZZ = J[Z_indices, :][:, Z_indices]
    J_ZX = J[Z_indices, :][:, X_indices]
    J_XZ = J_ZX.T
    J_XX = J[X_indices, :][:, X_indices]
    J_ZZ_inv = np.linalg.inv(J_ZZ)

    h_X = h[X_indices]
    h_Z = h[Z_indices]

    return J_ZZ, J_ZX, J_XZ, J_XX, J_ZZ_inv, h_X, h_Z

def get_conditional_distribution(J, h, X_values, X_indices):
    scope = [i for i in range(1, len(J) + 1)]
    J_ZZ, J_ZX, J_XZ, J_XX, J_ZZ_inv, h_X, h_Z = get_sub_matrices(scope, X_indices, J, h)

    J_reduced = J_ZZ
    h_reduced = h_Z- J_ZX @ X_values

    return J_reduced, h_reduced

X_values = simulate_data_for_learning(G, 1, alpha=alpha, beta=beta, sigma_sq=sigma_sq, alpha_0=50000, sigma_0_sq=5000, root=407, learn_params=False, only_X=True)[0]


leaves = [node for node in G.nodes if G.out_degree(node) == 0]
X_indices = leaves

J_reduced, h_reduced = get_conditional_distribution(J, h, X_values[0], X_indices)

In [11]:
Sigma = np.linalg.inv(J_reduced)
mu = Sigma @ h_reduced

random_index = np.random.choice(range(len(X_values[0])))
z = mapping_GTM(random_index)
print(f"Predicted value for node 407: {mu[z]}")
print(f"Variance for node 407: {Sigma[z,z]}")
print(f"Actual value for node 407: {X_values[0][0]}")
print(f"True variance for node 407: {sigma_sq}")

Predicted value for node 407: 49885.98983680517
Variance for node 407: 15906.197622746702
Actual value for node 407: 49881.28480624464
True variance for node 407: 2500


## Part II : inference

In [12]:
def single_clique(G):
    leaves = [node for node in G.nodes if G.out_degree(node) == 0]
    H = G.copy()
    H.remove_nodes_from(leaves)
    H = H.to_undirected()
    return H

def compute_J_i_arrow_j(clique_tree, i, j, J, h, J_messages, h_messages, globalcounter):
    
    neighbors = list(clique_tree.neighbors(i))
    neighbors.remove(j)

    i_idx = mapping_GTM(i)
    j_idx = mapping_GTM(j)

    if not neighbors:
        J_messages[i_idx][j_idx] = -J[i_idx, j_idx] * J[j_idx, i_idx] / J[i_idx, i_idx]
        globalcounter[0] += 1
    else:
        J_sum = np.sum(compute_J_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages, globalcounter) for k in neighbors)
        J_messages[i_idx][j_idx] = -J[i_idx, j_idx] * J[j_idx, i_idx] / (J[i_idx, i_idx] + J_sum)
        globalcounter[0] += 1

    return J_messages[i_idx][j_idx]


def compute_h_i_arrow_j(clique_tree, i, j, J, h, J_messages, h_messages, globalcounter):

    neighbors = list(clique_tree.neighbors(i))
    neighbors.remove(j)

    i_idx = mapping_GTM(i)
    j_idx = mapping_GTM(j)

    if not neighbors:
        h_messages[i_idx][j_idx] = -J[i_idx, j_idx] * h[i_idx] / J[i_idx, i_idx]
        globalcounter[1] += 1
    else:
        J_sum = np.sum(compute_J_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages, globalcounter) for k in neighbors)
        h_sum = np.sum(compute_h_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages, globalcounter) for k in neighbors)
        Ji_backslash_j = J[i_idx, i_idx] + J_sum
        hi_backslash_j = h[i_idx] + h_sum
        h_messages[i_idx][j_idx] = (-J[i_idx, j_idx] * hi_backslash_j) / (Ji_backslash_j)
        globalcounter[0] += 1

    return h_messages[i_idx][j_idx]

def inference_algorithm(full_tree, J, h, observed_X, Z, globalcounter):
    clique_tree = single_clique(full_tree)

    print("Number of nodes in clique tree:", len(clique_tree.nodes))

    n_observed = len(observed_X)
    J_reduced, h_reduced = get_conditional_distribution(J, h, observed_X, list(range(1, n_observed + 1)))


    J_messages = np.full(J_reduced.shape, np.nan)
    h_messages = np.full(J_reduced.shape, np.nan)

    Z_neighbors = list(clique_tree.neighbors(Z))
    Z_idx = mapping_GTM(Z)

    J_zz = J_reduced[Z_idx, Z_idx]

    # print("Z neighbors:", Z_neighbors)
    # print("Z index:", Z_idx)
    sum_J = np.sum([compute_J_i_arrow_j(clique_tree, k, Z, J_reduced, h_reduced, J_messages, h_messages, globalcounter) for k in Z_neighbors])
    sum_h = np.sum([compute_h_i_arrow_j(clique_tree, k, Z, J_reduced, h_reduced, J_messages, h_messages, globalcounter) for k in Z_neighbors])

    J_hat_Z = J_zz + sum_J
    h_hat_Z = h_reduced[Z_idx] + sum_h

    # print(f"J_zz for node {Z}:", J_zz)
    # print("Incoming J messages:")
    # for k in Z_neighbors:
    #     val = compute_J_i_arrow_j(clique_tree, k, Z, J_reduced, h_reduced, J_messages, h_messages, globalcounter)
    #     h_val = compute_h_i_arrow_j(clique_tree, k, Z, J_reduced, h_reduced, J_messages, h_messages, globalcounter)
    #     print(f"  {k} → {Z}: {val} ({h_val})")

    # print("→ Sum of messages:", sum_J)

    return J_hat_Z, h_hat_Z

def information_to_standard(J_hat_Z, h_hat_Z, Z):

    print("J_hat_Z:", J_hat_Z)
    print("h_hat_Z:", h_hat_Z)
    mu = h_hat_Z / J_hat_Z
    sigma = np.sqrt(1 / J_hat_Z)
    return mu, sigma

n = 10
random_index = np.random.choice(range(n))
Z = 300  # root node

alpha = 0
beta = 1
sigma_sq = 2500

X_values, _ = simulate_data_for_learning(G, n, alpha = alpha, beta = beta, sigma_sq = sigma_sq, alpha_0 = 50000, sigma_0_sq = 5000, root = 407, learn_params=False, only_X=True)

globalcounter = [0,0]

J_hat_Z, h_hat_Z = inference_algorithm(G, J, h, X_values[random_index], Z, globalcounter)
mu, sigma = information_to_standard(J_hat_Z, h_hat_Z, Z)

true_mu = X_values[random_index][0]
true_sigma = sigma_sq
print("True mean:", true_mu)
print("True std dev:", true_sigma)
print("Posterior mean:", mu)
print("Posterior std dev:", sigma)
print("Total messages sent:", globalcounter)

Number of nodes in clique tree: 203
J_hat_Z: 0.00028940114546818524
h_hat_Z: 14.62022314623666
True mean: 50616.455617331274
True std dev: 2500
Posterior mean: 50518.884859921556
Posterior std dev: 58.78274696273216
Total messages sent: [3350, 55]


  J_sum = np.sum(compute_J_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages, globalcounter) for k in neighbors)
  J_sum = np.sum(compute_J_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages, globalcounter) for k in neighbors)
  h_sum = np.sum(compute_h_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages, globalcounter) for k in neighbors)
