In [3]:
import numpy as np
import re, copy
from scipy.stats import linregress
from scipy.optimize import fmin_l_bfgs_b
import matplotlib.pyplot as plt
from scipy.special import erfc
from pymatgen.optimization.neighbors import find_points_in_spheres

# read and write itp forcefield file

In [4]:
# read itp file

def get_atom_para(ff_dict, ff_para):
    
    atom_types   = dict([[i[0],i[-2:]] for i in ff_dict["atomtypes"]])
    atom_type   = [i[1] for i in ff_dict["atoms"]]
    atom_name   = [i[4] for i in ff_dict["atoms"]]
    vdw_para    = np.array([[atom_types[i[1]] for i in ff_dict["atoms"]]], dtype=np.float32).tolist()
    atom_charge = np.array([i[-2] for i in ff_dict["atoms"]], dtype=np.float32).tolist()
    atom_mass   = np.array([i[-1] for i in ff_dict["atoms"]], dtype=np.float32).tolist()
    
    ff_para["vdw_para"]    = vdw_para[0]
    ff_para["atom_name"]   = atom_name
    ff_para["atom_type"]   = atom_type
    ff_para["mass"]        = atom_mass
    ff_para["charge"]      = atom_charge
    ff_para["mol_name"]    = ff_dict["moleculetype"][0][0]
    
    tmp_pairs = ff_dict.get("pairs", [])
    if ff_dict["pairs"] != []:
        tmp_pairs = np.array(ff_dict["pairs"], dtype=np.int32)        
        ff_para["pairs"]["type"]   = tmp_pairs[:,2].tolist()
        ff_para["pairs"]["pair"]   = tmp_pairs[:,:2].tolist()
    
    
    exclusions = ff_dict.get("exclusions", [])
    if exclusions != []:
        exclusions = [[int(j)for j in i] for i in exclusions]    
        ff_para["exclusions"] = exclusions
    
    return(ff_para)

def get_bond_para(ff_dict, ff_para):
    
    bond_type = np.array([i[2]  for i in ff_dict["bonds"]], dtype=np.int32).tolist()
    bond_pair = np.array([i[:2] for i in ff_dict["bonds"]], dtype=np.int32).tolist()
    bond_para = np.array([i[3:] for i in ff_dict["bonds"]], dtype=np.float32).tolist()
    
    ff_para["bonds"]["type"] = bond_type
    ff_para["bonds"]["pair"] = bond_pair
    ff_para["bonds"]["para"] = bond_para
    
    
    return(ff_para)
    

def get_angle_para(ff_dict, ff_para):
    
    if ff_dict.get("angles", False) :
        angle_type = np.array([i[3]  for i in ff_dict["angles"]], dtype=np.int32).tolist()
        angle_pair = np.array([i[:3] for i in ff_dict["angles"]], dtype=np.int32).tolist()
        angle_para = np.array([i[4:] for i in ff_dict["angles"]], dtype=np.float32).tolist()


        ff_para["angles"]["type"] = angle_type
        ff_para["angles"]["pair"] = angle_pair
        ff_para["angles"]["para"] = angle_para
    
    return(ff_para)

def get_dihe_para(ff_dict, ff_para):
    
    if ff_dict.get("dihedrals", False) :

        tmp_order = ["1", "2", "3", "4", "6"]
        dihe_line = copy.deepcopy(ff_dict["dihedrals"])

        tmp_line = [i[:5] + [i[-1]]  for i in dihe_line if i[4]=="9"]

        for line in tmp_line:
            for i in tmp_order:
                if line[4] == "9" and not line[:5] +[i] in tmp_line :
                    dihe_line.append(line[:5] + ["0.0", "0.0"] +[i])
                    tmp_line.append(line[:5] + [i] )


        dihe_line = sorted(dihe_line, key=lambda x : " ".join(x[:5]+[x[-1]])) 


        dihe_type  = np.array([i[4]    for i in dihe_line if i[4]=="9" ], dtype=np.int32).tolist()
        dihe_pair  = np.array([i[:4]   for i in dihe_line if i[4]=="9" ], dtype=np.int32).tolist()
        dihe_para  = np.array([i[5:-1] for i in dihe_line if i[4]=="9" ], dtype=np.float32).tolist()
        dihe_order = np.array([i[-1]   for i in dihe_line if i[4]=="9" ], dtype=np.int32).tolist()

        impo_type  = np.array([i[4]    for i in dihe_line if i[4]!="9" ], dtype=np.int32).tolist()    
        impo_pair  = np.array([i[:4]   for i in dihe_line if i[4]!="9" ], dtype=np.int32).tolist()
        impo_para  = np.array([i[5:-1] for i in dihe_line if i[4]!="9" ], dtype=np.float32).tolist()
        impo_order = np.array([i[-1]   for i in dihe_line if i[4]!="9" ], dtype=np.int32).tolist() 

    
        ff_para["dihedrals"]["type"]  = dihe_type 
        ff_para["dihedrals"]["pair"]  = dihe_pair 
        ff_para["dihedrals"]["para"]  = dihe_para 
        ff_para["dihedrals"]["order"] = dihe_order

        ff_para["impropers"]["type"]  = impo_type 
        ff_para["impropers"]["pair"]  = impo_pair 
        ff_para["impropers"]["para"]  = impo_para 
        ff_para["impropers"]["order"] = impo_order
    

    return(ff_para)
 

def get_pos_restrain_para(ff_dict, ff_para):
    
    if ff_dict.get("position_restraints", False) :
        pos_restrain_line = copy.deepcopy(ff_dict["position_restraints"])

        atom_id  = [int(i[0]) for i in pos_restrain_line ]
        pos_type = [int(i[1]) for i in pos_restrain_line ]
        pos_para = [[flost(j) for j in i[2:]] for i in pos_restrain_line ]

        ff_dict["position_restraints"]["type"] = pos_type
        ff_dict["position_restraints"]["id"]   = atom_id 
        ff_dict["position_restraints"]["para"] = pos_para
    
    return(ff_dict)

def itp_read(itp_file, ff_type = "gaff"):
    
    itp_txt = open(itp_file, "r").read()
    itp_txt = re.sub(';.*', '' , itp_txt)
    itp_txt = re.sub('[\[\]]', '' , itp_txt)

    itp_line = itp_txt.split("\n")
    
    part_list = ["atomtypes", "moleculetype", "atoms","bonds","pairs","angles",
                 "dihedrals", "exclusions", "" ]
    
    ff_dict = {"atomtypes":[], "moleculetype":[], "atoms":[],     "bonds":[],
               "pairs":[],     "angles":[],       "dihedrals":[],  "exclusions":[], # "position_restraints":[]
              }
    
    addition_part = []
    for line in itp_line:
        line_split = line.split()
        
        if len(line_split) == 1 :
            part = line_split[0].lower()
        elif len(line_split) > 0 and part in part_list :
            ff_dict[part].append(line_split)
        elif len(line_split) > 0 and part not in part_list :
            addition_part.append(line)

    ff_para = {"vdw_para":[], "bonds":{}, "angles":{}, "dihedrals":{},
               "impropers":{}, "pairs":{}, "mol_name":"", "ff_type":"gaff","charge":[],"mass":[],
               "atom_type":[],"atom_name":[], "exclusions":[],
               "position_restraints":[]}
    
    ff_para["addition_part"] = addition_part

    ff_para = get_atom_para(ff_dict,  ff_para)
    ff_para = get_bond_para(ff_dict,  ff_para)
    ff_para = get_angle_para(ff_dict, ff_para)
    ff_para = get_dihe_para(ff_dict,  ff_para)
    #ff_para = get_pos_restrain_para(ff_dict, ff_para)
    ff_para["ff_type"] = ff_type
    
    return(ff_para)
    


