# ERaBLE

## 0. Introduction

The problem addressed by Binet _et al._ (2016) is central to phylogenomics and involves the estimation of species trees based on a large set of gene trees. However they do not address the topological aspect of this problem but focus on how to get meaningful branch length estimates for the species tree based on a set of homologous genomic features.

The problem is an equality constrained continuous optimization problem. The objective is a weighted least squares criterion (WLS) and the problem translates to a classic quadratic programming problem that can be solved using Lagrange multipliers. Least squares problems are known to be convex. The problem can be stated as:

** Minimize: **
$$
    Q(\hat{\alpha},\hat{b}) = \sum\limits_{k=1}^m\sum\limits_{\{i,j\} \subset L_k} w_{ij}^{(k)}(\hat{\alpha}_k\delta_{ij}-\hat{d}_{ij})^2  
$$

** subject to: **
$$
    \sum\limits_{k=1}^mZ_k\hat{\alpha}_k = \sum\limits_{k=1}^{m}Z_k
$$

Where $L_k$ is the set of taxa for which sequences of $G_k$ (gene $k$) are available, $\delta_{ij}^{(k)}$ the distance between taxon $i$ and $j$ for gene $k$ (from input distance matrix), $\hat{b}$ the estimated branch length, $\hat{\alpha}_k$ the scale factor, with $\hat{\alpha}_k = 1/\hat{r}_k$ with $\hat{r}_k$ the evolutionary rate of $G_k$.

This is the general form of the constraint, ERaBLE chooses $Z_k = N_k\sum\limits_{i,j \in L_k}\delta_{ij}^{(k)}$
with $N_k$ the alignment length of $G_k$.
    

The problem can however be rephrased in matrix notation (Additional File 1 of Binet _et al._ 2016). Let $\hat{\alpha} = (\hat{\alpha}_1, \hat{\alpha}_2, \dots, \hat{\alpha}_m)^T$ and similarly $\hat{b} = (\hat{b}_1,\hat{b}_2, \dots, \hat{b}_\tau)^T$ with $\tau = |E(\mathcal{T})|$ where $\mathcal{T}$ denotes the tree topology graph. Let $\delta_k$ be the vector with all pairwise distances $\delta_{ij}^{(k)}$ for $G_k$. Similarly, let $\hat{d}$ be the vector with all additive distances from $\hat{b}$ and let $\hat{d}_k$ denote the vector of all additive distances for the taxa included in $L_k$.

We now define the topological matrix $A$ that represents the topology $\mathcal{T}$ as the $n(n-1)/2 \times \tau$ matrix where each row corresponds to a pair of taxa and each column to a branch in $\mathcal{T}$. Let $a_{ij,e}$ be the element for pair of taxa $i$ and $j$ for branch $e$, $a_{ij,e}$ is set to $1$ if $e$ is on the path between $i$ and $j$, and to $0$ otherwise. Similar as with the definition of $\hat{d}_k$, let $A_k$ be the $|L_k|(|L_k|-1)/2 \times \tau$ matrix obtained by removing all pairs of taxa that are not present in in $L_k$. Now $\hat{d} = A\hat{b}$ and $\hat{d}_k = A_k\hat{b}$.

Let $W_k$, the weight matrix, be the square matrix with as diagonal elements $w_{ij}^{(k)}$ and all other elements set to $0$. Finally let $z = (Z_1,Z_2,\dots,Z_m)^T$ and $Z=\sum\limits_{k=1}^{m}Z_k$.

The WLS problem can now be expressed as:

** Minimize: **
$$
    Q(\hat{\alpha},\hat{b}) = \sum\limits_{k=1}^m(\hat{\alpha}_k\delta_k - A_k\hat{b})^TW_k(\hat{\alpha}_k\delta_k - A_k\hat{b})
$$

** Subject to: **
$$
    z^T\hat{\alpha} = Z
$$

This problem can be solved using the method of Lagrange multipliers. The Lagrangian function is:

$$\mathcal{L}(\hat{\alpha},\hat{b},\lambda) = Q(\hat{\alpha},\hat{b})+\lambda(z^T\hat{\alpha}-Z)$$

