### Group 8
- Nimrod Grandpierre
- Jonas Riber Jørgensen
- Johan Ulstrup
- Laura Fur

# Project 5: NJ tree construction
This project is about making an efficient implementation of the neighbor-joining (NJ) algorithm as shown on slide 50 in the slides about tree reconstruction and compare its performance to the NJ programs QuickTree and RapidNJ that you know from project 4.

## <span style="color:cornflowerblue">Problem<span/>
You should make a program that implements the NJ algorithm as shown in the slides about tree reconstruction. Your program program should take a distance matrix in phylip-format as input and produce a tree in newick-format as output. You should know these formats from project 4. Your aim is to make your implementation as efficient as possible.

The file example_slide4.phy contains the distance matrix (in phylip-format) from slide 4 in the slides about tree reconstruction. With this matrix as input, your program should produce the tree that is also shown on slide 4 in the slides about tree reconstruction.

In [4]:
import pandas as pd
import numpy as np
from read_phylip import *
from correct_DM import *
from NJ_print import *
import os

### Implementation

In [8]:
def NJ(S: list, DM: dict):

    T = []

    while len(S) > 3:

        def cluster(S, DM):
            nonlocal T

            """ Step 1: Correcting the input distance matrix and choosing a sequence pair to cluster """
            
            N = compute_N(S, DM)  # Correcting the distance matrix.

            min_distances = {i: min((distance, idx) for idx, distance in enumerate(row) if distance != 0.0) for i, row in N.items()}  # Finding the shortest distance and its index for each cluster in N.

            C1 = min(min_distances, key = lambda k: min_distances[k][0])          # Cluster 1 is the cluster with the shortest distance among all clusters' minimum distances.
            C2_idx = min_distances[C1][1]                                         # Extract the index for cluster 2. 
            C2 = S[C2_idx]                                                        # Get cluster 2 from S.

            k = [C1, C2]       # Pair to cluster

            """ Step 2: Adding new node, k, to T """

            k_copy = k.copy()  # Create a copy of k to append to T to avoid overwriting k.
            T.append(k_copy)   # Append the copy to T.

            """ Step 3: Calculating distances to new cluster and adding edges, (k, i) and (k, j) """

            # Adding edges, (k, i) and (k, j):
            dij = {seq_i: DM[seq_i] for seq_i in DM.keys()}
            ri = {seq_i: round((1 / (len(DM) - 2)) * np.sum(DM_i), 2) for i, (seq_i, DM_i) in enumerate(DM.items()) for j in range(len(DM_i)) if i != j}
            
            diu = round(0.5 * dij[C1][C2_idx] + 0.5 * (ri[C1] - ri[C2]), 2)      # Distance from cluster 1 in k to common node.
            dju = round(dij[C1][C2_idx] - diu, 2)                                # Distance from cluster 2 in k to common node.

            cluster_idx = T.index(k)
            T[cluster_idx].append(diu)
            T[cluster_idx].append(dju)

            # Calculating distances from other clusters to k:
            dijk = {}
            for i, seq_i in enumerate(DM.keys()):
                if seq_i not in (C1, C2):
                    dik = DM[C1][i]
                    djk = DM[C2][i]
                    dij = DM[C1][C2_idx]
                    dijk[seq_i] = round(0.5 * (dik + djk - dij), 2)
                dijk[(C1, C2)] = 0                                               # (Converting k to tuple, as lists cannot be used as keys).

            """ Step 4: Update distance matrix """

            UDM = DM.copy()                                                      # Create template for the updated distance matrix from the input distance matrix.

            UDM[(C1, C2)] = UDM.pop(C1)                                          # Overwrite key for cluster 1 in k. Note that this puts the key for k in the last position of the dictionary.
            del UDM[C2]                                                          # Delete key for cluster 2 in k.

            for seq, dists in UDM.items():
                UDM[seq] = np.delete(dists, list(DM.keys()).index(C2))           # Delete column for the cluster 2 in k.

            # Updating distance matrix values:
            DM_ids = {seq_i: i for i, seq_i in enumerate(DM.keys())}             # Getting cluster indices in the input distance matrix.

            for i, seq_i in enumerate(UDM.keys()):
                if seq_i != (C1, C2):
                    for j, seq_j in enumerate(UDM.keys()):
                        if seq_j != seq_i:
                            if seq_j == (C1, C2):
                                UDM[seq_i][j] = dijk[seq_i]                      # Distance from cluster to k.
                            if seq_j != (C1, C2):
                                UDM[seq_i][j] = DM[seq_i][DM_ids[seq_j]]         # Distance from cluster to other clusters that are not k.
                        if seq_j == seq_i:
                            UDM[seq_i][j] = 0                                    # Distance from cluster to itself - always 0.
                if seq_i == (C1, C2):
                    for j, seq_j in enumerate(UDM.keys()):
                        UDM[seq_i][j] = dijk[seq_j]                              # Distances from k to other clusters.
                                        
            """ Step 5: Delete i and j from S and add the new taxon, k, to S. """

            S.remove(C1)
            S.remove(C2)
            S.append((C1, C2))

            return S, UDM, k

        S, DM, k = cluster(S, DM)

    """ Termination """
    i, j, m = S

    # Ensure m is k, otherwise swap the variable that is k (the last cluster formed) for m.
    if m != tuple(k):
        if i == tuple(k):
            i, m = m, i
        elif j == k:
            j, m = m, j

    idx_i = list(DM.keys()).index(i)
    idx_j = list(DM.keys()).index(j)
    idx_m = list(DM.keys()).index(m)

    gamma_vi = round((DM[i][idx_j] + DM[i][idx_m] - DM[j][idx_m])/2, 3)
    gamma_vj = round((DM[j][idx_i] + DM[j][idx_m] - DM[i][idx_m])/2, 3)
    gamma_vm = round((DM[m][idx_i] + DM[m][idx_j] - DM[i][idx_j])/2, 3)

    T.append([i, gamma_vi])
    T.append([j, gamma_vj])
    T.append([m, gamma_vm])

    return T

