In [1]:
import torch
import dgl
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from iotbx import cif
from iotbx.data_manager import DataManager
import os
os.environ["MMTBX_CCP4_MONOMER_LIB"] = "/Users/user/software/phenix/modules/chem_data/geostd"


import sys
from pathlib import Path
module_dir = Path("/Users/user/software/phenix/modules/")
for m in module_dir.glob("*"):
  if m.is_dir():
    sys.path.append(str(Path(module_dir,m)))

In [104]:
dm = DataManager()
dm.process_model_file("../data/PO_data/12AS/output/12AS.updated.pdb")
dm.process_restraint_file("../data/PO_data/12AS/output/restraints/AMP.cif")
model = dm.get_model()
model.process(make_restraints=True)

In [106]:
rm = model.restraints_manager
grm = rm.geometry
bonds_simple, bonds_asu = grm.get_all_bond_proxies()

In [107]:
len(bonds_simple)

10368

In [108]:
m = Chem.MolFromPDBFile("../data/PO_data/12AS/output/12AS.updated.pdb",removeHs=False)

In [110]:
from iotbx import cif

In [None]:
cif.

In [9]:
model = model.select(model.selection("resname AMP"))

In [10]:
from rdkit import Chem
def convert_cctbx_to_rdkit(cctbx_model,skip_xyz=False):
  """
  Build an rdkit molecule from a cctbx model by enumerating through
  atoms and bond proxies.
  
  Args:
    cctbx_model (mmtbx.model.model.manager): The cctbx model object to convert
    skip_xyz (bool): Whether or not to include the cartesian coordinates
  
  Returns:
    mol (rdkit.Chem.Mol): The rdkit mol
  """
  assert cctbx_model.restraints_manager_available(), "Must provide a cctbx model with restraints manager."
  pt  = Chem.GetPeriodicTable()
  mol = Chem.Mol()
  rwmol = Chem.RWMol(mol)
  if not skip_xyz:
    conformer = Chem.Conformer(cctbx_model.get_number_of_atoms())


  # atoms
  for i,atom in enumerate(cctbx_model.get_atoms()):
    e = atom.element.strip()
    if e=="D":
      e = "H"
    elif len(e)==2:
      e = e[0].upper()+e[1].lower()
    atomic_num = pt.GetAtomicNumber(e)
    rdatom = Chem.Atom(atomic_num)
    charge = atom.charge_as_int()
    rdatom.SetFormalCharge(charge)



    atomi = rwmol.AddAtom(rdatom)
    assert i==atomi, "Mismatch between atom enumerate index and atom index"
    if not skip_xyz:
      x,y,z = atom.xyz
      conformer.SetAtomPosition(atomi,(float(x),
                                       float(y),
                                       float(z)))

  # bonds
  rm = cctbx_model.restraints_manager
  grm = rm.geometry
  bonds_simple, bonds_asu = grm.get_all_bond_proxies()
  bond_proxies = bonds_simple.get_proxies_with_origin_id()


  for bond_proxy in bond_proxies:
    bond_type = Chem.rdchem.BondType.UNSPECIFIED # Need bond type from cctbx
    i,j = bond_proxy.i_seqs
    rwmol.AddBond(i,j,bond_type)

  if not skip_xyz:
    rwmol.AddConformer(conformer)
  rdmol = rwmol.GetMol()

  return rdmol

In [11]:
mol = convert_cctbx_to_rdkit(model)

