In [None]:
import os
import argparse
import re
import itertools
import random
import math
#import numpy as np
#import pandas as pd

# Set up constants

In [None]:
DNA = ('A', 'C', 'G', 'T')
RNA = ('A', 'C', 'G', 'U')

MAP_seq2NUMBER = {'A':0, 'C':1, 'G':2, 'T':3}
MAP_RNA2NUMBER = {'A':0, 'C':1, 'G':2, 'U':3}

MAP_NUMBER2DNA = {0:'A', 1:'C', 2:'G', 3:'T'}
MAP_NUMBER2RNA = {0:'A', 1:'C', 2:'G', 3:'U'}

MAP_CODON2AMINOACID = {'GUC': 'V', 'ACC': 'T', 'GUA': 'V', 'GUG': 'V', 'GUU': 'V', 'AAC': 'N', 'CCU': 'P', 'UGG': 'W',
                       'AGC': 'S', 'AUC': 'I', 'CAU': 'H', 'AAU': 'N', 'AGU': 'S', 'ACU': 'T', 'CAC': 'H', 'ACG': 'T',
                       'CCG': 'P', 'CCA': 'P', 'ACA': 'T', 'CCC': 'P', 'GGU': 'G', 'UCU': 'S', 'GCG': 'A', 'UGC': 'C',
                       'CAG': 'Q', 'GAU': 'D', 'UAU': 'Y', 'CGG': 'R', 'UCG': 'S', 'AGG': 'R', 'GGG': 'G', 'UCC': 'S',
                       'UCA': 'S', 'GAG': 'E', 'GGA': 'G', 'UAC': 'Y', 'GAC': 'D', 'GAA': 'E', 'AUA': 'I', 'GCA': 'A',
                       'CUU': 'L', 'GGC': 'G', 'AUG': 'M', 'CUG': 'L', 'CUC': 'L', 'AGA': 'R', 'CUA': 'L', 'GCC': 'A',
                       'AAA': 'K', 'AAG': 'K', 'CAA': 'Q', 'UUU': 'F', 'CGU': 'R', 'CGA': 'R', 'GCU': 'A', 'UGU': 'C',
                       'AUU': 'I', 'UUG': 'L', 'UUA': 'L', 'CGC': 'R', 'UUC': 'F', 'UAA': 'X', 'UAG': 'X', 'UGA': 'X'}

ALPHANUMERIC = [chr(i) for i in range(48, 58)] + [chr(i) for i in range(65, 91)] + [chr(i) for i in range(97, 123)]
WHITESPACE = re.compile(r'\s*')

# Chapter 1

In [None]:
def get_kmer_freq(sequence, kmer):
    """
    PatternCount
    
    Count the number of times <kmer> appears in <sequence>.
    """
    count = 0
    # Slide window
    for i in range(len(sequence) - len(kmer) + 1):
        # Check match
        if sequence[i:i + len(kmer)] == kmer:
            # Update count
            count += 1
    return count


def convert_nucleotide_to_number(nucleotide, nucleotide_type='DNA'):
    """
    PatternToNumber
    
    Use lexicographical order to turn <nucleotide> into a number.
    Kmer represents a 4-ary number.
    
    Removing the final symbol from all lexicographically ordered k-mers maintians the resulting list to be still ordered lexicographically.    
    """
    if len(nucleotide) == 0:
        return 0
    
    if nucleotide_type == 'DNA':
        return 4 * convert_nucleotide_to_number(nucleotide[:-1]) + MAP_seq2NUMBER[nucleotide[-1]]
    elif nucleotide_type == 'RNA':
        return 4 * convert_nucleotide_to_number(nucleotide[:-1], nucleotide_type='RNA') + MAP_RNA2NUMBER[nucleotide[-1]]

    
def convert_number_to_nucleotide(number, k, nucleotide_type='DNA'):
    """
    NumberToPattern
    
    Use lexicographical order to turn <number> into a nucleotide of length <k>.
    """
    if k == 1:
        if nucleotide_type == 'DNA':
            return MAP_NUMBER2DNA[number]
        elif nucleotide_type == 'RNA':
            return MAP_NUMBER2RNA[number]
    
    prefix_index = int(number / 4)
    if nucleotide_type == 'DNA':
        return convert_number_to_nucleotide(prefix_index, k-1) + MAP_NUMBER2DNA[number % 4]
    elif nucleotide_type == 'RNA':
        return convert_number_to_nucleotide(prefix_index, k-1, nucleotide_type='RNA') + MAP_NUMBER2DNA[number % 4]

    
def make_kmer_freq_array(sequence, k):
    """
    ComputingFrequencies
    
    Make a frequency array holding all <k>mers (length 4^k),
    where the ith element of the array holds the number of times
    that the i-th k-mer appears (in the lexicographic order) in <sequence>.
    param sequence: string
    param k: int
    return: numpy array
    """
    # Make an empty array of length 4^k
    freq_array = [0 for i in range(4 ** k)]#np.zeros(4**k)
    
    # Slide over <sequence>
    for i in range(len(sequence) - k + 1):
        # For each kmer, get its number representation <n> and increment <freq_array>[<n>]
        n = convert_nucleotide_to_number(sequence[i:i+k])
        freq_array[n] += 1
    return freq_array


def get_most_freq_kmers(sequence, k, algorithm='freq_array'):
    """
    FrequentWords & FreqArrayFrequentWords & FreqArraySortFrequentWords
    
    Find the most frequenct <k>mer in <sequence>.
    """
    most_freq_kmers = []
    max_freq = 0
    length = len(sequence)
    
    if algorithm == 'freq_array':
        # Make frequency array
        freq_array = make_kmer_freq_array(sequence, k)
        for i, freq in enumerate(freq_array):
            if freq > max_freq:
                most_freq_kmers = [convert_number_to_nucleotide(i, k)]
                max_freq = freq
            elif freq == max_freq:
                most_freq_kmers.append(convert_number_to_nucleotide(i, k))
    
    elif algorithm == 'brute_force':
        # Slide window
        for i in range(len(sequence) - k + 1):
            kmer = sequence[i:i + k]
            # Get frequency of this kmer
            freq = get_kmer_freq(sequence, kmer)
            if freq > max_freq:
                most_freq_kmers = [kmer]
                max_freq = freq
            elif freq == max_freq: # Tie
                most_freq_kmers.append(kmer)
        
    elif algorithm == 'sort':
        # Lexicographical order of kmers
        kmer_lex_idx = [(None, 0) for i in range(length - k + 1)]
        kmer_count = {}
        for i in range(length - k + 1):
            kmer = sequence[i:i + k]
            kmer_lex_idx[i] = kmer, convert_nucleotide_to_number(kmer)
            kmer_count[kmer] = 1

        # Sort kmers by lexicographical order
        print('kmer_lex_idx: before sort', kmer_lex_idx)        
        kmer_lex_idx = sorted(kmer_lex_idx, key=lambda x: x[1])
        print('kmer_lex_idx: after sort', kmer_lex_idx)

        for i in range(1, len(kmer_lex_idx)):
            prev_t = kmer_lex_idx[i - 1]
            t = kmer_lex_idx[i]
            if prev_t[1] == t[1]:
                kmer_count[t[0]] = kmer_count[prev_t[0]]  + 1

        for k, v in kmer_count.items():
            if v > max_freq:
                most_freq_kmers = [k]
                max_freq = v
            elif v == max_freq:
                most_freq_kmers.append(k)
                
    return set(most_freq_kmers), max_freq


def reverse_complement(dna_sequence):
    """
    ReverseComplement
    
    Return the reverse complement of <dna_sequence>.
    """
    complements = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N': 'N', '': ''}
    return ''.join([complements[x] for x in dna_sequence[::-1]])


def get_hamming_distance_between_kmers(kmer1, kmer2):
    """
    HammingDistance
    
    Get hamming distance between <kmer1> and <kmer2>.
    """
    assert len(kmer1) == len(kmer2), 'Length of 2 kmers must be the same'
    hamming_distance = 0
    for i, char1 in enumerate(kmer1):
        if char1 != kmer2[i]:
            hamming_distance += 1
    return hamming_distance


def get_kmer_occurence(kmer, sequence, thres=0):
    """
    PatternMatch & ApproximatePatternMatch
    
    Get all indices where <subseq> apprears in <seq>.
    """
    kmer_length = len(kmer)
    assert kmer_length <= len(sequence), 'kmer must be shorter than sequence'
    occurences = []
    for i in range(len(sequence) - kmer_length + 1):
        if thres > 0 and get_hamming_distance_between_kmers(sequence[i:i + kmer_length], kmer) <= thres:
            occurences.append(i)
        elif sequence[i: i + kmer_length] == kmer:
            occurences.append(i)
    return occurences


def count_approximate_kmer_occurence(kmer, sequence, d):
    """
    ApproximatePatternCount
    
    Count the occurences of <kmer> in <sequence>, including mers with at most d mismatches with <kmer>.
    """
    return len(get_kmer_occurence(kmer, sequence, thres=d))


def get_clumping_kmers(seq, k, clump_length, thres):
    """
    ClumpFinding
    
    Find clumping kmers (kmers that appears more than <thres> times within a subsequence of length <clump_length>) in <seq>.
    """
    kmers = [0 for i in range(4 ** k)]#np.zeros(4**k)

    # Slide window in which to evaluate clumps
    for i in range(len(seq) - clump_length + 1):
        # 1st window
        if i == 0:
            # Make the 1st frequency array
            freq_array = make_kmer_freq_array(seq[i:clump_length], k)
            # Record kmers that clump in this window
            for j, freq in enumerate(freq_array):
                if freq >= thres:
                    kmers[j] = 1
        else:
            # Decrement for the kmer just passed (1st kmer of the last window)
            freq_array[convert_nucleotide_to_number(seq[i - 1:i - 1 + k])] -= 1
            # Increment for the kmer just encountered (last kmer of this window)
            last_kmer = seq[i + clump_length - k:i + clump_length]
            last_kmer_freq_array_idx = convert_nucleotide_to_number(last_kmer)            
            freq_array[last_kmer_freq_array_idx] += 1
            
            # Check if the kmer just encountered (last kmer of this window) passes <thres>
            if freq_array[last_kmer_freq_array_idx] >= thres:
                kmers[last_kmer_freq_array_idx] += 1
    
    return set([convert_number_to_nucleotide(i, k) for i, clump_count in enumerate(kmers) if clump_count > 0])


def make_cumulative_freq_difference_array(seq, kmer1, kmer2):
    """
    Skew
    
    Return an array of length |<seq>|, whose i-th element is the
    cumulative frequency difference of <kmer1> and <kmer2> up to i-th index.
    """
    assert len(kmer1) == len(kmer2)
    k = len(kmer1)
    length = len(seq)
    
    cumulative_freq_difference_array = [0 for i in range(length - k + 1)] #np.empty(length - k + 1)
    num_kmer1 = num_kmer2 = 0
    # Slide window
    for i in range(length - k + 1):
        if seq[i:i + k] == kmer1:
            num_kmer1 += 1
        elif seq[i:i + k] == kmer2:
            num_kmer2 += 1
        cumulative_freq_difference_array[i] = (num_kmer1 - num_kmer2)
    return cumulative_freq_difference_array


def get_cumulative_kmer_freq_difference_array_min(seq, kmer1, kmer2):
    """
    Get the index with the minimum of cumulative kmer frequency difference array, whose ith element is the
    cumulative frequency difference of <kmer1> and <kmer2> up to i.
    """
    cumulative_freq_difference_array = make_cumulative_freq_difference_array(seq, kmer1, kmer2)
    #return np.where(cumulative_freq_difference_array==cumulative_freq_difference_array.min())
    min_ = 0
    min_indices = [0]
    for i, diff in enumerate(cumulative_freq_difference_array):
        if diff < min_:
            min_indices = [i + 1]
            min_ = diff
        elif diff == min_:
            min_indices.append(i + 1)
    return min_indices


def get_sequences_with_atmost_d_mismatch(sequence, d, nucleotide_type=DNA, include_self=True):
    """
    Neighbors
    
    Ruturn all sequences with <d> mismatches from <sequence>.
    """
    # No mismatch
    if d == 0:
        return sequence
    
    # Base case with 1 nucleotide
    if len(sequence) == 1:
        return set(nucleotide_type)

    atmost_d_mismatch_seqs = set()
    
    '''
    # Get sequences with at most <d> mismatches from suffix
    first_nucleotide = sequence[0]
    suffix = sequence[1:]
    suffix_atmost_d_mismatch_seqs = get_sequences_with_atmost_d_mismatch(suffix, d)
    for seq in suffix_atmost_d_mismatch_seqs:
        if get_hamming_distance_between_kmers(suffix, seq) < d:
            # Can add 1 or more mismatches
            for s in [BASE + seq for BASE in nucleotide_type]:
                atmost_d_mismatch_seqs.add(s)
                
        else:
            atmost_d_mismatch_seqs.add(first_nucleotide + seq)
    return atmost_d_mismatch_seqs
    '''
    
    # Get <d> combinations of indecises at which sequence can differ
    for differing_i in itertools.combinations(range(len(sequence)), d):
        #print('differing_i:', differing_i)
             
        # Break apart <sequence> (to use itertools.product later)
        broken_apart_seq = [[letter] for letter in sequence]
        #print('broken_apart_seq:', broken_apart_seq)
    
        # Substitute [A, C, G, T] at these <differing_i>
        for idx in differing_i:
            broken_apart_seq[idx] = nucleotide_type
        #print('broken_apart_seq: after substitution:', broken_apart_seq)
        
        # Cross multiply and add resulting kmers with atmost d mismatches to <atmost_d_mismatch_seqs>
        for seq in [''.join(i) for i in [c for c in itertools.product(*broken_apart_seq)]]:
            atmost_d_mismatch_seqs.add(seq)

    if include_self:
        atmost_d_mismatch_seqs.add(sequence)
        
    return atmost_d_mismatch_seqs


