In [2]:
import numpy as np
import queue

In [3]:
class Node:
    # stored in the DP table while filling it out for easy traceback
    score = 0
    coord = None
    # two in case of bifurcation
    back_pointer1 = None
    back_pointer2 = None
    # useful booleans for backtrace and other functions
    bifurcate = False
    match = False

In [4]:
class Nussinov:
    def __init__(self, s, T):
        # pass a string and a dictionary T that tells us what matches each key
        # T structure: { 'A': 'U'
        #             ...
        #              ... }
        
        self.s = s
        self.T = T
        self.table = np.array([[Node() for _ in range(len(s))] for _ in range(len(s))])
        self.trace = None # backtrace, currently None
        
    def solve(self):
        """
        DP solution to nussinov
        """
        
        # we fill the table diagonally
        cur_i, cur_j = 0, 1
        next_i_start, next_j_start = 0, 2

        while cur_i != 0 or cur_j != len(self.s):
            if cur_j - cur_i >= 2: # if bifurcation is possible
                options = [(self.table[cur_i,k].score + self.table[k+1,cur_j].score, (cur_i,k), (k+1,cur_j)) for k in range(cur_i+1, cur_j)]
                bifurcation = max(options, key = lambda x: x[0]) # choose max based on score (index 0 of tuple)
                bifurcation = bifurcation
            else:
                bifurcation = (0, None, None)

            if self.T[self.s[cur_i]] == self.s[cur_j]: # if we have a match
                match = (1 + self.table[cur_i+1,cur_j-1].score, (cur_i+1, cur_j-1))
            else:
                match = (0, None)

            down = (self.table[cur_i+1,cur_j].score, (cur_i+1,cur_j))
            left = (self.table[cur_i,cur_j-1].score, (cur_i,cur_j-1))

            # tie breaking would happen here
            choice = max(left, down, match, bifurcation, key=lambda x: x[0])
            new_node = Node()
            new_node.score = choice[0]
            new_node.coord = (cur_i, cur_j)
            if choice[1] == (cur_i+1,cur_j-1): # sees if we chose to match
                new_node.match = True
            
            if len(choice) == 3: # bifurcation
                new_node.back_pointer1 = choice[1]
                new_node.back_pointer2 = choice[2]
                new_node.bifurcate = True
            else: # regular
                new_node.back_pointer1 = choice[1]
            
            self.table[cur_i, cur_j] = new_node

            cur_i += 1
            cur_j += 1

            if cur_j >= len(self.s):
                cur_i = next_i_start
                cur_j = next_j_start
                next_j_start += 1

        return self.table[0, len(self.s)-1].score # top right corner has the answer
    
    
    def backtrace(self):
        # backtraces and returns sequences of nodes that led to the final path
        if self.trace is not None: # already computed
            return self.trace
        
        start_node = self.table[0,len(self.s)-1]
        self.trace = []
        
        node_queue = queue.Queue()
        node_queue.put(start_node)
        
        while not node_queue.empty():
            cur_node = node_queue.get()
            self.trace.append(cur_node)
            
            if cur_node.back_pointer1 is not None:
                pointer = cur_node.back_pointer1
                node_queue.put(self.table[pointer])
                
            if cur_node.back_pointer2 is not None:
                pointer = cur_node.back_pointer2
                node_queue.put(self.table[pointer])
                
        return self.trace
    
    
    def dot_parentheses(self):
        if self.trace is None:
            self.backtrace()
        
        result = ['-' for _ in range(len(s))] # gapped sequence
        for node in self.trace:
            if not node.match:
                continue
            i, j = node.coord
            result[i] = '('
            result[j] = ')'
        
        # printing
        ref = ' '.join(self.s)
        print(ref)
        middle = ''
        for _ in range(len(ref)):
            middle += '-'
        print(middle)
        print(' '.join(result))
        

In [5]:
T = {
    'A': 'U',
    'G': 'C',
    'C': 'G',
    'U': 'A',
}

In [6]:
s = 'GCUCGGG UUCCC UAU UCA AGAGC'.replace(' ', '') # should be 10
ns = Nussinov(s, T)
ns.solve()

10

In [7]:
ns.dot_parentheses()

G C U C G G G U U C C C U A U U C A A G A G C
---------------------------------------------
( ( ( ( ( ( ( - - ) ) ) ( ( ) ( - ) ) ) ) ) )


In [8]:
s = 'GCACGACG'
ns = Nussinov(s, T)
ns.solve()

3

In [9]:
ns.dot_parentheses()

G C A C G A C G
---------------
( ) - ( ( - ) )


In [14]:
s = 'ACUG'
ns = Nussinov(s, T)
ns.solve()

1

In [15]:
ns.dot_parentheses()

A C U G
-------
( - ) -


In [11]:
s = 'A ACUG GAUC GGUUCA'.replace(' ', '')
ns = Nussinov(s, T)
ns.solve()

5

In [12]:
ns.dot_parentheses()

A A C U G G A U C G G U U C A
-----------------------------
( ( ( - ) ( ( ) ) - - ) ) - -
