In [1]:
import espaloma as esp
from dgllife.utils.featurizers import (
    BaseAtomFeaturizer,
    ConcatFeaturizer,
    atom_type_one_hot,
    atom_degree_one_hot,
    atom_hybridization_one_hot,
    atom_is_aromatic,
    atom_is_in_ring_one_hot,
)

In [2]:
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.drawOptions.addAtomIndices = True

In [3]:
class AtomFeaturizer(BaseAtomFeaturizer):
    def __init__(self, atom_data_field='h'):
        super().__init__(
            featurizer_funcs={atom_data_field: ConcatFeaturizer(
                [atom_type_one_hot,
                 atom_degree_one_hot,
                 atom_is_aromatic,
                 atom_is_in_ringsize_one_hot,
                 atom_hybridization_one_hot,
                ]
            )})


        
def atom_is_in_ringsize_one_hot(atom):
    """One hot encoding ring size.
    """
    return atom.IsInRingSize(3), atom.IsInRingSize(4), atom.IsInRingSize(5), atom.IsInRingSize(6), atom.IsInRingSize(7), atom.IsInRingSize(8)



def from_openff_toolkit_mol(mol, use_fp=True):
    import dgl
    import torch
    from openmm import unit
    from dgllife.utils import mol_to_bigraph
    
    total_charge = mol.total_charge.value_in_unit(unit.elementary_charge)
    
    # convert openff molecule to rdkit
    rdmol = mol.to_rdkit()
    
    # canonical_atom_order=True might change the order of atoms in the graph constructed.
    g = mol_to_bigraph(rdmol, add_self_loop=False, node_featurizer=AtomFeaturizer("h0"), canonical_atom_order=False)
    
    g.ndata["type"] = torch.Tensor(
        [[atom.GetAtomicNum()] for atom in rdmol.GetAtoms()]
    )
    g.ndata["formal_charge"] = torch.Tensor(
        [[atom.GetFormalCharge()] for atom in rdmol.GetAtoms()]
    )

    g.ndata["sum_q"] = torch.Tensor(
        [[total_charge] for _ in range(rdmol.GetNumAtoms())]
    )
    
    return g

In [4]:
_g = esp.Graph.load("../../../exploring-rna/rna-espaloma/espaloma-openff-default.3/02-train-force/merge-data/openff-2.0.0_filtered/gen2/0")
mol = _g.mol   # openff molecule

In [5]:
g = from_openff_toolkit_mol(mol)

In [6]:
g

Graph(num_nodes=37, num_edges=78,
      ndata_schemes={'h0': Scheme(shape=(66,), dtype=torch.float32), 'type': Scheme(shape=(1,), dtype=torch.float32), 'formal_charge': Scheme(shape=(1,), dtype=torch.float32), 'sum_q': Scheme(shape=(1,), dtype=torch.float32)}
      edata_schemes={})

In [7]:
g.ndata.keys()

dict_keys(['h0', 'type', 'formal_charge', 'sum_q'])

In [8]:
g.ndata['h0']

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [9]:
g.ndata['h0'][0]

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])

In [10]:
g.ndata['type']

tensor([[6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [6.],
        [7.],
        [7.],
        [8.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])

In [11]:
g.ndata['formal_charge']

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])

In [12]:
g.ndata['sum_q']

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])

In [13]:
g.edges()

(tensor([ 0,  1,  0, 21,  1, 10,  2,  3,  2,  6,  2, 22,  3,  7,  3, 23,  4,  5,
          4,  8,  4, 24,  5, 10,  5, 25,  6, 11,  6, 26,  7, 12,  7, 27,  8, 13,
          8, 28,  9, 10,  9, 13,  9, 29, 11, 12, 11, 15, 12, 16, 13, 19, 14, 18,
         14, 19, 14, 20, 15, 17, 15, 30, 15, 31, 16, 18, 16, 32, 16, 33, 17, 18,
         17, 34, 17, 35, 19, 36], dtype=torch.int32),
 tensor([ 1,  0, 21,  0, 10,  1,  3,  2,  6,  2, 22,  2,  7,  3, 23,  3,  5,  4,
          8,  4, 24,  4, 10,  5, 25,  5, 11,  6, 26,  6, 12,  7, 27,  7, 13,  8,
         28,  8, 10,  9, 13,  9, 29,  9, 12, 11, 15, 11, 16, 12, 19, 13, 18, 14,
         19, 14, 20, 14, 17, 15, 30, 15, 31, 15, 18, 16, 32, 16, 33, 16, 18, 17,
         34, 17, 35, 17, 36, 19], dtype=torch.int32))

In [14]:
# create edges with openff molecules
import dgl
g_off = dgl.DGLGraph()
bonds = list(mol.bonds)
bonds_begin_idxs = [bond.atom1_index for bond in bonds]
bonds_end_idxs = [bond.atom2_index for bond in bonds]
bonds_types = [bond.bond_order for bond in bonds]

g_off.add_edges(bonds_begin_idxs, bonds_end_idxs)
g_off.add_edges(bonds_end_idxs, bonds_begin_idxs)

g_off.edges() # does this have to be in the same order as edges created with rdkit molecule?



(tensor([ 0,  0,  1,  2,  2,  2,  3,  3,  4,  4,  4,  5,  5,  6,  6,  7,  7,  8,
          8,  9,  9,  9, 11, 11, 12, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17,
         17, 17, 19,  1, 21, 10,  3,  6, 22,  7, 23,  5,  8, 24, 10, 25, 11, 26,
         12, 27, 13, 28, 10, 13, 29, 12, 15, 16, 19, 18, 19, 20, 17, 30, 31, 18,
         32, 33, 18, 34, 35, 36]),
 tensor([ 1, 21, 10,  3,  6, 22,  7, 23,  5,  8, 24, 10, 25, 11, 26, 12, 27, 13,
         28, 10, 13, 29, 12, 15, 16, 19, 18, 19, 20, 17, 30, 31, 18, 32, 33, 18,
         34, 35, 36,  0,  0,  1,  2,  2,  2,  3,  3,  4,  4,  4,  5,  5,  6,  6,
          7,  7,  8,  8,  9,  9,  9, 11, 11, 12, 13, 14, 14, 14, 15, 15, 15, 16,
         16, 16, 17, 17, 17, 19]))