# Rosalind problem26

Nicholas Rose

BME 205

Due Date: December 6, 2021

### Baum-Welch Learning Problem:

Given: A sequence of emitted symbols x = x1 . . . xn in an alphabet A, generated by a k-state HMM with unknown transition and emission probabilities, initial Transition and Emission matrices and a number of iterations I.

Return: A matrix of transition probabilities Transition and a matrix of emission probabilities Emission that maximizes Pr(x,π) over all possible transition and emission matrices and over all hidden paths π.

In [72]:
import math
import numpy as np
import itertools

In [73]:
class InfileParse():
    '''
    This class object parses an given infile.
    Includes attributes:
    self.i (int),
    self.x (string),
    self.char (list),
    self.states (list),
    self.transTable (list),
    self.emTable (list)
    '''
    
    def __init__(self, infile):
        
        self.i = 0
        self.x = ''
        self.char = []
        self.states = []
        self.transTable = []
        self.emTable = []
        
        with open(infile) as f:
            self.i = f.readline().strip()
            next(f)
            self.x = f.readline().strip()
            next(f)
            self.char = f.readline().strip().replace("\t", "").replace(" ", "")
            next(f)
            self.states = f.readline().strip().replace("\t", "").replace(" ", "")
            next(f)
            for i, line in enumerate(f.readlines()):
                if i - 1 < len(self.states):
                    self.transTable.append(line.strip().split())
                if i - 1 > len(self.states):
                    self.emTable.append(line.strip().split())
    
    
    def createTable(self, table):
        '''
        Method initialized by a list (table)
        of file lines (strings), representing the header and rows
        of a matrix. Returns a numpy array containing table data.
        '''
        
        matrix = np.zeros((0,len(table[-1])-1))
        states = []
        for i, line in enumerate(table):
            if i != 0:
                matrix = np.vstack([matrix, np.array(list(map(float, line[1:])))])
                    
        return matrix

In [124]:
class Viterbi():
    '''
    Object 'Viterbi'.
    Includes attributes:
    self.nodes (list), self.prob (int/float)
    and method:
    self.forward(self, x, states, transitionTable, emissionTable)
    '''
    
    def __init__(self, x, states):
    
        self.x = x
        self.states = states
        self.fNodes = np.zeros((len(x), len(states)))
        self.bNodes = np.zeros((len(x), len(states)))
        self.piStar = np.zeros((len(x), len(states)))
        self.piStarStar = np.zeros(((len(x)-1), len(states)**2))
        self.prob = 0
       
    def forward(self, char, transitionTable, emissionTable):
        '''
        Method which accepts a a string x, possible hidden states, an 
        transition table, and an emission table. Returns the probabity of
        string x being emmited by a given HMM considering all possible
        hidden states. This is done with the forward algorithy.
        '''
        
        for i, nucleo in enumerate(self.x):
            if i == 0:
                start = 1 / len(self.states)
                for j in range(len(self.states)):
                    self.fNodes[i][j] = start * emissionTable[j][char.index(nucleo)]
            else:
                a = np.transpose(self.fNodes[i-1])
                b = np.matmul(a, transitionTable)
                c = b * emissionTable[:,char.index(nucleo)]
                d = np.transpose(c)
                self.fNodes[i] = d
                
        self.prob = np.sum(self.fNodes[-1])
                
        return self.fNodes
    
    
    def backward(self, char, transitionTable, emissionTable):
        '''
        Method which accepts a a string x, possible hidden states, an 
        transition table, and an emission table. Returns the probabity of
        string x being emmited by a given HMM considering all possible
        hidden states. This is done with the forward algorithy.
        '''
        
        for i, nucleo in enumerate(self.x[::-1]):
            prevNucleo = self.x[::-1][i-1]
            if i == 0:
                for j in range(len(self.states)):
                    self.bNodes[i][j] = 1
            else:
                a = self.bNodes[i-1] * emissionTable[:,char.index(prevNucleo)]
                b = np.matmul(transitionTable, a)
                c = np.transpose(b)
                self.bNodes[i] = c
         
        self.bNodes = self.bNodes[::-1]
        
        return self.bNodes
    
    
    def softDecode(self):
        '''
        Method which returns the likelihood of a sequence
        passing through each hidden state at position i.
        This is done by summing the forward and backward
        algorithms at i, and dividing by the total probability
        given by the forward algorithm.
        These values are returned as a numpy array.
        '''
        
        for i in range(len(self.x)):
            self.piStar[i] = (self.fNodes[i] * self.bNodes[i]) / self.prob
                
        return self.piStar
    
    
    def starStar(self, char, transitionTable, emissionTable):
        '''
        Method which returns the likelihood of a sequence
        passing through each transiton path at position i.
        This is done by summing the forward and backward
        algorithms at i with the tranisition value and emission value of i+1,
        dividing by the total probability given by the forward algorithm.
        These values are returned as a numpy array.
        '''
        for i in range(len(self.x)-1):
            column = []
            for j in range(len(self.states)):
                for k in range(len(self.states)):
                    column.append((self.fNodes[i][j] * 
                                  transitionTable[j][k] * 
                                  emissionTable[k][char.index(self.x[i+1])] *
                                  self.bNodes[i+1][k]) / self.prob)
            self.piStarStar[i] = np.array(column)
            
        return self.piStarStar

