<a href="https://colab.research.google.com/github/kimjc95/computational-chemistry/blob/main/MD_in_Colab(ENG).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MD in Colab

2024-08-21 by Joo-Chan Kim at [MSBL](https://msbl.kaist.ac.kr), KAIST

YouTube Tutorial :

* [protein-ligand MD](https://www.youtube.com/watch?v=3ZbKCC_DS-A)

-------------------------------------

This is a Google Colaboratory Notebook for the all-atom molecular dynamics simulation of proteins (and nucleic acids) with ligands / non-standard residues. (BSD-3 license)

Please cite DOI:[10.5281/zenodo.13133762](https://doi.org/10.5281/zenodo.13685874) if you have used my code in your research.

If you have any problems, please raise an issue in [GitHub](https://github.com/kimjc95/computational-chemistry/issues) or email me (kimjoochan@kaist.ac.kr).

-------------------------------------

**Things to Prepare :**

1. Your system's PDB file or PDB ID

2. (optional) SDF files for your custom residues

3. Google account (for saving jobs in Google Drive), Internet connection (or additional Google Colab computing resources)

-------------------------------------

**with this notebook you CAN perform:**
1. All-atom implicit/explicit solvent MD simulation of protein/nucleic acids/small molecules with canonical/non-canonical residues

2. Post-run analyses including: clustering, RMSD, RMSF, RoG, FEL, H-bonds, SASA, DSSP, DCCM, PCA, and MM/GBSA

**with this notbook you CANNOT perform:**
1. MD simulation of systems with non-aqueous solvents or lipid bilayers

2. Advanced MD options like steered MD, replica exchange sampling, simulated tempering, metadynamics, etc.

3. Accurate binding energy calculation using MM/PBSA, 3D-RISM, umbrella sampling, alchemical transformation, etc.

4. Binding energy calculation of ligands with ML- or XTB-forcefields

----------------------------------------

To start simulating, open each tabs to view cells and run each cell one at a time.

## 0. Connect to Google Drive & Set working environment

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
assert gpu_info.find('command not found') <0, "Please change to GPU runtime!"

#@markdown Runtime will be restarted shortly. Please wait.

!pip install -q condacolab
import condacolab
condacolab.install()

⏬ Downloading https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:09
🔁 Restarting kernel...


In [1]:
#@markdown (Takes about 3 minutes) <br/>
#@markdown Enter the name of the working directory within your Google Drive. <br/>

from google.colab import drive
drive.mount('/content/drive')
working_directory = "new" #@param {type:"string"}

#@markdown Please authorize the Colab to get access to your Google Drive. MD Trajectory files are large, and Colab sessions are notorious for their instabilty.

import subprocess

print('Installing all required python packages...', end='')
subprocess.run(['mamba', 'install', '-c', 'pytorch', '-c', 'nvidia', '-c', 'conda-forge',
                'pytorch', 'pytorch-cuda=12.1', 'openmm=8', 'openmm-torch', 'openmm-ml',
                'openmm-xtb', 'nnpops', 'espaloma', 'openmmforcefields', 'scikit-learn',
                'scipy', 'ipywidgets=7', 'nglview', 'parmed', 'rdkit', 'pdbfixer', 'openbabel',
                'plotly', 'mdtraj', 'lxml', '--yes'])
subprocess.run(['pip', 'install', 'mace-torch'])
subprocess.run(['apt-get', 'install', 'python3-openbabel'])
subprocess.run(['pip', 'install', '--no-deps', 'plip'])
print('done.')

from pdbfixer import PDBFixer
from openmm import *
from openmm.app import *
from openmm.unit import *
from openmmxtb import XtbForce
from openmmml import MLPotential
from openff.toolkit import Molecule
from openmmforcefields.generators import GAFFTemplateGenerator
from openmmforcefields.generators import SMIRNOFFTemplateGenerator
from openmmforcefields.generators import EspalomaTemplateGenerator
from openmm.app.internal import customgbforces
from rdkit import Chem
from rdkit.Chem import AllChem
import io
import os
import locale
import sys
import glob
import copy
import warnings
import threading
import numpy as np
import parmed as pmd
import nglview as nv
from time import sleep
from tqdm.notebook import tqdm
from IPython.display import display, clear_output
import mdtraj as md
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
from collections import Counter
import openbabel
from plip.structure.preparation import PDBComplex
from plip.exchange.report import BindingSiteReport
from scipy.spatial.transform import Rotation
from sklearn.covariance import EmpiricalCovariance
from collections import Counter
import ipywidgets as widgets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
pio.renderers.default = 'colab'
from google.colab import output, files
output.enable_custom_widget_manager()
import torch
torch._C._jit_set_nvfuser_enabled(False)


def navigator(working_directory)->str:
    work = "/content/drive/MyDrive/"+working_directory

    if not os.path.isdir(work):
        print(f'No directory named {working_directory} in your Google Drive.')
        os.mkdir(work)
        print(f'Created a new directory named {working_directory} in your Google Drive.')
        print('You may now proceed to the cell# 1-1.')

    elif os.path.exists(work+'/production_final.xtc'):
        print(f'There is a finished MD simulation in the directory {working_directory}.')
        print('You can extend your production run at cell# 3-3 if you want.')
        print('To continue analysis, go to cell# 4-1.')
        print('To calculate binding free energy, go to cell# 5-1.')

    elif os.path.exists(work+'/production_0.xtc'):
        print('Seems like you were performing production run.')
        print('Go to cell# 3-3.')

    elif os.path.exists(work+'/npt_0.xtc'):
        print('Seems like you were performing NPT equilibration.')
        print('Go to cell# 3-2.')

    elif os.path.exists(work+'/nvt_0.xtc'):
        print('Seems like you were performing NVT equilibration.')
        print('Go to cell# 3-1.')

    elif os.path.exists(work+'/start.xml'):
        print('You have created your simulation object, but never started it.')
        print('Go to cell# 3-1.')

    elif os.path.exists(work+'/starting_structure.pdb'):
        print('You have prepared the starting structure, but did not create the simulation object.')
        print('Go to cell# 2-1.')

    else:
        print(f'There is a directory named {working_directory} in you Google Drive, but it is empty.')
        print('Proceed to cell# 1-1.')

    work += '/'
    return work


def read_settings(work:str)->dict:
    """
    Read the parameters from the settings.txt file.
    """
    if not os.path.exists(work+'settings.txt'):
        return None

    with open(work+'settings.txt', 'r') as f:
        settings = {}
        for line in f.readlines():
            if line.strip().startswith('Forcefield'):
                settings['forcefield'] = line.split(":")[1].strip()
            elif line.strip().startswith('Watermodel'):
                settings['watermodel'] = line.split(":")[1].strip()
            elif line.strip().startswith('Constraints'):
                settings['constraints'] = line.split(":")[1].strip()
            elif line.strip().startswith('Rigidwater'):
                settings['rigidwater'] = bool(line.split(":")[1].strip())
            elif line.strip().startswith('Hydrogen mass'):
                settings['hydrogen mass'] = float(line.split(":")[1].strip())*amu
            elif line.strip().startswith('Precision'):
                settings['precision'] = line.split(":")[1].strip()
            elif line.strip().startswith('Temperature'):
                settings['temperature'] = (float(line.split(":")[1].strip())+273.15)*kelvin
            elif line.strip().startswith('Pressure'):
                settings['pressure'] = float(line.split(":")[1].strip())*bar
            elif line.strip().startswith('Time step'):
                settings['stepsize'] = float(line.split(":")[1].strip())*femtosecond
            elif line.strip().startswith('Ionic strength'):
                settings['ionic strength'] = float(line.split(":")[1].strip())*molar
            elif line.strip().startswith('Periodic Boundary Condition'):
                settings['PBC'] = True if line.split(":")[1].strip()=="True" else False
            elif line.strip().startswith('Padding'):
                settings['padding'] = float(line.split(":")[1].strip())*nanometer
            elif line.strip().startswith('Cation'):
                settings['cation'] = line.split(":")[1].strip()
            elif line.strip().startswith('Anion'):
                settings['anion'] = line.split(":")[1].strip()
            elif line.strip().startswith('Save interval'):
                settings['save interval'] = int(line.split(":")[1].strip())*picosecond
            elif line.strip().startswith('Restraint'):
                settings['restraint'] = int(line.split(":")[1].strip())

    return settings


def read_non_standard_residues_list(work)->list:
    """
    Read the nonstandard residue list from the non_standard_residues.txt file.
    """
    if not os.path.exists(work+'non_standard_residues.txt'):
        return []

    with open(work+'non_standard_residues.txt', 'r') as f:
        nsres = []
        for line in f.readlines():
            if not line.startswith('#'):
                data = line.strip().split( )
                if len(data) > 1:
                    if data[1] == 'ligand':
                        lig = True
                    elif data[1] == 'residue':
                        lig = False
                    nsres.append({'name':data[0], 'lig':lig, 'ff':data[2]})
                else:
                    nsres.append({'name':data[0]})
    return nsres

def determine_forcefield(forcefield, watermodel)->tuple:
    """
    Assign the macromolecular forcefield & water model from the choice.
    """
    if forcefield == "Amber14":
        ff = ForceField('amber14-all.xml')
        if watermodel == "(explicit) TIP3P":
            ff.loadFile('amber14/tip3p.xml')
            water = 'tip3p'
        elif watermodel == "(explicit) SPC/E":
            ff.loadFile('amber14/spce.xml')
            water = 'spce'
        elif watermodel == "(explicit) TIP4P-Ew":
            ff.loadFile('amber14/tip4pew.xml')
            water = 'tip4pew'
        elif watermodel == "(explicit) TIP5P":
            print("ERROR : AMBER forcefield does not support TIP5P water model!")
            print("Please choose other options.")
            return None, ""
        elif watermodel == "(explicit) OPC":
            ff.loadFile('amber14/opc.xml')
            water = 'tip4pew'
        elif watermodel == "(explicit) OPC3":
            ff.loadFile('amber14/opc3.xml')
            water = 'tip3p'
        elif watermodel == "(implicit) HCT (igb=1)":
            ff.loadFile('implicit/hct.xml')
            water = ""
        elif watermodel == "(implicit) OBC1 (igb=2)":
            ff.loadFile('implicit/obc1.xml')
            water = ""
        elif watermodel == "(implicit) OBC2 (igb=5)":
            ff.loadFile('implicit/obc2.xml')
            water = ""
        elif watermodel == "(implicit) GBn (igb=7)":
            ff.loadFile('implicit/gbn.xml')
            water = ""
        elif watermodel == "(implicit) GBn2 (igb=8)":
            ff.loadFile('implicit/gbn2.xml')
            water = ""
        else:
            print('Error in reading AMBER forcefield parameters.')
            return None, ""

    elif forcefield == "CHARMM36":
        if watermodel == "(explicit) TIP3P":
            ff = ForceField('charmm36.xml', 'charmm36/water.xml')
            water = 'tip3p'
        elif watermodel == "(explicit) SPC/E":
            ff = ForceField('charmm36.xml', 'charmm36/spce.xml')
            water = 'spce'
        elif watermodel == "(explicit) TIP4P-Ew":
            ff = ForceField('charmm36.xml', 'charmm36/tip4pew.xml')
            water = 'tip4pew'
        elif watermodel == "(explicit) TIP5P":
            ff = ForceField('charmm36.xml', 'charmm36/tip5p.xml')
            water = 'tip5p'
        elif watermodel == "(explicit) OPC":
            print("ERROR : CHARMM forcefield does not support OPC water models!")
            print("Please choose other options.")
            return None, ""
        elif watermodel == "(explicit) OPC3":
            print("ERROR : CHARMM forcefield does not support OPC water models!")
            print("Please choose other options.")
            return None, ""
        elif watermodel == "(implicit) HCT (igb=1)":
            ff = ForceField('charmm36.xml', 'implicit/hct.xml')
            water = ''
        elif watermodel == "(implicit) OBC1 (igb=2)":
            ff = ForceField('charmm36.xml', 'implicit/obc1.xml')
            water = ''
        elif watermodel == "(implicit) OBC2 (igb=5)":
            ff = ForceField('charmm36.xml', 'implicit/obc2.xml')
            water = ''
        elif watermodel == "(implicit) GBn (igb=7)":
            ff = ForceField('charmm36.xml', 'implicit/gbn.xml')
            water = ''
        elif watermodel == "(implicit) GBn2 (igb=8)":
            ff = ForceField('charmm36.xml', 'implicit/gbn2.xml')
            water = ''
        else:
            print("Error in reading the CHARMM forcefield parameters")
            return None, ""

    else:
        return "", ""

    return ff, water


class SimulationStream(io.StringIO):
    """
    StringIO's child class to capture stdout simulation output.
    """
    def __init__(self):
        super().__init__()
        self.string = ""

    def write(self, string):
        self.string += string

    def read(self, size=-1)->str:
        if size < 0:
            result = self.string
            self.string = ""
        else:
            result = self.string[:size]
            self.string = self.string[size:]
        return result


def apply_restraints(work, simulation, restraint):
    """
    Apply a harmonic positional restraint potential to the heavy atoms.
    """
    if restraint > 0:
        posres = CustomExternalForce("k*periodicdistance(x, y, z, x0, y0, z0)^2")
        posres.addGlobalParameter("k", restraint)
        posres.addPerParticleParameter("x0")
        posres.addPerParticleParameter("y0")
        posres.addPerParticleParameter("z0")

        pdb = PDBFile(work+'minimized.pdb')
        heavy = []
        for i, a in enumerate(pdb.topology.atoms()):
            if a.element.symbol != 'H':
                heavy.append(i)

        positions = simulation.context.getState(getPositions=True).getPositions(asNumpy=True)
        for i in range(np.shape(positions)[0]):
            if i in heavy:
                posres.addParticle(i, positions[i])
        simulation.system.addForce(posres)
    return simulation



def restore_simulation_from_state(work, stage):
    """
    Restore a simulation object from a state file.
    """
    s = read_settings(work)
    pdb = PDBFile(work+'minimized.pdb')
    system = XmlSerializer.deserialize(open(work+'system.xml').read())
    integrator = XmlSerializer.deserialize(open(work+'integrator.xml').read())
    platform = Platform.getPlatformByName('CUDA')
    properties = {'Precision': s['precision']}
    simulation = Simulation(pdb.topology, system, integrator, platform, properties)
    statename = work+stage+'.xml'
    simulation.context.setState(XmlSerializer.deserialize(open(statename).read()))
    change_flag = False # When the context is changed, reinitialize it
    if stage in ['npt', 'production'] and s['PBC']:
        barostat = MonteCarloBarostat(s['pressure'], s['temperature'])
        simulation.system.addForce(barostat)
        change_flag = True
    if stage != 'production' and stage != 'start':
        simulation = apply_restraints(work, simulation, s['restraint'])
        change_flag = True
    if change_flag:
        simulation.context.reinitialize(preserveState=True)

    return simulation


def recall_simulation(work):
    """
    Recall a simulation object from the working directory.
    """
    if os.path.exists(work+'production.xml'):
        stage = "production"
        print('Resuming production run...', end='')
        simulation = restore_simulation_from_state(work, 'production')

    elif os.path.exists(work+'npt.xml'):
        stage = "npt"
        print('Resuming NPT equilibration...', end='')
        simulation = restore_simulation_from_state(work, 'npt')

    elif os.path.exists(work+'nvt.xml'):
        stage = "nvt"
        print('Resuming NVT equilibration...', end='')
        simulation = restore_simulation_from_state(work, 'nvt')

    elif os.path.exists(work+'start.xml'):
        stage = "start"
        print('Starting the simulation...', end='')
        simulation = restore_simulation_from_state(work, 'start')

    else:
        print('Error in reading the simulation object.')
        return None, ""

    print('done.')
    return simulation, stage


def run_simulation(work, stage, simulation, stream, stepsize, temperature, length, save_period, backup_period):
    """
    Run the simulation object for length picoseconds.
    """

    s = read_settings(work)

    runtime = 0.0
    time = simulation.context.getTime().value_in_unit(picosecond)

    simulation.context.getIntegrator().setStepSize(stepsize*femtoseconds)
    simulation.context.getIntegrator().setTemperature((temperature+273.15)*kelvin)

    if s['PBC'] and stage in ['npt', 'production']:
        simulation.context.setParameter(MonteCarloBarostat.Temperature(), (temperature+273.15)*kelvin)

    while runtime < length:
        step = simulation.context.getStepCount()
        log = StateDataReporter(stream, 50, step=True, potentialEnergy=True,
                                temperature=True, speed=True, time=True,
                                density=s['PBC'])

        i = 0
        for f in glob.glob(work+stage+'_*.xtc'):
            i += 1

        save_freq = int(save_period/(stepsize*femtosecond))
        xtc = XTCReporter(work+stage+f'_{i}.xtc', save_freq, False, s['PBC'])

        simulation.reporters.clear()
        simulation.reporters.append(log)
        simulation.reporters.append(xtc)

        for i in range(backup_period*60):
            simulation.runForClockTime(1*second)
            newStep = simulation.context.getStepCount()
            now = simulation.context.getTime().value_in_unit(picosecond)
            runtime = now - time
            if runtime >= length:
                break
            lines = stream.read()
            with open(work+stage+'_log.txt', 'a') as f:
                f.write(lines)
            line = lines.split('\n')
            if len(line) < 2:
                continue
            data = line[-2].split(',')
            print('', end='\r')
            if len(data) < 6:
                speed = float(data[4])
                density = ''
            else:
                speed = float(data[5])
                density = f'Density: {float(data[4]):.4f} g/mL, '
            if stage != 'production':
                T = f'Time: {float(data[1]):.1f} ps, '
            else:
                T = f'Time: {float(data[1])/1000:.2f} ns, '

            if speed > 0:
                d, hs = divmod((length - runtime)/speed/1000*24, 24)
                h, mins = divmod(hs*60, 60)
                min, secs = divmod(mins*60, 60)
                sec = int(secs)
                remaining = f'{str(int(d))+":" if d > 0 else ""}{int(h):0>2}:{int(min):0>2}:{sec:0>2} left, '
            else:
                remaining = '--:--:-- left, '

            print(f'{remaining}Step: {data[0]}, {T}Temperature: {(float(data[3])-273.15):.2f} C, {density}Speed: {speed:.1f} ns/day', end='')

        state = simulation.context.getState(getPositions=True, getVelocities=True)
        with open(work+stage+'.xml', 'w') as f:
            f.write(XmlSerializer.serialize(state))

    print('\nSimulation finished!')
    return

work = navigator(working_directory)

Mounted at /content/drive
Installing all required python packages...done.




No directory named new in your Google Drive.
Created a new directory named new in your Google Drive.
You may now proceed to the cell# 1-1.


## 1. PDB Input Preparation

In [2]:
#@title 1-1. Add missing atoms & Remove unrecognized molecules

assert 'work' in globals(), "Run cell# 0 to set working directory."

#@markdown Enter the 4-letter PDB ID to fetch PDB file.<br/>
#@markdown If you want to upload your PDB file, leave it as empty.
PDB_ID = "1DET" #@param {type:"string"}
#@markdown For ligands/non-standard residues that are defined in [RCSB Chemical Component Dictionary](https://www.wwpdb.org/data/ccd),
#@markdown enter the residue name below. <br/>For more than one non-standard residues, separate them using commas without spaces.
#@markdown <br/>**The residue names you entered must match with the ones present in the PDB file!**
non_standard_residues = "2GP,CGA" #@param {type:"string"}
#@markdown For ligands/non-standard residues that are not defined on the RCSB Dictionary, run the next cell to upload the SDF file.<br/>
#@markdown <br/> Check the box below to remove the water molecules present in your PDB.
remove_crystallographic_water = True #@param {type:"boolean"}

amino_acids = ['ALA', 'CYS', 'ASP', 'GLU', 'PHE', 'GLY', 'HIS', 'ILE', 'LYS', 'LEU',
               'MET', 'ASN', 'PRO', 'GLN', 'ARG', 'SER', 'THR', 'VAL', 'TRP', 'TYR']
nucleic_acids = ['A', 'T', 'G', 'C', 'U', 'I', 'DA', 'DT', 'DG', 'DC', 'DI']

if non_standard_residues == "":
    nsres = []
else:
    nsres = non_standard_residues.split(',')

with open(work+'non_standard_residues.txt', 'w') as f:
    f.write('# User-defined Non-standard Residues\n')
    for res in nsres:
        f.write(res.upper()+'\n')

if PDB_ID == "":
    print("Upload your PDB file.")
    pdbfile = files.upload()
    try:
        fixer = PDBFixer(filename=next(iter(pdbfile)))
    except:
        print('Error in reading the input pdbfile.')
else:
    try:
        fixer = PDBFixer(pdbid=PDB_ID.upper())
    except:
        print('Error in fetching the pdb file from the RCSB.')

print('Adding missing atoms...', end='')
fixer.findMissingResidues()
terminal_missing_res = []
chains = list(fixer.topology.chains())
for key in fixer.missingResidues.keys():
    if key[1] == 0 or key[1] == len(list(chains[key[0]].residues())):
        terminal_missing_res.append(key)
for key in terminal_missing_res:
    del fixer.missingResidues[key]
fixer.findMissingAtoms()
fixer.addMissingAtoms()

model = Modeller(fixer.topology, fixer.positions)
res_to_del = []
for res in model.topology.residues():
    if remove_crystallographic_water:
        if res.name not in amino_acids+nucleic_acids+nsres:
            res_to_del.append(res)
    else:
        if res.name not in (amino_acids+nucleic_acids+nsres+['HOH']):
            res_to_del.append(res)

model.delete(res_to_del)

print('done.')

with open(work+"fixed.pdb", 'w') as f:
    PDBFile.writeFile(model.topology, model.positions, f)

custom_info = {}

view1 = nv.NGLWidget()
view1._set_size('750px','500px')
view1.add_component(nv.FileStructure(work+"fixed.pdb"))
view1.add_licorice()
display(view1)

Adding missing atoms...done.


NGLWidget()

In [None]:
#@title (Optional) Upload SDF files for Custom Ligands/Residues
#@markdown For ligands/non-standard residues not defined on RCSB Chemical Component Dictionary, <br/>
#@markdown run this cell to upload SDF file(s). <br/>

assert 'work' in globals(), "Please run the cell #0 to set working directory."
assert len(nsres) > 0, "No non-standard residues entered in the above cell!"

print("Upload your ligand SDF file.")
ligfiles = files.upload()

lignames = list(ligfiles.keys())
lig_select = widgets.Dropdown(options=lignames, value=lignames[0], description='Custom residue')
view0 = nv.NGLWidget()
view0._set_size('400px','300px')

for lig in lignames:
    custom_info[lig] = {}
custom_info[lig_select.value]['id'] = view0.add_component(nv.FileStructure(lig_select.value))

res_select = widgets.Dropdown(options=nsres, value=nsres[0], description='Choose residue name:')
custom_info[lig_select.value]['name'] = res_select.value

def update_mol2(change): # callback function to interactively update viewer
    view0.remove_component(custom_info[change.old]['id'])
    custom_info[change.new]['id'] = view0.add_component(nv.FileStructure(change.new))
    view0.center()
    custom_info[change.new]['name'] = res_select.value

def update_name(change):
    custom_info[lig_select.value]['name'] = change.new

lig_select.observe(update_mol2, names='value')
res_select.observe(update_name, names='value')

print('Assign residue name to the structures of the residues you uploaded.')

display(lig_select, view0, res_select)

In [None]:
#@title 1-2. Add hydrogen atoms to the PDB file

assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'fixed.pdb'), "Please run the cell# 1-1 first."

#@markdown Choose a method to add hydrogen atoms to your system : <br/>
#@markdown * PDB2PQR - Uses PropKa to protonate systems without ligands.
#@markdown * REDUCE - Works for ligands too, but only for physiological pH.
#@markdown * OpenBabel - Can automatically protonate any system, but often buggy
#@markdown * PDBFixer - Only works for standard residues
#@markdown * Custom - Upload your PDB file with hydrogens (use external tools like
#@markdown [DelPhipKa](http://compbio.clemson.edu/pka_webserver/),
#@markdown [H++](http://newbiophysics.cs.vt.edu/H++/),
#@markdown [PypKa](https://pypka.org/run-pypka/), etc.

protonation_method = "REDUCE" #@param ['PDB2PQR', 'REDUCE', 'OpenBabel', 'PDBFixer', 'Custom']

#@markdown Choose the reference pH for the determination of protonation state.
pH = 7.4 #@param {type:"slider", min:0.0, max:14.0, step:0.1}

for key, value in custom_info.items():
    subprocess.run(f'cp {key} {work}{value["name"]}.sdf', shell=True)


def protonate_w_pdbfixer(work, removeH, pH)->None:
    """
    Add missing hydrogen atoms to the PDB file using PDBFixer.
    """
    print('Adding hydrogens using PDBFixer...', end='')
    pdb = PDBFile(work+'fixed.pdb')
    model = Modeller(pdb.topology, pdb.positions)

    if removeH:
        Hs = []
        for a in model.topology.atoms():
            if a.element.symbol == 'H':
                Hs.append(a)
        model.delete(Hs)

    try:
        model.addHydrogens(pH=pH)
        with open(work+'starting_structure.pdb', 'w') as f:
            PDBFile.writeFile(model.topology, model.positions, f)
    except Exception as e:
        print(e)
    print('done.')
    return


def protonate_manually(work)->None:
    """
    Enable uploading of custom PDB file.
    """
    print('Here is your fixed pdb file.')
    files.download(work+'fixed.pdb')
    sleep(10)
    print('Now upload the pdb file with all the hydrogens.')
    fixed_wH = files.upload()
    filename = next(iter(fixed_wH))
    subprocess.run(f'cp -rf {filename} {work}starting_structure.pdb', shell=True)
    return


def protonate_w_openbabel(work, nsres, pH)->None:
    """
    Add missing hydrogen atoms to the PDB file using OpenBabel.
    Since OpenBabel names all the hydrogen as 'H', we need to rename them.
    """
    print('Adding hydrogens using OpenBabel...', end='')

    # save atom names before OpenBabel butchers them
    pdb = PDBFile(work+'fixed.pdb')
    nsatoms = []
    for res in pdb.topology.residues():
        if res.name in nsres:
            atoms = []
            for a in res.atoms():
                atoms.append(a.name)
            nsatoms.append({'name':res.name, 'chain':res.chain.id, 'index':res.id, 'atoms':atoms})

    locale.getpreferredencoding = lambda: "UTF-8"
    subprocess.run(['obabel', f'{work}fixed.pdb', '-O', 'tmp.pdb', '-xk', '-p', f'{pH}'])

    with open('tmp.pdb', 'r') as f: # just number Hs so that OpenMM can discriminate them
        lines = f.readlines()
        chain = ''
        resi = -100
        headlines = ''
        bodylines = []
        taillines = ''
        headflag = True
        for line in lines:
            if line.startswith('ATOM') or line.startswith('HETATM'):
                headflag = False
                if line[17:20] == 'UNL': # ligands/ncAAs
                    if line[21] != chain: # reset index
                        chain = line[21]
                        resi = -100
                    if int(line[22:26]) > resi:
                        resi = int(line[22:26])
                        count = 0 # 0-based index of nsres
                        Hcount = 1 # 1-based index of Hs in the residue
                    if line[77] == 'H':
                        name = 'H'+str(Hcount)
                        Hcount += 1
                    else:
                        name = nsatoms[resi-1]['atoms'][count]
                        count += 1
                    resname = nsatoms[resi-1]['name']
                    chainid = nsatoms[resi-1]['chain']
                    resindex = nsatoms[resi-1]['index']
                    newinfo = f'{name:^4} {resname:>3} {chainid}{resindex:>4}'
                elif line[77] == 'H': # element symbol
                    if line[21] != chain: # chain ID
                        chain = line[21]
                        resi = -100
                    if int(line[22:26]) > resi:
                        resi = int(line[22:26])
                        Hcount = 1
                    name = 'H'+str(Hcount) # count number of Hs in a residue
                    Hcount += 1
                    newinfo = f'{name:^4}'+line[16:26]
                else:
                    newinfo = line[12:26]
                newline = line[:12]
                newline += newinfo
                newline += line[26:]
                bodylines.append(newline)
            elif headflag:
                headlines += line
            else:
                taillines += line

    with open('tmp2.pdb', 'w') as f:
        f.write(headlines)
        for line in sorted(bodylines, key=lambda x: (x[21], int(x[22:26]))):
            f.write(line)  # sort by chainID then residue index
        f.write(taillines)

    warnings.filterwarnings('ignore')
    pdb2 = PDBFile('tmp2.pdb')

    bonds_wH = []
    for bond in pdb2.topology.bonds():
        if bond[0].element.symbol == 'H':
            bonds_wH.append([bond[1], bond[0]])
        elif bond[1].element.symbol == 'H':
            bonds_wH.append([bond[0], bond[1]])
        else:
            continue

    counter = 4
    for bond in bonds_wH:
        multi = [b[0] for b in bonds_wH].count(bond[0]) # number of equivalent Hs
        bond[1].residue = bond[0].residue
        if counter > multi or counter == 0:
            counter = multi
        if bond[0].name == 'N':
            if multi == 3: # N-terminus
                bond[1].name = 'H'+str(counter)
                counter -= 1
            else:          # amide proton
                bond[1].name = 'H'
                counter -= 1
        else:
            if multi == 1:
                bond[1].name = 'H'+bond[0].name[1:]
                counter -= 1
            elif multi == 2: # AMBER atom naming scheme (HB2, HB3 connected to CB)
                if bond[0].element.symbol == 'C':
                    bond[1].name = 'H'+bond[0].name[1:]+str(counter+1)
                    counter -= 1
                else:
                    bond[1].name = 'H'+bond[0].name[1:]+str(counter)
                    counter -= 1
            elif multi == 3:
                bond[1].name = 'H'+bond[0].name[1:]+str(counter)
                counter -= 1

    with open(work+'starting_structure.pdb', 'w') as f:
        PDBFile.writeFile(pdb2.topology, pdb2.positions, f)
    print('done.')
    return



def protonate_w_pdb2pqr(work, removeH, pH)->None:
    """
    Add missing hydrogen atoms to the PDB file using PDB2PQR.
    """
    print('Adding hydrogens using PDB2PQR...', end='')

    if removeH:
        pdb = PDBFile(work+'fixed.pdb')
        model = Modeller(pdb.topology, pdb.positions)
        Hs = []
        for a in model.topology.atoms():
            if a.element.symbol == 'H':
                Hs.append(a)
        model.delete(Hs)

        with open('fixed-H.pdb', 'w') as f:
            PDBFile.writeFile(model.topology, model.positions, f)

        params = ' fixed-H.pdb fixed.pqr'

    else:
        params = f' {work}fixed.pdb fixed.pqr'


    params += f' --ff AMBER --ffout AMBER --pdb-output {work}starting_structure.pdb'
    params += f' --titration-state-method propka --with-ph {pH} --pH {pH} --drop-water'
    subprocess.run("pdb2pqr"+params, shell=True)
    print('done.')
    return



def get_parmed_atom_from_name(structure, res_top, name):
    for i, res in enumerate(structure.topology.residues()):
        if res.name == res_top.name and res.index == res_top.index:
            residue = structure.residues[i]
            break
    for atom in residue.atoms:
        if atom.name == name:
            return atom

def get_parmed_res_from_top(structure, res_top):
    for i, res in enumerate(structure.topology.residues()):
        if res.name == res_top.name and res.index == res_top.index:
            return structure.residues[i]

def find_coordinate(structure, atom):
    for i, a in enumerate(structure.atoms):
        if a is atom:
            return np.array([structure[i].xx, structure[i].xy, structure[i].xz])

def add_missed_proton(structure, residue, coords)->None:
    """
    Rarely, REDUCE program omits amide proton of some residues.
    This function adds them back.
    """
    H = pmd.Atom(name='H', atomic_number=1, type='H', mass=1.008, charge=0.0)

    pmd_res = get_parmed_res_from_top(structure, residue)

    structure.add_atom_to_residue(H, pmd_res)

    for i, a in enumerate(structure.atoms):
        if a is H:
            structure[i].xx = coords[0]
            structure[i].xy = coords[1]
            structure[i].xz = coords[2]

    structure.assign_bonds()

    print(f'Added a missing amide proton at {residue}!')

    return



def protonate_w_reduce(work, removeH)->None:
    """
    Add missing hydrogen atoms to the PDB file using REDUCE.
    """
    print('Warning: With REDUCE, the reference pH is fixed to 7.4!')
    print('Adding hydrogens using REDUCE...', end='')

    if removeH:
        Hflag = '-DROP_HYDROGENS_ON_ATOM_RECORDS '
    else:
        Hflag = ''
    subprocess.run(f'reduce -BUILD {Hflag}{work}fixed.pdb > reduced.pdb', shell=True)
    print('done.')

    struc = pmd.load_file("reduced.pdb")
    struc.assign_bonds()

    for chain in struc.topology.chains():
        for i, res in enumerate(chain.residues()):
            if i == 0 or res.name == 'PRO':
                C = get_parmed_atom_from_name(struc, res, 'C')
                O = get_parmed_atom_from_name(struc, res, 'O')
                continue
            N = get_parmed_atom_from_name(struc, res, 'N')

            # In most cases, N-H is antiparallel to C=O
            a_names = [atom.name for atom in res.atoms()]
            if 'N' in a_names and 'CA' in a_names and 'C' in a_names and 'O' in a_names and 'H' not in a_names:
                v_c = find_coordinate(struc, C) - find_coordinate(struc, O)
                pos_H = find_coordinate(struc, N) + 1.01*v_c/np.linalg.norm(v_c)
                add_missed_proton(struc, res, pos_H)

            C = get_parmed_atom_from_name(struc, res, 'C')
            O = get_parmed_atom_from_name(struc, res, 'O')

    with open(work+'starting_structure.pdb', 'w') as f:
        PDBFile.writeFile(struc.topology, struc.positions, f)
    return


#@markdown Check below box to ignore the hydrogen atoms currently present in your PDB.
ignore_current_Hs = False #@param {type:"boolean"}

if protonation_method == "PDBFixer":
    protonate_w_pdbfixer(work, ignore_current_Hs, pH)

elif protonation_method == "OpenBabel":
    print('Installing OpenBabel...', end='')
    subprocess.run(['mamba', 'install', '-c', 'conda-forge', 'openbabel'])
    print('done.')
    protonate_w_openbabel(work, nsres, pH)

elif protonation_method == "PDB2PQR":
    print('Installing PDB2PQR...', end='')
    subprocess.run(['pip', 'install', 'pdb2pqr'])
    print('done.')
    protonate_w_pdb2pqr(work, ignore_current_Hs, pH)

elif protonation_method == 'REDUCE':
    print('Installing REDUCE...', end='')
    subprocess.run(['mamba', 'install', '-c', 'bioconda', 'reduce'])
    print('done.')
    protonate_w_reduce(work, ignore_current_Hs)

else:
    protonate_manually(work)



#@markdown Select the center of focus to zoom in/out.<br/>
#@markdown You may also long-click an atom to center zoom and translate molecule with a right drag.

to_lookat = []

pdb = PDBFile(work+'starting_structure.pdb')

for res in pdb.topology.residues():
    if res.name in nsres:
        to_lookat.append(res)

# Interactively shows the protonated structure
select_res = widgets.Dropdown(options=['whole structure']+to_lookat, value='whole structure', description='Focus at', continuous_update=True)
view2 = nv.NGLWidget()
view2._set_size('750px','500px')
try:
    view2.add_component(nv.FileStructure(work+"starting_structure.pdb"))
except:
    print(f"While adding hydrogen atoms, an Error occurred in {protonation_method}.")

view2.add_licorice()
view2.center()

def center_of_view(change): # callback function for the update
    view2.center(f'{change.new.id} and :{change.new.chain.id}' if change.new != 'whole structure' else None)

select_res.observe(center_of_view, names='value')

display(select_res, view2)

print("The protonated structure also is saved at the 'starting_structure.pdb' in your working directory.")
print("Please double check the protonation state of all residues, especially for non-standard ones!")
print("Subsequent steps will fail with incorrectly assigned atoms.")

## 2. Create Simulation Object

In [None]:
#@title 2-1. Select a Forcefield
assert 'work' in globals(), "Please run the cell #0 to set working directory."

def identify_new_res(fftype, forcefield, model, non_standard_residue_list)->list:
    """
    Returns a list of nonstandard residues that
    are not present in the given macromolecular forcefield.
    """

    ions = ['AL', 'BA', 'BR', 'CA', 'CD', 'CL', 'CE', 'CO', 'CS', 'CU', 'EU',
            'EU3', 'F', 'FE', 'FE2', 'GD3', 'HG', 'IN', 'IOD', 'K', 'LA', 'LI',
            'LU', 'MG', 'MN', 'NA', 'NI', 'PB', 'PD', 'PR', 'PT', 'RB',
            'SM', 'SR', 'TB', 'V', 'YB2', 'YB3', 'ZN']

    if len(non_standard_residue_list) == 0:
        return []

    # Prepare copied modeller objects with a single nonstandard residue
    models = {}
    for res in non_standard_residue_list:
        models[res] = copy.deepcopy(model)
    for key, value in models.items():
        other_res = []
        for res in value.topology.residues():
            if res.name in non_standard_residue_list and res.name != key:
                other_res.append(res)
        if len(other_res) > 0:
            value.delete(other_res)

    new_res = []

    for key, value in models.items():
        if key in ions:
            continue
        try:
            if fftype == 'Amber14':
                value.addExtraParticles(forcefield)
            elif fftype == 'CHARMM36': #due to patches, CHARMM forcefield gives errors with addExtraParticles()
                forcefield.createSystem(value.topology)
            print(f'Found the residue template for {key} in the {fftype} forcefield!')

        except ValueError: # Usually raised when no matching residue templates are found.
            #print(str(e)[:41])
            # Trim other residues and save first appearing structure of the nonstandard residue
            res_to_del = []
            counter = 0
            for res in value.topology.residues():
                if res.name != key or counter > 0:
                    res_to_del.append(res)
                else:
                    counter += 1
            value.delete(res_to_del)
            with open(f'{key}.pdb', 'w') as f:
                PDBFile.writeFile(value.topology, value.positions, f)
            new_res.append(key)
    return new_res


def isLigand(model, name)->bool:
    """
    Check whether the given residue is connected to the other residues.
    If not, it is considered as a ligand.
    """
    for res in model.topology.residues():
        if res.name == name:
            if len(list(res.external_bonds())) > 0:
                return False
            else:
                return True

def scan_elements(model, name)->bool:
    """
    Check whether the given residue contains the elements that
    are not supported by GAFF/SMIRNOFF/espaloma/MACE.
    """
    gaff_elements = ['H', 'C', 'N', 'O', 'S', 'P', 'F', 'Cl', 'Br', 'I']

    for res in model.topology.residues():
        if res.name == name:
            residue = res
            break

    elements = []
    for atom in residue.atoms():
        if atom.element.symbol not in elements:
            elements.append(atom.element.symbol)

    elements = list(set(elements))

    for ele in elements:
        if ele not in gaff_elements:
            return False

    return True


#@markdown Select a forcefield to apply on your system's macromolecules.
forcefield = 'Amber14' #@param ['Amber14', 'CHARMM36']

#@markdown Select a water model.
watermodel = "(explicit) TIP3P" #@param ["(explicit) TIP3P", "(explicit) SPC/E", "(explicit) TIP4P-Ew", "(explicit) TIP5P", "(explicit) OPC", "(explicit) OPC3", "(implicit) HCT (igb=1)", "(implicit) OBC1 (igb=2)", "(implicit) OBC2 (igb=5)", "(implicit) GBn (igb=7)", "(implicit) GBn2 (igb=8)"]

#@markdown For ligands and non-standard residues, you should select a small molecule forcefield to apply on them below.
#@markdown * GFN-FF : Semiempirical tight binding forcefield (can simulate chemical reactions, but slow)
#@markdown * MACE-off : Machine-learning forcefield (currently having issues with temperature blowing up)
#@markdown * GAFF : Uses Ambertool's GAFF template generator
#@markdown * Sage : Uses OpenForceField's Smirnoff template generator
#@markdown * espaloma : Small molecule forcefield parametrized by a neural network
#@markdown * Custom : You should upload a XML file that contains the residue/atoms infos to the forcefield.

ff, water = determine_forcefield(forcefield, watermodel)

assert ff is not None, "Forcefield not set."

pdb = PDBFile(work+'starting_structure.pdb')
model = Modeller(pdb.topology, pdb.positions)

nsres = [a['name'] for a in read_non_standard_residues_list(work)]
uaas = identify_new_res(forcefield, ff, model, nsres)
uaa_info = {}
uaa_atom_info = []


if len(uaas) > 0:
    print(f'New residue infos must be registered to the {forcefield} forcefield.')
    print('Select a small molecule forcefield to compute the parameters of the residue with the structure shown below.')
    if len(uaas) > 1:
        print('Toggle the dropdown menu to select forcefield for other residue(s), too.')

    for uaa in uaas:
        lig_flag = isLigand(model, uaa)
        viable_ffs = ['GFN-FF'] # GFN-FF is applicable to all elements on the Periodic Table.
        if scan_elements(model, uaa):
            viable_ffs = ['MACE-off23'] + viable_ffs # Machine Learning forcefields
        if scan_elements(model, uaa) and lig_flag:
            viable_ffs = ['GAFF-2.11', 'Sage-2.2.1'] + viable_ffs # small molecule forcefields
        viable_ffs.append('Custom')
        uaa_info[uaa] = {'name':uaa, 'file':f'{uaa}.pdb', 'lig':lig_flag, 'ff':'GFN-FF', 'ffs':viable_ffs}

    pdb_select = widgets.Dropdown(options=uaas, value=uaas[0], description='New residue', continuous_update=True)
    view3 = nv.NGLWidget()
    view3._set_size('400px','300px')
    # save NGLView component id to remove it later
    uaa_info[pdb_select.value]['id'] = view3.add_component(nv.FileStructure(uaa_info[pdb_select.value]['file']))
    ff_select = widgets.Dropdown(options=uaa_info[uaas[0]]['ffs'], value='GFN-FF', description='forcefield', continuous_update=True)

    def update_pdb(change): # callback function to interactively update viewer
        view3.remove_component(uaa_info[change.old]['id'])
        uaa_info[change.new]['id'] = view3.add_component(nv.FileStructure(uaa_info[change.new]['file']))
        view3.center(change.new)
        ff_select.options = uaa_info[change.new]['ffs']
        ff_select.value = uaa_info[change.new]['ff']

    def update_ff(change): # callback function to interactively update residue infos
        uaa_info[pdb_select.value]['ff'] = change.new
        ffs =[]
        for info in uaa_info.values():
            ffs.append(info['ff'])

    pdb_select.observe(update_pdb, names='value')
    ff_select.observe(update_ff, names='value')

    display(pdb_select, view3, ff_select)

else:
    print('All residues assigned to the forcefield! You can skip the next cell.')

New residue infos must be registered to the Amber14 forcefield.
Select a small molecule forcefield to compute the parameters of the residue with the structure shown below.
Toggle the dropdown menu to select forcefield for other residue(s), too.


Dropdown(description='New residue', options=('2GP', 'CGA'), value='2GP')

NGLWidget()

Dropdown(description='forcefield', index=3, options=('GAFF-2.11', 'Sage-2.2.1', 'MACE-off23', 'GFN-FF', 'Custo…

In [None]:
#@title 2-2. Set buffer region around XTB/ML atoms
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert 'uaas' in globals(), "Please run cell# 2-1 first."

#@markdown Values for atomic partial charges and Leonnard-Jones parameters of non-standard residues are only approximate. <br/>
#@markdown So constructing a buffer region with XTB/ML atoms to minimize the error from the non-bonded interaction between MM atoms and XTB/ML atoms is highly recommended. <br/>
#@markdown Adjust cutoff distances around each nonstandard residues to set a buffer region. <br/>
#@markdown **Do not run this cell more than once in a row!**

def make_xml_for_custom_residues(fftype, ff, model, residue_infos)->bool:
    """
    Generate a template XML file for custom residues.
    """
    ncAA = []
    for value in residue_infos.values():
        if value['ff'] == 'Custom':
            ncAA.append(value['name'])
    if len(ncAA) == 0:
        return False

    templates, residues = ff.generateTemplatesForUnmatchedResidues(model.topology)

    with open('template.xml', 'w') as f:
        f.write("<ForceField>\n")
        f.write(" <AtomTypes>\n")
        f.write(" </AtomTypes>\n")
        f.write(" <Residues>\n")
        for i, temp in enumerate(templates):
            if temp.name in ncAA:
                f.write(f'  <Residue name="{temp.name}">\n')
                for atom in residues[i].atoms():
                    f.write(f'   <Atom name="{atom.name}" type="Enter_atom_type"/>\n')
                for bond in residues[i].internal_bonds():
                    f.write(f'   <Bond atomName1="{bond[0].name}" atomName2="{bond[1].name}"/>\n')
                for bond in residues[i].external_bonds():
                    if bond[0] in list(residues[i].atoms()):
                        atom = bond[0]
                    elif bond[1] in list(residues[i].atoms()):
                        atom = bond[1]
                    f.write(f'   <ExternalBond atomName="{atom.name}"/>\n')
                f.write('  </Residue>\n')
        f.write(" </Residues>\n")
        f.write(" <HarmonicBondForce>\n")
        f.write(" </HarmonicBondForce>\n")
        f.write(" <HarmonicAngleForce>\n")
        f.write(" </HarmonicAngleForce>\n")
        if fftype == 'CHARMM36':
            f.write(' <AmoebaUreyBradleyForce>\n')
            f.write(" </AmoebaUreyBradleyForce>\n")
        if fftype == 'Amber14':
            f.write(' <PeriodicTorsionForce ordering="amber">\n')
        elif fftype == 'CHARMM36':
            f.write(' <PeriodicTorsionForce ordering="charmm">\n')
        f.write(" </PeriodicTorsionForce>\n")
        if fftype == 'CHARMM36':
            f.write(' <CMAPTorsionForce>\n')
            f.write(" </CMAPTorsionForce>\n")
        if fftype == 'Amber14':
            f.write(' <NonbondedForce coulomb14scale="0.8333333333333334" lj14scale="0.5">\n')
        elif fftype == 'CHARMM36':
            f.write(' <NonbondedForce coulomb14scale="1.0" lj14scale="1.0" useDispersionCorrection="False">\n')
        f.write('  <UseAttributeFromResidue name="charge"/>\n')
        f.write(' </NonbondedForce>\n')
        if fftype == 'CHARMM36':
            f.write(' <LennardJonesForce lj14scale="1.0" useDispersionCorrection="False">\n')
            f.write(' </LennardJonesForce>\n')
        f.write("</ForceField>\n")

    files.download('template.xml')
    sleep(5)
    return True


def find_terminal_form(pdb)->str:

    mol = Chem.MolFromPDBFile(pdb, removeHs=False)

    nterm = Chem.MolFromSmarts('[NH3+][CH1]C(=O)')
    nterm_atoms = list(mol.GetSubstructMatch(nterm))
    if len(nterm_atoms) > 0:
        return 'N'

    cterm = Chem.MolFromSmarts('[NH1][CH1]C(=O)[O-]')
    cterm_atoms = list(mol.GetSubstructMatch(cterm))
    if len(cterm_atoms) > 0:
        return 'C'

    pbackbone = Chem.MolFromSmarts('[NH1][CH1]C(=O)')
    pbbatoms = list(mol.GetSubstructMatch(pbackbone))
    if len(pbbatoms) > 0:
        return 'bb-p'

    return '-'


def assign_partial_charges(name, residue_infos)->dict:

    pdb = residue_infos[name]['file']
    subprocess.run(f'obabel {pdb} -O {name}.mol2 --partialcharge mmff94', shell=True)

    charges = {}

    with open(f'{name}.mol2', 'r') as f:
        lines = f.readlines()
        atom_flag = False
        for line in lines:
            if line.strip().startswith('@<TRIPOS>ATOM'):
                atom_flag = True
                continue
            elif line.strip().startswith('@<TRIPOS>BOND'):
                atom_flag = False
                break
            if atom_flag:
                data = line.strip().split( )
                charges[data[1]] = round(float(data[-1]), 4)

    term = find_terminal_form(pdb)
    sum = 0.0
    for q in charges.values():
        sum += q
    offset = float(int(sum))-sum

    if term == "-":
        for key, value in charges.items():
            value += offset/len(charges)

    elif term == "N":
        before = 0.0
        for key, value in charges.items():
            if key == 'N':
                before += value
                value = -0.3
            elif key == 'H1' or key == 'H2' or key == 'H3':
                before += value
                value = 0.33
            elif key == 'CA':
                before += value
                value = 0.21
            elif key == 'HA':
                before += value
                value = 0.1
            elif key == 'C':
                before += value
                value = 0.51
            elif key == 'O':
                before += value
                value = -0.51
        offset += (before - 1) # N-terminus has +1 charge

        for key, value in charges.items():
            if key not in ['N', 'H1', 'H2', 'H3', 'CA', 'HA', 'C', 'O']:
                value += offset/(len(charges)-8)

    elif term == "C":
        before = 0.0
        for key, value in charges.items():
            if key == 'C':
                before += value
                value = 0.34
            elif key == 'O':
                before += value
                value = -0.67
            elif key == 'OXT':
                before += value
                value = -0.67
        offset += (before + 2) # C-terminus has -1 charge & formal charge for amide N is -1

        for key, value in charges.items():
            if key not in ['C', 'O', 'OXT']:
                value += offset/(len(charges)-3)

    elif term == 'bb-p':
        before = 0.0
        for key, value in charges.items():
            if key == 'N':
                before += value
                value = -0.47
            elif key == 'H':
                before += value
                value = 0.31
            elif key == 'CA':
                before += value
                value = 0.07
            elif key == 'HA':
                before += value
                value = 0.09
            elif key == 'C':
                before += value
                value = 0.51
            elif key == 'O':
                before += value
                value = -0.51
        offset += (before + 1) # formal charge for amide N is -1, so correct it.

        for key, value in charges.items():
            if key not in ['N', 'H', 'CA', 'HA', 'C', 'O']:
                value += offset/(len(charges)-6)

    return charges



def assign_LJ_parameters(fftype, ff, model, atomtype, element)->tuple:

    resname = atomtype.split('_')[0]
    atomname = atomtype.split('_')[1]

    amber14LJ = {'C':(0.25104, 0.402), 'O':(0.79496, 0.333), 'N':(0.54392, 0.386)}
    # parameters from
    # M. Freindorf, Y. Shao, T. R. Furiani, and J. Kong.
    # "Lennard-Jones parameters for the combined QM/MM method using the B3LYP/6-31G*/AMBER potential"
    # J. Comput. Chem. 2005, 26(12): 1270-1278.
    # https://doi.org/10.1002/jcc.20264

    gaff2AB = {'S':(1.9825, 0.2824), 'P':(2.0732, 0.2295), 'F':(1.7029, 0.0832), 'Cl':(1.9452, 0.2638),
                 'Br':(2.0275, 0.3932), 'I':(2.1558, 0.4955)}
    amberions = {'Al':'AL', 'Ag':'Ag', 'Ba':'BA', 'Br':'BR', 'Be':'Be', 'Ca':'CA', 'Cd':'CD',
                 'Cl':'CL', 'Co':'CO', 'Cs':'CS', 'Cu':'CU', 'Dy':'Dy', 'Er':'Er',
                 'F':'F', 'Gd':'GD3', 'Hg':'HG', 'Hf':'Hf',
                 'In':'IN', 'I':'IOD', 'K':'K', 'La':'LA', 'Li':'LI', 'Lu':'LU', 'Mg':'MG',
                 'Mn':'MN', 'Na':'NA', 'Ni':'NI', 'Nd':'Nd', 'Pb':'PB', 'Pd':'PD', 'Pr':'PR',
                 'Pt':'PT', 'Pu':'Pu', 'Rb':'RB', 'Ra':'Ra', 'Sm':'Sm', 'Sr':'SR', 'Sn':'Sn',
                 'Tb':'TB', 'Th':'Th', 'Tl':'Tl', 'Tm':'Tm', 'U':'U4+', 'V':'V2+', 'Y':'Y',
                 'Yb':'YB2', 'Zn':'ZN', 'Zr':'Zr'}
    charmm36LJ = {'N':(0.8368, 0.329632525712), 'F':(0.56484, 0.290432982114), 'Si':(2.5104, 0.391995435982),
                  'Al':(0.4460144, 0.039377723342), 'P':(2.44764, 0.3830864488), 'Cl':(1.435112, 0.34032331033),
                  'Se':(2.606632, 0.387540942391), 'Br':(2.25936, 0.356359487256), 'I':(2.17568, 0.399122625727),
                  'Be':(0.4250944, 0.007127189745), 'Sc':(0.589944, 0.195445360786), 'Ti':(1.7016328, 0.344011631023),
                  'V':(0.6732056, 0.196656983042), 'Cr':(0.6974728, 0.188727984451), 'Mn':(0.7050040, 0.214474957405),
                  'Fe':(0.667348, 0.192451941093), 'Co':(0.6439176, 0.165279530189), 'Ni':(0.5999856, 0.139033653953),
                  'Cu':(0.7748768, 0.153590939007), 'Ga':(0.5045904, 0.040428983829), 'Sr':(0.9572992, 0.330683786199),
                  'Y':(0.6928704, 0.251839249644), 'Rh':(0.537644, 0.082194315736), 'Pd':(0.7288528, 0.165190440318),
                  'Ag':(1.200808, 0.216363662688), 'In':(0.6163032, 0.154357111905), 'Sn':(0.7861736, 0.295422014935),
                  'La':(0.7932864, 0.315859231529), 'Ce':(0.7702744, 0.304954631219), 'Pr':(0.7627432, 0.295618012653),
                  'Nd':(0.7497728, 0.288633366703), 'Pm':(0.74266, 0.29527947114), 'Sm':(1.0242432,0.330790694046),
                  'Eu':(0.9907712,0.327922000173), 'Gd':(0.7209032, 0.26816051416), 'Tb':(0.7083512,0.262815121851),
                  'Dy':(0.6999832,0.257362821696), 'Ho':(0.6928704, 0.247794569464), 'Er':(0.6845024,0.243625163463),
                  'Tm':(0.677808, 0.240293202257), 'Yb':(0.8593936, 0.286691207498), 'Lu':(0.6644192, 0.24075646959),
                  'Pt':(0.6824104, 0.142027073646), 'Au':(1.4907592, 0.112502690127), 'Hg':(0.8593936, 0.21057282102),
                  'Tl':(0.6815736, 0.154891651136), 'Pb':(1.0016496, 0.313239989298), 'Bi':(0.77404, 0.243785525232),
                  'Ra':(1.2631496,0.372377846208), 'U':(0.7932864, 0.30311937986), 'Pu':(0.7702744,0.297506717936)}
    charmmions = {'Li':'LIT', 'Na':'SOD', 'Mg':'MG', 'K':'POT', 'Ca':'CAL', 'Rb':'RUB',
                  'Cs':'CES', 'Ba':'BAR', 'Zn':'ZN', 'Cd':'CAD', 'Cl':'CLA'}

    if fftype == 'Amber14':
        if element == 'H':
            for res in model.topology.residues():
                if res.name == resname:
                    for bond in res.bonds():
                        if bond[0].name == atomname:
                            parent = bond[1].element.symbol
                            break
                        elif bond[1].name == atomname:
                            parent = bond[0].element.symbol
                            break
                        else:
                            continue
                    if parent == 'C': # nonpolar H
                        lj = (0.12552, 0.222)
                    else:             # polar H
                        lj = (0.12552, 0.106)
                    break
            return lj
        elif element in gaff2AB: # parm99.dat in Amber forcefield
            A = gaff2AB[element][0]
            B = gaff2AB[element][1]
            sigma = ((A/B)**(1/6))/10  # sigma in Angstom
            epsilon = B/4/(sigma**6)*4.184 # epsilon in kcal/mol
            return (epsilon, sigma)
        elif element in amber14LJ:
            return amber14LJ[element]
        elif element in amberions: # use parameters in ions definition
            top = Topology()
            chain = top.addChain(id='X')
            residue = top.addResidue(amberions[element], chain)
            top.addAtom(element, Element.getBySymbol(element), residue)
            system = ff.createSystem(top, residueTemplates=dict(residue = amberions[element]))
            for f in system.getForces():
                if f.getName() == 'NonbondedForce':
                    q, s, e = f.getParticleParameters(0)
                    break
            return (e,s)
    elif fftype == 'CHARMM36':
        if element == 'H':
            for res in model.topology.residues():
                if res.name == resname:
                    for bond in res.bonds():
                        if bond[0].name == atomname:
                            parent = bond[1].element.symbol
                            break
                        elif bond[1].name == atomname:
                            parent = bond[0].element.symbol
                            break
                        else:
                            continue
                    if parent == 'C': # nonpolar H
                        lj = (0.192464, 0.040001352445)
                    else:             # polar H
                        lj = (0.092048, 0.235197261589)
                    return lj
        elif element == 'C':
            for res in model.topology.residues():
                if res.name == resname:
                    count = 0
                    for bond in res.bonds():
                        if bond[0].name == atomname or bond[1].name == atomname:
                            count += 1
                    if count < 4: # aromatic / carbonyl C
                        lj = (0.29288, 0.355005321205)
                    else:         # aliphatic C
                        lj = (0.08368, 0.405358916754)
                    return lj
        elif element == 'S':
            for res in model.topology.residues():
                if res.name == resname:
                    count = 0
                    for bond in res.bonds():
                        if bond[0].name == atomname or bond[1].name == atomname:
                            count += 1
                    if count > 1: # thioether S
                        lj = (1.8828, 0.356359487256)
                    else:         # thiolate S
                        lj = (1.96648, 0.391995435982)
                    return lj
        elif element == 'O':
            for res in model.topology.residues():
                if res.name == resname:
                    count = 0
                    for bond in res.bonds():
                        if bond[0].name == atomname or bond[1].name == atomname:
                            count += 1
                    if count > 1: # ether O
                        lj = (0.4184, 0.293996576986)
                    else:         # anionic O
                        lj = (0.50208, 0.302905564168)
                    return lj
        elif element in charmm36LJ:
            return charmm36LJ[element]

        elif element in charmmions: # use parameters in ions definition
            top = Topology()
            chain = top.addChain(id='X')
            residue = top.addResidue(charmmions[element], chain)
            top.addAtom(element, Element.getBySymbol(element), residue)
            system = ff.createSystem(top)
            for f in system.getForces():
                if f.getName() in ['NonbondedForce', 'LennardJonesForce']:
                    q, s, e = f.getParticleParameters(0)
                    break
            return (e,s)

    # For elements not defined in AMBER/CHARMM, use OpenKIM Universal LJ parameters
    # DOI : https://doi.org/10.25950/962b4967
    if not os.path.exists('/content/LennardJones612_UniversalShifted.params'):
        subprocess.run('wget https://openkim.org/files/MO_959249795837_003/LennardJones612_UniversalShifted.params', shell=True)
    with open('LennardJones612_UniversalShifted.params', 'r') as f:
        lines = f.readlines()
        for line in lines:
            if line and not line.startswith('#'):
                data = line.strip().split( )
                if len(data) == 5 and data[0] == element:
                    epsilon = float(data[3])
                    sigma = float(data[4])
                    break
    return (epsilon*96.485, sigma/10) # eV to kJ/mol, Anstrom to nm



def assign_parameters(fftype, ff, model, residue_infos)->list:
    """
    Assign those novel atom / residue types to the macromolecular forcefield.
    """
    ambertype = {'N':'protein-N', 'CA':'protein-CX', 'C':'protein-C', 'O':'protein-O', 'H':'protein-H', 'HA':'protein-H1'}
    charmmtype = {'N':'NH1', 'CA':'CT1', 'C':'C', 'O':'O', 'H':'H', 'HA':'HB1'}

    templates, residues = ff.generateTemplatesForUnmatchedResidues(model.topology)
    ncAA = []
    for value in residue_infos.values():
        if value['ff'] in ['GFN-FF', 'MACE-off23']:
            ncAA.append(value['name'])
    if len(ncAA) == 0:
        return []

    new_type = []
    unk_index = 0
    uaa_res = []
    lj_parms = {}

    for i, temp in enumerate(templates):
        if temp.name not in ncAA:
            for bond in residues[i].external_bonds(): # Search for normal atoms connected to the nonstandard atoms
                if bond[0].residue is residues[i]:    # e.g. metal coordinating nitrogen in Histidine
                    ext_atom = bond[1]
                    int_atom = bond[0]
                elif bond[1].residue is residues[i]:
                    ext_atom = bond[0]
                    int_atom = bond[1]
                else:
                    continue
                if int_atom.name not in ambertype.keys(): # Remove bond from the topology
                    if ext_atom.residue.name in [value['name'] for value in residue_infos.values()]:
                        print(f'WARNING: A bond between atom {int_atom.name} in {int_atom.residue} and atom {ext_atom.name} in {ext_atom.residue} found!')
                        print(f'You should consider adding {int_atom.residue} to the buffer region.')
                        model.delete(bond)
            continue # some normal residues (especially terminal ones) show up in this list, so skip them

        lig_flag = residue_infos[temp.name]['lig']

        if residue_infos[temp.name]['ff'] != 'GFN-FF':
            charges = assign_partial_charges(temp.name, residue_infos)

        for atom in temp.atoms:
            if residue_infos[temp.name]['ff'] != 'GFN-FF':
                atom.parameters = dict(charge=0.0)
            else:
                atom.parameters = dict(charge=charges[atom.name])

            if not lig_flag and fftype == 'Amber14' and atom.name in ambertype.keys():
                name = atom.name
                atom.type = ambertype[atom.name] # For backbone atoms, follow AMBER scheme

            elif not lig_flag and fftype == 'CHARMM36' and atom.name in charmmtype.keys():
                name = atom.name
                atom.type = charmmtype[atom.name] # For backbone atoms, follow CHARMM scheme

            else:
                name = temp.name+'_'+atom.name # For other atoms, register new atom type
                atom.type = name
                mass = atom.element.mass.value_in_unit(amu)
                element = atom.element.symbol
                ff.registerAtomType({'name':name, 'class':name, 'mass':mass, 'element':element})
                new_type.append(name)
                if residue_infos[temp.name]['ff'] != 'GFN-FF':
                    lj_parms[name] = (0.0, 1.0)
                else:
                    lj_parms[name] = assign_LJ_parameters(fftype, ff, model, name, element)

        try:
            ff.registerResidueTemplate(temp)
        except:
            # search for duplicate residue name, and if present, rename it
            print(f'There already is a residue template named {temp.name} in {fftype} forcefield!')
            newName = 'UK'+str(unk_index)
            print(f'Changing the residue name from {temp.name} to {newName}.')
            residues[i].name = newName
            ncAA[ncAA.index(temp.name)] = newName
            temp.name = newName
            ff.registerResidueTemplate(temp)
            unk_index += 1

        for res in model.topology.residues():
            if res.name == temp.name:
                uaa_res.append(res)


    with open('userff.xml', 'w') as f: # Add custom xml file to set LJ parameters to zeros
        f.write("<ForceField>\n")
        if fftype == "Amber14":
            f.write(' <NonbondedForce coulomb14scale="0.8333333333333334" lj14scale="0.5">\n')
            f.write('  <UseAttributeFromResidue name="charge"/>\n')
            for at in new_type:
                f.write(f'  <Atom class="{at}" sigma="{lj_parms[at][1]}" epsilon="{lj_parms[at][0]}"/>\n')
        elif fftype == "CHARMM36":
            f.write(' <NonbondedForce coulomb14scale="1.0" lj14scale="1.0" useDispersionCorrection="False">\n')
            f.write('  <UseAttributeFromResidue name="charge"/>\n')
            for at in new_type:
                f.write(f'  <Atom class="{at}" epsilon="{lj_parms[at][0]}" sigma="{lj_parms[at][1]}"/>\n')
        f.write(' </NonbondedForce>\n')
        if fftype == 'CHARMM36':
            f.write(' <LennardJonesForce lj14scale="1.0" useDispersionCorrection="False">\n')
            for at in new_type:
                f.write(f'  <Atom epsilon="{lj_parms[at][0]}" sigma="{lj_parms[at][1]}" type="{at}"/>\n')
                #f.write(f'  <Atom epsilon="0.0" epsilon14="0.0" sigma="1.0" sigma14="1.0" type="{at}"/>\n')
            f.write(' </LennardJonesForce>\n')
        f.write("</ForceField>\n")

    ff.loadFile('userff.xml')
    return uaa_res


def update_qm_region(model, atomlist)->int:
    """
    Similar to ONIOM, the MM/XTB or MM/ML boundary must be at isotropic bonds.
    If not placed at carbon-carbon bond, expand the XTB/ML region.
    """
    atoms_to_add = []
    counter = 0
    for atom1 in atomlist:
        for bond in model.topology.bonds():
            if bond[0] == atom1:
                atom2 = bond[1]
            elif bond[1] == atom1:
                atom2 = bond[0]
            else:
                continue
            if atom2 in atomlist:
                continue
            elif atom1.element.symbol != 'C' or atom2.element.symbol != 'C':
                atoms_to_add.append(atom2)
                counter += 1
            elif atom1.residue.name in ['HIS', 'HIP', 'HID', 'HIE', 'PHE', 'TYR', 'TRP']:
                if atom1.name not in  ['C', 'O', 'N', 'H', 'CA', 'HA', 'CB', 'HB1', 'HB2', 'HB3']:
                    if atom2.name != 'CB':
                        atoms_to_add.append(atom2)
                        counter += 1
    atomlist += atoms_to_add
    return counter


def find_neighbors(model, res_atoms, cutoff)->list:
    """
    Find atoms within a cutoff distance from the given residue.
    """
    contact_atoms = []
    for i in res_atoms:
        v1 = model.positions[i]
        for j, v2 in enumerate(model.positions):
            if j in res_atoms:
                continue
            if np.linalg.norm(np.array([v1.x-v2.x, v1.y-v2.y, v1.z-v2.z])) < cutoff:
                contact_atoms.append(j)

    qm_atoms = []
    for atom in model.topology.atoms():
        if atom.index in set(res_atoms+contact_atoms):
            qm_atoms.append(atom)

    counter = -1
    while counter != 0:
        counter = update_qm_region(model, qm_atoms)

    return [a.index for a in qm_atoms]


if len(uaas) == 0:
    print('No need to run this cell!')
else:
    with open(work+'non_standard_residues.txt', 'w') as f:
        f.write('# User-defined Non-standard Residues\n')
        f.write('#name\t#type\t#forcefield\n')
        for uaa in uaa_info.values():
            f.write(f"{uaa['name']}\t{'ligand' if uaa['lig'] else 'residue'}\t{uaa['ff']}\n")

    if make_xml_for_custom_residues(forcefield, ff, model, uaa_info):
        print('Upload a XML file for the non-standard residue definition.')
        custom_xml = files.upload()
        ff.loadFile(next(iter(custom_xml)))

    uaa_res = assign_parameters(forcefield, ff, model, uaa_info)

    if len(uaa_res) != 0:
        uaa_res_names = [f'{res.name}_{res.index}{res.chain}' for res in uaa_res]
        # Interactively determine the group of atoms to be treated differently
        res_select = widgets.Dropdown(options=uaa_res_names, value=uaa_res_names[0], description='residue', continuous_update=True)
        uaa_atom_info = []
        for res in uaa_res:
            atoms = []
            for a in res.atoms():
                atoms.append(a.index)
            uaa_atom_info.append({'res':res, 'atoms':atoms, 'cutoff':0.0, 'neighbor':find_neighbors(model, atoms, 0.5), 'charge':0, 'spin':1})

        cutoff_select = widgets.FloatSlider(value=0.0, min=0.0, max=1.5, step=0.1, description='cutoff(nm)', continuous_update=True)
        charge_input = widgets.IntText(value=0, description='Total charge:', disabled=False, continuous_update=True)
        spin_input = widgets.IntText(value=1, description='Spin :', disabled=False, continuous_update=True)

        view4 = nv.NGLWidget()
        view4._set_size('750px','500px')
        view4.add_component(nv.FileStructure(work+"starting_structure.pdb"), defaultRepresentation=False)
        view4.add_cartoon("protein or nucleic")
        view4.add_line()
        uaa_atom_info[0]['cutoff'] = cutoff_select.value
        neighbor = find_neighbors(model, uaa_atom_info[0]['atoms'], cutoff_select.value)
        uaa_atom_info[0]['neighbor'] = neighbor
        uaa_atom_info[0]['charge'] = charge_input.value
        uaa_atom_info[0]['spin'] = spin_input.value
        neighborhood = '@'+','.join([str(i) for i in neighbor])
        view4.add_ball_and_stick(neighborhood)
        view4.center(neighborhood)

        def update_res(change):
            res_index = uaa_res_names.index(change.new)
            view4.clear_representations()
            view4.add_cartoon("protein or nucleic")
            view4.add_line()
            uaa_atom_info[res_index]['charge'] = charge_input.value
            uaa_atom_info[res_index]['spin'] = spin_input.value
            neighbor = find_neighbors(model, uaa_atom_info[res_index]['atoms'], cutoff_select.value)
            neighborhood = '@'+','.join([str(i) for i in neighbor])
            view4.add_ball_and_stick(neighborhood)
            view4.center(neighborhood)

        def update_cutoff(change):
            res_index = uaa_res_names.index(res_select.value)
            view4.clear_representations()
            view4.add_cartoon("protein or nucleic")
            view4.add_line("not water")
            uaa_atom_info[res_index]['cutoff'] = change.new
            neighbor = find_neighbors(model, uaa_atom_info[res_index]['atoms'], change.new)
            neighborhood = '@'+','.join([str(i) for i in neighbor])
            view4.add_ball_and_stick(neighborhood)
            view4.center(neighborhood)
            uaa_atom_info[res_index]['neighbor'] = neighbor
            uaa_atom_info[res_index]['charge'] = charge_input.value
            uaa_atom_info[res_index]['spin'] = spin_input.value

        def update_charge(change):
            res_index = uaa_res_names.index(res_select.value)
            uaa_atom_info[res_index]['charge'] = change.new

        def update_spin(change):
            res_index = uaa_res_names.index(res_select.value)
            uaa_atom_info[res_index]['spin'] = change.new

        res_select.observe(update_res, names='value')
        cutoff_select.observe(update_cutoff, names='value')
        charge_input.observe(update_charge, names='value')
        spin_input.observe(update_spin, names='value')

        display(res_select, cutoff_select, view4, charge_input, spin_input)
    else:
        for uaa in uaa_info.values():
            assert uaa['lig'], 'Did you run this cell more than once? Run again from the cell# 2-1.'
        print('No non-ligand residue found!')

Dropdown(description='residue', options=('CGA_57<Chain 0>',), value='CGA_57<Chain 0>')

FloatSlider(value=0.0, description='cutoff(nm)', max=1.5)

NGLWidget()

IntText(value=0, continuous_update=True, description='Total charge:')

IntText(value=1, continuous_update=True, description='Spin :')

### 2-3. Prepare Simulation

In [None]:
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert 'uaas' in globals(), "Please run from the cell# 2-1."

#@markdown Set the ionic strength of the solution in mol/L.
ion_conc = 0.15 #@param {type:"slider", min:0.0, max:2.0, step:0.05}
#@markdown For explicit water models, set the additional parameters below.

#@markdown Set the thickness of the solvent pad around the protein in nanometers.
padding = 1 #@param {type:"slider", min:1.0, max:5.0, step:0.1}
cation = "Na+" #@param ["Li+", "Na+", "K+", "Rb+", "Cs+"]
anion = "Cl-" #@param ["F-", "Cl-", "Br-", "I-"]
#@markdown Set below constraints / increase the hydrogen mass to allow larger integration step sizes.
rigidWater = True #@param {type:"boolean"}
constraints = "None" #@param ["None", "HBonds", "AllBonds", "HAngles"]
hydrogenMass = 2 #@param {type:"slider", min:1.0, max:3.0, step:0.1}
#@markdown Set the precision of the calculation: (cost-efficient) single -> mixed -> double (high precision)
precision = "single" #@param ["single", "mixed", "double"]

assert forcefield != "CHARMM36" or all(p['ff'] != 'MACE-off23' for p in uaa_info.values()), "Currently, ML forcefield can only be coupled with AMBER forcefield!"
assert forcefield != "CHARMM36" or padding >= 1.2, "The minimum padding distance must be larger than the cutoff distance of 1.2 nm!"
assert all(p['ff'] != 'MACE-off23' for p in uaa_info.values()) or water != '', 'ML forcefield does not support implicit solvation model!'

def get_openff_ligand(work, ligand_info):
    """
    Register ligand template to the macromolecular forcefield.
    """
    name = ligand_info["name"]

    if os.path.exists(work+name+'.sdf'):
        sdf_file = work + name+'.sdf'
    elif os.path.exists(name+'_ideal.sdf'):
        sdf_file = name+'_ideal.sdf'
    else:
        subprocess.run(f'wget https://files.rcsb.org/ligands/download/{name}_ideal.sdf', shell=True)
        sdf_file = f'{name}_ideal.sdf'

    SDF = Chem.rdmolfiles.SDMolSupplier(sdf_file)[0]
    PDB = AllChem.MolFromPDBFile(f'{name}.pdb', removeHs=True)
    ligand_noH = AllChem.AssignBondOrdersFromTemplate(SDF, PDB)
    ligand = AllChem.MolFromPDBFile(f'{name}.pdb', removeHs=False)

    for bond in ligand_noH.GetBonds():
        pos1 = ligand_noH.GetConformer().GetAtomPosition(bond.GetBeginAtomIdx())
        pos2 = ligand_noH.GetConformer().GetAtomPosition(bond.GetEndAtomIdx())
        for i in range(ligand.GetNumAtoms()):
            pos = ligand.GetConformer().GetAtomPosition(i)
            if all(pos[j] == pos1[j] for j in range(3)):
                aid1 = i
            elif all(pos[j] == pos2[j] for j in range(3)):
                aid2 = i
            else:
                continue
        ligand.GetBondBetweenAtoms(aid1, aid2).SetBondType(bond.GetBondType())

    Chem.SanitizeMol(ligand)

    Chem.rdmolops.AssignStereochemistryFrom3D(ligand)
    openff_ligand = Molecule.from_rdkit(ligand, hydrogens_are_explicit=True)
    openff_ligand.name = name
    openff_ligand.assign_partial_charges('mmff94')

    with open(f'{work}{name}.json', 'w') as f:
        f.write(openff_ligand.to_json())

    return openff_ligand


def register_ml_residues(system, topology, pbc, residue_infos, atom_infos):
    """
    Create MM-ML or MM-XTB mixed system
    """

    xtb_atoms = []
    total_charge = 0
    total_e = 0
    ml_atoms = []

    for key1, value1 in residue_infos.items():
        for res in atom_infos:
            if res['res'].name == key1:
                if value1['ff'] == 'GFN-FF':
                    xtb_atoms.extend(res['neighbor'])
                    total_charge += res['charge']
                    total_e += float(res['spin']-1)/2
                elif value1['ff'] == 'MACE-off23':
                    ml_atoms.extend(res['neighbor'])

    xtb_atoms = set(xtb_atoms)
    ml_atoms = set(ml_atoms)

    element_list = []

    for a in list(xtb_atoms):
        for atom in topology.atoms():
            if atom.index == a:
                element_list.append(atom.element.atomic_number)

    if len(xtb_atoms&ml_atoms)>0:
        print('An atom cannot be allocated to two or more different forcefields!')
        print('Unify your small molecule forcefields or decrease cutoffs to remove overlapping atoms.')
        return None

    if len(ml_atoms) > 0:
        potential = MLPotential('mace-off23-small')
        print("Creating MM-ML Mixed system...")
        warnings.filterwarnings('ignore')
        system2 = potential.createMixedSystem(topology, system, ml_atoms,
                                              removeConstraints=False,
                                              implementation='nnpops',
                                              precision='single')
        print("done.")
    else:
        system2 = system

    if len(xtb_atoms) > 0:
        print("Creating MM-XTB Mixed system...")
        system2.addForce(XtbForce(XtbForce.GFNFF, total_charge, int(total_e/2+1),
                                  pbc, xtb_atoms, element_list))
        print("done.")

    return system2


def save_settings(work, forcefield, watermodel, ion_conc, pbc, cation, anion, padding, rigidWater, constraints, hydrogenMass, precision):
    """
    Save simulation settings to settings.txt file.
    """
    settings = f"""
    Forcefield : {forcefield}
    Watermodel : {watermodel}
    Ionic strength (in mol/L) : {ion_conc}
    Periodic Boundary Condition : {pbc}
    Cation : {cation}
    Anion : {anion}
    Padding (in nm) : {padding}
    Rigidwater : {rigidWater}
    Constraints : {constraints}
    Hydrogen mass (in amu) : {hydrogenMass}
    Precision : {precision}
    """
    with open(work+"settings.txt", 'w') as f:
        f.write(settings)
    return



gaff = []
smirnoff = []
espaloma = []

for key, value in uaa_info.items():
    if value['ff'] == 'GAFF-2.11':
        gaff.append(get_openff_ligand(work, value))
    elif value['ff'] == 'Sage-2.2.1':
        smirnoff.append(get_openff_ligand(work, value))
    elif value['ff'] == 'espaloma-0.3.2':
        espaloma.append(get_openff_ligand(work, value))

if len(gaff) > 0:
    gaff_g = GAFFTemplateGenerator(molecules=gaff)
    ff.registerTemplateGenerator(gaff_g.generator)
if len(smirnoff) > 0:
    smirnoff_g = SMIRNOFFTemplateGenerator(molecules=smirnoff)
    ff.registerTemplateGenerator(smirnoff_g.generator)
if len(espaloma) > 0:
    espaloma_g = EspalomaTemplateGenerator(molecules=espaloma)
    ff.registerTemplateGenerator(espaloma_g.generator)

if water != '':
    print("Adding solvent...", end='')
    model.addSolvent(ff, model=water, padding=padding*nanometers, positiveIon=cation,
                     negativeIon=anion, ionicStrength=ion_conc*molar, neutralize=True)
    print("done.")

    print("Creating system...", end='')

    if forcefield == "Amber14":
        system = ff.createSystem(model.topology, nonbondedMethod=PME,
                                 nonbondedCutoff=1.0 *nanometer, constraints=constraints,
                                 rigidWater=rigidWater, removeCMMotion=True,
                                 hydrogenMass=hydrogenMass*amu)
    elif forcefield == "CHARMM36":
        # CHARMM forcefield recommends the use of switch distance at 1.0 nm and cutoff distance at 1.2 nm.
        system = ff.createSystem(model.topology, nonbondedMethod=PME,
                                 nonbondedCutoff=1.2*nanometer, constraints=constraints,
                                 rigidWater=rigidWater, removeCMMotion=True,
                                 switchDistance=1.0*nanometer,
                                 hydrogenMass=float(hydrogenMass)*amu)

else:
    print("Creating system...", end='')

    kappa = 367.434915*sqrt(float(ion_conc)/78.5/(int(temperature)+273.15))

    if forcefield == "Amber14":
        system = ff.createSystem(model.topology, nonbondedMethod=CutoffNonPeriodic,implicitSolventKappa=kappa,
                                 nonbondedCutoff=1.0*nanometer, constraints=constraints, hydrogenMass=float(hydrogenMass)*amu)

    elif forcefield == "CHARMM36":
        system = ff.createSystem(model.topology, nonbondedMethod=CutoffNonPeriodic,implicitSolventKappa=kappa,
                                 nonbondedCutoff=1.2*nanometer, switchDistance=1.0*nanometer, constraints=constraints,
                                 hydrogenMass=float(hydrogenMass)*amu)

print('done.')



if water == '':
    pbc = False
else:
    pbc = True

final_system = register_ml_residues(system, model.topology, pbc, uaa_info, uaa_atom_info)

assert final_system is not None, "System not created!"


with open(work+'system.xml', 'w') as f:
    f.write(XmlSerializer.serialize(final_system))

print("Creating simulation...", end='')
T = 273.15*kelvin

if any([p['ff'] == 'MACE-off23' for p in uaa_info.values()]):

    integrator = NoseHooverIntegrator(0.001*femtoseconds)

    ml_atoms = []
    for p in uaa_atom_info:
        ml_atoms.extend(p['neighbor'])
    ml_atoms = set(ml_atoms)

    mm_atoms = []
    for a in model.topology.atoms():
        if a.index not in ml_atoms:
            mm_atoms.append(a.index)

    integrator.addSubsystemThermostat(thermostatedParticles=mm_atoms,
                                      thermostatedPairs=[],
                                      temperature=T, relativeTemperature=T,
                                      collisionFrequency=1/picosecond,
                                      relativeCollisionFrequency=1/picosecond,
                                      chainLength=3, numMTS=3, numYoshidaSuzuki=7)

    integrator.addSubsystemThermostat(thermostatedParticles=list(ml_atoms),
                                      thermostatedPairs=[],
                                      temperature=T, relativeTemperature=T,
                                      collisionFrequency=200/picosecond,
                                      relativeCollisionFrequency=200/picosecond,
                                      chainLength=5, numMTS=5, numYoshidaSuzuki=7)

elif water != '':
    integrator = NoseHooverIntegrator(0.1*femtoseconds)

    integrator.addThermostat(temperature=T, collisionFrequency=1/picosecond,
                             chainLength=3, numMTS=3, numYoshidaSuzuki=7)

else:
    integrator = LangevinMiddleIntegrator(T, 1/picosecond, 0.1*femtoseconds)

with open(work+'integrator.xml', 'w') as f:
    f.write(XmlSerializer.serialize(integrator))

platform = Platform.getPlatformByName('CUDA')
properties = {'Precision': precision}
simulation = Simulation(model.topology, final_system, integrator, platform, properties)
print("done.")

print("Minimizing...", end='')
simulation.context.setPositions(model.positions)
simulation.minimizeEnergy(tolerance=0.001*kilojoules/(nanometer*mole))
positions = simulation.context.getState(getPositions=True).getPositions()

with open(work+"minimized.pdb", 'w') as f:
    PDBFile.writeFile(simulation.topology, positions, f)
print("done.")

save_settings(work, forcefield, watermodel, ion_conc, pbc, cation, anion, padding, rigidWater, constraints, hydrogenMass, precision)

simulation.context.setStepCount(0)
simulation.context.setTime(0)
state = simulation.context.getState(getPositions=True, getVelocities=True)
with open(work+"start.xml", 'w') as f:
    f.write(XmlSerializer.serialize(state))
stage = 'start'

print("You are now ready to start the simulation!")

view3 = nv.NGLWidget()
view3._set_size('750px','500px')
view3.add_component(nv.FileStructure(work+"minimized.pdb"))
view3.add_licorice()
if water != '':
    view3.add_point("water")
    view3.add_spacefill("ion")
view3.center()
view3




Adding solvent...done.
Creating system...done.
Creating MM-ML Mixed system...
Using MACE-OFF23 MODEL for MACECalculator with /root/.cache/mace/MACE-OFF23_small.model
Model dtype is torch.float64 and requested dtype is torch.float32. The model will be converted to the requested dtype.
done.
Creating simulation...done.
Minimizing...done.
You are now ready to start the simulation!


NGLWidget()

## 3. Run the Simulation

In [None]:
#@title 3-1. NVT Equilibration
assert 'work' in globals(), "Please run the cell# 0 to set working directory."

#@markdown Equilibrate the system for n more picoseconds.
eq_time = 50 #@param {type:"slider", min:0.0, max:1000.0, step:10.0}
#@markdown Set the integration step size in femtoseconds.
step_size = 0.05 #@param {type:"slider", min:0.01, max:1, step:0.01}
#@markdown Set the temperature of the system in **degrees Celsius**.
temperature = 25 #@param {type:"slider", min:0.0, max:100.0, step:1.0}
#@markdown Backup the production run's progress to the GoogleDrive every n realtime minutes.
#@markdown Set it to smaller values if you are using free Colab sessions.
backup_interval = 1 #@param {type:"slider", min:1, max:60, step:1}
#@markdown Spring constant (in kJ/mol*nm^2) for the positional restraint on heavy atoms
positional_restraint = 0 #@param {type:"slider", min:0, max:3000, step:100}

simulation, stage = recall_simulation(work)

assert stage != 'npt', 'NPT equilibration is already started.'
assert stage != 'production', 'Production run already started.'
assert eq_time > 0, 'You set the equilibration time to zero. No more NVT equilibration performing.'

if stage == 'start':
    simulation = apply_restraints(work, simulation, int(positional_restraint))
    simulation.context.setVelocitiesToTemperature((int(temperature)+273.15)*kelvin)
    print('Now starting the NVT equilibration!')
    print(f'Running NVT run for {eq_time} picoseconds.')
    stage = 'nvt'
    state = simulation.context.getState(getPositions=True, getVelocities=True)
    with open(work+'nvt_step_0.xml', 'w') as f:
        f.write(XmlSerializer.serialize(state))
    assert os.path.exists(work+'nvt_step_0.xml'), 'State data not recorded.'

else:
    print(f'NVT equilbration has been performed for {simulation.context.getTime()} picoseconds.')
    print(f'Running NVT run for {eq_time} more picoseconds.')

subprocess.run(f'echo "\tTime step (in fs) : {step_size}" >> {work}settings.txt', shell=True)
subprocess.run(f'echo "\tTemperature (in Celsius) : {temperature}" >> {work}settings.txt', shell=True)
subprocess.run(f'echo "\tRestraint (NVT, kJ/mol*nm^2) : {positional_restraint}" >> {work}settings.txt', shell=True)

stream = SimulationStream()
f = open(work+"nvt_log.txt", 'a')
f.close()
run_simulation(work, stage, simulation, stream, step_size, temperature,
               float(eq_time), 5*picoseconds, int(backup_interval))

Resuming NVT equilibration...done.
NVT equilbration has been performed for 0.44560000000012545 ps picoseconds.
Running NVT run for 50 more picoseconds.
09:25:47 left, Step: 33100, Time: 0.5 ps, Temperature: 287.33 C, Density: 0.9637 g/mL, Speed: 0.1 ns/day

ValueError: Energy is NaN.  For more information, see https://github.com/openmm/openmm/wiki/Frequently-Asked-Questions#nan

In [None]:
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'nvt_log.txt'), "Please run the cell# 3-1 first."

#@markdown Verify temperature convergence
time = []
temp = []
s = read_settings(work)

with open(work+"nvt_log.txt", 'r') as f:
    for line in f.readlines():
        if line.startswith("#"):
            continue
        line = line.strip().split(',')
        step = float(line[0])
        time.append(step*s['stepsize']/picoseconds)
        temp.append(round(float(line[3]),2))

target_temp = round(s['temperature']/kelvin, 2)

fig1 = go.Figure()
fig1.add_trace(go.Scatter(x=time, y=temp, mode='lines', name='Temperature'))
fig1.add_hline(y=target_temp, annotation_text="target T : "+str(target_temp)+" K")
fig1.update_layout(title='NVT equilibration', xaxis_title='Time (ps)', yaxis_title='Temperature (K)', width=750, height=500)
fig1.show()

print("Check whether the system's apparent temperature stabilizes (flattens).")
print("If not, run the above cell again to extend the equilibration run.")
print("If yes, proceed to the next cell.")

Check whether the system's apparent temperature stabilizes (flattens).
If not, run the above cell again to extend the equilibration run.
If yes, proceed to the next cell.


In [None]:
#@title 3-2. NPT Equilibration
assert 'work' in globals(), "Please run the cell# 0 to set working directory."

#@markdown Equilibrate the system for n more picoseconds.
eq_time = 50 #@param {type:"slider", min:0.0, max:1000.0, step:10.0}
#@markdown Set the integration step size in femtoseconds.
step_size = 0.1 #@param {type:"slider", min:0.1, max:4, step:0.1}
#@markdown Set the temperature of the system in **degrees Celsius**.
temperature = 25 #@param {type:"slider", min:0.0, max:100.0, step:1.0}
#@markdown Pressure on the system in bar.
pressure = 1 #@param {type:"slider", min:0.0, max:100.0, step:1.0}
#@markdown Backup the production run's progress to the GoogleDrive every n realtime minutes.
#@markdown Set it to smaller values if you are using free Colab sessions.
backup_interval = 10 #@param {type:"slider", min:5, max:60, step:5}
#@markdown Spring constant (in kJ/mol*nm^2) for the positional restraint on heavy atoms
positional_restraint = 1000 #@param {type:"slider", min:0, max:3000, step:100}


s = read_settings(work)

simulation, stage = recall_simulation(work)

assert s['PBC'], 'No periodic boundary condition defined. Cannot run NPT equilibrium.'
assert stage != 'start', 'Run NVT equilibration first!'
assert stage != 'production', 'Production run already started. You can no longer run this cell.'
assert eq_time > 0, 'You set the equilibration time to zero. No more NPT equilibration performing.'

if stage == 'nvt':
    barostat = MonteCarloBarostat(int(pressure)*bar, s['temperature'])
    simulation.system.addForce(barostat)
    simulation.context.reinitialize(preserveState=True)
    simulation.context.setTime(0)
    simulation.context.setStepCount(0)
    state = simulation.context.getState(getPositions=True, getVelocities=True)
    with open(work+"npt_step_0.xml", 'w') as f:
        f.write(XmlSerializer.serialize(state))
    simulation.saveCheckpoint(work+"npt.chk")
    print('Now starting the NPT equilibration!')
    print(f'Running NPT run for {eq_time} picoseconds.')
    stage = 'npt'
elif stage == 'npt':
    print(f'NPT equilbration has been performed for {simulation.context.getTime()} picoseconds.')
    print(f'Running NPT run for {eq_time} more picoseconds.')

subprocess.run(f'echo "\tTime step (in fs) : {step_size}" >> {work}settings.txt', shell=True)
subprocess.run(f'echo "\tTemperature (in Celsius) : {temperature}" >> {work}settings.txt', shell=True)
subprocess.run(f'echo "\tPressure (in bar) : {pressure}" >> {work}settings.txt', shell=True)
subprocess.run(f'echo "\tRestraint (NPT, kJ/mol*nm^2) : {positional_restraint}" >> {work}settings.txt', shell=True)

stream = SimulationStream()
f = open(work+"npt_log.txt", 'a')
f.close()
run_simulation(work, stage, simulation, stream, step_size, temperature,
               float(eq_time), 5*picoseconds, int(backup_interval))

In [None]:
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'nvt_log.txt'), "Please run the cell# 3-2 first."

#@markdown Verify temperature & density convergence
time = []
temp = []
dens = []

with open(work+"npt_log.txt", 'r') as f:
    for line in f.readlines():
        if line.startswith("#"):
            continue
        line = line.strip().split(',')
        step = float(line[0])
        time.append(step*s['stepsize']/picoseconds)
        temp.append(round(float(line[3]),2))
        dens.append(round(float(line[5]),5))

target_temp = round(s['temperature']/kelvin,2)

fig2 = make_subplots(specs=[[{"secondary_y":True}]])
fig2.add_trace(go.Scatter(x=time, y=temp, mode='lines', name='Temperature'), secondary_y=False)
fig2.add_hline(y=target_temp, annotation_text="target T : "+str(target_temp)+" K")
fig2.add_trace(go.Scatter(x=time, y=dens, mode='lines', name='Density'), secondary_y=True)
fig2.update_layout(title='NPT equilibration', width=750, height=500)
fig2.update_xaxes(title_text="Time (ps)")
fig2.update_yaxes(title_text="Temperature (K)", secondary_y=False)
fig2.update_yaxes(title_text="Density (g/mL)", secondary_y=True)
fig2.show()

print("Check whether the system's apparent temperature stabilizes to the desired value.")
print("Check also the system's density stabilizes to a value near 1.04 g/mL for physiological conditions.")
print("If not, go back to the cell# 3-2.")
print("If yes, proceed to the next cell.")

In [None]:
#@title 3-3. Production Run
assert 'work' in globals(), "Please run the cell# 0 to set working directory."

#@markdown Run the production MD simulation for n nanoseconds.
run_time = 1 #@param {type:"slider", min:1, max:500.0, step:1}
#@markdown Set the integration step size in femtoseconds.
step_size = 0.1 #@param {type:"slider", min:0.1, max:4, step:0.1}
#@markdown Set the temperature of the system in **degrees Celsius**.
temperature = 25 #@param {type:"slider", min:0.0, max:100.0, step:1.0}
#@markdown Save the trajectory in every n picoseconds.
save_interval = 5 #@param {type:"slider", min:1, max:50, step:1}
#@markdown Backup the production run's progress to the GoogleDrive every n realtime minutes.
#@markdown Set it to smaller values if you are using free Colab sessions.
backup_interval = 10 #@param {type:"slider", min:5, max:60, step:5}

s = read_settings(work)

simulation, stage = recall_simulation(work)

if stage != 'production':
    if stage == 'start':
        simulation.context.setVelocitiesToTemperature((int(temperature)+273.15)*kelvin)
    simulation.context.setTime(0)
    simulation.context.setStepCount(0)
    state = simulation.context.getState(getPositions=True, getVelocities=True)
    with open(work+"production_step_0.xml", 'w') as f:
        f.write(XmlSerializer.serialize(state))
    simulation = restore_simulation_from_state(work, work+'production_step_0.xml', 0, 'production')
    if s['PBC']:
        barostat = MonteCarloBarostat(s['pressure'], s['temperature'])
        simulation.system.addForce(barostat)
        simulation.context.reinitialize(preserveState=True)

    simulation.saveCheckpoint(work+"production.chk")
    print('Now starting the MD production run!')
    print(f'Running for {run_time} nanoseconds.')
    stage = 'production'
else:
    print(f'MD run has been performed for {round(simulation.context.getTime().value_in_unit(nanosecond), 1)} nanoseconds.')
    print(f'Running for {run_time} more nanoseconds.')

subprocess.run(f'echo "\tTime step (in fs) : {step_size}" >> {work}settings.txt', shell=True)
subprocess.run(f'echo "\tTemperature (in Celsius) : {temperature}" >> {work}settings.txt', shell=True)
subprocess.run(f'echo "\tSave interval (in ps) : {save_interval}" >> {work}settings.txt', shell=True)
stream = SimulationStream()
f = open(work+"production_log.txt", 'a')
f.close()
run_simulation(work, stage, simulation, stream, step_size, temperature,
               int(run_time)*1000, int(save_interval)*picoseconds, int(backup_interval))

In [None]:
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'minimized.pdb'), "Please run the cell# 3-1 first."

#@markdown Check the snapshot of current state
view_water_box = True #@param {type:"boolean"}
s = read_settings(work)

view4 = nv.NGLWidget()
view4._set_size('750px','500px')
view4.add_component(nv.FileStructure(work+"minimized.pdb"), defaultRepresentation=False)
view4.add_cartoon("protein or nucleic")
view4.add_simplified_base("nucleic")
view4.add_licorice("protein or nucleic or ligand")
if s['PBC'] and view_water_box:
    view4.add_point("water")
    view4.add_spacefill("ion")
state = simulation.context.getState(getPositions=True, enforcePeriodicBox=s['PBC'])
view4.set_coordinates({0:state.getPositions(asNumpy=True)*10})
view4.update_representation(component=0)
view4.center()
display(view4)

NGLWidget()

## 4. Trajectory Analysis

In [None]:
#@title 4-1. Wrap and View Trajectory
assert 'work' in globals(), "Please run the cell# 0 to set working directory."

#@markdown This cell will concatenate the trajectories generated during the Step 3-1, 3-2, and 3-3. <br/>
#@markdown Check the box below to produce a 'smoothed' final trajectory. <br/>
#@markdown Download your trajectories whose names end with _final.xtc from your GoogleDrive!

smooth_trajectory = False #@param {type:"boolean"}

def wrap_trajectories(work, stage):
    """
    1. Merge the trajectory XTC files into one
    2. Wrap around PBC box if necessary
    3. Align the trajectory to the first frame
    4. Save the processed trajectory to the XTC and PDB files.
    """
    s = read_settings(work)
    print(f'Reading the {stage} trajectory files...', end='')
    trajs = []
    for t in sorted(glob.glob(work+stage+'_*.xtc')):
        if t != work+stage+'_final.xtc':
            trajs.append(md.load(t, top=work+'minimized.pdb'))
    trj = md.join(trajs)
    print('done.')

    if s['PBC']:
        print('Wrapping the trajectory...', end='')
        trj.image_molecules(inplace=True, anchor_molecules=trj.topology.guess_anchor_molecules())
        print("done.")

    trj.save_hdf5(work+stage+'.h5')

    print("Aligning the trajectory...", end='')
    alphas = trj.topology.select_atom_indices('alpha')
    if len(alphas) > 0:
        traj2 = trj.superpose(trj, 0, atom_indices=trj.topology.select_atom_indices('alpha'))
    else:
        nucleic = 'resname A or resname T or resname G or resname C or resname U or resname DA or resname DT or resname DG or resname DC'
        traj2 = trj.superpose(trj, 0, atom_indices=trj.topology.select(f'({nucleic}) and (symbol == "P" or symbol == "O" or symbol == "N")'))
    print('done.')

    print(f'Writing the {stage} trajectory...', end='')

    if smooth_trajectory:
        filter_width = max(12*picoseconds/s['save interval'], 1)
        if filter_width > 3:
            traj2.smooth(int(filter_width), order=3, inplace=True)

    traj2.save_xtc(work+stage+'_final.xtc')

    print('done.')

    return


to_view = {}
if os.path.exists(work+'nvt_0.xtc'):
    wrap_trajectories(work, 'nvt')
if os.path.exists(work+'npt_0.xtc'):
    wrap_trajectories(work, 'npt')
if os.path.exists(work+'production_0.xtc'):
    wrap_trajectories(work, 'production')

if os.path.exists(work+'nvt_final.xtc'):
    to_view['nvt'] = md.load(work+'nvt_final.xtc', top=work+'minimized.pdb')
if os.path.exists(work+'npt_final.xtc'):
    to_view['npt'] = md.load(work+'npt_final.xtc', top=work+'minimized.pdb')
if os.path.exists(work+'production_final.xtc'):
    to_view['production'] = md.load(work+'production_final.xtc', top=work+'minimized.pdb')

assert len(to_view) > 0, "No trajectory result XTC files found! Did you erase it?"

choice = widgets.Dropdown(options=to_view.keys(), value='production',
                          description='trajectory to view',
                          style={'description_width': 'initial'},
                          layout=widgets.Layout(width='750px'),
                          continuous_update=True)

def update_choice(change):
    with view_traj:
        clear_output(wait=True)
        view = nv.show_mdtraj(to_view[change.new], defaultRepresentation=False)
        view._set_size('750px','500px')
        view.add_cartoon()
        view.add_simplified_base("nucleic")
        view.add_ball_and_stick("ligand")
        view.center()
        display(view)

choice.observe(update_choice, names='value')

view_traj = widgets.Output()

with view_traj:
    view = nv.show_mdtraj(to_view['production'], defaultRepresentation=False)
    view._set_size('750px','500px')
    view.add_cartoon()
    view.add_simplified_base("nucleic")
    view.add_ball_and_stick("ligand")
    view.center()
    display(view)

print('Select a trajectory to view in a NGLViewer Widget.')
print('Wait for the widget to restart upon selection.')
display(choice)
display(view_traj)

In [None]:
#@title 4-2. Frame Clustering
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown Set the RMSD threshold value (in nanometers) for the trajectory frame clustering. The structures within this threshold will be grouped into the same cluster.
rmsd_threshold = 0.15 #@param {type:"slider", min:0.05, max:1.0, step:0.05}

#@markdown Centroid structures (representative structures) of each clusters will be saved to different frames of "clustered.pdb" file.
rmsd_matrix = np.zeros((traj.n_frames, traj.n_frames))

for i in tqdm(range(traj.n_frames)):
    rmsd_matrix[i] = md.rmsd(traj, traj, frame=i, atom_indices=traj.topology.select_atom_indices('heavy'))

clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=rmsd_threshold, metric='precomputed', linkage='average')
cluster_labels = clustering.fit_predict(rmsd_matrix)

n_clusters = len(set(cluster_labels))
cluster_sizes = Counter(cluster_labels)
sorted_clusters = sorted(cluster_sizes.items(), key=lambda x:x[1], reverse=True)

print(f"Total {n_clusters} clusters found!")

centroids = []
for cluster, size in sorted_clusters:
    cluster_frames = np.where(cluster_labels == cluster)[0]
    centroid_idx = cluster_frames[np.argmin([np.sum(rmsd_matrix[frame, cluster_frames]) for frame in cluster_frames])]
    print(f"Cluser {cluster} (size = {size})'s centroid structure : frame {centroid_idx}")
    centroids.append((cluster, size, centroid_idx))

cent_traj = md.join([traj[idx] for _, _, idx in centroids])
cent_traj.save_pdb(work+"clustered.pdb")

view6 = nv.show_mdtraj(cent_traj, defaultRepresentation=False)
view6._set_size('750px','500px')
view6.add_cartoon()
view6.add_licorice()
view6.add_simplified_base("nucleic")
view6.add_ball_and_stick("ligand")
view6.center()
view6.player.delay = 3000
display(view6)

In [None]:
#@title 4-3. Root Mean Squared Deviation
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell shows the backbone atoms' RMSD change along the trajectory.

nsres = read_non_standard_residues_list(work)
rmsd = md.rmsd(traj, traj, 0, atom_indices=traj.topology.select('backbone'))

lig_select = widgets.Dropdown(options=['whole structure']+[f"{a['name']} and rest" for a in nsres],
                              value='whole structure', description='calculate RMSDs for',
                              style={'description_width': 'initial'},
                              layout=widgets.Layout(width='750px'),
                              continuous_update=True)

def update_RMSD_chart(change):
    with view_chart:
        clear_output(wait=True)
        if change.new == 'whole structure':
            fig = go.Figure()
            fig.add_trace(go.Scatter(x=sorted(traj.time/1000), y=rmsd, mode='lines', name='RMSD', connectgaps=False))
            fig.update_layout(title='whole structure backbone RMSD', xaxis_title='Time (ns)',
                               yaxis_title='RMSD (nm)', width=750, height=500)
        else:
            lig_name = change.new.split( )[0]
            lig_rmsd = md.rmsd(traj, traj, 0, atom_indices=traj.topology.select(f'resname {lig_name}'))
            rec_rmsd = md.rmsd(traj, traj, 0, atom_indices=traj.topology.select(f'(resname != {lig_name}) and backbone'))
            fig = make_subplots(rows=2, cols=1, subplot_titles=("Receptor backbone RMSD", f"Ligand {lig_name} all-atom RMSD"), vertical_spacing=0.2)
            fig.add_trace(go.Scatter(x=sorted(traj.time/1000), y=rec_rmsd, mode='lines', name='Receptor RMSD', connectgaps=False), row=1, col=1)
            fig.add_trace(go.Scatter(x=sorted(traj.time/1000), y=lig_rmsd, mode='lines', name=f'Ligand {lig_name} RMSD', connectgaps=False), row=2, col=1)
            fig.update_xaxes(title='Time (ns)', showticklabels=True, row=1, col=1)
            fig.update_xaxes(title='Time (ns)', row=2, col=1)
            fig.update_yaxes(title='RMSD (nm)', row=1, col=1)
            fig.update_yaxes(title='RMSD (nm)', row=2, col=1)
            fig.update_layout(width=750, height=800, showlegend=False)
        display(fig)


lig_select.observe(update_RMSD_chart, names='value')

view_chart = widgets.Output()

with view_chart:
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=sorted(traj.time/1000), y=rmsd, mode='lines', name='RMSD', connectgaps=False))
    fig.update_layout(title='whole structure backbone RMSD', xaxis_title='Time (ns)',
                       yaxis_title='RMSD (nm)', width=750, height=500)
    display(fig)