def get_keys_with_max_value(dictionary):
    """
    Get <dictionary>'s keys holding the max value.
    """
    #sorted(kmer_to_freq.items(), key=lambda kv: kv[1])
    keys = set()
    max_value = max(dictionary.values())
    for k, v in dictionary.items():
        if v == max_value:
            keys.add(k)
    return keys


def get_most_frequent_kmer_and_reverse_complement_with_d_mismatch(seq, k, d, choices=['A', 'C', 'G', 'T'], search_reverse_complement=True):
    """
    FrequentWordsWithMismatchesAndReverseComplements
    
    Get the most frequent kmer whose <d>-mismatch sequences appear most frequently in <seq>.
    """
    # Dictionary with kmer key and d-mismatch occurance value
    kmer_to_freq = {}
    
    # Look at each kmer, sliding the window
    for i in range(len(seq) - k + 1):
        # For each similar kmer with at most <d> mismatches
        for s in get_sequences_with_d_mismatch(seq[i:i + k], d):
            if s in kmer_to_freq:
                kmer_to_freq[s] += 1
            else:
                kmer_to_freq[s] = 1
        if search_reverse_complement:
            for s in get_sequences_with_d_mismatch(reverse_complement(seq[i:i + k]), d):
                if s in kmer_to_freq:
                    kmer_to_freq[s] += 1
                else:
                    kmer_to_freq[s] = 1
    #print(kmer_to_freq)
    return get_keys_with_max_value(kmer_to_freq)


def find_oric_and_dnaa_box(genome, k, d):
    """
    Find OriC and DnaA box in a genome.
    """
    oric_loci = get_cumulative_kmer_freq_difference_array_min(g, 'G', 'C')
    ave_min_ = sum(oric_loci) / len(oric_loci)
    replication_start_window = g[ave_min_ - 500:ave_min_ + 500]
    kmers = get_most_frequent_kmer_and_reverse_complement_with_d_mismatch(replication_start_window, k, d)
    return oric_loci, kmers

# Chapter 2

In [None]:
def read_mtrx_from_lines(lines):
    mtrx = [[] for i in range(len(lines))]
    for i, line in enumerate(lines):
        mtrx[i] = [float(j) for j in line.split()]
    #print('mtrx:', mtrx)
    return mtrx


def get_kmers_from_sequence(sequence, k, return_set=False):
    """
    Get all kmers in a <sequence>.
    """
    assert k <= len(sequence), 'k must be less than or equal to sequence length'
    kmers = [sequence[i:i + k] for i in range(len(sequence) - k + 1)]
    
    if return_set:
        return set(kmers)
    else:
        return kmers


def get_kmers_from_sequences(sequences, k, return_set=False):
    """
    Get all kmers in all <sequences>.
    """
    kmers = []
    for sequence in sequences:
        kmers.extend(get_kmers_from_sequence(sequence, k, return_set=return_set))

    if return_set:
        return set(kmers)
    else:
        return kmers


def find_motifs(sequences, k, d, algorithm='brute'):
    """
    MOTIFENUMERATION
    
    Find <k>mer motifs whose <d>m-ismatch kmers appear in all <sequences>.
    """
    motifs = []
    
    # For all kmers in <sequences>
    for kmer in get_kmers_from_sequences(sequences, k):
        
        # For each neighbor kmer (kmer with atmost d-mismatch)
        for kmer_with_atmost_d_mismatch in get_sequences_with_d_mismatch(kmer, d, include_self=True):

            # Flag indicating whether a kmer is in each sequence
            found = [0 for i in range(len(sequences))]#np.zeros(len(sequences))
            # Check if this neighboring kmer is in all <sequences>
            for j, sequence in enumerate(sequences):
                for i in range(len(sequence) - k + 1):
                    if get_hamming_distance_between_kmers(sequence[i:i + k], kmer_with_atmost_d_mismatch) <= d:
                        # Mark as present
                        found[j] = 1
                        # Check the next sequence
                        break
            # Add this neghbor kmer to motifs only if it is present in all <sequences>
            if all(found):
                motifs.append(kmer_with_atmost_d_mismatch)
                
    return set(motifs)


def make_kmers(k, nucleotide_type='DNA'):
    """
    Make all possible nucleotide <k>mers.
    """
    if nucleotide_type == 'DNA':
        return [''.join(product) for product in itertools.product(DNA, repeat=k)]
    elif nucleotide_type == 'RNA':
        return [''.join(product) for product in itertools.product(RNA, repeat=k)]
    
    
def get_distance_between_kmer_and_sequence(kmer, sequence):
    """
    Get the minimum distance between <kmer> and a kmer in <sequence>.
    """
    # Initialize the minimum distance to be as long as possible (length of the kmer)
    min_distance = len(kmer)
    # For each kmer in this sequence
    for i in range(len(sequence) - len(kmer) + 1):
        # Computer the distance between <kmer> and this kmer being scanned
        distance = get_hamming_distance_between_kmers(kmer, sequence[i:i + len(kmer)])
        # If the computed distance is closer than the current minimum distance,
        if distance < min_distance:
            # update the minumun distance to be the current distance
            min_distance = distance
    return min_distance


def get_distance_between_kmer_and_sequences(kmer, sequences):
    """
    Get the sum of the minimum distances between <kmer> and a kmer in each sequence in <sequence>.
    """
    sum_distance = 0
    # For each sequence
    for sequence in sequences:
        # Sum the distance between <kmer> and this sequence
        sum_distance += get_distance_between_kmer_and_sequence(kmer, sequence)
    return sum_distance


def get_kmer_closest_to_sequences(sequences, k):
    """
    MEDIANSTRING
    
    Return a <k>mer whose sum of the distance to each sequence in <sequences> is the minimum.
    """
    # Make all possible kmers
    kmers = make_kmers(k)
    # Find the mediam string by minimizing the sum of the distances between a kmer and each sequence
    min_ = k * len(sequences)
    closest_kmers = []
    for kmer in kmers:
        distance_between_kmer_and_sequences = get_distance_between_kmer_and_sequences(kmer, sequences)
        if distance_between_kmer_and_sequences < min_:
            min_ = distance_between_kmer_and_sequences
            closest_kmers.append(kmer)
    return closest_kmers, min_


def is_uniform_list_of_lists(list_of_lists):
    """
    Return True if all list in <list_of_lists> have the same length.
    """
    list_length = len(list_of_lists[0])
    for a_list in list_of_lists:
        if len(a_list) != list_length:
            return False
    return True


def make_mtrx_acgt_x_idx(sequence_motif_matrix):
    """
    Make nucleotide x kmer index matrix.
    """
    assert is_uniform_list_of_lists(sequence_motif_matrix)
    seq_length = len(sequence_motif_matrix[0])
    
    mtrx_acgt_x_idx = [[0 for j in range(seq_length)] for i in DNA]
    #print('mtrx_acgt_x_idx:', mtrx_acgt_x_idx)
    
    for seq in sequence_motif_matrix:
        for j, n in enumerate(seq):
            mtrx_acgt_x_idx[MAP_seq2NUMBER.get(n)][j] += 1
    #print('mtrx_acgt_x_idx:', mtrx_acgt_x_idx)
    return mtrx_acgt_x_idx


def make_probability_mtrx_acgt_x_idx(sequence_motif_matrix):
    """
    PROFILE
    
    Create a DNA x index probability matrix from list of <sequence_motif_matrix>.
    """
    assert is_uniform_list_of_lists(sequence_motif_matrix)
    seq_length = len(sequence_motif_matrix[0])
    
    mtrx_acgt_x_idx = make_mtrx_acgt_x_idx(sequence_motif_matrix)
    mtrx_prob_acgt_x_idx = [[0 for j in range(seq_length)] for i in DNA]

    for j in range(seq_length):
        for i in range(len(DNA)):
            mtrx_prob_acgt_x_idx[i][j] = (mtrx_acgt_x_idx[i][j] + 1) / (sum([mtrx_acgt_x_idx[ii][j] for ii in range(len(DNA))]) + 4)
    #print('mtrx_prob_acgt_x_idx:', mtrx_prob_acgt_x_idx)
    assert is_probability_mtrx(mtrx_prob_acgt_x_idx)
    
    return mtrx_prob_acgt_x_idx


def is_probability_mtrx(probability_mtrx):
    """
    Check if a probability matrix is a list of list and each column sums up to 1.
    """
    assert is_uniform_list_of_lists(probability_mtrx)
    row_lenth = len(probability_mtrx)
    col_lenth = len(probability_mtrx[0])
    for j in range(col_lenth):
        #print('probability column: ', [probability_mtrx[i][j] for i in range(row_lenth)])
        assert 0.95 < sum([probability_mtrx[i][j] for i in range(row_lenth)]) and sum([probability_mtrx[i][j] for i in range(row_lenth)]) < 1.05, 'error at col %s, each probability column must sum up to 1' % j

    return True


def get_entropy(probability_list):
    """
    Calculate entropy of a list of probabilities.
    """
    return -1 * sum([p * math.log2(p) for p in probability_list if p > 0])


def get_probability_mtrx_entropy(probability_mtrx):
    """
    Calculate column entropy of probability matrix.
    """
    assert is_probability_mtrx(probability_mtrx)
    col_lenth = len(probability_mtrx)
    row_length = len(probability_mtrx[0])
    
    col_entropy = [None for j in range(row_length)]
    for j in range(row_length):
        col = [probability_mtrx[i][j] for i in range(col_lenth)]
        col_entropy[j] = get_entropy(col)
    return col_entropy


def get_sequence_motif_mtrx_entropy(sequence_motif_matrix):
    """
    Calculate column entropy of a <sequence_motif_matrix>.
    """
    assert is_uniform_list_of_lists(sequence_motif_matrix)
    
    prob_mtrx = make_probability_mtrx_acgt_x_idx(sequence_motif_matrix)
    return sum(get_probability_mtrx_entropy(prob_mtrx))


def get_mtrx_max_by_row(mtrx):
    """
    Get the max values and their row indices of <mtrx> columns.
    """
    assert is_uniform_list_of_lists(mtrx)
    row_lenth = len(mtrx)
    col_lenth = len(mtrx[0])
    
    mtrx_max_by_row = [() for i in range(col_lenth)]
    for j in range(col_lenth):
        col = [mtrx[i][j] for i in range(row_lenth)]
        max_ = max(col)
        max_idx = col.index(max_)
        mtrx_max_by_row[j] = (max_idx, max_)
    return mtrx_max_by_row


def make_kmer_most_probable_from_probability_mtrx(probability_mtrx_acgt_x_idx):
    """
    Make the most probable kmer from <probability_mtrx_acgt_x_idx>.
    """
    assert is_probability_mtrx(probability_mtrx_acgt_x_idx)
    col_lenth = len(probability_mtrx_acgt_x_idx[0])

    most_probable_kmer = ''
    mtrx_max_by_row = get_mtrx_max_by_row(probability_mtrx_acgt_x_idx)
    for i, v in mtrx_max_by_row:
        most_probable_kmer += MAP_NUMBER2DNA[i]
    return most_probable_kmer


def make_kmer_most_probable_from_sequence_motif_mtrx(sequence_motif_mtrx):
    """
    Make the most probable kmer from <sequence_motif_mtrx>.
    """
    probability_mtrx_acgt_x_idx = make_probability_mtrx_acgt_x_idx(sequence_motif_mtrx)
    return make_kmer_most_probable_from_probability_mtrx(probability_mtrx_acgt_x_idx)


def product(list_of_numbers):
    prod = list_of_numbers[0]
    for n in list_of_numbers[1:]:
        prod = prod * n
    return prod


def get_kmer_probability_based_on_probability_mtrx(kmer, probability_mtrx_acgt_x_idx):
    """
    Get the probabiliry of kmer based on <probability_mtrx_acgt_x_idx>.
    """
    #print('probability_mtrx_acgt_x_idx:', probability_mtrx_acgt_x_idx)
    assert is_probability_mtrx(probability_mtrx_acgt_x_idx)
    #print('kmer:', kmer)
    assert len(kmer) == len(probability_mtrx_acgt_x_idx[0])
    
    nuc_probs = [None for i in range(len(kmer))]
    
    for j, nuc in enumerate(kmer):
        nuc_probs[j] = probability_mtrx_acgt_x_idx[MAP_seq2NUMBER[nuc]][j]
    
    return product(nuc_probs)


def get_kmers_most_probable_in_sequence_based_on_probability_mtrx(sequence, probability_mtrx_acgt_x_idx):
    """
    Get the most probable kmer in <sequence> based on <probability_mtrx_acgt_x_idx>.
    """
    max_prob = 0
    most_probable_kmers = []
    k = len(probability_mtrx_acgt_x_idx[0])
    # Slide window of length k
    for i in range(len(sequence) - k + 1):
        kmer = sequence[i:i + k]
        #print('kmer:', kmer)
        prob = get_kmer_probability_based_on_probability_mtrx(kmer, probability_mtrx_acgt_x_idx)
        #print('prob:', prob)
        if max_prob < prob:
            max_prob = prob
            most_probable_kmers = [kmer]
        elif max_prob == prob:
            most_probable_kmers.append(kmer)

    return most_probable_kmers, max_prob


def get_kmers_most_probable_in_sequences_based_on_probability_mtrx(sequences, probability_mtrx_acgt_x_idx):
    """
    Get the most probable kmers from each sequence in <sequences> based on <probability_mtrx_acgt_x_idx>.
    """
    kmers = []
    for seq in sequences:
        kmers.append(get_kmers_most_probable_in_sequence_based_on_probability_mtrx(seq, probability_mtrx_acgt_x_idx))
    return kmers


def score_motifs(motifs):
    """
    Score <motifs> by 1) making the most probable kmer based on <motifs>'s probability matrix and 
    2) summing the distances between this most probable kmer and each kmer in <motifs>.
    """
    most_probable_motif = make_kmer_most_probable_from_sequence_motif_mtrx(motifs)
    return get_distance_between_kmer_and_sequences(most_probable_motif, motifs)