In [125]:
class EmissionTable():
    '''
    Object 'EmissionTable'. 
    Initialized by a sequence (x),
    a list of characters in x (char),
    a hidden path (pi), and a list of states in pi (states).
    Includes attributes:
    self.states (list),
    self.char (list)
    self.stateCount (dictionary), and
    self.transitions (dictionary),
    which are populated when creating an object.
    Includes method(s):
    matrix(self).
    '''
    
    def __init__(self, x, char, piStar, states):
        
        self.states = states
        self.char = char
        self.stateCount = {}
        self.emissions = {}
        
        for i in self.states:
            for j in char:
                self.emissions[i + j] = 0
        for i in states:
            self.stateCount[i] = 0
            
        for i, character in enumerate(x):
            for j, state in enumerate(states):
                self.emissions[states[j] + character] += piStar[i][j]
                self.stateCount[states[j]] += piStar[i][j]
            
            
    def matrix(self):
        '''
        Method initialized by two dictionaries containing
        the counts of states and emissions.
        Returns a numpy array containing emission table data.
        '''
        
        transProb = {}
        for i in self.emissions:
            if self.stateCount[i[0]] != 0:
                transProb[i] = self.emissions[i] / self.stateCount[i[0]]
            else:
                transProb[i] = 1 / len(self.char)
        
        data = list(transProb.values())
        matrix = np.reshape(data, (len(self.states), len(self.char)))
        
        return matrix

In [126]:
class TransitionTable():
    '''
    Object 'TransitionTable'. 
    Initialized by a hidden path (pi),
    and a list of states in pi (states).
    Includes attributes:
    self.states (list),
    self.stateCount (dictionary), and
    self.transitions (dictionary),
    which are populated when creating an object.
    Includes method(s):
    matrix(self).
    '''
    
    def __init__(self, piStarStar, states):
        
        self.states = states
        self.stateCount = {}
        self.transitions = {}
        
        trans = itertools.product(''.join(states), repeat=2)
        for i in trans:
            self.transitions[(i[0] + i[1])] = 0
        for i in states:
            self.stateCount[i] = 0
            
        colTotals = []
        stateTotals = []
        for i in zip(*piStarStar):
            colTotals.append(sum(i))
        for i in range(len(states)):
            stateTotals.append(sum(colTotals[i*len(states):(i+1)*len(states)]))
        for i, key in enumerate(self.transitions):
            self.transitions[key] = colTotals[i]
        for i, key in enumerate(self.stateCount):
            self.stateCount[key] = stateTotals[i]
            
            
    def matrix(self):
        '''
        Method initialized by two dictionaries containing
        the counts of states and transitions.
        Returns a numpy array containing transition table data.
        '''
        
        transProb = {}
        for i in self.transitions:
            if self.stateCount[i[0]] != 0:
                transProb[i] = self.transitions[i] / self.stateCount[i[0]]
            else:
                transProb[i] = 1 / len(self.states)
        
        data = list(transProb.values())
        matrix = np.reshape(data, (len(self.states), len(self.states)))
        
        return matrix

