In [18]:
from Bio import SeqIO
import itertools
import numpy as np
import pandas as pd

In [122]:
np.set_printoptions(suppress=True)

In [347]:
class Neighbor_Joining_Tree(object):
    
    def __init__(self, fna):
        self.D = None
        self.sums = None
        self.n = 0 # remaining number of nodes 
        self.node_ids = None
        self.sequences, self.identifiers = self.parse_sequences(fna)
        self.tree = np.zeros(shape=(1, 3))
        self.Q = None
        self.is_first_ancestor = True
        self.internal_node_id_tracker = 0
    
    
    def initialize_last_variables(self):
        self.sums = self.D.sum(axis=0)
        self.n = len(self.D[0])
        self.node_ids = np.arange(1, len(self.D)+1, 1) # 1 based counting, not 0 based
        self.internal_node_id_tracker = self.n * 2 - 2 # An unrooted binary tree has 2n-2 internal nodes. We are counting internal nodes beginning with root as ntips + 1, and we won't get to the root til later so we're counting backwrads here.
    
    
    def parse_sequences(self, fna):
        sequences = []
        identifiers = []
        fna_obj = list(SeqIO.parse(fna, 'fasta'))
        for i, seq in enumerate(fna_obj):
            sequences.append(seq.seq)
            identifiers.append(seq.id)
        return sequences, identifiers
    
    
    def calculate_dissimilarity_score(self, a, b):
        '''
        If letters do not match, add 1. Otherwise add 0. 
        Sequences a and b are the same length.
        Returns dissimilarity score as float. 
        '''
        score = 0.00
        length = len(a)

        for i, char_a in enumerate(a):
            char_b = b[i]
            if char_a != char_b:
                score += 1

        return score/length
    
    
    def calculate_D_matrix(self):
        print('Calculating D matrix...')
        
        distances = []
        for pair in itertools.product(self.sequences, repeat=2): # repeat=2? ... and we are doing too much work here but its ok for now, can just do half of the work and flip the matrix across the diagonal
            distances.append(self.calculate_dissimilarity_score(pair[0], pair[1]))
        
        # Reshape distances array into matrix based on number of sequences
        self.D = np.reshape(distances, (len(self.sequences), len(self.sequences)))

        
    def print_D_matrix(self):
        filename = 'pairwise_dissimilarity.txt'
        print('Printing D matrix to ' + filename + '...')
        
        distances_df = pd.DataFrame(self.D, columns=self.identifiers) 
        distances_df.insert(0, 'id', self.identifiers) # Add row index
        distances_df = distances_df.reset_index(drop=True).set_index('id')
        
        # Save pairwise dissimilarities to tab-delimited file 
        distances_df.to_csv(filename, sep='\t')
        
    
    def calculate_q_cell(self, distance, sums_i, sums_j):
        '''
        Calculates and returns cell value for the Q matrix. 

        n — number of sequences
        distance — pairwise dissimilarity between pair
        a_sums — summed distances from node a to all the other nodes
        b_sums — summed distances from node b to all the other nodes
        '''

        return (self.n-2) * distance - sums_i - sums_j

    
    def calculate_Q_matrix(self):
        Q_matrix = np.zeros_like(self.D)

        for i, row in enumerate(self.D):
            sums_i = self.sums[i]
            for j, distance in enumerate(row):
                # If diagonal cell, set to infinity (to avoid it from ever being the minimum)
                if i == j:
                    Q_matrix[j, i] = None
                else: 
                    # Get terms for neighbor joining equation
                    sums_j = self.sums[j]
                    # Calculate and set value for this cell in Q matrix
                    Q_matrix[j, i] = self.calculate_q_cell(distance, sums_i, sums_j)

        return Q_matrix
    
    
    def calculate_branch_length(self, distance, sums_f, sums_g):
        '''
        
        '''
        return distance/2 + abs((sums_f - sums_g))/(2*(self.n-2))

    
    def calculate_next_branch_length(self, distance, first_branch_length):
        return abs(distance - first_branch_length) 
    
    
    def calculate_uk(self, fk, gk, fg):
        '''
        Calculates and returns distance of node u to node k. 

        dist_fk — distance of node f to node k
        dist_gk — distance of node g to node k
        dist_fg — distance of node f to node g

        where f and g are members of the pair just joined.
        '''

        return (fk + gk - fg)/2
    
    
    def get_branches(self, f, g, h):
        '''
        [ancestral node, descendant node, branch length]
        '''
        
        # Get ids for each element
        u_id = self.internal_node_id_tracker
        f_id = self.node_ids[f]
        g_id = self.node_ids[g]
        
        # Track node id of new node
        self.node_ids = np.append(self.node_ids, u_id)
        self.internal_node_id_tracker -= 1 

        # Calculate lengths of branches joining f and g to u
        distance_fg = self.D[f, g]
        delta_fu = self.calculate_branch_length(distance_fg, self.sums[f], self.sums[g])
        delta_gu = self.calculate_next_branch_length(distance_fg, delta_fu)
            
        branch_fu = [u_id, f_id, delta_fu] 
        branch_gu = [u_id, g_id, delta_gu]
        
        # If we are on the last iteration, get the third branch joining h to u
        branch_hu = None
        if len(self.D) == 3:
            h_id = self.node_ids[h]
            distance_fh = self.D[f, h]
            delta_hu = self.calculate_next_branch_length(distance_fh, delta_fu)
            branch_hu = [u_id, h_id, delta_hu]
            
        return branch_fu, branch_gu, branch_hu
        
    def update_D(self, f, g):
        row_u = np.zeros_like(self.D[0]) 
        for k,_ in enumerate(self.D):
            row_u[k] = self.calculate_uk(self.D[f, k], self.D[g, k], self.D[f, g])

        # Update D matrix with these distances
        column_u = row_u[:, np.newaxis] 
        column_u = np.vstack([column_u, 0]) # Add its own diagonal value
        self.D = np.vstack((self.D, row_u))
        self.D = np.hstack((self.D, column_u))

        
    def run(self):
        # Task 1
        self.calculate_D_matrix()
        self.print_D_matrix()
        
        # Task 2
        self.initialize_last_variables() # which rely on D
        
        # Loop until D contains 59 elements
        while len(self.D) >= 3:
            # Calculate the join score of each pair (#1)
            #print('\nCalculating Q matrix...')
            self.Q = np.round(self.calculate_Q_matrix(), 12) #1 (rounding cuz weirdly it had an issue with precision past this amount, would be slightly different row vs column)
            self.Q = np.ma.masked_invalid(self.Q) # need to mask nans cuz nan cannot be used in np.where
            # np.savetxt('foo ' + str(len(self.Q)) + '.csv', self.Q, delimiter=",")
            
            # Find the pair with the minimum join score (#2)
            #print(np.ma.where(self.Q == np.amin(self.Q)))
            if len(self.D) == 3:
                # If on the last iteration, forcibly assign f, g, h to 0, 1, 2  
                f, g, h = 0, 1, 2
            else:
                f, g = np.ma.where(self.Q == np.amin(self.Q))[0]
                h = None
            
            #print('The two nodes to be joined have row ids: ' + str(f) + ' and ' + str(g) + ' and ' + str(h))
            #print('Their actual node ids: ' + str(self.node_ids[f]) + ' and ' + str(self.node_ids[g]) + ' and ' + str(self.node_ids[h])) # only works if h is not None
            
            # Add branch lengths for this pair to their ancestral node u (#3)
            #print('Adding branch lengths from this pair to new, ancestral node...')
            branch_fu, branch_gu, branch_hu = self.get_branches(f, g, h)
            self.tree = np.vstack((self.tree, branch_fu))
            self.tree = np.vstack((self.tree, branch_gu))
            if branch_hu:
                self.tree = np.vstack((self.tree, branch_hu))
            
            # Calculate distance between u and every other node k outside of this pair 
            # and update D matrix with these distances (#4)
            #print('Updating distances between new, ancestral node and every other node...')
            self.update_D(f, g)
            
            # Remove this pair 
            #print('Removing two nodes...')
            self.D = np.delete(self.D, [f, g], axis=1)
            self.D = np.delete(self.D, [f, g], axis=0)
            self.sums = np.delete(self.sums, [f, g])
            self.node_ids = np.delete(self.node_ids, [f, g])
            # Update variables
            self.sums = np.append(self.sums, self.D[-1].sum())
            self.n = len(self.D)
        
        self.tree = np.delete(njt.tree, 0, axis=0) # remove that initialized beginning [0,0,0]
        print(self.tree)