display(lig_select, view_chart)

In [None]:
#@title 4-4. Radius of Gyration
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell shows the radius of gyration of the given protein (approximate size of the globular protein) along the trajectory.

rg = md.compute_rg(traj)

print('Click autoscale to zoom in.')

fig4 = go.Figure()
fig4.add_trace(go.Scatter(x=traj.time/1000, y=rg, mode='lines', name='RoG'))
fig4.update_yaxes(range=[0, rg.max()*1.1])
fig4.update_layout(title='Radius of Gyration', xaxis_title='Time (ns)', yaxis_title='Radius of Gyration (nm)', width=750, height=500)
fig4.show()

In [None]:
#@title 4-5. Free Energy Landscape
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

#@markdown Shows the free energy landscape calculated from the kernel density of given macrostate.

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

rmsd = md.rmsd(traj, traj, 0)
rg = md.compute_rg(traj)
s = read_settings(work)
X = np.vstack([rmsd, rg]).T

# Automatic bandwidth sweep with 5-fold cross-validation
bandwidths = np.logspace(-1, 1, 20)
grid = GridSearchCV(KernelDensity(kernel='gaussian'), {'bandwidth': bandwidths}, cv=5)
grid.fit(X)

kde = grid.best_estimator_
xi, yi = np.mgrid[rmsd.min():rmsd.max():100j, rg.min():rg.max():100j]
xy_sample = np.vstack([xi.ravel(), yi.ravel()]).T
z = np.exp(kde.score_samples(xy_sample))
zi = z.reshape(xi.shape)