In [127]:
class Iterate():
    '''
    Object 'Iterate'.
    Includes attributes:
    self.data (InFileParse),
    self.transitionTableStart (np.array)
    self.emissionTableStart (np.array)
    self.transitionTableFinal (int / np.array)
    self.emissionTableFinal (int / np.array)
    and method:
    self.iterate(self)
    '''
    
    def __init__(self, data, transitionTable, emissionTable):
        
        self.data = data
        self.transitionTableStart = transitionTable
        self.emissionTableStart = emissionTable
        self.transitionTableFinal = 0
        self.emissionTableFinal = 0
        
    def iterate(self):
        '''
        Method that iterates i (self.data.i) times with
        a while loop. The Object and methods above are
        called and the perameters are updated for every
        iteration.
        '''
        
        n = 0
        ind = int(self.data.i)
        transitionTable = self.transitionTableStart
        emissionTable = self.emissionTableStart
        
        while n <= ind:
            viterbi = Viterbi(self.data.x, self.data.states)
            forward = viterbi.forward(self.data.char, transitionTable, emissionTable)
            backward = viterbi.backward(self.data.char, transitionTable, emissionTable)
            decode = viterbi.softDecode()
            starStar = viterbi.starStar(self.data.char, transitionTable, emissionTable)
            emissionTable = EmissionTable(self.data.x, self.data.char, decode, self.data.states).matrix()
            transitionTable = TransitionTable(starStar, self.data.states).matrix()
            n += 1
        
        self.transitionTableFinal = transitionTable
        self.emissionTableFinal = emissionTable

## Main

In [129]:
def main(infile):
    '''
    The main method. This method takes file containing
    a number of iterations, a sequence, a character list, 
    possible hidden states, a transition table, and an emission table.
    This method parses the file and runs the above methods.
    Output is printed below and to an output file 
    'rosalind_26.txt.out'
    '''
    
    data = InfileParse(infile)
    transitionTable = data.createTable(data.transTable)
    emissionTable = data.createTable(data.emTable)
    result = Iterate(data, transitionTable, emissionTable)
    result.iterate()
    
    
    with open('rosalind_26.txt.out', 'w') as out:
        print(*data.states, sep='\t', file=out)
        print(*data.states, sep='\t')
        for i, state in enumerate(result.transitionTableFinal):
            print(data.states[i], *state, sep='\t', file=out)
            print(data.states[i], *state, sep='\t')
        print('--------', file=out)
        print('--------')
        print('\t', *data.char, sep='\t', file=out)
        print('\t', *data.char, sep='\t')
        for i, state in enumerate(result.emissionTableFinal):
            print(data.states[i], *state, sep='\t', file=out)
            print(data.states[i], *state, sep='\t')
    
    
if __name__ == "__main__":
    main('/home/nick_rose/Downloads/rosalind_ba10k (1).txt')

A	B	C	D
A	1.5089909263353943e-11	0.5562878510215516	0.44343867641921303	0.00027347254414543067
B	0.10410680825067568	0.0001481708094359903	0.3595029434962177	0.5362420774436706
C	6.045930752178501e-05	0.3782820097116574	0.06678439178301289	0.554873139197808
D	0.43219736664710473	0.3912004961643097	0.00016279248796387317	0.1764393447006217
--------
		x	y	z
A	1.350567520953893e-11	1.1730308979761306e-10	0.9999999998691914
B	4.5888533329638097e-07	0.7265666067421621	0.273432934372505
C	0.7608067292444124	0.23363843777360316	0.0055548329819842986
D	0.7310130540498011	0.05280887668095777	0.2161780692692416