In [12]:
import numpy as np
def build_atom_graph_from_rdkit(rdkit_mol,
                                skip_hydrogen=True,
                                atom_features=None):
  """
  Create a dgl graph with nodes as atoms and edges as bonds
  
  Args:
    rdkit_mol (rdkit.Chem.Mol): Input molecule
    skip_hydrogens (bool): Whether to include hydrogens in graph
    atom_features (np.ndarray): Feature vector for each atom in input 
                                Shape=(n_atoms,n_features)
                                
  Returns:
    g (dgl.heterograph.DGLHeteroGraph): The dgl graph object
    atom_idx_noH_mapper (dict): A dictionary to map between atom indices with/without hydrogens
                                key=original atom idx
                                value=atom idx with Hs removed
  """
  if type(atom_features)==type(None):
    atom_features = np.zeros((rdkit_mol.GetNumAtoms(),1))

  atom_idxs_wH = []
  atom_idxs_woutH = []
  atom_idx_noH_mapper = {} # a dict with key: original atom idx, value: atom idx with Hs removed
  for i,atom in enumerate(rdkit_mol.GetAtoms()):
    assert i == atom.GetIdx(), "Mismatch between atom.GetIdx() and position in molecule"
    atom_idxs_wH.append(i)
    if atom.GetAtomicNum()>1:
      atom_idxs_woutH.append(i)

  if not skip_hydrogen:
    atom_idxs_woutH = atom_idxs_wH
  
  for i,idx in enumerate(atom_idxs_woutH):
    atom_idx_noH_mapper[idx] = i
    
  bond_idxs = []
  for bond in rdkit_mol.GetBonds():
    begin, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
    if (begin in atom_idxs_woutH) and (end in atom_idxs_woutH):
      begin, end = atom_idx_noH_mapper[begin], atom_idx_noH_mapper[end]
      bond_idxs.append([begin,end])
      
  bond_idxs = np.vstack([bond_idxs,np.flip(bond_idxs,axis=1)]) # add reverse direction edges
  g = dgl.graph((bond_idxs[:,0],bond_idxs[:,1]))

  g.ndata["h0"] = torch.from_numpy(atom_features[atom_idxs_woutH]) # set initial representation
  g.ndata["mol_idxs"] = torch.from_numpy(np.array(atom_idxs_woutH))
  return g, atom_idx_noH_mapper


def build_atom_graph_from_cctbx(cctbx_model,
                                skip_hydrogen=True,
                                atom_features=None):
  """
  Create a dgl graph with nodes as atoms and edges as bonds
  
  Args:
    cctbx_model (mmtbx.model.model.manager): Input molecule
    skip_hydrogens (bool): Whether to include hydrogens in graph
    atom_features (np.ndarray): Feature vector for each atom in input 
                                Shape=(n_atoms,n_features)
                                
  Returns:
    g (dgl.heterograph.DGLHeteroGraph): The dgl graph object
    atom_idx_noH_mapper (dict): A dictionary to map between atom indices with/without hydrogens
                                key=original atom idx
                                value=atom idx with Hs removed
  """
  if type(atom_features)==type(None):
    atom_features = np.zeros((cctbx_model.get_number_of_atoms(),1))

  atom_idxs_wH = []
  atom_idxs_woutH = []
  atom_idx_noH_mapper = {} # a dict with key: original atom idx, value: atom idx with Hs removed
  for i,atom in enumerate(cctbx_model.get_atoms()):
    assert i == atom.i_seq, "Mismatch between atom.i_seq and position in molecule"
    atom_idxs_wH.append(i)
    element = atom.element.strip()
    if element.upper() not in ["H","D"]:
      atom_idxs_woutH.append(i)

  if not skip_hydrogen:
    atom_idxs_woutH = atom_idxs_wH
  
  for i,idx in enumerate(atom_idxs_woutH):
    atom_idx_noH_mapper[idx] = i
    
  bond_idxs = []
  rm = cctbx_model.restraints_manager
  grm = rm.geometry
  bonds_simple, bonds_asu = grm.get_all_bond_proxies()
  bond_proxies = bonds_simple.get_proxies_with_origin_id()
  for bond_proxy in bond_proxies:
    begin, end = bond_proxy.i_seqs
    if (begin in atom_idxs_woutH) and (end in atom_idxs_woutH):
      begin, end = atom_idx_noH_mapper[begin], atom_idx_noH_mapper[end]
      bond_idxs.append([begin,end])
      
  bond_idxs = np.vstack([bond_idxs,np.flip(bond_idxs,axis=1)]) # add reverse direction edges
  g = dgl.graph((bond_idxs[:,0],bond_idxs[:,1]))

  g.ndata["h0"] = torch.from_numpy(atom_features[atom_idxs_woutH]) # set initial representation
  g.ndata["mol_idxs"] = torch.from_numpy(np.array(atom_idxs_woutH))
  return g, atom_idx_noH_mapper

In [29]:
g

Graph(num_nodes=70, num_edges=148,
      ndata_schemes={'h0': Scheme(shape=(1,), dtype=torch.float64), 'mol_idxs': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})

