### Group 8
- Nimrod Grandpierre
- Jonas Riber Jørgensen
- Johan Ulstrup
- Laura Fur

# Project 5: NJ tree construction
This project is about making an efficient implementation of the neighbor-joining (NJ) algorithm as shown on slide 50 in the slides about tree reconstruction and compare its performance to the NJ programs QuickTree and RapidNJ that you know from project 4.

## <span style="color:cornflowerblue">Problem<span/>
You should make a program that implements the NJ algorithm as shown in the slides about tree reconstruction. Your program program should take a distance matrix in phylip-format as input and produce a tree in newick-format as output. You should know these formats from project 4. Your aim is to make your implementation as efficient as possible.

The file example_slide4.phy contains the distance matrix (in phylip-format) from slide 4 in the slides about tree reconstruction. With this matrix as input, your program should produce the tree that is also shown on slide 4 in the slides about tree reconstruction.

In [15]:
import pandas as pd
import numpy as np
from read_phylip import *
from correct_DM import *
from write_newick import newick
import os
from io import StringIO
from Bio import Phylo
import copy

### Implementation

In [3]:
def compute_N(S, DM):
    N = copy.deepcopy(DM)

    dij = {seq_i: DM[seq_i] for seq_i in DM.keys()}
    ri = {seq_i: round((1/(len(S)-2)) * np.sum(distances_i), 2) for i, (seq_i, distances_i) in enumerate(DM.items()) for j in range(len(distances_i)) if i != j}
    
    for i, seq_i in enumerate(DM.keys()):
        for j, seq_j in enumerate(DM.keys()):
            if i != j:
                N[seq_i][j] = dij[seq_i][j] - (ri[seq_i] + ri[seq_j])
            else:
                N[seq_i][j] = 0
    return N

In [18]:
def NJ(DM: dict, S: list) -> dict:

    Clusters = {}

    while len(S) > 3:

        def Cluster(DM, S):
            nonlocal Clusters

            print(f'/33]1m/33]96mSequences left to process:/33]0m {len(S)}')

            N = compute_N(S, DM)

            """ Finding pair to cluster """ 
            minDistances = {i: min((distance, idx) for idx, distance in enumerate(row) if distance != 0.0) for i, row in N.items()}  # Finding the shortest distance and its index for each cluster in N.
            C1 = min(minDistances, key = lambda k: minDistances[k][0])          # Cluster 1 is the cluster with the shortest distance among all clusters' minimum distances.
            C2idx = minDistances[C1][1] 
            C2 = S[C2idx]

            k = (C1, C2)

            """ Calculating distance to common node: """
            ri = {seq_i: round((1 / (len(DM) - 2)) * np.sum(DM_i), 2) for i, (seq_i, DM_i) in enumerate(DM.items()) for j in range(len(DM_i)) if i != j}
            
            diu = round(0.5 * DM[C1][C2idx] + 0.5 * (ri[C1] - ri[C2]), 2)      # Distance from cluster 1 in k to common node.
            dju = round(DM[C1][C2idx] - diu, 2)                                # Distance from cluster 2 in k to common node.

            """ Adding k and edges, (i, k) and (j, k) """
            if C1 in Clusters:
                if C2 not in Clusters:
                    Clusters[k] = {C1: Clusters[C1], 'u': diu, C2: dju}
                    del Clusters[C1]

            elif C1 not in Clusters:
                if C2 in Clusters:
                    Clusters[k] = {C1: diu, C2: Clusters[C2], 'u': dju}
                    del Clusters[C2]
            
            elif C1 in Clusters:
                if C2 in Clusters:
                    Clusters[k] = {C1: Clusters[C1], 'u': diu, C2: Clusters[C2], 'u': dju}
                    del Clusters[C1]
                    del Clusters[C2]

            if k not in Clusters:    
                Clusters[k] = {C1: diu, C2: dju}
            
            """ Calculating distances from k to other clusters """
            dijk = {}
            for i, seq_i in enumerate(DM.keys()):
                if seq_i not in (C1, C2):
                    dik = DM[C1][i]
                    djk = DM[C2][i]
                    dij = DM[C1][C2idx]
                    dijk[seq_i] = round(0.5 * (dik + djk - dij), 2)
                dijk[k] = 0

            """ Update distance matrix """
            UDM = {seq: dists.copy() for seq, dists in DM.items()}

            UDM[k] = UDM.pop(C1)
            del UDM[C2] 
            
            for seq, dists in UDM.items():
                        UDM[seq] = np.delete(dists, list(DM.keys()).index(C2))

            DM_ids = {seq_i: i for i, seq_i in enumerate(DM.keys())}             # Getting cluster indices in the input distance matrix.

            for i, seq_i in enumerate(UDM.keys()):
                if seq_i != k:
                    for j, seq_j in enumerate(UDM.keys()):
                        if seq_j != seq_i:
                            if seq_j == k:
                                UDM[seq_i][j] = dijk[seq_i]                      # Distance from cluster to k.
                            if seq_j != k:
                                UDM[seq_i][j] = DM[seq_i][DM_ids[seq_j]]         # Distance from cluster to other clusters that are not k.
                        if seq_j == seq_i:
                            UDM[seq_i][j] = 0                                    # Distance from cluster to itself - always 0.
                if seq_i == k:
                    for j, seq_j in enumerate(UDM.keys()):
                        UDM[seq_i][j] = dijk[seq_j]                              # Distances from k to other clusters.

            """ Update taxa """
            S.remove(C1)
            S.remove(C2)
            S.append(k)

            return UDM, S
        
        clustered = Cluster(DM, S)
        DM = clustered[0]
        S = clustered[1]

    """ Termination """
    i, j, m = S

    idx_i = list(DM.keys()).index(i)
    idx_j = list(DM.keys()).index(j)
    idx_m = list(DM.keys()).index(m)

    gamma_vi = round((DM[i][idx_j] + DM[i][idx_m] - DM[j][idx_m])/2, 3)
    gamma_vj = round((DM[j][idx_i] + DM[j][idx_m] - DM[i][idx_m])/2, 3)
    gamma_vm = round((DM[m][idx_i] + DM[m][idx_j] - DM[i][idx_j])/2, 3)

    Clusters[(i, j)] = {i: gamma_vi, j: gamma_vj}
    Clusters['v'] = gamma_vm

    return Clusters

