<a href="https://colab.research.google.com/github/fpesceKU/EnsembleLab/blob/main/EnsembleLab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Preliminary information:**




*   This Colab notebook enables running molecular dynamics (MD) simulations of intrinsically disordered proteins (IDPs) and to study their conformational ensembles in combination with experimental data.
*   MD simulations employ the coarse-grained force fields CALVADOS2, where each residue is mapped onto a single bead and parametrized with a "stickiness" parameter and electrostatic.
*   Simulations only require that the user provides the sequence of an IDP and set environmental conditions (temperature, ionic strength, pH).
*   If experimental data are available for the IDP, the Bayesian/Maximum-entropy reweighting can be used to refine ensembles to better match the experimental data <font color='#FA003F'>(currently only SAXS is implemented)</font>.
*   Multiple structural observables are calculated from simulations.
*   MD simulations run on GPU. To enable GPU select `Runtime` from the menu, then `Change runtime type` and select `GPU`.
* <b><font color='#FA003F'>Please note:</b> this notebook uses condacolab, whose installation will cause a kernel restart. Because of this, a crash will happen if you execute all cells at once. We recommend to run cells for preliminary operations one by one to prevent crashes.</font>
*   A more detailed tutorial is available at https://github.com/fpesceKU/EnsembleLab/blob/main/Tutorial.pdf
---



In [None]:
#@title <b><font color='#72A276'>0. IDP sequence and data</font></b>
from google.colab import files
import numpy as np
import os
import shutil
#@markdown Name the IDP that you want to simulate:
NAME = "Hst5" #@param {type:"string"}

#@markdown Insert here the sequence of the IDP that you want to simulate:
SEQUENCE = "DSHAKRHHGYKRKFHEKHHSHRGY" #@param {type:"string"}
if " " in SEQUENCE:
    SEQUENCE = ''.join(SEQUENCE.split())
    print('Blank character(s) found in the provided sequence. Sequence has been corrected, but check for integrity:')
    print(SEQUENCE)
    print('\n')
#@markdown Are experimental data available for this IDP? If so, a prompt will appear where you can provide the location of the data.
EXPERIMENT = "SAXS" #@param ["None", "SAXS", "Rh (to be implemented)"]

if EXPERIMENT == "SAXS":
    print('SAXS data must be in a file containing 3 columns, which are q, I and sigma. Commented lines (#) are allowed.')
    print('Please upload to session storage a properly formatted file containing SAXS data.')
    tmp = files.upload()
    saxs_file = list(tmp.keys())[0]

    #check data
    try:
        np.loadtxt(saxs_file)
    except:
        print("Unable to read file. Make sure the file only contains 3 columns (q,I,sigma) and #commented lines")
    assert np.shape(np.loadtxt(saxs_file))[1] == 3, "Expected file with 3 columns (q,I,sigma)"

    exp_saxs = np.loadtxt(saxs_file)
    if exp_saxs[...,0][-1] <  1:
        print('q is in Å units. Converting to nm.')
        exp_saxs[...,0] = exp_saxs[...,0]*10
        np.savetxt(saxs_file, exp_saxs)
    
    if (exp_saxs[...,0] >= 5).sum() > 0:
        print('Found {} q-values above 5 nm^(-1). SAXS calculations are not reliable in that region of the spectrum. Those datapoints will be remove'.format((exp_saxs[...,0] >= 5).sum()))
        exp_saxs = exp_saxs[(exp_saxs[...,0] < 5)]
        np.savetxt(saxs_file, exp_saxs)
    
    shutil.move(saxs_file, 'saxs_input.dat')

#@markdown Simulation settings:
Temperature = 298 #@param {type:"number"}
Ionic_strength = 0.150 #@param {type:"number"}
pH = 7.0 #@param {type:"number"}
#@markdown <i>*Units: Temperature [K], Ionic_strength [M]<i>
np.savetxt('env_settings.txt', np.array([Temperature, Ionic_strength, pH]), header='temperature ionic_strength, ph')

# Need to store metadata prior to condacolab restarting the kernel
f = open('seq.fasta','w')
f.write('>{:s}\n{:s}'.format(NAME,SEQUENCE))
f.close()
try:
    os.mkdir('{:s}'.format(EXPERIMENT))
except:
    pass

In [4]:
#@title <b>Preliminary operations</b>: acquiring softwares
%%bash

rm -r sample_data

wget https://cssb.biology.gatech.edu/skolnick/files/PULCHRA/pulchra304.tgz &> /dev/null
tar -zxf pulchra304.tgz &> /dev/null
rm pulchra304.tgz
mv ./pulchra304/bin/linux/pulchra .
chmod +x pulchra
rm -r pulchra304/

wget https://files.inria.fr/NanoDFiles/Website/Software/Pepsi-SAXS/Linux/3.0/Pepsi-SAXS-Linux.zip &> /dev/null
unzip Pepsi-SAXS-Linux.zip &> /dev/null
rm Pepsi-SAXS-Linux.zip

