In [None]:
%run master_functions.ipynb

In [None]:
def plot_msd(hmm, ntrajper, res, traj_no, savename=None, color='xkcd:orange', show=True):
    
    final_p = get_params([hmm], hmm, hmm.z, clustered=False)

    traj = gen_realizations(final_p, ntrajper, progress=True)

    msd = individual_unclustered_realizations(res, traj_no, traj[np.newaxis, ...], single=True, 
                                              savename=savename, color=color, show=show)

# Load Original IHMM File

In [None]:
res = 'URE'

ihmm = file_rw.load_object('saved_parameters/2000iter_%s_unseeded.pl' % res)['ihmm']

# Define Trajectory that needs improvement and have a look

In [None]:
tr = 13 # trajectory number to improve
savename = '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/supporting_figures/state_sequence_before_%s_%d.pdf' % (res, tr)
ihmm[tr].summarize_results(savename=None)

# Reparameterize Trajectory

These are parmeters that may need to be changed in order to get a better initial guess at the state sequence.
This will run the first 250 iterations of the IHMM procedure so that you can see if it's heading in the right direction.

In [None]:
load = True  # should be False if reparameterizing
show_traj = True
save = False
mod_T = False
ntrajper = 10
analyze = [tr]
analyze = np.arange(24)
# analyze = [0, 1, 4, 8, 13, 15, 22]
# analyze = [1]
seed = True
seed_segs = 8
niter_inference = 250

scout = True  # take a look at the MSDs and don't do any parameterization

# IHMM params
max_states = 100  # More is usually better
com = 'trajectories/com_xy_radial_%s.pl' % res
load_com=True
#hyperparams={'scale_sig0': np.array([1, 1]), 'a_gamma': 50}
hyperparams={'a_gamma': 5000}
hyperparams=None

com_filename = 'trajectories/com_xy_radial_%s.pl' % res

if seed:
    com = file_rw.load_object(com_filename)
    load_com = False
else:
    com = com_filename
    load_com = True

for i, hmm in enumerate(ihmm):
    
    if i in analyze:
        
        print('Trajectory %d' % i)

        if not scout:
            
            if show_traj:

                hmm.summarize_results(traj_no=0)

            if seed:

                z = seed_sequence(com, i, nseg=seed_segs, max_states=max_states, niter=5)
                print('Seeding with %d states' % np.unique(z).size)
                com = (com[0][:(z.size + 1), ...], com[1])
                load_com = False

            else:

                z = None
                load_com = True

            # do it again
            max_states = max(max_states, np.unique(z).size)
            new_hmm = hdphmm.InfiniteHMM(com, traj_no=i, load_com=load_com, difference=False, 
                                 observation_model='AR', order=1, max_states=max_states,
                                 dim=[0, 1, 2], prior='MNIW-N', save_every=20, hyperparams=hyperparams,
                                 seed_sequence=z)

            new_hmm.inference(niter_inference)

            new_hmm.summarize_results(traj_no=0)
            new_hmm._get_params(traj_no=0)
        
        plot_msd(hmm, ntrajper, res, i, show=scout)
        
        if not scout:
            
            plot_msd(new_hmm, ntrajper, res, i, color='xkcd:blue')


In [None]:
print(new_hmm.com.shape)

# Look for convergence in the total number of states

In [None]:
plt.plot(new_hmm.convergence['nstates'], lw=2)
plt.xlabel('iteration', fontsize=14)
plt.ylabel('Number of Unique States', fontsize=14)
plt.tick_params(labelsize=14)
plt.show()

# If satisfied with above, continue the remainder of IHMM procedure

In [None]:
new_hmm.inference(1750)  # I want 2000 total iterations for the paper

# Check that number of unique states stopped changing

In [None]:
plt.plot(new_hmm.convergence['nstates'], lw=2)
plt.xlabel('iteration', fontsize=14)
plt.ylabel('Number of Unique States', fontsize=14)
plt.tick_params(labelsize=14)
plt.show()

# Get new converged parameters

In [None]:
new_hmm._get_params()  # need to do this to update the converged parameters

# View and/or save a picture of the new state sequence and MSD prediction

In [None]:
save = False

if save:
    savename = '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/supporting_figures/state_sequence_after_%s_%d.pdf' % (res, analyze[0])
else:
    savename = None
    
new_hmm.summarize_results(traj_no=0, savename=savename)

In [None]:
save = True
if save:
    savename = '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/supporting_figures/underestimate_%s_%d.pdf' % (res, tr)
else:
    savename = None

plot_msd(ihmm[tr], ntrajper, res, tr, savename=savename)

In [None]:
if save:
    savename = '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/supporting_figures/msd_improvement_%s_%d.pdf' % (res, analyze[0])
else:
    savename = None
    
plot_msd(new_hmm, ntrajper, res, tr, savename=savename)

# Update IHMM file with new parameterization

In [None]:
ihmm[tr] = new_hmm

In [None]:
file_rw.save_object({'ihmm': ihmm, 'ihmmr': None}, 'saved_parameters/2000iter_%s_unseeded.pl' % res)

# Test New IHMM file

In [None]:
ihmm2 = file_rw.load_object('saved_parameters/2000iter_%s_unseeded.pl' % res)['ihmm']

In [None]:
ihmm2[tr].summarize_results()

In [None]:
plot_msd(ihmm2[tr], 10, res, tr)

# Visualize Seeded Initial Sequences

In [None]:
save_prefix = '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/supporting_figures/seed_%s%d' %(res, tr)
z = seed_sequence(file_rw.load_object(com_filename), tr, nseg=4, max_states=100, niter=5, save_prefix=save_prefix)

In [None]:
seeded = hdphmm.InfiniteHMM(com_filename, traj_no=tr, load_com=True, difference=False, 
                     observation_model='AR', order=1, max_states=max_states,
                     dim=[0, 1, 2], prior='MNIW-N', save_every=20, hyperparams=hyperparams,
                     seed_sequence=z)

seeded.inference(1)

In [None]:
savename = '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/supporting_figures/seed_%s%d_full.pdf' %(res, tr)
seeded.summarize_results(savename=savename)