# write itp file

def get_moleculetype_line(ff_para):
    
    moleculetype = [ "[ moleculetype ]",
                     " ;name            nrexcl",
                     "   %s              3" %(ff_para["mol_name"])]
    
    moleculetype = "\n".join(moleculetype)
    
    return(moleculetype)

def get_atomtypes_line(ff_para):
    
    atom_type   = ff_para["atom_type"]
    vdw_para    = ff_para["vdw_para"]

    
    txt = ["[ atomtypes ]",
           ";name   bond_type     mass     charge   ptype   sigma         epsilon       Amb"]
    
    
    for n in range(len(atom_type)):   
        tmp_line = " %5s %5s          0.000      0.000    A   %10.7f  %10.7f " %(atom_type[n], atom_type[n],
                                                                                  vdw_para[n][0], vdw_para[n][1])
        if not tmp_line in txt:
            txt.append(tmp_line)
        
    return(txt)
    

def get_atom_line(ff_para):
    
    atom_type   = ff_para["atom_type"]
    atom_name   = ff_para["atom_name"]
    atom_mass   = ff_para["mass"]
    atom_charge = ff_para["charge"]
    mol_title   = ff_para["mol_name"]
    
    txt = ["[ atoms ]",
           ";   nr   type  resi   res  atom  cgnr     charge      mass"]
    
    
    for n in range(len(ff_para["mass"])):        
        

        
        tmp_line = " %5d %5s %5d  %5s%5s %5d   %10.7f  %10.7f " %(n+1,atom_type[n], 1, mol_title, atom_name[n],
                                                                 n, atom_charge[n], atom_mass[n])
        txt.append(tmp_line)
        
    return("\n".join(txt))
    
   
    

def get_bond_line(ff_para):
    
    txt = ["[ bonds ]",
           " ;   ai   aj   funct     r          k"]

    if ff_para.get("bonds", {}) != {}:
        atom_name  = ff_para["atom_name"]
        bond_type  = ff_para["bonds"]["type"]
        bond_pair  = ff_para["bonds"]["pair"]
        bond_para  = ff_para["bonds"]["para"]       
    
        
        for n in range(len(bond_pair)):
            
            tmp_pair   = bond_pair[n]
            bond_atom  = "%s - %s" %(atom_name[tmp_pair[0] - 1], atom_name[tmp_pair[1] - 1])
            tmp_line = "  %4d %4d   %4d    %8.6f   %e  ; %s" %(tmp_pair[0], tmp_pair[1],bond_type[n],
                                                               bond_para[n][0], bond_para[n][1],
                                                               bond_atom)
            txt.append(tmp_line)

    return("\n".join(txt))
  
    

def get_angle_line(ff_para):
    
    
    txt = ["[ angles ]",
           ";   ai   aj  ak    funct    theta          cth"]

    if ff_para.get("angles", {}) != {}:
        atom_name   = ff_para["atom_name"]
        angle_type  = ff_para["angles"]["type"]
        angle_pair  = ff_para["angles"]["pair"]
        angle_para  = ff_para["angles"]["para"]


        for n in range(len(angle_type)):

            tmp_pair   = angle_pair[n]
            angle_atom = " - ".join([atom_name[i - 1] for i in tmp_pair])

            tmp_line = "  %4d %4d %4d %4d    %e   %e  ; %s" %tuple(tmp_pair + [angle_type[n]] +
                                                               angle_para[n] + [angle_atom])
            txt.append(tmp_line)
    return("\n".join(txt))
 
    

def get_dihe_line(ff_para):

    txt = ["[ dihedrals ] ; propers",
           ";    i   j    k    l   func     phase          kd         pn"]
    if ff_para.get("dihedrals", {}) != {}:
        atom_name  =  ff_para["atom_name"]
        dihe_type  =  ff_para["dihedrals"]["type"]  
        dihe_pair  =  ff_para["dihedrals"]["pair"]  
        dihe_para  =  ff_para["dihedrals"]["para"]  
        dihe_order =  ff_para["dihedrals"]["order"] 

        for n in range(len(dihe_order)):

            tmp_pair   = dihe_pair[n]
            dihe_atom = " - ".join([atom_name[i - 1] for i in tmp_pair])

            tmp_line = "  %4d %4d %4d %4d %4d   %e   %e  %d  ; %s" %tuple(tmp_pair + [dihe_type[n]] +
                                                                          dihe_para[n] + [dihe_order[n], dihe_atom])
            txt.append(tmp_line)
    
    return("\n".join(txt))

def get_impo_line(ff_para):
    
    txt = ["[ dihedrals ] ; impropers",
           ";    i   j    k    l   func     phase          kd         pn"]


    if ff_para.get("impropers", {}) != {}:
        atom_name  =  ff_para["atom_name"]
        impo_type  =  ff_para["impropers"]["type"]   
        impo_pair  =  ff_para["impropers"]["pair"]   
        impo_para  =  ff_para["impropers"]["para"]   
        impo_order =  ff_para["impropers"]["order"]  
        
       
        for n in range(len(impo_order)):
            
            tmp_pair   = impo_pair[n]
            impo_atom = " - ".join([atom_name[i - 1] for i in tmp_pair])
            
            tmp_line = "  %4d %4d %4d %4d %4d   %e   %e  %d  ; %s" %tuple(tmp_pair + [impo_type[n]] +
                                                                          impo_para[n] + [impo_order[n], impo_atom])
            txt.append(tmp_line)
    
    return("\n".join(txt))

def get_pairs_line(ff_para):

    txt = ["[ pairs ]",
           ";   ai     aj    funct "]

    if ff_para.get("pairs", {}) != {}:
        atom_name = ff_para["atom_name"]
        pair_type = ff_para["pairs"]["type"]
        pair_pair = ff_para["pairs"]["pair"]
        
    
        
        for n in range(len(pair_type)):
            
            tmp_pair   = pair_pair[n]
            pair_atom  = "%s - %s" %(atom_name[tmp_pair[0] - 1], atom_name[tmp_pair[1] - 1])
            tmp_line = "  %4d  %4d   %4d    ; %s" %(tmp_pair[0], tmp_pair[1], pair_type[n], pair_atom)
            txt.append(tmp_line)

    return("\n".join(txt))

def get_exclusions_line(ff_para):

    txt = ["[ exclusions ]",
          "; i  j"]

    if ff_para.get("exclusions", {}) != {}:
        exclusions = ff_para["exclusions"]
        exclusions = [[str(j)for j in i]for i in exclusions]    



        for i in exclusions:
            tmp_line = "  ".join(i)
            txt.append(tmp_line)
    
    return("\n".join(txt))
    

def get_pos_restraint_line(para):
    pos_restraint_line = ["[ position_restraints ]",
                          "; atom  type  fx    fy   fz"]
    for i in para:
       pos_restraint_line.append("  %d  %d   %f    %f    %f" %tuple(i)) 
    
    pos_restraint_str = "\n".join(pos_restraint_line)
    return(pos_restraint_str)

