# NWCK: Distances in Trees

In [128]:
filename = '/home/hanuman/docs/biomatics/rosalind/NWCK/input-test.txt'
filename = '/home/hanuman/docs/biomatics/rosalind/NWCK/input-final.txt'

In [12]:
def read_newick(filename):

    """A generator function that reads a list of Newick format trees and pairs of nodes."""

    with open(filename, 'rt') as f:
        n = 0
        for line in f:
            if n % 3 == 0:
                tree = line.rstrip().strip(';')
            if n % 3 == 1:
                pair = line.rstrip().split()
                yield (tree, pair)
            if n % 3 == 2:
                pass
            n += 1

In [125]:
def split_newick(newick_string):
    
    """Convert Newick string into reversed list of symbols"""
    
    s = list(newick_string)
    w = []
    node = ''
    for char in s:
        if char in {'(', ')', ','}:
            if node != '':
                w.append(node.strip())
                node = ''
            w.append(char)
        else:
            node += char
    if node != '':
        w.append(node.strip())
    w.reverse()
    for i in range(len(w)):
        if w[i] == ')':
            w[i] = '('
        elif w[i] == '(':
            w[i] = ')'
    return w

def label_newick_split(newick_split):
    l = newick_split
    open_parenth = [i for i, k in enumerate(l) if k == '(']
    open_parenth = sorted(open_parenth, reverse=True)
    n = len(l)
    counter = 0
    for j in open_parenth:
        if (l[j+1] == ')') or (l[j+1] == ','):
            l.insert(j+1, 'x' + str(counter))
            counter += 1
        if (j == 0) or (l[j-1] in [',', '(']):
            l.insert(j, 'x' + str(counter))
            counter += 1
    commas = [i for i, k in enumerate(l) if k == ',']
    commas = sorted(commas, reverse=True)
    for j in commas:
        if l[j+1] in [')', ',']:
            l.insert(j+1, 'x' + str(counter))
            counter += 1
    return l

def process_newick_string(newick_string):
    s = split_newick(newick_string)
    s = label_newick_split(s)
    return s

def break_comma_list(newick_split):
    l = newick_split
    counter = 0
    for i in range(len(l)):
        if (l[i] == ',') and (counter == 0):
            l[i] = '*'
        elif l[i] == '(':
            counter += 1
        elif l[i] == ')':
            counter -= 1
    w = [-1] + [i for i, k in enumerate(l) if k == '*'] + [len(l)]
    for j in range(len(w)-1):
        yield l[w[j]+1:w[j+1]]

## General Implementation of a Graph

In [80]:
class Node(object):
    
    def __init__(self, label):
        self.label = label
        self.parent = None
        self.children = set()
    
    def add_child(self, node):
        self.children.add(node)
        node.parent = self

class Tree(object):
    
    def __init__(self, root):
        self.root = root
    
    def tree_dict(self):
        mydict = {}
        queue = [self.root]
        while len(queue) > 0:
            node = queue.pop(0)
            mydict[node.label] = set()
            for child in node.children:
                queue.append(child)
                mydict[node.label].add(child.label)
        return mydict
    
    def graph_dict(self):
        mydict = self.tree_dict()
        for key in mydict:
            for child in mydict[key]:
                mydict[child].add(key)
        return mydict
    
    def find_all_paths(self, start_label, end_label, path=[]):
        graph = self.graph_dict()
        path = path + [start_label]
        if start_label == end_label:
            return [path]
        if graph[start_label] == set():
            return []
        paths = []
        for node in graph[start_label]:
            if node not in path:
                newpaths = self.find_all_paths(node, end_label, path)
                for newpath in newpaths:
                    paths.append(newpath)
        return paths
    
    def distance(self, start_label, end_label):
        paths = self.find_all_paths(start_label, end_label)
        l = paths.pop()
        return len(l) - 1

def build_tree(newick_list):
    
    """Inputs a processed newick list"""
    
    l = newick_list
    if len(l) == 0:
        return
    root_label = l.pop(0)
    root = Node(root_label)
    l.pop(0)
    l.pop()
    for item in break_comma_list(l):
        if len(item) == 1:
            node = Node(item[0])
        else:
            node = build_tree(item)
        root.add_child(node)
    return root

In [129]:
results = []
for pair in read_newick(filename):
    s = pair[0]
    start_label = pair[1][0]
    end_label = pair[1][1]
    w = process_newick_string(s)
    root = build_tree(w)
    tree = Tree(root)
    tree.tree_dict()
    n = tree.distance(start_label, end_label)
    results.append(n)
results = list(map(str, results))
print(' '.join(results))

22 14 2 29 2 15 22 14 34 2 24 4 24 2 2 2 57 2 15 22 78 12 5 21 13 2 8 30 9 13 5 15
