# Blechpy Poisson HMM Tutorial
This tutorial will cover how to setup and fit spike data to a poisson HMM.
To use this you must first already have a blechpy.dataset object created with data that is past the spike sorting stage. 

In [1]:
# Imports
import blechpy
from blechpy.analysis import poissonHMM as phmm
import pandas as pd
import numpy as np

In [2]:
# First get the path to your recording folder
rec_dir = '/data/Katz_Data/Stk11_Project/RN10/RN10_ctaTest_190220_131512'

## Fitting a single HMM
To fit a single HMM you will using the PoissonHMM object. This object will house all the necessary parameters for an HMM except for the data being fitted. 
### Gathering the data
First you will need to collect the data for the HMM to fit. This should be a spike array (numpy array, dtype=int32) with 3-dimensions: Trial, Cell, Time bin with each value being the number of spikes in the time bin. If you have sorted your units and then used dat.make_unit_arrays(), then your spike arrays are stored in you h5 file with 1ms time bins from -2000 ms to 5000 ms. You will need to grab this spike array and cut it down to the right time window that you want to model as well as only the units you want to use. Additionally, you may need to rebin this array to have a different time step (especially with sparse firing units). 

phmm has a useful function for gathering this:

In [None]:
din_channel = 0 # Channel of the digital input that the trials you wish to fit are on
unit_type = 'single' # This can be single (for all single units),
                     # pyramidal (only single unit regular-spiking cells) 
                     # or interneuron (only single unit fast-spiking cells)
        
# The parameters below are optional
time_start = 0  # Time start in ms
time_end = 2000 # Time end in ms
dt = 0.01 # desired bin size for the return spike array in seconds, default is 0.001 seconds

spike_array, dt, time = phmm.get_hmm_spike_data(rec_dir, unit_type, din_channel, time_start, time_end, dt)

### Initializing and fitting the model
Now you can go ahead and initialize and fit your HMM.
#### Something to note: 
I have not yet figured out the best parameters to use to test convergence. So for now at each iteration the changes in every matrix (transition, emission and initial distribution) are computed and when the total change in every matrix is below the threshold then fitting stops. Alternatively, fitting stops if the maximum number of iterations is reached. For now the default convergence threshold is 1e-4 which works well for the simulated data. I have not yet had actual data meet this criteria. 

Also the cost of the model is computed on each iteration. The cost is computed by predicting the state at each time point and then using the emission (rate) matrix to predict the firing rate in each bin and them computing the distance of this prediction from the actual firing rate. This is then summed over time bins and averaged over trials to get the final cost of the model. This would probably provide a better measure of convergence, but I have not yet determined the best threshold for change in cost at which to stop fitting. Also this may lead to overfitting. But cost does provide a means of comparing models since BIC is only a good measure to compare models with the same number of states and time bins. 

In [None]:
n_states = 3  # Number of predicted states in your data

# Initializing the model
model = phmm.PoissonHMM(n_states) # Notice you're not giving it the data yet

# Fitting the model
convergence_threshold = 1e-4  # This is optional, the default is 1e-5, for my final models I used 1e-10.
                              # This is the threshold for fitting such that when the change in log_likelihood 
                              # betweeen iterations is below this then fitting ends.
max_iter = 1000  # This is also optional, the default is 1000

model.fit(spike_array, dt, time, max_iter=max_iter, thresh=convergence_threshold)

### Understanding your model
Now that the model is fitted there are some useful aspects to know about it. 
#### Important Attributes
- model.transition
 - transtion matrix giving probability of switching from one state to another
- model.emission
 - This is actually a rate matrix expressing the predicted firing of each neuron in each state
- model.initial_distribution
 - This gives the probability of being in each state at the start
- model.cost
 - This has the last computed cost of the model
- model.BIC
 - This has the last computed Bayesian Information Criteria of the model
- model.best_sequences
 - This has the best predicted sequence for each trial
- model.max_log_prob
 - Max log probability of the sequences (not sure this is computed correctly)
 
#### Useful fucntions in the model
     best_sequences, max_log_prob = model.get_best_paths(spike_array, dt)
     forward_probs = model.get_forward_probabilities(spike_array, dt)
     bakward_probs = model.get_backward_probabilities(spike_array, dt)
     gamma_probs = model.get_gamma_probabilites(spike_array, dt)
     
Additionally the model keeps a limited history of previous iterations can be be rolled back in case you pass a minima. Use:
    `model.set_to_lowest_cost()`
    or
    `model.set_to_lowest_BIC()`
     
Finally if you would like to re-fit the model, be sure to randomize it again before refitting:
- `model.randomize(spike_array, dt)`
- `model.fit(spike_array, dt)`
 


