In [1]:
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from scipy import optimize as opt
import Sequence_Analysis_Routines as sar
import ete3
from joblib import Parallel, delayed
from tqdm import tqdm



In [2]:
project_dir = 'D:/Project_Data/Project_3'
output_dir = project_dir + '/Output/Close_Species'
non_cds_output_dir = output_dir + '/Multiple_Alignment_Data/Non_CDS'
tb_species = 'GCF_000195955.2'

In [3]:
file_ids = sar.list_files(non_cds_output_dir+'/')
ids = [int(i.split('.')[0]) for i in file_ids]
#ids.remove(1559)  #Contains S in alignment!
outgroup_species = 'NC_008596.1'
master_tree = ete3.Tree(output_dir + '/Trees/Concatenated_JC_Tree.treefile')
outgroup = master_tree.search_nodes(name= outgroup_species)[0]
outgroup.delete()

In [4]:
group_ids = ids
align_dict = {}
for group_id in group_ids:
        alignment = sar.Alignment(non_cds_output_dir+'/'+str(group_id)+'.fasta', tb_species, 'NT')
        alignment.modify_sequence(1, False, True)
        align_dict[group_id] = alignment

In [10]:
def fit_hmm(params, group_ids, num_subsets, subset_num):
    num_symbols = 4    # Inserts are randomised
    num_states = 3
    initial_state_probabilities = [1.0/num_states]*num_states
    total_probability = 0
    ids = sar.chunk_list(group_ids, num_subsets, subset_num)
    for group_id in ids:
        alignment = align_dict[group_id]
        align_list =  alignment.modified_sequence_list
        align_names = alignment.sequence_names
        len_align_list = len(align_list[0])
        non_cds = [x[50:len_align_list - 50] for x in align_list]
        if len(non_cds[0]) < 10:
            continue
        #transition_probabilities = np.full((num_states,num_states),params[0])
        #np.fill_diagonal(transition_probabilities, 1 - (num_states-1)*params[0])
        a = params[0]
        b = params[0] * params[1]
        c = 1 - (a+b)
        d = params[2] * params[3]
        e = params[2]
        f = 1 - (d+e)
        h = params[4] * params[5]
        i = params[4]
        g = 1 - (h+i)
        transition_probabilities = np.array([[a,b,c],[d,e,f],[g,h,i]])
        observation_probabilities = sar.mutation_probs(params[len(params)-num_states:len(params)], non_cds, align_names, master_tree, num_symbols)
        trial_hmm = sar.HMM(initial_state_probabilities, transition_probabilities, observation_probabilities)
        trial_hmm.viterbi()
        total_probability += trial_hmm.viterbi_log_probability * -1
    return total_probability

In [11]:
def parallel_fit_hmm (params):
    num_cores = 16
    core_numbers = range(1, num_cores+1)
    a = Parallel(n_jobs=-1)(delayed(fit_hmm)(params, group_ids, num_cores, core_number) for core_number in core_numbers)
    print(params, sum(a))
    return sum(a)  

In [12]:
#res = opt.minimize(parallel_fit_hmm, (0.01, 9, 4, 0.0001), method = 'Nelder-Mead', bounds = ((0.001,0.999),(0.1,10),(0.1,10), (0.1, 10)))
#print(res.x)
res = opt.minimize(parallel_fit_hmm, (0.9, 0.9, 0.9, 8, 4, 1), method = 'Nelder-Mead', bounds = ((0.001,0.999),(0.001,0.999),(0.001,0.999),(0.01,10),(0.01,10),(0.01,10)))
print(res.x)

SyntaxError: invalid syntax (Temp/ipykernel_3816/1258937426.py, line 4)

In [None]:
fitted_parameters = res.x
transition_probabilities = np.array([[1-fitted_parameters[0],fitted_parameters[0]],[fitted_parameters[1],1-fitted_parameters[1]]])
group_id =   1167 #1569 #1505    #  1167
alignment = sar.Alignment(non_cds_output_dir+'/'+str(group_id)+'.fasta', tb_species, 'NT')
alignment.modify_sequence(consensus=1)
alignment_list =  alignment.modified_sequence_list
alignment_names = alignment.sequence_names
observation_probabilities = sar.mutation_probs(fitted_parameters[2], fitted_parameters[3], alignment_list, alignment_names, master_tree)
fitted_hmm = sar.HMM(initial_state_probabilities, transition_probabilities, observation_probabilities)
fitted_hmm.viterbi()
print(fitted_hmm.viterbi_log_probability)
plt.plot(fitted_hmm.viterbi_path);

In [None]:
fit_hmm2([0.02087488, 0.02126647, 4.64647873, 1.10655882], group_ids, 8, 1)

In [None]:
print (globals()['res_' + str(1)])

In [None]:
num_cores = 8
ans = 0
for n in range(1, num_cores+1):
    ans += globals()['res_' + str(n)]
print(ans)

In [None]:
print(globals()['nick_1'])

In [None]:
def test(var, val):
     globals()['res_' + str(var)] += val

In [None]:
test(3, 7)

In [None]:
print(res_3)