def search_motifs_greedly(sequences, k):
    # First best motifs are formed by first k-mers in <sequences>
    best_motifs = [seq[:k] for seq in sequences]
    
    # Slide window of size k on 1st sequence
    for i in range(len(sequences[0]) - k + 1):
        # First motif is only the 1st kmer in sequeces[0]
        cur_motifs = [sequences[0][i: i + k]]
        # For subsequent <sequences>
        for i in range(1, len(sequences)):
            # Make probabiliry matrix of the cur_motifs
            cur_prob_mtrx = make_probability_mtrx_acgt_x_idx(cur_motifs)
            # Get the best motif in this sequence based on cur_prob_mtrx
            #print('cur_prob_mtrx:', cur_prob_mtrx)
            selected_motif = get_kmers_most_probable_in_sequence_based_on_probability_mtrx(sequences[i], cur_prob_mtrx)[0][0]
            cur_motifs.append(selected_motif)
            
        # Update the best motif matrix based on scoring
        if score_motifs(cur_motifs) < score_motifs(best_motifs):
            best_motifs = cur_motifs
        
    return best_motifs


def sample_random_kmers(sequences, k):
    """
    Sample random <k>mers from <sequences>.
    """
    # All sequence must have the same length
    assert is_uniform_list_of_lists(sequences)
    
    random_kmers = []
    for sequence in sequences:
        i = random.randint(0, len(sequence) - k)
        random_kmers.append(sequence[i:i + k])
    return random_kmers


def search_motifs_randomly(sequences, k):
    cur_motifs = sample_random_kmers(sequences, k)
    #print('cur_motifs: randomly chosen', cur_motifs)
    best_motifs = cur_motifs
    while True:
        cur_prob_mtrx = make_probability_mtrx_acgt_x_idx(cur_motifs)
        cur_motifs = [t[0][0] for t in get_kmers_most_probable_in_sequences_based_on_probability_mtrx(sequences, cur_prob_mtrx)]
        #print('cur_motifs:', cur_motifs)
        # Update the best motif matrix based on scoring
        if score_motifs(cur_motifs) < score_motifs(best_motifs):
            print('score_motifs(cur_motifs)', score_motifs(cur_motifs))
            best_motifs = cur_motifs
        else:
            return best_motifs

        
def search_motifs_with_gibbs_sampling(sequences, k):
    """
    GIBBSSAMPLER(Dna, k, t, N)
    randomly select k-mers Motifs = (Motif1, …, Motift) in each string
        from Dna
    BestMotifs ← Motifs
    for j ← 1 to N
        i ← Random(t)
        Profile ← profile matrix constructed from all strings in Motifs
                   except for Motifi
        Motifi ← Profile-randomly generated k-mer in the i-th sequence
        if Score(Motifs) < Score(BestMotifs)
            BestMotifs ← Motifs
    return BestMotifs
    """

# Chapter X

In [None]:
def make_sequence_overlapping_network(sequences, nonoverlap=1, edge_score=1):
    assert is_uniform_list_sof_lists(sequences)
    assert nonoverlap > 0, 'There must be overlap'
    
    network = Network()
    for sequence in sequences:
        network.make_node({'name':sequence, 'prefix':sequence[:-nonoverlap], 'suffix':sequence[nonoverlap:]})
    
    for node_from in network.nodes:
        for node_to in remove_node_from_nodes(network.nodes, node_from):
            if node_from['suffix'] == node_to['prefix']:
                network.add_directed_edge([node_from, node_to, 1], allow_duplicate=False)
                print('make_sequence_overlapping_network: added an edge %s =(%s)=> %s' % (node_from, edge_score, node_to))
    return network


def make_debruijn_network_from_sequence(sequence, k, edge_score=1):
    kmers = get_kmers([sequence], k)
    network = Network()
    for kmer in kmers:
        node_from = kmer[:-1]
        node_to = kmer[1:]
        network.add_directed_edge([node_from, node_to, 1])
        #print('make_debruijn_network_from_sequence: added an edge %s =(%s)=> %s' % (node_from, edge_score, node_to))
    return network


def make_debruijn_network_from_kmers(kmers, edge_score=1, allow_duplicate=True):
    network = Network()
    for kmer in kmers:
        node_from = kmer[:-1]
        node_to = kmer[1:]
        network.add_directed_edge([node_from, node_to, 1], allow_duplicate=allow_duplicate)
        #print('make_debruijn_network_from_kmers: added an edge %s =(%s)=> %s' % (node_from, edge_score, node_to))
    return network


def print_debruijn_network(debruijn_network):
    debruijn_dict = {}
    
    for node in debruijn_network.nodes:
        debruijn_dict[node.get_name()] = []
    
    for node in debruijn_network.nodes:
        if node.nodes_to:
            #print(node, node.nodes_to)
            #print(node, node.nodes_to)
            #print('print_debruijn_network:', node.nodes_from, node, node.nodes_to)
            current_values = debruijn_dict.get(node.get_name())
            #print(values)
            debruijn_dict[node.get_name()] = current_values + [n.get_name() for n in node.nodes_to]

    for k, v in debruijn_dict.items():
        if v:
            print('%s -> %s' % (k, ','.join(v)))

            
def make_network_from_edges(edges):
    """
    Make a network from <edges>.
    """
    network = Network()
    for e in edges:
        if len(e) == 2:
            e.extend([1])
        print(e)
        network.add_directed_edge(e, allow_duplicate=False)
    return network


def traverse_eulerian_cycle(eulerian_network_dict):
    # Randomly choose a start_n
    start_n = random.choice(list(eulerian_network_dict.keys()))
    # Generate a cyclic path starting from start_n
    path = traverse_cycle(eulerian_network_dict, start_n)
    # While there is any unvisited edge
    while len(eulerian_network_dict) != 0:
        # Pick a start_n that has a remaining outgoing edge from cyclic path
        start_n = random.choice([n for n in path if n in eulerian_network_dict])
        # Get the index of the chosen start_n (used later for patching a full path)
        i = path.index(start_n)
        # Generate another cyclic path starting from start_n
        a_path = traverse_cycle(eulerian_network_dict, start_n)        
        path = path[i:] + path[1:i+1] + a_path[1:]
    return path

        
def traverse_cycle(network, start_n):
    # Cyclic path to return
    path = [start_n]
    while start_n in network:
        # Randomly choose a start_n's outgoing edge
        end_n = random.choice(network[start_n])
        # Remove the chosen node from start_n's outgoing edges
        network[start_n].remove(end_n)
        # Remove start_n from the network if it has no more outgoing edges
        if len(network[start_n]) == 0:
            network.pop(start_n)
        # Add end_n to path and end_n becomes the new start_n
        path.append(end_n)
        start_n = end_n
    return path


def find_eulerian_start_node(network):
    # Get all keys
    choices = network.keys()
    for c in choices:
        out_degree = len(network[c])
        in_degree = sum([1 for v in network.values() if c in v])
        if out_degree > in_degree:
            # start_n must have more out edges than in edges
            return c


def traverse_eulerian_noncycle(network):
    # Find the start node, which has more out going edge than incoming edges
    start_n = find_eulerian_start_node(network)
    
    nodes = [start_n]
    
    path = []
    while len(nodes) != 0:
        current_n = nodes[-1]
        
        if current_n not in network:
            # Last node without any outgoing edge
            path.append(nodes.pop())
            continue
        
        nodes_to = network[current_n]
        if len(nodes_to) <= 0:
            path.append(nodes.pop())
        else:
            node_to = random.choice(nodes_to)
            network[current_n].remove(node_to)
            nodes.append(node_to)
    return path[::-1]


def make_sequence_overlapping_network_dict(sequences, nonoverlap=1, edge_score=1):
    assert is_uniform_list_of_lists(sequences)
    assert nonoverlap > 0, 'There must be overlap'
    network_dict = {}
    for sequence in sequences:
        network_dict[sequence[:-1]] = [sequence[1:]]
    return network_dict


def reconstruct_string(sequences):
    network_dict = make_sequence_overlapping_network_dict(sequences)
    path = traverse_eulerian_noncycle(network_dict)
    sequence = path[0]
    for n in path[1:]:
        sequence += n[-1]
    return sequence


def make_universal_circular_string_network_dict(k):
    choices = [0, 1]
    product = [p for p in itertools.product(choices, repeat=k-1)]
    kmers = [''.join(str(x) for x in p) for p in product]
    
    kmer_dict = {}
    for kmer in kmers:
        overlap = kmer[1:]
        kmer_dict[kmer] = kmer_dict.get(kmer, []) + [overlap + '1', overlap + '0']
    return kmer_dict


def make_universal_circular_string(k):
    graph = make_universal_circular_string_network_dict(k)
    path = traverse_eulerian_cycle(graph)
    strand = ''
    while len(path) >= 2:
        edge = path[0] + path[1][-1]
        strand += edge[0]
        path = path[1:]
    return strand


def make_debruijn_network_from_paired_kmers(kmer_pairs):
    debruijn_graph = {}
    for kp1 in kmer_pairs:
        kp1_suffix = (kp1[0][1:], kp1[1][1:])
        for kp2 in kmer_pairs:
            kp2_prefix = (kp2[0][:-1], kp2[1][:-1])
            if kp1_suffix == kp2_prefix and kp1 != kp2:
                debruijn_graph[kp1] = debruijn_graph.get(kp1, []) + [kp2]
    return debruijn_graph


def reconstruct_sequence_from_paired_kmers(kmer_pairs, distance):
    debruijn_graph = make_debruijn_network_from_paired_kmers(kmer_pairs)
    print(debruijn_graph)
    k = len(kmer_pairs[0][0])
    path = traverse_eulerian_noncycle(debruijn_graph)
    print(path)
    sequence = path[0][0]
    for n in path[1:]:
        sequence += n[0][-1]
    for n in path[len(path) - k - d:]:
        sequence += n[1][-1]
    return sequence


def get_max_nonbranching_paths(graph):
    paths = []
    
    # Get in- and out-degrees for each node in De Bruijn graph
    degrees = get_node_degrees(graph)
    print('degrees:', degrees)
    
    # Contigs can be built starting from nodes with not-1 indegree or more than 1 out-degree
    starts = [n for n in graph if degrees[n][0] != 1 or degrees[n][1] > 1]
    print('starts:', starts)
    
    for n in starts:
        print('n:', n)
        
        while n in graph:
            print('while n:', n)
            
            path = n[0]
            print('path:', path)
            
            stop = False
            while not stop:
                
                # Pick a next node and get its in- and out- degrees
                next_n = random.choice(debruijn_graph[n])
                degs = degrees[next_n]
                print('next_n:', next_n, degs)
                # Then remove this picked node from current node's next nodes
                debruijn_graph[n].remove(next_n)
                # If the current node doesn't have any more out-nodes
                if len(debruijn_graph[n]) == 0:
                    # Remove from the graph
                    debruijn_graph.pop(n)
                # If the next_n is in graph, implying that it has outgoing nodes
                if next_n in debruijn_graph:
                    # If contig intermedia node
                    if degs == (1,1):
                        # Extend contig by 1
                        ctg += next_n[0]
                        print('ctg:', ctg)
                        # and update the current node
                        n = next_n
                        print('new n:', n)
                    else:
                        # If contig is the last node, fully extend contig
                        ctg += next_n
                        # This contig is done
                        stop = True
                else:
                    # If the contig is not in the graph, implying that the node doens't have outgoing ndoes
                    ctg += next_n
                    # This contig is done
                    stop = True
                print('ctg:', ctg)
                
            # Add the node to the contigs
            contigs.append(ctg)
    return sorted(contigs)


def extend_nonbranching_path(graph, start, ignore):
    print('ignore:', ignore)
    path = [start]
    
    # Base case: start node has no outgoing edge or has multiple outgoing edges
    if start not in graph or len(graph[start]) != 1:
        return path
    
    # Recursive case when the out-degree is 1
    
    # Ignore
    unvisited_nodes = list(set(graph[start]) - set(ignore))
    
    # Pick the next node
    next_ = unvisited_nodes[0]
    print('extend_nonbranching_path: next_:', next_)
    ignore.append(next_)
          
    return path + extend_nonbranching_path(graph, next_, ignore)


def get_nonbranching_path(graph, start, ignore=[]):
    assert start in graph, '%s is either not a key in the graph (not a node in the graph or has 0 outgoing edge)'
    
    # Ignore
    unvisited_nodes = list(set(graph[start]) - set(ignore))

    if len(unvisited_nodes) > 1:
        # Start has multiple outgoing edges, so pick the first branch to extend
        path = [start]
        next_ =  unvisited_nodes[0]
        return path + extend_nonbranching_path(graph, next_, ignore=ignore)
    else:
        return extend_nonbranching_path(graph, start, ignore=ignore)

    
def get_nonbranching_paths(graph):
    degrees = get_node_degrees(graph)
    

def change_money(money, coins):
    min_coin = [9999999] * (money + 1)
    min_coin[0] = 0
    #print(min_coin, len(min_coin))
    for m in range(len(min_coin))[1:]:
        #print('m:', m)
        for c in coins:
            if c <= m:
                #print('c:', c)
                if min_coin[m - c] + 1 < min_coin[m]:
                    min_coin[m] = min_coin[m - c] + 1
                    #print(min_coin)
    return min_coin[money]


def MANHATTANTOURIST(D, R):
    n = len(R)
    m = len(D[0])
    print('n, m:', n, m)

    # Make an empty n x m matrix
    M = [[0] * m for i in range(n)]
    print('M:\n', np.matrix(M))
    
    # Update the 1st column
    for i, r in enumerate(M):
        if i == 0:
            continue
        r[0] = M[i - 1][0] + D[i - 1][0]
    print('M:\n', np.matrix(M))
    
    # Update the 1st row
    for i, v in enumerate(M[0]):
        if i == 0:
            continue
        M[0][i] = M[0][i-1] + R[0][i-1]
    print('M:\n', np.matrix(M))
    
    for i in range(n)[1:]:
        for j in range(m)[1:]:
            down = M[i - 1][j] + D[i - 1][j]
            right = M[i][j - 1] + R[i][j - 1]
            #print('i, j:', i, j)
            #print('down, right:', down, right)
            #print('down, right:', down, right)
            M[i][j] = max(down, right)
            #print('Updated M[%s][%s]:' % (i, j), M[i][j])

    print('M:\n', np.matrix(M))

    
