In [561]:
import sklearn
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

In [562]:
# Import the data
tree = pd.read_csv('tree.csv')
tree['t'] = tree['t'].replace(to_replace=0, value=0.1)

vert_genes = pd.read_csv('vert_genes.csv')

## Part I: Simulation

In [563]:
# Creating the graph

def create_graph(tree, alpha, beta, sigma_sq):
    G = nx.DiGraph()
    for _, row in tree.iterrows():
        if not pd.isna(row['Parent']):
            G.add_edge(int(row["Parent"]), int(row["Child"]), time = row["t"], a = alpha*row["t"], b = beta, variance = sigma_sq*row["t"])
            
    return G

G = create_graph(tree, alpha = 0, beta = 1, sigma_sq = 2500)

In [564]:
def simulate_node_length_with_parameters(G, parent, simulated_lengths, alpha, beta, sigma_sq):
    for child in G.successors(parent):
        t = G[parent][child]['time']
        mean = alpha * t + beta * simulated_lengths[parent]
        std = np.sqrt(sigma_sq * t)
        simulated_lengths[child] = np.random.normal(mean, std)
        simulate_node_length_with_parameters(G, child, simulated_lengths, alpha, beta, sigma_sq)
    
    return simulated_lengths

In [565]:
def simulate_data_for_learning(G, n, alpha, beta, sigma_sq, alpha_0, sigma_0_sq, root, learn_params = True, only_X = True):
    X_values = []
    Y_values = []
    
    all_nodes = list(G.nodes)  # Get all nodes in the graph

    for _ in range(n):
        simulated_lengths = {}
        
        # Simulate root node first
        simulated_lengths[root] = np.random.normal(alpha_0, np.sqrt(sigma_0_sq))
        
        # Simulate all other nodes recursively
        simulated_lengths = simulate_node_length_with_parameters(G, root, simulated_lengths, alpha, beta, sigma_sq)
        if only_X:
            leaf_nodes = [node for node in all_nodes if G.out_degree(node) == 0]
            simulated_x = [simulated_lengths[node] for node in leaf_nodes]
            
            X_values.append(simulated_x)
        else:
            X_values.append([simulated_lengths[node] for node in all_nodes])
            
        if learn_params:
            Y_values.append([alpha,beta,sigma_sq])
        else:
            Y_values.append(simulated_lengths[root])

    return np.array(X_values), np.array(Y_values)

learn_parameters = False

X, y = simulate_data_for_learning(G, 1000, alpha = 0.5, beta = 1, sigma_sq = 2500, alpha_0 = 50000, sigma_0_sq = 5000, root = 407, learn_params=learn_parameters)

print(X.shape)
print(y.shape)

print(y[10])

(1000, 204)
(1000,)
50031.26726368049


In [566]:
def compute_gamma(G):
    gamma = np.eye(len(G.nodes)) # Initialize gamma as an identity matrix

    # Iterate through the nodes in the graph
    for node in G.nodes:
        parent = next(G.predecessors(node), None)
        if parent is None:
            continue
        else:
            gamma[parent-1, node -1] = -G[parent][node]['b'] # Determine the dependency between parent and child nodes with -b
    return gamma

def compute_beta(G, alpha_0):
    beta = np.zeros((len(G.nodes), 1))
    for node in G.nodes:
        parent = next(G.predecessors(node), None)
        if parent is None:
            beta[node-1] = alpha_0
            continue
        a = G[parent][node]['a'] # The constant term in the mean of the CPD
        beta[node-1] = a
        
    return beta

def compute_sigma(G, sigma_0_sq):
    sigma = np.zeros((len(G.nodes)))
    for node in G.nodes:
        parent = next(G.predecessors(node), None)
        if parent is None:
            # Assign the default value for the root node
            sigma[node - 1] = sigma_0_sq
        else:
            # Access the edge attribute 'variance' only if parent exists
            variance = G[parent][node]['variance']
            sigma[node - 1] = variance
    return sigma