Fi = -np.log(zi)
Fi -= np.max(Fi) # zero at maximum
Fi *= (8.31446261815324*s['temperature'].value_in_unit(kelvin)/1000) # convert from kT to kJ/mol

fig5 = go.Figure(data=[go.Surface(z=Fi, x=xi, y=yi, colorbar=dict(title='kJ/mol'),
                                  hovertemplate='RMSD: %{x:.3f} nm<br>Rg: %{y:.3f} nm<br>FE: %{z:.1f} kJ/mol<extra></extra>')])

fig5.update_layout(title='Relative Free Energy Landscape',
                     scene = dict(xaxis_title='RMSD (nm)', yaxis_title='Rg (nm)', zaxis_title='Free Energy (kJ/mol)'),
                     width=750, height=750, autosize=True)

fig5.show()

In [None]:
#@title 4-6. Residue-wise Root Mean Square Fluctuation
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell shows the all-atom RMSF of each residues over the course of trajectory.

#@markdown Residues with higher RMSFs are considered as more flexible.

prot_indices = traj.topology.select("not water and (mass 11 to 33) and (symbol != 'Na')")
traj2 = traj.superpose(traj, 0, atom_indices=prot_indices)
traj2.atom_slice(atom_indices=prot_indices, inplace=True)

atom_rmsf = np.sqrt(3*np.mean(np.square(traj2.xyz - np.mean(traj2.xyz, axis=0)), axis=0))