#### ```NJ()```

Follows the steps from Saitou and Nei's Neighbor-Joining algorithm.
The function recursively calls another function: ```cluster()```, which clusters sequences together and updates the distance matrix until there are three clusters left to cluster together.

## <span style = 'color:cornflowerblue'>Tests<span/>

Simple testdata with four sequences from ```wiki_ex.phy```

In [6]:
wiki_ex = initiate_DM('wiki_ex.phy')
wiki_ex_taxa = [i for i in wiki_ex.keys()]

wiki_ex_out = NJ_print(wiki_ex_taxa, wiki_ex)
wiki_ex_out

[1m[96mInitialization[0m
[1mNo. taxa:[0m          4
[1mTaxa:[0m              ['A', 'B', 'C', 'D']
[1mDM:[0m                {'A': array([ 0., 17., 21., 27.]), 'B': array([17.,  0., 12., 18.]), 'C': array([21., 12.,  0., 14.]), 'D': array([27., 18., 14.,  0.])}
[1mCorrected DM:[0m      {'A': array([  0., -39., -35., -35.]), 'B': array([-39.,   0., -35., -35.]), 'C': array([-35., -35.,   0., -39.]), 'D': array([-35., -35., -39.,   0.])}
[1mPair to cluster:[0m   ['A', 'B'] 

[1m[95mClustering[0m
[1mDistance to node:[0m  A: 13.0   B: 4.0
[1mDistance to k:[0m     {('A', 'B'): 0, 'C': 8.0, 'D': 14.0}
[1mOverwrite S1:[0m      {'B': array([17.,  0., 12., 18.]), 'C': array([21., 12.,  0., 14.]), 'D': array([27., 18., 14.,  0.]), ('A', 'B'): array([ 0., 17., 21., 27.])}
[1mDelete S2:[0m         {'C': array([21.,  0., 14.]), 'D': array([27., 14.,  0.]), ('A', 'B'): array([ 0., 21., 27.])}
[1mUpdated DM:[0m        {'C': array([ 0., 14.,  8.]), 'D': array([14.,  0., 14.]), 

[['A', 'B', 13.0, 4.0], ['C', 4.0], ['D', 10.0], [('A', 'B'), 4.0]]

For this example dataset, the Newick format should look like this:

**<span style='color:lightskyblue'>(C: 4.0, D: 10.0,(A: 13.0, B: 4.0): 4.0);</span>**

or this:

**<span style='color:lightskyblue'>((A: 13.0, B: 4.0): 4.0, C: 4.0, D: 10.0);</span>**

Provided testdata with five sequences from ```example_slide4.phy```

In [7]:
slide4_ex = initiate_DM('example_slide4.phy')
slide4_ex_taxa = [i for i in slide4_ex.keys()]

slide4_ex_out = NJ_print(slide4_ex_taxa, slide4_ex)
slide4_ex_out

[1m[96mInitialization[0m
[1mNo. taxa:[0m          5
[1mTaxa:[0m              ['A', 'B', 'C', 'D', 'E']
[1mDM:[0m                {'A': array([0.  , 0.23, 0.16, 0.2 , 0.17]), 'B': array([0.23, 0.  , 0.23, 0.17, 0.24]), 'C': array([0.16, 0.23, 0.  , 0.2 , 0.11]), 'D': array([0.2 , 0.17, 0.2 , 0.  , 0.21]), 'E': array([0.17, 0.24, 0.11, 0.21, 0.  ])}
[1mCorrected DM:[0m      {'A': array([ 0.  , -0.31, -0.32, -0.31, -0.32]), 'B': array([-0.31,  0.  , -0.29, -0.38, -0.29]), 'C': array([-0.32, -0.29,  0.  , -0.29, -0.36]), 'D': array([-0.31, -0.38, -0.29,  0.  , -0.29]), 'E': array([-0.32, -0.29, -0.36, -0.29,  0.  ])}
[1mPair to cluster:[0m   ['B', 'D'] 

[1m[95mClustering[0m
[1mDistance to node:[0m  B: 0.1   D: 0.07
[1mDistance to k:[0m     {'A': 0.13, ('B', 'D'): 0, 'C': 0.13, 'E': 0.14}
[1mOverwrite S1:[0m      {'A': array([0.  , 0.23, 0.16, 0.2 , 0.17]), 'C': array([0.16, 0.23, 0.  , 0.2 , 0.11]), 'D': array([0.2 , 0.17, 0.2 , 0.  , 0.21]), 'E': array([0.17, 0.24, 0

[['B', 'D', 0.1, 0.07],
 ['A', ('B', 'D'), 0.08, 0.05],
 ['C', 0.05],
 ['E', 0.06],
 [('A', ('B', 'D')), 0.03]]

For this example dataset, the Newick format should look like this:

**<span style='color:lightskyblue'>(C: 0.5000, E: 0.6000, (A: 0.8000, (B: 1.0000, D: 0.7000): 0.5000): 0.3000);</span>**

or this:

**<span style='color:lightskyblue'>((A: 0.8000, (B: 1.0000, D: 0.7000): 0.5000): 0.3000, C: 0.5000, E: 0.6000);</span>**

## <span style = 'color:cornflowerblue'>Experiments<span/>
From project 4, you know the programs QuickTree and RapidNJ that are implementations of the NJ methods. QuickTree implements the basic cubic time algorithm while RapidNJ implements an algorithm the is faster in practice.

You should compare the performance of your program against these two program in the following way.

The archive unique_distance_matrices.zip contains 14 distance matrices (in phylip-format) ranging in size from 89 to 1849 species. For each distance matrix, you should do the following:

1. Measure the time it takes to construct the corresponding NJ tree using QuickTree, RapidNJ, and your program.
2. Compute the RF-distances (using your program rfdist from project 4) between the trees produced by QuickTree, RapidNJ, and your program.

(If you want to investigate the running time of your program on more examples than provided in distance_matrices.zip, then you are welcome to download pfam_alignments.zip that contains 128 of alignment in Stockholm-format (from the Pfam database) aligning from 58 to 71535 species that you can convert to distance matrices in phylip-format using e.g. QuickTree. However, converting the big alignments to distance matrices would probably take too long and require too much space.)

In [10]:
uniqueDistMatrices = sorted([file for file in os.listdir('unique_distance_matrices')], key = lambda x: int(x.split('_')[0]))

In [11]:
data = initiate_DM(f'unique_distance_matrices/{uniqueDistMatrices[1]}')
data_taxa = [i for i in data.keys()]

NJ(data_taxa, data)

[['21_Q8AYY5', '22_Q8JZ01', -0.24, 0.24],
 ['10_Q8B120', '11_Q8JZ03', -0.22, 0.22],
 ['26_Q8JZ04', '29_Q8JZ00', 0.16, 0.17],
 ['32_Q8B116', '34_Q8JZ05', -0.18, 0.21],
 ['89_Q8JZ02', ('26_Q8JZ04', '29_Q8JZ00'), 0.18, 0.08],
 ['12_Q8B117', ('10_Q8B120', '11_Q8JZ03'), -0.1, 0.12],
 ['31_Q8B121', ('32_Q8B116', '34_Q8JZ05'), -0.09, 0.1],
 ['114_Q6PXS', ('89_Q8JZ02', ('26_Q8JZ04', '29_Q8JZ00')), -0.12, 0.09],
 ['16_Q911P0', '17_Q8B118', 0.0, 0.0],
 ['18_Q8B114', '19_A0PJ27', 0.01, 0.0],
 ['20_A0PJ25', ('18_Q8B114', '19_A0PJ27'), 0.01, 0.0],
 ['6_O11999', '7_O11997', 0.0, 0.0],
 ['14_Q995C5', '15_Q8B119', 0.01, 0.01],
 ['27_Q8AYW1', '28_A1A3Z2', 0.0, 0.01],
 ['9_Q9DK03', ('12_Q8B117', ('10_Q8B120', '11_Q8JZ03')), 0.01, 0.12],
 ['4_VGLY_PI', '5_O11998', 0.01, 0.0],
 ['2_Q9YTW9', '3_Q9YTW8', 0.0, 0.0],
 [('6_O11999', '7_O11997'), ('4_VGLY_PI', '5_O11998'), 0.0, 0.01],
 ['0_Q9YTX1', ('2_Q9YTW9', '3_Q9YTW8'), -0.0, 0.01],
 ['1_O90423', ('0_Q9YTX1', ('2_Q9YTW9', '3_Q9YTW8')), 0.01, -0.01],
 ['8_Q9