def compute_J_and_h(G, alpha, beta, sigma_sq, sigma_0_sq, alpha_0):
    
    beta = compute_beta(G, alpha_0)
    sigma = compute_sigma(G, sigma_0_sq)
    gamma = compute_gamma(G)

    J = np.sum([np.outer(gamma[:, i], gamma[:, i]) / sigma[i] for i in range(len(G))], axis=0)
    h = np.sum([(beta[i] / sigma[i]) * gamma[:, i] for i in range(len(G))], axis=0)
    return J, h



In [567]:
n = 1000

alpha = 0
beta = 1
sigma_sq = 2500
All_nodes_simulated, _ = simulate_data_for_learning(G, n, alpha=alpha, beta=beta, sigma_sq=sigma_sq, alpha_0=50000, sigma_0_sq=5000, root=407, learn_params=False, only_X=False)
empirical_covariance_matrix = np.cov(All_nodes_simulated, rowvar=False)

J, h = compute_J_and_h(G, alpha = alpha, beta = beta, sigma_sq = sigma_sq, sigma_0_sq=5000, alpha_0=50000)

computed_covariance_matrix = np.linalg.inv(J)

print(empirical_covariance_matrix.shape, computed_covariance_matrix.shape)


difference =  empirical_covariance_matrix - computed_covariance_matrix
print(f"Difference between empirical and computed covariance matrix \n: {difference}")

h

(407, 407) (407, 407)
Difference between empirical and computed covariance matrix 
: [[-2.06019590e+04  1.52300851e+03  1.37122304e+04 ... -2.49644495e+02
   3.02349309e+03  3.44928882e+03]
 [ 1.52300851e+03 -7.25339903e+01  1.34203802e+04 ...  4.18533338e+02
   2.25616478e+03  3.15581283e+03]
 [ 1.37122304e+04  1.34203802e+04  9.71434852e+03 ... -1.63354214e+03
   1.96831891e+03  2.93382955e+03]
 ...
 [-2.49644495e+02  4.18533338e+02 -1.63354214e+03 ... -4.07683961e+03
  -1.00112693e+05  4.82076651e+03]
 [ 3.02349309e+03  2.25616478e+03  1.96831891e+03 ... -1.00112693e+05
  -8.65095876e+03  1.02854471e+05]
 [ 3.44928882e+03  3.15581283e+03  2.93382955e+03 ...  4.82076651e+03
   1.02854471e+05  1.05625659e+05]]


array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0

In [568]:
def compute_clique_tree(G):
    C = nx.Graph()

    G_working = G.copy()
    leaves = [node for node in G_working.nodes if G_working.out_degree(node) == 0]
    G_working.remove_nodes_from(leaves)

    index = min(G_working.nodes) -1 if all(isinstance(n, int) for n in G_working.nodes) else 0

    for node in G_working.nodes:
        parent = node
        children = list(G_working.neighbors(parent))
        C.add_node(parent, variables=[parent])
        for child in children:
            pair_clique = index
            C.add_node(pair_clique, variables=[parent, child])
            C.add_edge(parent, pair_clique)
            C.add_edge(pair_clique, child)
            index = index - 1

    return C

C = compute_clique_tree(G)

In [569]:
NoV = len([node for node in C.nodes if len(C.nodes[node]['variables']) == 1]) -1
maxIndex = max([node for node in C.nodes])
away_from_zero = maxIndex - NoV
minIndex = min([node for node in C.nodes if len(C.nodes[node]['variables']) == 1])

print("Away from zero: ", away_from_zero)
print("Max index: ", maxIndex)
print("Min index: ", minIndex)
print("Number of variables in the graph: ", NoV)

def mapping_GTM(index_in_graph):
    return index_in_graph - away_from_zero 

def mapping_MTG(index_in_matrix):
    return index_in_matrix + away_from_zero

Away from zero:  205
Max index:  407
Min index:  205
Number of variables in the graph:  202


### Matrix algebra implementation

In [570]:
def get_sub_matrices(scope, X_values, J, h):
    X_indices = np.isin(scope, X_values)
    Z_indices = ~X_indices

    J_ZZ = J[Z_indices, :][:, Z_indices]
    J_ZX = J[Z_indices, :][:, X_indices]
    J_XZ = J_ZX.T
    J_XX = J[X_indices, :][:, X_indices]
    J_ZZ_inv = np.linalg.inv(J_ZZ)

    h_X = h[X_indices]
    h_Z = h[Z_indices]

    return J_ZZ, J_ZX, J_XZ, J_XX, J_ZZ_inv, h_X, h_Z

