In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from rdkit import Chem
from rdkit.Chem import MolFromSmiles

In [2]:
smiles = [
    'COC(=O)CNC(c1ccccc1)c1cc(Br)ccc1NC(=O)c1ccccc1Cl',
    'O=C(NCC1CCCO1)c1[nH]nnc1-c1ccc(F)cc1',
    'O=C1NC(=O)C(=CNc2cccc(O)c2)C(=O)N1',
    'N#Cc1c(Oc2ccc(F)c(NC(=O)Cc3cccc(C(F)(F)F)c3)c2)ccc2nc(NC(=O)C3CC3)sc12',
    'C=CC(=O)N1CC(Nc2ncnc3[nH]ccc23)CCC1C',
    'Cc1nnc(NN=Cc2ccc(Cl)c([N+](=O)[O-])c2)[nH]c1=O',
    'Cc1cccc(CN2CCN(C(c3ccccc3)c3ccc(Cl)cc3)CC2)c1.Cl.Cl',
    'O=c1c(O)c(-c2ccc(O)cc2O)oc2cc(O)cc(O)c12',
    'CCOc1ccc(C=C2SC(=S)N(C)C2=O)cc1', 'COc1ccccc1OCC(O)COC(N)=O',
    'O=S(=O)(c1ccc(F)cc1)N1CCC(c2nc3ccccc3[nH]2)CC1',
    'NC1(C(=O)NC(CCO)c2ccc(Cl)cc2)CCN(c2ncnc3[nH]ccc23)CC1',
    'Nc1nc(NC2CC2)c2ncn(C3C=CC(CO)C3)c2n1',
    'O=C(O)c1ccc2c3c1cccc3c(=O)n1c3ccccc3nc21',
    'CCN1C(C)=C(C(=O)OC)C(c2ccc([N+](=O)[O-])cc2)C(C(=O)OC)=C1C',
    'CCN1/C(=C/C(C)=O)Sc2ccc(OC)cc21'
]

In [35]:
mol_atom_features = []
mol_edge_lists = []

for smile in smiles:
    mol = MolFromSmiles(smile)

    nodes = torch.tensor(
        [[atom.GetAtomicNum(), 
          atom.GetTotalDegree(), 
          atom.GetFormalCharge(), 
          atom.GetTotalNumHs(), 
          atom.GetNumRadicalElectrons()]
         for atom in mol.GetAtoms()], 
        requires_grad=False
    ).float()

    adj = torch.tensor(
        Chem.rdmolops.GetAdjacencyMatrix(mol), 
        requires_grad=False
    ).fill_diagonal_(1.).float()

    '''
    edges = []
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edges.append([i, j])
        edges.append([j, i])
    edges = torch.tensor(edges, requires_grad=False)
    '''

    mol_atom_features.append(nodes)
    mol_edge_lists.append(adj)

In [49]:
nn.functional.pad(torch.randn(5, 5), pad=((0, 2, 0, 1)))

tensor([[ 0.6034, -1.1895,  0.3054,  0.5639,  0.6719,  0.0000,  0.0000],
        [ 1.5021,  0.8059,  0.1011, -0.2170, -2.2913,  0.0000,  0.0000],
        [-1.0052, -0.3480,  0.5345, -1.5106, -1.8436,  0.0000,  0.0000],
        [ 0.5952,  1.1640,  0.3395, -0.6214,  0.3493,  0.0000,  0.0000],
        [-1.0278,  1.5915,  1.3787,  0.4622,  0.5734,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])

In [38]:
node_tensor = pad_sequence(mol_atom_features, batch_first=True, padding_value=0, padding_side='right')
edge_tensor = pad_sequence(mol_edge_lists, batch_first=True, padding_value=0, padding_side='right')

RuntimeError: The size of tensor a (30) must match the size of tensor b (21) at non-singleton dimension 1

In [37]:
mol_edge_lists

[tensor([[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0.,

## Model

In [11]:
n_features = 5
d_embedding = 32
d_attention = 32

In [12]:
embedding = nn.Linear(n_features, d_embedding, bias=False)

Q = nn.Linear(d_embedding, d_attention, bias=False)
K = nn.Linear(d_embedding, d_attention, bias=False)
V = nn.Linear(d_embedding, d_attention, bias=False)

In [13]:
x = embedding(mol_tensor)

q, k, v = Q(x), K(x), V(x)

In [24]:
torch.bmm(q, k.transpose(-2, -1)).shape

torch.Size([16, 39, 39])