In [1]:
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union
from IPython.display import clear_output

from rdkit import Chem
from openff.units import unit
from openff.interchange import Interchange

from openff.toolkit.topology import Topology
from openff.toolkit.topology.molecule import FrozenMolecule, Molecule, Atom
from openff.toolkit.utils import toolkit_registry
from openff.toolkit.utils.toolkits import RDKitToolkitWrapper, OpenEyeToolkitWrapper, AmberToolsToolkitWrapper
from openff.toolkit.typing.engines.smirnoff import ForceField
from openff.toolkit.typing.engines.smirnoff import parameters as offtk_parameters

from openmm import LangevinMiddleIntegrator
from openmm.app import Simulation, PDBReporter, StateDataReporter
from openmm.unit import kelvin, picosecond, picoseconds, nanometer # need to do some unit conversion with both packages

## Useful functions for charge averaging and simulation

In [7]:
# Charge calculation methods
def generate_charged_molecule(pdbfile : str, substructure_file : Path, toolkit_method : str='openeye', partial_charge_method : str='am1bcc') -> tuple[Molecule, Topology]:
    '''Takes a .pbd structural file and a .json monomer file and produces a Molecule object with calculated partial charges using method of choice'''
    toolkits = {
        'openeye' : OpenEyeToolkitWrapper,
        'ambertools' : AmberToolsToolkitWrapper
    }

    off_topology, _, error = Topology.from_pdb_and_monomer_info(pdbfile, substructure_file, strict=True)
    # here, we assume that the topology only has ONE simple homopolymer. Later, all molecules can be extracted and charged
    mol = next(off_topology.molecules) # get the first molecule
    # get some conformers to run elf10 charge method. By default, `mol.assign_partial_charges`
    # uses 500 conformers, but we can generate and use 10 here for demonstration
    # mol.generate_conformers(
    #     n_conformers=10,
    #     rms_cutoff=0.25 * unit.angstrom,
    #     make_carboxylic_acids_cis=True,
    #     toolkit_registry=RDKitToolkitWrapper()
    # ) # very slow for large polymers! 

    mol.assign_partial_charges( # finally, assign partial charges using those 10 conformers generated 
        partial_charge_method=partial_charge_method, 
        toolkit_registry=toolkits.get(toolkit_method)()
    )
    
    return mol, off_topology # code for exact how thely above function works can be found in openff/toolkit/utils/openeye_wrapper.py under the assign_partial_charges()

def fetch_charged_mol(filename : str, parent_path : Path=Path.cwd()/'compatible_pdbs', extensions=('pdb', 'json'),
                      toolkit_method : str='openeye', partial_charge_method : str='am1bcc') -> tuple[Molecule, Topology]:
    '''Takes the name of a molecule and searches for associated .pbd and .json files
    If found, will perform charge assignment and monomer labelling and return the resultant charged Molecule object'''
    mol_files = {
        ext : path
            for path in parent_path.glob('**/*.*')
                for ext in extensions
                    if path.name == f'{filename}.{ext}'
    }

    for ext in extensions:
        if ext not in mol_files:
            raise FileNotFoundError(f'Could not find a(n) {ext} file \"{filename}.{ext}\"')
    else:
        charged_mol, topology = generate_charged_molecule(str(mol_files['pdb']), mol_files['json'], toolkit_method=toolkit_method)
        print(f'final molecular charges: {charged_mol.partial_charges}')

        # note: the charged_mol has metadata about which monomers were assigned where as a result of the chemicaly info assignment.
        # This can be a way to break up the molecule into repeating sections to partition the library charges 
        for atom in charged_mol.atoms:
            assert(atom.metadata['already_matched'] == True)
            # print(atom.metadata['residue_name'])
        
        return charged_mol, topology

# charge averaging methods
@dataclass
class Accumulator:
    '''Compact container for accumulating averages'''
    sum : float = 0.0
    count : int = 0

    @property
    def average(self) -> float:
        return self.sum / self.count

AveragedChargeMap = defaultdict[str, dict[int, float]] # makes typehinting clearer

def find_repr_residues(cmol : Molecule) -> dict[str, int]:
    '''Determine names and smallest residue numbers of all unique residues in charged molecule
    Used as representatives for generating labelled SMARTS strings '''
    rep_res_nums = defaultdict(set) # numbers of representative groups for each unique residue, used to build SMARTS strings
    for atom in cmol.atoms: 
        rep_res_nums[atom.metadata['residue_name']].add(atom.metadata['residue_number']) # collect unique residue numbers

    for res_name, ids in rep_res_nums.items():
        rep_res_nums[res_name] = min(ids) # choose group with smallest id of each residue to denote representative group

    return rep_res_nums

