In [1]:
# Loading modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
# Loading data
tree_table = pd.read_csv('input_data/table.dat', sep = ',', header=None, names=['Parent', 'Child'])
branch_lengths = pd.read_csv('input_data/branchlength.dat', sep = ',', header=None)
branch_lengths = pd.DataFrame(branch_lengths.values.flatten(), columns=['Length'])
msa = pd.read_csv('input_data/msa.dat', sep = ' ', header=None)
merged_data = pd.concat([tree_table, branch_lengths], axis=1)
merged_data['Sequence'] = np.nan
merged_data['Sequence'] = merged_data['Sequence'].astype(object)
# Iterate over msa and update the 'sequence' column in merged_data
for index, row in msa.iterrows():
    child_value = row[0]
    sequence_value = ' '.join(map(str, row[1:]))
    merged_data.loc[merged_data['Child'] == child_value, 'Sequence'] = sequence_value

merged_data

Unnamed: 0,Parent,Child,Length,Sequence
0,9,1,0.1,AGATCAAGATCAAGATCAAGATCAAGATCA
1,9,2,0.4,AGCTCAAGCTCAAGCTCAAGCTCAAGCTCA
2,8,9,0.01,
3,8,3,0.04,CGCTATCGCTATCGCTATCGCTATCGCTAT
4,7,4,0.2,CGTTACCGTTACCGTTACCGTTACCGTTAC
5,7,5,0.08,CGCTACCGCTACCGCTACCGCTACCGCTAC
6,6,7,0.12,
7,6,8,0.14,


In [49]:
# Creating the node class
class Node:

    def __init__(self, identity, parent=None, branch_length=None, sequence=None):
        self.identity = identity
        self.parent = parent
        self.branch_length = branch_length
        self.sequence = sequence
        self.children = []


In [86]:

import pandas as pd
import numpy as np
from scipy.linalg import expm
import os
import tempfile

Q = np.array([[-0.9, 0.3, 0.3, 0.3],
              [0.3, -0.9, 0.3, 0.3],
              [0.3, 0.3, -0.9, 0.3],
              [0.3, 0.3, 0.3, -0.9]])


class Node:
    def __init__(self, name):
        self.name = name
        self.parent = None
        self.children = []
        self.branch_length = None
        self.sequence = None
        self.is_final = False