def get_longest_common_subsequence(seq1, seq2):
    print('seq1', seq1)
    print('seq2', seq2)
    n = len(seq1)
    m = len(seq2)
    print('n, m:', n, m)
    
    # Make a similarity matrix
    S = [[0] * m for i in range(n)]
    for i, b1 in enumerate(seq1):
        for j, b2 in enumerate(seq2):
            if b1 == b2:
                S[i][j] = 1
    print('S:\n', np.matrix(S))
    
    # Make an empty n x m matrix
    M = [[0] * (m + 1) for i in range(n + 1)]
    
    for i in range(n):
        for j in range(m):
            M[i + 1][j + 1] = max(M[i][j] + S[i][j], M[i][j + 1], M[i + 1][j])
    print('M:\n', np.matrix(M))
    
    longest_commond_seq = []
    
    Mi, Mj = n, m
    while Mi > 0 and Mj > 0:
        print('current:', Mi, Mj, M[Mi][Mj], seq1[Mi - 1])

        if M[Mi][Mj - 1] == M[Mi][Mj]:
            Mj = Mj - 1
        elif M[Mi - 1][Mj] == M[Mi][Mj]:
            Mi = Mi - 1
        else:
            Mi, Mj = Mi - 1, Mj - 1
            print('match:', Mi, Mj, M[Mi][Mj], seq1[Mi])
            longest_commond_seq.append(seq1[Mi])

    return ''.join(reversed(longest_commond_seq))


def sort_topologically(graph):
    """
    Find nodes without any incoming edges and remove them from the graph.
    The topological order is this order of discovery.
    """
    # Copy graph
    graph = set(graph)
    
    # Get nodes without any incoming edge
    no_inedge = list({edge[0] for edge in graph} - {edge[1] for edge in graph})
    
    # Recursive find
    order = []
    while len(no_inedge) != 0:
        current = no_inedge.pop(0)
        order.append(current)
        
        # For all edges starting form <current>
        nodes_to = []
        for edge in [edge for edge in graph if edge[0] == current]:

            # Add the <node_to> to <nodes_to>
            node_to = edge[1]
            
            # Remove this outgoing <edge> from the <graph>
            graph.remove(edge)
        
            # If <node_to> doesn't have any incoming edges, add to <no_inedge> and order
            if node_to not in {edge[1] for edge in graph}:
                no_inedge.append(node_to)
        
    return order


def get_longest_path_dag(edge_scores, edges, source, sink):

    # Sort DAG graph topologically
    order = sort_topologically(edge_scores.keys())
    # Extract sequence from source to sink
    order = order[order.index(source) + 1:order.index(sink) + 1]
    #print('order:', order)

    # Itinialize backtrack; all nodes' prevs are None
    backtrack = {n: None for n in order}
    
    # Initialize scores; all nodes' scores are big negative
    scores = {n: -999 for n in {e[0] for e in edge_scores.keys()} | {e[1] for e in edge_scores.keys()}}
    
    # Set source's score to be 0
    scores[source] = 0

    # Update the score in topological order
    for n in order:
        try:
            edges_in = filter(lambda e: e[1] == n, edge_scores.keys())
            get_node_from = lambda score_nodein: score_nodein[0]
            scores[n], backtrack[n] = max(map(lambda e: [scores[e[0]] + edge_scores[e], e[0]], edges_in), key=get_node_from)
        except ValueError:
            pass

    # Backtracks
    path = [sink]
    while path[0] != source:
        path = [backtrack[path[0]]] + path

    return scores[sink], path


def get_node_degrees(graph):
    nodes_to = list(itertools.chain.from_iterable(graph.values()))
    degrees = {}
    for n in graph:
        degrees[n] = [nodes_to.count(n), len(graph[n])]
    # Nodes without any outgoing edge
    for n in set(n for n in nodes_to if n not in graph):
        degrees[n] = [nodes_to.count(n), 0]
    return degrees


def remove_self_edge(graph):
    for n, to_nodes in graph.items():
        graph[n] = [to for to in to_nodes if to != n]
    return graph


def backtrack(state, node, source):
    track = [node]
    
    # Base case
    if node == source:
        return track
    else:
        return backtrack(state, state[node][1], source) + track
    

def get_incoming_nodes(graph):

    ns = set()
    for from_n, to_ns in graph.items():
        ns.add(from_n)
        for to in to_ns:
            ns.add(to)
            
    #print('sorted(ns)', sorted(ns))
    
    n_ins = {}
    # For all nodes
    for n in ns:
        n_ins[n] = []
        
        for from_n, to_ns in graph.items():
            # If from_n visits n, then add from_n to n_ins[n]
            if n in to_ns:

                if from_n not in n_ins[n]:
                    n_ins[n].append(from_n)
    return n_ins


def is_uniform_list_of_lists(list_of_lists):
    """
    Return True if all list in <list_of_lists> have the same length.
    """
    list_length = len(list_of_lists[0])
    for a_list in list_of_lists:
        if len(a_list) != list_length:
            return False
    return True


def init_global_alignment_direction_matrices(seq1, seq2, indel_penalty):
    n = len(seq1)
    m = len(seq2)
    print('init_global_alignment_direction_matrices: n, m:', n, m)
    
    # Initialize the alignment and direction matrices, taking into consideration initial InDels
    alignment_matrix = [[0] * (m + 1) for i in range(n + 1)]
    direction_matrix = [[0] * (m + 1) for i in range(n + 1)]

    # Update column 1
    for i in range(n + 1):
        alignment_matrix[i][0] = indel_penalty * i
        if i != 0:
            direction_matrix[i][0] = (i - 1, 0)
    # Update row 1
    for j in range(m + 1):
        alignment_matrix[0][j] = indel_penalty * j
        if j != 0:
            direction_matrix[0][j] = (0, j - 1)
        
    print('init_global_alignment_direction_matrices: alignment_matrix:\n', pd.DataFrame(alignment_matrix))
    print('init_global_alignment_direction_matrices: direction_matrix:\n', pd.DataFrame(direction_matrix))
    return alignment_matrix, direction_matrix


def init_local_alignment_direction_matrices(seq1, seq2):
    n = len(seq1)
    m = len(seq2)
    print('init_local_alignment_direction_matrices: n, m:', n, m)
    
    # Initialize the alignment and direction matrices, taking into consideration initial InDels
    alignment_matrix = [[0] * (m + 1) for i in range(n + 1)]
    direction_matrix = [[0] * (m + 1) for i in range(n + 1)]
    
    # Update column 1
    for i in range(n + 1)[1:]:
        direction_matrix[i][0] = (0, 0)
    # Update row 1
    for j in range(m + 1)[1:]:
        direction_matrix[0][j] = (0, 0)
        
    print('init_local_alignment_direction_matrices: alignment_matrix:\n', pd.DataFrame(alignment_matrix))
    print('init_local_alignment_direction_matrices: direction_matrix:\n', pd.DataFrame(direction_matrix))
    return alignment_matrix, direction_matrix


def backtrack_alignment(seq1, seq2, direction_matrix, init_ij=None, and_or='or'):
    # Backtrack
    alignment = []
    if init_ij:
        print('backtrack_alignment: initialization coordinate:', init_ij)
        i, j = init_ij
    else:
        print('backtrack_alignment: no initialization coordinate:')
        i = len(seq1)
        j = len(seq2)
           
    if and_or == 'or':
        while i > 0 or j > 0:
            toi, toj = direction_matrix[i][j]
            #print('align_globally: i, j, toi, toj:', i, j, toi, toj)

            # For local alignment termination
            if toi == 0 and toj == 0:
                break    

            if toi == i - 1 and toj == j - 1:         
                #print('align_globally: match or mismatch', seq1[toi], seq2[toj])
                alignment.append((seq1[toi], seq2[toj]))
            elif toi == i - 1 and toj == j:         
                #print('align_globally: deletion', seq1[toi], '-')
                alignment.append((seq1[toi], '-'))
            elif toi == i and toj == j - 1:         
                #print('align_globally: insertion', '-', seq2[toj])
                alignment.append(('-', seq2[toj]))
            else:
                raise ValueError('direction error')

            i, j = toi, toj
    else:
        while i > 0 and j > 0:
            toi, toj = direction_matrix[i][j]
            #print('align_globally: i, j, toi, toj:', i, j, toi, toj)

            # For local alignment termination
            if toi == 0 and toj == 0:
                break    

            if toi == i - 1 and toj == j - 1:         
                #print('align_globally: match or mismatch', seq1[toi], seq2[toj])
                alignment.append((seq1[toi], seq2[toj]))
            elif toi == i - 1 and toj == j:         
                #print('align_globally: deletion', seq1[toi], '-')
                alignment.append((seq1[toi], '-'))
            elif toi == i and toj == j - 1:         
                #print('align_globally: insertion', '-', seq2[toj])
                alignment.append(('-', seq2[toj]))
            else:
                raise ValueError('direction error')

            i, j = toi, toj
        
    return alignment


def align_globally(seq1, seq2, scoring_matrix, labels, indel_penalty):
    alignment_matrix, direction_matrix = init_global_alignment_direction_matrices(seq1, seq2, indel_penalty)
    
    # Dynamically update alignment_matrix
    for ridx, row in enumerate(alignment_matrix[1:]): 
        # Row element
        relem = seq1[ridx]
        rsidx = labels.index(relem)
        assert relem == labels[rsidx]
        #print('align_globally: ridx, row, relem, rsidx:', ridx, row, relem, rsidx)
        
        for cidx, col in enumerate(row[1:]):
            # Column element
            celem = seq2[cidx]
            csidx = labels.index(celem)
            assert celem == labels[csidx]
            #print('align_globally: cidx, col, celem, csidx:', cidx, col, celem, csidx)
            
            # Score
            smatch = alignment_matrix[ridx][cidx] + scoring_matrix[rsidx][csidx], (ridx, cidx)
            sin = alignment_matrix[ridx][cidx + 1] - indel_penalty, (ridx, cidx + 1)
            sdel = alignment_matrix[ridx + 1][cidx] - indel_penalty, (ridx + 1, cidx)
            # Dynamically update alignment and direction matrices
            max_ = max([smatch, sin, sdel], key=lambda t: t[0])
            alignment_matrix[ridx + 1][cidx + 1] = max_[0]
            direction_matrix[ridx + 1][cidx + 1] = max_[1]
    
    #print('align_globally: alignment_matrix\n', pd.DataFrame(alignment_matrix, index=['0'] + [s for s in seq1], columns=['0'] + [s for s in seq2]))
    #print('align_globally: direction_matrix\n', pd.DataFrame(direction_matrix, index=['0'] + [s for s in seq1], columns=['0'] + [s for s in seq2]))

    alignment = backtrack_alignment(seq1, seq2, direction_matrix)
    #print('align_globally: final alignment:', alignment_matrix[len(seq1)][len(seq2)], alignment)
    return alignment_matrix[len(seq1)][len(seq2)], alignment


def align_locally(seq1, seq2, scoring_matrix, labels, indel_penalty):
    alignment_matrix, direction_matrix = init_local_alignment_direction_matrices(seq1, seq2, indel_penalty)
    

    # Dynamically update the alignment_matrix
    source = 0, (0, 0)
    # Keep track of the coordinate with the maximum score
    global_max = (source)
    for ridx, row in enumerate(alignment_matrix[1:]): 
        # Row element
        relem = seq1[ridx]
        rsidx = labels.index(relem)
        assert relem == labels[rsidx]
        #print('align_locally: ridx, row, relem, rsidx:', ridx, row, relem, rsidx)

        for cidx, col in enumerate(row[1:]):
            # Column element
            celem = seq2[cidx]
            csidx = labels.index(celem)
            assert celem == labels[csidx]
            #print('align_locally: cidx, col, celem, csidx:', cidx, col, celem, csidx)

            # Score
            smatch = alignment_matrix[ridx][cidx] + scoring_matrix[rsidx][csidx], (ridx, cidx)
            sin = alignment_matrix[ridx][cidx + 1] - indel_penalty, (ridx, cidx + 1)
            sdel = alignment_matrix[ridx + 1][cidx] - indel_penalty, (ridx + 1, cidx)
            # Dynamically update alignment and direction matrices
            max_ = max([source, smatch, sin, sdel], key=lambda t: t[0])
            alignment_matrix[ridx + 1][cidx + 1] = max_[0]
            direction_matrix[ridx + 1][cidx + 1] = max_[1]            
            
            # Keep track of the coordinate with the maximum score
            if global_max[0] < max_[0]:
                global_max = (max_[0], (ridx + 1, cidx + 1))
    
    print('align_locally:\n', pd.DataFrame(alignment_matrix, index=['0'] + [s for s in seq1], columns=['0'] + [s for s in seq2]))
    print('align_locally:\n', pd.DataFrame(direction_matrix, index=['0'] + [s for s in seq1], columns=['0'] + [s for s in seq2]))
    print('align_locally: global_max', global_max)
    
    # Backtrack
    alignment = backtrack_alignment(seq1, seq2, direction_matrix, init_ij=global_max[1])
    print('align_locally: final alignment:', global_max[0], alignment)
    return global_max[0], alignment


def print_alignment(alignment_tuple):    
    #print('print_alignment: alignment_tuple length:', len(alignment_tuple))
    
    seq1 = ''
    seq2 = ''
    for i in alignment_tuple:
        if i[0] != '-':
            seq1 += i[0]
        else:
            seq1 += '-'
        if i[1] != '-':
            seq2 += i[1]
        else:
            seq2 += '-'
    assert len(alignment_tuple) == len(seq1) == len(seq2)
    #print('print_alignment: alignment_tuple:', alignment_tuple)
    #print('print_alignment: seq1:', seq1[::-1])
    #print('print_alignment: seq2:', seq2[::-1])

    return seq1[::-1], seq2[::-1]


