In [1]:
from iotbx import cif
from pathlib import Path
from contextlib import closing, redirect_stderr, redirect_stdout
from multiprocessing import Pool
import tqdm
from io import StringIO
import numpy as np
import networkx as nx
from rdkit import Chem
from rdkit.Chem.AllChem import *

In [2]:
raw_dir = Path("/dev/shm/cschlick/GEO_processing/geostd/")

In [3]:
cif_files = [file for file in raw_dir.glob("**/*") if file.suffix == ".cif" and "data" in file.name]

In [5]:
import MDAnalysis as mda
import copy
import tempfile

# create a dictionary of element symbols in the peculiar format the mda uses
pt  = Chem.GetPeriodicTable()
total_elements = set([pt.GetElementSymbol(i+1) for i in range(118)])
element_symbols = {}
for symbol,mass in mda.topology.tables.masses.items():
  
  if len(symbol)==2:
    a,b = symbol[0],symbol[1]
    element_symbols[a.upper()+b.upper()] = symbol
    element_symbols[a.upper()+b.lower()] = symbol
  else:
    element_symbols[symbol] = symbol
    
def guess_bond_order_mda(mol):
  """ 
  Use MDAnalysis to "guess" bond orders if ambiguous
  Requires explicit hydrogens to work
  See:
    1) https://cbouy.github.io/blog/2020/07/01/rdkit-converter
    2) https://cbouy.github.io/blog/2020/07/22/rdkit-converter-part2
  """
  
  newmol = None
  success = True
  skip = (len(mol.GetBonds())==0)
  if skip:
    success = True
    newmol = mol
  else:
    try:
      m = copy.deepcopy(mol)
      params = Chem.rdmolops.AdjustQueryParameters()
      params.makeBondsGeneric = True
      modmol = Chem.rdmolops.AdjustQueryProperties(m, params)
      modmol.UpdatePropertyCache()
      pdb_string= Chem.MolToPDBBlock(modmol)
      tmp = tempfile.NamedTemporaryFile('w+t')
      tmp.write(pdb_string)
      tmp.seek(0)
      u= mda.Universe(tmp.name,format="PDB")
      tmp.close()
      elements = mda.topology.guessers.guess_types(u.atoms.names)


      # convert symbols to those mda anticipates
      if not set(list(elements)).issubset(set(list(element_symbols.keys()))):
        success = False
      if success:
        elements = np.array([element_symbols[symbol] for symbol in elements],dtype=object)
        u.add_TopologyAttr('elements', elements)
        element_set = set([a.GetSymbol() for a in mol.GetAtoms()])
        force = ("H" not in element_set)
        newmol = u.atoms.convert_to("RDKIT",force=force)
        if newmol is not None:
          success = True
        else:
          success = False
    except:
      success = False
  return newmol,success


class CIFFile:


  @staticmethod
  def assign_bond_orders_from_reference(refmol,mol,sanitize=True):
    refmol2 = rdchem.Mol(refmol)
    mol2 = rdchem.Mol(mol)
    # do the molecules match already?
    matching = mol2.GetSubstructMatch(refmol2)
    if not matching:  # no, they don't match
      # check if bonds of mol are SINGLE
      for b in mol2.GetBonds():
        if b.GetBondType() != BondType.SINGLE:
          b.SetBondType(BondType.SINGLE)
          b.SetIsAromatic(False)
      # set the bonds of mol to SINGLE
      for b in refmol2.GetBonds():
        b.SetBondType(BondType.SINGLE)
        b.SetIsAromatic(False)
      # set atom charges to zero;
      for a in refmol2.GetAtoms():
        a.SetFormalCharge(0)
        a.SetNumRadicalElectrons(0)
      for a in mol2.GetAtoms():
        a.SetFormalCharge(0)
        a.SetNumRadicalElectrons(0)

      matching = mol2.GetSubstructMatches(refmol2, uniquify=False)
    # do the molecules match now?
    if matching:
      if len(matching) > 1:
        logger.warning("More than one matching pattern found - picking one")
      matching = matching[0]
      # apply matching: set bond properties
      for b in refmol.GetBonds():
        atom1 = matching[b.GetBeginAtomIdx()]
        atom2 = matching[b.GetEndAtomIdx()]
        b2 = mol2.GetBondBetweenAtoms(atom1, atom2)
        b2.SetBondType(b.GetBondType())
        b2.SetIsAromatic(b.GetIsAromatic())
      # apply matching: set atom properties
      for a in refmol.GetAtoms():
        a2 = mol2.GetAtomWithIdx(matching[a.GetIdx()])
        a2.SetHybridization(a.GetHybridization())
        a2.SetIsAromatic(a.GetIsAromatic())
        a2.SetNumExplicitHs(a.GetNumExplicitHs())
        a2.SetFormalCharge(a.GetFormalCharge())
        a2.SetNumRadicalElectrons(a.GetNumRadicalElectrons())
      if sanitize:
        SanitizeMol(mol2)
      if hasattr(mol2, '__sssAtoms'):
        mol2.__sssAtoms = None  # we don't want all bonds highlighted
    else:
      raise ValueError("No matching found")
    return mol2


  def __init__(self,filename,debug=False):
    self.filepath = Path(filename)
    self.needs_bond_assignment = False
    self.debug = debug
    failed = False
    fail_message = None
    ########### Parse file
    if not failed:
      try:
          self.cif_model = cif.reader(self.filepath.as_posix()).model()
      except:
        self.cif_model = None
        failed = True
        fail_message = "Parsing file"
        
