# HMM Model - Pytorch and FRET Analysis Enabled

## Step 0: Setup the environment

In this directory, I created a .py script named smFRET_HMM.py that holds our HMM class and all relevant functions (no training).

In [1]:
import smFRET_HMM as Model
import torch

import os
import numpy as np
from matplotlib import pyplot as plt

from fretbursts import *
from H2MM_C import *

from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


 - Fallback to pure python burst search.
 - Fallback to pure python photon counting.
--------------------------------------------------------------
 You are running FRETBursts (version 0.7+47.gc51b088).

 If you use this software please cite the following paper:

   FRETBursts: An Open Source Toolkit for Analysis of Freely-Diffusing Single-Molecule FRET
   Ingargiola et al. (2016). http://dx.doi.org/10.1371/journal.pone.0160716 

--------------------------------------------------------------


Let's go ahead and define the priors, transitions, and emissions matrices. Then, let's initialize our model.

In [2]:
## According to the FRET API tutorial, there 2 states 0 and 1, both with the same chance of being x_0.
priors = [0.5,0.5]

## With 2 states there are 4 transitions. A(i,j) denotes probability of transitioning from state j to state i
## Usually there are numbers close to 1 along diagonal (the prob of not transitioning is higher) and close to 0 else.
transitions = [[0.999999, 1e-6],
                [1e-6, 0.999999]]

observations = [[0.3, 0.7],
                [0.5, 0.5]]
## Note that it is 2D, with 2 states (rows) and 2 streams (columns)

# Thus, we have the model:
model = Model.HMM(transitions, observations, priors)
print(model.unnormalized_emiss.matrix)
print(model.unnormalized_trans.matrix)

Parameter containing:
tensor([[0.3000, 0.7000],
        [0.5000, 0.5000]], requires_grad=True)
Parameter containing:
tensor([[1.0000e+00, 1.0000e-06],
        [1.0000e-06, 1.0000e+00]], requires_grad=True)


## Step 1: Import the data

Run burst search on the raw data suppleid by the API tutorial. 