def get_conditional_distribution(J, h, X_values, X_indices):
    scope = [i for i in range(1, len(J) + 1)]
    J_ZZ, J_ZX, J_XZ, J_XX, J_ZZ_inv, h_X, h_Z = get_sub_matrices(scope, X_indices, J, h)

    J_reduced = J_ZZ
    h_reduced = h_Z- J_ZX @ X_values

    return J_reduced, h_reduced

X_values = simulate_data_for_learning(G, 1, alpha=alpha, beta=beta, sigma_sq=sigma_sq, alpha_0=50000, sigma_0_sq=5000, root=407, learn_params=False, only_X=True)[0]

In [571]:
leaves = [node for node in G.nodes if G.out_degree(node) == 0]
X_indices = leaves

J_reduced, h_reduced = get_conditional_distribution(J, h, X_values[0], X_indices)


Sigma = np.linalg.inv(J_reduced)
mu = Sigma @ h_reduced

random_index = np.random.choice(range(len(X_values[0])))
z = mapping_GTM(random_index)
print(f"Predicted value for node 407: {mu[z]}")
print(f"Variance for node 407: {Sigma[z,z]}")
print(f"Actual value for node 407: {X_values[0][0]}")
print(f"True variance for node 407: {sigma_sq}")

Predicted value for node 407: 49919.64092136943
Variance for node 407: 9548.543736981992
Actual value for node 407: 49710.574014690654
True variance for node 407: 2500


## Part II : inference

In [572]:
def marginalize_out(J, h, X_indices):
    scope = [i for i in range(1, len(J) + 1)]
    J_ZZ, J_ZX, J_XZ, J_XX, J_ZZ_inv, h_X, h_Z = get_sub_matrices(scope, X_indices, J, h)

    J_marg = J_ZZ - (J_ZX @ np.linalg.inv(J_XX) @ J_XZ)
    h_marg = h_Z - (J_ZX @ np.linalg.inv(J_XX) @ h_X)

    return J_marg, h_marg

In [573]:
def single_clique(G):
    leaves = [node for node in G.nodes if G.out_degree(node) == 0]
    H = G.copy()
    H.remove_nodes_from(leaves)
    H = H.to_undirected()
    return H

def compute_J_i_arrow_j(clique_tree, i, j, J, h, J_messages, h_messages):
    
    neighbors = list(clique_tree.neighbors(i))
    neighbors.remove(j)

    i_idx = mapping_GTM(i)
    j_idx = mapping_GTM(j)

    if not neighbors:
        J_messages[i_idx][j_idx] = -J[i_idx, j_idx] * J[j_idx, i_idx] / J[i_idx, i_idx]
    else:
        J_sum = sum(compute_J_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages) for k in neighbors)
        J_messages[i_idx][j_idx] = -J[i_idx, j_idx] * J[j_idx, i_idx] / (J[i_idx, i_idx] + J_sum)

    return J_messages[i_idx][j_idx]


def compute_h_i_arrow_j(clique_tree, i, j, J, h, J_messages, h_messages):

    neighbors = list(clique_tree.neighbors(i))
    neighbors.remove(j)

    i_idx = mapping_GTM(i)
    j_idx = mapping_GTM(j)

    if not neighbors:
        h_messages[i_idx][j_idx] = -J[i_idx, j_idx] * h[i_idx] / J[i_idx, i_idx]
    else:
        J_sum = sum(compute_J_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages) for k in neighbors)
        h_sum = sum(compute_h_i_arrow_j(clique_tree, k, i, J, h, J_messages, h_messages) for k in neighbors)
        Ji_backslash_j = J[i_idx, i_idx] + J_sum
        hi_backslash_j = h[i_idx] + h_sum
        h_messages[i_idx][j_idx] = (-J[i_idx, j_idx] * hi_backslash_j) / (Ji_backslash_j)

    return h_messages[i_idx][j_idx]