#     ####### Get smiles
#     if not failed:
#       try:
#         with self.filepath.open("r") as fh:
#           for line in fh.readlines():
#             if "SMILES" in line:
#               self.smiles = line.replace("#","").replace("SMILES string:","").strip()
#               self.smiles_mol = Chem.MolFromSmiles(self.smiles)
#               if self.smiles_mol==None:
#                 failed = True
#                 fail_message = "Read smiles"
#       except:
#         failed = True
#         fail_message = "Read smiles"
    
    ###### Build rdkit
    if not failed:
      try:
        comp_id = "".join(self.cif_model["comp_list"]["_chem_comp.id"])
        atom_id = list(self.cif_model["comp_"+comp_id]["_chem_comp_atom.atom_id"])
        elements = list(self.cif_model["comp_"+comp_id]["_chem_comp_atom.type_symbol"])
        elements = [e[0].upper()+e[1].lower() if len(e)==2 else e.upper() for e in elements]

        x = [e for e in self.cif_model["comp_"+comp_id]["_chem_comp_atom.x"]]
        y = [e for e in self.cif_model["comp_"+comp_id]["_chem_comp_atom.y"]]
        z = [e for e in self.cif_model["comp_"+comp_id]["_chem_comp_atom.z"]]
        xyz = np.vstack([x,y,z]).T.astype(float)

        bonds_idx1 =  [atom_id.index(e) for e in self.cif_model["comp_"+comp_id]["_chem_comp_bond.atom_id_1"]]
        bonds_idx2 =  [atom_id.index(e) for e in self.cif_model["comp_"+comp_id]["_chem_comp_bond.atom_id_2"]]
        bond_pairs = np.vstack([bonds_idx1,bonds_idx2]).T.astype(int)

        bond_conversion = {"deloc":Chem.BondType.SINGLE,
                          "single":Chem.BondType.SINGLE,
                           "double":Chem.BondType.DOUBLE,
                           "triple":Chem.BondType.TRIPLE,
                           "aromatic":Chem.BondType.AROMATIC}

        bond_type = list(self.cif_model["comp_"+comp_id]["_chem_comp_bond.type"])
        if "deloc" in bond_type:
          # get deloc triads
          self.needs_bond_assignment = True

        bond_type = [bond_conversion[e] for e in bond_type]  

        bond_eq = np.array(self.cif_model["comp_"+comp_id]["_chem_comp_bond.value_dist"],dtype=float)
        bond_esd = np.array(self.cif_model["comp_"+comp_id]["_chem_comp_bond.value_dist_esd"],dtype=float)
        angle_idx1 =  [atom_id.index(e) for e in self.cif_model["comp_"+comp_id]["_chem_comp_angle.atom_id_1"]]
        angle_idx2 =  [atom_id.index(e) for e in self.cif_model["comp_"+comp_id]["_chem_comp_angle.atom_id_2"]]
        angle_idx3 =  [atom_id.index(e) for e in self.cif_model["comp_"+comp_id]["_chem_comp_angle.atom_id_3"]]
        angle_triples = np.vstack([angle_idx1,angle_idx2,angle_idx3]).T.astype(int)
        angle_eq = np.array(self.cif_model["comp_"+comp_id]["_chem_comp_angle.value_angle"],dtype=float)
        angle_esd = np.array(self.cif_model["comp_"+comp_id]["_chem_comp_angle.value_angle_esd"],dtype=float)
        # make rdmol
        pt  = Chem.GetPeriodicTable()
        mol = Chem.Mol()
        rwmol = Chem.RWMol(mol)
        conformer = Chem.Conformer(len(elements))

        for i,e in enumerate(elements):
          atomic_num = pt.GetAtomicNumber(e)
          atom = rwmol.AddAtom(Chem.Atom(atomic_num))
          conformer.SetAtomPosition(atom,xyz[i])

        atoms = rwmol.GetAtoms()

        for i,bond_pair in enumerate(bond_pairs):
          idx1,idx2 = bond_pair
          rwmol.AddBond(int(idx1),int(idx2),bond_type[i])

        rwmol.AddConformer(conformer)
        rdmol = rwmol.GetMol()
        self.atom_id = np.array(atom_id)
        self.rdmol = rdmol



      except:
        failed = True
        fail_message = "Build rdkit mol"
      
      ####### regularize bonds
      if not failed:
        try:
          newmol,success = guess_bond_order_mda(self.rdmol)
          if success:
            self.rdmol = newmol
          else:
            failed = True
            fail_message = "Regularize bonds"
        except:
          failed = True
          fail_message = "Regularize bonds"
      ####### sanitize
      if not failed:
        try:
          assert Chem.rdmolops.SanitizeFlags.SANITIZE_NONE == Chem.SanitizeMol(self.rdmol)
        except:
          failed = True
          fail_message = "Sanitize"
      if not failed:
        try:
          self.molecule = Molecule(self.rdmol,atom_names=self.atom_id,filepath=self.filepath)
          m = self.molecule
          m.bond_file_data["atom1_idx"] = m.bond_indices[:,0]
          m.bond_file_data["atom2_idx"] = m.bond_indices[:,1]
          m.bond_file_data["atom1_id"] =  m.atom_names[m.bond_indices[:,0]]
          m.bond_file_data["atom2_id"] =  m.atom_names[m.bond_indices[:,1]]

          bond_eq_dict = {frozenset(bond_inds):bond_eq[i] for i,bond_inds in enumerate(bond_pairs)}
          bond_esd_dict = {frozenset(bond_inds):bond_esd[i] for i,bond_inds in enumerate(bond_pairs)}

          m.bond_file_data["eq"] = [bond_eq_dict[frozenset(bond_inds)] for bond_inds in m.bond_indices]
          m.bond_file_data["esd"] = [bond_esd_dict[frozenset(bond_inds)] for bond_inds in m.bond_indices]
          
          
          m.angle_file_data["atom1_idx"] = m.angle_indices[:,0]
          m.angle_file_data["atom2_idx"] = m.angle_indices[:,1]
          m.angle_file_data["atom3_idx"] = m.angle_indices[:,2]
          m.angle_file_data["atom1_id"] =  m.atom_names[m.angle_indices[:,0]]
          m.angle_file_data["atom2_id"] =  m.atom_names[m.angle_indices[:,1]]
          m.angle_file_data["atom3_id"] =  m.atom_names[m.angle_indices[:,2]]
          
          angle_eq_dict = {frozenset(angle_inds):angle_eq[i] for i,angle_inds in enumerate(angle_triples)}
          angle_esd_dict = {frozenset(angle_inds):angle_esd[i] for i,angle_inds in enumerate(angle_triples)}

          m.angle_file_data["eq"] = [angle_eq_dict[frozenset(angle_inds)] for angle_inds in m.angle_indices]
          m.angle_file_data["esd"] = [angle_esd_dict[frozenset(angle_inds)] for angle_inds in m.angle_indices]
          
        except:
          failed = True
          fail_message = "Creating molecule object"
          if self.debug:
            raise

    self.failed = failed
    self.fail_message = fail_message