A necessary and sufficient condition for the tuple $(\hat{\alpha}, \hat{b})$ to be an optimal solution for the above problem is that the gradient of the Lagrangian is zero _i.e._ $\nabla_{\hat{\alpha},\hat{b},\lambda}\mathcal{L}(\hat{\alpha}, \hat{b}, \lambda) = 0$, note that this is sufficient since the WLS problem is convex. We get the system of equations:

$$
    \frac{\partial\mathcal{L}}{\partial\hat{\alpha}_k} = 0
$$
$$
    \frac{\partial\mathcal{L}}{\partial\hat{b}} = 0
$$
$$
    \frac{\partial\mathcal{L}}{\partial\lambda} = 0
$$

Which results in:

$$
    \hat{\alpha}_k\delta_k^TW_k\delta_k - \delta^TW_kA_kb + \lambda z^T/2 = 0
$$
$$
    \sum\limits_{k=1}^{m}(A_k^TW_kA_k\hat{b} - \hat{\alpha}_kA_k^tW_k\delta_k) = 0
$$
$$
    z^T\hat{\alpha} -Z = 0
$$

Here the first equation is a system of $m$ equations (with $m$ the number of genes), the second a system of $\tau$ equations (with $\tau$ the number ofb branches) and the last consist of one equation. This can be rewritten as:

$$
    D\hat{\alpha} + B^T\hat{b} + \lambda z = 0
$$
$$
    B\hat{\alpha} + C\hat{b} = 0
$$
$$
    z^T\hat{\alpha}=Z
$$
With $D = \delta_k^TW_k\delta_k$, $B^T = - \delta^T_kW_kA_k$ and $C = \sum\limits_{k=1}^mA^T_kW_kA_k$. The $1/2$ coeffici&euml;nt was dropped, as we are not interested in the value of the Lagrange multiplier. This system can be solved naively by matrix multiplication in $O(mn^4)$ time. Note that $Z_k = N_k\sum\limits_{i,j \in L_k}\delta_{ij}^{(k)}$ with $N_k$ the alignment length of $G_k$

## 1. Calculation of distances

Some functions to get distance matrices from a bunch of multiple sequence alignments.

In [2]:
from numpy import inf
from heapq import heappop, heappush, heappushpop, heapify
import numpy as np
import pandas as pd
import os
import re
import itertools
from ete3 import Tree

In [4]:
species = {
    'arath': re.compile('AT[^R].+'),
    'ambtr': re.compile('ATR.+'),
    'chlre': re.compile('CR.+'),
    'phypa': re.compile('PP.+'),
    'poptr': re.compile('PT.+'),
    'orysa': re.compile('OS.+'),
    'zeama': re.compile('ZM.+'),
    'zosma': re.compile('Zosma.+'),
    'ulvmu': re.compile('evm.+'),
    'vitvi': re.compile('VV.+'),
    'thepa': re.compile('TP.+'),
    'bradi': re.compile('BD.+'),
    'spipo': re.compile('Spipo.+'),
    'caros': re.compile('Caros.+'),
    'ostta': re.compile('OT.+')
}

def proportion_different_sites(s1, s2):
    """
    Calculate the proportion of different sites between pairwise sequences
    :param s1:
    :param s2:
    :return:
    """
    p = 0
    for i in range(len(s1)):
        if s1[i] != s2[i]:
            p += 1
    return p/len(s1)


def jukes_cantor(s1, s2):
    """
    Implementation of the JC69 model (Jukes & Cantor, 1969) which assumes
    that every nucleotide has the same instantaneous rate of changing into another
    nucleotide (an assumption that is almost never valid).
    :return:
    """
    p = proportion_different_sites(s1, s2)
    d = -(3/4)*np.log(1-(4/3)*p)
    var_d = (p*(1-p)/len(s1))*(1/(1-4*p/3))
    return d, var_d


def distance_matrix(msa, distance='JC69'):
    """
    Function to get a distance matrix from a multiple sequence alignment
    :param msa: dictionary with aligned sequences
    :param distance: distance metric to use (currently only JC69 supported)
    :return: distance matrix (pandas data frame)
    """
    l = list(msa.keys())
    df = pd.DataFrame(np.zeros((len(l),len(l))), index=l, columns=l)
    
    for i in range(len(list(df.index))):
        gene_1 = df.index[i]
        
        for j in range(i+1,len(list(df.index))):
            gene_2 = df.index[j]
            
            if distance == 'JC69':
                d = jukes_cantor(msa[gene_1], msa[gene_2])
                df[gene_1][gene_2] = d[0]
                df[gene_2][gene_1] = d[0]
    
    return df


