# PhyloCNN - CI_primates

This notebook was modified [from (Lambert et al. 2023)](https://github.com/JakubVoz/deeptimelearning/blob/main/estimation/NN/empirical/BISSE_cnn_CDV_mae_CI_computation_Gomez2012.ipynb).

In [4]:
## Import necessary libraries
import pandas as pd
import numpy as np
import tensorflow as tf
import keras
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from keras.models import model_from_json

In [5]:
#########loading data#########
encoding_test = pd.read_csv('./Encoded_primates.csv', sep="\t", header=0, index_col=0).values.reshape(-1,1000,19)

In [None]:
# Additional sampling probability and tree size for rescaling
test_rescale = 4.064753  # Rescaling factor
test_tree_size = 260  # Example tree size

# Add a new column for sampling probability (0.25 for all nodes)
samp_proba_list = np.array(0.68241469816273)
encoding_test=np.concatenate((encoding_test,np.repeat(samp_proba_list,1000).reshape(-1,1000,1)),axis=2)

In [9]:
# This function takes in the tree encodings for the empirical dataset
# and processes it to have a uniform shape. It also pads the leaves and nodes 
# of the tree to ensure a fixed number of 500 leaves and nodes.

def encode_pad_0s_rootage(enc):
    # Create an empty list to hold padded training encodings
    enc_pad = []
    
    # Iterate over each tree in the training dataset
    for i in range(enc.shape[0]):
        # Separate the leaves (where column 3 has value 1, which indicates leaves)
        leaves = enc[i][enc[i,:,3] == 1]
        # Sort leaves by their age (assumed to be in column 1)
        leaves = leaves[np.argsort(leaves[:, 1])]
        # Pad the leaves array with 0s until it has a maximum size of 500 leaves
        leaves = np.pad(leaves, [(0, (500 - leaves.shape[0])), (0, 0)], mode='constant')

        # Separate the nodes (where column 3 is greater than 1, indicating internal nodes)
        nodes = enc[i][enc[i,:,3] > 1]
        # Sort nodes by their age (assumed to be in column 1)
        nodes = nodes[np.argsort(nodes[:, 1])]
        # Copy the last node's value to balance the number of leaves and nodes
        nodes = np.append(nodes, nodes[-1].reshape(1, -1), axis=0)
        # Pad the nodes array with 0s to ensure a size of 500 nodes
        nodes = np.pad(nodes, [(0, (500 - nodes.shape[0])), (0, 0)], mode='constant')
        
        # Stack the leaves and nodes arrays together along axis 2 (creating 2 channels)
        enc_pad.append(np.stack((leaves, nodes), axis=2))
    
    # Convert lists to numpy arrays and return the padded training and test data
    return np.array(enc_pad)

#Change encoding to order by root age and pad with 0s
encoding_pad_test = encode_pad_0s_rootage(encoding_test)

In [None]:
encoding_pad_test.shape

(1, 500, 20, 2)

In [10]:
#load the model
from keras.models import model_from_json
json_file = open('./Trained_Models/Trained_2Generation_BiSSE.json', 'r')
model = json_file.read()
json_file.close()
estimator = model_from_json(model)
#load weights
estimator.load_weights('./Trained_Models/Trained_2Generation_BiSSE.h5')
print('model loaded!')

# predict values for the empirical data
predicted_test = pd.DataFrame(estimator.predict(encoding_pad_test))
predicted_test.columns = ["turnover", "lambda1_rescaled", "lambda2_rescaled", "q01_rescaled"]
predicted_test['mu1_rescaled'] = predicted_test['turnover']*predicted_test['lambda1_rescaled']
predicted_test['mu2_rescaled'] = predicted_test['turnover']*predicted_test['lambda2_rescaled']


predicted_test['lambda1_rescaled'] = predicted_test['lambda1_rescaled']/test_rescale
predicted_test['lambda2_rescaled'] = predicted_test['lambda2_rescaled']/test_rescale
predicted_test['mu1_rescaled'] = predicted_test['mu1_rescaled']/test_rescale
predicted_test['mu2_rescaled'] = predicted_test['mu2_rescaled']/test_rescale
predicted_test['q01_rescaled'] = predicted_test['q01_rescaled']/test_rescale

model loaded!


In [12]:
### load data sets for CI computations
# CI_param: known parameter values (i.e. used for obtaining simulations in the training set)
CI_param = pd.read_csv('BiSSE.csv', sep=",")
# CI_predicted: predicted parameter values obtained with the training set
CI_predicted = pd.read_csv('Predicted_BiSSE.csv')

In [13]:
# rescaling all values so that they correspond to trees of average branch length of 1
CI_param['lambda1_rescaled'] = CI_param['lambda1']
CI_param['lambda2_rescaled'] = CI_param['lambda2']
CI_param['q01_rescaled'] = CI_param['q01']

In [14]:
### prepare col names of output table
# parameters for which we compute the CI
targets = ["turnover", "lambda1_rescaled", "lambda2_rescaled", "q01_rescaled"]
# number of neighboring simulation sets we consider to compute CI
n_neighbors = [1000]
# min max values for the computed CI values: set to biologically relevant boundaries (i.e. non negative values)
min_max = {targets[0]: [0, 1000], targets[1]: [0, 1000], targets[2]: [0, 1000], targets[3]: [0, 1000]}
# prepare col names of output table: value of lower boundrary, upper boundary and the width of CI
add_ons_names = ['_CI_2_5', '_CI_97_5', '_CI_width']
col = [add_on + '_' + str(n_neigh) for n_neigh in n_neighbors for add_on in add_ons_names]
col_comp = []
col_comp = [target + co for target in targets for co in col]

In [15]:
def get_indexes_of_closest_single_factor(test_value, ci_values, n):
    """Returns indexes of knn for given set

    :param test_value: float, value of parameter (e.g. sampling proba or tree size) on which we select given observation
    :param ci_values: dataframe, values of these parameters in CI set
    :param n: int, number of KNNs to find
    :return: list, indexes of n KNNs
    """
    ref = ci_values.iloc[(ci_values-test_value).abs().argsort()].index
    return [ref[i] for i in range(n)]


def get_indexes_of_closest(test_s, ci_s, n):
    """Returns indexes of knn for given set
    :param test_s: dataframe, param set given observation
    :param ci_s: dataframe, param sets of CI set
    :param n: int, number of KNNs to find
    :return: list, indexes of n KNNs
    """
    ref = ci_s.iloc[(ci_s - test_s.values).pow(2).sum(axis=1).pow(0.5).argsort()].index
    return [ref[i] for i in range(n)]


def get_predicted_closest_single(indexes, pred_value_table, targ):
    """ returns the absolute errors for knn
    :param indexes: list, index of knn
    :param pred_value_table: dataframe, predicted parameter values of CI set
    :param targ: str, parameter name
    :return: list of predictions for each knn
    """
    # subset the real and predicted values of the closest neighbors
    closest_pred = pred_value_table.loc[indexes, :]

    # for single parameter, get the absolute difference between these
    pred_d = list(closest_pred[targ][:])
    return pred_d


def get_error_closest_single(indexes, real_value_table, pred_value_table, targ):
    """ returns the absolute errors for knn
    :param indexes: list, index of knn
    :param real_value_table: dataframe, real/target parameter values of CI set
    :param pred_value_table: dataframe, predicted parameter values of CI set
    :param targ: str, parameter name
    :return: list of absolute error in predictions for each knn
    """
    # subset the real and predicted values of the closest neighbors
    closest_pred = pred_value_table.loc[indexes, :]
    closest_real = real_value_table.loc[indexes, :]

    # for single parameter, get the absolute difference between these
    error_d = closest_pred[targ] - closest_real[targ]
    return error_d


def apply_filter(df1, df2, df3, df4, indexes):
    return df1.loc[indexes], df2.loc[indexes], df3.loc[indexes], df4.loc[indexes]


def load_files(arg_name, sep=""):
    """Loads given file

    :param arg_name: parser arg, pointer to the file
    :param sep: str, eventual separator
    :return: pd.Dataframe, loaded file
    """
    with open(arg_name, 'r') as des0:
        des_data0 = des0.read()
    des0.close()

    if sep == "":
        output = pd.read_csv(io.StringIO(des_data0), index_col=0, header=None)
    else:
        output = pd.read_csv(io.StringIO(des_data0), index_col=0, header=None, sep=sep)

    return output

In [16]:
### pre processing of datasets used for CI computation: extracting parameters of interest, standardizing them
# extract helper parameters of the CI set
# subset sampling probability:
CI_sampling = CI_param["sampling_frac"]
# tree size:
CI_tree_size = CI_param["tree_size"]

# subselect columns/parameters of interest for each table + all in the same order
CI_param = CI_param[targets]
predicted_test = predicted_test[targets]
CI_predicted = CI_predicted[targets]

# before computation, standardize all columns so that each parameter is on the same scale:
scaler = StandardScaler()
CI_param_standardized = pd.DataFrame(scaler.fit_transform(CI_param)) # fit to CI set
predicted_test_standardized = pd.DataFrame(scaler.transform(predicted_test))

# restore column names and index values
CI_param_standardized.columns = CI_param.columns
CI_param_standardized.index = CI_param.index
predicted_test_standardized.columns = predicted_test.columns
predicted_test_standardized.index = predicted_test.index

In [20]:
# initialize the output table
CI_df = pd.DataFrame(index=range(0, predicted_test.shape[0]), columns=col_comp)

# predicted parameter values from empirical set: here there is only one empirical set for which we want to compute CI values
current_obs = predicted_test.iloc[0, :]
current_obs_standardized = predicted_test_standardized.iloc[0, :]

## find the 2% of closest simulations with respect to tree size and sampling frequency
# first filter: keep only the closest 20k CI sets with respect to tree size
tree_size_indexes = get_indexes_of_closest_single_factor(test_tree_size, CI_tree_size, 20000)
filt_1_CI_predicted, filt_1_param_CI_standardized, filt_1_CI_param, filt_1_CI_sampling_proba = \
    apply_filter(CI_predicted, CI_param_standardized, CI_param, CI_sampling, tree_size_indexes)
# reset indexes
filt_1_CI_param.index = filt_1_param_CI_standardized.index = filt_1_CI_predicted.index = \
    filt_1_CI_sampling_proba.index = range(0, 20000)

# second filter: keep only the closest 4k CI sets with respect to sampling frequency
sampling_proba_indexes = get_indexes_of_closest_single_factor(samp_proba_list_test, filt_1_CI_sampling_proba, 4000)
filt_2_CI_predicted, filt_2_param_CI_standardized, filt_2_CI_param, filt_2_CI_sampling_proba = \
    apply_filter(filt_1_CI_predicted, filt_1_param_CI_standardized, filt_1_CI_param,
                 filt_1_CI_sampling_proba, sampling_proba_indexes)

# reset indexes
filt_2_CI_predicted.index = filt_2_param_CI_standardized.index = filt_2_CI_param.index = range(0, 4000)

# vector to stock all measures of the current observation
all_real = []

for elt in targets:

    # find indexes of closest parameter sets within the predicted values of 40K simulation of CI set
    top_ind = get_indexes_of_closest_single_factor(current_obs_standardized[elt], filt_2_param_CI_standardized[
        elt], n_neighbors[-1])

    # measure errors on closest parameters sets (predicted - actual values)
    pred_closest = get_predicted_closest_single(top_ind, filt_2_CI_predicted, elt)
    error_closest = get_error_closest_single(top_ind, filt_2_CI_param, filt_2_CI_predicted, elt)

    for j in range(len(n_neighbors)):
        # refactor the measured error into a dict 'name_of_param': list of errors (top n neighbours)
        pred_closest_n_neigh = pred_closest[0:n_neighbors[j]]
        error_closest_n_neigh = error_closest[0:n_neighbors[j]]
        median_pred = np.median(pred_closest_n_neigh)
        median_error = np.median(error_closest_n_neigh)
        # center the values around the given prediction
        centered = [item - median_error + current_obs[elt] for item in error_closest_n_neigh]

        # rescale back to original time scale of empirical observation for time-related parameters:
        if 'resc' in elt:
            centered_resc = [float(item) for item in centered]
            current_obs[elt] = current_obs[elt]
        else:
            centered_resc = centered

        # apply minimum and maximum values for each parameter (e.g. no negative values)
        print(elt,centered_resc)
        current_obs[elt] = max(min_max[elt][0], current_obs[elt])
        current_obs[elt] = min(min_max[elt][1], current_obs[elt])
        #centered_resc = [max(min_max[elt][0], item) for item in centered_resc]
        #centered_resc = [min(min_max[elt][1], item) for item in centered_resc]
        # compute statistics: 2.5%, 97.5% boundaries
        qtls = np.percentile(centered_resc, np.array(np.array([2.5, 97.5])))
        min_2_5 = qtls[0]
        max_97_5 = qtls[1]
        width_CI = qtls[1] - qtls[0]

        all_real.append(min_2_5)
        all_real.append(max_97_5)
        all_real.append(width_CI)

CI_df.loc[0, :] = all_real.copy()

turnover [0.6284678970950318, 0.5060314170950317, 0.6557887570950317, 0.6819995970950317, 0.6408510570950318, 0.6223135670950317, 0.5810363370950318, 0.6499264370950317, 0.7240008670950318, 0.6401077970950317, 0.3734958770950318, 0.6313821970950317, 0.6190509970950318, 0.7456169370950317, 0.5746316970950317, 0.48753245709503173, 0.6684605370950317, 0.6876874970950317, 0.6007569670950318, 0.6346997970950318, 0.5721766470950317, 0.5711128970950318, 0.5261457970950317, 0.3949061370950318, 0.6259188370950318, 0.5727950370950318, 0.7091635970950317, 0.6439630370950318, 0.5130554770950317, 0.5988831970950318, 0.6945594970950318, 0.47889359709503176, 0.6828306970950317, 0.5886027970950318, 0.7544511270950318, 0.5840556970950317, 0.5786455970950318, 0.34343724709503176, 0.33096909709503175, 0.5544294570950318, 0.5672141970950317, 0.40184696709503176, 0.6181788470950317, 0.6999242670950317, 0.6155868970950318, 0.5434162970950317, 0.6157877370950318, 0.5608619270950318, 0.6859050970950318, 0.508

In [21]:
CI_df = pd.concat([current_obs.to_frame().T, CI_df], axis=1)

In [22]:
#1000 trees
CI_df

Unnamed: 0,turnover,lambda1_rescaled,lambda2_rescaled,q01_rescaled,turnover_CI_2_5_1000,turnover_CI_97_5_1000,turnover_CI_width_1000,lambda1_rescaled_CI_2_5_1000,lambda1_rescaled_CI_97_5_1000,lambda1_rescaled_CI_width_1000,lambda2_rescaled_CI_2_5_1000,lambda2_rescaled_CI_97_5_1000,lambda2_rescaled_CI_width_1000,q01_rescaled_CI_2_5_1000,q01_rescaled_CI_97_5_1000,q01_rescaled_CI_width_1000
0,0.610441,0.378217,0.166157,0.010635,0.391435,0.733746,0.342311,0.310225,0.440335,0.130109,0.10892,0.275925,0.167005,0.005794,0.020707,0.014912