In [6]:
from itertools import combinations
import networkx as nx
import itertools
import pandas as pd

class Molecule:
  @staticmethod
  def rdmol_to_nx_homo(mol):
    G = nx.Graph()

    for atom in mol.GetAtoms():
        G.add_node(atom.GetIdx(),
                   symbol = atom.GetSymbol(),
                   atomic_num=atom.GetAtomicNum())

    for bond in mol.GetBonds():
        G.add_edge(bond.GetBeginAtomIdx(),
                   bond.GetEndAtomIdx(),
                   bond_type=bond.GetBondType())
    return G
  
  def __init__(self,rdmol,atom_names=None,filepath=None):
    self.rdmol = rdmol
    self.atom_names = atom_names
    self.filepath = filepath
    self.homograph = self.rdmol_to_nx_homo(self.rdmol)
    self.bond_file_data = {}
    self.angle_file_data = {}
  
  def _ipython_display_(self):
    from IPython.display import display
    self.show()
  
  @property
  def rdmol_2d(self):
    if not hasattr(self,"_rdmol_2d"):
      self._rdmol_2d = Chem.Mol(self.rdmol)
      Chem.rdDepictor.Compute2DCoords(self._rdmol_2d)
    return self._rdmol_2d
      
  @property
  def bond_indices(self):
      bonds = []
      for bond in self.rdmol.GetBonds():
        start,end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bonds.append((start,end))
      return np.array(bonds)
  
  @property
  def angle_indices(self):
      angles = set()
      for atom1 in self.rdmol.GetAtoms():
          for atom2 in atom1.GetNeighbors():
              for atom3 in atom2.GetNeighbors():
                  a,b,c = atom1.GetIdx(), atom2.GetIdx(), atom3.GetIdx()
                  if len(set([a,b,c]))<3:
                      continue
                  if a>c:
                      angles.add((a,b,c))
                  else:
                      angles.add((c,b,a))
      angles = list(angles)
      angles.sort(key=lambda x: x[2],reverse=True)
      angles.sort(key=lambda x: x[0],reverse=True)
      angles.sort(key=lambda x: x[1],reverse=False)
      return np.array(angles)
  
  @property
  def bond_df(self):
    return pd.DataFrame(self.bond_file_data)
  
  @property
  def angle_df(self):
    return pd.DataFrame(self.angle_file_data)
  
  def show(self):
    molSize = (600,600)
    figsize = 10
    fontsize = 14
    title=None
    if m.filepath is not None:
      title = self.filepath.name

    drawer = rdMolDraw2D.MolDraw2DCairo(molSize[0],molSize[1])
    rdMolDraw2D.PrepareAndDrawMolecule(drawer,self.rdmol_2d,highlightAtoms=[],highlightBonds=[])

    drawer.DrawMolecule(self.rdmol_2d)
    drawer.FinishDrawing()
    # read as bitmap
    file = io.BytesIO(drawer.GetDrawingText())
    mol_bitmap = plt.imread(file)



    fig, ax = plt.subplots(1,1,figsize=(figsize,figsize))
    ax.imshow(mol_bitmap)
    ax.axes.spines["left"].set_visible(False)
    ax.axes.spines["top"].set_visible(False)
    ax.axes.spines["right"].set_visible(False)
    ax.axes.spines["bottom"].set_visible(False)
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)



