In [9]:
import os
import math

class MEMEoops:
    """Abstract base class for learning motif using MEME oops model.

    A functional class defines some basic methods: raw input data processing, generalization of E step and M step,
    selecting the best candidate motif/starting point and several other helper functions.
    """
    def __init__(self,input_path,W,model="OOPS"):
        """Constructor that takes in model parameters"""
        self.input_path = input_path
        self.W = W
        self.model = model
        self.best_candidate = None
        self.best_candidate_loglikeli = None
        self.init_pwm = None
        self.init_z = None
        self.starting_pwm = None
        
    def enum_candidate(self):
        """Given a list of sequences, enumerate all possible motif candidates by exhaustively going
        over each possible position in all sequences.
        """
        assert self.model.upper() == "OOPS","Please provide valid model name. Note OOPS is the only model available at this time point."
        ##
        seq_list = []
        candidate_list = []
        with open(os.path.join(self.input_path)) as f:
            lines = f.readlines()
            for line in lines:
                tmp_seq = line.strip()
                seq_list.append(tmp_seq)
                for i in range(len(tmp_seq)- self.W + 1):
                    candidate_list.append(tmp_seq[i:i + self.W])
        # save seq_list and candidate_list and char count
        self.seq_list = seq_list
        self.candidate_list = list(set(candidate_list))
        #
        char_count = {"A":0,"C":0,"G":0,"T":0}
        for i in self.seq_list:
            for g in i:
                if g in "ACGT":
                    char_count[g]+=1
                else:
                    print("Invalid sequence detected.")
                    break
                    return None
                
        self.char_count = char_count
        #return (seq_list,list(set(candidate)))
        
    def init_pnz(self,candidate_motif,row_name = "ACGT",pi=0.7):
        """Initialize a pwm for a given motif candidate.
        
        Args:
            candidate_motif: the pwm for a given motif candidate
            row_names: all possible characters in order, default "ACGT"
            pi: the initial probablity for the dominate character, default: 0.7
        Returns:
            save the initial pwm and initial z matrix as instance attributes
        """
        for c in candidate_motif:
            assert c in "ACGT"
        # pwm init
        pwm = {row_name[0]:[0] * (len(candidate_motif)+1),
               row_name[1]:[0] * (len(candidate_motif)+1),
               row_name[2]:[0] * (len(candidate_motif)+1),
               row_name[3]:[0] * (len(candidate_motif)+1)}
        pwm[row_name[0]][0] = pwm[row_name[1]][0] = pwm[row_name[2]][0] = pwm[row_name[3]][0] = 0.25
        for i in range(0,len(candidate_motif)):
            for j in row_name:
                if j == candidate_motif[i]:
                    pwm[j][i+1] = pi
                else:
                    pwm[j][i+1] = round((1-pi)/(len(row_name)-1),5)
        # z init
        z = []
        for k in range(len(self.seq_list)):
            z.append([0.25] * self.W)
            
        # save init_z and init_pwm
        self.init_pwm = pwm
        self.init_z = z
    
    ##############################
    ### E: reestimate z with p ###
    ##############################
    def E(self,pwm):
        """ The E step - reestimate z matrix (t) using a input pwm (t-1)
        
        Args:
            pwm: the pwm at time (t-1)
            
        Returns:
            save the reestimated z matrix (t) in the instance field
        """
        if self.init_pwm is None:
            print("Please run picking_starting_point() to pick a starting postion first.")
            pass
        else:
            z_new = []
            for i in range(len(self.seq_list)):
                tmp_seq = self.seq_list[i]
                tmp_z_row = [0] * (len(tmp_seq)- self.W + 1)
                for j in range(len(tmp_z_row)):
                    tmp_z_row[j] = self.seq_prob_helper(tmp_seq, pwm,j)
                z_new.append(self.z_normalize_helper(tmp_z_row))
            #
            self.cur_z = z_new

        
    def z_normalize_helper(self,tmp_z_row):
        """A helper function to normalize a raw newly updated Z matrix by row (let the row sum up to 1).
        Args:
            tmp_z_row: the z matrix to normlaize  
        Returns:
            a normalized z matrix
        """
        total = 0
        for v in tmp_z_row:
            total+=v
        return [z/total for z in tmp_z_row]

    def seq_prob_helper(self, individual_seq,pwm,j):
        """ A helper function to calcualte the probablity of observing a given sequence with the motif (of length W)
            starting at position j.
        Args:
            individual_seq: the input sequence 
            pwm: the current pwm 
            j: the target starting position of the motif (of length W)
        Returns:
            a probablity
        """
        multiply = 1
        t = 0
        for p in range(len(individual_seq)):
            if p >= j and p <= j + self.W - 1:
                t+=1
                multiply = multiply * pwm[individual_seq[p]][t]
            else:   
                multiply = multiply * pwm[individual_seq[p]][0]
        return multiply
    
    ##############################
    ### M: reestimate p with z ###
    ##############################
    def M(self,z,sudo_count = 1, row_name = "ACGT"):
        """ The M step: using the current Z matrix to reestimate the pwm
        Args:
            z: the current z matrix
            sudo_count: sudo count to add
            row_name: all possible characters in order, default "ACGT"
        Returns:
            save the reestimated pwm to instance field
        """
        #
        assert len(self.seq_list) == len(z)
        #
        new_pwm = {row_name[0]:[0] * (self.W + 1),
                   row_name[1]:[0] * (self.W + 1),
                   row_name[2]:[0] * (self.W + 1),
                   row_name[3]:[0] * (self.W + 1)}
        
        # accumulate observed z (allocate to each position)
        for seq_tmp_idx in range(len(self.seq_list)):
            seq_tmp = self.seq_list[seq_tmp_idx]
            for j in range(len(seq_tmp)- self.W + 1):
                motif_tmp = seq_tmp[j:(j + self.W)]
                for k in range(len(motif_tmp)):
                    tmp_char = motif_tmp[k]
                    #print("now fou : " + tmp_char)
                    new_pwm[tmp_char][k+1] += z[seq_tmp_idx][j]

        # calculate background
        for c in row_name:
            tmp = new_pwm[c]
            new_pwm[c][0] = self.char_count[c] - sum(new_pwm[c][1:])
            
        # normalize
        new_pwm_norm = self.normlize_pwd(new_pwm)
        
        #
        self.cur_pwm = new_pwm_norm
        
    #
    def normlize_pwd(self,pwm,row_name = "ACGT"):
        """ A helper method to normalize a pwm by column
        Args:
            pwm: the input pwm to normalize
        Returns:
            A normllized pwm
        """
        assert len(pwm) == 4
        for i in range(self.W + 1):
            tmp_sum = 4 + pwm[row_name[0]][i] + pwm[row_name[1]][i] + pwm[row_name[2]][i] + pwm[row_name[3]][i]
            for c in row_name:
                pwm[c][i]  = round((1+pwm[c][i])/tmp_sum,5)
        return pwm
    
    ########################################################
    ### likelihold: calculate likelihold using updated P ###
    ########################################################
    def obtain_likelihood(self):
        """ Calculate the current log_e(probablity) of observing all the sequence using current pwm
        """
        #
        log_prob_accum = 0
        for i in range(len(self.seq_list)): # each seq
            tmp_seq = self.seq_list[i]
            tmp_prob = 0
            for j in range(len(tmp_seq)- self.W + 1):
                tmp_prob += self.seq_prob_helper(tmp_seq,self.cur_pwm,j)
            #
            log_prob_accum += math.log(tmp_prob/len(self.seq_list))
        #
        return log_prob_accum
    
    ##############################
    ### picking starting point: ###
    ##############################
    def picking_starting_point(self):
        """ A helper function to iterate over all possible motif candidates and select the one 
            with minimal log_e likelihood
        """
        if not self.best_candidate == None:
            print("Starting point already exists: thus passing")
            pass
        else:
            print("Total number of candidate: ", len(self.candidate_list))
            for candidate in self.candidate_list:
                self.init_pnz(candidate)
                self.E(self.init_pwm)
                self.M(self.cur_z)
                tmp_likelihood = self.obtain_likelihood()
                if self.best_candidate_loglikeli is None or tmp_likelihood > self.best_candidate_loglikeli:
                    self.best_candidate = candidate
                    self.best_candidate_loglikeli = tmp_likelihood  
                    self.starting_pwm = self.init_pwm 
            #
            print("The best candidate is : ", self.best_candidate)
            print("The corresponding likelihood is : ", self.best_candidate_loglikeli)