def inference_algorithm(G, alpha, beta, sigma_sq, alpha_0, X_indices, observed_X, Z):

    clique_tree = single_clique(G)
    J, h = compute_J_and_h(G, alpha, beta, sigma_sq, 5000, alpha_0)

    J_reduced, h_reduced = get_conditional_distribution(J, h, observed_X, X_indices)
        
    J_messages = np.full(J_reduced.shape, np.nan)
    h_messages = np.full(J_reduced.shape, np.nan)

    Z_neighbors = list(clique_tree.neighbors(Z))
    Z_idx = mapping_GTM(Z)

    J_zz = J_reduced[Z_idx, Z_idx]

    sum_J = sum([compute_J_i_arrow_j(clique_tree, k, Z, J_reduced, h_reduced, J_messages, h_messages) for k in Z_neighbors])
    sum_h = sum([compute_h_i_arrow_j(clique_tree, k, Z, J_reduced, h_reduced, J_messages, h_messages) for k in Z_neighbors])

    J_hat_Z = J_zz + sum_J
    h_hat_Z = h_reduced[Z_idx] + sum_h


    return J_hat_Z, h_hat_Z

def information_to_standard(J_hat_Z, h_hat_Z, Z):

    mu = h_hat_Z / J_hat_Z
    sigma = np.sqrt(1 / J_hat_Z)
    return mu, sigma

n = 10
random_index = np.random.choice(range(n))
Z = 407  # root node

alpha = 0
beta = 1
sigma_sq = 2500
alpha_0 = 50000

X_values, _ = simulate_data_for_learning(G, n, alpha = alpha, beta = beta, sigma_sq = sigma_sq, alpha_0 = 50000, sigma_0_sq = 5000, root = 407, learn_params=False, only_X=True)

leaves = [node for node in G.nodes if G.out_degree(node) == 0]
X_indices = leaves

J_hat_Z, h_hat_Z = inference_algorithm(G, alpha, beta, sigma_sq, alpha_0, X_indices, X_values[random_index], Z)
mu, sigma = information_to_standard(J_hat_Z, h_hat_Z, Z)

true_mu = X_values[random_index][0]
true_sigma = sigma_sq
print("True mean:", true_mu)
print("True std dev:", true_sigma)
print("Posterior mean:", mu)
print("Posterior std dev:", sigma)

True mean: 49885.43678105752
True std dev: 2500
Posterior mean: 50048.66881070986
Posterior std dev: 60.085211992236054


## Part III: Learning

In [574]:
def simulate_full_data(G, n_samples, alpha, beta, sigma_sq, alpha_0, sigma_0_sq):
    rows = []

    for _ in range(n_samples):
        simulated = {}
        simulated[root] = np.random.normal(alpha_0, np.sqrt(sigma_0_sq))
        simulate_node_length_with_parameters(G, root, simulated, alpha, beta, sigma_sq)

        for parent, child in G.edges:
            t = G[parent][child]['time']
            y = simulated[child]
            z = simulated[parent]
            rows.append({
                'Y': y,
                'Z': z,
                't': t
            })

    return pd.DataFrame(rows)

def estimate_alpha_beta_sigma2(data):
    X = data[['t', 'Z']]
    y = data['Y']

    weights = 1 / data['t']
    model = LinearRegression().fit(X, y, sample_weight=weights)
    alpha_hat = model.coef_[0]
    beta_hat = model.coef_[1]

    # Residual variance estimate of σ²

    residuals = y - model.predict(X)
    sigma_sq_hat = np.sum(weights * residuals**2) / len(y)

    return alpha_hat, beta_hat, sigma_sq_hat

root = 407
n = 1000
df = simulate_full_data(G, n, alpha=0.5, beta=1, sigma_sq=2500, alpha_0=50000, sigma_0_sq=5000)
alpha_hat, beta_hat, sigma_sq_hat = estimate_alpha_beta_sigma2(df)

print("Estimated alpha:", alpha_hat)
print("Estimated beta:", beta_hat)
print("Estimated sigma_sq:", sigma_sq_hat)

Estimated alpha: 0.5228403735042421
Estimated beta: 0.999991100573507
Estimated sigma_sq: 2497.3450648900794