def get_itp_line(ff_para, pos_restraint = None, write_atomtypes = False):
    
    moleculetype_line = get_moleculetype_line(ff_para)
    atomtypes_line    = get_atomtypes_line(ff_para)    
    atom_line         = get_atom_line(ff_para)
    bond_line         = get_bond_line(ff_para)
    angle_line        = get_angle_line(ff_para)
    dihe_line         = get_dihe_line(ff_para)
    impo_line         = get_impo_line(ff_para)
    pairs_line        = get_pairs_line(ff_para)
    exclusions_line   = get_exclusions_line(ff_para)
    
    

    if pos_restraint:
        pos_restraint_line = get_pos_restraint_line(pos_restraint)
    else:
        pos_restraint_line = ""
    
    ff_type = ff_para.get("ff_type","gaff")
    if ff_type == "gaff" :
        default_line = ["[ defaults ]",
                    "  ;nbfunc  comb-rule  gen-pairs  fudgeLJ  fudgeQQ",
                    "    1        2          yes        0.5      0.8333"]
    elif ff_type == "opls":
        default_line = ["[ defaults ]",
                    "  ;nbfunc  comb-rule  gen-pairs  fudgeLJ  fudgeQQ",
                    "    1        3          yes        0.5      0.5"]
    
    default_line = "\n".join(default_line)
    atomtypes_line = "\n".join(atomtypes_line)

    addition_line = "\n".join(ff_para["addition_aprt"])
                    

    if write_atomtypes:
        itp_txt = [default_line      ,
                   moleculetype_line ,
                   atomtypes_line    ,
                   atom_line         ,
                   bond_line         ,
                   angle_line        ,
                   dihe_line         ,
                   impo_line         ,
                   pairs_line        ,
                   exclusions_line   ,
                   addition_line,     ]
    else:
        itp_txt = [moleculetype_line ,
                   atom_line         ,
                   pos_restraint_line,
                   bond_line         ,
                   angle_line        ,
                   dihe_line         ,
                   impo_line         ,
                   pairs_line        ,
                   exclusions_line   ,
                   addition_line,     ]
        
        
    itp_txt = "\n\n".join(itp_txt)
    
    return(itp_txt)

# write top file

def write_top(structure, ff_para_list, pos_restraint = None, write_itp = True):
    
    mol_types = structure.mol_types
    
    mol_list = [mol_types[0]]
    mol_num  = [1]
    for i in mol_types[1:]:
        if i == mol_list[-1]:
            mol_num[-1] = mol_num[-1] + 1
        else:
            mol_list.append(i)
            mol_num.append(1)
            
    
    ff_type = list(set([ff_para.get("ff_type","gaff") for ff_para in ff_para_list]))
    assert len(ff_type) == 1, "Forcefield types should be same."
    ff_type = ff_type[0]

    if ff_type == "gaff" :
        default_line = ["[ defaults ]",
                    "  ;nbfunc  comb-rule  gen-pairs  fudgeLJ  fudgeQQ",
                    "    1        2          yes        0.5      0.8333"]
    elif ff_type == "opls":
        default_line = ["[ defaults ]",
                    "  ;nbfunc  comb-rule  gen-pairs  fudgeLJ  fudgeQQ",
                    "    1        3          yes        0.5      0.5"]
    default_line = "\n".join(default_line)


    atomtypes_line = ["[ atomtypes  ]",
                      ";   nr   type  resi   res  atom  cgnr     charge      mass"]
    
    include_line = ["; Include itp topology file"]

    mol_restrant_list = list(pos_restraint.keys()) if pos_restraint else []
    for ff_para in ff_para_list:
        if pos_restraint and ff_para["mol_name"] in mol_restrant_list:            
            restraint_para = pos_restraint[ff_para["mol_name"]]
        else:
            restraint_para = None
        include_line.append('#include "%s.itp" '  %(ff_para["mol_name"]))
        itp_file = open("%s.itp" %(ff_para["mol_name"]), "w")
        itp_file.write(get_itp_line(ff_para, pos_restraint = restraint_para, write_atomtypes = False))
        itp_file.close()
        
        atomtypes_line = atomtypes_line + get_atomtypes_line(ff_para)[2:]
    
    atomtypes_line = "\n".join(atomtypes_line)
    include_line = "\n".join(include_line)
    
    system_line = ["[ system ]",
                   "system"     ] 
    
    system_line = "\n".join(system_line)
    
    molecules_line = ["[ molecules ]",
                      "; Compound        nmols", ]
    
        
    for mol_type, mol_num in zip(mol_list, mol_num )  :
        molecules_line.append("  %s              %d" %(mol_type, mol_num ))
    
        
    molecules_line = "\n".join(molecules_line)
    

    top_line = [default_line,
                atomtypes_line,
                include_line,
                system_line,
                molecules_line,

               ]

    top_line = "\n\n".join(top_line)
    
    topology_file = open("topol.top", "w")
    
    topology_file.write(top_line)
    topology_file.close()


# nonbond part

In [5]:
# reciprocal space part with particle mesh ewald method
def Bspline(dr, PME_ORDER=6):
    """Compute the spline coefficients for a given distance.
       the code is adapted from the openmm implementation of PME 
       (https://github.com/openmm/openmm/blob/master/plugins/cpupme/src/CpuPmeKernels.cpp)
    """

    scale = 1.0 / (PME_ORDER - 1)
    data = np.zeros(PME_ORDER)
    data[PME_ORDER-1] = 0.
    data[0] = 1.0 - dr 
    data[1] = dr
    for j in range(3, PME_ORDER):
        div = 1.0 / (j - 1)
        data[j-1] = div * dr * data[j-2]
        for k in range(1, j-1):
            data[j-k-1] = div * ((dr + k) * data[j-k-2] + (j - k - dr) * data[j-k-1])
        data[0] = div * (1.0 - dr) * data[0]
    
    ddata = np.zeros(PME_ORDER)
    ddata[0] = -data[0]
    for i in range(1, PME_ORDER):
        ddata[i] = data[i-1] - data[i]
    
    data[PME_ORDER-1] = scale*dr * data[PME_ORDER-2]
    for j in range(1, PME_ORDER-1):
        data[PME_ORDER-j-1] = scale * ((dr + j) * data[PME_ORDER-j-2] + (PME_ORDER - j - dr) * data[PME_ORDER-j-1])

    data[0] = scale * (1.0 - dr) * data[0]
    
    return(data,ddata)

def Bm(grid_num, PME_ORDER=6):
    """Compute the spline moduli for a given grid.
       the code is adapted from the openmm implementation of PME 
    """
    
    data, ddata  = Bspline(0.0 , PME_ORDER=PME_ORDER)
    bm_x = np.zeros(grid_num[0])
    for i in range(grid_num[0]):
        arg = 2.0 * np.pi * i * np.arange(PME_ORDER-1) / grid_num[0]
        cos = np.sum(data[:-1] * np.cos(arg))
        sin = np.sum(data[:-1] * np.sin(arg))
        bm_x[i] = cos**2 + sin**2

    bm_y = np.zeros(grid_num[1])
    for i in range(grid_num[1]):
        arg = 2.0 * np.pi * i * np.arange(PME_ORDER-1) / grid_num[1]
        cos = np.sum(data[:-1]  * np.cos(arg))
        sin = np.sum(data[:-1]  * np.sin(arg))
        bm_y[i] = cos**2 + sin**2

    bm_z = np.zeros(grid_num[2])
    for i in range(grid_num[2]):
        arg = 2.0 * np.pi * i * np.arange(PME_ORDER-1) / grid_num[2]
        cos = np.sum(data[:-1]  * np.cos(arg))
        sin = np.sum(data[:-1]  * np.sin(arg))
        bm_z[i] = cos**2 + sin**2
        
    bm = bm_x.reshape((-1,1,1)) * bm_y.reshape((1,-1,1)) * bm_z.reshape((1,1,-1))
        
    return(bm)

