In [1]:
import tarfile
import io
import time
from contextlib import contextmanager
from collections import namedtuple

import numpy as np
import scipy
import scipy.io

from ase import Atoms
import ase.io.xyz
from rascal.representations import SphericalExpansion
from rascal.utils import (ClebschGordanReal, compute_lambda_soap, 
                          spherical_expansion_reshape)

import h5py
import tabulate

ANGSTROM_TO_BOHR = 1.88973
HARTREE_TO_KCALMOL = 627.509
SPECIES = (1, 6, 7, 8, 16)
SIGMA = 1E-3
N_FRAMES = 100

SPHERICAL_EXPANSION_HYPERS_COMMON = {
    "gaussian_sigma_constant": 0.5,
    "gaussian_sigma_type": "Constant",
    "cutoff_smooth_width": 0.5,
    "radial_basis": "GTO",
    'expansion_by_species_method': 'user defined',
    'global_species': SPECIES
}

In [2]:
# Helper functions

def pad_to_shape(array, shape, value=0):
    pad = [(0, n_max - n) for n_max, n in zip(shape, array.shape)]
    return np.pad(array, pad, constant_values=value)

def pad_to_max(arrays, value=0):
    # Takes arrays with different shapes, but same number of dimensions
    # and pads them to the size of the largest array along each axis
    shape = np.max([_.shape for _ in arrays], axis=0)
    return np.array([pad_to_shape(_, shape, value) for _ in arrays])

In [3]:
class FramesParser:
    '''Parses contents of a tarball with xyz coordinates and 
    charges/positions of point charges
    '''
    
    Z_DICT = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'S': 19}
    SPECIES_DICT = {1: 0, 6: 1, 7: 2, 8: 3, 16: 4}
    
    def __init__(self, filename, n):
        self.tar = tarfile.open(filename, mode='r:gz')
        self.n = n
        
        raw_xyz = [self._parse_xyz(f'{i+1}.xyz') for i in range(n)]
        self.z = raw_xyz[0][0]
        self.zid = np.array([self.SPECIES_DICT[_] for _ in self.z])
        self.xyz = np.array([_[1] for _ in raw_xyz])

        pc = pad_to_max([self._parse_pc(f'{i+1}.pc') for i in range(n)])
        self.pc_q, self.pc_xyz = pc[:,:,0], pc[:,:,1:]
    
    def _parse_xyz(self, filename):
        with self.tar.extractfile(filename) as f:
            next(f)
            next(f)
            raw_lines = f.readlines()[:-1]
            z, xyz = zip(*map(self._parse_xyz_line, raw_lines))
            return np.array(z), np.array(xyz)        
        
    @classmethod
    def _parse_xyz_line(cls, line):
        el, *xyz_raw = line.decode().split()
        return cls.Z_DICT[el], list(map(float, xyz_raw))    
    
    def _parse_pc(self, filename):
        with self.tar.extractfile(filename) as f:
            return np.loadtxt(f, skiprows=1)