def score_alignment(seq1, seq2, scoring_matrix, labels, indel_penalty):
    
    assert len(seq1) == len(seq2), 'seq1 and seq2 must have the same length'
    #print('score_alignment: alignment length:', len(seq1))

    score = 0
    for i in range(len(seq1)):
        if seq1[i] == '-' or seq2[i] == '-':
            score -= indel_penalty
        else:
            score += scoring_matrix[labels.index(seq1[i])][labels.index(seq2[i])]
    return(score)


def print_dict(dictionary, title=''):
    """
    Print a dictionary.
    """
    print(title)
    print('Number of entries:', len(dictionary))
    for k, v in dictionary.items():
        print('%s ==> %s' % (k, v))
    

def align(seq1, seq2, scoring_matrix, labels, indel_penalty=5):
    scr, aln = align_locally(seq1, seq2, scoring_matrix, labels, indel_penalty)
    s1, s2 = print_alignment(aln)
    assert score_alignment(s1, s2, scoring_matrix, labels, indel_penalty) == scr
    print('%s\n%s\n%s' % (scr, s1, s2))

    
def get_edit_distance(seq1, seq2):
    """
    Compute edit distance using dynamic programming.
    """
    n, m = len(seq1), len(seq2)
    print('get_edit_distance: n, m:', n, m)
    
    # Initialize the alignment and direction matrices, taking into consideration initial InDels
    alignment_matrix = [[0] * (m + 1) for i in range(n + 1)]

    # Update column 1
    for i in range(n + 1):
        alignment_matrix[i][0] = i
    # Update row 1
    for j in range(m + 1):
        alignment_matrix[0][j] = j
    print('get_edit_distance: alignment_matrix:\n', pd.DataFrame(alignment_matrix))

    # Dynamcally update the alignment matrix
    for ridx in range(1, n + 1):
        for cidx in range(1, m + 1):
            #print('ridx, cidx:', ridx, cidx)

            # Score
            if seq1[ridx - 1] == seq2[cidx - 1]:
                mismatch_penalty = 0
            else:
                mismatch_penalty = 1
            
            smatch = alignment_matrix[ridx - 1][cidx - 1] + mismatch_penalty, (ridx - 1, cidx - 1)
            sin = alignment_matrix[ridx][cidx - 1] + 1, (ridx, cidx - 1)
            sdel = alignment_matrix[ridx - 1][cidx] + 1, (ridx - 1, cidx)
            # Dynamically update alignment and direction matrices
            max_ = min([smatch, sin, sdel], key=lambda t: t[0])
            alignment_matrix[ridx][cidx] = max_[0]

    print('get_edit_distance: alignment_matrix final:\n', pd.DataFrame(alignment_matrix))
    
    return alignment_matrix[n][m]


def align_with_fit(seq1, seq2, indel_penalty=-1):

    n, m = len(seq1), len(seq2)
    print('get_edit_distance: n, m:', n, m)

    alignment_matrix = [[0 for j in range(m + 1)] for i in range(n + 1)]
    print('alignment_matrix: fresh\n', pd.DataFrame(alignment_matrix))

    direction_matrix = [[0 for j in range(m + 1)] for i in range(n + 1)]
    print('direction_matrix: fresh\n', pd.DataFrame(alignment_matrix))
    
    for ridx in range(1, n + 1):
        for cidx in range(1, m + 1):
            scores = [alignment_matrix[ridx - 1][cidx] - 1, alignment_matrix[ridx][cidx - 1] - 1, alignment_matrix[ridx - 1][cidx - 1] + [-1, 1][seq1[ridx - 1] == seq2[cidx - 1]]]
            alignment_matrix[ridx][cidx] = max(scores)
            direction_matrix[ridx][cidx] = scores.index(alignment_matrix[ridx][cidx])
    print('alignment_matrix:\n', pd.DataFrame(alignment_matrix))
    
    j = m
    i = max(enumerate([alignment_matrix[r][j] for r in range(len(seq2), len(seq1))]), key=lambda x: x[1])[0] + len(seq2)
    print('i, j:', i, j)
    
    max_score = alignment_matrix[i][j]
    seq1_aligned, seq2_aligned = seq1[:i], seq2[:j]
    add_indel = lambda word, i: word[:i] + '-' + word[i:]
    while 0 < i and 0 < j:
        if direction_matrix[i][j] == 0:
            i -= 1
            seq2_aligned = add_indel(seq2_aligned, j)
        elif direction_matrix[i][j] == 1:
            j -= 1
            seq1_aligned = add_indel(seq1_aligned, i)
        elif direction_matrix[i][j] == 2:
            i -= 1
            j -= 1
    seq1_aligned = seq1_aligned[i:]
    return max_score, seq1_aligned, seq2_aligned


def align_oseq1erlap(seq1, seq2):
    n, m = len(seq1), len(seq2)
    #print('get_edit_distance: n, m:', n, m)
    alignment_matrix = [[0 for j in range(m + 1)] for i in range(n + 1)]
    #print('alignment_matrix: fresh\n', pd.DataFrame(alignment_matrix))
    direction_matrix = [[0 for j in range(m + 1)] for i in range(n + 1)]
    #print('direction_matrix: fresh\n', pd.DataFrame(alignment_matrix))

    max_score = -999*(n + m)
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            score = [alignment_matrix[i-1][j-1] + [-2, 1][seq1[i-1] == seq2[j-1]], alignment_matrix[i-1][j] - 2, alignment_matrix[i][j-1] - 2]
            alignment_matrix[i][j] = max(score)
            direction_matrix[i][j] = score.index(alignment_matrix[i][j])

            if i == n or j == m:
                if alignment_matrix[i][j] > max_score:
                    max_score = alignment_matrix[i][j]
                    max_index = (i, j)
    print('alignment_matrix: final\n', pd.DataFrame(alignment_matrix))
    print('direction_matrix: final\n', pd.DataFrame(direction_matrix))
    i, j = max_index
    print('max_index: i, j:', i, j)
    
    seq1_aligned, seq2_aligned = seq1[:i], seq2[:j]
    add_indel = lambda seq2ord, i: seq2ord[:i] + '-' + seq2ord[i:]
    while i*j != 0:
        if direction_matrix[i][j] == 1:
            i -= 1
            seq2_aligned = add_indel(seq2_aligned, j)
        elif direction_matrix[i][j] == 2:
            j -= 1
            seq1_aligned = add_indel(seq1_aligned, i)
        else:
            i -= 1
            j -= 1
    seq1_aligned, seq2_aligned = seq1_aligned[i:], seq2_aligned[j:]
    
    return max_score, seq1_aligned, seq2_aligned


def align_globally_with_affinity_gap_penalty(seq1, seq2, score_matrix, labels, gap_penalty, extension_penalty):
    n, m = len(seq1), len(seq2)
    print('get_edit_distance: n, m:', n, m)
    
    alignment_matrices = [[[0 for j in range(m + 1)]for i in range(n + 1)] for num_m in range(3)]
    direction_matrices = [[[0 for j in range(m + 1)]for i in range(n + 1)] for num_m in range(3)]    
    
    # 0: , 1: , and 2: 
    for i in range(1, n + 1):
        alignment_matrices[0][i][0] = - gap_penalty - (i - 1) * extension_penalty
        alignment_matrices[1][i][0] = - gap_penalty - (i - 1) * extension_penalty
        alignment_matrices[2][i][0] = -10 * gap_penalty
    for j in range(1, m + 1):
        alignment_matrices[2][0][j] = - gap_penalty - (j - 1) * extension_penalty
        alignment_matrices[1][0][j] = - gap_penalty - (j - 1) * extension_penalty
        alignment_matrices[0][0][j] = -10 * gap_penalty
    print('alignment_matrices: fresh\n', pd.DataFrame(alignment_matrices))
    print('direction_matrices: fresh\n', pd.DataFrame(direction_matrices))
    
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            vs = [alignment_matrices[0][i - 1][j] - extension_penalty, alignment_matrices[1][i - 1][j] - gap_penalty]
            alignment_matrices[0][i][j] = max(vs)
            direction_matrices[0][i][j] = vs.index(alignment_matrices[0][i][j])
            
            hs = [alignment_matrices[2][i][j - 1] - extension_penalty, alignment_matrices[1][i][j - 1] - gap_penalty]
            alignment_matrices[2][i][j] = max(hs)
            direction_matrices[2][i][j] = hs.index(alignment_matrices[2][i][j])
                               
            ms = [alignment_matrices[0][i][j], alignment_matrices[1][i - 1][j - 1] + score_matrix[labels.index(seq1[i-1])][labels.index(seq2[j-1])], alignment_matrices[2][i][j]]
            alignment_matrices[1][i][j] = max(ms)
            direction_matrices[1][i][j] = ms.index(alignment_matrices[1][i][j])
            
    i, j = n, m
    seq1_aligned, se2_aligned = seq1, seq2
    #print('alignment_matrices: final\n', pd.DataFrame(alignment_matrices))
    #print('direction_matrices: final\n', pd.DataFrame(direction_matrices))

    s = [alignment_matrices[x][i][j] for x in range(3)]
    max_score = max(s)
    back = s.index(max_score)
    
    add_indel = lambda seq, i: seq[:i] + '-' + seq[i:]

    while i > 0 and j > 0:
        if back == 0:
            if direction_matrices[0][i][j] == 1:
                back = 1
            i -= 1
            seq2_aligned = add_indel(seq2, j)
            
        elif back == 1:
            if direction_matrices[1][i][j] == 0:
                back = 0
            elif direction_matrices[1][i][j] == 2:
                back = 2
            else:
                i -= 1
                j -= 1
        else:
            if direction_matrices[2][i][j] == 1:
                back = 1
            j -= 1
            seq1_aligned = add_indel(seq1, i)    
            
    for x in range(i):
        seq2_aligned = add_indel(seq2_aligned, 0)
    for x in range(j):
        seq1_aligned = add_indel(seq1_aligned, 0)
    
    return max_score, seq1_aligned, seq2_aligned


def align_multiple_seq(seq1, seq2, seq3):

    n, m, k = len(seq1), len(seq2), len(seq3)
    alignment_matrix = [[[0 for k in range(k+1)] for j in range(m+1)] for i in range(n+1)]
    direction_matrix = [[[0 for k in range(k+1)] for j in range(m+1)] for i in range(n+1)]

    for i in range(1, n+1):
        for j in range(1, m+1):
            for k in range(1, k+1):
                s = [alignment_matrix[i-1][j-1][k-1] + int(seq1[i-1] == seq2[j-1] == seq3[k-1]), alignment_matrix[i-1][j][k], alignment_matrix[i][j-1][k], alignment_matrix[i][j][k-1], alignment_matrix[i-1][j][k-1], alignment_matrix[i][j-1][k-1]]
                direction_matrix[i][j][k], alignment_matrix[i][j][k] = max(enumerate(s), key=lambda p: p[1])

    inalignment_matrixert_indel = lambda seq2ord, i: seq2ord[:i] + '-' + seq2ord[i:]
    seq1_aligned, seq2_aligned, seq3_aligned = seq1, seq2, seq3

    i, j, k = len(seq1), len(seq2), len(seq3)
    max_score = alignment_matrix[i][j][k]

    while i*j*k != 0:
        if direction_matrix[i][j][k] == 1:
            i -= 1
            seq2_aligned = inalignment_matrixert_indel(seq2_aligned, j)
            seq3_aligned = inalignment_matrixert_indel(seq3_aligned, k)
        elif direction_matrix[i][j][k] == 2:
            j -= 1
            seq1_aligned = inalignment_matrixert_indel(seq1_aligned, i)
            seq3_aligned = inalignment_matrixert_indel(seq3_aligned, k)
        elif direction_matrix[i][j][k] == 3:
            k -= 1
            seq1_aligned = inalignment_matrixert_indel(seq1_aligned, i)
            seq2_aligned = inalignment_matrixert_indel(seq2_aligned, j)
        elif direction_matrix[i][j][k] == 4:
            i -= 1
            j -= 1
            seq3_aligned = inalignment_matrixert_indel(seq3_aligned, k)
        elif direction_matrix[i][j][k] == 5:
            i -= 1
            k -= 1
            seq2_aligned = inalignment_matrixert_indel(seq2_aligned, j)
        elif direction_matrix[i][j][k] == 6:
            j -= 1
            k -= 1
            seq1_aligned = inalignment_matrixert_indel(seq1_aligned, i)
        else:
            i -= 1
            j -= 1
            k -= 1

    while len(seq1_aligned) != max(len(seq1_aligned),len(seq2_aligned),len(seq3_aligned)):
        seq1_aligned = inalignment_matrixert_indel(seq1_aligned, 0)
    while len(seq2_aligned) != max(len(seq1_aligned),len(seq2_aligned),len(seq3_aligned)):
        seq2_aligned = inalignment_matrixert_indel(seq2_aligned, 0)
    while len(seq3_aligned) != max(len(seq1_aligned),len(seq2_aligned),len(seq3_aligned)):
        seq3_aligned = inalignment_matrixert_indel(seq3_aligned, 0)

    return max_score, seq1_aligned, seq2_aligned, seq3_aligned


def read_mtrx(matrix_filename):
    with open(matrix_filename) as f:
        # Read the entire file content
        content = f.read()
        # Split the file's content by <delimiter>
        lines2 = content.split('\n')
        # Be careful about lines containing only ''
        lines2 = [i for i in lines2 if i != '']
        assert is_uniform_list_of_lists(lines2)

    labels = lines2[0].split()
    rows = []
    for line in lines2[1:]:
        trimmed = line.split()[1:]
        rows.append([int(i) for i in trimmed])

    assert is_uniform_list_of_lists(rows)
    print(pd.DataFrame(rows, index=labels, columns=labels))

    return rows

# Chapter 6

In [None]:
def format_input(ccs):
    return [list(map(int, cc.split())) for cc in (ccs[1:-1].split(')('))]


def format_output(list_of_lines):
    for line in list_of_lines:
        print('(' + ' '.join([['', '+'][n > 0] + str(n) for n in line]) + ')')


def reverse_negate_interval(seq, i, j):
    assert i <= j, 'i = %s & j = %s' % (i, j)
    
    before_i = seq[:i] # Not including i
    after_j = seq[j + 1:] # Not including j
    i_to_j = seq[i:j + 1] # i to j incluisve
    i_to_j_reversed_negated = [-1 * n for n in reversed(i_to_j)]
    return before_i + i_to_j_reversed_negated + after_j