def Cm(latt9, grid_num, alpha):
    vol    = np.linalg.det(latt9)
    reci_latt = np.linalg.inv(latt9)
    
    cm     = np.zeros(grid_num)  
    for kx in range(grid_num[0]):
        mx  = kx if (kx < (grid_num[0]+1)/2) else kx - grid_num[0]
        mhx = mx * reci_latt[0][0]
        for ky in range(grid_num[1]):
            my  = ky if (ky < (grid_num[1]+1)/2) else ky - grid_num[1]
            mhy = mx * reci_latt[1][0] + my * reci_latt[1][1]
            mhx2y2 = mhx*mhx + mhy*mhy
            for kz in range(grid_num[2]):
                mz  = kz if (kz < (grid_num[2]+1)/2) else kz - grid_num[2]
                mhz = mx*reci_latt[2][0] + my*reci_latt[2][1]+ mz * reci_latt[2][2]
                m2 = mhx2y2 + mhz*mhz
                cm[kx,ky,kz] = m2    

    cm[0,0,0] = 1.
    cm = np.exp(-np.pi**2 / alpha**2 * cm) / (np.pi * vol * cm )
    cm[0,0,0] = 0.
    return(cm)

def weight(x, grid_num, PME_ORDER=6):
    w, dw = Bspline( x -int(x), PME_ORDER=PME_ORDER)
    idx = (np.arange(PME_ORDER)+int( x )-int((PME_ORDER+1)/2) + 1) % grid_num
    return(idx, w)


def get_fftw_factor(n):
    
    # good fftw factor n = 2^a * 3^b * 5^c *7 ^d * 11^e * 13^f   and  e/f = 0/1 
    # prime_factor_list = [2, 3, 5, 7, 11, 13]
    # all good fftw factor smaller than 2049
    good_factor_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 21, 22, 24, 25, 26, 
                        27, 28, 30, 32, 33, 35, 36, 39, 40, 42, 44, 45, 48, 49, 50, 52, 54, 55, 56, 60, 63, 
                        64, 65, 66, 70, 72, 75, 77, 78, 80, 81, 84, 88, 90, 91, 96, 98, 99, 100, 104, 105, 
                        108, 110, 112, 117, 120, 125, 126, 128, 130, 132, 135, 140, 143, 144, 147, 150, 154, 
                        156, 160, 162, 165, 168, 175, 176, 180, 182, 189, 192, 195, 196, 198, 200, 208, 210, 
                        216, 220, 224, 225, 231, 234, 240, 243, 245, 250, 252, 256, 260, 264, 270, 273, 275, 
                        280, 286, 288, 294, 297, 300, 308, 312, 315, 320, 324, 325, 330, 336, 343, 350, 351, 
                        352, 360, 364, 375, 378, 384, 385, 390, 392, 396, 400, 405, 416, 420, 429, 432, 440, 
                        441, 448, 450, 455, 462, 468, 480, 486, 490, 495, 500, 504, 512, 520, 525, 528, 539, 
                        540, 546, 550, 560, 567, 572, 576, 585, 588, 594, 600, 616, 624, 625, 630, 637, 640, 
                        648, 650, 660, 672, 675, 686, 693, 700, 702, 704, 715, 720, 728, 729, 735, 750, 756, 
                        768, 770, 780, 784, 792, 800, 810, 819, 825, 832, 840, 858, 864, 875, 880, 882, 891, 
                        896, 900, 910, 924, 936, 945, 960, 972, 975, 980, 990, 1000, 1001, 1008, 1024, 1029, 
                        1040, 1050, 1053, 1056, 1078, 1080, 1092, 1100, 1120, 1125, 1134, 1144, 1152, 1155, 
                        1170, 1176, 1188, 1200, 1215, 1225, 1232, 1248, 1250, 1260, 1274, 1280, 1287, 1296, 
                        1300, 1320, 1323, 1344, 1350, 1365, 1372, 1375, 1386, 1400, 1404, 1408, 1430, 1440, 
                        1456, 1458, 1470, 1485, 1500, 1512, 1536, 1540, 1560, 1568, 1575, 1584, 1600, 1617, 
                        1620, 1625, 1638, 1650, 1664, 1680, 1701, 1715, 1716, 1728, 1750, 1755, 1760, 1764, 
                        1782, 1792, 1800, 1820, 1848, 1872, 1875, 1890, 1911, 1920, 1925, 1944, 1950, 1960, 
                        1980, 2000, 2002, 2016, 2025, 2048]
    
    good_factor_list = np.array(good_factor_list)
    
    assert n < 2048,  print("The grid number %d is too large, can't proceed calculation!" %(n))
    
    delta_n = ( good_factor_list - n ) >= 0
    
    good_n = good_factor_list[delta_n][0]
    
    return(good_n)

def get_real_chage_mesh(fraction_coord, charge, grid_num, PME_ORDER=6):

    grid_x, grid_y, grid_z = grid_num
    charge_mesh = np.zeros(grid_num)
    for n, xyz in enumerate(fraction_coord):
        x, y, z = xyz
        idx_x, w_x = weight(x*grid_x, grid_x, PME_ORDER=PME_ORDER)
        idx_y, w_y = weight(y*grid_y, grid_y, PME_ORDER=PME_ORDER)
        idx_z, w_z = weight(z*grid_z, grid_z, PME_ORDER=PME_ORDER)
        
        for n1, w1 in zip(idx_x, w_x):
            for n2, w2 in zip(idx_y, w_y):
                for n3, w3 in zip(idx_z, w_z):
                    charge_mesh[n1, n2, n3] += w1 * w2 * w3 * charge[n]
                        
       
    return(charge_mesh)


def get_charge_mesh_list(fraction_coord_list, charge_list, grid_num, PME_ORDER=6):

    charge_mesh_list = []
    for charge, fraction_coord in zip(charge_list, fraction_coord_list):
        charge_mesh = get_real_chage_mesh(fraction_coord, charge, grid_num, PME_ORDER = PME_ORDER)
        charge_mesh_list.append(charge_mesh)

    return(charge_mesh_list)

def get_reci_r_list(charge_mesh_list, bm, cm):
    
    bm_cm = cm / bm
    
    # get potential mesh in real space
    potentil_mesh_list = []
    mesh_num = np.prod(bm.shape)
    for i in charge_mesh_list:
        real_potential = np.fft.fftn(np.fft.ifftn(i)* bm_cm * mesh_num).real
        potentil_mesh_list.append(real_potential)   
    
    reci_r_list = []
    # calculate the 1/rij in of pme method part
    for i in range(len(charge_mesh_list)):
        for j in range(i, len(charge_mesh_list)):
            reci_r = np.sum(charge_mesh_list[i] * potentil_mesh_list[j])
            if i == j:
                reci_r = reci_r / 2
            reci_r_list.append(reci_r)
    
    return(reci_r_list)

In [6]:
# real space part
def get_neigbour_list(latt9, all_coords, center_coords, cutoff=12.0):
    
    all_coords    = np.asarray(all_coords.astype("float"),    order="c")
    center_coords = np.asarray(center_coords.astype("float"), order="c")
    latt9         = np.asarray(latt9.astype("float"),         order="c")
    pbc           = np.array([1,1,1], dtype=int)
    index1, index2, image, distance = \
               find_points_in_spheres(all_coords=all_coords, center_coords = center_coords, r = cutoff,
                                      pbc = pbc, lattice = latt9, tol = 1e-10)
    # mol_atom_num = len(center_coords)
    index = np.vstack([index1, index2]).T
    
    # delete those atom pairs with a zero small distance
    mask     = distance > 0.01
    index    = index[mask]
    distance = distance[mask]
    
    return(index, distance)

