In [None]:
# all imports here
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.collections import LineCollection
from matplotlib import colors as mcolors
import matplotlib.patches as mpatches
from LLC_Membranes.llclib import file_rw, timeseries
import hdphmm
from hdphmm import timeseries as ts
from hdphmm.generate_timeseries import GenARData
import numpy as np
import mdtraj
from scipy import stats
import levy
from hdphmm.cluster import Cluster
from sklearn.metrics import silhouette_score
import tqdm
import math

In [1]:
# All functions in this cell in alphabetical order


def association_fraction_lifetimes(z, coord, colors=('xkcd:blue', 'xkcd:orange'), bar_width=0.4, dt=0.5,
                                   savename=None):

    dominant_states, dominant_state_counts = prevalent_states(z, percent=20)
    ordered_ = list(np.argsort(dominant_state_counts)[::-1])
    ordered = [list(dominant_states).index(i) for i in ordered_[:len(dominant_states)]]

    fig, ax = plt.subplots()
    ax2 = ax.twinx()

    nT = monomer_hbonds.shape[1]
    
    for i, s in enumerate(dominant_states[ordered]):

        assoc = 0
        total = 0
        
        lifetimes = []
        
        for t in range(coord.shape[0]):
            
            ndx = np.where(z[t, :] == s)[0]
            
            if len(ndx) > 0:
                
                assoc += coord[t, ndx].sum()
                total += len(ndx)
                
                lt = np.zeros(nT)
                lt[np.where(coord[t, ndx] == 1)[0]] = 1
                lifetimes += calculate_lifetimes(lt)

        boot = bootstrap_lifetimes(lifetimes)
        ax.bar(i - bar_width/2, assoc / total, bar_width, color=colors[0], edgecolor='white', lw=1)
        ax2.bar(i + bar_width/2, dt*boot.mean(), bar_width, color=colors[1], edgecolor='white', lw=1, 
                yerr=dt*boot.std())

    ax.set_xticks(np.arange(dominant_states.size))
    ax.set_xticklabels(np.arange(1, dominant_states.size + 1))
    ax.tick_params(labelsize=14)
    ax2.tick_params(labelsize=14)
    ax.set_xlabel('State Number', fontsize=14)
    ax.set_ylabel('Fraction of total time\nin state coordinated', fontsize=14)
    ax2.set_ylabel('Association lifetime (ns)', fontsize=14)
    
    ax2.spines['left'].set_color(colors[0])
    ax2.spines['right'].set_color(colors[1])
    
    ax.tick_params(labelsize=14, axis='y', colors=colors[0])
    ax2.tick_params(labelsize=14, axis='y', colors=colors[1])
    ax.tick_params(labelsize=14)
    
    
    fig.tight_layout()
    
    if savename is not None:
        
        plt.savefig(savename)

    plt.show()
    
    
def boot(array, nboot, bins, bin_range, confidence):

    lower_confidence = (100 - confidence) / 2
    upper_confidence = 100 - lower_confidence
    
    boots = np.zeros([nboot, bins])
    for b in range(nboot):
        
        ndx = np.random.choice(len(array), size=len(array), replace=True)
        
        if isinstance(array, list):

            array_boot = []
            for n in ndx:
                array_boot += array[n]
                
        else:
            array_boot = array[ndx]

        boots[b, :], edges = np.histogram(array_boot, bins, range=bin_range, density=True)
        
    limits = np.zeros([2, boots.shape[1]])
    for d in range(limits.shape[1]):
        limits[0, d] = np.abs(np.percentile(boots[:, d], lower_confidence) - boots.mean(axis=0)[d])
        limits[1, d] = np.abs(np.percentile(boots[:, d], upper_confidence) - boots.mean(axis=0)[d])
        
    return boots.mean(axis=0), limits


def bootstrap_lifetimes(lifetime, nboot=200, ci=95):
    """ Determine 95 % confidence interval of hydrogen bond lifetimes
    """
    
    boots = np.zeros([nboot])
    for b in range(nboot):
        
        life_boot = np.random.choice(lifetime, size=len(lifetime), replace=True)
        
        boots[b] = np.percentile(life_boot, ci)
        
    return boots


def calculate_lifetimes(x):
    """ Calculate the lifetime of each hydrogen bond and output some statistics

    :param ci: confidence that a solute has a hop length less than a certain amount
    :param nboot: number of bootstrap trials for generating statistics

    :type ci: float
    :type nboot: int
    """

    dwell_times = []

    frame = 0
    while frame < x.size:
        if x[frame]:
            count = 1
            while (frame + count) < x.size:
                if x[frame + count]:
                    count += 1
                elif (frame + count + 1) < x.size and not x[frame + count] and x[frame + count + 1]:
                    count += 1
                else:
                    break

            frame += count
            if frame < x.size:  # don't count the last dwell time since it was not necessarily finished
                dwell_times.append(count)
        else:
            frame += 1

    return dwell_times


def cluster_behavior(params, percent, n=2):
    """
    percent: only show states that are in this percent of total trajectories
    """

    z = params['z']
    ihmmr = params['ihmmr']
    mu = params['mu']

    nclusters = np.unique(z).size

    state_counts = dict()

    for n in range(24):

        unique_states = np.unique(ihmmr[n].clustered_state_sequence[0, :])
        #print(unique_states)
        for u in unique_states:
            if u in state_counts.keys():
                state_counts[u] += 1
            else:
                state_counts[u] = 1

    nstates = max(state_counts.keys()) + 1
    state_counts = np.array([state_counts[i] for i in range(nstates)])
    fraction = state_counts / 24

    prevelant_states = np.where(fraction > (percent / 100))[0]
    
    reorder = np.argsort(mu[prevelant_states, 0])
    print(reorder)
    prevelant_states = prevelant_states[reorder]
    print(prevelant_states)

    # For methanol
    cmap = plt.cm.jet
    
    shown_colors = np.array([cmap(i) for i in np.linspace(50, 225, len(prevelant_states)).astype(int)])
    colors = np.array([cmap(i) for i in np.linspace(50, 225, nclusters + 1).astype(int)])
    colors[prevelant_states] = shown_colors

    print('Prevelant States:', prevelant_states)

    shift = 1.5

    fig, ax = plt.subplots(1, 3, figsize=(10, 10), sharey=False, gridspec_kw={'width_ratios': [1, 1, 0.15]})
    #fig, ax = plt.subplots(1, 2, figsize=(7, 7), sharey=True)

    trajectory_generator = GenARData(params=final_parameters)

    A = final_parameters['A']
    sigma = final_parameters['sigma']
    T = final_parameters['T']

    fig1, Tax = plt.subplots()
    fig2, Aax = plt.subplots()
    #sigax = Aax.twinx()

    bin_width = 0.2
    
    for i, s in enumerate(prevelant_states):

        Adiag = np.diag(A[s, 0, ...])
        sigdiag = np.diag(sigma[s, ...])

#         print(np.diag(A[s, 0, ...]), np.diag(sigma[s, ...]))
        trajectory_generator.gen_trajectory(100, 1, state_no=s, progress=False)

        t = trajectory_generator.traj[:, 0, :]
        t -= t.mean(axis=0)

        Tax.bar(i, 0.5 * (1 / (1 - T[s, s])), color=colors[s], edgecolor='black')

        Aax.scatter(sigdiag[0], Adiag[0], color=colors[s], edgecolor='black', s=100)
        Aax.scatter(sigdiag[1], Adiag[1], color=colors[s], edgecolor='black', s=100, marker='^')
        
        ax[0].plot(t[:, 1] + i*shift, lw=2, color=colors[s])
        ax[1].plot(t[:, 0] + i*shift, lw=2, color=colors[s])
        ax[2].text(0, i*shift, '%.1f %%' % (100*fraction[s]), fontsize=16, horizontalalignment='center')

    ax[0].set_yticks([i*shift for i in range(len(prevelant_states))])
    #ax[0].set_yticklabels(['%.1f %%' % (100*fraction[s]) for s in prevelant_states])

    ax[0].set_yticklabels(['%d' % (s + 1) for s in range(len(prevelant_states))])

    ax[1].set_yticks([i * shift for i in range(len(prevelant_states))])
    ax[1].set_yticklabels(['%.1f' % mu[p, 0] for p in prevelant_states])

    ax[0].set_xlabel('Step Number', fontsize=18)
    ax[0].set_ylabel('State Number', fontsize=18)
    ax[1].set_xlabel('Step Number', fontsize=18)
    ax[0].set_title('$z$ direction', fontsize=18)
    ax[1].set_title('$r$ direction', fontsize=18)
    ax[0].tick_params(labelsize=16)
    ax[1].tick_params(labelsize=16)
    ax[1].set_ylabel('Cluster radial mean', fontsize=18)
    ax[2].axis('off')
    ax[2].set_yticks([i*shift for i in range(len(prevelant_states))])
    ax[2].set_title('Percentage\nPrevalence', fontsize=18)

    ax[2].set_xlim(0, 1)

    ax[0].set_ylim(-shift, shift * len(prevelant_states))
    ax[1].set_ylim(-shift, shift * len(prevelant_states))
    ax[2].set_ylim(-shift, shift * len(prevelant_states))

    circle = mlines.Line2D([], [], color='white', markeredgecolor='black', marker='o', linestyle=None, label='radial dimension', markersize=12)
    square = mlines.Line2D([], [], color='white', markeredgecolor='black', marker='^', linestyle=None, label='axial dimension', markersize=12)

    Tax.set_ylabel('Average dwell time (ns)', fontsize=18)
    Tax.set_xticklabels([])

    Aax.set_ylabel('A diagonals', fontsize=18)
    Aax.set_xlabel('$\Sigma$ diagonals', fontsize=18)
    #Aax.set_xticklabels([])
    Aax.tick_params(labelsize=16)
    Tax.tick_params(labelsize=16)
    Aax.legend(handles=[circle, square], fontsize=16)
    #sigax.tick_params(labelsize=14)
    Tax.set_xticks([i for i in range(len(prevelant_states))])
    Tax.set_xticklabels([i + 1 for i in range(len(prevelant_states))])
    Tax.set_xlabel('State Number', fontsize=16)

    fig1.tight_layout()
    fig2.tight_layout()
    fig.tight_layout()

    plt.show()
    
    return prevelant_states


