In [1]:
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from scipy import optimize as opt
import Sequence_Analysis_Routines as sar
import ete3
from joblib import Parallel, delayed
from tqdm import tqdm



In [2]:
project_dir = 'D:/Project_Data/Project_3'
output_dir = project_dir + '/Output/Close_Species'
non_cds_output_dir = output_dir + '/Multiple_Alignment_Data/Non_CDS'
tb_species = 'GCF_000195955.2'
outgroup_species = 'NC_008596.1'

##### Run IQTree on CDS alignments to generate tree

In [None]:
#alignment_names = sar.list_files(cds_output_dir)
#sar.concatenate_fasta(cds_output_dir, alignment_names, cds_output_dir + '/CDS/concatenated_cds.fasta')

In [None]:
#subprocess.run('cd \\users\\nicho\\IQTree & bin\\iqtree2 -q ' + cds_output_dir + ' --prefix '+ output_dir + '/Trees/Full_Tree -m GTR+I+G -B 1000 -T AUTO', shell=True)
#subprocess.run('cd \\users\\nicho\\IQTree & bin\\iqtree2 -s ' + outgroup_cds_output_dir + '/CDS/concatenated_cds.fasta' + ' --prefix '+ output_dir + '/Trees/Concatenated_JC_Tree -m JC -B 1000 -T AUTO -o ' + outgroup_species, shell=True)

In [3]:
file_ids = sar.list_files(non_cds_output_dir+'/')
ids = [int(i.split('.')[0]) for i in file_ids]
#ids.remove(1559)  #Contains S in alignment!
master_tree = ete3.Tree(output_dir + '/Trees/Concatenated_JC_Tree.treefile')
outgroup = master_tree.search_nodes(name= outgroup_species)[0]
outgroup.delete()

In [4]:
group_ids = ids
align_dict = {}
for group_id in tqdm(group_ids):
        alignment = sar.Alignment(non_cds_output_dir+'/'+str(group_id)+'.fasta', tb_species, 'NT')
        alignment.modify_sequence(1, False, True)
        align_dict[group_id] = alignment

100%|██████████| 2225/2225 [00:28<00:00, 77.73it/s] 


In [9]:
def fit_hmm(params, group_ids, num_subsets, subset_num):
    num_symbols = 4    # Inserts are randomised
    num_states = 2
    initial_state_probabilities = [1.0/num_states]*num_states
    total_probability = 0
    ids = sar.chunk_list(group_ids, num_subsets, subset_num)
    for group_id in ids:
        alignment = align_dict[group_id]
        align_list =  alignment.modified_sequence_list
        align_names = alignment.sequence_names
        len_align_list = len(align_list[0])
        non_cds = [x[50:len_align_list - 50] for x in align_list]
        if len(non_cds[0]) < 10:
            continue
        #transition_probabilities = np.full((num_states,num_states),params[0])
        #np.fill_diagonal(transition_probabilities, 1 - (num_states-1)*params[0])
        a = params[0]
        b = (1-params[0])
        c = 1 - (params[1])
        d = params[1]
        transition_probabilities = np.array([[a,b],[c,d]])
        observation_probabilities = sar.mutation_probs(params[len(params)-num_states:len(params)], non_cds, align_names, master_tree, num_symbols)
        trial_hmm = sar.HMM(initial_state_probabilities, transition_probabilities, observation_probabilities)
        #trial_hmm.viterbi()
        #total_probability += trial_hmm.viterbi_log_probability * -1
        trial_hmm.forward()
        total_probability += trial_hmm.forward_ll * -1
    return total_probability

In [7]:
def parallel_fit_hmm (params):
    num_cores = 16
    core_numbers = range(1, num_cores+1)
    a = Parallel(n_jobs=-1)(delayed(fit_hmm)(params, group_ids, num_cores, core_number) for core_number in core_numbers)
    print(params, sum(a))
    return sum(a)  

In [12]:
#res = opt.minimize(parallel_fit_hmm, (0.01, 9, 4, 0.0001), method = 'Nelder-Mead', bounds = ((0.001,0.999),(0.1,10),(0.1,10), (0.1, 10)))
#print(res.x)
res = opt.minimize(parallel_fit_hmm, (0.97, 0.97, 8,0.2 ), method = 'Nelder-Mead', bounds = ((0.001,0.999),(0.001,0.999),(0.001,10),(0.001,10)))
print(res.x)