class Tree:
    def __init__(self, table_path, msa_path, branch_lengths_path, transition_matrix):
        # Preprocess files before using them
        table_path, msa_path, branch_lengths_path = self.preprocess_files_if_needed(
            table_path, msa_path, branch_lengths_path
        )

        self.data = self.get_data(table_path, msa_path, branch_lengths_path)
        self.root = None
        self.transition_matrix = transition_matrix
        self.populate_tree(self.data)
        self.tree_probability(self.root)

    @staticmethod
    def preprocess_files_if_needed(table_path, msa_path, branch_lengths_path):
        # Preprocess the table file
        tree_table = pd.read_csv(table_path, sep=',', header=None, names=['Parent', 'Child'])

        # Ensure MSA file has the correct structure
        msa = pd.read_csv(msa_path, sep=' ', header=None)
        if msa.shape[1] != 2:
            raise ValueError(f"MSA file must have two columns. Detected columns: {msa.shape[1]}")
        msa.columns = ['Species', 'Sequence']

        # Assign unique identifiers to species in the MSA
        species_to_id = {species: idx + 1 for idx, species in enumerate(msa['Species'])}
        msa['Species'] = msa['Species'].map(species_to_id)

        # Replace species names in the tree table
        def map_species_to_id(value):
            return species_to_id.get(value, value)  # Map species or leave internal nodes as-is

        tree_table['Child'] = tree_table['Child'].apply(map_species_to_id)
        tree_table['Parent'] = tree_table['Parent'].apply(map_species_to_id)

        # Ensure branch lengths have the correct structure
        branch_lengths = pd.read_csv(branch_lengths_path, sep=',', header=None)
        branch_lengths = branch_lengths.T
        if branch_lengths.shape[1] != 1:
            raise ValueError(f"Branch lengths file must have exactly one column. Detected columns: {branch_lengths.shape[1]}")

        # Save transformed files as temporary files
        temp_dir = tempfile.mkdtemp()
        transformed_table_path = os.path.join(temp_dir, 'transformed_table.csv')
        transformed_msa_path = os.path.join(temp_dir, 'transformed_msa.csv')
        transformed_branch_lengths_path = os.path.join(temp_dir, 'transformed_branch_lengths.csv')

        tree_table.to_csv(transformed_table_path, index=False, header=False)
        msa.to_csv(transformed_msa_path, index=False, header=False)
        branch_lengths.to_csv(transformed_branch_lengths_path, index=False, header=False)

        return transformed_table_path, transformed_msa_path, transformed_branch_lengths_path

    def get_data(self, table, msa, branch_lengths):
        tree_table = pd.read_csv(table, sep=',', header=None, names=['Parent', 'Child'])
        branch_lengths = pd.read_csv(branch_lengths, sep=',', header=None)
        branch_lengths = pd.DataFrame(branch_lengths.values.flatten(), columns=['Length'])
        msa = pd.read_csv(msa, sep=',', header=None)
        msa.columns = ['Species', 'Sequence']

        data = pd.concat([tree_table, branch_lengths], axis=1)
        data['Sequence'] = np.nan
        data['Sequence'] = data['Sequence'].astype(object)

        for index, row in msa.iterrows():
            child_value = row['Species']
            sequence_value = row['Sequence']
            data.loc[data['Child'] == child_value, 'Sequence'] = sequence_value
        print(data)
        return data

    def is_final_node(self, node_name):
        child_count = self.data['Child'].value_counts().get(node_name, 0)
        parent_count = self.data['Parent'].value_counts().get(node_name, 0)
        return child_count == 1 and parent_count == 0

    def populate_tree(self, data):
        nodes = {}

        for index, row in data.iterrows():
            parent = row['Parent']
            child = row['Child']
            branch_length = row['Length']
            sequence = row['Sequence']

            if parent not in nodes:
                nodes[parent] = Node(parent)
            if child not in nodes:
                nodes[child] = Node(child)

            nodes[child].parent = nodes[parent]
            nodes[child].branch_length = branch_length
            nodes[child].sequence = sequence
            nodes[child].is_final = self.is_final_node(child)
            nodes[parent].children.append(nodes[child])

        for node in nodes.values():
            if node.parent is None:
                self.root = node
                break
        self.nodes = nodes

    def one_hot_encode(self, sequence):
        if sequence is None or (isinstance(sequence, str) and sequence == ''):
            return np.array([])
        mapping = {
            'A': [1, 0, 0, 0],
            'C': [0, 1, 0, 0],
            'G': [0, 0, 1, 0],
            'T': [0, 0, 0, 1],
        }
        one_hot_sequence = [mapping[base] for base in sequence]
        return np.array(one_hot_sequence)

    def get_nucleotide_probability(self, node1, node2):
        prob_vector = []

        sequence1 = self.one_hot_encode(node1.sequence) if node1.is_final else node1.sequence
        sequence2 = self.one_hot_encode(node2.sequence) if node2.is_final else node2.sequence

        for i in range(len(sequence1)):
            exp_matrix1 = expm(self.transition_matrix * node1.branch_length)
            exp_matrix2 = expm(self.transition_matrix * node2.branch_length)
            transformed_seq1 = np.dot(exp_matrix1, sequence1[i])
            transformed_seq2 = np.dot(exp_matrix2, sequence2[i])
            prob_vector.append(transformed_seq1 * transformed_seq2)

        prob_vector = np.array(prob_vector)
        node1.parent.sequence = prob_vector
        return prob_vector

    def tree_probability(self, node):
        if node.is_final:
            # Leaf nodes already have sequences assigned
            return self.one_hot_encode(node.sequence)

        if not node.children:
            return None  # No children, no sequence to propagate

        # Recursively calculate probabilities for both children
        child_sequences = [self.tree_probability(child) for child in node.children if child.sequence is not None]

        if len(child_sequences) == 2:
            node.sequence = self.get_nucleotide_probability(node.children[0], node.children[1])
            return node.sequence
        else:
            return None

    def get_log_likelihood(self):
        eq_freq = np.array([0.25, 0.25, 0.25, 0.25])
        log_likelihood = np.sum(np.log(np.matmul(self.root.sequence, eq_freq)))
        return float(log_likelihood)

    def print_tree(self, node=None, level=0):
        if node is None:
            node = self.root
        one_hot_sequence = self.one_hot_encode(node.sequence) if node.is_final and node.sequence else None
        print(
            ' ' * level * 4 + f'Node: {node.name}, Branch Length: {node.branch_length}, Sequence: {node.sequence}, One-Hot Encoded Sequence: {one_hot_sequence}, Is Final: {node.is_final}')
        for child in node.children:
            self.print_tree(child, level + 1)

    
