In [232]:
from Comparative_Analysis import Sequence_Analysis_Routines as sar
from Comparative_Analysis import HMM as hmm
from Comparative_Analysis import Utilities as util
from Comparative_Analysis import Alignment_HMM as alignment_hmm
from Comparative_Analysis import Alignment_Analysis as alignment_analysis
from Comparative_Analysis import Alignment as align
from numpy.random import default_rng
import numpy as np
from scipy.stats import norm
from scipy.stats import binom
from scipy import optimize as opt
import seaborn as sns
import math
from tqdm import tqdm

In [233]:
rng = default_rng()

In [234]:
def create_transition_matrix(params):
    a = params[0]
    b = (1-params[0])*(params[1])
    c = 1-a-b
    e = params[2]
    d = (1-params[2])*(params[3])
    f = 1-e-d
    i = params[4]
    g = (1-params[4])*(params[5])
    h = 1 - i - g
    transition_probabilities = np.array([[a,b,c],[d,e,f],[g,h,i]])
    mutation_probabilities = params[6:]
    return transition_probabilities, mutation_probabilities

In [235]:
def sim_multinomial(probs):
    a = np.where(rng.multinomial(1, probs) == 1)[0][0]
    return a

In [236]:
def sum_logs(p, q):
        if p>9999 and q>99999:
            ans = math.log(math.exp(p) + math.exp(q))
        else:
            if p > q:
                ans =  p + math.log(1 + math.exp(q - p))
            else:
                ans =  q + math.log(1 + math.exp(p - q))
        return ans

In [237]:
def normal_draw(state, means):
    return rng.standard_normal() + means[state]

In [238]:
def binomial_draw(state, probs, size):
    return rng.binomial(size, probs[state])

##### Simulate from HMM

In [239]:
sample_size = 10000

In [240]:
num_states = 3
num_comparison_sequences = 10
means = [2, 7, 10]
mutation_probs = [0.9, 0.5, 0.1]
initial_probs = [0.333, 0.333, 0.334]
transition_matrix = np.array([[0.9, 0.075, 0.025], [0.7, 0.2, 0.1], [0.5, 0.3, 0.2]])

In [267]:
def calculate_observation_probabilities(observations, mutation_probs):
    observation_probs = np.zeros((num_states, sample_size))
    for i in range(sample_size):
        for state in range(num_states):
            observation_probs[state, i] = binom.pmf(observations[i],  num_comparison_sequences, mutation_probs[state])
    return observation_probs 

In [268]:
states = np.zeros(sample_size)
observation_probabilities = np.zeros((num_states, sample_size))
for i in range(sample_size):
    if i == 0:
        current_state = sim_multinomial(initial_probs)
    else:
        current_state = sim_multinomial(transition_matrix[current_state,:])
    states[i] = current_state
    #observations[i] = normal_draw(current_state, means)
    observations[i] = binomial_draw(current_state, mutation_probs, num_comparison_sequences)
    
observation_probabilities = calculate_observation_probabilities(observations, mutation_probs)

In [269]:
def calculate_likelihood(params):
    trans_matrix, mutation_probabilities = create_transition_matrix(params)
    observation_probabilities = calculate_observation_probabilities(observations, mutation_probabilities)
    hmm_model = hmm.HMM(initial_probs, trans_matrix, observation_probabilities)
    hmm_model.calculate_probabilities()
    print(trans_matrix)
    print(mutation_probabilities)
    print (hmm_model.forward_ll * -1)

    return hmm_model.forward_ll * -1

In [270]:
params = [0.95, 0.5, 0.95, 0.5, 0.95, 0.5, 0.8, 0.7, 0.6]
bound_tuple = [(0.001,0.999),(0.001,0.999),(0.001,0.999),(0.001,0.999),(0.001,0.999),(0.001,0.999),(0.001,0.999),(0.001,0.999),(0.001,0.999)]

In [271]:
res = opt.minimize(calculate_likelihood, params, method = 'Nelder-Mead', bounds = bound_tuple)

[[0.95  0.025 0.025]
 [0.025 0.95  0.025]
 [0.025 0.025 0.95 ]]
22232.19254363934
22232.192543639496
[[0.9975  0.00125 0.00125]
 [0.025   0.95    0.025  ]
 [0.025   0.025   0.95   ]]
22329.64107432685
22329.64107432675
[[0.95    0.02625 0.02375]
 [0.025   0.95    0.025  ]
 [0.025   0.025   0.95   ]]
22240.688647706982
22240.68864770696
[[0.95    0.025   0.025  ]
 [0.00125 0.9975  0.00125]
 [0.025   0.025   0.95   ]]
22289.05013908625
22289.05013908643
[[0.95    0.025   0.025  ]
 [0.02625 0.95    0.02375]
 [0.025   0.025   0.95   ]]
22230.670812466593
22230.670812466684
[[0.95    0.025   0.025  ]
 [0.025   0.95    0.025  ]
 [0.00125 0.00125 0.9975 ]]
22667.828527636928
22667.828527636826
[[0.95    0.025   0.025  ]
 [0.025   0.95    0.025  ]
 [0.02625 0.02375 0.95   ]]
