In [11]:
import pandas as pd
import numpy as np
import scipy as sp
import plotly.express as px
from hmmlearn import hmm
import dask

# Setup example data

Define states and emission probabilities

In [2]:
state_means = np.array([0, 0.5, 1])
stdev = 0.1

Define a transition matrix

In [3]:
def generate_transition_matrix(n_states):
    transition_matrix = np.random.random(size=(n_states,n_states)) * 0.05
    for i in range(n_states):
        transition_matrix[i,i] = 1.0 - np.sum(transition_matrix[i,:]) + np.sum(transition_matrix[i,i])
    return transition_matrix

In [4]:
state_stds = np.array([stdev] * len(state_means))
transition_matrix = generate_transition_matrix(len(state_means))

Function to convert an underlying state into an emission (ie a state into an observed FRET value)

In [5]:
def get_emission_value(state):
    return np.random.normal(loc=state_means[state], scale=state_stds[state])

Write a generator function to generate a sequence of states and emissions

In [6]:
def generate_trace(n_steps, state_means, state_stds, transition_matrix):
    states = [0]
    values = [get_emission_value(states[-1])]
    shift = np.random.random() * .05 - .025
    for i in range(0,n_steps-1):
        current = states[-1]
        new_state = np.random.choice(np.arange(0,len(state_means)), p=transition_matrix[current])
        states.append(new_state)
        values.append(shift + get_emission_value(new_state))
    return states, values

In [7]:
len_trace = 500
num_traces = 80

states = []
values = []
for i in range(0,num_traces):
    states_sub, values_sub = generate_trace(len_trace, state_means, state_stds, transition_matrix)
    states.append(states_sub)
    values.append(values_sub)

In [8]:
px.line(y=[np.array(values)[0,:], state_means[np.array(states)[0,:]]])

# Functions for doing predictions

This is the standard pipeline, however I have found that it gives local minima very easily, and so it is best to run it multiple times and take the best result.  Note that it provides an initial guess of a uniform distribution for the initial state probabilities.

In [12]:
def do_fit(values, guessed_components):
    guessed_means = (np.arange(0,guessed_components)/(guessed_components-1)).reshape(-1,1)
    np_values = np.array(values).reshape(-1,1)

    model = hmm.GaussianHMM(n_components = guessed_components, covariance_type = "full", n_iter = 100, init_params="mcs",  means_prior=guessed_means,
                            )
    model.fit(np_values)
    # Predict the hidden states corresponding to observed X.
    Z = model.predict(np_values)
    score = model.score(np.array(values).reshape(-1,1))
    return model, Z, score

This will run the fitter multiple times and use the best.

In [13]:
def find_best_fit(values, guessed_components, num_trials=30):
    best_model = None
    best_score = -np.inf
    best_Z = None

    for i in range(num_trials):
        try:
            model, Z, score = do_fit(values, guessed_components)
            if score > best_score:
                best_score = score
                best_model = model
                best_Z = Z
            
        except:
            pass
    print(best_score)
    return best_model, best_model.means_[best_Z].squeeze()


In [14]:
def dask_wrapper(values, guessed_components, num_trials=30):
    model, predicted = find_best_fit(values, guessed_components, num_trials)
    return predicted


In [15]:
predicteds = [dask.delayed(dask_wrapper)(vals, 3, 30) for vals in values]
    

In [16]:
predicteds = dask.compute(*predicteds)

Model is not converging.  Current: 272.79543242710594 is not greater than 272.7957142876345. Delta is -0.00028186052855971866
Model is not converging.  Current: 325.81449201093284 is not greater than 325.8145185038774. Delta is -2.6492944584788347e-05
Model is not converging.  Current: 287.56840034264565 is not greater than 287.56850215497997. Delta is -0.00010181233432149384
Model is not converging.  Current: 325.81449569666137 is not greater than 325.8145085847643. Delta is -1.288810295818621e-05
Model is not converging.  Current: 96.14926089692435 is not greater than 96.15613084093586. Delta is -0.00686994401151253
Model is not converging.  Current: 327.9608124416831 is not greater than 327.961055908673. Delta is -0.00024346698990029836
Model is not converging.  Current: 272.7954042038586 is not greater than 272.79568916059196. Delta is -0.0002849567333669256


324.40282990564424