[0.97 0.97 8.   0.2 ] 1247733.950129405
[0.999 0.97  8.    0.2  ] 1252420.3184435947
[0.97  0.999 8.    0.2  ] 1254995.3515171919
[0.97 0.97 8.4  0.2 ] 1248710.3516285585
[0.97 0.97 8.   0.21] 1247385.7010901968
[0.9845 0.941  8.2    0.205 ] 1247066.4703388233
[0.99175 0.912   8.3     0.2075 ] 1248097.5042134812
[0.94825 0.9555  8.3     0.2075 ] 1247995.6983341884
[0.966375 0.94825  7.85     0.21125 ] 1245942.7071641872
[0.9645625 0.937375  7.575     0.216875 ] 1245054.8531679413
[0.99628125 0.9536875  7.5875     0.2084375 ] 1248291.144788454
[0.96025781 0.95504687 8.121875   0.20773437] 1247056.0270626862
[0.96966016 0.93171094 7.9484375  0.21980469] 1245307.873825135
[0.96949023 0.91256641 7.92265625 0.21470703] 1245079.2188951958
[0.94748535 0.92734961 7.58398437 0.22456055] 1245218.081385335
[0.96534131 0.8994541  7.39316406 0.23023926] 1243907.4219245247
[0.96788306 0.87165771 7.02880859 0.2414917 ] 1243536.0228227356
[0.95505042 0.89276343 7.10678711 0.22901245] 1244069.367759754

In [None]:
fitted_parameters = res.x
transition_probabilities = np.array([[1-fitted_parameters[0],fitted_parameters[0]],[fitted_parameters[1],1-fitted_parameters[1]]])
group_id =   1167 #1569 #1505    #  1167
alignment = sar.Alignment(non_cds_output_dir+'/'+str(group_id)+'.fasta', tb_species, 'NT')
alignment.modify_sequence(consensus=1)
alignment_list =  alignment.modified_sequence_list
alignment_names = alignment.sequence_names
observation_probabilities = sar.mutation_probs(fitted_parameters[2], fitted_parameters[3], alignment_list, alignment_names, master_tree)
fitted_hmm = sar.HMM(initial_state_probabilities, transition_probabilities, observation_probabilities)
fitted_hmm.viterbi()
print(fitted_hmm.viterbi_log_probability)
plt.plot(fitted_hmm.viterbi_path);

In [13]:
parallel_fit_hmm([0.98576229, 0.97154437, 8.54931842, 1.26858705])

[0.98576229, 0.97154437, 8.54931842, 1.26858705] 1236460.2911367614


1236460.2911367614

In [14]:
res = opt.shgo(parallel_fit_hmm, bounds = ((0.7,0.999),(0.7,0.999),(0.001,2),(0.5,10)))
print(res.x)

[0.7   0.7   0.001 0.5  ] 1787438.7118226273
[ 0.999  0.999  2.    10.   ] 1242824.7578134167
[0.999 0.7   0.001 0.5  ] 1794995.1024155773
[0.999 0.999 0.001 0.5  ] 1737328.4342628245
[0.999 0.999 2.    0.5  ] 1373422.0422607826
[9.99e-01 9.99e-01 1.00e-03 1.00e+01] 1293009.2096890537
[0.999 0.7   2.    0.5  ] 1377058.8926613845
[ 0.999  0.7    2.    10.   ] 1269234.1565202537
[9.99e-01 7.00e-01 1.00e-03 1.00e+01] 1327543.2709945561
[0.7   0.999 0.001 0.5  ] 1738002.1843624804
[0.7   0.999 2.    0.5  ] 1422771.018696171
[ 0.7    0.999  2.    10.   ] 1257720.9672052362
[7.00e-01 9.99e-01 1.00e-03 1.00e+01] 1284989.16752892
[0.7 0.7 2.  0.5] 1412563.2635506813
[ 0.7  0.7  2.  10. ] 1258165.2110610222
[7.e-01 7.e-01 1.e-03 1.e+01] 1289341.866332135
[0.8495 0.8495 1.0005 5.25  ] 1259990.6724239653
[ 0.999  0.999  2.    10.   ] 1242824.7578134167
[ 0.99899999  0.999       2.         10.        ] 1242824.746906694
[ 0.999       0.99899999  2.         10.        ] 1242824.746014033
[ 0.999   