wget https://raw.githubusercontent.com/fpesceKU/BLOCKING/main/block_tools.py &> /dev/null
wget https://raw.githubusercontent.com/fpesceKU/BLOCKING/main/main.py &> /dev/null

wget https://raw.githubusercontent.com/KULL-Centre/BME/main/BME_tools.py &> /dev/null
wget https://raw.githubusercontent.com/KULL-Centre/BME/main/BME.py &> /dev/null

wget https://raw.githubusercontent.com/ehb54/GenApp-BayesApp/main/bin/source/bift.f &> /dev/null
gfortran bift.f -march=native -O2 -o bift

In [None]:
#@title <b>Preliminary operations</b>: setting the environment (i)
import subprocess
subprocess.run( 'pip install -q condacolab'.split() )
import condacolab
condacolab.install()

In [None]:
#@title <b>Preliminary operations</b>: setting the environment (ii)
import subprocess
print('Installing libraries...')
_ = subprocess.run( 'mamba install matplotlib mdtraj openmm=7.7 -c conda-forge --yes'.split() )
subprocess.run( 'pip install wget kneed==0.5.0'.split() )
import wget
import os
import shutil
import numpy as np
import pandas as pd
import scipy.stats as scs
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression
import mdtraj as md
from simtk import openmm, unit
from simtk.openmm import app
import matplotlib as mpl
import matplotlib.pyplot as plt
import BME as BME
from kneed import KneeLocator
from google.colab import files

from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')

In [6]:
#@title <b><font color='#FA003F'>1.1 MD Toolbox</font></b>

wget.download('https://raw.githubusercontent.com/KULL-Centre/CALVADOS/main/residues.csv')

residues = pd.read_csv('residues.csv')
residues = residues.set_index('one')

def genParamsLJ(df,seq):
    fasta = seq.copy()
    r = df.copy()
    r.loc['X'] = r.loc[fasta[0]]
    r.loc['Z'] = r.loc[fasta[-1]]
    r.loc['X','MW'] += 2
    r.loc['Z','MW'] += 16
    fasta[0] = 'X'
    fasta[-1] = 'Z'
    types = list(np.unique(fasta))
    lj_eps = 0.2*4.184
    lj_sigma = pd.DataFrame((r.sigmas.values+r.sigmas.values.reshape(-1,1))/2,
                            index=r.sigmas.index,columns=r.sigmas.index)
    lj_lambda = pd.DataFrame((r.lambdas.values+r.lambdas.values.reshape(-1,1))/2,
                             index=r.lambdas.index,columns=r.lambdas.index)
    return lj_eps, lj_sigma, lj_lambda, fasta, types

def genParamsDH(df,seq,temp,pH,ionic):
    kT = 8.3145*temp*1e-3
    fasta = seq.copy()
    r = df.copy()
    # Set the charge on HIS based on the pH of the protein solution
    r.loc['H','q'] = 1. / ( 1 + 10**(pH-6) )
    r.loc['X'] = r.loc[fasta[0]]
    r.loc['Z'] = r.loc[fasta[-1]]
    fasta[0] = 'X'
    fasta[-1] = 'Z'
    r.loc['X','q'] = r.loc[seq[0],'q'] + 1.
    r.loc['Z','q'] = r.loc[seq[-1],'q'] - 1.
    # Calculate the prefactor for the Yukawa potential
    fepsw = lambda T : 5321/T+233.76-0.9297*T+0.1417*1e-2*T*T-0.8292*1e-6*T**3
    epsw = fepsw(temp)
    lB = 1.6021766**2/(4*np.pi*8.854188*epsw)*6.022*1000/kT
    yukawa_eps = [r.loc[a].q*np.sqrt(lB*kT) for a in fasta]
    # Calculate the inverse of the Debye length
    yukawa_kappa = np.sqrt(8*np.pi*lB*ionic*6.022/10)
    return yukawa_eps, yukawa_kappa

def genDCD(name, eqsteps=1000):
    """ 
    Generates coordinate and trajectory
    in convenient formats
    """
    traj = md.load("{:s}/pretraj.dcd".format(name), top="{:s}/top.pdb".format(name))
    traj = traj.image_molecules(inplace=False, anchor_molecules=[set(traj.top.chain(0).atoms)], make_whole=True)
    traj.center_coordinates()
    traj.xyz += traj.unitcell_lengths[0,0]/2
    tocut = eqsteps #10 ns (eq)
    traj[int(tocut):].save_dcd("{:s}/traj.dcd".format(name))