In [4]:
test1 = MEMEoops("example1.txt",W = 4)
test1.enum_candidate()
test1.picking_starting_point()

Total number of candidate:  239
The best candidate is :  ATCC
The corresponding likelihood is :  -1367.374319162761


In [None]:
test6 = MEMEoops("example2.txt",W = 10)
test6.enum_candidate()
test6.picking_starting_point()

In [None]:
# select candidate
W = 5

def select_candidate(input_path,W):
    seq = []
    candidate = []
    with open(os.path.join(input_path)) as f:
        lines = f.readlines()
        for line in lines:
            #print(line)
            tmp_seq = line.strip()
            
            # save all training seqs for ref
            seq.append(tmp_seq)
            
            # take one line each time and enumerate motif candidates 
            for i in range(len(tmp_seq)- W + 1):
                #print(tmp_seq[i:i+W])
                candidate.append(tmp_seq[i:i+W])
                
    return (seq,list(set(candidate)))

test_seq, test_candidate = select_candidate("example1.txt",W)

In [None]:
select_candidate("example1.txt",5)
print(len(test_seq))
print(len(test_candidate))
#print(len(test_candidate))
#test_candidate
"CACAT" in test_candidate

In [None]:
# init_pwm
def init_pnz(W,pi,candidate_motif,num_seq,row_name = "ACGT"):
    for c in candidate_motif:
        assert c in "ACGT"
    # pwm init
    pwm = {row_name[0]:[0] * (len(candidate_motif)+1),
           row_name[1]:[0] * (len(candidate_motif)+1),
           row_name[2]:[0] * (len(candidate_motif)+1),
           row_name[3]:[0] * (len(candidate_motif)+1)}
    pwm[row_name[0]][0] = pwm[row_name[1]][0] = pwm[row_name[2]][0] = pwm[row_name[3]][0] = 0.25
    for i in range(0,len(candidate_motif)):
        for j in row_name:
            if j == candidate_motif[i]:
                pwm[j][i+1] = pi
            else:
                pwm[j][i+1] = round((1-pi)/(len(row_name)-1),5)
    # z init
    z = []
    for k in range(num_seq):
        z.append([0.25] * W)
    return (pwm,z)