In [4]:
class ORCAParser:
    '''Parses contents of a tarball with outputs of calculations
    in vacuo and in presence of point charges. Optionally, can
    also read values of electrostatic potential calculated with 
    vpot to decompose the interaction energy into static and inuced 
    components and read MBIS partitioning done with horton.
    '''
    
    def __init__(self, filename, n, q=None, mbis=False):
        self.tar = tarfile.open(filename, mode='r:gz')
        self.n = n
        self.vac_E, self.pc_E = self.get_E()
        self.E = self.pc_E - self.vac_E
        self.time = self.get_time()
        
        if q is not None:
            self.E_static = self.get_E_static(q)
            self.E_induced = self.E - self.E_static
            
        if mbis:
            self.mbis = self._parse_horton()
    
    def get_E(self):
        E = [self._get_E_by_id(i+1) for i in range(self.n)]
        return np.array(E).T
    
    def _get_E_by_id(self, i):
        pc_E = self._get_E_from_out(f'{i}_pc.out')
        vac_E = self._get_E_from_out(f'{i}_vac.out')
        return vac_E, pc_E    
    
    def _get_E_from_out(self, filename):
        E_prefix = 'FINAL SINGLE POINT ENERGY'
        with self.tar.extractfile(filename) as f:
            E_line = next(line for line in f 
                          if line.decode().startswith(E_prefix))
        return float(E_line.split()[-1]) * HARTREE_TO_KCALMOL
    
    def get_E_static(self, q):
        return np.sum(self.get_vpot() * q, axis=1) * HARTREE_TO_KCALMOL
    
    def get_vpot(self):
        return self._get_vpot_by_suffix('vac')
    
    def _get_vpot_by_suffix(self, suffix):
        return pad_to_max([self._get_vpot_from_out(f'{i+1}_{suffix}.dat')
                           for i in range(self.n)])
    
    def _get_vpot_from_out(self, filename):
        with self.tar.extractfile(filename) as f:
            return np.loadtxt(f, skiprows=1)[:,3]
    
    def get_time(self):
        time = [self._get_time_by_id(i+1) for i in range(self.n)]
        return np.array(time)
    
    def _get_time_by_id(self, i):
        with self.tar.extractfile(f'{i}_pc.out') as f:
            try:
                cpu_prefix = '           *        Program running with'
                cpu_line = next(line for line in f 
                                if line.decode().startswith(cpu_prefix))
                cpu = int(cpu_line.split()[4])
            except StopIteration:
                cpu = 1
                f.seek(0)
            time_prefix = 'Sum of individual times         ... '
            time_line = next(line for line in f 
                             if line.decode().startswith(time_prefix))
            time = float(time_line.split()[5])
        return cpu * time    
    
    HORTON_KEYS = ('cartesian_multipoles', 'pure_multipoles', 'core_charges',
                   'radial_moments', 'valence_charges', 'valence_widths')
    
    def _parse_horton(self):
        data = [self._parse_horton_by_id(i+1) for i in range(self.n)]
        return {k: np.array([_[k] for _ in data]) for k in self.HORTON_KEYS}

    def _parse_horton_by_id(self, i):
        with self.tar.extractfile(f'{i}.h5') as f:
            f_h5 = h5py.File(io.BytesIO(f.read()), 'r')
            return {key: f_h5[key][:] for key in self.HORTON_KEYS}

In [5]:
FramesE = namedtuple('FramesE', ('E', 'E_static', 'E_induced'))

@contextmanager
def timer(tag='', print_time=True):
    start_time = time.time()
    try:
        yield start_time
    finally: 
        if print_time:
            print(f'{time.time() - start_time:.3f} ({tag})')
            
class GPRCalculator:
    '''Predicts an atomic property for a molecule with GPR'''
    
    def __init__(self, ref_values, ref_soap, n_ref, sigma):
        '''
        ref_values: (N_Z, N_REF)
        ref_soap: (N_Z, N_REF, N_SOAP)
        n_ref: (N_Z,)
        sigma: ()
        '''
        self.ref_soap = ref_soap
        Kinv = self.get_Kinv(ref_soap, sigma)
        self.n_ref = n_ref
        self.n_z = len(n_ref)
        self.ref_mean = np.sum(ref_values, axis=1) / n_ref
        ref_shifted = ref_values - self.ref_mean[:, None]
        self.c = (Kinv @ ref_shifted[:, :, None]).squeeze()
        
    def __call__(self, mol_soap, zid):
        '''
        mol_soap: (N_ATOMS, N_SOAP)
        zid: (N_ATOMS,)
        '''
        
        result = np.zeros(len(zid))
        for i in range(self.n_z):
            n_ref = self.n_ref[i]
            ref_soap_z = self.ref_soap[i, :n_ref]
            mol_soap_z = mol_soap[zid == i, :, None]
            K_mol_ref = (ref_soap_z @ mol_soap_z).squeeze() ** 2
            result[zid == i] = K_mol_ref @ self.c[i, :n_ref] + self.ref_mean[i]
        return result
    
    @classmethod
    def get_Kinv(cls, ref_soap, sigma):
        '''
        ref_soap: (N_Z, MAX_N_REF, N_SOAP)
        sigma: ()
        '''
        n = ref_soap.shape[1]
        K = (ref_soap @ ref_soap.swapaxes(1, 2)) ** 2
        return np.linalg.inv(K + sigma ** 2 * np.identity(n))