def simulate(residues,name,seq,temp,pH,ionic,nsteps,stride=1e3,eqsteps=1000):
    os.mkdir(name)

    lj_eps, _, _, fasta, types= genParamsLJ(residues,seq)
    yukawa_eps, yukawa_kappa = genParamsDH(residues,seq,temp,pH,ionic)

    N = len(fasta)

    L = 200

    system = openmm.System()

    # set box vectors
    a = unit.Quantity(np.zeros([3]), unit.nanometers)
    a[0] = L * unit.nanometers
    b = unit.Quantity(np.zeros([3]), unit.nanometers)
    b[1] = L * unit.nanometers
    c = unit.Quantity(np.zeros([3]), unit.nanometers)
    c[2] = L * unit.nanometers
    system.setDefaultPeriodicBoxVectors(a, b, c)
    
    top = md.Topology()
    pos = []
    chain = top.add_chain()
    pos.append([[0,0,L/2+(i-N/2.)*.38] for i in range(N)])
    for resname in fasta:
        residue = top.add_residue(resname, chain)
        top.add_atom(resname, element=md.element.carbon, residue=residue)
    for i in range(chain.n_atoms-1):
        top.add_bond(chain.atom(i),chain.atom(i+1))
    md.Trajectory(np.array(pos).reshape(N,3), top, 0, [L,L,L], [90,90,90]).save_pdb('{:s}/top.pdb'.format(name))

    pdb = app.pdbfile.PDBFile('{:s}/top.pdb'.format(name))

    system.addParticle((residues.loc[seq[0]].MW+2)*unit.amu)
    for a in seq[1:-1]:
        system.addParticle(residues.loc[a].MW*unit.amu) 
    system.addParticle((residues.loc[seq[-1]].MW+16)*unit.amu)

    hb = openmm.openmm.HarmonicBondForce()
    energy_expression = 'select(step(r-2^(1/6)*s),4*eps*l*((s/r)^12-(s/r)^6),4*eps*((s/r)^12-(s/r)^6)+eps*(1-l))'
    ah = openmm.openmm.CustomNonbondedForce(energy_expression+'; s=0.5*(s1+s2); l=0.5*(l1+l2)')
    yu = openmm.openmm.CustomNonbondedForce('q*(exp(-kappa*r)/r - exp(-kappa*4)/4); q=q1*q2')
    yu.addGlobalParameter('kappa',yukawa_kappa/unit.nanometer)
    yu.addPerParticleParameter('q')

    ah.addGlobalParameter('eps',lj_eps*unit.kilojoules_per_mole)
    ah.addPerParticleParameter('s')
    ah.addPerParticleParameter('l')
 
    for a,e in zip(seq,yukawa_eps):
        yu.addParticle([e*unit.nanometer*unit.kilojoules_per_mole])
        ah.addParticle([residues.loc[a].sigmas*unit.nanometer, residues.loc[a].lambdas*unit.dimensionless])

    for i in range(N-1):
        hb.addBond(i, i+1, 0.38*unit.nanometer, 8033*unit.kilojoules_per_mole/(unit.nanometer**2))
        yu.addExclusion(i, i+1)
        ah.addExclusion(i, i+1)

    yu.setForceGroup(0)
    ah.setForceGroup(1)
    yu.setNonbondedMethod(openmm.openmm.CustomNonbondedForce.CutoffPeriodic)
    ah.setNonbondedMethod(openmm.openmm.CustomNonbondedForce.CutoffPeriodic)
    hb.setUsesPeriodicBoundaryConditions(True)
    yu.setCutoffDistance(4*unit.nanometer)
    ah.setCutoffDistance(2*unit.nanometer)
 
    system.addForce(hb)
    system.addForce(yu)
    system.addForce(ah)

    #serialized_system = XmlSerializer.serialize(system)
    #outfile = open('system.xml','w')
    #outfile.write(serialized_system)
    #outfile.close()

    integrator = openmm.openmm.LangevinIntegrator(temp*unit.kelvin,0.01/unit.picosecond,0.010*unit.picosecond) #10 fs timestep

    platform = openmm.Platform.getPlatformByName('CUDA')

    simulation = app.simulation.Simulation(pdb.topology, system, integrator, platform, dict(CudaPrecision='mixed')) 

    check_point = '{:s}/restart.chk'.format(name)

    if os.path.isfile(check_point):
        print('Reading check point file')
        simulation.loadCheckpoint(check_point)
        simulation.reporters.append(app.dcdreporter.DCDReporter('{:s}/pretraj.dcd'.format(name),int(stride),append=True))
    else:
        simulation.context.setPositions(pdb.positions)
        simulation.minimizeEnergy()
        simulation.reporters.append(app.dcdreporter.DCDReporter('{:s}/pretraj.dcd'.format(name),int(stride)))

    simulation.reporters.append(app.statedatareporter.StateDataReporter('{:s}/traj.log'.format(name),int(stride),
             potentialEnergy=True,temperature=True,step=True,speed=True,elapsedTime=True,separator='\t'))

    simulation.step(nsteps)

    simulation.saveCheckpoint(check_point)

    genDCD(name,eqsteps)

In [7]:
#@title <b><font color='#FA003F'>1.2 Analysis Toolbox</font></b>
from main import BlockAnalysis