In [3]:
def data_sort(data,nchan=3,Aex_stream=2,Aex_shift=None,**kwargs):
    usALEX = hasattr(data,'alex_period')
    if usALEX:
        if Aex_shift not in [None, 'shift', 'rand', 'even']:
            raise ValueError("Aex_shift must be 'shift', 'rand', or 'even'")
    elif Aex_shift is not None:
        raise ValueError("Aex_shift only valid for usALEX data")
    if (len(kwargs) == 0):
        chan_sel = [Ph_sel(Dex='Dem'), Ph_sel(Dex='Aem'), Ph_sel(Aex='Aem'), Ph_sel(Aex='Dem')]
        n = 4
    else:
        n = 1
        chan_sel = []
        while kwargs.get('ph_sel' + str(n)) != None:
            chan_sel.append(kwargs.get('ph_sel' + str(n)))
            n += 1
        n += -1
    print('Sorting photon channels',end='...')
    chans = np.array([data.get_ph_mask(ph_sel=chan_sel[i]) for i in range(0,len(chan_sel))])
    mask = chans.sum(axis=0)
    if np.any(mask != 1):
        print('Photon assigned to multiple channels or none, check ph_sel selections')
        return None
    ph_chan = np.zeros(chans.shape[1],dtype=int)
    for i in range(0,len(chan_sel)):
        ph_chan += i*chans[i,:]
    ph_times = data.ph_times_m[0]
    if not usALEX:
        ph_nanotime = data.nanotimes[0]
        ArrivalNanotime = []
    burst = data.mburst[0]
    ArrivalColor = []
    ArrivalTime = []
    print('Slicing data into bursts',end='...')
    for start, stop in zip(burst.istart, burst.istop + 1):
        ArrivalColor.append(ph_chan[start:stop][ph_chan[start:stop]<nchan])
        ArrivalTime.append(ph_times[start:stop][ph_chan[start:stop]<nchan])
        if not usALEX:
            ArrivalNanotime.append(ph_nanotime[start:stop][ph_chan[start:stop]<nchan])
    # apply a shift to usALEX AexAem photons, and make sure to re-sort the photons
    if Aex_shift == 'shift' and Aex_stream < nchan:
        print("Shifting Aex photons",end='...')
        alex_shift = data.D_ON[0] - data.A_ON[0]
        for i, (color, time) in enumerate(zip(ArrivalColor,ArrivalTime)):
            time[color==Aex_stream] += alex_shift
            sort = np.argsort(time)
            ArrivalColor[i] = color[sort]
            ArrivalTime[i] = time[sort]
    elif Aex_shift == 'rand' and Aex_stream < nchan:
        print("Shift and randomizing Aex photons",end='...')
        D_ON, D_OFF = data.D_ON[0], data.D_ON[1]
        for i, (color, time) in enumerate(zip(ArrivalColor, ArrivalTime)):
            time_temp = time.copy()
            alex_mask = color == Aex_stream
            alex_new = np.random.randint(D_ON,D_OFF,size=alex_mask.sum())
            time_temp[alex_mask] = alex_new + (time[alex_mask] // data.alex_period)*data.alex_period
            sort = np.argsort(time_temp)
            ArrivalColor[i] = color[sort].astype('uint8')
            ArrivalTime[i] = time_temp[sort]
    elif Aex_shift == 'even' and Aex_stream < nchan:
        print("Distributing Aex photons",end='...')
        D_ON, D_OFF = data.D_ON[0], data.D_ON[1]
        D_dur = D_OFF - D_ON
        for i, (color, time) in enumerate(zip(ArrivalColor, ArrivalTime)):
            Aex_mask = color == Aex_stream
            tms, inverse, counts = np.unique(time[Aex_mask]//data.alex_period,return_counts=True,return_inverse=True)
            newAex_times = np.empty(inverse.shape,dtype=time.dtype)
            for j, (tm, count) in enumerate(zip(tms,counts)):
                t_beg = tm*data.alex_period + D_ON + D_dur/(count+1)
                t_end = tm*data.alex_period + D_OFF
                newAex_times[j==inverse] = np.arange(t_beg,t_end,D_dur/(count+1))[:count]
            time_new = time.copy()
            time_new[Aex_mask] = newAex_times
            sort = np.argsort(time_new)
            ArrivalColor[i] = color[sort].astype('uint8')
            ArrivalTime[i] = time_new[sort]
    print('Done')
    if usALEX:
        return ArrivalColor, ArrivalTime
    else:
        return ArrivalColor, ArrivalTime , ArrivalNanotime

In [4]:
data = loader.photon_hdf5("h2mm_api_tutorial/033HP3_T25C_300mM_NaCl_2.hdf5")
loader.alex_apply_period(data)
data.calc_bg(fun=bg.exp_fit,time_s=30, tail_min_us='auto', F_bg=1.7)
data.burst_search(m=10,F=6)
data.fuse_bursts(ms=0)
data = Sel(data,select_bursts.size,add_naa=True,th1=50)
data = Sel(data,select_bursts.size,th1=30)

color, times = data_sort(data,nchan=2)

# Total photons (after ALEX selection):    23,542,014
#  D  photons in D+A excitation periods:    7,271,909
#  A  photons in D+A excitation periods:   16,270,105
# D+A photons in  D  excitation period:    11,826,619
# D+A photons in  A  excitation period:    11,715,395

 - Calculating BG rates ... Channel 0
[DONE]
 - Performing burst search (verbose=False) ...[DONE]
 - Calculating burst periods ...[DONE]
 - Counting D and A ph and calculating FRET ... 
   - Applying background correction.
   [DONE Counting D/A]
 - - - - - CHANNEL  1 - - - - 
 --> END Fused 136948 bursts (38.2%, 20 iter)

 - Counting D and A ph and calculating FRET ... 
   - Applying background correction.
   [DONE Counting D/A and FRET]
Sorting photon channels...Slicing data into bursts...Done


In [5]:
emissions = [list(elem) for elem in color]
train_lines, test_lines = train_test_split(emissions, random_state=42)

train_dataset = Model.PhotonDataset(train_lines)
test_dataset = Model.PhotonDataset(test_lines)

trainer = Model.Trainer(model, lr=0.01)

In [6]:
for epoch in range(10):
        print("========= Epoch %d of %d =========" % (epoch+1, 10))
        train_loss = trainer.train(train_dataset)
        valid_loss = trainer.test(train_dataset)

        print("========= Results: epoch %d of %d =========" % (epoch+1, 10))
        print("train loss: %.2f| valid loss: %.2f\n" % (train_loss, valid_loss) )



100%|██████████| 15/15 [00:00<00:00, 21.25it/s]


train loss: -0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 15.99it/s]


train loss: 0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 16.02it/s]


train loss: 0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 15.92it/s]


train loss: 0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 15.97it/s]


train loss: 0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 15.89it/s]


train loss: 0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 15.77it/s]


train loss: 0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 15.93it/s]


train loss: 0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 15.85it/s]


train loss: 0.00| valid loss: 0.00



100%|██████████| 15/15 [00:00<00:00, 15.97it/s]


train loss: 0.00| valid loss: 0.00



In [7]:
print(model.unnormalized_trans.matrix)
print(model.unnormalized_emiss.matrix)

Parameter containing:
tensor([[ 9.0045e-01, -1.6913e-05],
        [-2.1985e-05,  9.0045e-01]], requires_grad=True)
Parameter containing:
tensor([[0.2018, 0.6007],
        [0.4010, 0.4010]], requires_grad=True)