class SOAPCalculator:
    '''Calculates SOAP feature vectors for a given system'''
    
    def __init__(self, hypers):
        self.hypers = hypers
        self.cg = ClebschGordanReal(lmax=hypers["max_angular"])
        self.spex = SphericalExpansion(**hypers)
    
    def __call__(self, z, xyz):
        mol = self.get_mol(z, xyz)
        spex_feats_flat = self.spex.transform(mol).get_features(self.spex)
        spex_feats = spherical_expansion_reshape(spex_feats_flat, **hypers)
        return self.get_soap(spex_feats, self.cg)
    
    @staticmethod
    def get_mol(z, xyz):
        xyz_min = np.min(xyz, axis=0)
        xyz_max = np.max(xyz, axis=0)
        xyz_range = xyz_max - xyz_min
        return Atoms(z, positions=xyz - xyz_min, cell=xyz_range, pbc=0)
    
    @staticmethod
    def get_soap(spex_feats, cg):
        X = compute_lambda_soap(spex_feats, cg, 0, 0)
        X = X.reshape((X.shape[0], -1))
        norm = np.linalg.norm(X, axis=1)
        return X / norm[:, None]
    

class MLMMCalculator:
    '''Main ML/MM energy prediction class'''
    
    def __init__(self, ref_soap, n_ref, s_sigma, chi_sigma, params, hypers):
        self.get_soap = SOAPCalculator(hypers)
        self.q_core = params['q_core']
        self.a_QEq = params['a_QEq']
        self.a_Thole = params['a_Thole']
        self.k_Z = params['k_Z']
        self.get_s = GPRCalculator(params['s_ref'], ref_soap, n_ref, s_sigma)
        self.get_chi = GPRCalculator(params['chi_ref'], ref_soap, n_ref, chi_sigma)
        
    def get_frames_E(self, frames, **kwargs):
        result_raw = [self._get_frame_E_by_id(frames, i, **kwargs) 
                      for i in range(frames.n)]
        return FramesE(*map(np.array, zip(*result_raw)))
    
    def _get_frame_E_by_id(self, frames, i, **kwargs):
        kwargs_i = {k: v[i] for k, v in kwargs.items()}
        return self(frames.xyz[i], frames.z, frames.zid, frames.pc_xyz[i],
                    frames.pc_q[i], **kwargs_i)
    
    def __call__(self, xyz, z, zid, pc_xyz, pc_q, print_times=False, 
                 s=None, chi=None, q_core=None, q=None):
        xyz_bohr = xyz * ANGSTROM_TO_BOHR
        pc_xyz_bohr = pc_xyz * ANGSTROM_TO_BOHR
        
        with timer('soap', print_times):
            mol_soap = self.get_soap(z, xyz)
            
        with timer('gpr', print_times):
            s = s if s is not None else self.get_s(mol_soap, zid)
            chi = chi if chi is not None else self.get_chi(mol_soap, zid)
        
        with timer('mlmm', print_times):
            q_core = q_core if q_core is not None else self.q_core[zid]
            k_Z = self.k_Z[zid]
            r_data = self._get_r_data(xyz_bohr)
            mesh_data = self._get_mesh_data(xyz_bohr, pc_xyz_bohr, s)
            q = q if q is not None else self.get_q(r_data, s, chi)
            q_val = q - q_core
            mu_ind = self._get_mu_ind(r_data, mesh_data, pc_q, s, q_val, k_Z)
            vpot_q_core = self._get_vpot_q(q_core, mesh_data['T0_mesh'])
            vpot_q_val = self._get_vpot_q(q_val, mesh_data['T0_mesh_slater'])
            vpot_static = vpot_q_core + vpot_q_val
            E_static = np.sum(vpot_static @ pc_q)
        
            vpot_ind = self._get_vpot_mu(mu_ind, mesh_data['T1_mesh'])
            E_ind = np.sum(vpot_ind @ pc_q) * 0.5
        
            E_total = E_static + E_ind
            
        return np.array([E_total, E_static, E_ind]) * HARTREE_TO_KCALMOL
    
    def get_frames_chi(self, frames, q, s=None):
        return np.array([self._get_frame_chi_by_id(frames, q, s, i)
                         for i in range(frames.n)])
    
    def _get_frame_chi_by_id(self, frames, q, s, i):
        s = s and s[i]
        return self.get_chi_from_q(frames.xyz[i], frames.z, frames.zid, q[i], s)
    
    def get_chi_from_q(self, xyz, z, zid, q, s=None):
        mol_soap = self.get_soap(z, xyz)
        xyz_bohr = xyz * ANGSTROM_TO_BOHR
        r_data = self._get_r_data(xyz_bohr)
        s = s if s is not None else self.get_s(mol_soap, zid)
        A = self._get_A_QEq(r_data, s)
        return -A[:-1,:-1] @ q
        
    def get_q(self, r_data, s, chi):
        A = self._get_A_QEq(r_data, s)
        b = np.hstack([-chi,  0])
        return np.linalg.solve(A, b)[:-1]          
    
    def _get_A_QEq(self, r_data, s):
        s_gauss = s * self.a_QEq
        s2 = s_gauss ** 2
        s_mat = np.sqrt(s2[:, None] + s2[None, :])

        A = self._get_T0_gaussian(r_data['T01'], r_data['r_mat'], s_mat)
        A[np.diag_indices_from(A)] = 1. / (s_gauss * np.sqrt(np.pi))

        ones = np.ones((len(A), 1))
        return np.block([[A, ones], [ones.T, 0.]])
    
    def _get_mu_ind(self, r_data, mesh_data, q, s, q_val, k_Z):
        A = self._get_A_thole(r_data, s, q_val, k_Z)
        fields = np.sum(mesh_data['T1_mesh'] * q[:, None], 
                        axis=1).flatten()
        mu_ind = np.linalg.solve(A, fields)
        E_ind = mu_ind @ fields * 0.5
        return mu_ind.reshape((-1, 3))
    
    def _get_A_thole(self, r_data, s, q_val, k_Z):
        N = - q_val
        v = 60 * N * s ** 3
        alpha = np.array(v * k_Z)

        alphap = alpha * self.a_Thole
        alphap_mat = alphap[:, None] * alphap[None, :]

        au3 = r_data['r_mat'] ** 3 / np.sqrt(alphap_mat)
        au31 = au3.repeat(3, axis=1)
        au32 = au31.repeat(3, axis=0)
        A = - self._get_T2_thole(r_data['T21'], r_data['T22'], au32)
        A[np.diag_indices_from(A)] = 1. / alpha.repeat(3)
        return A
    
    @staticmethod
    def _get_vpot_q(q, T0):
        return np.sum(T0 * q[:, None], axis=0)

    @staticmethod
    def _get_vpot_mu(mu, T1):
        return - np.tensordot(T1, mu, ((0, 2), (0, 1)))
        
    @classmethod
    def _get_r_data(cls, xyz):
        n_atoms = len(xyz)
        t01 = np.zeros((n_atoms, n_atoms))
        t11 = np.zeros((n_atoms, n_atoms * 3))
        t21 = np.zeros((n_atoms * 3, n_atoms * 3))
        t22 = np.zeros((n_atoms * 3, n_atoms * 3))

        rr_mat = xyz[:, None, :] - xyz[None, :, :]
        r_mat = np.linalg.norm(rr_mat, axis=2)
        
        r_inv = 1. / r_mat
        r_inv[np.diag_indices_from(r_mat)] = 0.
        
        r_inv1 = r_inv.repeat(3, axis=1)
        r_inv2 = r_inv1.repeat(3, axis=0)
        outer = cls.get_outer(rr_mat)
        id2 = np.tile(np.tile(np.eye(3).T, n_atoms).T, n_atoms)
        
        t01 = r_inv
        t11 = -rr_mat.reshape(n_atoms, n_atoms * 3) * r_inv1 ** 3
        t21 = -id2 * r_inv2 ** 3
        t22 = 3 * outer  * r_inv2 ** 5

        return {'r_mat': r_mat, 'T01': t01, 'T11': t11, 'T21': t21, 'T22': t22}
    
    @staticmethod
    def get_outer(a):
        n = len(a)
        idx = np.triu_indices(n, 1)

        result = np.zeros((n, n, 3, 3))
        result[idx] = a[idx][:, :, None] @ a[idx][:, None, :]
        result.swapaxes(0,1)[idx] = result[idx]

        return result.swapaxes(1, 2).reshape((n * 3, n * 3))
    
    @classmethod
    def _get_mesh_data(cls, xyz, xyz_mesh, s):
        rr = xyz_mesh[None, :, :] - xyz[:, None, :]
        r = np.linalg.norm(rr, axis=2)
            
        return {'T0_mesh': 1. / r,
                'T0_mesh_slater': cls.get_T0_slater(r, s[:, None]),
                'T1_mesh': - rr / r[:, :, None] ** 3}
    
    @staticmethod
    def get_T0_slater(r, s):
        return (1 - (1 + r / (s * 2)) * np.exp(-r / s)) / r
    
    @staticmethod
    def _get_T0_gaussian(t01, r, s_mat):
        return t01 * scipy.special.erf(r / (s_mat * np.sqrt(2)))

    @staticmethod
    def _get_T1_gaussian(t11, r, s_mat):
        s_invsq2 = 1. / (s_mat * np.sqrt(2))
        return t11 * (
            scipy.special.erf(r * s_invsq2) - 
            r * s_invsq2 * 2 / np.sqrt(np.pi) * np.exp(-r * s_invsq2) ** 2
        ).repeat(3, axis=1)

    @classmethod
    def _get_T2_thole(cls, tr21, tr22, au3):
        return cls._lambda3(au3) * tr21 + cls._lambda5(au3) * tr22
    
    @staticmethod
    def _lambda3(au3):
        return 1 - np.exp(-au3)

    @staticmethod
    def _lambda5(au3):
        return 1 - (1 + au3) * np.exp(-au3)