## Fitting and saving HMMs
That's just a breakdown of a `PoissonHMM` object. A much better way is to use the `HmmHandler`. This interface handles fitting HMMs for all digital inputs (tastes) as well as trying different parameters sets, plotting and saving all data to an hdf5 store. *This store is seperate from your data h5 file*
The HMM handler also takes care of parallelizing HMM fitting and creating plots for fitted HMMs

### The Parameters

The HmmHandler is passed parameters as a dict or a list of dicts. You can provide as many or as few of these parameters as you want. The defaults are drawn from `phmm.HMM_PARAMS`. These important parameters are:
- hmm_id
 - This is set automatically by the handler
- taste
 - This will be set automatically as all tastes in the dataset, but you can specific a single one if you'd like. **See Updates below**.
- channel
 - This is always set automatically
- unit_type
 - can be 'single', 'pyramidal', or 'interneurons'
- dt
 - time bin size to use in seconds
- threshold
 - the convergence threshold to use for fitting
- max_iter
 - max number of iterations while fitting
- n_cells
 - Set automatically
- n_trials
 - Set automatically
- time_start
 - Time start to cut data
- time_end
 - Time end to cut data
- n_repeats
 - Number of repeats to fit of this HMM, best is chosen automatically by lowest BIC
- n_states
 - Number of predicted states to fit
- fitted
 - set automatically when fitting is compelete
 
Notice that a lot are set automatically, so your input dict can only contain the parameters you want to deviate from the defaults, the rest will be filled in. See the below cell's output to see the default dict

In [3]:
phmm.HMM_PARAMS

In [4]:
# Example of defining parameter set
params = [{'unit_type': 'pyramidal', 'time_end': 2500, 'n_states': x} for x in [2,3]]
print(params)

### Now we can initialize and run the handler
Keep in mind after using the handler you can load it again at anytime and add new parameters and re-run, only the model that haven't already been fitted will be run.

In [None]:
# Initializing
handler = phmm.HmmHandler(rec_dir)
# Save directory is automatically made inside the recording directory,
# but you can also specificy another place witht eh save_dir keyword argument.
# You can also pass the params directly when initializing the handler, but
# I just split it here you can see how to add new parameters later.
handler.add_params(params)

# Running the handler
handler.run() # to overwrite existing models pass overwrite=True

# To plot data
handler.plot_saved_models()

# Looking at the parameters already in the handler
parameter_overview = handler.get_parameter_overview() # this is a pandas DataFrame

# Looking at the parameters and fitted model stats
data_overview = handler.get_data_overview() # also a pandas DataFrame with extra info such as cost and BIC

# The matrices defining each HMM and the best sequences can be access from teh HDF5 store directly. They can also be access programatically with:
hdf5_file = handler.h5_file
hmm, time, params = handler.get_hmm(0) # the hmm_id number goes here
# The hmm object has an attribute stat_arrays with various information including best_sequences, 
# gamma_probabilities, max_log_prob (on each iteration), time, and row_id
# Now you have the PoissonHMM object with the fitted model parameters and can do anything with it. 
# Only information lost is the model history, every model is set to the best model in the history before saving

When using pre-blechpy spike sorted data, in order to fit HMMs you must first:
- create a dataset
- initParams
- create_trial_list
- blechpy.dio.h5io.write_electrode_map_to_h5(dat.h5_file, dat.electrode_mapping)
- add array_time vector to the spike_arrays stored in the .h5 file

These processes are wrapped into the `blechpy.port_in_dataset` function. 
```python
#Example
blechpy.port_in_dataset(rec_dir, shell=True)
```

# Updates
HmmHandler has some little things about the params you should know.

- taste:
  - If left as `None`: will fit 1 HMM for each dig_in with spike_array == True in dat.dig_in_mapping
  - If `str`: should match the name of a single dig_in in dat.dig_in_mapping, will fit 1 HMM using trials for that taste only
  - If `List of str`: Will fit exactly 1 HMM using trials from all digital_input specfied by name in the List. 
  - If `'all'`: Will fit exactly 1 HMM using trial from all digital_inputs with spike_arrays

- n_trials:
    - If left as `None`: this will be automatically set as the number of trials used to fit each HMM
    - If specfied as an `int`: If the integer N is less then the total number of possible trials (for a particular taste) then the fitting will only be done using the first N trials for each taste.

- trial_nums:
    - This does not even have to be in the params dictionary at all. You will notice it is not int he default params dict `phmm.HMM_PARAMS`. However if this is provided, it MUST be a `list of int` and will specify the trial numbers to be used in fitting. Importantly, these are the trial numbers for each taste. So if you have 120 trials in a session (30 for each of 4 tastes) and you want the last 10 trials of each tastant to be used in fitting then `params['trial_nums'] = list(range(20,30))` or `params['trials_nums'] = [20,21,22,23,24,25,26,27,28,29]` because python indexes from 0. This will use only those trials to fit the HMM. If its 1 taste per HMM then it will fit the HMM using those 10 trials. If multiple tastes per HMM, then it will use these trials for each taste. 