def cluster_centroids(sig_cluster, T, mu, A, sigma):

    labels = np.array(sig_cluster.labels)
    nclusters = np.unique(labels).size
    colors = np.array([plt.cm.jet(i) for i in np.linspace(50, 225, nclusters).astype(int)])
    fig, ax = plt.subplots(2, 2, figsize=(12, 10))

    for n in range(nclusters):

        ndx = np.where(labels == n)[0]

        ax[0, 0].bar(n, 1 / (1 - T[ndx].mean()), color=colors[n])#, yerr=np.std(1 / (1 - T[ndx])))
        ax[0, 1].bar(n, mu[0, ndx].mean(), yerr=mu[0, ndx].mean(), color=colors[n])
        ax[1, 0].scatter(A[0, 0, ndx].mean(), A[1, 1, ndx].mean(), color=colors[n])
        ax[1, 1].scatter(sigma[0, 0, ndx].mean(), sigma[1, 1, ndx].mean(), color=colors[n])

    ax[0, 0].tick_params(labelsize=14)
    ax[0, 1].tick_params(labelsize=14)
    ax[1, 0].tick_params(labelsize=14)
    ax[1, 1].tick_params(labelsize=14)

    ax[0, 0].set_xlabel('T$_{ii}$', fontsize=14)
    ax[0, 1].set_xlabel('r (nm)', fontsize=14)
    ax[1, 0].set_xlabel('A(0, 0)', fontsize=14)
    ax[1, 1].set_xlabel(r'$\Sigma$(0, 0)', fontsize=14)

    ax[0, 0].set_ylabel('Frequency', fontsize=14)
    ax[0, 1].set_ylabel('Frequency', fontsize=14)
    ax[1, 0].set_ylabel('A(1, 1)', fontsize=14)
    ax[1, 1].set_ylabel(r'$\Sigma$(1, 1)', fontsize=14)

    fig.tight_layout()

    plt.show()

def cluster_parameters(sigma, A, mu, T, combine_clusters, cluster_T_separate, eigs, diags, tot_clusters, 
                       nonT_clusters, nclusters_T, nclusters_sigma, nclusters_r, nclusters_A, linkage='ward'):

    if combine_clusters:

        params = {'sigma': sigma, 'A': A, 'mu': np.square(np.linalg.norm(mu[:2, :], axis=0)), 'T': -np.log(1 - T)}
        sig_cluster = Cluster(params, eigs=eigs, diags=diags, algorithm=algorithm, distance_threshold=None, 
                              nclusters=tot_clusters, linkage=linkage, convert_rz=True)
#         sig_cluster.X[:, -2] *= 100
        sig_cluster.fit()
        new_labels = sig_cluster.labels

        print('Found %d clusters' % np.unique(sig_cluster.labels).size)

    elif cluster_T_separate:

        params = {'sigma': sigma, 'A': A, 'mu': np.square(np.linalg.norm(mu[:2, :], axis=0))}
        sig_cluster = Cluster(params, eigs=eigs, diags=diags, algorithm=algorithm, distance_threshold=None, 
                              nclusters=nonT_clusters, linkage=linkage, convert_rz=True)

        sig_cluster.fit()

        nsig_clusters = np.unique(sig_cluster.labels).size
        print('Found %d non-T clusters' % nsig_clusters)

        T_cluster = Cluster({'T': 1 / -np.log(1 - T)}, algorithm=algorithm, nclusters=nclusters_T)
        T_cluster.fit()

        print('Found %d T clusters' % nclusters_T)

        new_clusters = np.zeros([A.shape[-1]])

        for state in range(A.shape[-1]):
            #new_clusters[state] = A_cluster.labels[state] * nsig_clusters * nclusters_r * nclusters_T + sig_cluster.labels[state] * nclusters_r * nclusters_T + r_cluster.labels[state] * nclusters_T + T_cluster.labels[state]
            new_clusters[state] = sig_cluster.labels[state] * nclusters_T + T_cluster.labels[state]

        print('Found %d total clusters' % np.unique(new_clusters).size)

        all_labels = np.unique(new_clusters).astype(int)

        new_label_dict = {l:i for i, l in enumerate(all_labels)}

        new_labels = [new_label_dict[int(i)] for i in new_clusters]

        sig_cluster.labels = new_labels

    else:

        sig_params = {'sigma': sigma}
        A_params = {'A': A}

        sig_cluster = Cluster(sig_params, eigs=eigs, diags=diags, algorithm=algorithm, distance_threshold=None, 
                              nclusters=nclusters_sigma, linkage=linkage, convert_rz=True)
        A_cluster = Cluster(A_params, eigs=eigs, diags=diags, algorithm=algorithm, distance_threshold=None, 
                            nclusters=nclusters_A, linkage=linkage, convert_rz=True)
        r_cluster = Cluster({'mu': np.square(np.linalg.norm(mu[:2, :], axis=0))}, algorithm=algorithm, nclusters=nclusters_r,
                            linkage=linkage)

        sig_cluster.fit()
        A_cluster.fit()
        r_cluster.fit()

        nA_clusters = np.unique(A_cluster.labels).size
        nsig_clusters = np.unique(sig_cluster.labels).size
        print('Found %d sigma clusters' % nsig_clusters)
        print('Found %d A clusters' % nA_clusters)
        print('Found %d r clusters' % nclusters_r)

        if nclusters_T is not None:
            T_cluster = Cluster({'T': -np.log(1 - T)}, algorithm=algorithm, nclusters=nclusters_T, linkage=linkage)
            T_cluster.fit()

        print('Found %d T clusters' % nclusters_T)

        cluster_matrix = np.zeros([nA_clusters, nsig_clusters])

        new_clusters = np.zeros([A.shape[-1]])

        if nclusters_T is None:
            for state in range(A.shape[-1]):
                #new_clusters[state] = A_cluster.labels[state] * nsig_clusters + sig_cluster.labels[state]
                new_clusters[state] = A_cluster.labels[state] * nsig_clusters * nclusters_r + sig_cluster.labels[state] * nclusters_r + r_cluster.labels[state]
        else:
            for state in range(A.shape[-1]):
                new_clusters[state] = A_cluster.labels[state] * nsig_clusters * nclusters_r * nclusters_T + sig_cluster.labels[state] * nclusters_r * nclusters_T + r_cluster.labels[state] * nclusters_T + T_cluster.labels[state]

        print('Found %d total clusters' % np.unique(new_clusters).size)

        all_labels = np.unique(new_clusters).astype(int)

        new_label_dict = {l:i for i, l in enumerate(all_labels)}

        new_labels = [new_label_dict[int(i)] for i in new_clusters]

        sig_cluster.labels = new_labels
        
    return sig_cluster, {'A': A, 'sigma': sigma, 'mu': mu, 'state_labels': new_labels, 'T': T}


def findseed(s, traj_no, ntraj):
    """
    Plot qualitative realizations for given random seed
    """
    
    #traj_no = 2
    #ntraj = 3
    np.random.seed(s)  # 3, 9, 10, 13

    final_p = get_params([ihmm[traj_no]], ihmm[traj_no], ihmm[traj_no].z, clustered=False)

    trajectory_generator = GenARData(params=final_p)
    trajectory_generator.gen_trajectory(ihmm[traj_no].z.shape[1], ntraj, bound_dimensions=[0], resample_T=True, progress=False)

    plot_realizations(ihmm[traj_no].com[1:, ...], trajectory_generator.traj)

#     possibilities = [3, 10, 31, 34, 54]
#     for i in possibilities:
#         print(i)
#         findseed(i)

def gen_realizations(final_p, ntrajper, progress=False):
    
    trajectory_generator = GenARData(params=final_p)

#     for t in tqdm.tqdm(range(ntrajper)):

    trajectory_generator.gen_trajectory(final_p['z'].shape[1], ntrajper, bound_dimensions=[0, 1], 
                                        resample_T=True, progress=progress)

        # This shouldn't be necessary. I implemented when testing unstable higher order ARs
        #while trajectory_generator.traj.max() > 1e4:
        #    trajectory_generator.gen_trajectory(hmm.z.shape[1], 1, bound_dimensions=[0], 
        #                                        resample_T=True, progress=False)
        
    return trajectory_generator.traj


def get_clustered_mur(mur, labels, weights=True):
    
    uniq = np.unique(labels)

    mu_ = []

    for i in range(uniq.size):
        ndx = np.where(labels == uniq[i])[0]
        if weights:
            mu_ += [mur[ndx].mean()] * len(ndx)
        else:
            mu_ += [mur[ndx].mean()]
            
    return mu_


def get_coordination(res):
    
    return file_rw.load_object('trajectories/coord_summary_%s.pl' % res)['coordinated']


def get_density(res, ma=None):
    """ Load up local density timeseries and apply a moving average to the series (ma)
    """
    
    dens = file_rw.load_object('trajectories/local_density_%s.pl' % res)
    
    if ma is not None:
        density = np.zeros([dens.shape[0] - ma + 1, dens.shape[1]])
        for s in range(len(ihmm)):
            density[:, s] = timeseries.calculate_moving_average(dens[:, s], ma)

        return density
    
    else:
        return dens


