# Reweight REPX RNA tetramers with PyMBAR

In [1]:
import os, sys, math
import numpy as np
import glob
import mdtraj
import logging
import netCDF4 as nc
import warnings
import pandas as pd
from collections import defaultdict

import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, AutoMinorLocator
import seaborn as sns

import openmmtools as mmtools

import barnaba as bb
from barnaba import definitions
from barnaba.nucleic import Nucleic

In [2]:
#logger = logging.getLogger(__name__)
#logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
warnings.filterwarnings("ignore")

In [3]:
pd.options.display.max_rows = None
pd.options.display.max_columns = None
pd.options.display.precision = 1
pd.options.display.float_format = '{:.1f}'.format

In [4]:
#plt.rcParamsDefault

In [5]:
params_mydefault = {'legend.fontsize': 40, 
                    'font.size': 40, 
                    'axes.labelsize': 48,
                    'axes.titlesize': 48,
                    'xtick.labelsize': 40,
                    'ytick.labelsize': 40,
                    'savefig.dpi': 600, 
                    'figure.figsize': [64, 8],
                    'xtick.major.size': 10,
                    'xtick.minor.size': 7,
                    'ytick.major.size': 10,
                    'ytick.minor.size': 7}

In [6]:
plt.rcParams.update(params_mydefault)

In [7]:
backbone_sugar_atoms = [
    "C1'", \
    "H1'", \
    "C2'", \
    "H2'", \
    "C3'", \
    "H3'", \
    "C4'", \
    "H4'", \
    "C5'", \
    "H5'", \
    "H5''", \
    "O2'", \
    "HO2'", \
    "O3'", \
    "O4'", \
    "O5'", \
    "P", \
    "OP1", \
    "OP2", \
    "HO5'", \
    "HO3'"
]

In [82]:
# structure annotation
myclass_mapping_dict = {"AMa": 1, "AMi":2, "I": 3, "F1": 4, "F4": 5, "O": 6}

# define color
mycolor_dict = { "AMa": "green", "AMi": "blue", "I": "red", "F1": "magenta", "F4": "orange", "O": "black" }

# settings
PLOT_TITLE = "CCCC Amber ff14"
UNIT_NM_TO_ANGSTROMS = 10
UNIT_PS_TO_NS = 1/1000

In [9]:
def radian_to_degree(a):
    """
    a : list
        [trajectory frame : residue : torsion]
    """
    
    a[np.where(a<0.0)] += 2.*np.pi
    a *= 180.0/np.pi

    # same as above
    #a = a*(180./np.pi)
    #a[np.where(a<0.0)] += 360
    
    return a

### define

In [72]:
# ==============================================================================
# STRUCTURE ANNOTATION
# ==============================================================================

def _check_endo(angle_d, angle_p):
    """
    Define C3'-endo and C2'-endo.

    δ torsion angles is used to defined the endo states described in:
    RNA backbone: Consensus all-angle conformers and modular string nomenclature (an RNA Ontology Consortium contribution), RNA 2008, doi: 10.1261/rna.657708

    C3'-endo:
        an individual ribose with δ between 55° and 110°
    C2'-endo:
        an individual ribose with δ between 120° and 175°
        
    Alternatively C3'- and C2'- endo can be defined using the pucker phase angle. C3'-endo [0°, 36°) as in canonical RNA and A-form DNA, and the C2'-endo [144°, 180°).

    Returns
    -------
    c3_endo : list
        '1' if if δ torsion angle forms a C3'-endo form, else '0'

    c2_endo : list
        '1' if if δ torsion angle forms a C2'-endo form, else '0'
    """

    c3_endo = []
    for _delta, _phase in zip(angle_d, angle_p):
        # C3 endo
        if (_delta >= 55 and _delta < 110) or (_phase >=0 and _phase < 36):
            c3_endo.append(1)
        else:
            c3_endo.append(0)

    c2_endo = []
    for _delta, _phase in zip(angle_d, angle_p):
        # C2 endo
        if (_delta >= 120 and _delta < 175) or (_phase >= 144 and _phase < 180):
            c2_endo.append(1)
        else:
            c2_endo.append(0)
        
    return c3_endo, c2_endo