In [7]:
work = cif_files
def worker(data):
  with redirect_stderr(StringIO()) as err:
    ciffile = CIFFile(data)
  return ciffile

In [None]:
with closing(Pool(processes=16)) as pool:
  results = []
  for result in tqdm.tqdm(pool.imap_unordered(worker, work), total=len(work)):
      results.append(result)
  pool.terminate()

  0%|                                                                                                                                                  | 0/21103 [00:00<?, ?it/s][21:42:26] non-ring atom 13 marked aromatic
  0%|                                                                                                                                        | 1/21103 [00:00<4:23:39,  1.33it/s][21:42:26] non-ring atom 18 marked aromatic
[21:42:26] non-ring atom 5 marked aromatic
  0%|                                                                                                                                        | 4/21103 [00:01<1:18:53,  4.46it/s][21:42:28] non-ring atom 20 marked aromatic
  0%|                                                                                                                                       | 6/21103 [00:09<13:28:36,  2.30s/it][21:42:35] non-ring atom 33 marked aromatic
  0%|▏                                                                   

Traceback (most recent call last):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/pool.py", line 853, in next
    item = self._items.popleft()
IndexError: pop from an empty deque

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3444, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_118639/1956106820.py", line 3, in <module>
    for result in tqdm.tqdm(pool.imap_unordered(worker, work), total=len(work)):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/site-packages/tqdm/std.py", line 1180, in __iter__
    for obj in iterable:
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/pool.py", line 858, in next
    self._cond.wait(timeout)
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/pool.py", line 853, in next
    item = self._items.popleft()
IndexError: pop from an empty deque

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3444, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_118639/1956106820.py", line 3, in <module>
    for result in tqdm.tqdm(pool.imap_unordered(worker, work), total=len(work)):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/site-packages/tqdm/std.py", line 1180, in __iter__
    for obj in iterable:
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/pool.py", line 858, in next
    self._cond.wait(timeout)
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/

Process ForkPoolWorker-10292:
Traceback (most recent call last):
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/queues.py", line 368, in get
    return _ForkingPickler.loads(res)


Traceback (most recent call last):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/pool.py", line 853, in next
    item = self._items.popleft()
IndexError: pop from an empty deque

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3444, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_118639/1956106820.py", line 3, in <module>
    for result in tqdm.tqdm(pool.imap_unordered(worker, work), total=len(work)):
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/site-packages/tqdm/std.py", line 1180, in __iter__
    for obj in iterable:
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/multiprocessing/pool.py", line 858, in next
    self._cond.wait(timeout)
  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/

  File "/net/cci/cschlick/miniconda3/envs/espaloma/lib/python3.9/pathlib.py", line 1082, in __new__
    self = cls._from_parts(args, init=False)
KeyboardInterrupt


In [461]:
from collections import Counter
failure_messages = Counter([ciffile.fail_message for ciffile in results])
print(failure_messages)

Counter({None: 18156, 'Creating molecule object': 2782, 'Regularize bonds': 116, 'Build rdkit mol': 49})


In [462]:
passed = [ciffile for ciffile in results if not ciffile.failed]
failed = [ciffile for ciffile in results if ciffile.failed]
print("Failed:",len(failed))
print("Passed:",len(passed))

Failed: 2947
Passed: 18156