res = []
res_rmsf = []
for c in traj.topology.chains:
    chain = c.chain_id
    for i, r in enumerate(c.residues):
        res.append(f'{"" if chain is None else f"{chain}:"}{r.name.capitalize()}{i+1}')
        atom_idx = [atom.index for atom in r.atoms]
        res_rmsf.append(np.mean(atom_rmsf[atom_idx]))

fig6 = go.Figure()
fig6.add_trace(go.Scatter(x=np.arange(len(res)), y=res_rmsf, mode='lines', name='RMSF', text=res))
fig6.update_traces(hovertemplate='Residue: %{text}<br>RMSF: %{y:.3f} nm<extra></extra>')
fig6.update_yaxes(range=[0, max(res_rmsf)*1.1])
fig6.update_layout(title='Residue-wise heavy atom RMSF', xaxis_title='Residue index', yaxis_title='RMSF (nm)',
                   hovermode='x unified', width=750, height=500)
fig6.show()

In [None]:
#@title 4-7. Ramachandran Map
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell shows the Ramachandran map of the given protein over the course of trajectory.

#@markdown For large proteins, this may take some time.

stride = int(traj.n_frames/50)

if stride < 2:
    traj2 = traj
    stride = 1
else:
    traj2 = traj[::stride]

s = read_settings(work)