def _intercalete(stacking_residue_index):    
    """
    Define intercaleted structures.
    
    RNA structures are intercalated if nucleotide `j` inserts between and stacks against nb `i` and `i+1`.
    
    Parameters
    ----------

    stacking_residue_index : list
        List of stacking residue index (e.g. [[0, 1], [1, 2]]).

    Returns
    -------

    name : str
        Category name
    """
    
    name = ""
    
    # 1-'0'-2-3
    # 2-'0'-1-3
    if [[0, 1], [0, 2]] == stacking_residue_index:
        name = "I0102"
            
    # 1-'2'-0-3
    # 0-'2'-1-3
    if [[0, 2], [1, 2]] == stacking_residue_index:
        name = "I0212"

    # 1-2-'0'-3
    if [[0, 2], [0, 3]] == stacking_residue_index:
        name = "I0203"

    # 0-'3'-1-2
    if [[0, 3], [1, 3]] == stacking_residue_index:
        name = "I0313"
            
    # 0-2-'1'-3
    # 0-3-'1'-2
    if [[1, 2], [1, 3]] == stacking_residue_index:
        name = "I1213"
        
    # 0-2-'3'-1
    # 0-1-'3'-2
    if [[1, 3], [2, 3]] == stacking_residue_index:
        name = "I1323"
        
    # 1-'2'-'0'-3
    if [[0, 2], [0, 3], [1, 2]] == stacking_residue_index:
        name = "I020312"        

    # 0-'2'-'1'-3
    if [[0, 2], [1, 2], [1, 3]] == stacking_residue_index:
        name = "I021213"

    # 0-'3'-'1'-2
    if [[0, 3], [1, 2], [1, 3]] == stacking_residue_index:
        name = "I031213"
    
    return name



def _tangled(stacking_residue_index):    
    """
    Define tangled structures. (>2 nb stacking with 5' and 3' stacking)            
    
    RNA structures are tangled if 5' and 3' nb are stacking and number of stacking nb are 3.
    
    Parameters
    ----------

    stacking_residue_index : list
        List of stacking residue index (e.g. [[0, 1], [1, 2]]).

    Returns
    -------

    name : str
        Category name
    """
    
    name = ""        

    # 1-2-'3'-0
    #if [[0, 3], [1, 2]] == stacking_residue_index:
    #    name = "T0312"

    #if [[0, 3], [2, 3]] == stacking_residue_index:
    #    name = "T0323"
    
    if [[0, 3], [1, 2], [2, 3]] == stacking_residue_index:
        name = "T031223"
    
    # 3-'0'-1-2
    #if [[0, 1], [0, 3]] == stacking_residue_index:
    #    name = "T0103"
    
    #if [[0, 1], [1, 2]] == stacking_residue_index:
    #    name = "T0103"
    
    if [[0, 1], [0, 3], [1, 2]] == stacking_residue_index:
        name = "T010312"
        
    return name