In [348]:
alignment_file = 'hw3.fna'
njt = Neighbor_Joining_Tree(alignment_file)
njt.run()

Calculating D matrix...
Printing D matrix to pairwise_dissimilarity.txt...
[[120.           1.           0.12465497]
 [120.           2.           0.11558729]
 [119.           3.           0.13564533]
 [119.         120.           0.01812317]
 [118.          16.           0.04776156]
 [118.          33.           0.03366508]
 [117.           9.           0.08202149]
 [117.          59.           0.07544823]
 [116.          30.           0.03473021]
 [116.          60.           0.02583507]
 [115.          10.           0.06171178]
 [115.          57.           0.05403519]
 [114.          25.           0.02287387]
 [114.          29.           0.02019477]
 [113.          17.           0.0353168 ]
 [113.          18.           0.031305  ]
 [112.          23.           0.04042303]
 [112.          28.           0.03360119]
 [111.           5.           0.0952961 ]
 [111.          40.           0.08707268]
 [110.          34.           0.04201802]
 [110.          46.           0.03873568]
 

Tomorrow: 
- MIght need to figure out why my branch lenghts are not the same as in solutoin and why i have different nodes connected. 
- generate output files.


Perform preorder traversal

Choose arbitrary node with id higher than sequence length idicating it is internal, choose it as root

Preorder traversal is a DFS search

In [350]:
tree = njt.tree

Algorithm Preorder(tree)
   1. Visit the root.
   2. Traverse the left subtree, i.e., call Preorder(left-subtree)
   3. Traverse the right subtree, i.e., call Preorder(right-subtree) 

In [439]:
def preorder_traversal(root_id):
    traversed_tree = []
    children = np.where(tree[:,0] == root_id)[0]
    
    if children.size > 0:
        # Traverse descendant subtrees
        for child in children: 
            traversed_tree.append(tree[child])
            traversed_tree = traversed_tree + preorder_traversal(tree[child][1])
    
    return traversed_tree

root_id = 62
traversed_tree = preorder_traversal(root_id)
np.savetxt('edges.txt', traversed_tree, fmt='%i\t%i\t%1.10f')

Note: as per [documentation](https://docs.scipy.org/doc/numpy/reference/generated/numpy.savetxt.html), delimiter is ignored with multiformat string, so I specified delimiter within the multiformat string

In [433]:
len(traversed_tree)

119

In [442]:
def postorder_traversal(root_id):
    traversed_tree = []
    children = np.where(tree[:, 0] == root_id)[0]
    
    if children.size > 0:
        for child in children: 
            traversed_tree = traversed_tree + preorder_traversal(tree[child][1])
            traversed_tree.append(tree[child])
    
    return traversed_tree

postorder_traversed_tree = postorder_traversal(root_id)
np.savetxt('edges_postorder.txt', postorder_traversed_tree, fmt='%i\t%i\t%1.10f')

In [438]:
postorder_traversed_tree

[array([66.        , 73.        ,  0.01672643]),
 array([73.       , 90.       ,  0.0182353]),
 array([90.       , 12.       ,  0.1090407]),
 array([90.        , 37.        ,  0.06861744]),
 array([73.        , 85.        ,  0.01718382]),
 array([ 85.        , 110.        ,   0.05107041]),
 array([110.        ,  34.        ,   0.04201802]),
 array([110.        ,  46.        ,   0.03873568]),
 array([ 85.        , 109.        ,   0.02976741]),
 array([109.       ,  35.       ,   0.0702792]),
 array([109.        , 116.        ,   0.00803156]),
 array([116.        ,  30.        ,   0.03473021]),
 array([116.        ,  60.        ,   0.02583507]),
 array([66.        , 71.        ,  0.00736299]),
 array([71.        , 84.        ,  0.01121088]),
 array([ 84.        , 113.        ,   0.07259056]),
 array([113.       ,  17.       ,   0.0353168]),
 array([113.      ,  18.      ,   0.031305]),
 array([ 84.        , 108.        ,   0.01421236]),
 array([108.        ,  45.        ,   0.11183815]),