In [1]:
import unyt as u
import ele
import numpy as np

import xml.etree.ElementTree as ET

In [14]:
def equil_bond_distance(atom1, atom2):
    """
    From the paper:
    ---------------
    Rij = Ri + Rj - delta
    delta = 0.01 Ang = 0.001 nm
    """
    return atom1.get("bond_r").to("nm").value + atom2.get("bond_r").to("nm").value - 0.001


def bond_energy(atom1, atom2):
    """
    From the paper:
    ---------------
    K_ij = 700kcal/mol for all single bonds
    K_ij(n) = n(K_ij) for bonds of order n
    """
    # Determine n from atom 1 and atom 2?
    n = 1
    return (n * 700 * u.kcal).to("kJ").value
    

def equil_bond_angle(central_atom):
    """
    From the paper:
    ---------------
    Equilibrium angle is determined only by middle atom
    E_ijk = 0.5 * C_ijk[cos(theta_ijk) - cos0(theta_j)]^2
    """
    return central_atom.get("theta").to("rad").value
    

def angle_energy(central_atom):
    """
    From the paper:
    ---------------
    - E_ijk = 0.5 * C_ijk[cos(theta_ijk) - cos0(theta_j)]^2
    - C_ijk = K_ijk / sin(theta_j)^2
    - K_ijk = 100kcal/mol for all angles
    - The equilibrium angle is determined by the middle atom (j/atom2)
    
    If the angle has linear geometry: (theta_j = 180 degrees) then the functional form changes
    In this work, we will use E_ijk for linear angles as well, but replace C with K (can't divide by sin(180)^2.
    """
    theta_j = central_atom.get("theta").to("rad")
    K = (100 * u.kcal).to("kJ")
    if np.round(np.sin(theta_j), 4) != 0:
        C = K / (np.sin(theta_j)**2)
        return C.value / 2
    else:
        return K.value


def equil_torsion_angle(atom1, atom2, atom3, atom4):
    """
    From the paper:
    ---------------
    E_ijkl = 0.5 * Vjk{1 - cos[n_jk(phi_ijkl - phi_jk)]}
    
    n: Periodicity
    phi_ijkl: Dihedral angle
    phi_jk: Equilibrium dihedral angle
    V_jk: Force constant

    V_jk, n and phi_jk only depend on atom 2 and aotm 3
    """

def create_forcefield_xml(output_file, atom_types_dict, bonds, angles):
    # Create the root element
    ForceField = ET.Element(
        "ForceField",
        name="mBuild_Dreiding",
        version="0.1.0",
        combining_rule="geometric"
    )

    # Add AtomTypes
    AtomTypes = ET.SubElement(ForceField, "AtomTypes")
    for atom_type, vals in atom_types_dict.items():
        ET.SubElement(
            AtomTypes,
            "Type",
            **{
                "name": atom_type,
                "class": atom_type,
                "element": vals["element"],
                "mass": str(vals["mass"]),
                "def": vals["_def"],
                "desc": vals["desc"],
                "doi": vals["doi"]
            }
        )
    
    # Add HarmonicBondForce
    HarmonicBondForce = ET.SubElement(ForceField, "HarmonicBondForce")
    for bond in bonds:
        atom1_dict = dreiding_atom_types[bond[0]]
        atom2_dict = dreiding_atom_types[bond[1]]
        eq_length = equil_bond_distance(atom1=atom1_dict, atom2=atom2_dict)
        bond_k = bond_energy(atom1=atom1_dict, atom2=atom2_dict)
        
        ET.SubElement(
            HarmonicBondForce,
            "Bond",
            **{
                "class1": bond[0],
                "class2": bond[1],
                "length": str(float(eq_length)),
                "k": str(float(bond_k)),
            }
        )

    # Add HarmonicAngleForce
    HarmonicAngleForce = ET.SubElement(ForceField, "HarmonicAngleForce")
    for angle in angles:
        atom2_dict = dreiding_atom_types[angle[1]]
        theta = equil_bond_angle(central_atom=atom2_dict)
        angle_k = angle_energy(central_atom=atom2_dict)

        if theta == 0:
            continue
        
        ET.SubElement(
            HarmonicBondForce,
            "Angle",
            **{
                "class1": angle[0],
                "class2": angle[1],
                "class3": angle[2],
                "angle": str(float(theta)),
                "k": str(float(angle_k)),
            }
        )

    # Add NonbondedForce
    NonbondedForce = ET.SubElement(
        ForceField, "NonbondedForce",
        coulomb14scale="0.0",
        lj14scale="1.0"
    )
    for atom_type, vals in atom_types_dict.items():
        ET.SubElement(
            NonbondedForce,
            "Atom",
            **{
                "type": atom_type,
                "charge": "0.0",
                "sigma": str(float(vals["vdw_r"].to("nm"))),
                "epsilon": str(float(vals["D"].to("kJ"))),
            }
        )
    tree = ET.ElementTree(ForceField)
    ET.indent(tree, space="  ", level=0)
    tree.write(output_file, encoding="utf-8", xml_declaration=True)