def autoblock(cv, multi=1, plot=False):
    block = BlockAnalysis(cv, multi=multi)
    block.SEM()

    if plot == True:
        plt.errorbar(block.stat[...,0], block.stat[...,1], block.stat[...,2], fmt='', color='k', ecolor='0.5')
        plt.scatter(block.bs, block.sem,zorder=10,c='tab:red')
        plt.xlabel('Block size')
        plt.ylabel('SEM')
        plt.show()

    return block.av, block.sem, block.bs

def Rg(t, seq):
    masses = residues.loc[list(seq),'MW'].values
    rgarray = md.compute_rg(t,masses=masses)
    return rgarray

def invRh(t):
    n = len(list(t.topology.residues))
    return (1-1/n)*(1/md.compute_distances(t,t.top.select_pairs('all','all'))).mean(axis=1)
    
def fix_topology(t,seq):
    cgtop = md.Topology()
    cgchain = cgtop.add_chain()
    for res in seq:
        cgres = cgtop.add_residue(res, cgchain)
        cgtop.add_atom('CA', element=md.element.carbon, residue=cgres)
    traj = md.Trajectory(t.xyz, cgtop, t.time, t.unitcell_lengths, t.unitcell_angles)
    traj = traj.superpose(traj, frame=0)
    return traj

def Dee(t):
    return md.compute_distances( t, atom_pairs=np.array([[ 0,  len(list(t.top.atoms))-1]]) )[...,0]

def Rij(traj, r0_fix=0.68, w=None):
    pairs = traj.top.select_pairs('all','all')
    d = md.compute_distances(traj,pairs)
    if w is None:
        w = np.full(len(traj),1)
    dmean = np.average(d, weights=w, axis=0)
    dmax = np.max(d)
    ij = np.array(range(1,traj.n_atoms))
    diff = [x[1]-x[0] for x in pairs]
    dij = np.empty(0)
    for i in ij:
        dij = np.append(dij,dmean[diff==i].mean())
    
    f = lambda x,v : r0_fix*np.power(x,v)
    popt, pcov = curve_fit(f,ij[ij>10],dij[ij>10],p0=[.5])
    nu = popt[0]
    return ij, dij, nu, dmax

def kde(a, w=None, min_=None, max_=None):
    if type(w) == 'NoneType':
        w = np.full(len(a), 1)
    if min_ == None:
        min_ = np.min(a)
    if max_ == None: 
        max_ = np.max(a)
    x = np.linspace( min_, max_, num = 50 )
    d = scs.gaussian_kde( a, bw_method = "silverman", weights = w ).evaluate(x)
    u = np.average(a, weights = w)
    return x,d/np.sum(d),u

def plot_dist(ax,x,p,av):
    ax.plot(x,p,c='k')
    ax.vlines(av,0,np.max(p)+0.1*np.max(p), color='k')
    ax.set_xlim(np.min(x), np.max(x))
    ax.set_ylim(0,np.max(p)+0.1*np.max(p))

def plot_rew_dist(ax,x,p,av):
    ax.plot(x,p,c='tab:red',ls='dashed')
    ax.vlines(av,0,100, color='tab:red', linestyle='dashed')

def backmapping(name, traj, dt):
    for i in np.arange(0,len(t_cg),dt):
        t_cg[int(i)].save_pdb('frame.pdb')
        subprocess.run(['./pulchra', 'frame.pdb'])
        if i == 0:
            traj_AA = md.load_pdb('frame.rebuilt.pdb')
        else:
            traj_AA += md.load_pdb('frame.rebuilt.pdb')
    traj_AA.save_dcd('{}/traj_AA.dcd'.format(name))
    shutil.move('frame.rebuilt.pdb', '{}/top_AA.pdb'.format(name))
    return traj_AA

In [8]:
#@title <b><font color='#FA003F'>1.3 BME Toolbox</font></b>
def iBME(calc_file,exp_file,THETAS=np.array([1,10,20,50,75,100,200,400,750,1000,5000,10000])):
    W = []
    STATS = []

    for t in THETAS:
        print('Reweighting with theta={}'.format(t))
        rew = BME.Reweight('ibme_t{}'.format(t))
        rew.load(exp_file,calc_file)
        rew.ibme(theta=t, iterations=25, ftol=0.001)
        
        W.append(rew.get_ibme_weights()[-1])
        STATS.append(rew.get_ibme_stats()[-1])
        print('chi2={:.2f}, phi_eff={:.2f}'.format(STATS[-1][1],STATS[-1][2]))

    return THETAS, np.array(STATS), np.array(W)

def theta_loc(thetas, stats):
    kneedle = KneeLocator(stats[...,2], stats[...,1], S=1, curve="convex", direction="increasing")
    choice = np.array(thetas)[stats[...,2]==kneedle.knee][0]
    return choice

In [None]:
#@title <b><font color='#ffc413'>2.1 Launch MD simulation</font></b>
# Getting back variables from user inputs
f = open('seq.fasta', 'r').readlines()
NAME = f[0][1:].strip()
SEQUENCE = f[1].strip()

