In [1]:
import numpy as np # using numpy primarly so I can access using [a,b] notation
from queue import Queue
import itertools
import tqdm.notebook as tqdm
import copy

from utils import AlteredList

In [2]:
class Node:
    score = 0
    match_ptrs = []
    
    def __repr__(self):
        return str(self.score)

In [3]:
class Nussinov:
    def __init__(self, s, T):
        self.s = s
        self.T = T

        
    def solve(self, min_padding=0):
        """
        DP solution to nussinov. Runtime: |s|^3
        
        min_padding: padding between i and j at minimum
        tie_break_permute: pass a list that tells what order to prioritize the following if there's a tie
            - left (L)
            - down (D)
            - match (M)
            - bifurcate (B)
        """        
        
        # initializes variables for the class to use
        self.table = AlteredList(len(self.s), Node)
        self.trace = None
        
        # we fill the table diagonally, I'll use a while loop cause it helps me
        cur_i, cur_j = 0, 1
        next_i_start, next_j_start = 0, 2
        
        # if over 50 characters, I'll show a progress bar
        need_bar = False if len(s) < 50 else True
        # progress bar, total found via some simple geometry on the table
        if need_bar:
            pbar = tqdm.tqdm(total= (len(s)-1)*len(s)/2)
        
        prev, count = 0, 0
        while cur_i != 0 or cur_j != len(self.s):
            if cur_j - cur_i >= 2: # if bifurcation is possible
                options = [(self.table[cur_i,k].score + self.table[k+1,cur_j].score, (cur_i,k), (k+1,cur_j)) for k in range(cur_i+1, cur_j)]
                bifurcation = max(options, key = lambda x: x[0]) # choose max based on score (index 0 of tuple)
            else:
                bifurcation = (0, None, None)
                
            # if we have a match in the allowable distance
            if self.T[self.s[cur_i]] == self.s[cur_j] and cur_j - cur_i - 1 >= min_padding:
                match = (1 + self.table[cur_i+1,cur_j-1].score, (cur_i+1, cur_j-1))
            else:
                match = (0, None)

            down = (self.table[cur_i+1,cur_j].score, (cur_i+1,cur_j))
            left = (self.table[cur_i,cur_j-1].score, (cur_i,cur_j-1))
            
            best = 0
            match_ptrs = set()
            for opt in [down, left, match, bifurcation]:
                if opt[0] > best:
                    best = opt
                    match_ptrs = set()
                    for ptr in opt[1].match_ptrs:
                        match_ptrs.add(ptr)
                        
                    if len(opt) == 3:
                        for ptr in opt[2]:
                            match_ptrs.add(opt[2].match_ptrs)
                elif opt[0] == best:
                    for ptr in opt[1].match_ptrs:
                        match_ptrs.add(ptr)
                        
                    if len(opt) == 3:
                        for ptr in opt[2]:
                            match_ptrs.add(opt[2].match_ptrs)
            
                
                    
            self.table[cur_i, cur_j] = max([down, left, match, bifurcation])

            cur_i += 1
            cur_j += 1
            
            if need_bar:
                count += 1
                if count % 100 == 0:
                    pbar.update(count - prev)
                    prev = count

            if cur_j >= len(self.s):
                cur_i = next_i_start
                cur_j = next_j_start
                next_j_start += 1
                
        if need_bar:
            pbar.update(count - prev)    
            pbar.close()

        return self.table[0, len(self.s)-1] # top right corner has the answer
    
    
    def backtrace(self):
        # backtraces and returns sequences of nodes that led to the final path
        if self.trace is not None: # already computed
            return self.trace
        
        self.trace = []
        
        queue = Queue()
        queue.put((0, len(self.s)-1))
        while not queue.empty():
            cur_i, cur_j = queue.get()
            
            cur_score = self.table[cur_i, cur_j]
            # no more matchings for this spot and its precedents
            if cur_score == 0:
                continue
        
            # check down
            elif self.table[cur_i+1, cur_j] == cur_score:
                queue.put((cur_i+1, cur_j))
                continue
            
            # check left
            elif self.table[cur_i, cur_j-1] == cur_score:
                cur_i, cur_j = cur_i, cur_j-1
                continue
            
            # check match
            elif self.table[cur_i+1, cur_j-1] + 1 == cur_score:
                self.trace.append((cur_i, cur_j))
                queue.put((cur_i+1, cur_j-1))
                continue
            
            # check bifurcate
            elif cur_j - cur_i >= 2: # if bifurcation is possible
                options = [(self.table[cur_i,k] + self.table[k+1,cur_j], (cur_i,k), (k+1,cur_j)) for k in range(cur_i+1, cur_j)]
                bif_score, cord1, cord2 = max(options, key=lambda x: x[0])
                if bif_score == cur_score:
                    queue.put(cord1)
                    queue.put(cord2)
                    continue
                    
            raise ValueError('Could not backtrace correctly')
                
        return self.trace
    
    
    def dot_parentheses(self, prettify=True):
        """
        Creates a string showing the dot parenthesis version of the structure
        """
        if self.trace is None:
            self.backtrace()
        
        result = ['-' for _ in range(len(s))] # gapped sequence
        for pair_idx1, pair_idx2 in self.trace:
            if pair_idx1 > pair_idx2:
                pair_idx1, pair_idx2 = pair_idx2, pair_idx1
            result[pair_idx1] = '('
            result[pair_idx2] = ')'
            
        # the raw string
        result = ' '.join(result)
        if not prettify:
            return result
        
        # printing
        ref = ' '.join(self.s)
        middle = ''
        for _ in range(len(ref)):
            middle += '-'
        return ref + '\n' + middle + '\n' + result
    
    
    def evaluate_tie_breaks(self):
        """
        Passes all tie permutations to save and stores the dot_parenthesis computation
        """
        
        to_permute = ['L', 'D', 'M', 'B']
        perms = list(itertools.permutations(to_permute))
        
        results = {}
        for perm in perms:
            self.solve(tie_break_permute=perm)
            dot_parenth = self.dot_parentheses(for_printing=True)
            if dot_parenth in results:
                results[dot_parenth].append(perm)
            else:
                results[dot_parenth] = [perm]
        return results
        