def averaged_charges_by_SMARTS(cmol : Molecule) -> AveragedChargeMap:
    '''Takes a charged molecule and averages charges for each repeating residues
    Returns a dict (indexed by SMARTS strings) of subdicts containing averaged charges for each atom in a residue'''
    rdmol = cmol.to_rdkit() # create rdkit representation of Molecule to allow for SMARTS generation
    rep_res_nums = find_repr_residues(cmol) # determine ids of representatives of each unique residue

    atom_ids_for_SMARTS = defaultdict(list)
    avg_charges_by_res = defaultdict(lambda : defaultdict(Accumulator))
    for atom in cmol.atoms: # accumulate counts and charge values across matching substructures
        res_name, substruct_id, atom_id = atom.metadata['residue_name'], atom.metadata['substructure_id'], atom.metadata['pdb_atom_id']
        if atom.metadata['residue_number'] == rep_res_nums[res_name]: # if atom is member of representative group for any residue...
            atom_ids_for_SMARTS[res_name].append(atom_id)             # ...collect pdb id...
            rdmol.GetAtomWithIdx(atom_id).SetAtomMapNum(substruct_id) # ...and set atom number for labelling in SMARTS string

        curr_accum = avg_charges_by_res[res_name][substruct_id] # accumulate charge info for averaging
        curr_accum.sum += atom.partial_charge.magnitude # eschew units (easier to handle, added back when writing to XML)
        curr_accum.count += 1

    avg_charges_by_SMARTS = defaultdict(dict)
    for res_name, charge_map in avg_charges_by_res.items():
        SMARTS = Chem.rdmolfiles.MolFragmentToSmarts(rdmol, atomsToUse=atom_ids_for_SMARTS[res_name]) # determine SMARTS for the current residue's representative group
        for substruct_id, accum in charge_map.items():
            avg_charges_by_SMARTS[SMARTS][substruct_id] = accum.average # collapse accumulators into actual average values

    return avg_charges_by_SMARTS

def write_new_library_charges(mol_name : str, offxml_file : Path, output_name : str, toolkit_method : str='openeye',
                              partial_charge_method : str='am1bcc') -> tuple[Molecule, AveragedChargeMap, Interchange]:
    '''Loads a molecule, calculates partial charges, partitions molecule into monomer residues, computes average charges for atoms in distinct residues,
     and appends library charges based on these averaged charges to a .offxml file of choice, creating a new xml with the sspecified filename'''
    cmol, topology = fetch_charged_mol(mol_name, toolkit_method=toolkit_method, partial_charge_method=partial_charge_method) # will raise exception if files for molecule are not found
    clear_output() # for Jupyter notebooks only, can freely comment this out
    avgs = averaged_charges_by_SMARTS(cmol) # average charges over unique residues

    forcefield = ForceField(offxml_file) # simpler to add library charges through forcefield API than to directly write to xml
    lc_handler = forcefield["LibraryCharges"]

    for smirks, charges in avgs.items():
        lc_entry = {f'charge{cid}' : f'{charge} * elementary_charge' for cid, charge in charges.items()} # stringify charges into form usable for library charges
        lc_entry['smirks'] = smirks # add SMIRKS string to library charge entry to allow for labelling

        lc_params = offtk_parameters.LibraryChargeHandler.LibraryChargeType(allow_cosmetic_attributes=True, **lc_entry) # must enable cosmetic params for general kwarg passing
        lc_handler.add_parameter(parameter=lc_params)

    interchange = Interchange.from_smirnoff(force_field=forcefield, topology=topology) # generate Interchange with new library charges prior to writing to file
    forcefield.to_file(offxml_file.parent/f'{output_name}.offxml') # write modified library charges to new xml (avoid overwrites in case of mistakes)
    
    return cmol, avgs, interchange

# OpenMM simulation methods
def create_sim_from_interchange(interchange : Interchange) -> Simulation:
    '''Sets up a Simulation object using topology and force field data as specified by an Interchange object
    Converts topologies and positions to OpenMM format from OpenFF formats (can support GROMACS format too in future)'''
    openmm_sys = interchange.to_openmm(combine_nonbonded_forces=True) 
    openmm_top = interchange.topology.to_openmm()
    openmm_pos = interchange.positions.m_as(unit.nanometer) * nanometer
    integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.0005*picoseconds)

    simulation = Simulation(openmm_top, openmm_sys, integrator)
    simulation.context.setPositions(openmm_pos)

    return simulation

