[Drippypale](https://github.com/drippypale/ROSALIND)  
Email: drippypale@gmail.com  

Problem: **[Implement the Neighbor Joining Algorithm](https://rosalind.info/problems/ba7d/)**

In [311]:
import numpy as np

In [312]:
with open('ba7e.in', 'r') as f:
    lines = f.read().splitlines()
    n = int(lines[0])
    dist_matrix = list()

    for l in lines[1:]:
        dist_matrix.extend([int(x) for x in l.split()])
    
    dist_matrix = np.array(dist_matrix, dtype=np.int64).reshape((n, n))

In [313]:
def total_distance(dist_matrix, i):
    return np.sum(dist_matrix[i])

the `to_Dstar()` function will calculate the $D^*$ matrix from the given $D$ matrix which is the distance matrix:

In [314]:
def to_Dstar(dist_matrix, n):
    n = dist_matrix.shape[0]
    Dstar = np.zeros(shape=dist_matrix.shape)
    for ind, x in np.ndenumerate(dist_matrix):
        if ind[0] == ind[1]:
            continue
        Dstar[ind]  = (n - 2) * dist_matrix[ind] - total_distance(dist_matrix, ind[0]) - total_distance(dist_matrix, ind[1])
    return Dstar

In [315]:
def arg_min(dist_matrix):
    ind_, min_dist = None, np.inf
    for ind, x in np.ndenumerate(dist_matrix):
        if ind[0] != ind[1] and x < min_dist:
            ind_, min_dist = ind, x
    return ind_

In [316]:
def delta(dist_matrix, i, j, n):
    return (total_distance(dist_matrix, i) - total_distance(dist_matrix, j)) / (n - 2)

In [317]:
def limb_length(dist_matrix, i, j, n):
    delta_ij = delta(dist_matrix, i, j, n)
    return (dist_matrix[i, j] + delta_ij) / 2, (dist_matrix[i, j] - delta_ij) / 2

In [318]:
def join_i_j(dist_matrix, i, j):
    Dprime = np.copy(dist_matrix)
    Dprime = np.delete(Dprime, (max(i, j)), axis=0)
    Dprime = np.delete(Dprime, (min(i, j)), axis=0)
    Dprime = np.delete(Dprime, (max(i, j)), axis=1)
    Dprime = np.delete(Dprime, (min(i, j)), axis=1)

    new_distances = []
    for k in range(dist_matrix.shape[0]):
        if k in (i, j):
            continue
        new_distances.append((dist_matrix[i, k] + dist_matrix[j, k] - dist_matrix[i, j]) / 2)
    
    return np.column_stack((np.row_stack((Dprime, new_distances)), new_distances + [0]))


In [319]:
remaining_nodes = [i for i in range(n)]

In [320]:
def neighbor_joining(D, n, remaining_nodes: list, m):
    if n == 2:
        T = {
            remaining_nodes[0]: {remaining_nodes[1]: D[0, 1]},
            remaining_nodes[1]: {remaining_nodes[0]: D[1, 0]}
        }
        return T
    Dstar = to_Dstar(D, n)
    i, j = arg_min(Dstar)
    limb_length_i, limb_length_j = limb_length(D, i, j, n)
    Dprime = join_i_j(D, i, j)

    ii, jj = remaining_nodes[i], remaining_nodes[j]
    remaining_nodes.remove(ii)
    remaining_nodes.remove(jj)
    remaining_nodes.append(m)

    T = neighbor_joining(Dprime, n - 1, remaining_nodes, m + 1)

    if T.get(ii, False):
        T[ii][m] = limb_length_i
    else:
        T[ii] = {m: limb_length_i}
    if T.get(jj, False):
        T[jj][m] = limb_length_j
    else:
        T[jj] = {m: limb_length_j}
    if T.get(m, False):
        T[m][ii] = limb_length_i
        T[m][jj] = limb_length_j
    else:
        T[m] = {
            ii: limb_length_i,
            jj: limb_length_j
        }

    return T


In [321]:
T = neighbor_joining(dist_matrix, n, remaining_nodes, m = n)

In [322]:
with open('ba7e.out', 'w') as f:
    for i in sorted(T.keys()):
        for j in sorted(T[i].keys()):
            f.write(f'{i}->{j}:{T[i][j]:.3f}\n')