In [42]:
def build_fragment_graph(atom_graph,
                         frag_idxs,
                         node_name="atom",
                         frag_name="fragment",
                         fragment_data={}):
  """
  Build a dgl heterograph with "fragment" nodes connected to atoms
  by edges
  
  Args:
    atom_graph (dgl.graph): The atom graph (nodes are atoms, edges are bonds)
    frag_idxs (np.ndarray): Shape (n_fragments, n_atoms_per_frag)
                            The node indices of each fragment. NOTE: Not necessarily
                            the atom indices, if omitting hydrogens for example.
    node_name (str): The name for each (atom) node
    frag_name (str): The name for each fragment node
    fragment_data (dict): A dictionary of additional fragment data to attach to the graph.
                          For example, ground truth regression data.
                          
  Returns:
    frag_graph (dgl.graph): The new heterograph (different node types) with fragment nodes present.
                            
  """

  e1,e2 = atom_graph.edges()
  bonded_idxs = np.stack([e1.numpy(),e2.numpy()],axis=1)
  
  edge_dict = {}
  edge_dict[(node_name,"%s_%s_%s" % (node_name,"bonded",node_name), node_name)] = bonded_idxs

  for i in range(frag_idxs.shape[1]):
    name = (node_name,"%s_as_%s_in_%s" % (node_name,i,frag_name),frag_name)
    frag_edge_idxs = np.stack([frag_idxs[:,i],np.arange(frag_idxs.shape[0])],axis=1)
    edge_dict[name] = frag_edge_idxs
    
  frag_graph = dgl.heterograph({key: list(value) for key, value in edge_dict.items()})
  frag_graph.nodes[node_name].data["h0"] = atom_graph.ndata["h0"].type(torch.get_default_dtype())



  frag_graph.nodes[node_name].data["mol_idxs"] = atom_graph.ndata["mol_idxs"]
  frag_graph.nodes[frag_name].data["mol_idxs"] = torch.from_numpy(atom_graph.ndata["mol_idxs"].numpy()[frag_idxs])
  for key,value in fragment_data.items():
    frag_graph.nodes[frag_name].data[key] = torch.from_numpy(value)
  
  return frag_graph

In [31]:
def espaloma_fingerprint(atom):

        HYBRIDIZATION_RDKIT = {
          Chem.rdchem.HybridizationType.SP: np.array(
              [1, 0, 0, 0, 0],
          ),
          Chem.rdchem.HybridizationType.SP2: np.array(
              [0, 1, 0, 0, 0],
          ),
          Chem.rdchem.HybridizationType.SP3: np.array(
              [0, 0, 1, 0, 0],
          ),
          Chem.rdchem.HybridizationType.SP3D: np.array(
              [0, 0, 0, 1, 0],
          ),
          Chem.rdchem.HybridizationType.SP3D2: np.array(
              [0, 0, 0, 0, 1],
          ),
          Chem.rdchem.HybridizationType.S: np.array(
              [0, 0, 0, 0, 0],
          ),
          Chem.rdchem.HybridizationType.UNSPECIFIED: np.array(
              [1, 0, 0, 0, 0], # TODO: UNSPECIFIED goes to SP (it seems very rare...)
          ),
        }

        return np.concatenate(
              [
                  np.array(
                      [
                          #atom.GetTotalDegree(),
                          #atom.GetTotalNumHs(),
                          #atom.GetTotalValence(),
                          #atom.GetExplicitValence(),
                          len(atom.GetNeighbors())*1.0,
                          len([a for a in atom.GetNeighbors() if a.GetAtomicNum()==1]),
                          atom.GetFormalCharge() if atom.GetFormalCharge()>0 else 0,
                          atom.GetFormalCharge() if atom.GetFormalCharge()<0 else 0,
                          atom.GetIsAromatic() * 1.0,
                          atom.GetMass(),
                          atom.IsInRingSize(3) * 1.0,
                          atom.IsInRingSize(4) * 1.0,
                          atom.IsInRingSize(5) * 1.0,
                          atom.IsInRingSize(6) * 1.0,
                          atom.IsInRingSize(7) * 1.0,
                          atom.IsInRingSize(8) * 1.0,
                      ],
                  ),
                  HYBRIDIZATION_RDKIT[atom.GetHybridization()],
              ],
          )

In [32]:
atom_features = np.vstack([espaloma_fingerprint(atom) for atom in mol.GetAtoms()])

In [36]:
atom_graph, _ = build_atom_graph_from_rdkit(mol,skip_hydrogen=False,atom_features=atom_features)

In [37]:
atom_graph.ndata["h0"]