def reverse_sort_greedily(permutation):
    sorted_permutations = []
    
    index_ = lambda list_, n: list(map(abs, list_)).index(n)
    
    i = 0
    while i < len(permutation):
        if permutation[i] == i + 1:
            i += 1
        else:
            permutation = reverse_negate_interval(permutation, i, index_(permutation, i + 1))
            sorted_permutations.append(permutation)
    return sorted_permutations


def count_breaking_points(permutation):
    c = 0
    prev = 0
    for n in permutation:
        if prev + 1 != n:
            c += 1
        prev = n
    return c


def get_2break_distance(ccs1, ccs2):
    from collections import defaultdict

    graph = defaultdict(list)
    for cycle in ccs1 + ccs2:
        L = len(cycle)
        for i in range(len(cycle)):
            graph[cycle[i]].append(-1*cycle[(i+1) % L])
            graph[-1*cycle[(i+1) % L]].append(cycle[i])

    nc = 0
    remaining = set(graph.keys())
    while len(remaining) > 0:
        nc += 1
        queue = [remaining.pop()]
        while queue:
            current = queue.pop(0)
            queue += filter(lambda node: node in remaining, graph.get(current, []))
            remaining -= set(queue)

    return sum(map(len,P)) - nc


def get_shared_kmer(seq1, seq2, k):
    from collections import defaultdict
    seq1_dict = defaultdict(list)
    for i in range(len(seq1)-k+1):
        seq1_dict[seq1[i:i+k]].append(i)
    return {(i,j) for j in range(len(seq2)-k+1) for i in seq1_dict[seq2[j:j+k]] + seq1_dict[reverse_complement(seq2[j:j+k])]}


def convert_synteny_blocks_to_cycle(synteny_blocks):
    nodes = [0 for i in range(2 * len(synteny_blocks))]
    for i, n in enumerate(synteny_blocks):
        if 0 < n:
            nodes[2 * i] = 2 * n - 1
            nodes[2 * i + 1] = 2 * n
        else:
            nodes[2 * i] = -2 * n
            nodes[2 * i + 1] = -2 * n - 1
    return nodes


def convert_cycle_to_synteny_blocks(nodes):
    synteny_blocks = [0 for i in range(int(1 / 2 * len(nodes)))]
    for i, n in enumerate(nodes[::2]):
        if (n / 2).is_integer():
            synteny_blocks[i] = - int(n / 2)
        else:
            synteny_blocks[i] = int((n + 1) / 2)
    return synteny_blocks


def make_synteny_block_graph(list_of_synteny_blocks):
    edges = []
    for synteny_blocks in list_of_synteny_blocks:
        nodes = convert_synteny_blocks_to_cycle(synteny_blocks)
        nodes.append(nodes[0])
        #print('nodes:', nodes)
        for i, n in enumerate(nodes):
            if i % 2 == 1:
                #print(i, n)
                edges.append((n, nodes[i + 1]))
    return edges


def break_synteny_block_graph(synteny_block_graph_edges):
    synteny_block_edges = dict()
    for from_, to_ in synteny_block_graph_edges:
        from_sign = [- 1, + 1][from_ % 2 == 0]
        to_sign = [+ 1, - 1][to_ % 2 == 0]
        from_v = int((from_ - 1) / 2) + 1
        to_v = int((to_ - 1) / 2) + 1
        #print(from_sign * from_v, to_sign * to_v)
        synteny_block_edges[from_sign * from_v] = to_sign * to_v
        
    cycles = []
    visited = []
    current_from = list(synteny_block_edges.keys())[0]
    visited.append(current_from)
    while len(synteny_block_edges) > 0:
        #print('current_from:', current_from)
        current_to = synteny_block_edges.pop(current_from)
        #print('current_to:', current_to)
        if current_to in visited:
            cycles.append(visited)
            #print('visited:', visited)
            visited = []
            keys = list(synteny_block_edges.keys())
            if len(keys) == 0:
                break
            current_from = keys[0]
            visited.append(current_from)
        else:
            visited.append(current_to)
            current_from = current_to

    return [sorted(c, key=abs) for c in cycles]


def fff(synteny_block_graph_edges, a1, a2, b1, b2):
    to_return = []
    f = set([a1, a2, b1, b2])
    #print(f)
    for e in synteny_block_graph_edges:
        #print(set(v for v in e))
        if len(f & set(v for v in e)) == 0:
            to_return.append(e)
            #print('!')
    to_return.append((a1, b1))
    to_return.append((a2, b2))
    return to_return


def ggg(genome, a1, a2, b1, b2):
    v = convert_synteny_blocks_to_cycle(genome)
    print(v)
    vv = make_synteny_block_graph([v])
    print(vv)
    vvv = fff(vv, a1, a2, b1, b2)
    return break_synteny_block_graph(vvv)
'''
raw = lines[0]
split = raw.split(')')
input_ = [(int(i[0]), int(i[1])) for i in [s.split(',') for s in [s.split('(')[1] for s in split if s != ''] if s != '']]

b = fff(input_, 74, 75, 87, 89)
', '.join(str(s) for s in b)
'''

# Chapter 9

In [None]:
def print_trie(trie):
    for k, v in trie.items():
        if len(v) == 0:
            continue
        for vv in v:
            print('%s->%s:%s' % (k, *vv))
            
            
def make_trie(sequences):
    """
    TRIECONSTRUCTION
    
    """
    trie = dict()
    
    # Root
    c = 0
    trie[c] = []
    
    for seq in sequences:
        seq = seq + '$'
        #print('seq:', seq)
        # Each sequence starts form root
        cur_n = 0
        for nuc in seq:
            #print('cur_n, nuc:', cur_n, nuc)
            next_n = [n for n, e in trie[cur_n] if e == nuc]
            assert len(next_n) < 2
            #print('next_n:', next_n)
            if len(next_n) == 0: # There isn't any edge with ID cur_n
                # Create new node
                c += 1
                trie[c] = []
                # Add this new node as successor of the current node
                trie[cur_n].append((c, nuc))
                # Update the current node to be the new node
                if nuc != '$':
                    cur_n = c
            elif nuc != '$': # There is an edge with ID cur_n
                # Go to the successor
                cur_n = next_n[0]
    return trie


'''
def get_trie_match(sequence, trie):
    i = 0
    nuc = sequence[i]
    cur_n = 0
    while True:
        if len(trie[cur_n]) == 0:
            return 'LEAF'
        else:
            next_n = [n for n, e in trie[cur_n] if e == nuc]
            if len(next_n) > 0:
                nuc = sequence[i]
                cur_n == next_n[0]
            else:
                print('no match')
                return False
'''


def get_trie_match(sequence, trie):
    return_ = explore_trie(0, sequence, trie)
    #print('get_trie_match: return_:', return_)
    if return_:
        return return_[:-1]
    else:
        return False


def explore_trie(from_n, sequence, trie):
    #print('from_n:', from_n)
    #print('sequence:', sequence)
    
    # Get out nodes of from_n
    to_ns = trie[from_n]
    #print('to_ns:', to_ns)
    if '$' in [e for n, e in to_ns]: # Leaf
        #print('$ in to_ns')
        return '$'
    
    if not sequence:
        return None
    
    else:
        nuc = sequence[0]
        if nuc in [e for n, e in to_ns]: # Nucleotide edge found
            cur_n = [n for n, e in to_ns if e == nuc][0]
            #print('Updated cur_n:', cur_n)
            #print('\n')
            recr = explore_trie(cur_n, sequence[1:], trie)
            if recr:
                return nuc + recr
            else:
                return None
        else:
            return None

        
def make_suffix_trie(sequences):
    trie = make_trie(sequences)
    print_trie(trie)
    
    stack = []
    cur_n = 0
    stack.append(cur_n)
    while len(stack) > 0:
        for path_n in trie[stack.pop()]:
            print('popped:', path_n)
            path_s = ''
            path_s += path_n[1]
            cur_n = path_n[0]
            for next_path_n in trie[cur_n]:
                print('next_path_n:', next_path_n)
                stack.append(next_path_n[0])
                
            
def make_suffix_trie_2(sequence):
    trie = {}
    next_n = 0
    trie[next_n] = []
    next_n += 1
    for i in range(len(sequence)):
        cur_n = 0
        for j in range(i, len(sequence)):
            cur_s = sequence[j]
            print('cur_s:', cur_s)
            if cur_s == '$':
                
                continue
            if cur_s in [e for n, e, s in trie[cur_n]]:
                cur_n = [n for n, e, s in trie[cur_n]][0]
            else:
                print('%s --(%s, %s)--> %s' % (cur_n, cur_s, j, next_n))
                trie[next_n] = []
                trie[cur_n].append((next_n, cur_s, j))
                print('trie[cur_n]: after append:', trie[cur_n])
                cur_n = next_n
                next_n += 1
        print('%s --(%s, %s, %s)--> %s' % (cur_n, '$', j, i, next_n))
        trie[cur_n].append((next_n, '$', j, i))
        next_n += 1
    return trie

In [None]:
def make_suffix_array(sequence):
    """
    Make a suffix array.
    """
    if sequence[-1] != '$':
        sequence += '$'
    a = []
    for i in range(len(sequence)):
        a.append((i, sequence[i:]))
    a = sorted(a, key=lambda x: x[1])
    return a


def match_suffix_array(kmer, sequence):
    """
    Make a suffix array and return its indices at which a suffix starts with <kmer>.
    """
    a = make_suffix_array(sequence)
    return [i for i, s in a if s.startswith(kmer)]


def trim_suffix_array(k, suffix_array):
    """
    Trim a suffix_array by keeping indices that are multiple of <k>.
    """
    return [(i, t[0]) for i, t in enumerate(suffix_array) if t[0] % k == 0]


def cyclic_rotate_sequence(seq):
    """
    Return sequences that result from cyclically rotating <seq>.
    """
    return [seq[len(seq) - i:] + seq[0:len(seq) - i] for i in range(len(seq))]


def get_bwa(seq):
    """
    Return the BWA of <seq>.
    """
    return [s[-1] for s in sorted(cyclic_rotate_sequence(seq))]


def get_number_of_occurences_up_to_index(index, sequence):
    """
    Get the number of occurences <sequecne>[<index>] appears up to and inlcuding <index>.
    """
    # Get the value at index
    value = sequence[index]
    
    count = 0
    for v in sequence[:index + 1]:
        if v == value:
            count += 1
    return count
            
    
def get_ith_occuring_index(value, i, sequence):
    """
    Get the index at which <value> appears for the <i>th time.
    """
    count = 0
    for j, v in enumerate(sequence):
        if v == value:
            count += 1
        if count == i:
            return j
    raise ValueError("'{}' does not occure {} times in {}".format(value, i, sequence))


def get_ith_occuring_sequence2_index_of_ith_occuring_sequence1_val(i, sequence1, sequence2):
    """
    Get the <i>th occuring <sequence2> index for the <i>th occuring value in <sequence1>.
    """
    # Get value at <i>
    val = sequence1[i]
    # Get how many times <val> appears up to and including index <i> in <sequence1>
    num_occurences_up_to_i = get_number_of_occurences_up_to_index(i, sequence1)
    # Get <num_occurences_up_to_i>th occurence of <val> in <sequence2>
    return get_ith_occuring_index(val, num_occurences_up_to_i, sequence2)


def make_sequence_from_bwt(bwt):
    """
    Reconstruct a sequene from its <bwt>.
    """
    # Get column 1 by sorting <bwt>
    col1 = sorted(bwt)

    sequence = ''
    # Start from $, the end
    bwt_idx = bwt.index('$')
    start_idx = bwt_idx
    while True:
        # Get the <btw_idx>th value in <col1>
        col1_nuc = col1[bwt_idx]
        sequence += col1_nuc
        
        # Get the number of occurences of <col1_nuc> up to and inclduing <btw_idx>th index
        num_occurences_up_to_bwt_idx = get_number_of_occurences_up_to_index(bwt_idx, col1)
        
        # Get <num_occurences_up_to_bwt_idx>th occurence of <col1_nuc> in <bwt>
        bwt_idx = get_ith_occurence_index(col1_nuc, num_occurences_up_to_bwt_idx, bwt)
        
        # Return sequence if back to $
        if bwt_idx == start_idx:
            return sequence
        
        
def get_first_col1_occurences_of_bwt(bwt):
    col1 = sorted(bwt)
    first_col1_occurences_of_bwt = {}
    cur_v = None
    for i, v in enumerate(col1):
        if v != cur_v:
            print(i)
            cur_v = v


def track_element_occurenes_in_bwt(bwt):
    elem_occurence = {}
    for e in set(bwt):
        c = 0
        elem_occurence[e] = [c for i in range(len(bwt))]
        for i, ee in enumerate(bwt):
            if ee == e:
                c += 1
            elem_occurence[e][i] = c
    return elem_occurence