In [4]:
T = {
    'A': 'U',
    'G': 'C',
    'C': 'G',
    'U': 'A',
}

In [5]:
s = 'GCACGACG'
ns = Nussinov(s, T)
ns.solve()

3.0

In [6]:
ns.backtrace()

[(0, 1), (3, 7), (4, 6)]

In [7]:
ns.dot_parentheses(prettify=False)

'( ) - ( ( - ) )'

In [8]:
print(ns.dot_parentheses())

G C A C G A C G
---------------
( ) - ( ( - ) )


In [9]:
print(ns.dot_parentheses())

G C A C G A C G
---------------
( ) - ( ( - ) )


In [10]:
res = ns.evaluate_tie_breaks()

In [11]:
for k in res:
    print(k)
    print(len(res[k]))

G C A C G A C G
---------------
( ) - ( ( - ) )
8
G C A C G A C G
---------------
( ) - ( ) - ( )
16


In [12]:
s = 'GCUCGGG UUCCC UAU UCA AGAGC'.replace(' ', '') # should be 10
ns = Nussinov(s, T)
ns.solve()

10.0

In [13]:
print(ns.dot_parentheses())

G C U C G G G U U C C C U A U U C A A G A G C
---------------------------------------------
( ( ( ( ( ( ( - - ) ) ) ( ( ) ( - ) ) ) ) ) )


In [14]:
ns.solve(tie_break_permute=['B', 'M', 'L', 'D'])
print(ns.dot_parentheses())

G C U C G G G U U C C C U A U U C A A G A G C
---------------------------------------------
( ) ( ( ) ( ( - - ) ) ( ( ) ( ( - ) ) ) ) ( )


In [15]:
ns.solve(min_padding=1)
print(ns.dot_parentheses())

G C U C G G G U U C C C U A U U C A A G A G C
---------------------------------------------
( ( ( ( ( ( ( - - ) ) ) ( ( - ) - ) - ) ) ) )


In [16]:
# cat coding-RNA from https://rnacentral.org/rna/URS00000F3C2D/9685
# under 1000 in size
s = 'CAAAGGUUUGGUCCUGGCCUUUCCAUUAGUUAUUAAUAAGAUUACACAUGCAAGCCUCCGCAUCCCGGUGAAAAUGCCCUCUAAGUCACCCAGUGACCUAAAGGAGCUGGUAUCAAGCACACAACCACAGUAGCUCAUAACACCUUGCUCAGCCACACCCCCACGGGAUACAGCAGUGAUAAAAAUUAAGCCAUGAAUGAAAGUUCGACUAAGCUAUAUUAAACAAGGGUUGGUAAAUUUCGUGCCAGCCACCGCGGCCAUACGAUUAACCCAAACUAAUAGACCCACGGCGUAAAGCGUGUUACAGAGAAAAAAAUAUACUAAAGUUAAAUUUUAACUAGGCCGUAGAAAGCUACAGUUAACAUAAAAAUACAGCACGAAAGUAACUUUAACACCUCCGACUACACGACAGCUAAGACCCAAACUGGGAUUAGAUACCCUACUAUGCUUAGCCCUAAACUUAGAUAGUUACCCUAAACAAAACUAUCCGCCAGAGAACUACUAGCAAUAGCUUAAAACUCAAAGGACUUGGCGGUGCUUUACAUCCCUCUAGAGGAGCCUGUUCUAUAAUCGAUAAACCCCGAUAUACCUCACCAUCUCUUGCUAAUUCAGCCUAUAUACCGCCAUCUUCAGCAAACCCUAAAAAGGAAGAAAAGUAAGCACAAGUAUCUUAACAUAAAAAAGUUAGGUCAAGGUGUAGCUCAUGAGAUGGGAAGCAAUGGGCUACAUUUUCUAAAAUUAGAACACCCACGAAGAUCCUUACGAAACUAAGUAUUAAAGGAGGAUUUAGUAGUAAAUUUGAGAAUAGAGAGCUCAAUUGAAUCGGGCCAUGAAGCACGCACACACCGCCCGUCACCCUCCUCAAGUGGUAACUCCCAAAAAAACCUAUUUAAAUUAUCACACCCACAAGAGGAGAUAAGUCGUAACAAGGUAAGCAUACUGGAAAGUGUGCUUGGAUAA'