def get_r_1_list(latt9, all_coords, mol_coord, alpha = 0.33333, cutoff = 12.0):
    
    index, distance = get_neigbour_list(latt9, all_coords, mol_coord, cutoff = cutoff)
    mol_atom_num = len(mol_coord)
    index = index % mol_atom_num

    r_1  = 1. / distance
    r_6  = r_1**6
    r_12 = r_6**2
    
    # the real space part of pme method
    r_1 = r_1 * erfc(distance * alpha)
    
    r_1_list  = []
    r_6_list  = []
    r_12_list = []
    
    for i in range(mol_atom_num):
        for j in range(i, mol_atom_num):
            mask1 = np.logical_and(index[:,0] == i, index[:,1] == j)
            mask2 = np.logical_and(index[:,1] == i, index[:,0] == j)
            mask  = np.logical_or(mask1, mask2)
            tmp_r_1  = np.sum(r_1[mask])
            tmp_r_6  = np.sum(r_6[mask])
            tmp_r_12 = np.sum(r_12[mask])
            
            r_1_list.append(tmp_r_1)
            r_6_list.append(tmp_r_6)
            r_12_list.append(tmp_r_12)
    
    r_1_list  = np.array(r_1_list ) / 2
    r_6_list  = np.array(r_6_list ) / 2
    r_12_list = np.array(r_12_list) / 2
    
    return(r_1_list, r_6_list, r_12_list)

def get_r_14_list(pair14, pair1213, mol_coord, cutoff = 12.0):
    # the 1-4 and 1-2/1-3 interaction part
    mol_atom_num = len(mol_coord)
    r_1_list  = []
    r_6_list  = []
    r_12_list = []    
    
    for i in range(mol_atom_num):
        for j in range(i, mol_atom_num):            
            distance = np.linalg.norm(mol_coord[i] - mol_coord[j])
            if i == j or distance > cutoff:
                r_1_list.append(0.)
                r_6_list.append(0.)
                r_12_list.append(0.) 
                continue
            
            r_1  = 1. / distance
            r_6  = r_1**6
            r_12 = r_6**2
            
            tmp_r_1  = 0.
            tmp_r_6  = 0.
            tmp_r_12 = 0.
            
            if [i,j] in pair14:
                tmp_r_1  = r_1 *  ( 0.8333333 - 1.0)
                tmp_r_6  = r_6 *  ( 0.5000000 - 1.0)
                tmp_r_12 = r_12 * ( 0.5000000 - 1.0)
    
            if [i,j] in pair1213:
                tmp_r_1  = r_1  * - 1.0
                tmp_r_6  = r_6  * - 1.0
                tmp_r_12 = r_12 * - 1.0
          
                
            r_1_list.append(tmp_r_1)
            r_6_list.append(tmp_r_6)
            r_12_list.append(tmp_r_12) 
            
    r_1_list  = np.array(r_1_list )
    r_6_list  = np.array(r_6_list )
    r_12_list = np.array(r_12_list)
    
    return(r_1_list, r_6_list, r_12_list)

def get_ewald_self_interaction_list(n_atom, alpha = 0.33333):

    # the self interaction part of ewald summation method
    r_1_list  = []
    for i in range(n_atom):
        for j in range(i, n_atom):
            if i == j:
                r_1 = -1. 
                r_1_list.append(r_1)
            else:
                r_1_list.append(0.)

    r_1_list = np.array(r_1_list) * alpha / np.pi**0.5
    return(r_1_list)


In [26]:
def get_nonbond_feature(latt9, coord, ff,  cutoff = 12., grid_space=1.0, alpha=0.333333, 
                        PME_ORDER = 6):
    
    mol_atom_num = len(ff["charge"])
    mol_num      = len(coord) // mol_atom_num
    r_1, r_6, r_12 = 0.0, 0.0, 0.0

    
    # short range part    
    for i in coord.reshape((-1,mol_atom_num,3)):
        r_1_list, r_6_list, r_12_list = get_r_1_list(latt9, coord, i, alpha = alpha, 
                                                     cutoff = cutoff)
        r_1  += r_1_list
        r_6  += r_6_list
        r_12 += r_12_list
        
   
    # pair 1-4
    pair_14      = np.array(ff["pairs"]["pair"]) -1 
    # make sure the first atom index of the pair is larger then the second one
    mask = pair_14[:,0] > pair_14[:,1]
    pair_14[mask] = pair_14[mask][:,::-1]
    pair_14       = pair_14.tolist() 
    
    # pair 1-2, 1-3
    pair_1213    = np.array(ff["bonds"]["pair"] + [[i[0],i[2]] for i in ff["angles"]["pair"]]) - 1
    # make sure the first atom index of the pair is larger then the second one
    mask = pair_1213[:,0] > pair_1213[:,1]
    pair_1213[mask] = pair_1213[mask][:,::-1] 
    pair_1213       = pair_1213.tolist()
    
    # short range 1-2, 1-3, 1-4 part
    for i in coord.reshape((-1,mol_atom_num,3)):
        r_1_list, r_6_list, r_12_list = get_r_14_list(pair_14,  pair_1213, i)
        r_1  += r_1_list
        r_6  += r_6_list
        r_12 += r_12_list
        

    # reciprocal part for coulomb interaction
    
    # set the fft grid for pme
    grid_num = [get_fftw_factor(i) for i in np.diagonal(latt9) / grid_space]

    fraction_coord = np.dot(coord, np.linalg.inv(latt9)).reshape((-1, mol_atom_num, 3))
    fraction_coord = fraction_coord.transpose(1,0,2)
    tmp_charge = np.ones(fraction_coord.shape[:2])
    charge_mesh_list = get_charge_mesh_list(fraction_coord, tmp_charge, grid_num, 
                                            PME_ORDER = PME_ORDER)
    bm = Bm(grid_num, PME_ORDER = PME_ORDER)
    cm = Cm(latt9, grid_num, alpha)
    
    r_1_list = get_reci_r_list(charge_mesh_list, bm, cm)
    r_1     += np.array(r_1_list)

    # self interaction part for coulomb interaction    
    r_1_list = get_ewald_self_interaction_list(len(ff["charge"]), alpha = alpha)
    r_1     += r_1_list * mol_num

    # convert to kj/mol
    r_1      = r_1 * 96.485 *27.2 * 0.5291772
    
    return(r_1, r_6, r_12)

# colvent bond part

In [3]:
# energy item calculation
def bond_feature(coord, bond_idx):
    bond_value = np.linalg.norm(coord[:, bond_idx[:,0]] - coord[:, bond_idx[:,1]], axis=-1)
    return(bond_value)

def angle_feature(coord, angle_idx):
    v1 = coord[:, angle_idx[:,0]] - coord[:, angle_idx[:,1]]
    v2 = coord[:, angle_idx[:,2]] - coord[:, angle_idx[:,1]]
    theta = np.arccos(np.sum(v1*v2, axis=-1) / (np.linalg.norm(v1, axis=-1) * np.linalg.norm(v2, axis=-1)))
    return(theta)

