In [1]:
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from scipy import optimize as opt
import Sequence_Analysis_Routines as sar
import ete3
from joblib import Parallel, delayed
from tqdm import tqdm



In [2]:
project_dir = 'D:/Project_Data/Project_3'
output_dir = project_dir + '/Output/Close_Species'
non_cds_output_dir = output_dir + '/Multiple_Alignment_Data/Non_CDS'
tb_species = 'GCF_000195955.2'

In [3]:
file_ids = sar.list_files(non_cds_output_dir+'/')
ids = [int(i.split('.')[0]) for i in file_ids]
#ids.remove(1559)  #Contains S in alignment!
outgroup_species = 'NC_008596.1'
master_tree = ete3.Tree(output_dir + '/Trees/Concatenated_JC_Tree.treefile')
outgroup = master_tree.search_nodes(name= outgroup_species)[0]
outgroup.delete()

In [4]:
group_ids = ids
align_dict = {}
for group_id in tqdm(group_ids):
        alignment = sar.Alignment(non_cds_output_dir+'/'+str(group_id)+'.fasta', tb_species, 'NT')
        alignment.modify_sequence(1, False, True)
        align_dict[group_id] = alignment

100%|██████████| 2225/2225 [00:25<00:00, 87.56it/s] 


In [5]:
def fit_hmm(params, group_ids, num_subsets, subset_num):
    num_symbols = 4    # Inserts are randomised
    num_states = 2
    initial_state_probabilities = [1.0/num_states]*num_states
    total_probability = 0
    ids = sar.chunk_list(group_ids, num_subsets, subset_num)
    for group_id in ids:
        alignment = align_dict[group_id]
        align_list =  alignment.modified_sequence_list
        align_names = alignment.sequence_names
        len_align_list = len(align_list[0])
        non_cds = [x[50:len_align_list - 50] for x in align_list]
        if len(non_cds[0]) < 10:
            continue
        #transition_probabilities = np.full((num_states,num_states),params[0])
        #np.fill_diagonal(transition_probabilities, 1 - (num_states-1)*params[0])
        a = params[0]
        b = (1-params[0])
        c = 1 - (params[1])
        d = params[1]
        transition_probabilities = np.array([[a,b],[c,d]])
        observation_probabilities = sar.mutation_probs(params[len(params)-num_states:len(params)], non_cds, align_names, master_tree, num_symbols)
        trial_hmm = sar.HMM(initial_state_probabilities, transition_probabilities, observation_probabilities)
        trial_hmm.viterbi()
        total_probability += trial_hmm.viterbi_log_probability * -1
    return total_probability

In [6]:
def parallel_fit_hmm (params):
    num_cores = 16
    core_numbers = range(1, num_cores+1)
    a = Parallel(n_jobs=-1)(delayed(fit_hmm)(params, group_ids, num_cores, core_number) for core_number in core_numbers)
    print(params, sum(a))
    return sum(a)  

In [12]:
#res = opt.minimize(parallel_fit_hmm, (0.01, 9, 4, 0.0001), method = 'Nelder-Mead', bounds = ((0.001,0.999),(0.1,10),(0.1,10), (0.1, 10)))
#print(res.x)
res = opt.minimize(parallel_fit_hmm, (0.97, 0.97, 8,0.2 ), method = 'Nelder-Mead', bounds = ((0.001,0.999),(0.001,0.999),(0.001,10),(0.001,10)))
print(res.x)