phi = md.compute_phi(traj2, periodic=s['PBC'])[1]
psi = md.compute_psi(traj2, periodic=s['PBC'])[1]

res = []
for c in traj.topology.chains:
    chain = c.chain_id
    for i, r in enumerate(c.residues):
        res.append(f'{"" if chain is None else f"{chain}:"}{r.name.capitalize()}{i+1}')

fig8 = go.Figure()
colors = np.linspace(0, traj.n_frames*s['save interval'].value_in_unit(nanosecond), traj2.n_frames)

fig8.add_trace(go.Scatter(x=phi.flatten(), y=psi.flatten(), mode='markers', marker=dict(color=np.repeat(colors, phi.shape[1]), colorscale='Viridis', colorbar=dict(title='Time (ns)'), showscale=True),
                          text=[f'Time: {(i//phi.shape[1])*stride*s["save interval"]}<br>Residue: {res[i%phi.shape[1]]}<br>Phi: {phi[i//phi.shape[1], i%phi.shape[1]]:.2f}<br>Psi: {psi[i//phi.shape[1], i%phi.shape[1]]:.2f}'
                                for i in range(phi.size)], hoverinfo='text'))

fig8.update_layout(title='Time-dependent Ramachandran Map', xaxis_title='Phi (radians)', yaxis_title='Psi (radians)', xaxis_range=[-np.pi, np.pi], yaxis_range=[-np.pi, np.pi], width=750, height=500)
fig8.show()