def get_hbonds(res):
    
    # coloring
    carb = 1
    head = 2
    tails = 3
    other_solutes = 0
    monomer_colors = {'O': head, 'O1': head, 'O2': head, 'O3': carb, 'O4': carb, 'O5': tails, 'O6': tails, 'O7': tails, 'O8': tails, 'O9': tails, 'O10': tails, 'N1': other_solutes, 'N': other_solutes}

    definitions = {0: 'Not Hbonded', 1: 'Carboxylates', 2: 'Ethers', 3: 'Tails'}
    
    hbonds = file_rw.load_object('trajectories/hbond_summary_%s.pl' % res)
    hbonded = hbonds['hbonded']
    hbonded_to = hbonds['bonded_to']

    # for color-coding types of hbonds
    monomer_hbonds = np.zeros_like(hbonded, dtype=int)
    for h in range(len(hbonded_to)):
        ndx = []
        for s in range(hbonded.shape[0]):
            if len(hbonded_to[h][s]) > 0:
                monomer_hbonds[s, h] = monomer_colors[hbonded_to[h][s][0][0]]
                ndx.append(s)
                
    return hbonded, monomer_hbonds, definitions


def get_mu(ihmm, rotate_xy=False):
    
    dim = ihmm[0].dimensions
    traj_no = np.arange(len(ihmm))

    # get all of the mu
    all_mu = None
    for t in traj_no:
        
        #ihmm[t]._get_params(traj_no=0, quiet=True)

        m = ihmm[t].converged_params['mu'].mean(axis=0)
        phi = ihmm[t].converged_params['A'][:, 0, ..., :].mean(axis=0)

        # convert to unconditional mean
        for i in range(m.shape[1]):
            m[:, i] = np.linalg.inv(np.eye(dim) - phi[..., i]) @ m[:, i]
            if rotate_xy:
                m[:, i] = rotate_vector(m[:, [i]].T, m[:2, i], [1, 0])

        if all_mu is None:
            all_mu = m
        else:
            all_mu = np.concatenate((all_mu, m), axis=1)
            
    return all_mu


def get_params(ihmm, ihmm_final, z, mu_weights=None, clustered=True, recalculate_params=False, rotate_xy=False, 
               rmeans=True):
    
    dim = ihmm_final.dimensions
    order = ihmm[0].order
    
    nclusters = np.unique(z).size
       
    all_mu = get_mu(ihmm, rotate_xy=rotate_xy)
    #all_mu = np.zeros([2, nclusters])

    ntraj = len(ihmm)

    A = np.zeros([ntraj, nclusters, order, dim, dim])
    sigma = np.zeros([ntraj, nclusters, dim, dim])
    weights = np.zeros([ntraj, nclusters])
    nemissions = [[] for _ in range(ntraj)]

    for t in range(ntraj):
        if recalculate_params:
            ihmm_final._get_params(traj_no=t, quiet=True)
        for i, ndx in enumerate(ihmm_final.found_states):
            if not clustered:
                ndx = i
            A[t, ndx, ...] = ihmm_final.converged_params['A'][..., i].mean(axis=0)
            sigma[t, ndx, ...] = ihmm_final.converged_params['sigma'][:, ..., i].mean(axis=0)
            weights[t, ndx] = np.where(ihmm_final.z[t, :] == ihmm_final.found_states[i])[0].size

    A_final = np.zeros([nclusters, order, dim, dim])
    sigma_final = np.zeros([nclusters, dim, dim])
    for c in range(nclusters):
        if weights[:, c].sum() > 0:
            A_final[c, ...] = np.average(A[:, c, ...], axis=0, weights=weights[:, c])
            sigma_final[c, ...] = np.average(sigma[:, c, ...], axis=0, weights=weights[:, c])
    
#     nclusters = np.unique(sig_cluster.labels).size
    if clustered:
        if rmeans:
            mu = np.zeros([nclusters, 2])
            for i in range(nclusters):
                ndx = np.where(np.array(sig_cluster.labels) == i)[0]  #WARNING - sig_cluster.labels should be passed to this function
                mu[i, 0] = np.average(np.linalg.norm(all_mu[:2, ndx], axis=0), weights=mu_weights[ndx])
                mu[i, 1] = np.average(all_mu[2, ndx], weights=mu_weights[ndx])
        else:
            mu = np.zeros([nclusters, dim])
            for i in range(nclusters):
                ndx = np.where(np.array(sig_cluster.labels) == i)[0]
                mu[i, :] = np.average(all_mu[:, ndx], axis=1, weights=mu_weights[ndx])  # untested
    else:

        if rmeans:
            mu = np.zeros([all_mu.shape[1], 2])
            mu[:, 0] = np.linalg.norm(all_mu[:2, :], axis=0)
            mu[:, 1] = all_mu[2, :]
        else:
            mu = np.copy(all_mu.T)
        
    # Get final transition matrix
    found_states = np.unique(ihmm_final.z)
    ndx_dict = {found_states[i]: i for i in range(len(found_states))}

    if clustered:
        equil = ihmm_final.converged_params['T'].shape[0]
        transition_matrix = np.array(ihmm_final.convergence['T'])[-equil:, ...]
    else:
        transition_matrix = np.array(ihmm_final.converged_params['T'])

    # equilibrium distribution is the eigenvector associated with an eigenvalue of 1
    w, v = np.linalg.eig(transition_matrix.mean(axis=0).T)
    ndx = np.argmin(np.abs(w - 1))
    pi_init = v[:, ndx] / v[:, ndx].sum()
    
    # final_parameters
    if clustered:
        return {'A': A_final, 'sigma': sigma_final, 'mu': mu, 'self_T': T, 'T': transition_matrix, 'pi_init': pi_init.real, 'z': ihmm_final.z, 'ihmmr': ihmm, 'ihmm':ihmm, 'all_state_params': all_state_params, 'ihmm_final': ihmm_final, 'T_distribution': ihmm_final.convergence['T']}
    else:
        return {'A': A_final, 'sigma': sigma_final, 'mu': mu, 'T': transition_matrix, 'pi_init': pi_init.real, 'z': ihmm_final.z}

    
def hbond_fraction(z, monomer_hbonds, colors=('xkcd:blue', 'xkcd:green', 'xkcd:red'), savename=None):

    dominant_states, dominant_state_counts = prevalent_states(z, percent=20)
    ordered_ = list(np.argsort(dominant_state_counts)[::-1])
    ordered = [list(dominant_states).index(i) for i in ordered_[:len(dominant_states)]]

    fig, ax = plt.subplots()

    carb = {i : 0 for i in dominant_states}
    ethers = {i : 0 for i in dominant_states}
    tails = {i : 0 for i in dominant_states}
    total = {i : 0 for i in dominant_states}

    for i, s in enumerate(dominant_states[ordered]):

        for t in range(monomer_hbonds.shape[0]):
            ndx = np.where(z[t, :] == s)[0]
            if len(ndx) > 0:
                carb[s] += len(np.where(monomer_hbonds[t, ndx] == 1)[0])
                ethers[s] += len(np.where(monomer_hbonds[t, ndx] == 2)[0])
                tails[s] += len(np.where(monomer_hbonds[t, ndx] == 3)[0])
                total[s] += len(ndx)

        carb[s] /= total[s]
        ethers[s] /= total[s]
        tails[s] /= total[s]

        ax.bar(i, carb[s], color=colors[0], edgecolor='white', lw=1)
        ax.bar(i, ethers[s], color=colors[1], bottom=carb[s], edgecolor='white', lw=1)
        ax.bar(i, tails[s], color=colors[2], bottom=ethers[s] + carb[s], edgecolor='white', lw=1)

    ax.set_xticks(np.arange(dominant_states.size))
    ax.set_xticklabels(np.arange(1, dominant_states.size + 1))
    ax.tick_params(labelsize=14)
    ax.set_xlabel('State Number', fontsize=14)
    ax.set_ylabel('Fraction of total time in state', fontsize=14)
    
    fig.tight_layout()
    
    if savename is not None:
        
        plt.savefig(savename)

    plt.show()

    
def hbond_lifetimes(z, monomer_hbonds, colors=('xkcd:blue', 'xkcd:green', 'xkcd:red'), bar_width=0.25, dt=0.5,
                    savename=None):
    
    dominant_states, dominant_state_counts = prevalent_states(z, percent=20)
    ordered_ = list(np.argsort(dominant_state_counts)[::-1])
    ordered = [list(dominant_states).index(i) for i in ordered_[:len(dominant_states)]]

    fig, ax = plt.subplots()
    
    nT = monomer_hbonds.shape[1]

    for i, s in enumerate(dominant_states[ordered]):
        
        lifetimes = {1: [], 2: [], 3: []}  # 1: carboxylate groups, 2: ether groups, 3: tails

        for t in range(monomer_hbonds.shape[0]):
            
            ndx = np.where(z[t, :] == s)[0]
            
            if len(ndx) > 0:  # if state appears in this trajectory
                
                for j in range(1, 4):
                    series = np.zeros(nT)
                    series[np.where(monomer_hbonds[t, ndx] == j)[0]] = 1
                    lifetimes[j] += calculate_lifetimes(series)
        
        bar_shift = [-1, 0, 1]
        for k, b in enumerate(range(1, 4)):
            boot = bootstrap_lifetimes(lifetimes[b])
            ax.bar(i + bar_shift[k]*bar_width, dt*boot.mean(), bar_width, yerr=dt*boot.std(), color=colors[k], edgecolor='white')

    ax.set_xticks(np.arange(dominant_states.size))
    ax.set_xticklabels(np.arange(1, dominant_states.size + 1))
    ax.tick_params(labelsize=14)
    ax.set_xlabel('State Number', fontsize=14)
    ax.set_ylabel('Hydrogen Bond Lifetime (ns)', fontsize=14)
    
    ax.set_yticks(np.arange(0,21,3))
    ax.set_yticklabels(np.arange(0,21,3))
    
    fig.tight_layout()
    
    if savename is not None:
        plt.savefig(savename)

    plt.show()
    
    