if os.path.exists('SAXS'):
    EXPERIMENT = 'SAXS'
else:
    EXPERIMENT = None

T, IS, PH = np.loadtxt('env_settings.txt', unpack=True)
#@markdown Simulation time (ns):
Simulation_time = "AUTO" #@param {type:"raw"}

if Simulation_time == "AUTO":
    N_res = len(SEQUENCE)
    L = (N_res-1)*0.38+4
    N_save = 7000 if N_res < 150 else int(np.ceil(3e-4*N_res**2)*1000)
    nsteps = 1010*N_save
    print('AUTO simulation length selected. Running for {} ns'.format(nsteps*0.01/1000))
else:
    nsteps = float(Simulation_time)*1000/0.01
try:
    shutil.rmtree(NAME)
except:
    pass
simulate(residues,NAME,list(SEQUENCE),temp=T,pH=PH,ionic=IS,nsteps=nsteps,stride=1e3,eqsteps=1000)

In [None]:
#@title <b><font color='#ffc413'>2.2 Calculate structural observables from simulation</font></b>
traj = md.load_dcd('{:s}/traj.dcd'.format(NAME), top='{:s}/top.pdb'.format(NAME))

rg_array = Rg(traj, SEQUENCE)
rg_av, rg_err, rg_bs = autoblock(rg_array, plot=False)
x_rg, p_rg, _ = kde(rg_array)

invrh = invRh(traj)
tmp = autoblock(invrh, plot=False)
rh_av = 1/tmp[0]
rh_err = rh_av*tmp[1]/tmp[0]
rh_bs = tmp[2]
tmp, p_rh, _ = kde(invrh)
x_rh = 1/tmp

dee_array = Dee(traj)
dee_av, dee_err, dee_bs = autoblock(dee_array, plot=False)
x_dee, p_dee, _ = kde(dee_array)

ij, dij, nu, dmax = Rij(traj)

mpl.rcParams.update({'font.size': 10})
fig, axs = plt.subplots(2, 2, figsize=(8,6), facecolor='w', dpi=300, layout='constrained')
axs = axs.flatten()

plot_dist(axs[0], x_rg, p_rg, rg_av)
axs[0].set_xlabel(r'$R_g$ [nm]')
axs[0].set_ylabel(r'p($R_g$)')
axs[0].text(0.75, 0.9, r'$\langle R_g \rangle$={:.2f}$\pm{:.2f}$ nm'.format(rg_av,rg_err), horizontalalignment='center',verticalalignment='center', transform=axs[0].transAxes, fontsize=10)

plot_dist(axs[1], x_rh, p_rh, rh_av)
axs[1].set_xlabel(r'$R_h$ [nm]')
axs[1].set_ylabel(r'p($R_h$)')
axs[1].text(0.25, 0.9, r'$\langle R_h \rangle$={:.2f}$\pm{:.2f}$ nm'.format(rh_av,rh_err), horizontalalignment='center',verticalalignment='center', transform=axs[1].transAxes, fontsize=10)

plot_dist(axs[2], x_dee, p_dee, dee_av)
axs[2].set_xlabel(r'$D_{ee}$ [nm]')
axs[2].set_ylabel(r'p($D_{ee}$)')
axs[2].text(0.75, 0.9, r'$\langle D_{ee} \rangle$'+'={:.2f}$\pm{:.2f}$ nm'.format(dee_av,dee_err), horizontalalignment='center',verticalalignment='center', transform=axs[2].transAxes, fontsize=10)

axs[3].plot(ij,dij,c='k')
axs[3].set_xlabel('|i-j|')
axs[3].set_ylabel(r'$\langle R_{ij} \rangle$ [nm]')
axs[3].text(0.25, 0.9, r'$\nu$={:.2f}'.format(nu), horizontalalignment='right',verticalalignment='center', transform=axs[3].transAxes, fontsize=10)

plt.show()

In [None]:
#@title <b><font color='#ffc413'>2.3 Generate all-atom trajectory</font></b>
SEQ3 = [residues.three[x] for x in SEQUENCE]
t_cg = fix_topology(md.load_dcd('{:s}/traj.dcd'.format(NAME), top='{:s}/top.pdb'.format(NAME)), SEQ3)
sub_f = int(np.average([rg_bs,rh_bs,dee_bs]))
print('Subsampling trajectory prior to backmapping to all-atom. Taken 1 simulation frame every {}'.format(sub_f))
print('Backmapping to all-atom resolution...')
traj_AA = backmapping(NAME, t_cg, sub_f)

ndx = np.argsort(rg_array[0:len(t_cg):sub_f])
sel = ndx[::np.ceil(len(ndx)/10).astype(int)]

for i,n in enumerate(sel):
    traj_AA[n].save_pdb('{:s}/frame{}.pdb'.format(NAME,i+1))