def dihedral_feature(coord, dihedral_idx):
    v1 = coord[:, dihedral_idx[:,1]] - coord[:, dihedral_idx[:,0]]
    v2 = coord[:, dihedral_idx[:,2]] - coord[:, dihedral_idx[:,1]]
    v3 = coord[:, dihedral_idx[:,3]] - coord[:, dihedral_idx[:,2]]
    v12 = np.cross(v1, v2, axis=-1)
    v23 = np.cross(v2, v3, axis=-1)

    vangle = np.arctan2(np.sum(v12*v3, axis=-1)*np.linalg.norm(v2,axis=-1), 
                        np.sum(v12*v23, axis=-1))
    return(vangle)

def improper_feature(coord, improper_idx):
    v1 = coord[:, improper_idx[:,1]] - coord[:, improper_idx[:,0]]
    v2 = coord[:, improper_idx[:,2]] - coord[:, improper_idx[:,1]]
    v3 = coord[:, improper_idx[:,3]] - coord[:, improper_idx[:,2]]
    v12 = np.cross(v1, v2, axis=-1)
    v23 = np.cross(v2, v3, axis=-1)

    vangle = np.arctan2(np.sum(v12*v3, axis=-1)*np.linalg.norm(v2,axis=-1), 
                        np.sum(v12*v23, axis=-1))
    return(vangle)


def part_task(total_task_num, nthread):
    
    mean_task = total_task_num // nthread
    task_num_list = [mean_task] * nthread
    if total_task_num > mean_task * nthread:
        for i in range(total_task_num - mean_task * nthread):
            task_num_list[i] += 1
            
    return(task_num_list)


def get_nonbond_feature_list(st_list, ff):
    
    nonbond_feature_list = []
    for st in st_list:
        latt9, coord = st
        feature = get_nonbond_feature(latt9, coord, ff,  cutoff = 12.,  
                                              grid_space=1.0, alpha=0.33333, PME_ORDER = 6)
        nonbond_feature_list.append(feature)
        
    return(nonbond_feature_list)
        

import multiprocessing as mp
def nonbond_feature(st_list, ff, nthread = 1):
    
    if nthread > 1:
        task_num_list = part_task(len(st_list), nthread)
        
        pool = mp.Pool(n_thread)
        feature_list = [pool.apply_async(get_nonbond_feature_list, 
                                         args=(st_list[task_num_list[i]:task_num_list[i+1]], ff))  
                        for i in range(n_thread) ]
        
        pool.close()
        pool.join()   
       
        feature_list = [p.get() for p in feature_list]
        
        feature_list = np.array(reduce(lambda x,y: x+y, [i for i in feature_list]) )
        
        
    else:
        feature_list = get_nonbond_feature_list(st_list, ff)
    feature_list   = np.array(feature_list)
    return(feature_list)