def individual_unclustered_realizations(res, traj_no, unclustered_trajectory_realizations, endshow=2000, 
                                        confidence=68, dt=0.5, single=False, savename=None, show=True, color='xkcd:orange'):
    """ Plot the MSD the unclustered trajectory realizations generated for one of the solute trajectories (traj_no)
    """
    #print('Trajectory %d' % traj_no)
    
    MD_MSD = file_rw.load_object('trajectories/%s_msd.pl' % res)

    t = np.arange(endshow) * dt

    plt.plot(t, MD_MSD.MSD[:endshow, traj_no], color='black', label='MD')

    if single:
        trj = 0
    else:
        trj = traj_no
    
    msd = ts.msd(unclustered_trajectory_realizations[trj, ...], 2, progress=False)
    msd_mean = msd.mean(axis=1)

    lower_confidence = (100 - confidence) / 2
    upper_confidence = 100 - lower_confidence

    error = np.zeros([2, msd.shape[0]], dtype=float)  # upper and lower bounds at each point along MSD curve
    # determine error bound for each tau (out of n MSD's, use that for the error bars)
    for s in range(error.shape[1]):
        error[0, s] = np.abs(np.percentile(msd[s, :], lower_confidence) - msd_mean[s])
        error[1, s] = np.abs(np.percentile(msd[s, :], upper_confidence) - msd_mean[s])

    plt.plot(t, msd.mean(axis=1)[:endshow], lw=2, color=color, label='HDP-AR-HMM')
    plt.fill_between(t, msd.mean(axis=1)[:endshow] - error[0, :endshow], msd.mean(axis=1)[:endshow] + error[1, :endshow], alpha=0.3, color=color)

    plt.tick_params(labelsize=14)
    plt.xlabel('Time (ns)', fontsize=14)
    plt.ylabel('MSD (nm$^2$)', fontsize=14)
    
    plt.legend(loc='upper left', fontsize=14)
    plt.tight_layout()
    
    if savename is not None:
        plt.savefig(savename)
    
    if show:
        
        plt.show()
    
    #return msd
    
def modify_T(T, percentile=20, increase_factor=2, decrease_factor=1):

    cut = np.percentile(np.diag(T), percentile)

    increase_transition_to = np.where(np.diag(T) < cut)[0] # np.argmin(np.diag(T))
    decrease_transition_from = np.where(np.diag(T) > cut)[0] # np.argmax(np.diag(T))

    for i, x in enumerate(increase_transition_to):
        ndx = [n for n in range(T.shape[0]) if n != x]
        T[ndx, x] *= increase_factor

    #transition_matrix = (count_matrix.T / count_matrix.sum(axis=1)).T
    T = (T.T / T.sum(axis=1)).T

    return T[np.newaxis, ...]

    
def organize_parameters(ihmm):
    
    # Get the parameters of all states
    
    A = None
    sigma = None
    mu = None
    T = None
    mu_weights = None
    
    traj_no = np.arange(len(ihmm))

    for t in traj_no:
        
        dim = ihmm[t].dimensions

        estimated_states = ihmm[t].z[0, :]
        found_states = list(np.unique(estimated_states))

        a = np.zeros([dim, dim, len(found_states)])  # should probably an include a dimension for AR order
        s = np.zeros([dim, dim, len(found_states)])
        m = np.zeros([dim, len(found_states)])
        mw = np.zeros(len(found_states), dtype=int)
        st = np.diag(ihmm[t].converged_params['T'].mean(axis=0))

        for i, state in enumerate(found_states):

            Amean = ihmm[t].converged_params['A'][:, 0, ..., i].mean(axis=0)
            sigmamean = ihmm[t].converged_params['sigma'][:, ..., i].mean(axis=0)

            # we want to cluster on unconditional mean
            mucond = ihmm[t].converged_params['mu'][..., i].mean(axis=0)  # conditional mean
            mumean = np.linalg.inv(np.eye(dim) - Amean) @ mucond # unconditional mean
            
            mw[i] = len(np.where(ihmm[t].z[0, :] == state)[0])

            a[..., i] = Amean
            s[..., i] = sigmamean
            m[:, i] = mumean

        if A is None:
            A = a
            sigma = s
            mu = m
            T = st
            mu_weights = mw
        else:
            A = np.concatenate((A, a), axis=-1)
            sigma = np.concatenate((sigma, s), axis=-1)
            mu = np.concatenate((mu, m), axis=-1)
            T = np.concatenate((T, st), axis=-1)
            mu_weights = np.concatenate((mu_weights, mw))
            
    return A, sigma, mu, T, mu_weights


def outline_histogram(data, bin_width):
    
    nbins = int((data.max() - data.min()) / bin_width)
    if nbins == 0:
        nbins = 1
    
    n, edges = np.histogram(data, nbins, density=True)

    x = np.zeros([2 * nbins + 2])
    y = np.zeros_like(x)
    
    x[0] = edges[0]
    x[1::2] = edges
    x[2::2] = edges[:-1]

    width = edges[1] - edges[0]
    x += width
    
    y[:-2:2] = n
    y[1:-1:2] = n
    
    x = np.concatenate(([x[0] - width], x))
    y = np.concatenate(([y[0]], y))
    
    return x, y


def parameterize_clusters(sig_cluster, ihmm, mu_weights, show=False):
    
    traj_no = np.arange(len(ihmm))
    
    # Reassign state sequence in terms of new clustered labels
    new_labels = sig_cluster.labels
    ndx = 0
    for i in traj_no:
        end = ndx + len(ihmm[i].found_states)
        labels = new_labels[ndx:end]
        ndx = end
        ihmm[i].reassign_state_sequence(sig_cluster, labels=labels)

    mean_zero = []
    dwells = []
    hops = []
 
    lf = ihmm[0].z.size  # last frame since trajectory lengths might be uneven due to state sequence seeding
    for t in range(1, len(ihmm)):
        if ihmm[t].z.size < lf:
            lf = ihmm[t].z.size

    for t in traj_no:
        
        # new procedure zeroing out xyz trajectories. Should build it into ihmm.subtract_mean eventually
        z = ihmm[t].z

        alignment_vector = np.array([1, 0])  # rotate trajectories about z-axis to be in line with x-axis
        com = ihmm[t].com[1:(lf +1), ...] # the first point in the trajectory doesn't get assigned a state
        zeroed = np.zeros([lf, 1, 3])

        switchpoints = ts.switch_points(z[0, :lf])
        for s, sp in enumerate(switchpoints[:-1]):

            start = sp
            end = switchpoints[s + 1]

            mean = com[start:end, 0, :2].mean(axis=0)
            zeroed[start:end, 0, :] = rotate_vector(com[start:end, 0, :], mean, alignment_vector)

            zeroed[start:end, 0, [0, 2]] -= zeroed[start:end, 0, [0, 2]].mean(axis=0)
            
        mean_zero.append(zeroed)

        #zeroed, d, h = ihmm[t].subtract_mean(traj_no=0, simple_mean=True, return_dwells_hops=True)
        #dwells += d
        #hops += h

    #hops = np.array(hops)[:, 1]
#     mz = np.zeros([len(mean_zero), mean_zero[0].shape[0], mean_zero[0].shape[2]])
#     for j in range(len(mean_zero)):
#         mz[j, ...] = mean_zero[j][:, 0, :]
#     mean_zero = mz
    mean_zero = np.array(mean_zero)[:, :, 0, :]

    niter = 10

    z = None
    for t in traj_no:

        seq = ihmm[t].clustered_state_sequence[:, :lf]
        if z is None:
            z = seq
        else:
            z = np.concatenate((z, seq), axis=0)

    ihmm_final = hdphmm.InfiniteHMM((np.moveaxis(mean_zero, 0, 1), ihmm[t].dt), traj_no=None, load_com=False, 
                                    difference=False, order=1, max_states=len(np.unique(z)), dim=[0, 1, 2],
                                    prior='MNIW', save_com=False, state_sequence=z[:, 1:])

    ihmm_final.inference(niter)

    for i, t in enumerate(traj_no):
        if show:
            ihmm_final.summarize_results(traj_no=i)
            if i == 2:
                ihmm_final.summarize_results(traj_no=i, savename='/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/supporting_figures/zeroed_clustered_hmm.pdf')
        ihmm_final._get_params(traj_no=i, quiet=True)
        
    final_parameters = get_params(ihmm, ihmm_final, z, mu_weights, recalculate_params=True)
    
    return final_parameters


def plot_clustered_params(A, sigma, mu, T, clustered_labels, z, cmap=plt.cm.jet, nbins=20, percent=10):

    labels_clustered = np.array(clustered_labels)

    prev_states, _ = prevelant_states(z, percent=percent)

    colors = np.array([cmap(i) for i in np.linspace(50, 225, len(prev_states)).astype(int)])

    fig, ax = plt.subplots(2, 2, figsize=(12, 8))

    sig_eigs = np.linalg.eig(np.moveaxis(sigma, 2, 0))[0].real
    A_eigs = np.linalg.eig(np.moveaxis(A, 2, 0))[0].real
    
    sig_eigs = np.concatenate((np.square(sig_eigs[:, :2]).sum(axis=1)[:, np.newaxis], sig_eigs[:, [2]]), axis=1)
    A_eigs = np.concatenate((np.square(A_eigs[:, :2]).sum(axis=1)[:, np.newaxis], A_eigs[:, [2]]), axis=1)
    
    mu = np.concatenate((np.linalg.norm(mu[:2, :], axis=0)[np.newaxis, :], mu[[2], :]))
    
    for i, c in enumerate(prev_states):

        ndx = np.where(labels_clustered == c)[0]

        ax[0, 0].scatter(sig_eigs[ndx, 0], sig_eigs[ndx, 1], color=colors[i])
        ax[0, 1].scatter(A_eigs[ndx, 0], A_eigs[ndx, 1], color=colors[i])
        
        xr, yr = outline_histogram(mu[0, ndx]**2, 0.5)
        ax[1, 0].plot(xr, yr, color=colors[i], lw=2)

        xT, yT = outline_histogram(-np.log(1 - T[ndx]), 0.5)
        ax[1, 1].plot(xT, yT, color=colors[i], lw=2)

    ax[0, 0].set_xlabel(r'$\lambda_{\Sigma r}$', fontsize=14)
    ax[0, 0].set_ylabel(r'$\lambda_{\Sigma z}$', fontsize=14)
    ax[0, 1].set_xlabel(r'$\lambda_{A r}$', fontsize=14)
    ax[0, 1].set_ylabel(r'$\lambda_{A z}$', fontsize=14)
    ax[1, 0].set_xlabel('$\mu_r^2$', fontsize=14)
    ax[1, 0].set_ylabel('Count', fontsize=14)
    ax[1, 1].set_xlabel('-log(1 - T)', fontsize=14)
    ax[1, 1].set_ylabel('Count', fontsize=14)

    ax[0, 0].tick_params(labelsize=14)
    ax[0, 1].tick_params(labelsize=14)
    ax[1, 0].tick_params(labelsize=14)
    ax[1, 1].tick_params(labelsize=14)

    fig.tight_layout()
    
    plt.show()