In [None]:
test_pwm, test_z = init_pnz(W,0.7,"CACAT",len(test_seq))

test_pwm

In [None]:
### E: reestimate z with p
# 
def z_normalize_helper(tmp_z_row):
    total = 0
    for v in tmp_z_row:
        total+=v
    return [z/total for z in tmp_z_row]
#
def seq_prob_helper(seq,pwm,W,j):
    multiply = 1
    t = 0
    for p in range(len(seq)):
        #print(p,j+W-1)
        if p >= j and p <= j+W-1:
            t+=1
            multiply = multiply * pwm[seq[p]][t]
        else:   
            multiply = multiply * pwm[seq[p]][0]
    return multiply

def seq_prob_helper_c(seq,pwm,W,j):
    multiply = 1
    t = 0
    for p in range(len(seq)):
        if p >= j and p <= j+W-1:
            t+=1
            print(pwm[seq[p]][t])
            multiply = multiply * pwm[seq[p]][t]
        else:   
            print(pwm[seq[p]][0])
            multiply = multiply * pwm[seq[p]][0]
    print()
    return multiply

## 
def E(seq, pwm, W):
    z_new = []
    for i in range(len(seq)):
        tmp_seq = seq[i]
        tmp_z_row = [0] * (len(tmp_seq)-W+1)
        for j in range(len(tmp_z_row)):
            tmp_z_row[j] = seq_prob_helper(tmp_seq,pwm,W,j)
#             if i == 49:
#                 tmp_z_row[j] = seq_prob_helper_c(tmp_seq,pwm,W,j)
#             else:    
#                 tmp_z_row[j] = seq_prob_helper(tmp_seq,pwm,W,j)
            
        z_new.append(z_normalize_helper(tmp_z_row))
        #z_new.append(tmp_z_row)
        
    return z_new

In [None]:
# CACAT
test_seq[-1]

In [None]:
test_pwm

In [None]:
# For motif: CACAT
# GGGGTTTAGCCCTTC     C     G     C     G     T
0.25**15 *           0.1 * 0.1 * 0.1 * 0.1 * 0.7

In [None]:
#test_output_z = E(test_seq,test_pwm,5)

In [None]:
### M: reestimate p with z
def count_char_helper(seq):
    char_count = {"A":0,"C":0,"G":0,"T":0}
    for i in seq:
        for g in i:
            if g in "ACGT":
                char_count[g]+=1
            else:
                print("Invalid sequence detected.")
                break
                return None
    return char_count
char_count = count_char_helper(test_seq)