def annotate(replica_index, alpha, beta, gamma, delta, eps, zeta, chi, phase, stacking):
    """
    Annotate RNA structures.

    Cateogries RNA tetramer structure into 6 categories based on their geometries.

    AMa: A-form major
    AMi: A-form minor
    I:   Intercaleted
    F1:  First nb flipped
    F4:  Last nb flipped
    O:   Others

    Parameters
    -----------
    replica_index : int
        Replica index number.
    alpha : np.ndarray
        Backbone alpha angles in degrees of shape [n_iterations, n_residues, n_angles].
    beta : np.ndarray
        Backbone beta angles in degrees of shape [n_iterations, n_residues, n_angles].
    gamma : np.ndarray
        Backbone gamma angles in degrees of shape [n_iterations, n_residues, n_angles].
    delta : np.ndarray
        Backbone delta angles in degrees of shape [n_iterations, n_residues, n_angles].
    eps : np.ndarray
        Backbone epsilon angles in degrees of shape [n_iterations, n_residues, n_angles].
    zeta : np.ndarray
        Backbone zeta angles in degrees of shape [n_iterations, n_residues, n_angles].
    chi : np.ndarray
        Nucleobase chi angles in degrees of shape [n_iterations, n_residues, n_angles].
    phase : np.ndarray
        Sugar pucker angles in degrees of shape of shape [n_iterations, n_residues, n_angles].
    stacking : np.ndarray
        Stacking information shape of [n_iterations, stacking_pattern]. Stacking_pattern contains stacking residue index and stacking form (e.g. [[[0, 1], [1, 2], [2, 3]], ['>>', '>>', '>>']]).

    Return
    ----------
    xxxx
    
    """
    ref_pdb = mdtraj.load('../../../../eq/min.pdb')
    rnames = [ residue.name for residue in ref_pdb.topology.residues if residue.name not in ["HOH", "NA", "CL"]]
    
    obs_dict = defaultdict(list)
    unknown_category = defaultdict(list)
    myclass, myclass_by_number = [], []
    
    for frame_idx in range(len(stacking)):
        ####
        #stackings[frame_idx] : list
        #    e.g. [[[0, 1], [1, 2], [2, 3]], ['>>', '>>', '>>']]
        ####
        
        names = []
        stacking_residue_index = stacking[frame_idx][0]
        stacking_pattern = stacking[frame_idx][1]

        
        # A-form
        if stacking_residue_index == [[0, 1], [1, 2], [2, 3]]:
            c3_binary, c2_binary = _check_endo(delta[frame_idx], phase[frame_idx])
            if stacking_pattern == ['>>', '>>', '>>'] and sum(c3_binary) == 4:
                names.append("AMa")
            else:
                names.append("AMi")

        # Partial stacking
        if stacking_residue_index == [[1, 2], [2, 3]] and stacking_pattern == ['>>', '>>']:
            names.append("F1")
        if stacking_residue_index == [[0, 1], [1, 2]] and stacking_pattern == ['>>', '>>']:
            names.append("F4")

        # Intercalete
        if len(stacking_pattern) >= 2:
            _name = _intercalete(stacking_residue_index)
            if _name.startswith("I"):
                names.append("I")

        # Other
        if len(stacking_pattern) == 0 or len(stacking_pattern) == 1:
            names.append("O")
        if len(names) == 0:
            names.append("O")
            unknown_category[str(stacking_pattern) + str(stacking_residue_index)].append(frame_idx+1)

        assert len(names) == 1, "{}: multiple annotation {}\t{}\t{}".format(frame_idx+1, names, stacking_pattern, stacking_residue_index)
        
        for category in myclass_mapping_dict.keys():
            if names[0] == category:
                obs_dict[category].append(1)
            else:
                obs_dict[category].append(0)        
        
        myclass.append(names[0])
        myclass_by_number.append(myclass_mapping_dict[names[0]])

    return myclass, myclass_by_number, obs_dict, unknown_category

In [73]:
def plot(myclass):
    from collections import Counter
    d = Counter(myclass)

    mydata = {
        "AMa": d["AMa"], \
        "AMi": d["AMi"], \
        "I":   d["I"], \
        "F1":  d["F1"], \
        "F4":  d["F4"], \
        "O":   d["O"]
    }

    mydata = {
        "AMa": 100*d["AMa"]/len(myclass), \
        "AMi": 100*d["AMi"]/len(myclass), \
        "I":   100*d["I"]/len(myclass), \
        "F1":  100*d["F1"]/len(myclass), \
        "F4":  100*d["F4"]/len(myclass), \
        "O":   100*d["O"]/len(myclass)
    }

    mycolor = ["green", "blue", "red", "magenta", "orange", "black"]

    # rmsd scatter plot
    color = []
    for _ in myclass:
        if _ == "AMa":
            color.append("green")
        elif _ == "AMi":
            color.append("blue")
        elif _ == "I":
            color.append("red")
        elif _ == "F1":
            color.append("magenta")
        elif _ == "F4":
            color.append("orange")
        #elif _ in ["O", "F1i", "F4i", "T"]:
        elif _ == "O":
            color.append("black")
        else:
            print("undefined {}".format(_))
            color.append("white")

    """
    histogram of conformational population
    """
    # define
    fig, ax = plt.subplots(figsize=(16, 8))
    ax.set_ylabel('(%)')
    ax.yaxis.set_label_position("left")
    ax.yaxis.set_minor_locator(MultipleLocator(10))
    ax.yaxis.set_ticks_position("left")
    ax.set_ylim([0, 100])

    i = 0
    for k, v in mydata.items():
        ax.text(x=i-0.3, y=v+3, s=f'{v:.1f}', size=24)
        i += 1

    # plot
    ax.bar(mydata.keys(), mydata.values(), width=1.0, color=mycolor)
    plt.tight_layout()
    plt.title(PLOT_TITLE)
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    plt.show()
    #plt.savefig("conformation_population.png")
        
    return mydata