Read the positions of the QM atoms and point charges

In [6]:
# https://zenodo.org/record/7048725/files/mpro_xyz.tgz
frames = FramesParser('mpro_xyz.tgz', N_FRAMES)

Parse the ORCA outputs. For DFT calculations provide the charges of the MM atoms to calculate the static/induced energy decomposition. For B3LYP/cc-pVTZ also read the MBIS partitioning results.

In [7]:
# https://doi.org/10.5281/zenodo.7048725
orca_dict = {'b3lyp_cc-PVTZ': ORCAParser('mpro_B3LYP_VTZ.tgz', N_FRAMES, 
                                         frames.pc_q, mbis=True),
             'b3lyp_cc-PVDZ': ORCAParser('mpro_B3LYP_VDZ.tgz', N_FRAMES, 
                                         frames.pc_q),
             'pbe0_cc-PVTZ': ORCAParser('mpro_PBE0_VTZ.tgz', N_FRAMES, 
                                        frames.pc_q),
             'blyp_631s': ORCAParser('mpro_BLYP_631Gs.tgz', N_FRAMES, 
                                     frames.pc_q),
             'xtb': ORCAParser('mpro_XTB.tgz', N_FRAMES),
             'am1': ORCAParser('mpro_AM1.tgz', N_FRAMES),
             'pm3': ORCAParser('mpro_PM3.tgz', N_FRAMES)}