def plot_dwells_hops(dwell_distributions, hop_distributions):
    """ This function is incomplete
    """
    
    dwells, pre_cluster_dwells, post_cluster_dwells = dwell_distributions

    fig, ax = plt.subplots(1, 2, figsize=(12, 5))

    bins, edges = np.histogram(dwells, 25, range=(0, 150), density=True)
    bins2, edges2 = np.histogram(trajectory_generator.dwells, 25, range=(0, 150), density=True)
    bins3, edges3 = np.histogram(pre_cluster_dwells.mean(axis=0), 25, range=(0, 150), density=True)

    bins_hop, edges_hop = np.histogram(hops, 51, range=(-2, 2), density=True)
    bins2_hop, edges2_hop = np.histogram(np.array(trajectory_generator.hops)[:, 1], 51, range=(-2, 2), density=True)
    bins3_hop, edges3_hop = np.histogram(pre_cluster_hops, 51, range=(-2, 2), density=True)

    bin_width = edges[1] - edges[0]
    centers = [i + bin_width/2 for i in edges[:-1]]

    bin_width_hop = edges_hop[1] - edges_hop[0]
    centers_hop = [i + bin_width_hop/2 for i in edges_hop[:-1]]

    ax[0].plot(centers, bins, lw=2, label='MD Data')
    ax[0].plot(centers, bins2, lw=2, label='HDP-AR-HMM Simulation Data')
    ax[0].plot(centers, pre_cluster_dwells, lw=2, label='Pre-cluster HDP-AR-HMM Simulation Data')

    ax[1].plot(centers_hop, bins_hop, lw=2, label='MD Data')
    ax[1].plot(centers_hop, bins2_hop, lw=2, label='HDP-AR-HMM Simulation Data')
    ax[1].plot(centers_hop, pre_cluster_hops, lw=2, label='Pre-cluster HDP-AR-HMM Simulation Data')

    ax[0].set_xlabel('Dwell Time (steps)', fontsize=14)
    ax[0].set_ylabel('Probability', fontsize=14)

    ax[1].set_xlabel('Hop Length (nm)', fontsize=14)
    ax[1].set_ylabel('Probability', fontsize=14)

    ax[0].tick_params(labelsize=14)
    ax[0].legend(fontsize=14)

    ax[1].tick_params(labelsize=14)
    ax[1].legend(fontsize=14)

    fig.tight_layout()

    plt.show()

    
def plot_hop_emissions(md_traj, ihmm_traj, xrange=(-1, 1), nbins=50, nboot=200, confidence=68, savename=None):
    """ Histogram first order differences of center of mass trajectories
    """
    
    fo_md = md_traj[1:, :10, 1] - md_traj[:-1, :10, 1]
    fo_ihmm = ihmm_traj[1:, :10, 1] - ihmm_traj[:-1, :10, 1]
    
    # This is the first time lag of the MSD
    # print(fo_md.flatten().std() ** 2)
    # print(fo_ihmm.flatten().std() ** 2)


    ihmm_hopsz, ihmm_hopsz_limits = boot(fo_ihmm, nboot, nbins, xrange, confidence)
    md_hopsz, md_hopsz_limits = boot(fo_md, nboot, nbins, xrange, confidence)

    edges = np.linspace(xrange[0], xrange[1], nbins + 1)
    bin_width = edges[1] - edges[0]
    centers = [i + bin_width/2 for i in edges[:-1]]

    fig, ax = plt.subplots(1, 1, figsize=(6, 4.5), sharey=True)

    MD_alphaz, MD_sigmaz = levy.fit_levy(fo_md.flatten(), beta=0)[0].x[[0, 2]]
    ihmm_alphaz, ihmm_sigmaz = levy.fit_levy(fo_ihmm.flatten(), beta=0)[0].x[[0, 2]]
    ax.plot(centers, md_hopsz, lw=2, label=r'MD ($\alpha$=%.2f)' % MD_alphaz)
    ax.fill_between(centers, md_hopsz + md_hopsz_limits[0, :], md_hopsz - md_hopsz_limits[1, :], alpha=0.3)

    ax.plot(centers, ihmm_hopsz, lw=2, label=r'HDP-AR-HMM ($\alpha$=%.2f)' % ihmm_alphaz)
    ax.fill_between(centers, ihmm_hopsz + ihmm_hopsz_limits[0, :], ihmm_hopsz - ihmm_hopsz_limits[1, :], alpha=0.3)

    ax.tick_params(labelsize=14)

    ax.set_xlabel('$z$ Fluctuation Size (nm)', fontsize=14)
    ax.set_ylabel('Probability (nm)', fontsize=14)

    ax.legend(fontsize=14)
    fig.tight_layout()
    
    if savename:
        plt.savefig(savename)
    
    plt.show()

    
def plot_msds(tot_clusters=(), confidence=68, nboot=10, dt=0.5, endshow=2000, endshow_inset=50, 
              unclustered_realizations=False, clustered_realizations=False, exclude_outliers=False, savename=None,
              inset=False, figsize=(8, 5), bbox_to_anchor=(-0.1, -.04), ncol_legend=2, unclustered_color='xkcd:orange'):

    cmap = plt.cm.jet
    
    t = np.arange(endshow)*dt
    tinset = np.arange(endshow_inset) * dt

    if clustered_realizations:
        fig, ax1 = plt.subplots(figsize=figsize)
        
        if inset:
            left, bottom, width, height = [0.15, 0.65, 0.3, 0.25]
            ax2 = fig.add_axes([left, bottom, width, height])
            
        lw=2
        fontsize=14
        c = unclustered_color
        lab = 'Unclustered'
    else:
        fig, ax1 = plt.subplots(figsize=figsize)
        lw=3
        fontsize=20
        c = 'xkcd:orange'
        lab = 'HDP-AR-HMM'

    if unclustered_realizations:

        exclude = []

        keep = [i for i in range(24) if i not in exclude]

        nrealizations = unclustered_trajectory_realizations.shape[2]
        msd = np.zeros([unclustered_trajectory_realizations.shape[1], nboot])

        for b in tqdm.tqdm(range(nboot)):

            ndx = np.random.choice(nrealizations, size=24, replace=True)

            realizations = np.moveaxis(unclustered_trajectory_realizations[keep, ...][:, :, b, :], 0, 1)

            msd[:, b] = ts.msd(realizations, 2, progress=False).mean(axis=1)

        lower_confidence = (100 - confidence) / 2
        upper_confidence = 100 - lower_confidence

        error = np.zeros([2, msd.shape[0]], dtype=float)  # upper and lower bounds at each point along MSD curve
        # determine error bound for each tau (out of n MSD's, use that for the error bars)
        for s in range(error.shape[1]):
            error[0, s] = np.abs(np.percentile(msd[s, :], lower_confidence) - msd.mean(axis=1)[s])
            error[1, s] = np.abs(np.percentile(msd[s, :], upper_confidence) - msd.mean(axis=1)[s])

        ax1.plot(t, msd.mean(axis=1)[:endshow], lw=lw, color=c, label=lab)
        ax1.fill_between(t, msd.mean(axis=1)[:endshow] + error[0, :endshow], msd.mean(axis=1)[:endshow] - error[1, :endshow], alpha=0.3, color=c)

        if clustered_realizations and inset:

            ax2.plot(tinset, msd.mean(axis=1)[:endshow_inset], lw=lw, color=c, label=lab)
            ax2.fill_between(tinset, msd.mean(axis=1)[:endshow_inset] + error[0, :endshow_inset], msd.mean(axis=1)[:endshow_inset] - error[1, :endshow_inset], alpha=0.3, color=c)

    if clustered_realizations:

        colors = np.array([cmap(i) for i in np.linspace(50, 225, len(tot_clusters)).astype(int)])

        for i, g in enumerate(tot_clusters):

            msd = ts.msd(trajectory_generators[g].traj, 2)

            error, _ = ts.bootstrap_msd(msd, nboot, confidence=68, median=False)

            ax1.plot(t, msd.mean(axis=1)[:endshow], lw=2, color=colors[i], label='%d Clusters' % trajectory_generators[g].nstates)
            ax1.fill_between(t, msd.mean(axis=1)[:endshow] + error[0, :endshow], msd.mean(axis=1)[:endshow] - error[1, :endshow], alpha=0.3, color=colors[i])

            if inset:
                ax2.plot(tinset, msd.mean(axis=1)[:endshow_inset], lw=2, color=colors[i], label='%d Clusters' % trajectory_generators[g].nstates)
                ax2.fill_between(tinset, msd.mean(axis=1)[:endshow_inset] + error[0, :endshow_inset], msd.mean(axis=1)[:endshow_inset] - error[1, :endshow_inset], alpha=0.3, color=colors[i])

    MD_MSD = file_rw.load_object('trajectories/%s_msd.pl' % res)
    ax1.plot(t, MD_MSD.MSD_average[:endshow], color='black', lw=2, label='MD')
    ax1.fill_between(t, MD_MSD.MSD_average[:endshow] + MD_MSD.limits[0, :endshow], MD_MSD.MSD_average[:endshow] - MD_MSD.limits[1, :endshow], alpha=0.3, color='black')
    if clustered_realizations and inset:
        ax2.plot(tinset, MD_MSD.MSD_average[:endshow_inset], color='black', lw=2, label='MD')
        ax2.fill_between(tinset, MD_MSD.MSD_average[:endshow_inset] + MD_MSD.limits[0, :endshow_inset], MD_MSD.MSD_average[:endshow_inset] - MD_MSD.limits[1, :endshow_inset], alpha=0.3, color='black')
    #else:
    #    msd = ts.msd(com_raw[0], 1)
    #    error, _ = ts.bootstrap_msd(msd, nboot, confidence=68, median=False) 
    #    plt.plot(t, msd.mean(axis=1)[:endshow], lw=2, color='xkcd:black')
    #    plt.fill_between(t, msd.mean(axis=1)[:endshow] + error[0, :endshow], msd.mean(axis=1)[:endshow] - error[1, :endshow], alpha=0.3, color='xkcd:black')

    ax1.set_xlabel('Time (ns)', fontsize=fontsize)
    ax1.set_ylabel('MSD (nm$^2$)', fontsize=fontsize)
    ax1.tick_params(labelsize=fontsize)

    if clustered_realizations:
        if inset:
            ax2.tick_params(labelsize=fontsize)
        ax1.legend(fontsize=fontsize, loc='upper left', bbox_to_anchor=bbox_to_anchor, ncol=ncol_legend)
    else:
        ax1.legend(fontsize=fontsize, loc='upper left')

    fig.tight_layout()

    if savename is not None:

        fig.savefig(savename)
        
    plt.show()
    
    