22222.618976624235
22222.61897662433
[[0.95  0.025 0.025]
 [0.025 0.95  0.025]
 [0.025 0.025 0.95 ]]
20814.494954065773
20814.494954065853
[[0.95  0.025 0.025]
 [0.025 0.95  0.025]
 [0.025 0.025 0.95 ]]
22222.639688893963


KeyboardInterrupt: 

In [261]:
for iter in tqdm(range(100)):
    if iter == 0:
        transition_probabilities, mutation_probabilities = create_transition_matrix(params)
    else:
        transition_probabilities = transition_counts
        mutation_probabilities = mutation_counts
    observation_probabilities = calculate_observation_probabilities(observations, mutation_probabilities)
    hm_model = hmm.HMM(initial_probs, transition_probabilities, observation_probabilities)
    hm_model.calculate_probabilities()
    if iter > 1 and abs(total_probability - (hm_model.forward_ll * -1)) < 0.01:
        break
    total_probability = hm_model.forward_ll * -1
    prob_observation = hm_model.forward_ll
    transition_counts = np.zeros((num_states, num_states))
    mutation_counts = np.zeros(num_states)
    for s in range(num_states):
        for t in range(num_states):
            temp = 0
            for i in range(sample_size - 1):
                if i == 0:
                    temp = hm_model.forward_probabilities[s, i] + math.log(transition_probabilities[s, t]) + math.log(observation_probabilities[t, i+1]) + hm_model.backward_probabilities[t, i+1]
                else:
                    temp = sum_logs(temp, hm_model.forward_probabilities[s, i] + math.log(transition_probabilities[s, t]) + math.log(observation_probabilities[t, i+1]) + hm_model.backward_probabilities[t, i+1])
            transition_counts[s, t] += math.exp(temp - prob_observation)

    for s in range(num_states):
        temp_1 = 0
        for t in range(num_states):
            temp_1 += transition_counts[s, t]
        for t in range(num_states):
            transition_counts[s, t] = transition_counts[s, t] / temp_1
  
    for s in range(num_states):
        temp_1 = 0; temp_2 = 0
        for i in range(sample_size - 1):
            temp_1 += hm_model.state_probabilities[s][i] * observations[i] / num_comparison_sequences
            temp_2 += hm_model.state_probabilities[s][i]
        mutation_counts[s] = temp_1 / temp_2
    if iter % 10 == 0:
        print(total_probability)  
        print(transition_counts, mutation_counts)

print("Final Fit....")
print(total_probability)  
print(transition_counts, mutation_counts)


  1%|          | 1/100 [00:02<03:45,  2.27s/it]

22325.14866042927
[[0.97121204 0.00512653 0.02366143]
 [0.1805367  0.76857426 0.05088904]
 [0.26761117 0.01676357 0.71562526]] [0.87059523 0.69426885 0.4198027 ]


 11%|█         | 11/100 [00:25<03:23,  2.29s/it]

17104.04085615522
[[0.91044171 0.0167624  0.07279589]
 [0.54571637 0.3207572  0.13352643]
 [0.65355702 0.05074317 0.2956998 ]] [0.89648957 0.62953887 0.28806192]


 21%|██        | 21/100 [00:48<03:01,  2.30s/it]

16942.112333075296
[[0.89454169 0.06463809 0.04082022]
 [0.7234141  0.15329125 0.12329465]
 [0.56156485 0.21407671 0.22435844]] [0.90047621 0.54839929 0.18294303]


 31%|███       | 31/100 [01:11<02:39,  2.31s/it]

16889.782222269492
[[0.89636346 0.07614049 0.02749606]
 [0.72730386 0.17365137 0.09904478]
 [0.51417858 0.30131329 0.18450813]] [0.90022816 0.49808778 0.11820549]


 41%|████      | 41/100 [01:34<02:16,  2.32s/it]

16887.331274073444
[[0.89890615 0.07559272 0.02550113]
 [0.72471941 0.18259867 0.09268192]
 [0.51122381 0.31100404 0.17777215]] [0.89961007 0.4830128  0.10719061]


 45%|████▌     | 45/100 [01:45<02:09,  2.35s/it]

Final Fit....
16887.25569683104
[[0.89921639 0.07551292 0.02527069]
 [0.72442574 0.18362342 0.09195084]
 [0.51092857 0.3120968  0.17697463]] [0.89953085 0.4811928  0.10593838]





In [251]:
observation_probabilities

array([[0.31849633, 0.31849633, 0.00237329, ..., 0.31849633, 0.1593297 ,
        0.29263245],
       [0.31849633, 0.31849633, 0.00237329, ..., 0.31849633, 0.1593297 ,
        0.29263245],
       [0.31849598, 0.31849598, 0.00237331, ..., 0.31849598, 0.15933003,
        0.29263259]])

array([ 7.,  8., 10., 10.,  8.,  9.,  4.,  5.,  9.,  8., 10.,  8.,  9.,
        6.,  7.,  9., 10.,  9., 10.,  8., 10.,  9.,  8.,  9.,  9., 10.,
        8.,  8.,  1.,  3., 10.,  9.,  9.,  8.,  9.,  9., 10., 10.,  9.])