#### ```NJ()```

Follows the steps from Saitou and Nei's Neighbor-Joining algorithm.

## <span style = 'color:cornflowerblue'>Tests<span/>

Provided testdata with five sequences from ```example_slide4.phy```

In [5]:
ex_slide4 = initiate_DM('example_slide4.phy')
ex_slide4_taxa = [i for i in ex_slide4]

ex_out = NJ(DM = ex_slide4, S = ex_slide4_taxa)
ex_out

'(A:0.08,(B:0.1,D:0.07):0.05,(C:0.05,E:0.06):0.03);'

## <span style = 'color:cornflowerblue'>Experiments<span/>
From project 4, you know the programs QuickTree and RapidNJ that are implementations of the NJ methods. QuickTree implements the basic cubic time algorithm while RapidNJ implements an algorithm the is faster in practice.

You should compare the performance of your program against these two program in the following way.

The archive unique_distance_matrices.zip contains 14 distance matrices (in phylip-format) ranging in size from 89 to 1849 species. For each distance matrix, you should do the following:

1. Measure the time it takes to construct the corresponding NJ tree using QuickTree, RapidNJ, and your program.
2. Compute the RF-distances (using your program rfdist from project 4) between the trees produced by QuickTree, RapidNJ, and your program.

(If you want to investigate the running time of your program on more examples than provided in distance_matrices.zip, then you are welcome to download pfam_alignments.zip that contains 128 of alignment in Stockholm-format (from the Pfam database) aligning from 58 to 71535 species that you can convert to distance matrices in phylip-format using e.g. QuickTree. However, converting the big alignments to distance matrices would probably take too long and require too much space.)

In [7]:
uniqueDistMatrices = sorted([file for file in os.listdir('unique_distance_matrices')], key = lambda x: int(x.split('_')[0]))

In [23]:
treedata = initiate_DM(f'unique_distance_matrices/{uniqueDistMatrices[-1]}')
taxa = [i for i in treedata.keys()]

clusters = NJ(treedata, taxa)