In [None]:
#@title <b><font color='#058ED9'>3.1 Ensemble reweighting with experimental data</b>: getting data ready</font>
if EXPERIMENT == 'SAXS':

    #correct experimental errors with BIFT
    f = open('inputfile.dat','w')
    f.write("saxs_input.dat\n\n\n\n{}\n\n\n\n\n\n\n\n\n\n\n\n".format(dmax))
    f.close()
    subprocess.run('./bift < inputfile.dat'.split())
    np.savetxt('{:s}_bift.dat'.format(NAME), np.loadtxt('rescale.dat'), header=' DATA=SAXS')
    print('Experimental errors on SAXS intensities have been corrected with BIFT')
    print('Factor used for rescaling errors is: {}'.format(np.loadtxt('scale_factor.dat')[0,1]))
    print('SAXS data with corrected errors is in {:s}_bift.dat\n'.format(NAME))

    #SAXS
    print('Calculating SAXS from all-atom trajectory...')
    t_aa = md.load_dcd('{}/traj_AA.dcd'.format(NAME), top='{}/top_AA.pdb'.format(NAME))
    for i,f in enumerate(t_aa):
        f.save_pdb('frame.pdb')
        pepsi_comm = './Pepsi-SAXS frame.pdb {:s}_bift.dat -o saxs.dat -cst --cstFactor 0 --I0 1.0 --dro 1.0 --r0_min_factor 1.025 --r0_max_factor 1.025 --r0_N 1'.format(NAME)
        subprocess.run(pepsi_comm.split())
        if i == 0:
            calc_saxs = np.loadtxt('saxs.dat')[...,3]
        else:
            calc_saxs = np.vstack((calc_saxs,np.loadtxt('saxs.dat')[...,3]))
    col0 = np.arange(0,len(calc_saxs)).reshape(len(calc_saxs),1)
    calc_saxs = np.hstack((col0,calc_saxs))
    np.savetxt('calc_saxs.dat', calc_saxs)

elif EXPERIMENT is None:
    print('No experiment specified. Select type of experiment in cell #0.')
else:
  print('An unknown option is specified.')

In [None]:
#@title <b><font color='#058ED9'>3.2 Execute reweighting</b></font>
if EXPERIMENT == 'SAXS':
    thetas, stats, weights = iBME('calc_saxs.dat', '{:s}_bift.dat'.format(NAME))
elif EXPERIMENT is None:
    print('No experiment specified. Select type of experiment in cell #0.')
else:
  print('An unknown option is specified.')

In [None]:
#@title <b><font color='#058ED9'>3.3 Analyze reweighting</b></font>
#@markdown NB if using "INTERACTIVE": a prompt will appear below the theta scan plot for you to indicate the theta of chice, but this is a bit buggy and sometimes the prompt does not appear. In that case simply stop the execution of this cell and try re-running it.
THETA_LOCATOR = "AUTO" #@param ["AUTO", "INTERACTIVE"]
if EXPERIMENT is not None:
    if THETA_LOCATOR == "AUTO":
        choice = theta_loc(thetas, stats)
        mpl.rcParams.update({'font.size': 10})
        fig = plt.figure(layout='constrained',dpi=300,figsize=(8,3))
        plt.scatter(stats[...,2],stats[...,1], c='k')
        ndx = np.where(thetas==choice)[0][0]
        plt.scatter(stats[...,2][ndx],stats[...,1][ndx], c='tab:red',label=r'Chosen $\theta$')
        plt.xlabel(r'$\phi_{eff}$',fontsize=10)
        plt.ylabel(r'$\chi^2_r$',fontsize=10)
        plt.legend()
        plt.show()
        print('Using theta={} for reweighting.\n'.format(choice))
    elif THETA_LOCATOR == "INTERACTIVE":
        fig, ax = plt.subplots(dpi=300,figsize=(12,4),layout='tight')
        ax.scatter(stats[...,2],stats[...,1], c='k',alpha=0)
        ax.set_xlabel(r'$\phi_{eff}$',fontsize=10)
        ax.set_ylabel(r'$\chi^2_r$',fontsize=10)

        ax2 = ax.twiny()
        ax2.scatter(stats[...,2],stats[...,1], c='k',zorder=100,s=5)
        ax2.set_xticks(ticks=stats[...,2], labels=thetas, rotation=90, fontsize=8)
        ax2.grid(ls='dashed')
        ax2.set_xlabel(r'$\theta$',fontsize=10)
        plt.ioff()
        plt.show()

        choice = input('Indicate your choice for theta:')
        choice = int(choice)
        if choice in thetas:
            ndx = np.where(thetas==choice)[0][0]
        else:
            print('Data for selected theta not available.')