def match_kmer(sequence, kmers, sequence_is_btw=False):
    #print('sequence:', sequence)
    #print('kmers:', kmers)

    # Get BWT O(n)
    if sequence_is_btw:
        bwt = [n for n in sequence]
    else:
        bwt = [n for n in get_bwa(sequence)]
    #print('\nbwt:', bwt)

    # Get the 1st column O(nlogn)
    col1 = sorted(bwt)
    #print('col1:', col1)

    # Make a map of indices of matching <bwt> index and <col1> index
    map_bwt_to_col1 = {}
    for i in range(len(bwt)):
        map_bwt_to_col1[i] = get_ith_occuring_sequence2_index_of_ith_occuring_sequence1_val(i, bwt, col1)
    #print('\nmap_bwt_to_col1:', map_bwt_to_col1)

    num_occurences = []
    occuring_indices = set()
    
    # For each kmer
    for kmer in kmers:
        found = True
        
        # Define search scope
        col1_idx_first = 0
        col1_idx_last = len(bwt) - 1
        #print('col1_idx_first, col1_idx_last:', col1_idx_first, col1_idx_last)

        #while col1_idx_first <= col1_idx_last:
        while kmer:
            popped = kmer.pop()
            #print('\nkmer, popped:', kmer, popped)
            # Get all occurences of <popped> in <bwt>
            popped_indices_in_bwt = [col1_idx_first + i for i, n in enumerate(bwt[col1_idx_first:col1_idx_last + 1]) if n == popped]
            #print('\npopped_indices_in_bwt:', popped_indices_in_bwt)
            if popped_indices_in_bwt:

                bwt_idx_first = popped_indices_in_bwt[0]
                bwt_idx_last = popped_indices_in_bwt[-1]
                #print('bwt_idx_first, bwt_idx_last:', bwt_idx_first, bwt_idx_last)

                col1_idx_first = map_bwt_to_col1[bwt_idx_first]
                col1_idx_last = map_bwt_to_col1[bwt_idx_last]
                #print('col1_idx_first, col1_idx_last:', col1_idx_first, col1_idx_last)
            else:
                num_occurences.append(0)
                found = False
                break

        if found:
            num_occurences.append(col1_idx_last + 1 - col1_idx_first)
            for i in range(col1_idx_first, col1_idx_last + 1):
                occuring_indices.add(i)

    sfx_array = make_suffix_array(sequence)
        
    return num_occurences, [sfx_array[i][0] for i in occuring_indices]


sequence = lines[0]#'panamabananas$'
sequence += '$'
B = ' '.join(lines[1:]) #
kmers = [list(kmer) for kmer in (B.split(' '))]
matches = match_kmer(sequence, kmers, sequence_is_btw=False)
print(' '.join(sorted((str(i) for i in matches[1]))))

# Chapter 8

In [None]:
def make_data_points(points_line, dictionarize=False):
    points = []
    for p in points_line:
        points.append(tuple(float(i) for i in p.split(' ') if i))
    if dictionarize:
        points = dict(zip(range(0, len(points)), points))
    return points


def get_euclidean_distance(p1, p2):
    assert len(p1) == len(p2), '{} and {} must have the same dimension'.format(p1, p2)
    sqrd_diff_per_dim = []
    for dim, val1 in enumerate(p1):
        sqrd_diff_per_dim.append(math.pow(val1 - p2[dim], 2))
    return math.sqrt(sum(sqrd_diff_per_dim))
    

# TODO: remove center added
def cluster_furthest(num_centers, points):
    points = make_data_points(points)
    print('points:', points)
    
    centers = []
    centers.append(random.choice(range(0, points)))
    print('initial centers:', centers)
    
    min_dist_to_a_center = [None for i in range(len(points))]
    while len(centers) < num_centers:
        for i, p in enumerate(points):
            min_dist_to_a_center[i] = min((get_euclidean_distance(c, p) for c in centers))
        print('min_dist_to_a_center:', min_dist_to_a_center)
        
        centers.append(points[min_dist_to_a_center.index(max(min_dist_to_a_center))])
        print('updated centers:', centers)
        print()

    return centers


def print_points(points):
    for p in points:
        print(' '.join(('{0:.3f}'.format(i) for i in p)))


def calculate_cluster_distortion(points, centers):
    points = make_data_points(points, dictionarize=False)
    #print('points:', points)
    
    num_points = len(points)
    #print('num_points:', num_points)
    
    centers = make_data_points(centers, dictionarize=False)
    #print('centers:', centers)
    
    min_distances = []
    for coord in points:
        #print('coord:', coord)
        distances = []
        for c in centers:
            distances.append(get_euclidean_distance(coord, c))
        min_distances.append(min(distances))
        #print('min_distances:', min_distances)
    
    return sum((math.pow(v, 2) for v in min_distances)) / num_points


def assign_clusters(points, centers):
    # Assign each data point to the cluster corresponding to its nearest center
    cluster_assignments = [None for i in range(len(points))]  # List of integers, <points> indices
    for i, p in enumerate(points):
        min_distance = get_euclidean_distance(p, centers[0])
        min_center_idx = 0
        for ii, c in enumerate(centers[1:]):
            dist = get_euclidean_distance(p, c)
            if dist < min_distance:
                min_distance = dist
                min_center_idx = ii + 1
        cluster_assignments[i] = min_center_idx
    return cluster_assignments


def update_centers(points, cluster_assignments, k):
    num_points = len(points)
    
    # Separate clustering points
    clusters = {}
    for i, c in enumerate(cluster_assignments):
        if c not in clusters:
            clusters[c] = [points[i]]
        else:
            clusters[c].append(points[i])
    print('clusters:', clusters)
    
    # Compute new centers
    centers = [[None for ii in range(len(points[0]))] for i in range(k)]
    print('empty centers:', centers)
    
    # For each cluster, compute a new center
    for c, clustering_points in clusters.items():
        print('c, clustering_points:', c, clustering_points)
        # Sum dimension coordinates
        dimension_sum = {}  # key = dimension; value = dimension sum
        for p in clustering_points:
            for dim, coord in enumerate(p):
                if dim not in dimension_sum:
                    dimension_sum[dim] = coord
                else:
                    dimension_sum[dim] += coord
        
        print('dimension_sum:', dimension_sum)
        
        for dim, sum_ in dimension_sum.items():
            print('centers:', centers)
            centers[c][dim] = sum_ / num_points
    
    return centers
        
        
def cluster_lloyd(k, points):
    points = make_data_points(points)  # List of point coordinates

    cur_centers = random.sample(points, k)  # List of point coordinates
    print('cur_centers:', cur_centers)
    
    while True:
        cluster_assignments = assign_clusters(points, cur_centers)
        #print('cluster_assignments:', cluster_assignments)

        new_centers = update_centers(points, cluster_assignments, k)
        print('new_centers:', new_centers)
        
        
        # If the centers stop changing, exit
        if cur_centers == new_centers:
            return cur_centers
        else:
            cur_centers = new_centers
        print()

# Chapter 7

In [44]:
def read_input_edges(edge_list):
    """
    Read lines like: from->to:weight as dictionary {from:[(to, weight)]}.
    """
    edges = {}
    for line in edge_list:
        from_, to_and_weight = line.split('->')
        to_, weight = to_and_weight.split(':')
        from_ = int(from_)
        to_ = int(to_)
        weight = int(weight)
        
        if from_ not in edges:
            edges[from_] = [(to_, weight)]
        else:
            edges[from_].append((to_, weight))
    return edges


def get_leaves(edges):
    """
    Get leaves, which has 1 neighbor.
    """
    leaves = []
    for from_, to_weight in edges.items():
        if len(to_weight) == 1:
            leaves.append(from_)
    return leaves


def get_distance_between_leaves(from_node, to_node, edges):
    """
    Get distance between 2 nodes.
    """
    #print('Getting the distance from {} to {}'.format(from_node, to_node))
    
    visited = []
    stack = [(from_node, 0)]

    while len(stack) != 0:
        cur_n, cur_d = stack.pop()
        #print('cur_n, cur_d:', cur_n, cur_d)
        
        if cur_n == to_node:
            return cur_d
        elif cur_n in visited:
            continue
        else:
            visited.append(cur_n)

        for to_n, to_d in edges[cur_n]:
            #print('to:', to_n, to_d)
            stack.append((to_n, to_d + cur_d))
            
            
def make_leaf_distance_matrix(lines):
    """
    Distances Between Leaves Problem
    
    Make lead distance matrix from edge lines like from->to:weight.
    """
    #print('lines:', lines)
    edges = read_input_edges(lines)
    #print('edges:', edges)
    leaves = get_leaves(edges)
    #print('leaves:', leaves)

    for from_leaf in leaves:
        #print('from_leaf', from_leaf)
        row = []
        for to_leaf in leaves:
            #print('to_leaf:', to_leaf)
            if to_leaf == from_leaf:
                row.append(0)
            else:
                row.append(get_distance(from_leaf, to_leaf, edges))
        print(' '.join([str(i) for i in row]))

        
def get_limb_length(mtrx, leaf_n):
    """
    Limb Length
    
    Get the length of a leaf node to its parent node.
    """
    
    # Read matrix
    #print('mtrx: {}'.format(mtrx))
    
    # Index leaves and remove
    leaves = list(range(0, len(mtrx)))
    leaves.remove(leaf_n)
    #print('leaves:', leaves)
    
    min_ = 999999999
    # For each pair of leaves excluding <leaf_n>
    for l1, l2 in [i for i in itertools.combinations(leaves, 2)]:
        # Compute the limb length
        val = (int(mtrx[l1][leaf_n]) + int(mtrx[leaf_n][l2]) - int(mtrx[l1][l2])) / 2
        # Look for the min limb length because of limb length theorem
        if val < min_:
            min_ = val
    return int(min_)


def make_phylogeny_tree_using_additive_method(distance_mtrx, n, idx):
    print('distance_mtrx: {}\nn: {}'.format(distance_mtrx, n))
    
    # Base case
    if n == 1:
        print('n is 1; 2 x 2 matrix: {}'.format(distance_mtrx))
        assert distance_mtrx[0][1] == distance_mtrx[1][0]
        print('RETURNING:', {0: [(1, distance_mtrx[0][1])], 1: [(0, distance_mtrx[1][0])]})
        return {0: [(1, distance_mtrx[0][1])], 1: [(0, distance_mtrx[1][0])]}, idx
    
    # Get limb length of leaf n
    limb_length = get_limb_length(distance_mtrx, n)
    print('limb_length: {}'.format(limb_length))
    
    # Alimb_length non-n leaves
    non_n_indices = list(range(0, len(distance_mtrx)))
    non_n_indices.remove(n)
    print('non_n_indices: {}'.format(non_n_indices))
    
    # Remove the impact of leaf n
    for i in non_n_indices:
        distance_mtrx[i][n] -= limb_length
        distance_mtrx[n][i] = distance_mtrx[i][n]
    print('distance_mtrx after updating distances: {}'.format(distance_mtrx))
    
    # Find l1 and l2 such that D(l1,l2) = D(l1,n) + D(n,l2)
    for l1, l2 in [i for i in itertools.combinations(non_n_indices, 2)]:
        if distance_mtrx[l1][l2] == distance_mtrx[l1][n] + distance_mtrx[n][l2]:
            leaf1, leaf2 = l1, l2
            break
    print('l1 and l2: {} {}'.format(l1, l2))
    
    # Get the distance between l1 and n; this distance is used to attach leaf n later
    attaching_point = distance_mtrx[l1][n]
    print('attaching_point: {}'.format(attaching_point))
    
    '''
    # Remove nth column
    distance_mtrx = distance_mtrx[:n] + distance_mtrx[n + 1:]
    # Remove nth row
    for row in distance_mtrx:
        row.pop(n)
    print('distance_mtrx after removing n: {}'.format(distance_mtrx))
    '''
    
    tree, idx = make_phylogeny_tree_using_additive_method([r[:-1] for r in distance_mtrx[:-1]], len(distance_mtrx) - 2, idx)
    
    # (current node, distacne to current node, previous node, distacne to previous node)
    stack = [(l1, 0, None, None)]
    visited = []
    while True:
        cur_n, cur_d, prev_n, prev_d = stack.pop()
        print('popped:', cur_n, cur_d, prev_n, prev_d)
        
        if cur_n in visited:
            continue
        visited.append(cur_n)
        if not can_reach(tree, cur_n, l2, visited):
            print('{} cannot reach {}'.format(cur_n, l2))
            continue
            
        if cur_d == attaching_point:
            print('AT THE ATTACHING POINT')
            tree[cur_n].append((n, limb_length))
            print('{} --({})--> {}'.format(cur_n, limb_length, n))
            assert n not in tree
            tree[n] = [(cur_n, limb_length)]
            print('{} --({})--> {}'.format(n, limb_length, cur_n))
            return tree, idx
        
        elif cur_d > attaching_point:
            print('PASSED THE ATTACHING POINT')
            # Make a new node
            new_n = idx
            idx += 1
            print('new_n: {}'.format(new_n))
            #
            tree[prev_n].append((new_n, attaching_point - prev_d))
            print('{} --({})--> {}'.format(prev_n, attaching_point - prev_d, new_n))
            assert new_n not in tree
            tree[new_n] = [(prev_n, attaching_point - prev_d)]
            print('{} --({})--> {}'.format(new_n, attaching_point - prev_d, prev_n))
            #
            tree[cur_n].append((new_n, cur_d - attaching_point))
            print('{} --({})--> {}'.format(cur_n, cur_d - attaching_point, new_n))
            tree[new_n].append((cur_n, cur_d - attaching_point))
            print('{} --({})--> {}'.format(new_n, cur_d - attaching_point, cur_n))
            #
            tree[new_n].append((n, limb_length))
            print('{} --({})--> {}'.format(new_n, limb_length, n))
            assert n not in tree
            tree[n] = [(new_n, limb_length)]
            print('{} --({})--> {}'.format(n, limb_length, new_n))
            
            # Remove edges
            tree[prev_n] = [(nn, dd) for nn, dd in tree[prev_n] if nn != cur_n]
            tree[cur_n] = [(nn, dd) for nn, dd in tree[cur_n] if nn != prev_n]
            
            return tree, idx
        
        else:
            for to_n, to_d in tree[cur_n]:
                stack.append((to_n, cur_d + to_d, cur_n, cur_d))


def can_reach(tree, from_n, to_n, visited):
    print('TREE:', tree)
    local_visited= [n for n in visited if n != from_n]
    stack = [from_n]
    while stack:
        print('stack:\t{}'.format(stack))
        cur_n = stack.pop()
        print('cur_n:\t{}'.format(cur_n))
        if cur_n == to_n:
            return True
        elif cur_n in local_visited:
            print('{} is in {}'.format(cur_n, local_visited))
            continue
        else:
            local_visited.append(cur_n)
            for n, d in tree[cur_n]:
                stack.append(n)
    return False


def AddiditePhylogeny():
    idx = 1
    mtrx = [list(map(int, ln)) for ln in [ln.split(' ') for ln in lines[1:]]]
    result, idx = make_phylogeny_tree_using_additive_method(mtrx, len(mtrx) - 1, len(mtrx))
    for k, v in result.items():
        for vv, dd in v:
            print('{}->{}:{0:.3g}'.format(k, vv, dd))