def get_feature(st_list, ff, nthread=1):
    
    mol_atom_num = len(ff["charge"])
    
    coord_list = []
    st_id_list = []
    for n, st in enumerate(st_list): 
        latt9, coord = st
        coord_list.append(coord.reshape(-1, mol_atom_num, 3))
        st_id_list += [n] * (len(coord) // mol_atom_num)
        
        
    coord      = np.vstack(coord_list)
    st_id_list = np.array(st_id_list)
    
    # the index of bonded items start form 1    
    bond_value      = bond_feature(coord,     np.array(ff["bonds"]["pair"])    -1)
    angle_value     = angle_feature(coord,    np.array(ff["angles"]["pair"])   -1)
    dihedral_value  = dihedral_feature(coord, np.array(ff["dihedrals"]["pair"])-1)
    improper_value  = improper_feature(coord, np.array(ff["impropers"]["pair"])-1)
   
    nonbond_value   = nonbond_feature(st_list, ff, nthread = nthread)

    
    all_feature     = [st_id_list, bond_value, angle_value, dihedral_value, improper_value, nonbond_value]
    return(all_feature)

def get_ff_items(st_list, ff, nthread=1):

    all_feature    = get_feature(st_list, ff, nthread=1)
    
    st_id_list     = all_feature[0]
    bond_value     = all_feature[1]
    angle_value    = all_feature[2]
    dihedral_value = all_feature[3]
    improper_value = all_feature[4]
    nonbond_value  = all_feature[5]

    # the unit of k is kj/mol/nm^2, the unit of b is nm, and the unit of bond_value is angstrom
    b = np.array(ff["bonds"]["para"])[:,0]
    bond_item = (bond_value / 10  - b)**2 / 2
    
    # the unit of k is kj/rad^2/mol, the unit of b is degree, and the unit of angle_value is radian
    b = np.array(ff["angles"]["para"])[:,0]
    angle_item = (angle_value - np.deg2rad(b))**2 / 2
    
    # the unit of k is kj/mol, the phase of b is degree, the unit of n is integer
    n     = np.array(ff["dihedrals"]["order"])
    phase = np.deg2rad(np.array(ff["dihedrals"]["para"])[:,0])
    dihedral_item  = 1 + np.cos( n * dihedral_value - phase )
    

    # the unit of k is kj/mol/rad^2, the unit of b is radian
    b = np.deg2rad(np.array(ff["impropers"]["para"])[:,0])
    improper_item = (np.abs(improper_value) - b)**2 / 2

    # the unit of nonbond distance is angstrom
    nonbond_item = nonbond_value

    all_item = [st_id_list, bond_item, angle_item, dihedral_item, improper_item, nonbond_item]

    return(all_item)

def get_energy_items(ff, all_item, e_qm, energy_only = False, weight_list = None):

    st_id_list, bond_item, angle_item, dihedral_item, improper_item, nonbond_item = all_item

    st_num       = len(nonbond_item)
    mol_atom_num = len(ff["charge"])
    e_total      = 0.
    
    
    # gradient of bonded item is the same with item's values here
    e_bond     = bond_item     * np.array(ff["bonds"]["para"]    )[:,1]
#     print(e_bond.reshape((10,-1))[0].sum())
    e_total   += np.sum(e_bond, axis=1)
    
    e_angle    = angle_item    * np.array(ff["angles"]["para"]   )[:,1]
#     print(e_angle.reshape((10,-1))[0].sum())
    e_total   += np.sum(e_angle, axis=1)
    
    e_dihedral = dihedral_item * np.array(ff["dihedrals"]["para"])[:,1]
#     print(e_dihedral.reshape((10,-1))[0].sum())
    e_total   += np.sum(e_dihedral, axis=1)
    
    e_improper = improper_item * np.array(ff["impropers"]["para"])[:,1]
#     print(e_improper.reshape((10,-1))[0].sum())
    e_total   += np.sum(e_improper, axis=1)
    
    
    e_total    = np.array([np.sum(e_total[st_id_list == i]) for i in range(st_num)])
    
    del e_bond, e_angle, e_dihedral, e_improper

    charge  = np.array(ff["charge"])
    epsilon = np.array(ff["vdw_para"])[:, 1]
    sigma   = np.array(ff["vdw_para"])[:, 0]*10
    

    all_nonbond_idx = np.array([[i, j] for i in range(mol_atom_num) for j in range(i,mol_atom_num)])

    e_coulomb       = charge[all_nonbond_idx[:,0]] * charge[all_nonbond_idx[:,1]] * nonbond_item[:,0]

    epsilon_mix = np.sqrt(epsilon[all_nonbond_idx[:,0]] * epsilon[all_nonbond_idx[:,1]]) 
    sigma_mix = (sigma[all_nonbond_idx[:,0]] + sigma[all_nonbond_idx[:,1]]) / 2
    all_r6 = sigma_mix**6 
    e_vdw  = 4 * epsilon_mix * (all_r6**2 * nonbond_item[:,2] - all_r6 * nonbond_item[:,1])
    
#     print(e_coulomb.reshape((10,-1))[0].sum())
#     print(e_vdw.reshape((10,-1))[0].sum())
    
    # dE/dq_i = q_j / r_ij 
    #         = q_i*q_j/r_ij / q_i
    
    # dE/depsilon_i = 4*epsilon_ij * [(sigma_ij/r_ij)^12 - 6*(sigma_ij/r_ij)^6] / epsilon_i / 2
    #               = E_vdw / epsilon_i / 2
    
    # dE/dsigma = 4*epsilon_ij * [12*(sigma_ij/r_ij)^11 - 6*(sigma_ij/r_ij)^5]
    #           = 4*epsilon_ij * [12*(sigma_ij/r_ij)^12 - 6*(sigma_ij/r_ij)^6] / sigma_ij / 2
    g_vdw_sigma = 4 * epsilon_mix * (all_r6**2 * nonbond_item[:,2] * 12 - 
                                     all_r6    * nonbond_item[:,1] * 12) / sigma_mix
    
    
    e_total += np.sum(e_vdw, axis=1) + np.sum(e_coulomb, axis=1)
    
    if energy_only:
        return(e_total)
    
    if weight_list == None:
        weight_list = np.ones(len(e_qm))
    assert len(e_qm) == len(weight_list), print("length of weight_list should be same with train set")
    
    
    # dE_MSD = 2(E_ff-E_qm)dE_ff/dk
    e_diff   = (e_total - e_total.mean()) - (e_qm - e_qm.mean())
    e_diff   = e_diff * weight_list
    e_diff   = e_diff.reshape((-1,1))


    g_bond     = np.sum(bond_item     * e_diff[st_id_list], axis=0)
    g_angle    = np.sum(angle_item    * e_diff[st_id_list], axis=0)
    g_dihedral = np.sum(dihedral_item * e_diff[st_id_list], axis=0)
    g_improper = np.sum(improper_item * e_diff[st_id_list], axis=0)
    
    g_charge   = np.zeros(mol_atom_num)
    g_epsilon  = np.zeros(mol_atom_num)
    g_sigma    = np.zeros(mol_atom_num)
    
    
    charge_mix  = e_coulomb   * e_diff
    e_vdw       = e_vdw       * e_diff
    g_vdw_sigma = g_vdw_sigma * e_diff
    for i in range(mol_atom_num):
        mask = all_nonbond_idx[:,0]==i
        g_charge[i]  += np.sum(charge_mix[:,mask] ) / charge[i]
        g_epsilon[i] += np.sum(e_vdw[:,mask]      ) / epsilon[i] 
        g_sigma[i]   += np.sum(g_vdw_sigma[:,mask])
        
        mask = all_nonbond_idx[:,1]==i
        g_charge[i]  += np.sum(charge_mix[:,mask] ) / charge[i]
        g_epsilon[i] += np.sum(e_vdw[:,mask]      ) / epsilon[i]
        g_sigma[i]   += np.sum(g_vdw_sigma[:,mask]) 

    g_epsilon = g_epsilon / 2 
    g_sigma   = g_sigma   / 2

    g_total = np.hstack([g_bond, g_angle, g_dihedral, g_improper, g_charge, g_epsilon, g_sigma])
    
    return(e_total, g_total)

def update_ff_parameter(ff_para, ff, ff_para_len):

    # uodate forcefield parameters
    ff_new = copy.deepcopy(ff)
    
    # update forcefield parameters
    for n in range(ff_para_len[0], ff_para_len[1]):
        nn = n - ff_para_len[0]
        ff_new["bonds"]["para"][nn][1] = ff_para[n]
        
    for n in range(ff_para_len[1], ff_para_len[2]):
        nn = n - ff_para_len[1]
        ff_new["angles"]["para"][nn][1] = ff_para[n]
        
    for n in range(ff_para_len[2], ff_para_len[3]):
        nn = n - ff_para_len[2]
        ff_new["dihedrals"]["para"][nn][1] = ff_para[n]
        
    for n in range(ff_para_len[3], ff_para_len[4]):
        nn = n - ff_para_len[3]
        ff_new["impropers"]["para"][nn][1] = ff_para[n]
        
    # coulomb
    new_charge = ff_para[ff_para_len[4]:ff_para_len[5]]
    # make sure the total net charge is zero
    net_charge   = np.sum(new_charge)/len(new_charge)
    for n in range(len(new_charge)):
        ff_new["charge"][n] = new_charge[n] - net_charge
    

    # vdw
    for n in range(ff_para_len[5], ff_para_len[6]):
        nn = n - ff_para_len[5]
        ff_new["vdw_para"][nn][1] = ff_para[n]
    for n in range(ff_para_len[6], ff_para_len[7]):
        nn = n - ff_para_len[6]
        ff_new["vdw_para"][nn][0] = ff_para[n]
    
    return(ff_new)

def set_equivalent_item(grad, ff_para_len, equivalent_item={}):
    
    # set the gradient value in equivalent_item to the average vaule
    
    if equivalent_item != None:
        
        parameter_list = ["bond", "angle", "dihedral", "improper", "charge",
                          "epsilon", "sigma"]
        # average the gradient value to the same
        for n, i in enumerate(parameter_list):        
            para_range = equivalent_item.get(i, None) 
            if para_range != None:
                for j in para_range:
                    idx = np.array(j)
                    grad[idx + ff_para_len[n]] = grad[idx + ff_para_len[n]].mean()
          
    
    return(grad)

def get_msd_gradient(ff_para, all_item, e_qm, ff, equivalent_item = {}, weight_list = None):
    
    k_bond     = np.array(ff["bonds"]["para"]    )[:,1]
    k_angle    = np.array(ff["angles"]["para"]   )[:,1]
    k_dihedral = np.array(ff["dihedrals"]["para"])[:,1]
    k_improper = np.array(ff["impropers"]["para"])[:,1]
    charge     = np.array(ff["charge"])
    epsilon    = np.array(ff["vdw_para"])[:, 1]
    sigma      = np.array(ff["vdw_para"])[:, 0]
    
    ff_para_len  = [0, len(k_bond), len(k_angle), len(k_dihedral), len(k_improper), len(charge), len(charge), len(charge)]
    ff_para_len  = np.cumsum(ff_para_len)
    
    # uodate forcefield parameters
    ff_new = update_ff_parameter(ff_para, ff, ff_para_len, equivalent_item=equivalent_item)    
    
    # total energy and gradient of forcefield parameters  
    e_total, g_total = get_energy_items(ff_new, all_item, e_qm, weight_list=weight_list)
    
    # set equivalent item 
    g_total = set_equivalent_item(g_total, ff_para_len, equivalent_item)    
    
    # msd
    e_total = e_total -e_total.mean()
    msd = np.sum((e_total - (e_qm- e_qm.mean()))**2 * weight_list)
    
    # fitting of ff energy and qm energy    
    fit = linregress(e_qm - e_qm.mean(), e_total)
    rmsd = (msd / len(e_total))**0.5
    print("y={:<6.4f}x+{:<6.4f}  rmsd={:<6.4f} r={:<6.4f}".format(fit.slope, fit.intercept, rmsd,
                                                                  fit.rvalue ))
    
    return(msd, g_total)

# forcefield optimizing

In [2]:
def get_parameter_range(parameter_range, ff_para, ff_para_len):
    
    # here default setting is 10% for bonded, vdw parameters, and 0.1 for charge 
    para_bond  =  np.array([ff_para*0.8, ff_para*1.2]).T
    para_bond[ff_para_len[4]:ff_para_len[5]][:,0] = ff_para[ff_para_len[4]:ff_para_len[5]] - 0.1
    para_bond[ff_para_len[4]:ff_para_len[5]][:,1] = ff_para[ff_para_len[4]:ff_para_len[5]] + 0.1
    if parameter_range != None:
        
        parameter_list = ["bond", "angle", "dihedral", "improper", "charge",
                          "epsilon", "sigma"]
        # update paramter range
        for n, i in enumerate(parameter_list):        
            para_range = parameter_range.get(i, None) 
            if para_range != None:
                for i in para_range.keys():
                    para_bond[i + ff_para_len[n]] = para_range[i]
                    
    return(para_bond)

def ff_opt_bfgs(ff, all_item, e_qm, parameter_range = None,
                     equivalent_item = {}, weight_list = None, cutoff = 12., maxcycle=50):
    
#     all_item = get_ff_items(coord_list, ff, cutoff = cutoff)
    
    if weight_list == None:
        weight_list = np.ones(len(e_qm))
    
    
    k_bond     = np.array(ff["bonds"]["para"]    )[:,1]
    k_angle    = np.array(ff["angles"]["para"]   )[:,1]
    k_dihedral = np.array(ff["dihedrals"]["para"])[:,1]
    k_improper = np.array(ff["impropers"]["para"])[:,1]
    charge     = np.array(ff["charge"])
    epsilon    = np.array(ff["vdw_para"])[:, 1]
    sigma      = np.array(ff["vdw_para"])[:, 0]
    
    ff_para    = np.hstack([k_bond, k_angle, k_dihedral, k_improper, charge, epsilon, sigma])
    ff_para_len  = [0, len(k_bond), len(k_angle), len(k_dihedral), len(k_improper), len(charge), 
                    len(charge), len(charge)]
    ff_para_len  = np.cumsum(ff_para_len)
        
    para_bound  =  get_parameter_range(parameter_range, ff_para, ff_para_len)
    
    
    ff_para_opt = fmin_l_bfgs_b(get_msd_gradient, ff_para, args=(all_item, e_qm, copy.deepcopy(ff), 
                                                                 equivalent_item, weight_list ), 
                                approx_grad=0, 
                                bounds=para_bound, m=10, factr=10000000.0, pgtol=1e-05, epsilon=1e-08, 
                                iprint=-1, maxfun=maxcycle, maxiter=1500)
    # updated ff parameter list
    ff_para_opt = ff_para_opt[0]
    
    # update ff parameters
    ff_new = update_ff_parameter(ff_para_opt, ff, ff_para_len, equivalent_item=equivalent_item)
    
    
    
    return(ff_new)

In [None]:
# pbc structure list


In [11]:
# read itp format gaff force field
# the forcefield was created by https://bio2byte.be/acpype/
ff_itp = itp_read("./testcase/mol.itp", ff_type = "gaff")

In [45]:
st_list = [[latt9, coord]]*10

all_item = get_ff_items(st_list, ff_itp)

In [90]:
# e_ff,g = get_energy_items(ff_itp, all_item, np.array([0.]*10), energy_only = False)

In [89]:
# parameter range setting
# The parameter is a dict, the items are the bond or nonbond parts in forcefield.
# There are "bond", "angle", "dihedral", "improper", "charge", "epsilon" or "sigma".
# The keys and items of the sub-dict items of those parts are the index of the specific bond 
# or nonbond parameters and the parameter ranges(which is a list, the lower and upper bound 
# of the paramter ). When some of the parts are not presented, default values will be used. 
# and the details is in the function "get_parameter_range".
# If you want to fix some parameters of the forcefield, you can achive that by setting the lower 
# and upper bound of the specific paramter to the target value.
# Below is a example. It may not be reasonable 
parameter_range = {"bond":{0:[ff_itp["bonds"]["para"][0][1] * 0.9, 
                              ff_itp["bonds"]["para"][0][1] * 1.1]},
                   
                   "charge":{5:[ff_itp["charge"][5] - 0.1,  
                                ff_itp["charge"][5] + 0.1]}}

# the equivalent item setting
# The parameter is a dict, and the items are the bond or nonbond parts in forcefield.
# There are "bond", "angle", "dihedral", "improper", "charge", "epsilon" or "sigma".
# The value of the sub-dict items is a list, and the elements in those list is the index of 
# those specific bond or nonbond parameters share a same value.
# Below is a example, and only for explaination, not for pratical forcefiled fitting calculation.
equivalent_item   = {"bond":[[0, 1 ,2],
                             [3,4,5]],
                     
                     "charge": [[1,3,9]]}

# the weight setting
# The weight of every data in the train set, which decide the data point's contibution of 
# gradient to the total gradient. It is a list, and the default values is 1.0 for all
# data point in the train set.
weight_list = np.ones(len(coord_list))

259

In [None]:
# the test set index
train_set_idx = slice(0, 10000)

# ff_item, in which store the all items related to the covalent bond and
# noncovalent bond forcefield items computed from structure coordinates. So the energy
# and parameter gradient can be calculated effeciently during the following forcefield 
# optimizing procedure
# the noncovalent bond part is time consuming, a multithread method can be used
ff_item  = get_ff_items(coord_list[train_set_idx], ff_itp, nthread=1)
# the ff_item can be saved as a npy file, so there is no need to calculate it everytimes when
# we want to optimize forcefield with a modified forcefield optimizing parameters
np.save("./ff_item_pbc.npy", ff_item)

In [4]:
train_set_idx = slice(0, 10000)
ff_opt = ff_opt_bfgs(ff_itp, ff_item, energy_qm[train_set_idx], 
                     parameter_range = parameter_range,
                     equivalent_item = equivalent_item, weight_list = None, maxcycle=500)

In [1]:
# the performance of the optimized forcefield on the other data poins in the dataset

test_set_idx = slice(10000, 50000)

all_item = get_ff_items(coord_list[test_set_idx], ff_opt, nthread=1)
e_ff ,_ = get_energy_items(ff_opt, all_item, energy_qm[test_set_idx])
e_ff = e_ff-e_ff.mean()

fit = linregress(energy_qm[test_set_idx]-energy_qm[test_set_idx].mean(), e_ff)
e_diff = energy_qm[test_set_idx]- e_ff
rmsd   = np.mean(e_diff**2)**0.5
label="y={:<6.3f}x+{:<6.3f}  rmsd={:<6.3f} r={:<6.3f}".format(fit.slope, fit.intercept, rmsd,
                                                                  fit.rvalue )
plt.plot(energy_qm[test_set_idx],e_ff-e_ff.mean(), marker="o", markersize=2, 
         linewidth=0, label=label )
plt.plot([e_ff.min(),e_ff.max()],[e_ff.min(),e_ff.max()], label="y=x", color="k")
plt.legend()
plt.xlabel("qm energy(kj/mol)")
plt.xlabel("ff energy(kj/mol)")

In [2]:
# the energy difference distribution of test set
bins = np.linspace(-60,60, 60)
num, bins = np.histogram(e_diff, bins=bins )
plt.bar(np.linspace(-60,60, 59), num)

In [3]:
# the energy distribution of test set
bins = np.linspace(-100,100, 60)
num, bins = np.histogram(energy_qm[test_set_idx], bins=bins )
plt.bar(np.linspace(-100,100, 59), num, label="qm")
bins = np.linspace(-100,100, 60)
num, bins = np.histogram(e_ff, bins=bins )
plt.bar(np.linspace(-100,100, 59)+1, num, label="ff")
plt.legend()