def plot_realizations(md_traj, ihmm_traj, dt=0.5, savename=None, figsize=(15, 7)):
    
    nMD = md_traj.shape[1]
    nihmm = ihmm_traj.shape[1]
    
    total_trajectories = nMD + nihmm
    
    fig, ax = plt.subplots(total_trajectories, 2, figsize=figsize, sharex=True, gridspec_kw = {'hspace':0})

    t = np.arange(md_traj.shape[0]) * dt #/ 1000
    
    #md_traj[..., 1] -= md_traj[..., 1].mean(axis=0)

    for n in range(nMD):
        ax[n, 0].plot(t, md_traj[:, n, 2], color='xkcd:black')
        ax[n, 1].plot(t, np.linalg.norm(md_traj[:, n, :2], axis=1), color='xkcd:black')
        ax[n, 0].tick_params(labelsize=14)
        ax[n, 1].tick_params(labelsize=14)

    for i in range(nihmm):

        ax[i + nMD, 0].plot(t, ihmm_traj[:, i, 2], lw=2, color='xkcd:blue')
        ax[i + nMD, 1].plot(t,  np.linalg.norm(ihmm_traj[:, i, :2], axis=1), lw=2, color='xkcd:blue')

        #ax[0].set_xlabel('Time (ns)', fontsize=14)
        ax[i + nMD, 0].tick_params(labelsize=14)

        #ax[1].set_ylabel('r coordinate', fontsize=14)
        ax[i + nMD, 1].tick_params(labelsize=14)

    ax[total_trajectories // 2, 0].set_ylabel('                       Position (nm)\n', fontsize=16)
    ax[0, 0].set_title('$z$ coordinate', fontsize=16)
    ax[0, 1].set_title('$r$ coordinate', fontsize=16)
    ax[-1, 0].set_xlabel('Time (ns)', fontsize=14)
    ax[-1, 1].set_xlabel('Time (ns)', fontsize=14)
    
    fig.tight_layout()
    if savename is not None:
        fig.savefig(savename)

    plt.show()

    
def plot_state_emissions(md_traj, final_parameters, states, ar_order=1, bins=50, nboot=200, nstepsgen=1000, xrange=(-1, 1), confidence=68, show_ax=(0, 1), savename=None):
    
    if isinstance(states, int):
        states = [states]

    z = final_parameters['z']
    
    # first 'ar_order' data points aren't output from HDP-AR-HMM procedure
    # If seeding the algorithm, the last few data points are sometimes discarded because segments aren't divided evenly
    md_traj = md_traj[ar_order:(ar_order + z.shape[1]), ...]

    for s in states:

        MD_fluctuations = [[], []]
        for t in range(z.shape[0]):
            
            sequence = np.zeros(z.shape[1], dtype=bool)
            ndx = np.where(z[t, :] == s)[0]
            sequence[ndx] = True
            
            if len(ndx) > 0:
                segments = ts.switch_points(sequence)[1:]
                for i in range(0, segments.size, 2):
                    try:
                        positions = md_traj[segments[i]:segments[i + 1], t, :]
                        MD_fluctuations[0] += (positions[1:, 0] - positions[:-1, 0]).tolist()
                        MD_fluctuations[1] += (positions[1:, 1] - positions[:-1, 1]).tolist()
                    except IndexError:  # case where beginning of segment is actually last point in trajectory
                        pass
        
        trajectory_generator = GenARData(params=final_parameters)
        trajectory_generator.gen_trajectory(nstepsgen, 1, bound_dimensions=[0], resample_T=True, state_no=s)
        traj = trajectory_generator.traj
        
        ihmm_fluctuations = []
        ihmm_fluctuations.append(traj[1:, 0, 0] - traj[:-1, 0, 0])
        ihmm_fluctuations.append(traj[1:, 0, 1] - traj[:-1, 0, 1])
        
        ihmm_hopsr, ihmm_hopsr_limits = boot(ihmm_fluctuations[0], nboot, bins, xrange, confidence)
        ihmm_hopsz, ihmm_hopsz_limits = boot(ihmm_fluctuations[1], nboot, bins, xrange, confidence)
        md_hopsr, md_hopsr_limits = boot(np.array(MD_fluctuations[0]), nboot, bins, xrange, confidence)
        md_hopsz, md_hopsz_limits = boot(np.array(MD_fluctuations[1]), nboot, bins, xrange, confidence)

        edges = np.linspace(xrange[0], xrange[1], bins + 1)
        bin_width = edges[1] - edges[0]
        centers = [i + bin_width/2 for i in edges[:-1]]
        
        if len(show_ax) == 1:
            fig, ax = plt.subplots(1, 1, figsize=(6, 4.5), sharey=False)
        else:
            fig, ax = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
        
        MD_alphar = levy.fit_levy(MD_fluctuations[0], beta=0)[0].x[0]
        MD_alphaz = levy.fit_levy(MD_fluctuations[1], beta=0)[0].x[0]
        ihmm_alphar = levy.fit_levy(ihmm_fluctuations[0], beta=0)[0].x[0]
        ihmm_alphaz = levy.fit_levy(ihmm_fluctuations[1], beta=0)[0].x[0]
        
        if 0 in show_ax:
            
            if 1 in show_ax:
                ax2 = ax[0]
            else:
                ax2 = ax
        
            ax2.plot(centers, md_hopsr, lw=2, label=r'MD ($\alpha$=%.2f)' % MD_alphar)
            ax2.fill_between(centers, md_hopsr + md_hopsr_limits[0, :], md_hopsr - md_hopsr_limits[1, :], alpha=0.3)
            ax2.plot(centers, ihmm_hopsr, lw=2, label=r'HDP-AR-HMM ($\alpha$=%.2f)' % ihmm_alphar)
            ax2.fill_between(centers, ihmm_hopsr + ihmm_hopsr_limits[0, :], ihmm_hopsr - ihmm_hopsr_limits[1, :], alpha=0.3)
            ax2.tick_params(labelsize=14)
            ax2.set_xlabel('$r$ Fluctuation Size (nm)', fontsize=14)        
            ax2.set_ylabel('Probability (nm)', fontsize=14)
            ax2.legend(fontsize=14)
            
        if 1 in show_ax:
            
            if 0 in show_ax:
                ax1 = ax[1]
            else:
                ax1 = ax
            
            ax1.plot(centers, md_hopsz, lw=2, label=r'MD ($\alpha$=%.2f)' % MD_alphaz)
            ax1.fill_between(centers, md_hopsz + md_hopsz_limits[0, :], md_hopsz - md_hopsz_limits[1, :], alpha=0.3)

            ax1.plot(centers, ihmm_hopsz, lw=2, label=r'HDP-AR-HMM ($\alpha$=%.2f)' % ihmm_alphaz)
            ax1.fill_between(centers, ihmm_hopsz + ihmm_hopsz_limits[0, :], ihmm_hopsz - ihmm_hopsz_limits[1, :], alpha=0.3)

#         ax[0].hist(MD_fluctuations[0], bins, density=True, range=xrange, label='MD')
#         ax[0].hist(ihmm_fluctuations[0], bins, density=True, range=xrange, label='IHMM')
        
#         ax[1].hist(MD_fluctuations[1], bins, density=True, range=xrange, label='MD')
#         ax[1].hist(ihmm_fluctuations[1], bins, density=True, range=xrange, label='IHMM')
        
            ax1.tick_params(labelsize=14)
            ax1.set_xlabel('$z$ Fluctuation Size (nm)', fontsize=14)
            ax1.legend(fontsize=14)
        
            if 0 not in show_ax:
                ax1.set_ylabel('Probability (nm)', fontsize=14)
        
        fig.tight_layout()
        
        if savename is not None:
            plt.savefig(savename)
            
        plt.show()

    
def plot_state_sequence(com, z, r_convert=True, cmap=plt.cm.jet, dt=0.5, fontsize=14, savename=None, seed=None,
                       legend=False, state_numbers=None, bbox_to_anchor=(0, 1.05)):
    """
    :param com: center of mass coordinates
    :param z: state sequence
    :param r_convert: convert cartesian xy coordinates to radial, r
    """
    
    if r_convert:
        axes = [0, 1]  # r, z
        rax, zax = axes
        labels = {rax: '$r$ coordinate (nm)', zax: '$z$ coordinate (nm)'}
        
        pos = np.zeros([com.shape[0], 1, 2])
        pos[:, 0, 0] = np.linalg.norm(com[:, 0, :2], axis=1)
        pos[:, 0, 1] = com[:, 0, 2]
        
    else:
        axes = [0, 1, 2]  # x, y, z
        xax, yax, zax = axes
        labels = {xax: '$x$ coordinate (nm)', yax: '$x$ coordinate (nm)', zax: '$z$ coordinate (nm)'}
        pos = com
        
    fig, ax = plt.subplots(len(axes), 1, figsize=(12, 5))
    
    nclusters = np.unique(z).size
    time = np.arange(z.size) * dt
    
    shown_colors = np.array([cmap(i) for i in np.linspace(50, 225, nclusters).astype(int)])
    
    if seed is not None:
        np.random.seed(seed)
        np.random.shuffle(shown_colors)
        
    colors = np.array([cmap(i) for i in np.linspace(50, 225, z.max() + 1).astype(int)])
    colors[np.unique(z)] = shown_colors

    for dim, a in zip(range(len(axes)), axes):
        
#         com = ihmmr[n].com[1:, 0, dim]
        
        ax[a].add_collection(
               hdphmm.multicolored_line_collection(time, pos[1:, 0, dim], z, colors))

        #ax[a].plot(t, com[1:, 0, dim], lw=2, color='xkcd:blue')

        ax[a].set_xlim([0, time[-1]])
        ax[a].set_ylim([pos[1:, 0, dim].min(), pos[1:, 0, dim].max()])
        
        ax[a].tick_params(labelsize=fontsize)

        ax[a].set_ylabel(labels[a], fontsize=fontsize)
        
    if legend:
        
        print(np.unique(z))
        handles = []
        for i, s in enumerate(state_numbers):
            handles.append(mlines.Line2D([], [], color=colors[s], lw=2, label='State %d' % (i + 1)))
            
        labels = [h.get_label() for h in handles]
    
        ax[0].legend(handles, labels, fontsize=14, ncol=len(state_numbers), bbox_to_anchor=bbox_to_anchor,
                     loc='upper left')
        
    if savename is not None:
        
        plt.savefig(savename)

    plt.show()
    
    
def plot_zeroed_trajectory(ihmm, traj_no):

    cmap = plt.cm.jet
    zeroed = ihmm[traj_no].subtract_mean(traj_no=0, simple_mean=True)
    nT = zeroed.shape[0]

    fig, ax = plt.subplots(2, 1)

    seq = ihmm[traj_no].clustered_state_sequence[0, :]

    colors = np.array([cmap(i) for i in np.linspace(50, 225, seq.max()).astype(int)])

    for a in range(2):

        ax[a].add_collection(
            hdphmm.multicolored_line_collection(np.arange(nT) * ihmm[traj_no].dt / 1000, zeroed[:, a], seq, colors))

        ax[a].set_xlim([0, nT * ihmm[t].dt / 1000])
        ax[a].set_ylim([zeroed[:, a].min(), zeroed[:, a].max()])

        
def prevalent_states(z, percent=34):

    state_counts = dict()

    for n in range(z.shape[0]):

        unique_states = np.unique(z[n, :])
        for u in unique_states:
            if u in state_counts.keys():
                state_counts[u] += 1
            else:
                state_counts[u] = 1
                
    nstates = max(state_counts.keys()) + 1
    state_counts = np.array([state_counts[i] for i in range(nstates)])
    fraction = state_counts / 24

    prevelant_states = np.where(fraction > (percent / 100))[0]

    return prevelant_states, state_counts


def rotate_z(theta):
    """ Generate rotation matrix for rotating about the z-axis

    :param: theta: angle by which to rotate

    :type theta: float

    :return: Rotation matrix to rotate input vector about z-axis
    :rtype numpy.ndarray
    """
    Rz = np.zeros([3, 3])  # makes a 3 x 3 zero matrix
    Rz[0, 0] = math.cos(theta)
    Rz[1, 0] = math.sin(theta)
    Rz[0, 1] = -math.sin(theta)
    Rz[1, 1] = math.cos(theta)
    Rz[2, 2] = 1

    return Rz


def rotate_vector(xyz, v1, v2):
    """ Rotate coordinates based on a reference vector to a second vector

    :param xyz: xyz coordinates of object to be rotated
    :param v1: original vector
    :param v2: direction you want v1 to be pointing in

    :type xyz: numpy.ndarray
    :type v1: numpy.ndarray
    :type v2: numpy.ndarray

    :return: rotated coordinates
    :rtype: numpy.ndarray
    """

    quad = quadrant(v1)
    # first find the angle between v1 and v2
    num = np.dot(v1, v2)
    denom = np.linalg.norm(v1) * np.linalg.norm(v2)
    theta = np.arccos(num / denom)

    if quad == 3 or quad == 4:
        Rz = rotate_z(theta)
    else:
        Rz = rotate_z(-theta)

    pos = np.zeros_like(xyz)
    for i in range(np.shape(xyz)[0]):
        pos[i, :] = np.dot(Rz, xyz[i, :])

    return pos


def seed_sequence(com, traj_no, nseg=2, max_states=100, niter=10, save_prefix=None):

    segments = [[] for i in range(nseg)]
    pps = com[0].shape[0] // nseg  # point per segment
    for s in range(nseg):

        if s == 0:
            seg = (com[0][s*pps: (s + 1)*pps, [traj_no], :], com[1])
        else:
            seg = (com[0][s*pps - 1: (s + 1)*pps, [traj_no], :], com[1])
            
        print(seg[0].shape)

        segments[s] = hdphmm.InfiniteHMM(seg, traj_no=0, load_com=False, difference=False, 
                                 observation_model='AR', order=1, max_states=max_states,
                                 dim=[0, 1, 2], prior='MNIW-N', save_every=1, hyperparams=None)

    z = np.zeros([1, 0], dtype=int)
    for s in range(nseg):
        segments[s].inference(niter)
        zseg = segments[s].z + max_states * s
        z = np.concatenate((z, zseg), axis=1)
        
        if save_prefix is not None:
            
            savename = '%s_segment%d.pdf' % (save_prefix, s)
            segments[s].summarize_results(savename=savename)

    new_labels = {x: i for i, x in enumerate(np.unique(z))}
    for i in range(z.shape[1]):
        z[0, i] = new_labels[z[0, i]]
        
    return z

def silhouette_(params, eigs, diags, linkage, nclust, algorithm='agglomerative', plot=False, savename=None):
    
    silhouette_avg = []

    for n in nclust:

        cluster = Cluster(params, eigs=eigs, diags=diags, algorithm=algorithm, distance_threshold=None, 
                          nclusters=n, linkage=linkage, convert_rz=True)
        cluster.fit()

        silhouette_avg.append(silhouette_score(cluster.X, cluster.labels))
        
    if plot:
        
        plt.plot(nclust, silhouette_avg, lw=2)
        plt.xlabel('Number of clusters', fontsize=14)
        plt.ylabel('Silhouette Score', fontsize=14)
        plt.tick_params(labelsize=14)
        plt.ylim(0, 1)
        plt.tight_layout()
        
        if savename is not None:
            plt.savefig(savename)
    
    else:
        
        return silhouette_avg


def silhouette(params, eigs, diags, algorithm='agglomerative'):
    # Silhouette Scoring

    nclust = np.arange(2, 50)
    linkages = ['ward', 'average', 'complete', 'single']
    fig, ax = plt.subplots(2, 2, figsize=(11, 9))

    for i, linkage in enumerate(linkages):

        silhouette_avg = silhouette_(params, eigs, diags, linkage, nclust, algorithm=algorithm)

        ax1 = i // 2
        ax2 = i % 2

        ax[ax1, ax2].set_title('%s' % linkage, fontsize=14)

        ax[ax1, ax2].plot(nclust, silhouette_avg, lw=2, label='$\mu$')
        ax[ax1, ax2].set_xlabel('Number of clusters', fontsize=14)
        ax[ax1, ax2].set_ylabel('Silhouette Score', fontsize=14)
        ax[ax1, ax2].set_ylim(0, 1)
        ax[ax1, ax2].tick_params(labelsize=14)
        ax[ax1, ax2].legend(fontsize=14)

    plt.tight_layout()
    plt.show()

    
def toc(com, z, cmap=plt.cm.jet, dt=0.5, fontsize=14, savename=None, seed=None, start=0, end=-1):
    """
    :param com: center of mass coordinates
    :param z: state sequence
    :param r_convert: convert cartesian xy coordinates to radial, r
    """
       
    fig1, ax1 = plt.subplots(figsize=(8, 4.5))
    fig2, ax2 = plt.subplots(figsize=(8, 4.5))
    
    if end == -1:
        end = z.size
    
    nclusters = np.unique(z[start:end]).size
    time = np.arange(z[start:end].size) * dt
    
    shown_colors = np.array([cmap(i) for i in np.linspace(50, 225, nclusters).astype(int)])
    
    if seed is not None:
        np.random.seed(seed)
        np.random.shuffle(shown_colors)
        
    colors = np.array([cmap(i) for i in np.linspace(50, 225, z.max() + 1).astype(int)])
    colors[np.unique(z[start:end])] = shown_colors
    
    ax1.plot(time, com[(1 + start):(end + 1), 0, 2], lw=2)
    ax1.set_ylim([com[(1 + start):(end + 1), 0, 2].min(), com[(1 + start):(end + 1), 0, 2].max()])
    ax1.tick_params(labelsize=fontsize)
    ax1.set_ylabel('$z$-coordinate', fontsize=fontsize)
    ax1.set_xlabel('Time (ns)', fontsize=fontsize)
    ax1.set_xlim([0, time[-1]])
        
    ax2.add_collection(
           hdphmm.multicolored_line_collection(time, com[(1 + start):(end + 1), 0, 2], z[start:end], colors))
    ax2.set_xlim([0, time[-1]])
    ax2.set_ylim([com[(1 + start):(end + 1), 0, 2].min(), com[(1 + start):(end + 1), 0, 2].max()])
    ax2.tick_params(labelsize=fontsize)
    ax2.set_ylabel('$z$-coordinate', fontsize=fontsize)
    ax2.set_xlabel('Time (ns)', fontsize=fontsize)

    fig1.tight_layout()
    fig2.tight_layout()
    
    if savename is not None:
        
        savename_1 = savename + '_unclustered.png'
        savename_2 = savename + '_clustered.png'
        
        fig1.savefig(savename_1)
        fig2.savefig(savename_2)

    plt.show()
    

def unclustered_trajectories(ihmm, ntrajper, load=False, mod_T=False, save=False, exclude=(), mod_T_percentile=20, 
                             mod_T_increase_factor=2.25, mod_T_decrease_factor=1, nsteps='auto'):
    """
    :param ihmm: list of hdphmm.InfiniteHMM objects
    :param ntrajper: number of realizations per ihmm parameterization
    """

    if not load:

        pre_cluster_hops_ = []
        pre_cluster_dwells_ = []
        
        if nsteps is 'auto':
            
            nsteps = ihmm[0].z.shape[1]
            for i in range(1, len(ihmm)):
                nsteps = min(nsteps, ihmm[i].z.shape[1])

        unclustered_trajectory_realizations = np.zeros([len(ihmm), nsteps, ntrajper, ihmm[0].dimensions])

        for i, hmm in enumerate(ihmm):

            if i not in exclude:

                final_p = get_params([hmm], hmm, hmm.z, clustered=False)

                if mod_T:

                    final_p['T'] = modify_T(final_p['T'].mean(axis=0), percentile=mod_T_percentile, 
                                            increase_factor=mod_T_increase_factor, 
                                            decrease_factor=mod_T_decrease_factor)

                trajectory_generator = GenARData(params=final_p)

                for t in tqdm.tqdm(range(ntrajper)):

                    trajectory_generator.gen_trajectory(nsteps, 1, bound_dimensions=[0, 1], 
                                                        resample_T=True, progress=False)

                    # This shouldn't be necessary. I implemented when testing unstable higher order ARs
                    #while trajectory_generator.traj.max() > 1e4:
                    #    trajectory_generator.gen_trajectory(hmm.z.shape[1], 1, bound_dimensions=[0], 
                    #                                        resample_T=True, progress=False)

                    unclustered_trajectory_realizations[i, :, t, :] = trajectory_generator.traj[:, 0, :]

                try:
                    pre_cluster_hops_.append(np.array(trajectory_generator.hops)[:, 1].tolist())
                except IndexError: # if no hops occur
                    pass

                pre_cluster_dwells_.append(trajectory_generator.dwells)

        if save:
            
            file_rw.save_object((pre_cluster_hops_, pre_cluster_dwells_, unclustered_trajectory_realizations), 'unclustered_%s.pl' % res)

    else:

        if not mod_T:

            savename = 'unclustered_%s.pl' % res

        else:

            savename = 'unclustered_%s_modT.pl' % res

        pre_cluster_hops_, pre_cluster_dwells_, unclustered_trajectory_realizations = file_rw.load_object(savename)
        
    return pre_cluster_hops_, pre_cluster_dwells_, unclustered_trajectory_realizations


def visualize_clusters(params, cluster, nclusters, eigs, diags, nbins=25, algorithm='agglomerative', linkage='ward', hist_limits=None, axis_limits=None, savename=None):
    """
    :param params: this is the set of parameters corresponding to cluster. If cluster='T', this should be 'T'
    :param cluster: name of cluster to visualize (T, sigma or A)
    :param nclusters: number of clusters
    """

    params = {'T': {'T': 1/-np.log(1 - params)}, 'sigma':{'sigma': params}, 'A':{'A': params}}
    visualization_functions = {'T': visualize_T_clusters, 'A': visualize_A_clusters, 'sigma': visualize_sigma_clusters}
    
    clusters = Cluster(params[cluster], algorithm=algorithm, nclusters=nclusters, linkage=linkage, eigs=eigs, diags=diags)
    clusters.fit()
    
    visualization_functions[cluster](clusters, nbins, show=True, hist_limits=hist_limits, axis_limits=axis_limits, savename=None)

    
def visualize_A_clusters(clusters, nbins, show=True, hist_limits=None, axis_limits=None):

#     fig, ax = plt.subplots(2, 2, figsize=(15, 10))
    colors = ['xkcd:blue', 'xkcd:orange', 'xkcd:green', 'xkcd:red', 'xkcd:purple']

    fig2, scatter_ax = plt.subplots(1, 1)

    for i in range(nclusters):

        ndx = np.where(np.array(clusters.labels) == i)[0]

#         ax[0, 0].hist(A[0, 0, ndx], nbins, color=colors[i])
#         ax[0, 1].hist(A[0, 1, ndx], nbins, color=colors[i])
#         ax[1, 0].hist(A[1, 0, ndx], nbins, color=colors[i])
#         ax[1, 1].hist(A[1, 1, ndx], nbins, color=colors[i])

        scatter_ax.scatter(A[0, 0, ndx], A[1, 1, ndx], color=colors[i])

#     ax[0, 0].tick_params(labelsize=14)
#     ax[0, 1].tick_params(labelsize=14)
#     ax[1, 0].tick_params(labelsize=14)
#     ax[1, 1].tick_params(labelsize=14)
    # ax[0].set_xlabel('r variance', fontsize=14)
    # ax[1].set_xlabel('covariance', fontsize=14)
    # ax[2].set_xlabel('z variance', fontsize=14)

    scatter_ax.set_xlabel('A(0, 0)', fontsize=14)
    scatter_ax.set_ylabel('A(1, 1)', fontsize=14)
    
    scatter_ax.tick_params(labelsize=14)
    
    fig2.tight_layout()

    if show:
        plt.show()

def visualize_sigma_clusters(clusters, nbins, show=True, hist_limits=None, axis_limits=None):

    #fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    colors = ['xkcd:blue', 'xkcd:orange', 'xkcd:green', 'xkcd:red']

    fig2, scatter_ax = plt.subplots(1, 1)

    for i in range(nclusters):

        ndx = np.where(np.array(clusters.labels) == i)[0]

#         ax[0].hist(sigma[0, 0, ndx], nbins, color=colors[i])
#         ax[1].hist(sigma[0, 1, ndx], nbins, color=colors[i])
#         ax[2].hist(sigma[1, 1, ndx], nbins, color=colors[i])

        scatter_ax.scatter(sigma[0, 0, ndx], sigma[1, 1, ndx], color=colors[i])

#     ax[0].set_xlim(0, 1)
#     ax[2].set_xlim(0, 1)

#     ax[0].set_ylim(0, 10)
#     ax[1].set_ylim(0, 10)
#     ax[2].set_ylim(0, 10)

#     ax[0].tick_params(labelsize=14)
#     ax[1].tick_params(labelsize=14)
#     ax[2].tick_params(labelsize=14)
#     ax[0].set_xlabel('r variance', fontsize=14)
#     ax[1].set_xlabel('covariance', fontsize=14)
#     ax[2].set_xlabel('z variance', fontsize=14)

    scatter_ax.set_xlabel('$\Sigma$(0, 0)', fontsize=14)
    scatter_ax.set_ylabel('$\Sigma$(1, 1)', fontsize=14)
    scatter_ax.tick_params(labelsize=14)
    
    fig2.tight_layout()
    
    if show:
        plt.show()

def visualize_T_clusters(clusters, nbins, show=True, hist_limits=None, axis_limits=None, savename=None):
    
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    colors = ['xkcd:blue', 'xkcd:orange', 'xkcd:green', 'xkcd:red']
    
    nclusters = np.unique(clusters.labels).size

    for i in range(nclusters):
        ndx = np.where(np.array(clusters.labels) == i)[0]

        ax[0].hist(T[ndx], nbins, color=colors[i], range=hist_limits[0][i])

        ax[1].hist(-np.log(1 - T[ndx]), color=colors[i], range=hist_limits[1][i])

        ax[2].hist(1 / (1 - T[ndx]), nbins, color=colors[i], range=hist_limits[2][i])

    ax[0].tick_params(labelsize=14)
    ax[1].tick_params(labelsize=14)
    ax[2].tick_params(labelsize=14)

    ax[0].set_xlabel('T$_{ii}$', fontsize=14)
    ax[1].set_xlabel('-log(1- T$_{ii}$)', fontsize=14)
    ax[2].set_xlabel('Dwell Time (ns)', fontsize=14)

    if axis_limits is not None:
        for a, lim in enumerate(axis_limits):
            xlim, ylim = lim
            if xlim is not None:
                ax[a].set_xlim(xlim)
            if ylim is not None:
                ax[a].set_ylim(ylim)
    
    fig.tight_layout()
    
    if savename is not None:
        plt.savefig(savename)

    if show:
        plt.show()


def quadrant(pt, origin=[0, 0]):
    """ Find out which quadrant of the xy plane a point is sitting in
    II    |    I
          |
    -------------
          |
    III   |    IV
    :param: pt: point to be tested
    :param: origin: the location of the origin. Default is [0, 0] but can be set arbitrarily (such as a pore center)
    """
    if pt[0] > origin[0] and pt[1] > origin[1]:
        return 1
    elif pt[0] < origin[0] and pt[1] < origin[1]:
        return 3
    elif pt[0] > origin[0] and pt[1] < origin[1]:
        return 4
    elif pt[0] < origin[0] and pt[1] > origin[1]:
        return 2
    else:
        return 0  # the case where the point lies on the x or y axis


NameError: name 'plt' is not defined