def run_simulation(simulation : Simulation, num_steps=1000, record_freq=10, output_name='md_output') -> None:
    '''Takes a Simulation object, performs energy minimization, and runs the simulation for the specified number of time steps,
    recording PBD frames and numerical data to file at the specified frequency using '''
    outdir = Path(output_name)
    outdir.mkdir(exist_ok=True) # ensure a folder for output exists

    # for saving pdb frames and reporting state/energy data
    pdb_rep = PDBReporter(f'{output_name}/{output_name}_frames.pdb', record_freq)
    state_rep = StateDataReporter(f'{output_name}/{output_name}_data.csv', record_freq, step=True, potentialEnergy=True, temperature=True)

    # minimize and run simulation
    simulation.minimizeEnergy()
    simulation.saveCheckpoint(f'{output_name}/{output_name}_checkpoint.chk') # save initial minimal state to simplify reloading process
    simulation.reporters.append(pdb_rep) # save frames at the specified interval
    simulation.reporters.append(state_rep)
    simulation.step(num_steps)

## Running averaging code for test molecule

In [4]:
mol_name = 'polymethylketone'
offxml_file = Path('xml examples/openff_unconstrained_with_library_charges-2.0.0.offxml')

cmol, avgs, interchange = write_new_library_charges(mol_name, offxml_file, output_name=f'new {mol_name} charges')
for smiles, subdict in avgs.items():
    print(f'{smiles}\n\t{subdict}\n')

# TOSELF: add method to load from newly generated file rather than regenerating/re-averaging each time
sim = create_sim_from_interchange(interchange)
run_simulation(sim, output_name=f'{mol_name}_sim')

[#8](-[#6:1]1:[#6:2](-[H:3]):[#6:4](-[H:5]):[#6:6](-[#6:7]23-[#6:8](-[H:9])(-[H:10])-[#6:11]4(-[H:12])-[#6:23](-[H:24])(-[H:25])-[#6:21](-[H:22])(-[#6:18](-[H:19])(-[H:20])-[#6:16](-[H:17])(-[#6:13]-4(-[H:14])-[H:15])-[#6:29]-2(-[H:30])-[H:31])-[#6:26]-3(-[H:27])-[H:28]):[#6:32](-[#8:33]):[#6:34]:1-[H:35])-[#6:36]1:[#6:44](-[H:45]):[#6:42](-[H:43]):[#6:41](:[#6:39](-[H:40]):[#6:37]:1-[H:38])-[#6:46](-[#6:47]1:[#6:56](-[H:57]):[#6:54](-[H:55]):[#6:52](-[H:53]):[#6:50](-[H:51]):[#6:48]:1-[H:49])=[#8:58]
	{0: -0.2483700068778027, 33: -0.2518100144452084, 53: 0.13902999440965763, 12: 0.04984999800858609, 24: 0.041859999101735275, 25: 0.041859999101735275, 22: 0.04984999800858609, 19: 0.041859999101735275, 20: 0.041859999101735275, 17: 0.04984999800858609, 14: 0.041859999101735275, 15: 0.041859999101735275, 9: 0.0478300003224384, 10: 0.0478300003224384, 30: 0.0478300003224384, 31: 0.0478300003224384, 27: 0.0478300003224384, 28: 0.0478300003224384, 5: 0.14724999646482578, 3: 0.15615999678907

In [None]:
cmol.visualize()

## Example for assigning atom ids in SMARTS

In [None]:
rdmol = cmol.to_rdkit()
smarts_no_map = Chem.rdmolfiles.MolFragmentToSmarts(rdmol, atomsToUse=[i for i in range(5,10)])
# how to specify atom map numbers
i = 0
for atom in rdmol.GetAtoms():
    i += 1
    atom.SetAtomMapNum(atom.GetIdx())
smarts_yes_map = Chem.rdmolfiles.MolFragmentToSmarts(rdmol, atomsToUse=[i for i in range(5,10)])

print(smarts_no_map)
print(smarts_yes_map)

In [None]:
for atom in rdmol.GetAtoms(): # checking that atom types match between rdkit and openff version
    n = atom.GetIdx()
    if atom.GetAtomicNum() != cmol.atoms[n].metadata['atomic_number']:
        print(f'Mismatch at atom {n}')
        break
else:
    print('All good!')


## Playing with NX to get a feel for it

In [None]:
import networkx as nx

G = nx.Graph()
G.add_node(0, val=6, attr='stuff')
G.add_node(3, val=7, attr='other')
G.nodes[3]['attr']

In [None]:
G.add_edge(1, 2, weight=10)
G.edges[1, 2]['weight']

## Testing XML encoding

In [None]:
import xml
import xml.etree.ElementTree as ET

p = Path('xml examples/test.offxml')
p.touch()

top = ET.Element('a')
new = ET.SubElement(top, 'b')
new.attrib = {'first' : '4', 'second' : '5'}
 
tree = ET.ElementTree(top)

ET.dump(top) # print out tree
tree.write(p, encoding='utf-8', xml_declaration=True) # write to file