In [575]:
z0_id = 407

def make_train_data(di):
    X_i = np.array([[di['t'], di[z0_id]]])
    y_i = np.array([di['Y']])
    return X_i, y_i

def learn_parameters(xs_lst, zs_list):
    '''
    xs_lst is a list of dictionaries for x
    zs_list is a list of dictionaries for z
    '''

    y_list = []
    x_list = []

    for i in range(len(xs_lst)):
        di = {**xs_lst[i], **zs_list[i]}
        X_i, y_i = make_train_data(di)
        y_list.append(y_i)
        x_list.append(X_i)
      
    y = np.concatenate(y_list)
    X = np.concatenate(x_list)

    model = LinearRegression().fit(X, y)
    mean_ys = model.predict(X)

    alpha = model.coef_[0]
    beta = model.coef_[1]


    var = np.mean((y - mean_ys)**2 / X[:, 0])
    alpha_0 = np.mean([d[z0_id] for d in zs_list])

    return alpha, beta, var, alpha_0


In [576]:
def hard_assignment_EM(n, alpha_init, beta_init, sigma_sq_init, alpha_0_init, root=407):
    alpha = alpha_init
    beta = beta_init
    sigma_sq = sigma_sq_init
    alpha_0 = alpha_0_init

    for i in range(n):
        G = create_graph(tree, alpha, beta, sigma_sq)
        df = simulate_full_data(G, 100, alpha, beta, sigma_sq, alpha_0, 5000)
        single_dataset = df.iloc[:, 0]
        leaves = [node for node in G.nodes if G.out_degree(node) == 0]
        X_values = single_dataset[leaves]
        X_indices = leaves

        J_hat_z, h_hat_z = inference_algorithm(G, alpha, beta, sigma_sq, alpha_0, X_indices, X_values, root)
        alpha_0_hat, _ = information_to_standard(J_hat_z, h_hat_z, root)
        alpha_hat, beta_hat, sigma_sq_hat = estimate_alpha_beta_sigma2(df)

        alpha = alpha_hat
        beta = beta_hat
        sigma_sq = sigma_sq_hat
        alpha_0 = alpha_0_hat
        if i % 4 == 0:
            print("Iteration:", i)
            print("Estimated alpha:", alpha)
            print("Estimated beta:", beta)
            print("Estimated sigma_sq:", sigma_sq)
            print("Estimated alpha_0:", alpha_0)  
        
    return alpha, beta, sigma_sq, alpha_0

hard_assignment_EM(1, 0.5, 1, 2500, 50000)

Iteration: 0
Estimated alpha: 0.5195387926104431
Estimated beta: 1.0000241696546568
Estimated sigma_sq: 2491.262709803439
Estimated alpha_0: 49996.043177036605


(0.5195387926104431, 1.0000241696546568, 2491.262709803439, 49996.043177036605)

### Part III: Apply inference and learning algorithms to real data

In [577]:
def split_gene_data(tree, vert_genes):
    data = pd.merge(tree, vert_genes, on='species', how='inner')
    rows = []
    for orth_id, group in data.groupby('orthId'):
        child_glength = dict(zip(group['Child'], group['glength']))
        rows.append({'orthId': orth_id, 'glength_dict': child_glength})

    return pd.DataFrame(rows)

real_data = split_gene_data(tree, vert_genes)

real_data


Unnamed: 0,orthId,glength_dict
0,1CPN2,"{1: 88338, 2: 219992, 3: 89972, 4: 233175, 5: ..."
1,1CQBX,"{1: 30935, 2: 31081, 3: 40971, 4: 32508, 5: 28..."
2,1CQJ6,"{1: 21555, 2: 20662, 3: 12707, 4: 21723, 5: 22..."
3,1CR8Z,"{1: 12401, 2: 16419, 3: 11305, 4: 16328, 5: 59..."
4,1CTEU,"{1: 3125, 2: 2814, 3: 2122, 4: 2765, 5: 3457, ..."
5,1CTI9,"{1: 44314, 2: 42679, 3: 44077, 4: 36837, 5: 11..."
6,1CYBB,"{1: 104633, 2: 104930, 3: 105023, 4: 106842, 5..."
7,1D0EM,"{1: 19977, 2: 21211, 3: 24187, 4: 22395, 5: 21..."
8,1D1CF,"{1: 16494, 2: 15176, 3: 16651, 4: 15010, 5: 16..."
9,1D3F1,"{1: 29924, 2: 8341, 3: 6260, 4: 9613, 5: 6236,..."