def read_msa(msa):
    """
    Read multiple sequence alignment (MSA) in PHYLIP format.
    :param msa: MSA file
    :return: dictionary
    """
    msa_dict = {}
    
    with open(msa, 'r') as f:
        content = f.readlines()
        
    n_seq, length = int(content[0].split()[0].strip()), int(content[0].split()[1].strip())

    for i in range(1,n_seq*2,2):
        msa_dict[content[i].strip()] = content[i+1].strip()
            
    return msa_dict


def get_species(gene, species):
    """
    Get species for a gene using regex matches.
    """
    for sp, p in species.items():
        if p.match(gene):
            return sp
    raise ValueError("species not found for gene {}!".format(gene))
    

def collapse_on_species(distance_matrix, species):
    """
    Collapse a given distance matrix for a gene family on species level.
    Distances for duplicates within a species are averaged.
    """
    matrix = pd.DataFrame()
    sp_set = set()
    to_drop = []
    for gene in distance_matrix.index:
        sp = get_species(gene, species)
        if sp in sp_set:
            matrix[sp] += distance_matrix[gene]
            matrix[sp] /= 2
            to_drop.append(gene)
        else:
            matrix[sp] = distance_matrix[gene]
            sp_set.add(sp)
        
    matrix = matrix.drop(to_drop)
    matrix.index = matrix.columns
    
    for k in range(len(matrix.index)):
        matrix.iloc[k,k] = 0
    return matrix

## 2. Helper functions for ERaBLE

In [5]:
def topology_matrix_dijkstra(tree, species):
    """
    Input: newick tree
    Output: topology matrix
    """
    t = Tree(tree)
    i = len(species)
    for node in t.traverse():
        if node.name in species:
            node.name = species[node.name]
        else:
            node.name = i
            i+=1
    
    graph = tree_to_adj_list(t)
    leaves = list(t.get_leaves())
    paths = {}
    branches = set()
    for i in range(len(leaves)):
        node1 = leaves[i]
        for j in range(i+1, len(leaves)):
            node2 = leaves[j]
            path, distance = dijkstra(graph, node1.name, node2.name)
            path = [(path[x], path[x+1]) for x in range(len(path)-1)]
            paths[tuple(sorted([node1.name, node2.name]))] = path
            for p in path:
                if p not in branches:
                    branches.add(p)
            
    # put in a matrix
    df = pd.DataFrame(index=list(paths.keys()), columns=list(branches))
    for tup, path in paths.items():
        for branch in path:
            df[branch][tup] = 1
    df = df.fillna(0)
    df.sort_index(0, inplace=True)
    df.sort_index(1, inplace=True)
    return df, t


def dijkstra(graph, source, sink=None):
    """
    Implementation of Dijkstra's shortest path algorithm
    Inputs:
        - graph : dict representing the weighted graph
        - source : the source node
        - sink : the sink node (optional)
    Ouput:
        - distance : dict with the distances of the nodes to the source
        - came_from : dict with for each node the came_from node in the shortest
                    path from the source
    """
    distance = {v: inf for v in graph}
    distance[source] = 0
    current = source
    previous = {}
    Q = [(source, 0)]
    
    while Q:
        U = heappop(Q)
        if sink and U == sink:
            break
        for V in graph[U[0]]:
            alt = distance[U[0]] + V[0]
            if alt < distance[V[1]]:
                distance[V[1]] = alt
                previous[V[1]] = U[0]
                heappush(Q, (V[1], alt))
                
    if sink is None:
        return distance, previous
    else:
        return reconstruct_path(previous, source, sink), distance[sink]

    
def reconstruct_path(previous, source, sink):
    """
    Reconstruct the path from the output of the Dijkstra algorithm
    Inputs:
            - previous : a dict with the came_from node in the path
            - source : the source node
            - sink : the sink node
    Ouput:
            - the shortest path from source to sink (list)
    """
    if sink not in previous:
        return []
    
    V = sink
    path = [V]
    while V != source:
        V = previous[V]
        path.append(V)
    return path