if EXPERIMENT == 'SAXS':
    q, I_exp, err = np.loadtxt('{}_bift.dat'.format(NAME), unpack=True)
    I_prior = np.average(np.loadtxt('calc_saxs.dat')[...,1:],axis=0)
    I_post = np.average(np.loadtxt(list(filter(lambda x: x.startswith('ibme_t{}_'.format(choice)) and x.endswith('.calc.dat'), os.listdir('.')))[0])[...,1:], axis=0, weights=weights[ndx])
    wlr = 1/(err**2)
    model = LinearRegression()
    model.fit(I_prior.reshape(-1,1),I_exp,wlr)
    a = model.coef_[0]
    b = model.intercept_
    I_prior = a*I_prior+b

    mpl.rcParams.update({'font.size': 10})
    fig, axs = plt.subplots(2, 2, figsize=(8,6), facecolor='w', dpi=300, layout='constrained')
    axs = axs.flatten()

    axs[0].errorbar(q,I_exp,err, lw=1,c='0.5',alpha=0.5)
    axs[0].plot(q,I_prior, lw=1, zorder=500)
    axs[0].plot(q,I_post, lw=1, color='tab:red', ls='dashed', zorder=1000)
    axs[0].set_xlabel(r'q [nm$^{-1}$]')
    axs[0].set_ylabel('Intensity')

    axs[1].errorbar(q,I_exp,err, lw=1,c='0.5',alpha=0.5)
    axs[1].plot(q,I_prior, lw=1, zorder=500)
    axs[1].plot(q,I_post, lw=1, color='tab:red', ls='dashed', zorder=1000)
    axs[1].set_xscale('log')
    axs[1].set_yscale('log')
    axs[1].set_xlabel(r'q [nm$^{-1}$]')
    axs[1].set_ylabel('Intensity')

    kratky_exp = (q**2)*I_exp
    kratky_err = (q**2)*err
    axs[2].errorbar(q,kratky_exp,kratky_err, lw=1,c='0.5',alpha=0.5)
    axs[2].plot(q,(q**2)*I_prior, lw=1, zorder=500)
    axs[2].plot(q,(q**2)*I_post, lw=1, color='tab:red', ls='dashed', zorder=1000)
    axs[2].set_xlabel(r'q [nm$^{-1}$]')
    axs[2].set_ylabel(r'q$^2$I')

    axs[3].plot(q, (I_exp-I_prior)/err, color='tab:blue', lw=1 )
    axs[3].plot(q, (I_exp-I_post)/err, color='tab:red', lw=1, alpha=0.6 )
    axs[3].set_xlabel(r'q [nm$^{-1}$]')
    axs[3].set_ylabel(r'(I$^{EXP}$-I$^{CALC}$)/$\sigma$')

if EXPERIMENT is not None:
    x_rg_rew, p_rg_rew, rg_av_rew = kde(rg_array[np.arange(0,len(t_cg),sub_f)], w=weights[ndx], min_=np.min(rg_array), max_=np.max(rg_array))
    tmp = kde(invrh[np.arange(0,len(t_cg),sub_f)], w=weights[ndx])
    x_rh_rew, p_rh_rew, rh_av_rew = 1/tmp[0], tmp[1], 1/tmp[2]
    x_dee_rew, p_dee_rew, dee_av_rew = kde(dee_array[np.arange(0,len(t_cg),sub_f)], w=weights[ndx], min_=np.min(dee_array), max_=np.max(dee_array))
    ij, dij_rew, nu_rew, _ = Rij( traj[np.arange(0,len(t_cg),sub_f)], w=weights[ndx] )

    mpl.rcParams.update({'font.size': 10})
    fig, axs = plt.subplots(2, 2, figsize=(8,6), facecolor='w', dpi=300, layout='constrained')
    axs = axs.flatten()

    plot_dist(axs[0], x_rg, p_rg, rg_av)
    plot_rew_dist(axs[0], x_rg_rew, p_rg_rew, rg_av_rew)
    axs[0].set_xlabel(r'$R_g$ [nm]')
    axs[0].set_ylabel(r'p($R_g$)')
    axs[0].text(0.75, 0.9, r'$\langle R_g \rangle$={:.2f} nm'.format(rg_av_rew), color='tab:red', horizontalalignment='center',verticalalignment='center', transform=axs[0].transAxes, fontsize=10)

    plot_dist(axs[1], x_rh, p_rh, rh_av)
    plot_rew_dist(axs[1], x_rh_rew, p_rh_rew, rh_av_rew)
    axs[1].set_xlabel(r'$R_h$ [nm]')
    axs[1].set_ylabel(r'p($R_h$)')
    axs[1].text(0.2, 0.9, r'$\langle R_h \rangle$={:.2f} nm'.format(rh_av_rew), color='tab:red', horizontalalignment='center',verticalalignment='center', transform=axs[1].transAxes, fontsize=10)

    plot_dist(axs[2], x_dee, p_dee, dee_av)
    plot_rew_dist(axs[2], x_dee_rew, p_dee_rew, dee_av_rew)
    axs[2].set_xlabel(r'$D_{ee}$ [nm]')
    axs[2].set_ylabel(r'p($D_{ee}$)')
    axs[2].text(0.75, 0.9, r'$\langle D_{ee} \rangle$'+'={:.2f} nm'.format(dee_av_rew), color='tab:red', horizontalalignment='center',verticalalignment='center', transform=axs[2].transAxes, fontsize=10)

    axs[3].plot(ij,dij,c='k')
    axs[3].plot(ij,dij_rew,nu_rew,c='tab:red',ls='dashed')
    axs[3].set_xlabel('|i-j|')
    axs[3].set_ylabel(r'$\langle R_{ij} \rangle$ [nm]')
    axs[3].text(0.2, 0.9, r'$\nu$={:.2f}'.format(nu_rew), color='tab:red', horizontalalignment='right',verticalalignment='center', transform=axs[3].transAxes, fontsize=10)

    plt.show()