In [None]:
#@title 4-8. Solvent Accessible Surface Area
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell shows the change in the surface area of the given residue.

#@markdown Not recommended for free Colab sessions.

#@markdown Type the index (starting from 1) of the residue you want to analyze its SASA.

residue_index = "140" #@param {type:"string"}
res_idx = int(residue_index)-1

assert res_idx < traj.topology.n_residues, "Invalid residue index!"

sasa = md.shrake_rupley(traj, mode='residue')

for c in traj.topology.chains:
    chain = c.chain_id
    for i, r in enumerate(c.residues):
        if r.index == res_idx:
            name = f'{"" if chain is None else f"{chain}:"}{r.name.capitalize()}{i+1}'
            break

fig9 = go.Figure()
fig9.add_trace(go.Scatter(x=traj.time/1000, y=sasa[:,res_idx], mode='lines', name='SASA'))
fig9.update_layout(title=f'Solvent Accessible Surface Area of the residue {name}',
                   xaxis_title='Time (ns)', yaxis_title='SASA (nm^2)', width=750, height=500)
fig9.show()

In [None]:
#@title 4-9. Inter-residue hydrogen bond analysis
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

traj = md.load(work+'production.h5')

residue_1 = "118" #@param {type:"string"}
residue_2 = "" #@param {type:"string"}

residue_1 = int(residue_1) -1
if residue_2 != "":
    water_flag = False
    residue_2 = int(residue_2) -1
else:
    water_flag = True

assert residue_1 != residue_2, "Same residue indices entered!"

s = read_settings(work)

topol = traj.topology.to_openmm()

donors = []
acceptors = []

for a in topol.atoms():
    if a.element.symbol in ['N', 'O', 'F']:
        if a.residue.index == residue_1:
            donors.append(a.index)
        elif a.residue.index == residue_2:
            acceptors.append(a.index)
        elif water_flag and a.residue.name == 'HOH':
            acceptors.append(a.index)

donor_Hs = {}
acceptor_Hs = {}

for b in topol.bonds():
    if b[0].index in donors and b[1].element.symbol == 'H':
        donor_Hs[b[1].index] = b[0].index
    elif b[1].index in donors and b[0].element.symbol == 'H':
        donor_Hs[b[0].index] = b[1].index
    elif b[0].index in acceptors and b[1].element.symbol == 'H':
        acceptor_Hs[b[1].index] = b[0].index
    elif b[1].index in acceptors and b[0].element.symbol == 'H':
        acceptor_Hs[b[0].index] = b[1].index

H_bonds = []

for key, value in donor_Hs.items():
    for a in acceptors:
        H_bonds.append([value, key, a])
for key, value in acceptor_Hs.items():
    for d in donors:
        H_bonds.append([value, key, d])

assert len(H_bonds) >0, "There are no H-bond donor/acceptors in the residue!"

angles = md.compute_angles(traj, np.array(H_bonds), s['PBC'], True)
distances = md.compute_distances(traj, np.array([[h[1], h[2]] for h in H_bonds]), s['PBC'], True)

hbonds = np.sum(np.where(angles > 2*np.pi/3, 1, 0)*np.where(distances < 0.25, 1, 0), axis=1)

for c in traj.topology.chains:
    chain = c.chain_id
    for i, r in enumerate(c.residues):
        if r.index == residue_1:
            name1 = f'{"" if chain is None else f"{chain}:"}{r.name.capitalize()}{i+1}'
        elif not water_flag and r.index == residue_2:
            name2 = f'{"" if chain is None else f"{chain}:"}{r.name.capitalize()}{i+1}'
        else:
            continue

fig10 = go.Figure()
fig10.add_trace(go.Scatter(x=traj.time/1000, y=hbonds, mode='lines', name='H-bond'))
fig10.update_layout(xaxis_title='Time (ns)', yaxis_title='Number of H-bonds', width=750, height=500)
if water_flag:
    fig10.update_layout(title=f'Number of Hydrogen bonds between {name1} and water molecules')