[0.97 0.97 8.   0.2 ] 1250545.824715568
[0.999 0.97  8.    0.2  ] 1253999.829653256
[0.97  0.999 8.    0.2  ] 1256529.5785359123
[0.97 0.97 8.4  0.2 ] 1251516.399687634
[0.97 0.97 8.   0.21] 1250180.1211393552
[0.9845 0.941  8.2    0.205 ] 1249896.1613026552
[0.99175 0.912   8.3     0.2075 ] 1250832.4604396436
[0.94825 0.9555  8.3     0.2075 ] 1251676.3941608756
[0.9609375 0.959125  8.225     0.205625 ] 1250801.5201093212
[0.97271875 0.9500625  7.8125     0.2103125 ] 1248967.9144002988
[0.97407813 0.94009375 7.51875    0.21546875] 1248160.9854173101
[0.98835156 0.95142187 7.634375   0.20960938] 1249073.2725551643
[0.98846484 0.93125781 7.6765625  0.22003906] 1248552.7049790146
[0.99769727 0.91188672 7.51484375 0.21505859] 1251228.6007500628
[0.97692432 0.95547168 7.87871094 0.21126465] 1249112.1095262559
[0.97940942 0.94812256 7.15419922 0.22319092] 1247753.3836777667
[0.97686414 0.95168384 6.63129883 0.23228638] 1247953.8618308683
[0.98822766 0.92997632 7.11323242 0.2228894 ] 1247930.

In [None]:
fitted_parameters = res.x
transition_probabilities = np.array([[1-fitted_parameters[0],fitted_parameters[0]],[fitted_parameters[1],1-fitted_parameters[1]]])
group_id =   1167 #1569 #1505    #  1167
alignment = sar.Alignment(non_cds_output_dir+'/'+str(group_id)+'.fasta', tb_species, 'NT')
alignment.modify_sequence(consensus=1)
alignment_list =  alignment.modified_sequence_list
alignment_names = alignment.sequence_names
observation_probabilities = sar.mutation_probs(fitted_parameters[2], fitted_parameters[3], alignment_list, alignment_names, master_tree)
fitted_hmm = sar.HMM(initial_state_probabilities, transition_probabilities, observation_probabilities)
fitted_hmm.viterbi()
print(fitted_hmm.viterbi_log_probability)
plt.plot(fitted_hmm.viterbi_path);

In [11]:
parallel_fit_hmm([0.97, 0.97, 8, 0.2])

[0.97, 0.97, 8, 0.2] 1250545.824715568


1250545.824715568

In [16]:
res = opt.shgo(parallel_fit_hmm, bounds = ((0.001,0.999),(0.001,0.999),(0.001,10),(0.001,10)))
print(res.x)

[0.001 0.001 0.001 0.001] 3880032.869238109
[ 0.999  0.999 10.    10.   ] 1366251.5872554705
[0.999 0.001 0.001 0.001] 3880032.869238109
[0.999 0.999 0.001 0.001] 3880032.869238109
[9.99e-01 9.99e-01 1.00e+01 1.00e-03] 1293500.6782582172
[9.99e-01 9.99e-01 1.00e-03 1.00e+01] 1293500.6782582172
[9.99e-01 1.00e-03 1.00e+01 1.00e-03] 1365176.1920964515
[9.99e-01 1.00e-03 1.00e+01 1.00e+01] 1366251.5872554705
[9.99e-01 1.00e-03 1.00e-03 1.00e+01] 2129956.3775916146
[0.001 0.999 0.001 0.001] 3880032.869238109
[1.00e-03 9.99e-01 1.00e+01 1.00e-03] 2129956.3775916146
[1.00e-03 9.99e-01 1.00e+01 1.00e+01] 1366251.5872554705
[1.00e-03 9.99e-01 1.00e-03 1.00e+01] 1365176.1920964515
[1.e-03 1.e-03 1.e+01 1.e-03] 2041694.4165698334
[1.e-03 1.e-03 1.e+01 1.e+01] 1366251.5872554705
[1.e-03 1.e-03 1.e-03 1.e+01] 2041694.4165698334
[0.5    0.5    5.0005 5.0005] 1418673.2737656874
[9.99e-01 9.99e-01 1.00e+01 1.00e-03] 1293500.6782582172
[9.98999985e-01 9.99000000e-01 1.00000000e+01 1.00000000e-03] 1293