Read the model parameters and initialize the ML/MM energy calculator.

In [8]:
# https://zenodo.org/record/7051785/files/mlmm.mat
params = scipy.io.loadmat('mlmm.mat', squeeze_me=True)

hypers = {"interaction_cutoff": 3.,
          "max_radial": 4,
          "max_angular": 4,
          "compute_gradients": True,
          **SPHERICAL_EXPANSION_HYPERS_COMMON}

mlmm_calculator = MLMMCalculator(params['ref_soap'], params['n_ref'], 
                                 SIGMA, SIGMA, params, hypers)

Calculate interaction energies using ML/MM and exact MBIS properties

In [9]:
mbis = orca_dict['b3lyp_cc-PVTZ'].mbis
s = mbis['valence_widths']
q_val = mbis['valence_charges']
q_core = mbis['core_charges']
q = q_val + q_core
q = q - np.mean(q, axis=1)[:, None]
q_av = np.mean(q, axis=0)[None,:].repeat(frames.n, axis=0)
chi = mlmm_calculator.get_frames_chi(frames, q)
chi_av = np.mean(chi, axis=0)[None, :].repeat(frames.n, axis=0)
E_mlmm = {'mlmm': mlmm_calculator.get_frames_E(frames),
          'mbis': mlmm_calculator.get_frames_E(frames, s=s, q_core=q_core, q=q),
          'q_mbis': mlmm_calculator.get_frames_E(frames, q=q),
          'q_mbis_av': mlmm_calculator.get_frames_E(frames, q=q_av),
          'chi_mbis_av': mlmm_calculator.get_frames_E(frames, chi=chi_av)}

  r_inv = 1. / r_mat