tensor([[4., 1., 0.,  ..., 0., 0., 0.],
        [3., 1., 0.,  ..., 0., 0., 0.],
        [4., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)

In [59]:
rm = model.restraints_manager
grm = rm.geometry
result = bonds_simple.get_sorted("delta",model.get_sites_cart())
bonds_simple, bonds_asu = grm.get_all_bond_proxies()
bond_proxies = bonds_simple.get_proxies_with_origin_id()
frag_idxs = np.array([bond_proxy.i_seqs for bond_proxy in bond_proxies])
bond_length_ideal = np.array([bond_proxy.distance_ideal for bond_proxy in bond_proxies]) # Note: want actual for regression

In [60]:
bond_proxy = bonds_simple[0]

In [61]:
bond_proxy.distance_ideal

1.529

In [68]:
result = bonds_simple.get_sorted("delta",model.get_sites_cart())

In [66]:
for r in result:
  i,j = r[0],r[1]
  ideal, actual = r[3], r[4]

TypeError: 'int' object is not subscriptable

In [43]:
frag_graph = build_fragment_graph(atom_graph,
                                  frag_idxs=frag_idxs,
                                  fragment_data={"bond_length_ideal":bond_length_ideal})

In [69]:
frag_graph

Graph(num_nodes={'atom': 70, 'fragment': 74},
      num_edges={('atom', 'atom_as_0_in_fragment', 'fragment'): 74, ('atom', 'atom_as_1_in_fragment', 'fragment'): 74, ('atom', 'atom_bonded_atom', 'atom'): 148},
      metagraph=[('atom', 'fragment', 'atom_as_0_in_fragment'), ('atom', 'fragment', 'atom_as_1_in_fragment'), ('atom', 'atom', 'atom_bonded_atom')])

In [76]:
import torch
import torch.nn.functional as F
import dgl
import numpy as np


class MessagePassingBonded(torch.nn.Module):
    """Sequential neural network with input layers.
    Parameters
    ----------
    layer : torch.nn.Module
        DGL graph convolution layers.
    config : List
        A sequence of numbers (for units) and strings (for activation functions)
        denoting the configuration of the sequential model.
    feature_units : int(default=117)
        The number of input channels.
    Methods
    -------
    forward(g, x)
        Forward pass.
    """
    

    
    def __init__(
        self,
        nlayers= 3,
        feature_units=None,
        hidden_units=128,
        atom_node_name = "atom",
        fragment_name = "fragment",
        model_kwargs={},
    ):
        super(MessagePassingBonded, self).__init__()
        
        # validate
        assert feature_units is not None, "Specify the size of feature units"
        
        # setup
        self.atom_node_name = atom_node_name
        self.fragment_name = fragment_name
        
        
        # initial featurization
        self.f_in = torch.nn.Sequential(
            torch.nn.Linear(feature_units, hidden_units), torch.nn.Tanh()
        )
        layers = []
        for i in range(nlayers):
          layers.append(dgl.nn.pytorch.conv.sageconv.SAGEConv(hidden_units,hidden_units,"mean",bias=True,activation=F.relu))
          #layers.append(torch.nn.ReLU())
          
        self.mp = dgl.nn.Sequential(*layers)


    def forward(self, g, x=None):
        
        # get homogeneous subgraph
        edge_type = "%s_%s_%s" % (self.atom_node_name,"bonded",self.atom_node_name)
        g_ = dgl.to_homogeneous(g.edge_type_subgraph([edge_type]))

        if x is None:
            # get node attributes
            x = g.nodes[self.atom_node_name].data["h0"]
            x = self.f_in(x)

        # message passing on atom graph
        x = self.mp(g_,x)

        # put attribute back in the graph
        g.nodes[self.atom_node_name].data["h"] = x

        return g

In [82]:
mp = MessagePassingBonded(feature_units=atom_features.shape[1],hidden_units=16)

In [83]:
output = mp(frag_graph)

In [84]:
output

Graph(num_nodes={'atom': 70, 'fragment': 74},
      num_edges={('atom', 'atom_as_0_in_fragment', 'fragment'): 74, ('atom', 'atom_as_1_in_fragment', 'fragment'): 74, ('atom', 'atom_bonded_atom', 'atom'): 148},
      metagraph=[('atom', 'fragment', 'atom_as_0_in_fragment'), ('atom', 'fragment', 'atom_as_1_in_fragment'), ('atom', 'atom', 'atom_bonded_atom')])

In [85]:
output.nodes["atom"].data["h"].shape

torch.Size([70, 16])

In [88]:
output.nodes["atom"].data["h"]

tensor([[0.0000, 1.7602, 1.2307,  ..., 0.0000, 0.0321, 0.5444],
        [0.0000, 2.0693, 1.0694,  ..., 0.0000, 0.0000, 0.6958],
        [0.0000, 1.4575, 0.9586,  ..., 0.0000, 0.0999, 0.1989],
        ...,
        [0.0000, 0.9082, 0.3219,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 1.4117, 0.5076,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 1.4117, 0.5076,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<ReluBackward0>)