def tree_to_adj_list(tree):
    """
    Convert ete3 Tree object to an adjacency list representing the tree graph
    """
    adj_list = {}
    for node in tree.traverse('postorder'):
        l = []
        if not node.is_leaf():
            # add children 
            l += [(1, c.name) for c in node.children]
        if not node.is_root():
            # add parent
            l.append((1, node.up.name))
        adj_list[node.name] = set(l)
    return adj_list


def code_species(species):
    """
    Code species to integers
    """
    if type(species) == dict:
        species = list(species.keys())
    species_to_int = {x: 0 for x in species}
    i = 0
    for x in sorted(species):
        species_to_int[x] = i
        i += 1
    return species_to_int


def sub_topology_matrix(subset, matrix):
    """
    Get topology matrix for a subset of taxa
    """
    to_drop = []
    for pair in matrix.index:
        if pair[0] not in subset or pair[1] not in subset:
            to_drop.append(pair)
    matrix = matrix.drop(to_drop)
    return matrix


def sequence_length(msa_file):
    """
    Construct a weight matrix. Using alignment length as weight.
    """
    with open(msa_file, 'r') as f:
        content = f.readlines()
        
    length = int(content[0].split()[1].strip())
    return length


def weight_matrix(length, s):
    """
    Construct a weight matrix. Using alignment length as weight.
    """
    return np.diag([length for i in range(s)])


def vectorize_delta(delta):
    """
    Convert distance matrix to vector
    """
    d = np.array(delta)
    l = []
    for i in range(0,len(d)):
        for j in range(i+1, len(d)):
            l.append(d[i,j]) 
    return np.array(l)

## 3. ERaBLE implementation

In [13]:
def erable(tree_file, species, msa_files=None, distance_matrices=None, lengths=None):
    if not msa_files and not distance_matrices:
        raise ValueError("Provide either distance matrices or MSA files")
        
    species_mapping = code_species(species)
    
    A, t = topology_matrix_dijkstra(tree_file, species_mapping)
    
    # construct C
    C = np.zeros((A.shape[1], A.shape[1]))
    weights = []
    Aks = []
    z = []
    msas = []
    Nks = []
    delta_mats = []

    if not msa_files:
        msa_files = distance_matrices
        
    for k in range(len(msa_files)):
        if not distance_matrices:
            msa = read_msa(msa_files[k])
            delta_matrix = collapse_on_species(distance_matrix(msa), species)
        else:
            delta_matrix = distance_matrices[k]
        delta_matrix.index = [species_mapping[x] for x in delta_matrix.index]
        delta_matrix.columns = list(delta_matrix.index)
        delta_matrix.sort_index(0, inplace=True)
        delta_matrix.sort_index(1, inplace=True)

        if delta_matrix is not None and len(delta_matrix) > 1:
            delta_mats.append(delta_matrix)
            msas.append(msa_files[k])
        
    D = np.zeros((len(delta_mats), len(delta_mats)))
    B = np.zeros((A.shape[1], len(delta_mats)))
        
    for k in range(len(delta_mats)):
        delta = vectorize_delta(delta_mats[k]).reshape(-1,1)
        Ak = sub_topology_matrix(set(delta_mats[k].index), A)
        if lengths:
            Nk = lengths[k]
            Nks.append(Nk)
        else:
            Nk = sequence_length(msas[k])
            Nks.append(Nk)
        Wk = weight_matrix(Nk,len(delta))
        C += Ak.T @ Wk @ Ak
        weights.append(Wk)
        Aks.append(Ak)
        D[k,k] = delta.T @ Wk @ delta
        B[:,k] = list(-Ak.T @ Wk @ delta)
        Zk = Nk * delta.sum()
        z.append(Zk)
    
    Z = sum(z)
    
    return D, B, z, C, Z
    
    # invert D, diagonal matrix so inverse is just diagonal elements^-1
    D_inv = 1/D
    D_inv[np.isinf(D_inv)] = 0
    
    z = np.array(z).reshape((-1,1))
    u = B @ D_inv @ z
    w = (z.reshape((1,-1)) @ D_inv @ z)[0][0]
    M = C + (1/w) * u @ u.T - B @ D_inv @ B.T
    
    #return M, Z, w, u, z, B, D_inv, A, Nks
    #return M, Z, w, u
    b = np.linalg.solve(M, -(Z/w)*u)
    alpha = D_inv @ ((z.reshape((-1,1)) @ u.reshape((1,-1)))/w - B.T) @ b + (Z/w)*D_inv @ z
    #alpha = D_inv @ ((z.reshape((1,-1)) @ u.reshape((-1,1)))/w - B.T) @ b + (Z/w)*D_inv @ z
    #alpha[alpha == 0] = 0.0001 # this shouldn't be necessary
    c = (1/sum(Nks)) * sum([Nks[i]/alpha[i] for i in range(len(Nks))])
    
    r = 1/(c*alpha)
    b = c * b

    return(b, r, c, t)