In [74]:
def extract_data(replica_index, bb_angles, pucker_angles, stackings):
    """
    Extract pre-calculated properties.

    Parameters
    ----------
    replica_index : int
        Replica index number
    bb_angles : np.ndarray of shape [n_replicas, n_iterations, n_residues, n_angles]
        Backbone angles in degrees
    pucker_angles : np.ndarray of shape [n_replicas, n_iterations, n_residues, n_angles]
        Pucker angles in degrees
    stackings : np.ndarray of shape [n_replicas, n_iterations, stacking_info]
        Stacking_info stores stacking residue index and stacking pattern (e.g. [[[0, 1], [1, 2], [2, 3]], ['>>', '>>', '>>']])

    Returns
    -------
    """
    alpha = bb_angles[replica_index,:,:,0]  # [n_iterations, n_residues]
    beta = bb_angles[replica_index,:,:,1]
    gamma = bb_angles[replica_index,:,:,2]
    delta = bb_angles[replica_index,:,:,3]
    eps = bb_angles[replica_index,:,:,4]
    zeta = bb_angles[replica_index,:,:,5]
    chi = bb_angles[replica_index,:,:,6]
    phase = pucker_angles[replica_index,:,:,0]
    stacking = stackings[replica_index,:,:]
    
    return alpha, beta, gamma, delta, eps, zeta, chi, phase, stacking

### load data

In [75]:
md_trial = 1

In [76]:
ncfile = "../{}/enhanced.nc".format(md_trial)
npzfile = "../analysis/mydata{}_replica.npz".format(md_trial)

In [77]:
reporter = mmtools.multistate.MultiStateReporter(ncfile, open_mode='r')
analyzer = mmtools.multistate.MultiStateSamplerAnalyzer(reporter)



### annotate structures

In [78]:
npzfile = np.load(npzfile, allow_pickle=True)
for k in npzfile.files:
    print(k)

bb_angles
pucker_angles
rg
rmsd
ermsd
stackings
couplings


In [79]:
bb_angles = npzfile["bb_angles"]         # [n_replicas, n_iterations, n_residues, n_angles]
pucker_angles = npzfile["pucker_angles"] # [n_replicas, n_iterations, n_residues, n_angles]
stackings = npzfile["stackings"]         # [n_replicas, n_iterations, stacking_info]

In [80]:
# check data shape
print(bb_angles.shape)

(33, 30000, 4, 7)


In [83]:
print(">annotate structure")

myclasses_by_number = []
result_dict = defaultdict(list)
for replica_index in range(analyzer.n_replicas):
    alpha, beta, gamma, delta, eps, zeta, chi, phase, stacking = extract_data(replica_index, bb_angles, pucker_angles, stackings)
    
    # myclass: list of annotated category by names.
    # myclass_binary: list of annotated category by integers.
    # obs_dict: defaultdict that stores all categories. Cateogry names are used as keys and stores lists of 1 (True) or 0 (False) if the structure is annotated to that category class.
    # unknown_category: defaultdict that was assigned to any of the given categories.
    myclass, myclass_by_number, obs_dict, unknown_cateorgy = annotate(replica_index, alpha, beta, gamma, delta, eps, zeta, chi, phase, stacking)

    # store category numbers for each replica
    myclasses_by_number.append(myclass_by_number) 

    # store defaultdict(list) for each replica
    result_dict[replica_index] = obs_dict

    # count each class (annotation category)
    from collections import Counter
    d = Counter(myclass)
    mydata = {
        "AMa": d["AMa"], \
        "AMi": d["AMi"], \
        "I":   d["I"], \
        "F1":  d["F1"], \
        "F4":  d["F4"], \
        "O":   d["O"]
    }
    
    print("replica {}:\t{}".format(replica_index, mydata))