# Atom Types:

In [15]:
# Values as is from the paper.
# R in angstrom and D in kcal/mol

# Contains all info from Table I and Table II in the paper.
dreiding_atom_types = {
    "H_" : dict(
        vdw_r=3.195 * u.Angstrom,
        D=0.0152 * u.kcal,
        psi=12.382,
        bond_r=0.330 * u.Angstrom,
        theta=180.0 * u.deg,
        _def="[H;X1]",
        desc="Hydrogen forming one bond",
        mass=ele.element_from_symbol("H").mass,
        element="H",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "H_HB": dict( # HB means hydrogen capable of forming hydrogen bonds
        vdw_r=3.195 * u.Angstrom,
        D=0.0001 * u.kcal,
        psi=12.0,
        bond_r=0.330 * u.Angstrom,
        theta=180.0 * u.deg,
        _def="[H;X1][N,O,F]",
        overrides="H_",
        desc="Hydrogen capable of hydrogen bonding",
        mass=ele.element_from_symbol("H").mass,
        element="H",
        doi="10.1021/j100389a010",
        is_metal=False,
    ), 
    "H_b": dict( # Bridging hydrogen atom in Diborane; H forms 2 bonds
        vdw_r=3.195 * u.Angstrom,
        D=0.0152 * u.kcal,
        psi=12.382,
        bond_r=0.510 * u.Angstrom,
        theta=90.0 * u.deg,
        _def="[H;X2]",
        desc="Briding hydrogen with 2 bonds",
        mass=ele.element_from_symbol("H").mass,
        element="H",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "B": dict(
        vdw_r=4.02 * u.Angstrom,
        D=0.095 * u.kcal,
        psi=14.23,
        bond_r=0 * u.Angstrom,
        theta=0 * u.deg,
        _def="[B]",
        desc="Boron",
        mass=ele.element_from_symbol("B").mass,
        element="B",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "B_2": dict(
        vdw_r=3.8983 * u.Angstrom,
        D=0.0951 * u.kcal,
        psi=14.034,
        bond_r=0.790 * u.Angstrom,
        theta=120 * u.deg,
        _def="[B;X3]",
        desc="Boron with 3 bonds",
        overrides="B",
        mass=ele.element_from_symbol("B").mass,
        element="B",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "B_3": dict(
        vdw_r=3.8983 * u.Angstrom,
        D=0.0951 * u.kcal,
        psi=14.034,
        bond_r=0.880 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[B;X4]",
        overrides="B",
        desc="Boron with 4 bonds",
        mass=ele.element_from_symbol("B").mass,
        element="B",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "C_1": dict(
        vdw_r=3.8983 * u.Angstrom,
        D=0.0951 * u.kcal,
        psi=14.034,
        bond_r=0.602 * u.Angstrom,
        theta=180 * u.deg,
        _def="[C;X2;!r6;!r5]",
        desc="Linear carbon (sp1) with 2 bonds",
        mass=ele.element_from_symbol("C").mass,
        element="C",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "C_2": dict(
        vdw_r=3.8983 * u.Angstrom,
        D=0.0951 * u.kcal,
        psi=14.034,
        bond_r=0.670 * u.Angstrom,
        theta=120 * u.deg,
        _def="[C;X3;!r6;!r5]",
        desc="Trigonal carbon (sp2) with 3 bonds, not in a ring",
        mass=ele.element_from_symbol("C").mass,
        element="C",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "C_3": dict(
        vdw_r=3.8983 * u.Angstrom,
        D=0.0951 * u.kcal,
        psi=14.034,
        bond_r=0.770 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[C;X4;!r6;!r5]",
        desc="Tetrahedral carbon (sp3) with 4 bonds, not in a ring",
        mass=ele.element_from_symbol("C").mass,
        element="C",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "C_R": dict(
        vdw_r=3.8983 * u.Angstrom,
        D=0.0951 * u.kcal,
        psi=14.034,
        bond_r=0.700 * u.Angstrom,
        theta=120 * u.deg,
        _def="[C;r6]",
        desc="Carbon in a 6-membered ring",
        mass=ele.element_from_symbol("C").mass,        
        element="C",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "C_R5": dict(
        vdw_r=3.8983 * u.Angstrom,
        D=0.0951 * u.kcal,
        psi=14.034,
        bond_r=0.700 * u.Angstrom,
        theta=116 * u.deg, # Angle from OPLS-AA
        _def="[C;r5]",
        desc="Carbon in 5-membered rings",
        mass=ele.element_from_symbol("C").mass,
        element="C",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "N_1": dict(
        vdw_r=3.6621 * u.Angstrom,
        D=0.0774 * u.kcal,
        psi=13.843,
        bond_r=0.556 * u.Angstrom,
        theta=180 * u.deg,
        _def="[N;X2;!r6;!r5]",
        desc="Linear nitrogen with 2 bonds, not in a ring.",
        mass=ele.element_from_symbol("N").mass,
        element="N",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "N_2": dict(
        vdw_r=3.6621 * u.Angstrom,
        D=0.0774 * u.kcal,
        psi=13.843,
        bond_r=0.615 * u.Angstrom,
        theta=120 * u.deg,
        _def="[N;X3;!r6;!r5]",
        desc="Trigonal nitrogen with 3 bonds, not in a ring",
        mass=ele.element_from_symbol("N").mass,
        element="N",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "N_3": dict(
        vdw_r=3.6621 * u.Angstrom,
        D=0.0774 * u.kcal,
        psi=13.843,
        bond_r=0.702 * u.Angstrom,
        theta=106.7 * u.deg,
        _def="[N;X3;!r6;!r5]",
        desc="Tetrahedral nitrogen with 4 bonds, not in a ring",
        mass=ele.element_from_symbol("N").mass,
        element="N",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "N_R": dict(
        vdw_r=3.6621 * u.Angstrom,
        D=0.0774 * u.kcal,
        psi=13.843,
        bond_r=0.650 * u.Angstrom,
        theta=120 * u.deg,
        _def="[N;r6]",
        desc="Nitrogen in a 6-membered ring",
        mass=ele.element_from_symbol("N").mass,
        element="N",
        doi="10.1021/j100389a010",
        is_metal=False,
    ),
    "N_R5": dict(
        vdw_r=3.6621 * u.Angstrom,
        D=0.0774 * u.kcal,
        psi=13.843,
        bond_r=0.650 * u.Angstrom,
        theta=107.2 * u.deg,
        _def="[N;r5]",
        desc="Nitrogen in 5-membered rings",
        mass=ele.element_from_symbol("N").mass,
        element="N",
        doi="", # Added manually, use theta from OPLS-AA
        is_metal=False,
    ),
    "O_1": dict(
        vdw_r=3.4046 * u.Angstrom,
        D=0.0957 * u.kcal,
        psi=13.843,
        bond_r=0.528 * u.Angstrom,
        theta=180 * u.deg,
        _def="[O;X2;!r6;!r5]",
        desc="Linear oxygen",
        element="O",
        mass=ele.element_from_symbol("O").mass,
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "O_2": dict(
        vdw_r=3.4046 * u.Angstrom,
        D=0.0957 * u.kcal,
        psi=13.843,
        bond_r=0.560 * u.Angstrom,
        theta=120 * u.deg,
        _def="[O;X3;!r6;!r5]",
        desc="Trigonal oxygen, 3 bonds, not in a ring",
        mass=ele.element_from_symbol("O").mass,
        element="O",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "O_3": dict(
        vdw_r=3.4046 * u.Angstrom,
        D=0.0957 * u.kcal,
        psi=13.843,
        bond_r=0.660 * u.Angstrom,
        theta=104.51 * u.deg,
        _def="[O;X4;!r6;!r5]",
        desc="Tetrahedral oxygen, 4 bonds, not in a ring",
        mass=ele.element_from_symbol("O").mass,
        element="O",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "O_R": dict(
        vdw_r=3.4046 * u.Angstrom,
        D=0.0957 * u.kcal,
        psi=13.843,
        bond_r=0.790 * u.Angstrom,
        theta=120 * u.deg,
        _def="[O;r6]",
        desc="Oxygen in a 6-membered ring",
        mass=ele.element_from_symbol("O").mass,
        element="O",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "O_R5": dict(
        vdw_r=3.4046 * u.Angstrom,
        D=0.0957 * u.kcal,
        psi=13.843,
        bond_r=0.790 * u.Angstrom,
        theta=106.5 * u.deg,
        _def="[O;r5]",
        desc="Oxygen in 5-membered rings",
        mass=ele.element_from_symbol("O").mass,
        element="O",
        is_metal=False,
        doi="",
    ),
    "F_": dict(
        vdw_r=3.4720 * u.Angstrom,
        D=0.0725 * u.kcal,
        psi=14.444,
        bond_r=0.611 * u.Angstrom,
        theta=180 * u.deg,
        _def="[F]",
        desc="Any fluorine atom",
        mass=ele.element_from_symbol("F").mass,
        element="F",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "Al3": dict(
        vdw_r=4.39 * u.Angstrom,
        D=0.31 * u.kcal,
        psi=12.0,
        bond_r=1.047 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[Al]",
        desc="Any aluminum atom",
        mass=ele.element_from_symbol("Al").mass,
        element="Al",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),
    "Si3": dict(
        vdw_r=4.27 * u.Angstrom,
        D=0.31 * u.kcal,
        psi=12.0,
        bond_r=0.937 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[Si;X4]", # Just change this to any Si?
        desc="Tetrahedral silicon.",
        mass=ele.element_from_symbol("Si").mass,
        element="Si",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "P_3": dict(
        vdw_r=4.1500 * u.Angstrom,
        D=0.3200 * u.kcal,
        psi=12.0,
        bond_r=0.890 * u.Angstrom,
        theta=93.3 * u.deg,
        _def="[P]",
        mass=ele.element_from_symbol("P").mass,
        element="P",
        desc="Any phosphorus",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "S_3": dict(
        vdw_r=4.0300 * u.Angstrom,
        D=0.3440 * u.kcal,
        psi=12.0,
        bond_r=1.040 * u.Angstrom,
        theta=92.1 * u.deg,
        _def="[S;!r5,!r6]",
        mass=ele.element_from_symbol("S").mass,
        element="S",
        desc="Any sulfur not in a ring",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "S_R": dict(
        vdw_r=4.0300 * u.Angstrom,
        D=0.3440 * u.kcal,
        psi=12.0,
        bond_r=1.040 * u.Angstrom,
        theta=97.0 * u.deg, # Angle from OPLS-AA Carbon-Thiol sulfur-Carbon angle
        _def="[S;r5,r6]",
        desc="Heterocyclic sulfur",
        overrides="S_3",
        mass=ele.element_from_symbol("S").mass,
        element="S",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "Cl": dict(
        vdw_r=3.9503 * u.Angstrom,
        D=0.2833 * u.kcal,
        psi=13.861,
        bond_r=0.997 * u.Angstrom,
        theta=180 * u.deg,
        _def="[Cl]",
        desc="Any chlorine atom",
        mass=ele.element_from_symbol("Cl").mass,
        element="Cl",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "Ga3": dict(
        vdw_r=4.39 * u.Angstrom,
        D=0.40 * u.kcal,
        psi=12.0,
        bond_r=1.210 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[Ga]",
        desc="Any gallium atom",
        mass=ele.element_from_symbol("Ga").mass,
        element="Ga",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),    
    "Ge3": dict(
        vdw_r=4.27 * u.Angstrom,
        D=0.40 * u.kcal,
        psi=12.0,
        bond_r=1.210 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[Ge]",
        desc="Any germanium atom",
        mass=ele.element_from_symbol("Ge").mass,
        element="Ge",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),    
    "As3": dict(
        vdw_r=4.15 * u.Angstrom,
        D=0.41 * u.kcal,
        psi=12.0,
        bond_r=1.210 * u.Angstrom,
        theta=92.1 * u.deg,
        _def="[As]",
        desc="Any arsenic atom",
        mass=ele.element_from_symbol("As").mass,
        element="As",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),    
    "Se3": dict(
        vdw_r=4.03 * u.Angstrom,
        D=0.43 * u.kcal,
        psi=12.0,
        bond_r=1.210 * u.Angstrom,
        theta=90.6 * u.deg,
        _def="[Se]",
        desc="Any selenium atom",
        mass=ele.element_from_symbol("Se").mass,
        element="Se",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),    
    "Br": dict(
        vdw_r=3.9 * u.Angstrom,
        D=0.37 * u.kcal,
        psi=12.0,
        bond_r=1.167 * u.Angstrom,
        theta=180.0 * u.deg,
        _def="[Br]",
        desc="Any bromine atom",
        mass=ele.element_from_symbol("Br").mass,
        element="Br",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),    
    "In": dict(
        vdw_r=4.59 * u.Angstrom,
        D=0.55 * u.kcal,
        psi=12.0,
        bond_r=1.390 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[In]",
        desc="Any indium atom",
        mass=ele.element_from_symbol("In").mass,
        element="In",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),    
    "Sn3": dict(
        vdw_r=4.47 * u.Angstrom,
        D=0.55 * u.kcal,
        psi=12.0,
        bond_r=1.373 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[Sn]",
        desc="Any tin atom",
        mass=ele.element_from_symbol("Sn").mass,
        element="Se",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),    
    "Sb3": dict(
        vdw_r=4.35 * u.Angstrom,
        D=0.55 * u.kcal,
        psi=12.0,
        bond_r=1.432 * u.Angstrom,
        theta=91.6 * u.deg,
        _def="[Sb]",
        desc="Any antimony atom",
        mass=ele.element_from_symbol("Sb").mass,
        element="Sb",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),
    "Te3": dict(
        vdw_r=4.23 * u.Angstrom,
        D=0.57 * u.kcal,
        psi=12.0,
        bond_r=1.280 * u.Angstrom,
        theta=90.3 * u.deg,
        _def="[Te]",
        desc="Any tellurium atom",
        mass=ele.element_from_symbol("Te").mass,
        element="Te",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),
    "I_": dict(
        vdw_r=4.15 * u.Angstrom,
        D=0.51 * u.kcal,
        psi=12.0,
        bond_r=1.360 * u.Angstrom,
        theta=180.0 * u.deg,
        _def="[I]",
        desc="Any iodine atom",
        mass=ele.element_from_symbol("I").mass,
        element="I",
        is_metal=False,
        doi="10.1021/j100389a010",
    ),
    "Na": dict(
        vdw_r=3.144 * u.Angstrom,
        D=0.5 * u.kcal,
        psi=12.0,
        bond_r=1.860 * u.Angstrom,
        theta=90.0 * u.deg,
        _def="[Na]",
        desc="Any sodium atom",
        mass=ele.element_from_symbol("Na").mass,
        element="Na",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),
    "Ca": dict(
        vdw_r=3.427 * u.Angstrom,
        D=0.05 * u.kcal,
        psi=12.0,
        bond_r=1.940 * u.Angstrom,
        theta=90.0 * u.deg,
        _def="[Ca]",
        desc="Any calcium atom",
        mass=ele.element_from_symbol("Ca").mass,
        element="Ca",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),
    "Fe": dict(
        vdw_r=4.54 * u.Angstrom,
        D=0.055 * u.kcal,
        psi=12.0,
        bond_r=1.258 * u.Angstrom,
        theta=90.0 * u.deg,
        _def="[Fe]",
        desc="Any iron atom",
        mass=ele.element_from_symbol("Fe").mass,
        element="Fe",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),
    "Zn": dict(
        vdw_r=4.54 * u.Angstrom,
        D=0.055 * u.kcal,
        psi=12.0,
        bond_r=1.330 * u.Angstrom,
        theta=109.471 * u.deg,
        _def="[Zn]",
        desc="Any zinc atom",
        mass=ele.element_from_symbol("Zn").mass,
        element="Zn",
        is_metal=True,
        doi="10.1021/j100389a010",
    ),
}

# Bonding Rules

In [16]:
# Middle bond score: Goal is to limit which atoms are the central atom in a bond angle
# Middle bond scores: 0 = Never, 1 = Rare, 2 = less likely, 3 = very likely

bonding_rules = {
    "Al": {"allowed_bonds": ["C", "F", "H", "N", "O", "S", "Cl"], "middle_bond": 1},
    "As": {"allowed_bonds": ["C", "Cl", "H", "N", "O", "S"], "middle_bond": 2},
    "B": {"allowed_bonds": ["B", "C", "Cl", "F", "H", "I", "N", "O", "S"], "middle_bond": 2},
    "Br": {"allowed_bonds": ["C", "H", "N", "O", "S", "Br"], "middle_bond": 1},
    "C": {"allowed_bonds": ["B", "C", "Cl", "F", "H", "I", "Mg", "N", "O", "P", "S", "Si", "Br", "Ge",], "middle_bond": 3},
    "Ca": {"allowed_bonds": ["C", "Cl", "F", "N", "O", "S"], "middle_bond": 0},
    "Cl": {"allowed_bonds": ["C", "H", "I", "N", "O", "P", "S", "Br", "Cl", "Ge",], "middle_bond": 1},
    "Cu": {"allowed_bonds": ["N", "O", "S"], "middle_bond": 0},
    "Fe": {"allowed_bonds": ["C", "F", "N", "O", "P", "S", "Cl"], "middle_bond": 0},
    "F": {"allowed_bonds": ["Al", "B", "C", "Fe", "Mg", "N", "P", "S", "Si", "Zn", "Ge",], "middle_bond": 1},
    "Ga": {"allowed_bonds": ["C", "H", "N", "O", "S", "Cl", "Br"], "middle_bond": 1},
    "Ge": {"allowed_bonds": ["C", "Cl", "F", "H", "N", "O", "S"], "middle_bond": 3},
    "H": {"allowed_bonds": ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "B", "Ge",], "middle_bond": 0},
    "I": {"allowed_bonds": ["C", "H", "I", "N", "O", "P", "S"], "middle_bond": 1},
    "In": {"allowed_bonds": ["C", "H", "N", "O", "S", "Cl"], "middle_bond": 0},
    "Mg": {"allowed_bonds": ["C", "N", "O", "S"], "middle_bond": 0},
    "Na": {"allowed_bonds": ["C", "H", "N", "O", "S", "Cl"], "middle_bond": 0},
    "N": {"allowed_bonds": ["B", "C", "Cu", "Fe", "H", "Mg", "N", "O", "P", "S", "Si", "Zn", "F", "Cl", "Br", "I", "Ge",], "middle_bond": 3},
    "O": {"allowed_bonds": ["Al", "B", "C", "Ca", "Cu", "Fe", "H", "Mg", "Na", "N", "O", "P", "S", "Si", "Zn", "F", "Cl", "Br", "I", "Ge",], "middle_bond": 3},
    "P": {"allowed_bonds": ["C", "Cl", "F", "H", "N", "O", "S"], "middle_bond": 3},
    "Sb": {"allowed_bonds": ["C", "Cl", "H", "N", "O", "S"], "middle_bond": 2},
    "S": {"allowed_bonds": ["C", "F", "H", "N", "O", "P", "S", "Cl", "Ge",], "middle_bond": 2},
    "Se": {"allowed_bonds": ["C", "Cl", "H", "N", "O", "S", "Se"], "middle_bond": 2},
    "Si": {"allowed_bonds": ["C", "Cl", "F", "H", "N", "O", "S", "Ge",], "middle_bond": 3},
    "Sn": {"allowed_bonds": ["C", "Cl", "H", "N", "O", "S"], "middle_bond": 2},
    "Te": {"allowed_bonds": ["C", "H", "O", "S", "Se"], "middle_bond": 1},
    "Zn": {"allowed_bonds": ["C", "F", "N", "O", "P", "S", "Cl"], "middle_bond": 0},
}

In [17]:
def parse_bond_types(atom_type, exclude_metals):
    this_atom_dict = dreiding_atom_types[atom_type]
    element = dreiding_atom_types[atom_type]["element"]
    allowed_bonds = bonding_rules[element]["allowed_bonds"]
    
    bonds_list = set()
    for bonding_element in allowed_bonds:
        for other_atom_type, vals in dreiding_atom_types.items():
            if exclude_metals and vals["is_metal"]:
                continue
            if vals["element"] == bonding_element:
                sorted_bonds = sorted([atom_type, other_atom_type])
                bonds_list.add((sorted_bonds[0], sorted_bonds[1]))
    return bonds_list


def get_all_bonds(exclude_metals=True):
    master_bonds_list = set()
    for atom_type in dreiding_atom_types:
        if exclude_metals and dreiding_atom_types[atom_type]["is_metal"]:
            continue
        atom_type_bonds = parse_bond_types(atom_type, exclude_metals=exclude_metals)
        for bond_type in atom_type_bonds:
            master_bonds_list.add(bond_type)
    return master_bonds_list


def parse_angle_types(central_atom, exclude_metals):
    """Returns all possible angles where central_atom is the middle atom"""
    central_dict = dreiding_atom_types[central_atom]
    central_element = central_dict["element"]
    allowed_bonds = bonding_rules[central_element]["allowed_bonds"]
    # Get all atom types that can bond to the central atom
    bonded_atom_types = []
    for element in allowed_bonds:
        for atom_type, vals in dreiding_atom_types.items():
            if vals["element"] == element:
                if exclude_metals and vals["is_metal"]:
                    pass
                else:
                    bonded_atom_types.append(atom_type)
    # All unique pairs of bonded atom types for angles
    angles = set()
    for i, atom1 in enumerate(bonded_atom_types):
        for atom2 in bonded_atom_types[i:]:  # include i: to allow same atom type twice
            # Sort the outer atoms so (A,B,C) == (B,A,C)
            sorted_outer = sorted([atom1, atom2])
            angles.add((sorted_outer[0], central_atom, sorted_outer[1]))
    return angles

def get_all_angles(likely_hood_limit=1, exclude_metals=True):
    master_angles_list = set()
    for atom_type, vals in dreiding_atom_types.items():
        if exclude_metals and vals["is_metal"]:
            continue
        if bonding_rules[vals["element"]]["middle_bond"] >= likely_hood_limit:
            atom_type_angles = parse_angle_types(atom_type, exclude_metals=exclude_metals)
            for angle_type in atom_type_angles:
                master_angles_list.add(angle_type)
    return master_angles_list

# Create list of bonds and angles, write the XML File

In [18]:
all_bonds = get_all_bonds(exclude_metals=False)
all_angles = get_all_angles(likely_hood_limit=2, exclude_metals=True)

In [19]:
create_forcefield_xml(
    output_file="mbuild_dreiding.xml",
    atom_types_dict=dreiding_atom_types,
    bonds=all_bonds,
    angles=all_angles
)

# Notes on things I added

- Sulfur in 5 membered ring: OPLS has angle of 97 deg
- Sulfur in 6 membered ring: OPLS has an angle of ?
- Carbon in 5 membered ring: OPLS has angle of 116 deg
- Nitrogen in 5 membered ring: 107.2 deg