In [578]:

def run_inference_real_data(real_data, alpha, beta, sigma_sq, alpha_0, root = 407):

    results = pd.DataFrame(columns=['orthId', 'mu', 'sigma', 'true_mean', 'NoV'])
    for _, row in real_data.iterrows():
        orth_id = row['orthId']
        data = row['glength_dict']
        X_indices = list(data.keys())
        X_values = list(data.values())
        mean = np.mean(X_values)

        print("Number of observed values:", len(X_values))

        G = create_graph(tree, alpha, beta, sigma_sq)
        J_hat_z, h_hat_z = inference_algorithm(G, alpha, beta, sigma_sq, alpha_0, X_indices, X_values, root)
        mu, sigma = information_to_standard(J_hat_z, h_hat_z, root)
        
        results.loc[len(results)] = {'orthId': orth_id, 'mu': mu, 'sigma': sigma, 'true_mean': mean, 'NoV': len(X_indices)}

    return results

results = run_inference_real_data(real_data, 0, 1, 100, 50000)

print(results)
tree

Number of observed values: 203


Number of observed values: 204
Number of observed values: 202
Number of observed values: 202
Number of observed values: 204
Number of observed values: 204
Number of observed values: 204
Number of observed values: 188
Number of observed values: 196
Number of observed values: 201
  orthId            mu      sigma      true_mean  NoV
0  1CPN2   2350.091567  11.551491  135505.615764  203
1  1CQBX  36231.870337  21.694264   34798.240196  204
2  1CQJ6    704.051626   4.787311   20419.287129  202
3  1CR8Z    653.214565   4.787311   13790.514851  202
4  1CTEU   7474.681154  21.694264    2998.083333  204
5  1CTI9  25903.702363  21.694264   28328.098039  204
6  1CYBB  47521.445748  21.694264   75589.941176  204
7  1D0EM    229.813534   5.588823   15203.244681  188
8  1D1CF      0.000000   8.735438   20381.913265  196
9  1D3F1     42.700600   4.638834    5174.621891  201


Unnamed: 0,Parent,Child,age_ch,t,species
0,222.0,1,0.000000,9.000250,Peromyscus_maniculatus
1,222.0,2,0.000000,9.000250,Mus_musculus
2,221.0,3,0.000000,12.172706,Cricetulus_griseus
3,220.0,4,0.000000,14.684269,Rattus_norvegicus
4,219.0,5,0.000000,17.062881,Mesocricetus_auratus
...,...,...,...,...,...
402,399.0,403,7.861905,1.640772,
403,403.0,404,5.750020,2.111885,
404,404.0,405,5.500000,0.250020,
405,403.0,406,5.000000,2.861905,


In [579]:
# Iterate over each row in the real_data DataFrame
for _, row in real_data.iterrows():
    orth_id = row['orthId']
    glength_dict = row['glength_dict']

    # Extract observed values (X_values) and their indices (X_indices)
    X_indices = list(glength_dict.keys())
    X_values = np.array(list(glength_dict.values()))

    # Compute the reduced precision matrix and potential vector
    J_reduced, h_reduced = get_conditional_distribution(J, h, X_values, X_indices)

    print("size of J_reduced: ", J_reduced.shape)
    # Compute the posterior covariance matrix (Sigma) and mean (mu)
    Sigma = np.linalg.inv(J_reduced)
    mu = Sigma @ h_reduced

    # Map the root node index to the reduced matrix
    root_index = mapping_GTM(root)

    # Extract the posterior mean and variance for the root node
    posterior_mean = mu[root_index]
    posterior_variance = Sigma[root_index, root_index]

    print(f"Ortholog ID: {orth_id}")
    print(f"Posterior mean for root node: {posterior_mean}")
    print(f"Posterior variance for root node: {posterior_variance}")
    print("-" * 50)