#
def normlize_pwd(pwm,W,row_name = "ACGT"):
    assert len(pwm) == 4
    for i in range(W+1):
        tmp_sum = pwm[row_name[0]][i] + pwm[row_name[1]][i] + pwm[row_name[2]][i] + pwm[row_name[3]][i]
        for c in row_name:
            pwm[c][i]  = round(pwm[c][i]/tmp_sum,5)
    return pwm

#
def M(seq, z, W, sudo_count = 1, row_name = "ACGT"):
    assert len(seq) == len(z)
    #
    new_pwm = {row_name[0]:[0] * (W+1),
               row_name[1]:[0] * (W+1),
               row_name[2]:[0] * (W+1),
               row_name[3]:[0] * (W+1)}
    # sum z

    #
    # accumulate observed z (allocate to each position)
    for seq_tmp_idx in range(len(seq)):
        seq_tmp = seq[seq_tmp_idx]
        for j in range(len(seq_tmp)-W+1):
            motif_tmp = seq_tmp[j:(j+W)]
            for k in range(len(motif_tmp)):
                tmp_char = motif_tmp[k]
                #print("now fou : " + tmp_char)
                new_pwm[tmp_char][k+1] += z[seq_tmp_idx][j]
    
    # get backgound count
    char_count = count_char_helper(seq)
    
    # calculate background
    for c in row_name:
        tmp = new_pwm[c]
        new_pwm[c][0] = char_count[c] - sum(new_pwm[c][1:])
    
    # normalize
    new_pwm_norm = normlize_pwd(new_pwm,W)
    
    return new_pwm_norm  
    

M(test_seq,test_output_z,5)
#test_output_z[:1]


In [None]:
# calculate likelihold using updated P
def seq_prob_helper(seq,pwm,W,j):
    multiply = 1
    t = 0
    for p in range(len(seq)):
        #print(p,j+W-1)
        if p >= j and p <= j+W-1:
            t+=1
            multiply = multiply * pwm[seq[p]][t]
        else:   
            multiply = multiply * pwm[seq[p]][0]
    return multiply

def obtain_likelihood(seq,pwm,W):
    #
    log_prob_accum = 0
    for i in range(len(seq)): # each seq
        tmp_seq = seq[i]
        tmp_prob = 0
        for j in range(len(tmp_seq)-W+1):
            tmp_prob += seq_prob_helper(tmp_seq,pwm,W,j)
        #
        log_prob_accum += math.log(tmp_prob/len(seq))
    #
    return log_prob_accum

In [None]:
obtain_likelihood(test_seq,test_pwm,5)

In [None]:
def get_EM_starting_point(seq_list,candidate_list,W,pi):
    highest_log = None
    highest_candide= ""
    pwm_choosen = None
    #
    for candidate in candidate_list:
        tmp_pwm,tmp_z = init_pnz(W,pi,candidate,len(seq_list))
        
        tmp_log = obtain_likelihood(seq_list,tmp_pwm,W)
        #
        if highest_log is None or tmp_log > highest_log:
            highest_log = tmp_log
            highest_candide = candidate
            pwm_choosen =tmp_pwm
    return pwm_choosen,highest_candide,highest_log


In [None]:
get_EM_starting_point(test_seq,test_candidate,5,0.7)

In [1]:
(11 + 1) / (9*4 +4)

0.3

In [3]:
print((8 + 1) / (9*4 +4))
print((12 + 1) / (9*4 +4))
print((5 + 1) / (9*4 +4))

0.225
0.325
0.15


In [7]:
print((1+1) / (9+4))
print((3+1) / (9+4))
print((3+1) / (9+4))
print((2+1) / (9+4))

0.15384615384615385
0.3076923076923077
0.3076923076923077
0.23076923076923078


In [8]:
(0.4615* 0.1538* 0.6154* 0.3077) / (0.225* 0.325* 0.325* 0.225)

2.5135137815873874

In [9]:
(0.2308 * 0.0769* 0.1538 * 0.3077 )/(0.325* 0.225* 0.225* 0.325)

0.15707758581662

In [10]:
(0.2308* 0.2308 * 0.1538 * 0.1538 )/(0.325* 0.325 * 0.225* 0.3 )

0.17673142739588513

In [11]:
(0.2308 * 0.0769 * 0.0769 * 0.3077  )/(0.325 * 0.225 * 0.3 * 0.325 )

0.058904094681232505