Model is not converging.  Current: 327.96085246793956 is not greater than 327.96090195199156. Delta is -4.9484051999115763e-05
Model is not converging.  Current: 272.7953727892004 is not greater than 272.79547956117676. Delta is -0.00010677197633413016


325.8145021619454
311.17649281378124
272.7953538337735


Model is not converging.  Current: 287.56841272002566 is not greater than 287.5684655700188. Delta is -5.284999315335881e-05


364.9980611417097
287.5683893601121
327.9606949209074


Model is not converging.  Current: 296.19134973519186 is not greater than 296.19153073822395. Delta is -0.00018100303208257174
Model is not converging.  Current: 296.1913444370256 is not greater than 296.1914926204847. Delta is -0.0001481834591459119


275.3440035667951
336.28591813552566


Model is not converging.  Current: 315.73814151420373 is not greater than 315.73833942960454. Delta is -0.00019791540080404957


338.3995712256117
296.19133052141035
300.51616979335034
307.4130406132894


Model is not converging.  Current: 323.99350609418065 is not greater than 323.99352505953. Delta is -1.896534934076044e-05


315.7381221524526


Model is not converging.  Current: 323.9935249849893 is not greater than 323.993612793362. Delta is -8.780837271160635e-05


285.49327907312556


Model is not converging.  Current: 335.95217609518994 is not greater than 335.95218369320224. Delta is -7.598012302878487e-06


273.7129414516072


Model is not converging.  Current: 323.99352714870025 is not greater than 323.99360767057414. Delta is -8.052187388329912e-05


335.9521705418898
333.6290024252523
336.900057939884


Model is not converging.  Current: 268.20743800348606 is not greater than 268.2075413122603. Delta is -0.00010330877421438345


318.1730480401213


Model is not converging.  Current: 323.99353576447896 is not greater than 323.9935628095323. Delta is -2.7045053343499603e-05


351.9745025780484
323.9934815967693
332.343118051905
307.0693707122121


Model is not converging.  Current: 321.69437379314905 is not greater than 321.69458188734177. Delta is -0.00020809419271472507
Model is not converging.  Current: 268.2073721072852 is not greater than 268.2074199369826. Delta is -4.78296973938086e-05


273.672344817772


Model is not converging.  Current: 268.2073764779402 is not greater than 268.2074251277351. Delta is -4.864979490548649e-05


314.8226626164513
339.16996676018624
268.2073609405892
321.6943615035237
309.7915923292828
271.3429617362008
259.53114743513044
283.87977569362306
327.7329615931278
309.5175597160136
286.5796623515883
273.519455284775


Model is not converging.  Current: 352.2536758812231 is not greater than 352.25392211105054. Delta is -0.0002462298274394925
Model is not converging.  Current: 352.2536851097596 is not greater than 352.25393173321686. Delta is -0.0002466234572580106
Model is not converging.  Current: 352.2536778492417 is not greater than 352.2539317964886. Delta is -0.00025394724690386283


300.13344737283114
281.66099988012394


Model is not converging.  Current: 352.2536834917354 is not greater than 352.2539054516262. Delta is -0.00022195989083684253
Model is not converging.  Current: 352.25365848045027 is not greater than 352.2536870622371. Delta is -2.8581786807535536e-05
Model is not converging.  Current: 352.2536975984673 is not greater than 352.2538097187747. Delta is -0.00011212030739216061
Model is not converging.  Current: 319.5619475111564 is not greater than 319.5629686466465. Delta is -0.0010211354900775405


352.2536636197904
291.37901544566427
313.38608716907146
285.54066935600724


Model is not converging.  Current: 319.55926983333853 is not greater than 319.56023819205075. Delta is -0.0009683587122140125


296.542577839987
336.86655788130435
319.5600225428082


Model is not converging.  Current: 365.30553461143097 is not greater than 365.3055367129987. Delta is -2.1015677020841395e-06


313.4377452794556
324.34046973507105
365.3055329441662


Model is not converging.  Current: 350.6094544503037 is not greater than 350.6095458623126. Delta is -9.141200888507228e-05
Model is not converging.  Current: 317.38032968315275 is not greater than 317.38038778803985. Delta is -5.8104887102672365e-05


342.7085187508125


Model is not converging.  Current: 276.5489239930209 is not greater than 276.54905760462634. Delta is -0.00013361160546310202
Model is not converging.  Current: 350.60945542396377 is not greater than 350.6095503650162. Delta is -9.494105245266837e-05