else:
    fig10.update_layout(title=f'Number of Hydrogen Bonds between {name1} and {name2}')
fig10.show()

In [None]:
#@title 4-10. Secondary Structure Analysis using DSSP
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell shows the secondary structure of each residue over the course of trajectory.

#@markdown Not recommended for free Colab sessions, since it takes up a lot of memory.

dssp = md.compute_dssp(traj)
dssp_map = {'H':0, 'E':1, 'C':2}
dssp_num = np.array([[dssp_map.get(ss,2) for ss in frame] for frame in dssp])

s = read_settings(work)

fig11 = go.Figure(data=go.Heatmap(z=dssp_num.T, x=traj.time/1000, y=np.arange(traj.n_residues),
                                 colorscale=[[0,'red'],[0.33, 'red'], [0.33, 'yellow'], [0.66, 'yellow'], [0.66, 'blue'], [1, 'blue']],
                                 colorbar=dict(tickvals=[0,1,2],ticktext=['Helix', 'Sheet', 'Coil'])))
fig11.update_layout(title='Time-dependent Secondary Structure (DSSP)', xaxis_title='Time (ns)', yaxis_title='Residue index', width=750, height=1000)
fig11.data[0].hovertext = [[f'Time : {i*s["save interval"]}<br>Residue: {j}<br>SS: {dssp[i][j]}' for i in range(traj.n_frames)] for j in range(traj.n_residues)]
fig11.data[0].hoverinfo = 'text'
fig11.show()

In [None]:
#@title 4-11. Dynamic Cross-Correlation Matrix
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell shows the correlation between the movements of alpha carbons of each residue.
#@markdown When the correlation value approaches +1 (blue), it means the pair of residues tend to move in the same direction.
#@markdown When the value approaches -1 (red), it means the pair of residues tend to move in the opposite directions.
#@markdown When the value approaches 0 (white), it means the pair of residues move independently from each other.

ca_idx = traj.topology.select_atom_indices('alpha')
mean_pos = np.mean(traj.xyz[:, ca_idx, :], axis=0)
fluct = traj.xyz[:, ca_idx, :] - mean_pos

dccm = np.zeros((traj.n_residues, traj.n_residues))

for i in range(len(ca_idx)):
    for j in range(len(ca_idx)):
        numerator = np.mean(np.sum(fluct[:,i,:]*fluct[:,j,:], axis=1))
        denominator = np.sqrt(np.mean(np.sum(fluct[:,i,:]**2, axis=1))*np.mean(np.sum(fluct[:,j,:]**2, axis=1)))
        dccm[i,j] = numerator/denominator

fig12 = go.Figure(data=go.Heatmap(z=dccm, colorscale='RdBu', zmid=0, colorbar=dict(title='Correlation')))
fig12.update_layout(title='Dynamic Cross-Correlation Matrix', xaxis_title='Residue index', yaxis_title='Residue index', width=750, height=750)
fig12.data[0].hovertext = [[f'Residue i: {i+1}<br>Residue j: {j+1}<br>Correlation: {dccm[i,j]:.2f}' for j in range(dccm.shape[1])] for i in range(dccm.shape[0])]
fig12.data[0].hoverinfo = 'text'
fig12.show()

### 4-12. Principal Component Analysis

In [None]:
#@title 2D Principal Component Analysis
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell extracts the collective motions of atomic coordinates and demonstrates them as a pair of orthogonal eigenvectors called principal components. The percentage values in the parentheses describe the eigenvalue of each component, which show the explainability of each component in the motion of atoms over the course of trajectory.

prot_indices = traj.topology.select("not water and (mass 11 to 33) and (symbol != 'Na')")
traj2 = traj.superpose(traj, 0, atom_indices=prot_indices)
traj2.atom_slice(atom_indices=prot_indices, inplace=True)
pca = PCA(n_components=2)
reduced_cart = pca.fit_transform(traj2.xyz.reshape(traj2.n_frames, traj2.n_atoms*3))
var_explained = pca.explained_variance_ratio_*100

s = read_settings(work)

fig13 = go.Figure()
fig13.add_trace(go.Scatter(x=reduced_cart[:,0], y=reduced_cart[:,1], mode='markers',
                          marker=dict(size=5, color=np.arange(len(reduced_cart))*s["save interval"].value_in_unit(nanosecond), colorscale='Viridis', colorbar=dict(title='Time (ns)'), showscale=True),
                          text=[f'Time: {i*s["save interval"]}' for i in range(len(reduced_cart))], hoverinfo='text', name='PCA'))
fig13.update_layout(title='Two Dimensional Principal Component Analysis', xaxis_title=f'PC1 ({var_explained[0]:.2f}%)', yaxis_title=f'PC2 ({var_explained[1]:.2f}%)', width=750, height=500)
fig13.show()

In [None]:
#@title 3D Principal Component Analysis
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown SImilar to the above cell that performed a 2D PCA, this cell performs the 3D PCA.
prot_indices = traj.topology.select("not water and (mass 11 to 33) and (symbol != 'Na')")
traj2 = traj.superpose(traj, 0, atom_indices=prot_indices)
traj2.atom_slice(atom_indices=prot_indices, inplace=True)
pca = PCA(n_components=3)
reduced_cart = pca.fit_transform(traj2.xyz.reshape(traj2.n_frames, traj2.n_atoms*3))
var_explained = pca.explained_variance_ratio_*100

s = read_settings(work)
s['save interval'] = 10*picosecond

fig14 = go.Figure()
fig14.add_trace(go.Scatter3d(x=reduced_cart[:,0], y=reduced_cart[:,1], z=reduced_cart[:,2], mode='markers',
                          marker=dict(size=5, color=s["save interval"].value_in_unit(nanosecond)*np.arange(len(reduced_cart)), colorscale='Viridis', opacity=0.8, colorbar=dict(title='Time (ns)'), showscale=True),
                          text=[f'Time: {i*s["save interval"]}' for i in range(len(reduced_cart))], hoverinfo='text', name='PCA'))
fig14.update_layout(title='Three Dimensional Principal Component Analysis', scene=dict(xaxis_title=f'PC1 ({var_explained[0]:.2f}%)', yaxis_title=f'PC2 ({var_explained[1]:.2f}%)', zaxis_title=f'PC3 ({var_explained[2]:.2f}%)'), width=750, height=750)
fig14.show()

In [None]:
#@title Visualization of Principal Components
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'production.h5'), "Run the cell# 4-1 first to wrap the trajectory."

if 'traj' not in globals():
    traj = md.load(work+'production.h5')

#@markdown This cell animates each principal component's motion in MDTraj trajectory.

#@markdown Set index of the principal component to visualize.
component_to_view = "1" #@param [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
#@markdown Set the scale of the motion. Set larger values for more exaggerated motions.
scale = 3 #@param {type:"slider", min:1, max:10, step:1}

nsres = read_non_standard_residues_list(work)
nsres_to_view = ""
for i, r in enumerate(nsres):
    nsres_to_view += f"(resname {r['name']})"
    if i < len(nsres)-1:
        nsres_to_view += " or "

ligand_indices = traj.topology.select(f'{nsres_to_view}')
prot_indices = traj.topology.select("not water and (mass 11 to 33) and (symbol != 'Na')")
view_indices = np.array(sorted(set(list(ligand_indices)+list(prot_indices))))
traj2 = traj.superpose(traj, 0, atom_indices=prot_indices)
traj2.atom_slice(atom_indices=view_indices, inplace=True)
coords = traj2.xyz.reshape(traj2.n_frames, -1)

pca = PCA()
pca_results = pca.fit_transform(coords)
top_pcs = pca.components_[:10]

pc_trajs = top_pcs.reshape(10, traj2.n_atoms, 3)
sine_wave = np.sin(np.linspace(0, 2*np.pi, num=50))

pc_traj = md.Trajectory(np.zeros((50, traj2.n_atoms, 3)), traj2.topology)

for i in range(50):
    pc_traj.xyz[i] = traj2.xyz[0] + scale*sine_wave[i]*pc_trajs[int(component_to_view)-1]

view6 = nv.show_mdtraj(pc_traj, defaultRepresentation=False)
view6._set_size('750px','500px')
view6.add_cartoon()
view6.add_simplified_base("nucleic")
view6.add_ball_and_stick("ligand")
view6.center()
view6.player.delay = 500
display(view6)

## 5. Intermolecular Interaction Analysis

In [None]:
#@title 5-1. Visualizing protein-ligand interactions by PLIP
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert os.path.exists(work+'starting_structure.pdb'), "starting_structure.pdb file not found!"
assert os.path.exists(work+'production.h5'), "Result trajectory not found! Run cell# 4-1 first."
assert os.path.exists(work+'clustered.pdb'), "Run cell# 4-2 first."

traj = md.load(work+'production.h5')

#@markdown Each interaction type is shown in different color. <br/>
#@markdown * H-bonds : Red
#@markdown * Hydrophobic : grey
#@markdown * Salt bridge : yellow
#@markdown * Water bridge : cyan
#@markdown * Pi-pi stacking : blue
#@markdown * Pi-cation : purple
#@markdown * Halogen bonds : green
#@markdown * Metal coordination : emerald

clusterPDB = PDBComplex()
clusterPDB.load_pdb(work+'clustered.pdb')
clusterPDB.analyze()

ligands = {}
for ligand in clusterPDB.ligands:
    clusterPDB.characterize_complex(ligand)

for key, value in sorted(clusterPDB.interaction_sets.items()):
    ligands[key] = BindingSiteReport(value)

interaction_types = ['hbond', 'hydrophobic', 'saltbridge', 'waterbridge', 'pistacking', 'pication', 'halogen', 'metal']
color = [(1,0,0), (0.5,0.5,0.5), (1,1,0), (0,1,1), (0,0,1), (1,0,1), (0,1,0), (0.9,0.5,0.9)]

view_interaction = widgets.Output()

with view_interaction:
    view = nv.NGLWidget()
    view.add_component(nv.FileStructure(work+'clustered.pdb'), defaultRepresentation=False)
    view._set_size('750px','500px')
    view.add_cartoon()
    view.add_simplified_base("nucleic")
    view.add_licorice("ligand")
    view.center()
    display(view)

def center_ligand(change): # callback function for the update
    with view_interaction:
        clear_output(wait=True)
        view = nv.NGLWidget()
        view.add_component(nv.FileStructure(work+'clustered.pdb'), defaultRepresentation=False)
        view._set_size('750px','500px')
        view.add_cartoon()
        view.add_simplified_base("nucleic")
        view.add_licorice("ligand")
        if change.new == 'none':
            view.center()
        else:
            intres = []
            for i, inttype in enumerate(interaction_types):
                intdata = getattr(ligands[change.new], inttype+'_info')
                for interaction in intdata:
                    intres.append(f'({interaction[0]} and {interaction[1]} and :{interaction[2]})')
                    view.shape.add_cylinder(interaction[-2], interaction[-1], color[i], [0.05], inttype)
            view.add_licorice(" or ".join(intres))
            resdata = change.new.split(':')
            view.center(f'{resdata[2]} and :{resdata[1]} and {resdata[0]}')
        display(view)

print('Wait up to 1 min for the NGLWidget to restart after selecting a ligand. ')

if len(ligands) > 0:
    select_lig = widgets.Dropdown(options=['none']+list(ligands.keys()), value='none', description='Ligand : ', continuous_update=True)
    select_lig.observe(center_ligand, names='value')
    display(select_lig, view_interaction)
else:
    display(view_interaction)

In [None]:
#@title 5-2. Prepare trajectories for MM/GBSA
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
#@markdown Enter the chain ID(s) of ligand and receptor. Separate multiple chains with commas. <br/>
#@markdown You can enter only ligand chain ID to make every other chain as receptor.
ligand_chain = "C" #@param {type:"string"}
receptor_chain = "" #@param {type:"string"}

#@markdown To discard the starting fraction of the trajectory, set this value larger then 0.
skip_initial_fraction = 0 #@param {type:"slider", min:0, max:0.9, step:0.1}
#@markdown Only calculate every N frames.
stride = 1  #@param {type:"slider", min:1, max:20, step:1}

pdb = PDBFile(work+'starting_structure.pdb')

lig_a = []
rec_a = []
com_a = []
elements = []

lig_present = False
rec_present = False
lig_r = []
rec_r = []

for a in pdb.topology.atoms():
    if a.residue.chain.id in ligand_chain.split(','):
        lig_present = True
        lig_r.append(a.residue.index)
        lig_a.append(a.index)
    elif receptor_chain == "" or a.residue.chain.id in receptor_chain.split(','):
        rec_present = True
        rec_r.append(a.residue.index)
        rec_a.append(a.index)
    else:
        continue
    com_a.append(a.index)
    elements.append(a.element.symbol)

for e in set(elements):
    assert e in ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Si', 'Cl'], f'No bondi parameter found for element {e}!'
assert lig_present, f'No ligand molecule with the chain ID among {ligand_chain}!'
assert rec_present, f'No receptor molecule with the chain ID among {receptor_chain}!'


s = read_settings(work)
ff, wm = determine_forcefield(s['forcefield'], s['watermodel'])

nsres = read_non_standard_residues_list(work)
smffs = []
for res in nsres:
    if res['lig']:
        smffs.append(res['ff'])

esp_flag = False
if 'espaloma-0.3.2' in set(smffs):
    esp_flag = True

if esp_flag:
    print('You need to install the Full package!')

In [None]:
#@title 5-3. Running MM/GBSA
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert 'ff' in globals(), "Please run the cell# 5-2 first."

#@markdown Select a Generalized Born model to use.
GB_model = "GBn2 (igb=8)" #@param ["HCT (igb=1)", "OBC1 (igb=2)", "OBC2 (igb=5)", "GBn (igb=7)", "GBn2 (igb=8)"]
#@markdown Set the dielectric constants of solvent and solute.<br/>
#@markdown Recommended values are 78 for the solvent and 2~5 for the solute. <br/>
#@markdown You may increase/decrease the solute dielectric constant to tone down/up the interaction energy.
solvent_dielectric = 78 #@param {type:"slider", min:1, max:80, step:1}
solute_dielectric = 5 #@param {type:"slider", min:1, max:80, step:1}

assert solvent_dielectric > solute_dielectric, "Solvent dielectric constant must be larger than Solute's!"

gaff_flag = False
smirnoff_flag = False
esp_flag = False

gaff = []
smirnoff = []
esp = []
for res in nsres:
    if res['ff'] == 'GAFF-2.11':
        gaff_flag = True
        with open(f'{work}{res["name"]}.json', 'r') as f:
            gaff.append(Molecule.from_json(f.read()))
    elif res['ff'] == 'Sage-2.2.1':
        smirnoff_flag = True
        with open(f'{work}{res["name"]}.json', 'r') as f:
            smirnoff.append(Molecule.from_json(f.read()))
    elif res['ff'] == 'espaloma-0.3.2':
        esp_flag = True
        with open(f'{work}{res["name"]}.json', 'r') as f:
            esp.append(Molecule.from_json(f.read()))

if gaff_flag:
    g = GAFFTemplateGenerator(molecules=gaff)
    ff.registerTemplateGenerator(g.generator)
if smirnoff_flag:
    g = SMIRNOFFTemplateGenerator(molecules=smirnoff)
    ff.registerTemplateGenerator(g.generator)
if esp_flag:
    g = EspalomaTemplateGenerator(molecules=esp)
    ff.registerTemplateGenerator(g.generator)

def create_GBSA_system(trajectory, forcefield, fftype, eps_solv, eps_solut, kappa):

    top = trajectory.topology.to_openmm(trajectory)

    for chain in top.chains():
        a1 = None
        a2 = None
        for i, res in enumerate(chain.residues()):
            if i == 0 and res.name == 'PRO':
                for atom in res.atoms():
                    if atom.name == 'H2':
                        a2 = atom
                    elif atom.name == 'N':
                        a1 = atom
                    else:
                        continue
            else:
                break
        if a1 is not None and a2 is not None:
            top.addBond(a1, a2)

    system = forcefield.createSystem(top, nonbondedMethod=NoCutoff)

    charges = []

    for force in system.getForces():
        if isinstance(force, NonbondedForce):
            force.setForceGroup(1)
            force.addGlobalParameter(f'coul_scale', 1)
            force.addGlobalParameter(f'lj_scale', 1)
            for i in range(force.getNumParticles()):
                q, s, e = force.getParticleParameters(i)
                charges.append(q)
                force.setParticleParameters(i, 0, s, 0)
                force.addParticleParameterOffset(f'coul_scale', i, q, 0, 0)
                force.addParticleParameterOffset(f'lj_scale', i, 0, 0, e)
            for i in range(force.getNumExceptions()):
                p1, p2, q_p, s, e = force.getExceptionParameters(i)
                force.setExceptionParameters(i, p1, p2, 0, s, 0)

        elif isinstance(force, CustomNonbondedForce):
            force.setForceGroup(2)
            for i in range(force.getNumExceptions()):
                p1, p2, q_p, s, e = force.getExceptionParameters(i)
                force.setExceptionParameters(i, p1, p2, 0, s, 0)
        else:
            force.setForceGroup(0)

    args = {'soluteDielectric': eps_solut, 'solventDielectric': eps_solv,
            'kappa': kappa, 'SA':'ACE'}

    if fftype == "HCT (igb=1)":
        gb = customgbforces.GBSAHCTForce(**args)
        params = customgbforces.GBSAHCTForce.getStandardParameters(top)
    elif fftype == "OBC1 (igb=2)":
        gb = customgbforces.GBSAOBC1Force(**args)
        params = customgbforces.GBSAOBC1Force.getStandardParameters(top)
    elif fftype == "OBC2 (igb=5)":
        gb = customgbforces.GBSAOBC2Force(**args)
        params = customgbforces.GBSAOBC2Force.getStandardParameters(top)
    elif fftype == "GBn (igb=7)":
        gb = customgbforces.GBSAGBnForce(**args)
        parmas = customgbforces.GBSAGBnForce.getStandardParameters(top)
    elif fftype == "GBn2 (igb=8)":
        gb = customgbforces.GBSAGBn2Force(**args)
        params = customgbforces.GBSAGBn2Force.getStandardParameters(top)
    else:
        raise ValueError

    gb.setForceGroup(3)
    for i, p in enumerate(params):
        if fftype in ['GBn (igb=7)', 'GBn2 (igb=8)']:
            gb.addParticle([charges[i], p[0], p[1], p[2], p[3], p[4]])
        else:
            gb.addParticle([charges[i], p[0], p[1]])

    gb.finalize()
    gb.setNonbondedMethod(CustomNonbondedForce.NoCutoff)
    system.addForce(gb)

    return system