In [89]:
def upgma(lines):
    # Read in matrix
    mtrx_orig = [list(map(int, ln)) for ln in [ln.strip().split('\t') for ln in lines]]
    mtrx = [list(map(int, ln)) for ln in [ln.strip().split('\t') for ln in lines]]
    
    node_name = 0
    
    # Create individual clusters
    clusters = [tuple([i]) for i in range(len(mtrx))]  # [(c1), (c2), ...]
    
    # Label matrix
    labels = [c for c in clusters]  # [(c1), (c2), ...]
    
    # Make ages and tree
    ages = {}  # (c1): age
    tree = {}  # (c_from): [((c_to), dist), ...]
    for c in clusters:
        ages[c] = (0, node_name)
        node_name += 1
        tree[c] = [[None, 0]]

    
    # Cluster into 1
    while len(clusters) > 1:
        print('\n\nclusters: {}'.format(clusters))
        print('labels: {}'.format(labels))
        print('ages: {}'.format(ages))
        print('tree: {}'.format(tree))
        
        # Find the closest 2 clusters
        cluster1 = csluter2 = None
        min_dist = 999999999
        for c1, c2 in itertools.combinations(clusters, 2):
            #print('c1: {}; c2: {}'.format(c1, c2))
            dist = get_mean_pairwise_distance(mtrx_orig, c1, c2)
            if dist < min_dist:
                min_dist = dist
                cluster1, cluster2 = c1, c2
                print('Updated clustering pair: {}--<{}>--{}'.format(cluster1, min_dist, cluster2))
        
        # Merge 2 closest clusters and add to clusters list
        clusters.remove(cluster1)
        clusters.remove(cluster2)
        new_cluster = cluster1 + cluster2
        clusters.append(new_cluster)
        #print('\nupdated clusters: {}'.format(clusters))
        
        # Add the ages of the new cluster
        ages[new_cluster] = (min_dist / 2, node_name)
        node_name += 1
        #print('updated ages: {}'.format(ages))
        
        # Add the new cluster into the tree
        tree[new_cluster] = [[cluster1, 0], [cluster2, 0]]
        tree[cluster1].append([new_cluster, 0])
        tree[cluster2].append([new_cluster, 0])
        #print('updated tree: {}'.format(tree))
        
        # Get indices of cluster 1 and cluster 2
        i1, i2 = sorted([labels.index(cluster1), labels.index(cluster2)])
        print('i1: {}; i2: {}'.format(i1, i2))

        # Remove cluster 1 and cluster 2 and add new cluster to the label
        labels.remove(cluster1)
        labels.remove(cluster2)
        labels.append(new_cluster)
        #print('updated labels: {}'.format(labels))
        
        # Make a row for the new cluster
        new_row = []
        for lab in labels:
            if lab == new_cluster:
                new_row.append(0)
            else:
                new_row.append(get_mean_pairwise_distance(mtrx_orig, new_cluster, lab))
        print('new_row: {}'.format(new_row))
        #print('mtrx: {}'.format(mtrx))
                                                          
        if len(mtrx) > 3:
            # Remove cluster 1 and cluster 2 and add new cluster to the matrix
            #mtrx = [row[:i1] + row[i1 + 1:i2] + row[i2 + 1:] + [new_row[i]] for i, row in enumerate(mtrx) if i != i1 and i != i2]
            #mtrx.append(new_row)
            print('updated mtrx: {}'.format(mtrx))
    
    for k, v in tree.items():
        for vv, dd in v:
            if not vv:
                continue
            dist = abs(ages[k][0] - ages[vv][0])
            print('{}->{}:{}'.format(ages[k][1], ages[vv][1], dist))
        
        
def get_mean_pairwise_distance(mtrx, cluster1, cluster2):
    #print('get_mean_pairwise_distances:\nmtrx: {}\ncluster1: {}\ncluster2: {}'.format(mtrx, cluster1, cluster2))
    sum_ = 0
    assert not set(cluster1) & set(cluster2), 'cluster 1 {} and cluster 2 {} should not have an intersection'.format(cluster1, cluster2)
    for c1 in cluster1:
        for c2 in cluster2:
            sum_ += mtrx[c1][c2]
    return sum_ / (len(cluster1) * len(cluster2))
upgma(lines[1:])



clusters: [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,), (11,), (12,), (13,), (14,), (15,), (16,), (17,), (18,), (19,)]
labels: [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,), (11,), (12,), (13,), (14,), (15,), (16,), (17,), (18,), (19,)]
ages: {(15,): (0, 15), (0,): (0, 0), (1,): (0, 1), (2,): (0, 2), (8,): (0, 8), (3,): (0, 3), (9,): (0, 9), (4,): (0, 4), (10,): (0, 10), (5,): (0, 5), (16,): (0, 16), (11,): (0, 11), (6,): (0, 6), (17,): (0, 17), (12,): (0, 12), (7,): (0, 7), (18,): (0, 18), (13,): (0, 13), (19,): (0, 19), (14,): (0, 14)}
tree: {(15,): [[None, 0]], (0,): [[None, 0]], (1,): [[None, 0]], (2,): [[None, 0]], (8,): [[None, 0]], (3,): [[None, 0]], (9,): [[None, 0]], (4,): [[None, 0]], (10,): [[None, 0]], (5,): [[None, 0]], (16,): [[None, 0]], (11,): [[None, 0]], (6,): [[None, 0]], (17,): [[None, 0]], (12,): [[None, 0]], (7,): [[None, 0]], (18,): [[None, 0]], (13,): [[None, 0]], (19,): [[None, 0]], (14,): [[None, 0]]}
Updated cluste

In [79]:
def make_nj_mtrx(mtrx):
    n = len(mtrx)
    
    total_dist = [sum(row) for row in mtrx]
    #print('total_dist: {}'.format(total_dist))
    
    nj_mtrx = [[0 for i in range(len(mtrx[0]))] for i in range(len(mtrx))]
    #print('nj_mtrx: {}'.format(nj_mtrx))
    
    for i, row in enumerate(nj_mtrx):
        for j, col in enumerate(row):
            if i == j:
                continue
            #print('i: {} j: {}'.format(i, j))
            row[j] = (n - 2) * mtrx[i][j] - total_dist[i] - total_dist[j]
    return nj_mtrx

In [85]:
def NJ2(mtrx, n, idx):
    #print('mtrx: {}\nn: {}\nidx: {}'.format(mtrx, n, idx))
    
    total_dist = [sum(row) for row in mtrx]
    #print('total_dist: {}'.format(total_dist))
    
    # Base case
    if n == 1:
        #print('n is 1; 2 x 2 matrix: {}'.format(mtrx))
        return {0: [(1, mtrx[0][1])], 1: [(0, mtrx[1][0])]}, idx
    
    # Make NJ mtrx
    nj_mtrx = make_nj_mtrx(mtrx)
    #print('nj_mtrx: {}'.format(nj_mtrx))
    
    # Find i and j such that nj_mtrx[i][j] is the minimum non-diagonal element
    min_ = 999999999
    for i, j in itertools.combinations(range(len(nj_mtrx)), 2):
        if nj_mtrx[i][j] < min_:
            min_ = nj_mtrx[i][j]
    #print('i: {}; j: {}'.format(i, j))
    
    delta = (total_dist[i] - total_dist[j]) / (n - 2)
    limb_length_i = 0.5 * (mtrx[i][j] + delta)
    limb_length_j = 0.5 * (mtrx[i][j] - delta)
    print('lli: {}; llj: {}'.format(limb_length_i, limb_length_j))
    
    new_row = [0.5 * (mtrx[k][i] + mtrx[k][j] - mtrx[i][j]) for k in range(len(mtrx))]
    print('new_row:', new_row)
    return
    
    
    # Get limb length of leaf n
    limb_length = get_limb_length(mtrx, n)
    print('limb_length: {}'.format(limb_length))
    
    # Alimb_length non-n leaves
    non_n_indices = list(range(0, len(mtrx)))
    non_n_indices.remove(n)
    print('non_n_indices: {}'.format(non_n_indices))
    
    # Remove the impact of leaf n
    for i in non_n_indices:
        mtrx[i][n] -= limb_length
        mtrx[n][i] = mtrx[i][n]
    print('mtrx after updating distances: {}'.format(mtrx))
    
    # Find l1 and l2 such that D(l1,l2) = D(l1,n) + D(n,l2)
    for l1, l2 in [i for i in itertools.combinations(non_n_indices, 2)]:
        if mtrx[l1][l2] == mtrx[l1][n] + mtrx[n][l2]:
            leaf1, leaf2 = l1, l2
            break
    print('l1 and l2: {} {}'.format(l1, l2))
    
    # Get the distance between l1 and n; this distance is used to attach leaf n later
    attaching_point = mtrx[l1][n]
    print('attaching_point: {}'.format(attaching_point))
    
    '''
    # Remove nth column
    mtrx = mtrx[:n] + mtrx[n + 1:]
    # Remove nth row
    for row in mtrx:
        row.pop(n)
    print('mtrx after removing n: {}'.format(mtrx))
    '''
    
    tree, idx = make_phylogeny_tree_using_additive_method([r[:-1] for r in mtrx[:-1]], len(mtrx) - 2, idx)
    
    # (current node, distacne to current node, previous node, distacne to previous node)
    stack = [(l1, 0, None, None)]
    visited = []
    while True:
        cur_n, cur_d, prev_n, prev_d = stack.pop()
        print('popped:', cur_n, cur_d, prev_n, prev_d)
        
        if cur_n in visited:
            continue
        visited.append(cur_n)
        if not can_reach(tree, cur_n, l2, visited):
            print('{} cannot reach {}'.format(cur_n, l2))
            continue
            
        if cur_d == attaching_point:
            print('AT THE ATTACHING POINT')
            tree[cur_n].append((n, limb_length))
            print('{} --({})--> {}'.format(cur_n, limb_length, n))
            assert n not in tree
            tree[n] = [(cur_n, limb_length)]
            print('{} --({})--> {}'.format(n, limb_length, cur_n))
            return tree, idx
        
        elif cur_d > attaching_point:
            print('PASSED THE ATTACHING POINT')
            # Make a new node
            new_n = idx
            idx += 1
            print('new_n: {}'.format(new_n))
            #
            tree[prev_n].append((new_n, attaching_point - prev_d))
            print('{} --({})--> {}'.format(prev_n, attaching_point - prev_d, new_n))
            assert new_n not in tree
            tree[new_n] = [(prev_n, attaching_point - prev_d)]
            print('{} --({})--> {}'.format(new_n, attaching_point - prev_d, prev_n))
            #
            tree[cur_n].append((new_n, cur_d - attaching_point))
            print('{} --({})--> {}'.format(cur_n, cur_d - attaching_point, new_n))
            tree[new_n].append((cur_n, cur_d - attaching_point))
            print('{} --({})--> {}'.format(new_n, cur_d - attaching_point, cur_n))
            #
            tree[new_n].append((n, limb_length))
            print('{} --({})--> {}'.format(new_n, limb_length, n))
            assert n not in tree
            tree[n] = [(new_n, limb_length)]
            print('{} --({})--> {}'.format(n, limb_length, new_n))
            
            # Remove edges
            tree[prev_n] = [(nn, dd) for nn, dd in tree[prev_n] if nn != cur_n]
            tree[cur_n] = [(nn, dd) for nn, dd in tree[cur_n] if nn != prev_n]
            
            return tree, idx
        
        else:
            for to_n, to_d in tree[cur_n]:
                stack.append((to_n, cur_d + to_d, cur_n, cur_d))

def NJ():
    idx = 1
    mtrx = [list(map(int, ln)) for ln in [ln.split(' ') for ln in lines[1:]]]
    result, idx = NJ2(mtrx, len(mtrx) - 1, len(mtrx))
    for k, v in result.items():
        for vv, dd in v:
            print('{}->{}:{0:.3g}'.format(k, vv, dd))
mtrx = [[0, 13, 21, 22], [13, 0, 12, 13], [21, 12, 0, 13], [22, 13, 13, 0]]

NJ2(mtrx, len(mtrx), 0)

lli: 6.0; llj: 7.0
new_row: [15.0, 6.0, 0.0, 0.0]


# Parse input

In [87]:
DOWNLOAD_DIRECTORY_PATH = os.path.join(os.environ['HOME'], 'Downloads')
with open('/Users/Kwat/Downloads/rosalind_ba7d.txt', 'r') as f:
    content = f.read()
    lines = content.split('\n')
    lines = [i for i in lines if i != ''] # Be careful about lines containing only ''
for i, line in enumerate(lines):
    print('line {}:\t{}'.format(i, line[:50]))

line 0:	20
line 1:	0	765	486	692	420	759	741	562	477	584	716	550	579	
line 2:	765	0	428	614	778	749	642	708	729	657	455	514	695	
line 3:	486	428	0	645	432	661	578	609	607	754	585	437	711	
line 4:	692	614	645	0	705	576	704	669	628	651	660	662	546	
line 5:	420	778	432	705	0	718	522	558	781	782	452	495	557	
line 6:	759	749	661	576	718	0	780	415	784	795	474	470	631	
line 7:	741	642	578	704	522	780	0	582	685	783	552	592	634	
line 8:	562	708	609	669	558	415	582	0	748	597	548	613	563	
line 9:	477	729	607	628	781	784	685	748	0	521	402	559	797	
line 10:	584	657	754	651	782	795	783	597	521	0	652	610	663	
line 11:	716	455	585	660	452	474	552	548	402	652	0	551	536	
line 12:	550	514	437	662	495	470	592	613	559	610	551	0	775	
line 13:	579	695	711	546	557	631	634	563	797	663	536	775	0	
line 14:	479	714	774	435	460	664	401	723	709	593	632	694	51
line 15:	596	535	442	612	755	706	591	406	457	433	701	625	47
line 16:	482	555	639	667	556	621	506	752	412	507	697	698	76
line 17:	588	647	599	456	430	654	501	4