size of J_reduced:  (204, 204)
Ortholog ID: 1CPN2
Posterior mean for root node: 11585.139492304294
Posterior variance for root node: 4192.0271292983725
--------------------------------------------------
size of J_reduced:  (203, 203)
Ortholog ID: 1CQBX
Posterior mean for root node: 45775.45150739554
Posterior variance for root node: 3610.232700151948
--------------------------------------------------
size of J_reduced:  (205, 205)
Ortholog ID: 1CQJ6
Posterior mean for root node: 8758.692219838955
Posterior variance for root node: 3506.7370253340227
--------------------------------------------------
size of J_reduced:  (205, 205)
Ortholog ID: 1CR8Z
Posterior mean for root node: 7490.846996389608
Posterior variance for root node: 3486.1782806976985
--------------------------------------------------
size of J_reduced:  (203, 203)
Ortholog ID: 1CTEU
Posterior mean for root node: 36951.730117071085
Posterior variance for root node: 3610.232700151948
-----------------------------------------

In [580]:
def hard_assignment_EM_real_data(real_data, n, alpha_init, beta_init, alpha0_init, sigma_sq_init, root = 407):
    
    for i in range(n):
        alpha = np.array([alpha_init] * real_data.shape[0])
        beta = np.array([beta_init] * real_data.shape[0])
        sigma_sq = np.array([sigma_sq_init] * real_data.shape[0])
        alpha_0 = np.array([alpha0_init] * real_data.shape[0])

        for i, row in real_data.iterrows():
            orth_id = row['orthId']
            glength_dict = row['glength_dict']

            X_indices = list(glength_dict.keys())
            X_values = np.array(list(glength_dict.values()))

            J_hat_z, h_hat_z = inference_algorithm(G, alpha[i], beta[i], sigma_sq[i], alpha_0[i], X_indices, X_values, root)
            alpha_0_hat, _ = information_to_standard(J_hat_z, h_hat_z, root)

            df = simulate_full_data(G, 100, alpha[i], beta[i], sigma_sq[i], alpha_0[i], 5000)
            alpha_hat, beta_hat, sigma_sq_hat = estimate_alpha_beta_sigma2(df)

            alpha[i] = alpha_hat
            beta[i] = beta_hat
            sigma_sq[i] = sigma_sq_hat
            alpha_0[i] = alpha_0_hat

    return alpha, beta, sigma_sq, alpha_0

alpha, beta, sigma_sq, alpha_0 = hard_assignment_EM_real_data(real_data, 1, 0.5, 1, 50000, 5000)

for i in range(real_data.shape[0]):
    print("Ortholog ID:", real_data.iloc[i]['orthId'])
    print("Estimated alpha:", alpha[i])
    print("Estimated beta:", beta[i])
    print("Estimated sigma_sq:", sigma_sq[i])
    print("Estimated alpha_0:", alpha_0[i])
    print("-" * 50)


Ortholog ID: 1CPN2
Estimated alpha: 0.5187120233116777
Estimated beta: 0
Estimated sigma_sq: 4998
Estimated alpha_0: 2350
--------------------------------------------------
Ortholog ID: 1CQBX
Estimated alpha: 0.3143416072305276
Estimated beta: 0
Estimated sigma_sq: 5044
Estimated alpha_0: 45775
--------------------------------------------------
Ortholog ID: 1CQJ6
Estimated alpha: 0.4326197880284413
Estimated beta: 1
Estimated sigma_sq: 5012
Estimated alpha_0: 704
--------------------------------------------------
Ortholog ID: 1CR8Z
Estimated alpha: 0.5359186617403578
Estimated beta: 1
Estimated sigma_sq: 4958
Estimated alpha_0: 653
--------------------------------------------------
Ortholog ID: 1CTEU
Estimated alpha: 0.319424537765399
Estimated beta: 1
Estimated sigma_sq: 4984
Estimated alpha_0: 36951
--------------------------------------------------
Ortholog ID: 1CTI9
Estimated alpha: 0.6781610629918227
Estimated beta: 1
Estimated sigma_sq: 5054
Estimated alpha_0: 42606
-------------