## 4. Test

In [7]:
msa_files = [os.path.join('subset/', x) for x in sorted(os.listdir('subset/'))]
b, r, c, tree = erable('tree.nwck', species, msa_files=msa_files)
#M, Z, w, u, z, B, D_inv, A, Nks = erable(msa_files, 'tree.nwck', species)

In [8]:
b

array([[  3.28167584e+01],
       [  1.10677638e+01],
       [  3.92583444e+00],
       [  8.22513533e+00],
       [  1.34470477e+03],
       [  3.91750038e+00],
       [  3.22303845e+02],
       [  5.25676344e+00],
       [  1.03627275e+01],
       [  1.10656255e+01],
       [  1.34466238e+03],
       [  3.77135929e+00],
       [  1.43907678e+00],
       [  1.03796727e+01],
       [  9.15167735e+02],
       [ -1.45907169e+02],
       [ -1.13572470e+02],
       [ -3.21474737e+02],
       [ -1.09077822e+03],
       [ -3.23445998e+01],
       [  1.13737427e+02],
       [ -2.37942439e+01],
       [ -1.34423909e+03],
       [ -1.02246936e+03],
       [  2.36863848e+01],
       [  8.20169979e+00],
       [ -5.61197928e-01],
       [ -8.13629047e+00],
       [ -6.94833054e+00],
       [ -1.55796848e+01],
       [ -7.82233675e+00],
       [  9.13570774e-01],
       [ -4.51979344e+00],
       [ -9.88476068e+00],
       [ -9.86404353e+00],
       [  6.93193704e+00],
       [ -1.21898092e+00],
 

## 5. Example from ERaBLE

In [15]:
mats = []
for f in os.listdir('../../erable1.0_Unix_Linux/Example/mats'):
    m = pd.read_csv(os.path.join('../../erable1.0_Unix_Linux/Example/mats', f), sep="\s+", index_col=0, header=None)
    m.columns = m.index
    mats.append(m)

lengths = [382, 210, 590]
tree_file = '../../erable1.0_Unix_Linux/Example/input_tree_topology.nwk'
tree_s = '((Tax10:0.51144,(Tax34:0.19411,(Tax20:0.19064,Tax4:0.15564)1:0.09995)1:0.09856),(Tax14:0.10233,Tax3:0.03293)1:0.01914);'
species = [n.name for n in Tree(tree_s).get_leaves()]

In [16]:
#erable(tree_file=tree_s, species=species, distance_matrices=mats, lengths=lengths)
D, B, z, C, Z = erable(tree_file=tree_s, species=species, distance_matrices=mats, lengths=lengths)

In [80]:
def solve_constrained_quadratic_problem(P, q, A, b):
    """
    Solve a linear constrained quadratic convex problem.
    Inputs:
        - P, q: quadratic and linear parameters of
                the linear function to be minimized
        - A, b: system of the linear constaints
    Outputs:
        - xstar: the exact minimizer
        - vstar: the optimal Lagrance multipliers
    """
    #p, n = A.shape  # size of the problem
    n = A.shape[1]
    A = np.bmat([[P,A.T],[A,np.zeros((A.shape[0],A.shape[0]))]])
    B = np.bmat([[-q],[b]])
    solution = np.linalg.solve(A,B)
    xstar = solution[:n]
    vstar = solution[n:]
    return np.array(xstar), np.array(vstar)

In [82]:
b = np.array([0.5]*C.shape[0])
z = np.array(z).reshape(-1,1)
alpha, l = solve_constrained_quadratic_problem(
    D, (-B.T @ b).reshape(-1,1), z.T, np.array(Z).reshape(1,1)
)

In [76]:
solve_constrained_quadratic_problem()

matrix([[ 2030.378323],
        [  676.73865 ],
        [ 6059.52066 ],
        [ 4028.844772]])