>annotate structure
replica 0:	{'AMa': 1321, 'AMi': 829, 'I': 412, 'F1': 964, 'F4': 951, 'O': 25523}
replica 1:	{'AMa': 2523, 'AMi': 62, 'I': 18, 'F1': 755, 'F4': 1112, 'O': 25530}
replica 2:	{'AMa': 6928, 'AMi': 262, 'I': 27, 'F1': 1519, 'F4': 3339, 'O': 17925}
replica 3:	{'AMa': 3656, 'AMi': 24, 'I': 80, 'F1': 1042, 'F4': 1528, 'O': 23670}
replica 4:	{'AMa': 44, 'AMi': 6854, 'I': 40, 'F1': 246, 'F4': 581, 'O': 22235}
replica 5:	{'AMa': 183, 'AMi': 385, 'I': 6699, 'F1': 624, 'F4': 3205, 'O': 18904}
replica 6:	{'AMa': 1758, 'AMi': 1, 'I': 5, 'F1': 2858, 'F4': 1596, 'O': 23782}
replica 7:	{'AMa': 40, 'AMi': 2011, 'I': 1027, 'F1': 25, 'F4': 3715, 'O': 23182}
replica 8:	{'AMa': 110, 'AMi': 340, 'I': 35, 'F1': 123, 'F4': 282, 'O': 29110}
replica 9:	{'AMa': 468, 'AMi': 5032, 'I': 62, 'F1': 247, 'F4': 7042, 'O': 17149}
replica 10:	{'AMa': 60, 'AMi': 69, 'I': 1907, 'F1': 250, 'F4': 89, 'O': 27625}
replica 11:	{'AMa': 11, 'AMi': 3, 'I': 2444, 'F1': 49, 'F4': 26, 'O': 27467}
replica 12:	{'AMa':

In [85]:
# check shape
myclasses_by_number = np.array(myclasses_by_number)
assert analyzer.n_replicas == myclasses_by_number.shape[0]

#### compute decorrelated energies

In [24]:
# check decorrelated u_ln and N_l array shapes exported by _compute_mbar_decorrelated_energies
decorrelated_u_ln, decorrelated_N_l = analyzer._compute_mbar_decorrelated_energies()
print(decorrelated_u_ln.shape, decorrelated_N_l.shape, decorrelated_N_l.sum())

(33, 249744) (33,) 249744


##### compute manually

In [26]:
# energy_data is [energy_sampled, energy_unsampled, neighborhood, replicas_state_indices]
energy_data = list(analyzer.read_energies())

# generate the equilibration data
sampled_energy_matrix, unsampled_energy_matrix, neighborhoods, replicas_state_indices = energy_data

# Note: This is different from pymbar.timeseries.detectEquilibration. analyzer._get_equilibration_data uses max_subset and excludes first iteration (minimization) to detect the equilibration data.
number_equilibrated, g_t, Neff_max = analyzer._get_equilibration_data(sampled_energy_matrix, neighborhoods, replicas_state_indices)

In [29]:
# remove equilibrated and decorrelated data from energy_data
for i, energies in enumerate(energy_data):
    # Discard equilibration iterations.
    energies = mmtools.multistate.utils.remove_unequilibrated_data(energies, number_equilibrated, -1)
    # Subsample along the decorrelation data.
    energy_data[i] = mmtools.multistate.utils.subsample_data_along_axis(energies, g_t, -1)

In [31]:
sampled_energy_matrix, unsampled_energy_matrix, neighborhood, replicas_state_indices = energy_data
print(sampled_energy_matrix.shape)
print(unsampled_energy_matrix.shape)
print(neighborhood.shape)
print(replicas_state_indices.shape)

(33, 33, 7568)
(33, 0, 7568)
(33, 33, 7568)
(33, 7568)


In [32]:
# Initialize the MBAR matrices in ln form.
n_replicas, n_sampled_states, n_iterations = sampled_energy_matrix.shape
_, n_unsampled_states, _ = unsampled_energy_matrix.shape
n_total_states = n_sampled_states + n_unsampled_states
energy_matrix = np.zeros([n_total_states, n_iterations*n_replicas])
samples_per_state = np.zeros([n_total_states], dtype=int)

In [35]:
print(energy_matrix.shape)
print(samples_per_state.shape)

(33, 249744)
(33,)


In [36]:
# Compute shift index for how many unsampled states there were.
# This assume that we set an equal number of unsampled states at the end points.
first_sampled_state = int(n_unsampled_states/2.0)
last_sampled_state = n_total_states - first_sampled_state
print(first_sampled_state, last_sampled_state)

0 33


In [37]:
analyzer.reformat_energies_for_mbar?

[0;31mSignature:[0m
[0manalyzer[0m[0;34m.[0m[0mreformat_energies_for_mbar[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mu_kln[0m[0;34m:[0m [0mnumpy[0m[0;34m.[0m[0mndarray[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_k[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mnumpy[0m[0;34m.[0m[0mndarray[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Convert [replica, state, iteration] data into [state, total_iteration] data

This method assumes that the first dimension are all samplers,
the second dimension are all the thermodynamic states energies were evaluated at
and an equal number of samples were drawn from each k'th sampler, UNLESS n_k is specified.

Parameters
----------
u_kln : np.ndarray of shape (K,L,N')
    K = number of replica samplers
    L = number of thermodynamic states,
    N' = number of iterations from state k
n_k : np.ndarray of shape K or None
    Number of sam

In [54]:
print(sampled_energy_matrix.shape)

(33, 33, 7568)


In [55]:
u_kln = sampled_energy_matrix
n_k = None

k, l, n = u_kln.shape
if n_k is None:
    n_k = np.ones(k, dtype=np.int32)*n
u_ln = np.zeros([l, n_k.sum()])

print(n_k.shape, u_ln.shape)

(33,) (33, 249744)


In [56]:
n_counter = 0
for k_index in range(k):
    print(k_index, n_counter, n_counter + n_k[k_index], n_k[k_index])
    print(u_kln[k_index, :, :n_k[k_index]].shape)
    u_ln[:, n_counter:n_counter + n_k[k_index]] = u_kln[k_index, :, :n_k[k_index]]
    n_counter += n_k[k_index]

0 0 7568 7568
(33, 7568)
1 7568 15136 7568
(33, 7568)
2 15136 22704 7568
(33, 7568)
3 22704 30272 7568
(33, 7568)
4 30272 37840 7568
(33, 7568)
5 37840 45408 7568
(33, 7568)
6 45408 52976 7568
(33, 7568)
7 52976 60544 7568
(33, 7568)
8 60544 68112 7568
(33, 7568)
9 68112 75680 7568
(33, 7568)
10 75680 83248 7568
(33, 7568)
11 83248 90816 7568
(33, 7568)
12 90816 98384 7568
(33, 7568)
13 98384 105952 7568
(33, 7568)
14 105952 113520 7568
(33, 7568)
15 113520 121088 7568
(33, 7568)
16 121088 128656 7568
(33, 7568)
17 128656 136224 7568
(33, 7568)
18 136224 143792 7568
(33, 7568)
19 143792 151360 7568
(33, 7568)
20 151360 158928 7568
(33, 7568)
21 158928 166496 7568
(33, 7568)
22 166496 174064 7568
(33, 7568)
23 174064 181632 7568
(33, 7568)
24 181632 189200 7568
(33, 7568)
25 189200 196768 7568
(33, 7568)
26 196768 204336 7568
(33, 7568)
27 204336 211904 7568
(33, 7568)
28 211904 219472 7568
(33, 7568)
29 219472 227040 7568
(33, 7568)
30 227040 234608 7568
(33, 7568)
31 234608 242176 756

In [57]:
# Cast the sampled energy matrix from kln' to ln form.
energy_matrix[first_sampled_state:last_sampled_state, :] = analyzer.reformat_energies_for_mbar(sampled_energy_matrix)

# Determine how many samples and which states they were drawn from.
unique_sampled_states, counts = np.unique(replicas_state_indices, return_counts=True)

# Assign those counts to the correct range of states.
samples_per_state[first_sampled_state:last_sampled_state][unique_sampled_states] = counts

In [58]:
energy_matrix.shape

(33, 249744)

In [59]:
sampled_energy_matrix.shape

(33, 33, 7568)

In [60]:
sampled_energy_matrix[:,0,:]

array([[-49117.99260111, -49348.46038518, -49126.48641033, ...,
        -45279.45492346, -45323.20070508, -45153.71124907],
       [-50067.63498656, -50174.47841353, -50255.02515686, ...,
        -48113.85875757, -48032.57379787, -47801.9676816 ],
       [-50835.22407309, -50569.40388435, -50724.38225869, ...,
        -45602.56783022, -45787.92246022, -45623.45969323],
       ...,
       [-50607.22365677, -50695.32194067, -50584.36462129, ...,
        -49176.10934155, -49219.56009528, -49395.51226428],
       [-47107.27415786, -47325.26837134, -47160.94386513, ...,
        -45507.79609604, -44916.27802646, -45343.49361899],
       [-48416.21382272, -48449.50548989, -47977.32197804, ...,
        -50244.12259948, -50212.61940506, -50548.97521055]])

In [61]:
energy_matrix[0,:]

array([-49117.99260111, -49348.46038518, -49126.48641033, ...,
       -50244.12259948, -50212.61940506, -50548.97521055])

In [62]:
assert sampled_energy_matrix[:,0,:].flatten().all() == energy_matrix[0,:].all()

### reformat observable

In [66]:
result_dict.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32])

In [89]:
myclasses_by_number.shape   # n_replicas, n_iterations

(33, 30000)

In [100]:
# initialize
k, n = myclasses_by_number.shape
o_ln = np.zeros([k, n+1])  # add iteration to match the shape size with energy matrix
o_ln[:,1:] = myclasses_by_number

In [101]:
o_ln.shape

(33, 30001)

In [102]:
import copy
_o_ln = copy.deepcopy(o_ln)
# Discard equilibration iterations.
_o_ln = mmtools.multistate.utils.remove_unequilibrated_data(_o_ln, number_equilibrated, -1)
# Subsample along the decorrelation data.
decorrelated_o_ln = mmtools.multistate.utils.subsample_data_along_axis(_o_ln, g_t, -1)

In [110]:
decorrelated_o_ln.shape

(33, 7568)

In [118]:
decorrelated_o_ln[1][:100]

array([1., 5., 5., 6., 6., 6., 5., 5., 6., 5., 5., 6., 1., 6., 1., 1., 6.,
       1., 6., 6., 1., 5., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.,
       6., 6., 6., 6., 6., 3., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.,
       6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.,
       6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 1., 6., 6., 6., 6.,
       6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.])

In [119]:
decorrelated_o_ln.flatten()[7568:7568+100]

array([1., 5., 5., 6., 6., 6., 5., 5., 6., 5., 5., 6., 1., 6., 1., 1., 6.,
       1., 6., 6., 1., 5., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.,
       6., 6., 6., 6., 6., 3., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.,
       6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.,
       6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 1., 6., 6., 6., 6.,
       6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.])

In [120]:
obs = decorrelated_o_ln.flatten()

In [125]:
print(obs.shape, decorrelated_u_ln.shape, decorrelated_N_l.shape)

(249744,) (33, 249744) (33,)


### MBAR

In [122]:
from pymbar import MBAR

In [123]:
print(">analyze with mbar")
mbar = MBAR(decorrelated_u_ln, decorrelated_N_l)   

>analyze with mbar


In [128]:
mbar.getWeights().shape

(249744, 33)

In [129]:
for index in range(mbar.getWeights().shape[1]):
    print(index, mbar.getWeights()[:,index].sum())

0 1.0000000000003908
1 0.9999999999970262
2 1.0000000000007851
3 1.000000000001145
4 0.9999999999990967
5 1.000000000001369
6 0.9999999999999434
7 0.9999999999998374
8 0.999999999999301
9 1.0000000000000666
10 1.0000000000005924
11 0.9999999999996981
12 0.9999999999999212
13 1.0000000000000286
14 0.9999999999984726
15 1.000000000000838
16 0.9999999999994938
17 0.9999999999988886
18 0.9999999999996665
19 1.0000000000015374
20 0.9999999999992332
21 1.0000000000003404
22 1.0000000000003662
23 1.000000000000083
24 0.999999999999854
25 0.9999999999997664
26 1.0000000000007399
27 0.9999999999993936
28 1.0000000000015812
29 0.9999999999998865
30 1.0000000000002993
31 0.9999999999996755
32 1.0000000000007272


In [131]:
weights = mbar.getWeights()[:,0]   # weights for the first thermodynamic state

In [133]:
myclass_mapping_dict

{'AMa': 1, 'AMi': 2, 'I': 3, 'F1': 4, 'F4': 5, 'O': 6}

In [138]:
indexes = np.where(obs == 1)

In [140]:
weights[indexes].sum()

0.36581633071009356

In [144]:
print("CLASS\tCOUNT\tREWEIGHT")
print("----------------------")

for k, v in myclass_mapping_dict.items():
    indexes = np.where(obs == v)
    #print(k, v, indices)
    
    x = weights[indexes].sum() * 100
    print("{}:\t{:.2f}\t\t{:.2f}".format(k, 100*len(indexes[0])/len(obs), x))

CLASS	COUNT	REWEIGHT
----------------------
AMa:	6.95		36.58
AMi:	3.04		9.02
I:	5.10		19.93
F1:	2.71		4.93
F4:	5.27		13.08
O:	76.93		16.46