tree = Tree('input_data/ENSG00000112282_MED23_NT.table.dat', 'input_data/ENSG00000112282_MED23_NT.msa.dat', 'input_data/ENSG00000112282_MED23_NT.branchlength.dat', Q)
tree.print_tree()
tree.get_log_likelihood()


     Parent  Child  Length                                           Sequence
0       116     13   0.030  ATGGTGCAGATGGAGACGCAGCTGCAGAGCATTTTCGAGGAGGTTG...
1       117     12   0.020  TGCTGTGGGATGGAGACGCAGCTGCAGAGCATTTTCGAGGAGGTCG...
2       117     14   0.019  ATGGTGCAGATGGAGACGCAGCTGCAGAGCATTTTCGAGGAGGTCG...
3       116    117   0.007                                                NaN
4       118    116   0.050                                                NaN
..      ...    ...     ...                                                ...
223     127    179   0.004                                                NaN
224     120    127   0.006                                                NaN
225     118    120   0.073                                                NaN
226     229    118   0.013                                                NaN
227     229      2   0.126  GCCCTATGTTTCGTAAGCGTAATACATGTCCACTGATCATCGCATC...

[228 rows x 4 columns]
Node: 229, Branch Length: None, Sequence

-11350.557966913415

In [27]:
import pandas as pd

def process_files(table_path, msa_path, branch_lengths_path):
    # Load files
    tree_table = pd.read_csv(table_path, sep=',', header=None, names=['Parent', 'Child'])
    msa = pd.read_csv(msa_path, sep=' ', header=None)
    msa.columns = ['Species', 'Sequence']
    branch_lengths = pd.read_csv(branch_lengths_path, sep=',', header=None)
    
    # Assign unique identifiers to species in the MSA
    species_to_id = {species: idx + 1 for idx, species in enumerate(msa['Species'])}
    
    # Debug: Print the mapping dictionary
    print("Species to ID Mapping:")
    for species, identifier in species_to_id.items():
        print(f"{species} -> {identifier}")
    
    # Replace species names in the MSA with their identifiers
    msa['Species'] = msa['Species'].map(species_to_id)
    
    # Replace species names in the tree table
    def map_species_to_id(child):
        if child in species_to_id:
            return species_to_id[child]
        return child  # Internal nodes remain unchanged

    tree_table['Child'] = tree_table['Child'].apply(map_species_to_id)
    
    # Debug: Print transformed MSA and tree table
    print("\nTransformed MSA:")
    print(msa.head())
    
    print("\nTransformed Tree Table:")
    print(tree_table.head())
    
    # Return transformed data for further processing
    return tree_table, msa, branch_lengths

table_path = 'input_data/ENSG00000112282_MED23_NT.table.dat'
msa_path = 'input_data/ENSG00000112282_MED23_NT.msa.dat'
branch_lengths_path = 'input_data/ENSG00000112282_MED23_NT.branchlength.dat'

transformed_tree_table, transformed_msa, transformed_branch_lengths = process_files(table_path, msa_path, branch_lengths_path)


Species to ID Mapping:
Choloepus_hoffmanni -> 1
Ornithorhynchus_anatinus -> 2
Trichechus_manatus_latirostris -> 3
Cavia_aperea -> 4
Sorex_araneus -> 5
Loxodonta_africana -> 6
Procavia_capensis -> 7
Ailuropoda_melanoleuca -> 8
Fukomys_damarensis -> 9
Nannospalax_galili -> 10
Mesocricetus_auratus -> 11
Notamacropus_eugenii -> 12
Monodelphis_domestica -> 13
Phascolarctos_cinereus -> 14
Mustela_putorius -> 15
Canis_familiaris -> 16
Felis_catus -> 17
Panthera_tigris_altaica -> 18
Panthera_pardus -> 19
Acinonyx_jubatus -> 20
Neomonachus_schauinslandi -> 21
Odobenus_rosmarus_divergens -> 22
Leptonychotes_weddellii -> 23
Ursus_maritimus -> 24
Enhydra_lutris_kenyoni -> 25
Aotus_nancymaae -> 26
Callithrix_jacchus -> 27
Cercocebus_atys -> 28
Chlorocebus_sabaeus -> 29
Gorilla_gorilla -> 30
Homo_sapiens -> 31
Pongo_abelii -> 32
Rhinopithecus_bieti -> 33
Rhinopithecus_roxellana -> 34
Colobus_angolensis -> 35
Mandrillus_leucophaeus -> 36
Nomascus_leucogenys -> 37
Piliocolobus_tephrosceles -> 38
Saimi