In [17]:
# takes a while
# making the initial table actually takes a noticeable amount of time
# perhaps using np.zeros might make the array creation be faster
ns = Nussinov(s, T)
ns.solve(min_padding=1)

HBox(children=(IntProgress(value=0, max=460320), HTML(value='')))




369

In [18]:
print(ns.dot_parentheses(for_printing=False))

( - ( ( ( ( ( ( ( ( ( ( ( ( ( ( - ) ( ( ( ( ( ( - ( ( ( ( ( ( ( ( ( ( - ) ) ) ( ( ( ( - ( ( - ( ( ( - - ( ( - ) ) ( - ) ) ) ) ( ( - ) ) ) ) ) ) ( - ) ( - ) - ) ) ) ) ) ( ( ( ( - - ( - ) ) ) ) ) ) ) ) ) - ) ) ) ( - ( ( ( ( - ( - ) ) - ) - ) ) ) ) ) ( ( ( ( - ) ( ( ( ( ( - ( ( - ) ( ( ( ( ( ( ( ( ( - ) ) ) - ) - ( ( ( ( ( ( - - ) ) ) - ( ( - ( ( ( - ) ( ( - ( - ) - ( ( ( ( - ) ( - ) ) ) ( ( ( ( ( - - - ( ( ( ( ( ( ( ( - ) ) - ( ( ( ( ( ( - ) ) - ( ( ( ( ( ( ( ( ( ( ( ( ( - ) ( ( ( ( ( - ) ( ( ( ( - ( - ) - ) ) ) - - ) ) ) ) ) ) ) ) - ) ) - ) ( ( ( ( - ( ( - ( - ) - ( ( - ( ( - ( ( - ) ( ( ( ( ( ( ( - ) ) ) ( ( ( - - ( ( ( ( - ) ( ( ( ( ( - ) - ) ) ( - ) ) ) ) ) ) - - ) ) ( ( ( ( - ) ( - ( ( ( ( ( - ( ( ( - ) ) ( - ( ( - ( - ) - ( - ) - ( - ( - - ) ) - - - ) ) ) ) ) ) ) ) - - ) ) ) ) ) ) ) ) - ) ( - ( ( ( ( - ) - ) ( ( - ) ) ( ( ( ( - - ( - ) ) ) ) - ) ) - ) ) ) - ) ) ) ( ( - ) ) ) ) ( ( ( ( ( - ) - ) ) ( ( - ) ) ) ) ) ) ) ) ) ) ) ) ) ) ( ( ( ( ( - ( ( ( - ) - ) ( ( ( ( ( ( ( ( ( - - ) ) 

In [19]:
a = '.....................((.(((((((.............(((.(((..((....))..............(((((....((((.....))))....))).)).(((....(((..((.......)).)))......)))...............((....)).....))))))...............(((((....)))))...................((((((((..............................)))))))).))))))).)).....(((..............................((((((......((((((..((........))...))))))..............((....)).))))))....................((((((........((((.........)))).....))))))....................................))).(((.......(((....))).....)))........(((((.(((((((...((.........((((((((((...((((........))))........(((((((...........((.((((..(((((.............(((....)))..................(((......))).........)))...))))))))....)))))))..)).)))))))).....(((((....)))))..........(((.((((......)))).)))......................(((((.........)))))........))...)))))))))).))..............((((((..((......................................))..))))))................((((((((((....)))))))))).....'




In [20]:
a = a.replace(')', '(')
c = 0
for l in a:
    c += 1 if l == '(' else 0
c

342

In [None]:
s = 'GCACGACG'
ns = Nussinov(s, T)
ns.solve()

In [None]:
ns.dot_parentheses()

In [None]:
s = 'ACUG'
ns = Nussinov(s, T)
ns.solve()

In [None]:
ns.dot_parentheses()

In [None]:
s = 'A ACUG GAUC GGUUCA'.replace(' ', '')
ns = Nussinov(s, T)
ns.solve()

In [None]:
ns.dot_parentheses()