In [None]:
#@title <b><font color='#72A276'>4. Download results</b></font>
try:
    os.remove('{}/pretraj.dcd'.format(NAME))
    os.remove('{}/restart.chk'.format(NAME))
except:
    pass
try:
    shutil.move('seq.fasta' ,'{}/seq.fasta'.format(NAME),'w')
except:
    pass
try:
    shutil.move('env_settings.txt' ,'{}/env_settings.txt'.format(NAME),'w')
except:
    pass

f = open('{}/observables.dat'.format(NAME),'w')
f.write('Rg {:.2f} +/- {:.2f} nm\n'.format(rg_av,rg_err))
f.write('Rh {:.2f} +/- {:.2f} nm\n'.format(rh_av,rh_err))
f.write('Dee {:.2f} +/- {:.2f} nm\n'.format(dee_av,dee_err))
f.write('nu {:.2f}'.format(nu))
f.write('Dmax {:.2f}'.format(dmax))
f.close()

np.savetxt('{:s}/hist_rg.dat'.format(NAME), np.vstack((x_rg,p_rg)).T, header='Rg p(Rg)')
np.savetxt('{:s}/hist_rh.dat'.format(NAME), np.vstack((x_rh,p_rh)).T, header='Rh p(Rh)')
np.savetxt('{:s}/hist_dee.dat'.format(NAME), np.vstack((x_dee,p_dee)).T, header='Dee p(Dee)')
np.savetxt('{:s}/Rij.dat'.format(NAME), np.vstack((ij,dij)).T, header='ij Rij')

try:
    os.mkdir('{:s}_EnsembleLab'.format(NAME))
except:
    pass
try:
    shutil.copytree('{:s}'.format(NAME), '{:s}_EnsembleLab/SIMULATION'.format(NAME)) 
except:
    pass

wget.download('https://raw.githubusercontent.com/fpesceKU/EnsembleLab/main/utils/readme_download_simulation.txt')
fout = open('{:s}_EnsembleLab/README'.format(NAME), 'w')
fout.write(''.join(open('readme_download_simulation.txt', 'r').readlines()))
fout.write('\n')

if EXPERIMENT == 'SAXS':
    try:
        os.mkdir('{:s}_EnsembleLab/SAXS'.format(NAME))
    except:
        pass
    f = open('{}/reweighted_observables.dat'.format(NAME),'w')
    f.write('Rg {:.2f} nm\n'.format(rg_av_rew))
    f.write('Rh {:.2f} nm\n'.format(rh_av_rew))
    f.write('Dee {:.2f} nm\n'.format(dee_av_rew))
    f.write('nu {:.2f}'.format(nu_rew))
    f.write('Dmax {:.2f}'.format(dmax))
    f.close() 
    np.savetxt('{:s}_EnsembleLab/SAXS/hist_rg.dat'.format(NAME), np.vstack((x_rg_rew,p_rg_rew)).T, header='Rg p(Rg)')
    np.savetxt('{:s}_EnsembleLab/SAXS/hist_rh.dat'.format(NAME), np.vstack((x_rh_rew,p_rh_rew)).T, header='Rh p(Rh)')
    np.savetxt('{:s}_EnsembleLab/SAXS/hist_dee.dat'.format(NAME), np.vstack((x_dee_rew,p_dee_rew)).T, header='Dee p(Dee)')
    np.savetxt('{:s}_EnsembleLab/SAXS/saxs.dat'.format(NAME), np.vstack((q,I_exp,err,I_prior,I_post)).T, header='q I(exp) err(bift) I(prior) I(post)')
    np.savetxt('{:s}_EnsembleLab/SAXS/Rij.dat'.format(NAME), np.vstack((ij,dij_rew)).T, header='ij Rij')
    wget.download('https://raw.githubusercontent.com/fpesceKU/EnsembleLab/main/utils/readme_download_saxs.txt')
    fout.write(''.join(open('readme_download_saxs.txt', 'r').readlines()))

fout.close()
zipper = 'zip -r {:s}_EnsembleLab.zip {:s}_EnsembleLab'.format(NAME,NAME)

subprocess.run(zipper.split())
files.download('{:s}_EnsembleLab.zip'.format(NAME))

TODO list:
- progress bar for simulation and reweighting
- Adaptive significant digits
- Also show and save fit of Rij
- check auto simulation length