def calculate_energies(work, name, trajectory, forcefield, fftype, eps_solv, eps_solut):

    print(f'Running analysis on {name} system :')

    s = read_settings(work)
    temperature = s['temperature']/kelvin
    ion_conc = s['ionic strength']/molar
    kappa = 367.434915*sqrt(float(ion_conc)/78.5/(temperature))

    system = create_GBSA_system(trajectory, forcefield, fftype, eps_solv, eps_solut, kappa)
    integrator = VerletIntegrator(1*femtosecond)
    context = Context(system, integrator)

    coulombic = []
    LJ = []
    GBSA = []

    for i in tqdm(range(trajectory.n_frames)):
        context.setPositions(trajectory.openmm_positions(i))
        context.setParameter('coul_scale', eps_solut**(-1/2))
        context.setParameter('lj_scale', 0)
        coulombic.append(context.getState(getEnergy=True, groups={1}).getPotentialEnergy().value_in_unit(kilojoules_per_mole))
        context.setParameter('coul_scale', 0)
        context.setParameter('lj_scale', 1/eps_solut)
        LJ.append(context.getState(getEnergy=True, groups={1,2}).getPotentialEnergy().value_in_unit(kilojoules_per_mole))
        GBSA.append(context.getState(getEnergy=True, groups={3}).getPotentialEnergy().value_in_unit(kilojoules_per_mole))

    return np.array(coulombic), np.array(LJ), np.array(GBSA)


def derive_G_K(dH, dH_std, dS, dS_std, temperature):
    dG = dH - temperature*dS/1000
    dG_std = np.sqrt(dH_std**2+(dS_std*temperature/1000)**2)
    Kd = np.exp(dG*1000/8.31446261815324/temperature)
    Kd_std = Kd*1000/8.31446261815324/temperature*dG_std
    return dG, dG_std, Kd, Kd_std


rough_traj = md.load(work+'production.h5')
start_point = int(skip_initial_fraction*rough_traj.n_frames)

lig_traj = rough_traj[start_point::stride].atom_slice(atom_indices=lig_a)
rec_traj = rough_traj[start_point::stride].atom_slice(atom_indices=rec_a)
com_traj = rough_traj[start_point::stride].atom_slice(atom_indices=com_a)

lig_coul, lig_lj, lig_gb = calculate_energies(work, 'ligand', lig_traj, ff, GB_model, solvent_dielectric, solute_dielectric)
rec_coul, rec_lj, rec_gb = calculate_energies(work, 'receptor', rec_traj, ff, GB_model, solvent_dielectric, solute_dielectric)
com_coul, com_lj, com_gb = calculate_energies(work, 'complex', com_traj, ff, GB_model, solvent_dielectric, solute_dielectric)

couls = com_coul - lig_coul - rec_coul
ljs = com_lj - lig_lj - rec_lj
gbsas = com_gb - lig_gb - rec_gb
dHs = couls + ljs + gbsas
dH = np.mean(dHs)

fig15 = go.Figure()
fig15.add_trace(go.Scatter(x=lig_traj.time/1000, y=dHs, mode='lines', name='delta H'))
fig15.add_trace(go.Scatter(x=lig_traj.time/1000, y=couls, mode='lines', name='Coulombic E'))
fig15.add_trace(go.Scatter(x=lig_traj.time/1000, y=ljs, mode='lines', name='Lennard-Jones E'))
fig15.add_trace(go.Scatter(x=lig_traj.time/1000, y=gbsas, mode='lines', name='Solvation E'))
fig15.update_layout(title='Binding Energy', xaxis_title='Time (ns)', yaxis_title='Energy (kJ/mol)', width=750, height=500)

dE = np.mean(couls+ljs)
dE_std = np.std(couls+ljs)
dH_std = np.sqrt(np.var(couls+ljs)+np.var(gbsas))

fig16 = go.Figure()
x = ['ΔE', 'polar', 'nonpolar', 'GBSA', 'ΔH']
y = [dE, np.mean(couls), np.mean(ljs), np.mean(gbsas), dH]
colors = ['#1f77b4', '#17becf', '#17becf', '#9467bd', '#d62728']
err = dict(type='data', array=[dE_std, np.std(couls), np.std(ljs), np.std(gbsas), dH_std])
fig16.add_trace(go.Bar(x=x, y=y, error_y=err, marker_color=colors))
fig16.update_layout(title="Energy Contributions", xaxis_title='Type', yaxis_title='Energy (kJ/mol)', width=750, height=500)

display(fig15, fig16)

print(f'Interaction Energy    : {dE:>+7.2f} +/- {dE_std:>5.2f} kJ/mol')
print(f'Polar contribution    : {np.mean(couls):>+7.2f} +/- {np.std(couls):>5.2f} kJ/mol')
print(f'Nonpolar contribution : {np.mean(ljs):>+7.2f} +/- {np.std(ljs):>5.2f} kJ/mol')
print(f'GBSA Solvation Energy : {np.mean(gbsas):>+7.2f} +/- {np.std(gbsas):>5.2f} kJ/mol')
print(f'Binding Enthalpy      : {dH:>+7.2f} +/- {dH_std:>5.2f} kJ/mol')

In [None]:
#@title 5-4. Calculate Binding Free Energy from Interaction Entropy & C2 Entropy
assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert 'couls' in globals(), "Please run the cell# 5-2 first."

#@markdown Set the bootstrap sampling size to 10^n.
bootstrap_power = 4 #@param {type:"slider", min:3, max:8, step:1}

def calculate_interaction_entropy(interaction_energies, temperature):
    """
    Calculate the interaction entropy using the method from Duan et al.
    https://doi.org/10.1021/jacs.6b02682
    """
    R = 8.31446261815324  # Gas constant in J/mol/K
    beta = 1 / (R * temperature)
    avg_energy = np.mean(interaction_energies)
    delta_e = interaction_energies - avg_energy
    exp_beta_delta_e = np.exp(beta * delta_e *1000) # kJ/mol to J/mol

    interaction_entropy = - R * np.log(np.mean(exp_beta_delta_e))
    return interaction_entropy

def calculate_c2_entropy(interaction_energies, temperature):
    """
    Calculate the C2 entropy using the method from Menzer et al.
    https://doi.org/10.1021/acs.jctc.8b00418
    """
    R = 8.31446261815324  # Gas constant in J/mol/K
    beta = 1 / (R * temperature)
    variance = np.var(interaction_energies*1000)

    c2_entropy = -0.5 * beta * variance / temperature # kJ/mol to J/mol
    return c2_entropy


def bootstrap_entropy_error(interaction_energies, temperature, n_bootstrap):

    n_samples = len(interaction_energies)
    interaction_entropies = []
    c2_entropies = []

    print(f'Running with n_bootstrap : {n_bootstrap}...')

    for _ in tqdm(range(n_bootstrap)):
        bootstrap_sample = np.random.choice(interaction_energies, size=n_samples, replace=True)

        # Calculate entropies for this sample
        interaction_entropy = calculate_interaction_entropy(bootstrap_sample, temperature)
        c2_entropy = calculate_c2_entropy(bootstrap_sample, temperature)

        interaction_entropies.append(interaction_entropy)
        c2_entropies.append(c2_entropy)

    # Calculate mean and standard error
    IE = np.mean(interaction_entropies)
    IE_std = np.std(interaction_entropies, ddof=1)

    c2 = np.mean(c2_entropies)
    c2_std = np.std(c2_entropies, ddof=1)

    return IE, IE_std, c2, c2_std

temperature = s['temperature']/kelvin

# set from https://doi.org/10.1021/acs.jctc.1c00374
if dE_std > 15:
    print('Warning : Interaction Entropy is likely unreliable!')
if dE_std > 25:
    print('Warning : C2 Entropy is likely unreliable!')

IE, IE_std, c2, c2_std = bootstrap_entropy_error(couls+ljs, temperature, 10**bootstrap_power)

print(f'Interaction Entropy : {IE:>+7.2f} +/- {IE_std:>3.2f} J/mol/K')
print(f'C2 Entropy          : {c2:>+7.2f} +/- {c2_std:>3.2f} J/mol/K')

dG_IE, dG_IE_std, Kd_IE, Kd_IE_std = derive_G_K(dH, dH_std, IE, IE_std, temperature)
dG_c2, dG_c2_std, Kd_c2, Kd_c2_std = derive_G_K(dH, dH_std, c2, c2_std, temperature)

print(f'Binding Free Energy (ΔG) from IE : {dG_IE:>+7.2f} +/- {dG_IE_std:>3.2f} kJ/mol')
print(f'Binding Free Energy (ΔG) from C2 : {dG_c2:>+7.2f} +/- {dG_c2_std:>3.2f} kJ/mol')
print(f'Dissociation Constant (Kd) from IE: {Kd_IE:>5.2e} +/- {Kd_IE_std:>3.2e} M')
print(f'Dissociation Constant (Kd) from C2: {Kd_c2:>5.2e} +/- {Kd_c2_std:>3.2e} M')


fig17 = go.Figure()
fig17.add_trace(go.Bar(name='IE', x=['ΔH', '-TΔS', 'ΔG'], y=[dH, -temperature*IE/1000, dG_IE],
                        error_y=dict(type='data', array=[dH_std, -temperature*IE_std/1000, dG_IE_std])))
fig17.add_trace(go.Bar(name='c2', x=['ΔH', '-TΔS', 'ΔG'], y=[dH, -temperature*c2/1000, dG_c2],
                        error_y=dict(type='data', array=[dH_std, -temperature*c2_std/1000, dG_c2_std])))
fig17.update_layout(title='Binding Free Energy', xaxis_title='Components', yaxis_title='Energy (kJ/mol)', barmode='group', width=750, height=500)
display(fig17)

In [None]:
#@title 5-5. (Optional) Calculate Binding Entropy via Quasi-harmonic Approximation
#@markdown **WARNING: QH method takes very long time for large systems!!**

assert 'work' in globals(), "Please run the cell# 0 to set working directory."
assert 'temperature' in globals(), "Please run the cell# 5-2 first."

#@markdown Set the reference concentration in mol/L for a standard solution of ligand.
reference_concentration = 1 #@param {type:"string"}
bootstrap_power = 3 #@param {type:"slider", min:3, max:10, step:1}

def estimate_rot_entropy_loss(lig_traj, temperature):
    print("Estimating the ligand's rotational entropy loss upon binding...", end='')

    R = 8.31446261815324  # Gas constant in J/mol/K
    N_A = 6.022e23  # Avogadro's number
    h = 6.62607004e-34  # Planck's constant in J*s

    inertia_tensor = md.compute_inertia_tensor(lig_traj)
    moment, axes = np.linalg.eigh(inertia_tensor)
    quaternions = Rotation.from_matrix(axes).as_quat()

    kde_rot = KernelDensity(kernel='gaussian', bandwidth='scott').fit(quaternions)
    S_bound = -R * kde_rot.score_samples(quaternions)
    S_bound += 0.5*R*np.log(np.prod(moment.mean(axis=0)))

    principal_axes = []
    for i in range(lig_traj.n_frames):
        I_A, I_B, I_C = inertia_tensor[i,:,:].diagonal()
        principal_axes.append(I_A*I_B*I_C)

    S_free = []
    for i in range(lig_traj.n_frames):
        S_free.append(R*(1.5*np.log(temperature)+np.log(principal_axes[i])+1.5*np.log(8*np.pi**2 * R/N_A/h**2)+1.5))

    dS = S_bound - np.array(S_free)

    print("done.")

    return np.mean(dS), np.std(dS)


def estimate_trans_entropy_loss(lig_traj, rec_traj, temperature, reference_concentration):
    R = 8.31446261815324  # Gas constant in J/mol/K
    N_A = 6.022e23  # Avogadro's number

    print(f'Reference concentration set to {reference_concentration} mol/L.')
    print("Estimating the ligand's translational entropy loss upon binding...", end='')
    # Calculate free volume based on reference concentration
    ref_v  = 1000 * 1e21/ (reference_concentration * N_A)

    free_volume = np.zeros(lig_traj.n_frames)

    for i in range(lig_traj.n_frames):
        free_volume[i] = ref_v - 4/3*np.pi*(md.compute_rg(lig_traj)[i]**3)

    lig_com = md.compute_center_of_mass(lig_traj)
    rec_com = md.compute_center_of_mass(rec_traj)
    bound_volume = 4/3 * np.pi
    for i in range(3):
        bound_volume *= np.var((lig_com-rec_com)[:,i])

    dS_trans = R * np.log(np.full(lig_traj.n_frames, bound_volume) / free_volume)

    print("done.")

    return np.mean(dS_trans), np.std(dS_trans)


def autocorrelation_time(trajectory):
    n_frames, n_atoms, _ = trajectory.shape
    mean_positions = np.mean(trajectory, axis=0)

    displacements = trajectory - mean_positions
    sd = np.sum(displacements**2, axis=(1, 2))

    msd = np.mean(sd)

    y = sd - msd
    acorr = np.correlate(y, y, mode='full')[n_frames-1:]
    acorr /= acorr[0]
    return np.sum(acorr)

def extract_independent_frames(coordinates):
    """
    Extract independent frames based on dynamic autocorrelation times.
    """
    n_frames, n_atoms, n_dims = coordinates.shape
    independent_indices = []
    last_index = 0

    while last_index < n_frames:
        independent_indices.append(last_index)
        remaining_data = coordinates[last_index:,:,:]
        if remaining_data.shape[0] > 1:  # Ensure we have enough data to calculate autocorrelation
            try:
                lag = int(np.ceil(autocorrelation_time(remaining_data)))
            except:
                break

        last_index += lag

    return coordinates[independent_indices]

def bootstrap_sample(data, n_samples):
    indices = np.random.choice(data.shape[0], size=n_samples, replace=True)
    return data[indices,:,:]

def quasi_harmonic(name, coordinates, n_bootstrap):
    R = 8.31446261815324  # Gas constant in J/mol/K

    print(f"Calculating {name}'s vibrational entropy by quasi-harmonic approach...")
    independent_coords = extract_independent_frames(coordinates)
    print(f"Total {independent_coords.shape[0]} independent frames extracted.")
    bootstrapped_coords = bootstrap_sample(independent_coords, n_bootstrap)
    reshaped_coords = bootstrapped_coords.reshape(bootstrapped_coords.shape[0], -1)
    print(f'Calculating covariance matrix from bootstrapped samples...', end='')
    cov_estimator = EmpiricalCovariance()
    cov_estimator.fit(reshaped_coords)
    cov_matrix = cov_estimator.covariance_

    S_corr = 0.5 * np.log(np.linalg.det(2 * np.pi * np.e * cov_matrix))

    print('done.')

    return R * S_corr


n_bootstrap = 10**bootstrap_power
print(f'n_bootstrap set to {n_bootstrap}.')

dS_rot, dS_rot_std = estimate_rot_entropy_loss(lig_traj, temperature)
dS_trans, dS_trans_std = estimate_trans_entropy_loss(lig_traj, rec_traj, temperature, float(reference_concentration))

qh_lig_S = quasi_harmonic('ligand', lig_traj.xyz, n_bootstrap)
qh_rec_S = quasi_harmonic('receptor', rec_traj.xyz, n_bootstrap)
qh_com_S = quasi_harmonic('complex', com_traj.xyz, n_bootstrap)

qh_dS_vib = com_S - lig_S - rec_S

qh = dS_vib + dS_rot + dS_trans
qh_std = np.sqrt(dS_rot_std**2 + dS_trans_std**2)

print(f'Estimated Binding Entropy : {qh:.2f} +/- {qh_std:.2f} J/mol/K')


## Tools used

* REDUCE
> J. M. Word, S. C. Lovell, J. S. Richardson, and D. C. Richardson. (1999) "Asparagine and glutamine: using hydrogen atom contacts in the choice of side-chain amide orientation" J. Mol. Biol. 285(4): 1735-1747. DOI:[10.1006/jmbi.1998.2401](https://doi.org/10.1006/jmbi.1998.2401)

*  RDKit
> RDKit: Open-source cheminformatics. https://www.rdkit.org

* OpenBabel
> N. M. O'Boyle, M. Banck, C. A. James, C. Morley, T. Vandermeersch, and G. R. Hutchison (2011) "Open Babel: An open chemical toolbox" J. Cheminf. 3:33. DOI:[10.1186/1758-2946-3-33](https://doi.org/10.1186/1758-2946-3-33)

* PDB2PQR
> T. J. Dolinsky, J. E. Nielsen, J. A. McCammon, and N. A. Baker. (2004) "PDB2PQR: an automated pipeline for the setup, execution, and analysis of Poisson-Boltzmann electrostatics calculations." Nucleic Acids Res. 32: W665-667. DOI:[10.1093/nar/gkh381](https://doi.org/10.1093/nar/gkh381)

* PropKa
> H. Li, A. D. Robertson, and J. H. Jensen. (2005) "Very Fast Empirical Prediction and Rationalization of Protein pKa Values." Proteins, 61: 704-721. DOI:[10.1002/prot.20660](https://doi.org/10.1002/prot.20660)

*  OpenMM
> P. Eastman, J. Swails, J. D. Chodera, R. T. McGibbon, Y. Zhao, K. A. Beauchamp, L.-P. Wang, A. C. Simmonett, M. P. Harrigan, C. D. Stern, R. P. Wiewiora, B. R. Brooks, and V. S. Pande. (2017) "OpenMM 7: Rapid development of high performance algorithms for molecular dynamics.” PLOS Comp. Biol. 13(7): e1005659. DOI:[10.1371/journal.pcbi.1005659](https://doi.org/10.1371/journal.pcbi.1005659)

*  NGLViewer
> H. Nguyen, D. A. Case, and A. S. Rose. (2018) "NGLview - Interactive molecular graphics for Jupyter notebooks" Bioinformatics 34(7): 1241-1242. DOI:[10.1093/bioinformatics/btx789](https://doi.org/10.1093/bioinformatics/btx789)

* MDTraj
> R. T. McGibbon, K. A. Beauchamp, M. P. Harrigan, C. Klein, J. M. Swails, C. X. Hernández, C. R. Schwantes, L-P. Wang, T. J. Lane, and V. S. Pande (2011) "MDTraj: A Modern Open Library for the Analysis of Molecular Dynamics Trajectories" Biophys. J. 109(8): 1528-1532. DOI:[10.1016/j.bpj.2015.08.015](https://doi.org/10.1016/j.bpj.2015.08.015)

* NumPy
> C. R. Harris, K. J. Millman, S. J. van der Walt et al. (2020) "Array programming with NumPy" Nature 585, 357-362. DOI:[10.1038/s41586-020-2649-2](https://doi.org/10.1038/s41586-020-2649-2)

* Scikit-learn
> F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, P. Prettenhofer, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, David Cournapeau, M. Brucher, M. Perrot, and E. Duchesnay (2011) "Scikit-learn: Machine Learning in Python" JMLR 12(85): 2825-2830. (https://jmlr.csail.mit.edu/papers/v12/pedregosa11a.html)

* Plotly
> Plotly Technologies Inc. (https://plot.ly)

* ipywidgets
> Jupyter widgets community (https://github.com/jupyter-widgets/ipywidgets)

* PLIP
> S. Salentin, S. Schreiber, V. J. Haupt, M. F. Adasme, and M. Schroeder (2015) "PLIP: fully automated protein-ligand interaction profiler" Nucleic Acids Research 43(W1): W443-W447. DOI:[10.1093/nar/gkv315](https://doi.org/10.1093/nar/gkv315)