In [12]:
(0.4615 *  0.4615* 0.6154 * 0.2307)/(0.225 * 0.3 *  0.325 * 0.15  )

9.189039201718517

In [13]:
a = 0.1570
b = 0.1767
c = 0.0589
d = 9.1890
e = 2.5135
e/(a+b+c+d+e)

0.2078114277682698

In [14]:
(1 + 2 + 1 +1) / ((9+4) +(9+4)) 

0.19230769230769232

In [15]:
(5 +3 + 1+1) / ((9+4) +(9+4))

0.38461538461538464

In [16]:
(2+ 3 + 1 + 1) / ((9+4) +(9+4))

0.2692307692307692

In [17]:
(1 + 1 + 1 +1) / ((9+4) +(9+4))

0.15384615384615385

In [18]:
(5 + 1 + 1 +1) / ((9+4) +(9+4))

0.3076923076923077

In [19]:
(0 + 7 + 1 +1) / ((9+4) +(9+4))

0.34615384615384615

In [20]:
(1 + 1 + 1 +1) / ((9+4) +(9+4))

0.15384615384615385

In [21]:
(3 + 0 + 1 +1) / ((9+4) +(9+4))

0.19230769230769232

In [24]:
# import os
# with open(os.path.join("example2_subseqs.txt")) as f:
#     lines = f.readlines()
#     for line in lines:
#         tmp = line.strip()
A = [0.20683, 0.11878, 0.10153, 0.10733, 0.06868, 0.07549, 0.26911, 0.26915, 0.9026, 0.9055,  0.5344,  0.36969]
C = [0.40942, 0.0113,  0.02767, 0.78335, 0.29476, 0.01104, 0.0111,  0.17589, 0.01218,0.01571, 0.03985, 0.29997]
G = [0.11839, 0.04278, 0.28387, 0.04077, 0.02012, 0.59377, 0.70771, 0.53331, 0.0199, 0.06647, 0.22403, 0.10967]
T = [0.26536, 0.82714, 0.58692, 0.06855, 0.61644, 0.3197,  0.01208, 0.02165, 0.06533,0.01232, 0.20171, 0.22067]


CTTCTAGAAATG
TTTCCGGGAAAT
GTTCCGGGAATT
TGTCTGGAAATC
CTTCTTGGAGAC
TTTATTGGAAGA
ATTCCGGAAAGA
ATGCTGGGAAAT
CAGCGGGAAATC
ATTATGGATAGA
TTGCCGACAACA
ATACTGGGAAAC
TTTCTGACAATT
AGACTGGATAAA
CTTCCGGGAAAA
TATCCGGGAAAG
GTACTGGAAAAC
ATTCTGGGAAGT
AATCCGGGAGAC
GTCCTGACAGAC
GTGATAGAAAGT
ATACCGGGAAAT
CTTCTGACAAGA
ATTCCGGAAATG
ATGACGACAGAC
TTTCTGGGAAAA
CTGCTTGGAATT
CTTCCGACAACC
ATTCTTGGAAAA
ATTCTTGATAGA
CTTCTTGGAATC
TTGCTGGGAATC
ATGCTTGGAAAC
ATATTGGGAAAC
CTACATGGAAGA
TTTCCGGAAAAT
ATCCTGACAAAC
CTTCTTGGAAGC
GTTCTTGGAAAT
TTTATGGGAAGT
CTGCTGACAAAC
CATTTGATAAAA
TATCTGACAATT
CTTCTTGGAAAT
ATACTTGATAAA
TTTCTTGGAAAC
CTGCTGGGAAGT
GTTCCGGGAAGC
GTTCCGACAATA
CTTCAGAGAAAT
CTTTCGGGAATC
GTGGATAGAAAA
CTGCTGGGAAAC
CTGCCGGAAAAA
CTGCTTGAAAAA
CTTCATGGAAAT
CTTCCGACAATG
TATCTTGGAATC
TTGCTGGGAAGA
CTTCTGAGAAAG
TTGCTGGGAAAA
TATCTTGGAAAT
CTTCTGACAAAA
CTTGATAGAAGA
CTGCCAGAAAAG
TTTCTTGAAAAT
TTTACGGAGATA
ATACTGGAAATA
CGTATTGAAATG
CTTTCGACAAAT
CTTCTTGAAAAA
CTTATTAGACTT
CTTCTTGGAAAC
CTGCTGACAAGA
GTTCCGACAAGC
TTTCCGGGAGAC
TATCTTGAAAAA