Finally, calculate RMSEs with respect to B3LYP/cc-pVTZ energies

In [10]:
def rmse(a, b):
    return np.linalg.norm(a-b)/np.sqrt(len(a))

def rmse_norm(a, b):
    d = np.mean(a-b)
    return np.linalg.norm(a-b-d)/np.sqrt(len(a))

def get_rmses(frames_E, frames_E_ref):
    base_rmses = {'E': rmse(frames_E.E, frames_E_ref.E),
            'E_norm': rmse_norm(frames_E.E, frames_E_ref.E),
            'E_mean': np.mean(frames_E.E - frames_E_ref.E)}
    if hasattr(frames_E, 'E_static'):
        decompose_rmses = {
            'E_static': rmse(frames_E.E_static, frames_E_ref.E_static),
            'E_static_norm': rmse_norm(frames_E.E_static, frames_E_ref.E_static),
            'E_induced': rmse(frames_E.E_induced, frames_E_ref.E_induced),
            'E_induced_norm': rmse_norm(frames_E.E_induced, frames_E_ref.E_induced)
        }
        return {**base_rmses, **decompose_rmses}
    return {**base_rmses, 
            **{'E_static': None, 
               'E_static_norm': None, 
               'E_induced': None, 
               'E_induced_norm': None}}

def get_rmse_data(frames_E_dict, frames_E_ref):
    return [{'name': k, **get_rmses(v, frames_E_ref)} 
            for k, v in frames_E_dict.items()]

data_mlmm = get_rmse_data(E_mlmm, orca_dict['b3lyp_cc-PVTZ'])
data_orca = get_rmse_data(orca_dict, orca_dict['b3lyp_cc-PVTZ'])

In [11]:
from IPython.display import HTML, display
table = tabulate.tabulate(data_orca + data_mlmm, 
                          headers='keys', floatfmt=".3f", tablefmt='html')
display(HTML(table))

name,E,E_norm,E_mean,E_static,E_static_norm,E_induced,E_induced_norm
b3lyp_cc-PVTZ,0.0,0.0,0.0,0.0,0.0,0.0,0.0
b3lyp_cc-PVDZ,4.312,0.639,4.265,2.263,0.504,2.081,0.307
pbe0_cc-PVTZ,0.522,0.214,-0.476,0.635,0.217,0.123,0.022
blyp_631s,4.884,0.655,4.84,2.728,0.539,2.189,0.326
xtb,34.583,4.342,34.309,,,,
am1,68.266,10.33,67.48,,,,
pm3,77.713,10.663,76.978,,,,
mlmm,3.259,2.046,-2.537,3.177,1.941,0.567,0.566
mbis,5.351,1.489,-5.14,5.333,1.398,0.548,0.548
q_mbis,5.377,1.483,-5.168,5.376,1.394,0.56,0.56