276.5489190694088
332.4276850692235


Model is not converging.  Current: 305.9033508718509 is not greater than 305.90343298025573. Delta is -8.210840485389781e-05
Model is not converging.  Current: 301.32572852065937 is not greater than 301.3267054191107. Delta is -0.0009768984513129908
Model is not converging.  Current: 300.1882331901399 is not greater than 300.1884080462175. Delta is -0.00017485607759226696


301.32547187252504


Model is not converging.  Current: 317.380329952909 is not greater than 317.3805874512978. Delta is -0.0002574983888052884


350.6094508490319
317.38031974367516
352.9524444088339


Model is not converging.  Current: 305.9034417212487 is not greater than 305.90351523732045. Delta is -7.351607172267904e-05
Model is not converging.  Current: 325.9941852455426 is not greater than 325.9941929802136. Delta is -7.734670987247227e-06
Model is not converging.  Current: 305.90332899421753 is not greater than 305.9034791335383. Delta is -0.00015013932079455117


325.994187193982
305.9034093367219


Model is not converging.  Current: 339.99198750766914 is not greater than 339.9919924700024. Delta is -4.962333264302288e-06
Model is not converging.  Current: 305.7604800253759 is not greater than 305.7605251278559. Delta is -4.510247998723571e-05


307.7705815898696


Model is not converging.  Current: 332.34126634073425 is not greater than 332.3412709313274. Delta is -4.5905931642664655e-06
Model is not converging.  Current: 339.99198780031185 is not greater than 339.99198875179957. Delta is -9.514877206129313e-07


295.77567948523824
305.24934352001924


Model is not converging.  Current: 305.76046689640145 is not greater than 305.7605584982966. Delta is -9.160189512158468e-05


279.0867157375713
273.31092853286736


Model is not converging.  Current: 305.760461424903 is not greater than 305.760557627363. Delta is -9.62024600426048e-05
Model is not converging.  Current: 339.99198749944077 is not greater than 339.9919942593476. Delta is -6.759906852948916e-06


305.76045903956344


Model is not converging.  Current: 314.2759044609166 is not greater than 314.2761401437373. Delta is -0.00023568282068708868
Model is not converging.  Current: 339.99198742778765 is not greater than 339.99199397979356. Delta is -6.552005913817993e-06


332.34126629267774


Model is not converging.  Current: 326.5825963078779 is not greater than 326.58332414708826. Delta is -0.0007278392103557962


339.9919872660876
314.27588794538417
314.25562020303784


Model is not converging.  Current: 326.58254212659983 is not greater than 326.5833313364611. Delta is -0.0007892098612387599


284.0034322612234
345.93424919329937


Model is not converging.  Current: 326.58260404528596 is not greater than 326.5834082473877. Delta is -0.0008042021017331535
Model is not converging.  Current: 325.37371234253516 is not greater than 325.3740269995599. Delta is -0.00031465702471678014


307.6434634066245


Model is not converging.  Current: 323.5636026596821 is not greater than 323.56398354658836. Delta is -0.00038088690627091637


326.5824029897853
326.3433672910865
323.56356145508107
325.3736718755288
300.9906192662522
318.06649531934903


Model is not converging.  Current: 258.24929056015753 is not greater than 258.2501115221293. Delta is -0.0008209619717831629


292.33284868035884
258.2492307867274
310.34502695611167


In [17]:
def find_transitions(Z):
    transitions = []
    for i in range(0,len(Z)-1):
        #transitions.append([Z[i],Z[i+1]])
        #if Z[i]!=Z[i+1]:
        if np.abs(Z[i] - Z[i+1])>0.1:
            transitions.append([Z[i],Z[i+1]])
    return np.array(transitions)

In [18]:
transitions = []
for predicted in predicteds:
    transitions.extend(find_transitions(predicted))
transitions = np.array(transitions)

In [19]:
px.histogram(np.array(predicteds).flatten(), nbins=100)

In [20]:
tdf = pd.DataFrame(data={'source':transitions[:,0], 'sink':transitions[:,1]})
f = px.density_heatmap(tdf, x='source', y='sink', nbinsx=200, nbinsy=200, width=800, height=700)
f.write_html('heatmap_5000.html')
f