In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
print(torch.__version__)
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
import torch.nn.functional as F
from torch_sparse import coalesce
from torch_geometric.data import Data
import os
import numpy as np
from torch_geometric.utils import to_dense_adj
from rdkit.Chem import Draw
import networkx as nx
from torch.utils.data import DataLoader, Dataset
from utils.data_utils import to_undirected, encode_adj
from torch_geometric.utils import to_networkx
from collections import Counter


2.0.1


In [3]:
def mols_from_file(pathfile: str, drop_none: bool = False):
    '''
    takes as input a path/to/file.ext 
    where ext can be:
    .sdf, .csv, .txt, .smiles
    it returns all mols from file
    if drop_none: drops all mols non valid for rdkit
    '''
    filename_ext = os.path.splitext(pathfile)[-1].lower()
    if filename_ext in ['.sdf']:
        suppl = Chem.SDMolSupplier(pathfile)
    elif filename_ext in ['.csv', '.txt', '.smiles']:
        suppl = Chem.SmilesMolSupplier(pathfile, titleLine=False)
    else:
        raise TypeError(f"{filename_ext} not supported")
    if drop_none:
        return [x for x in suppl if x is not None]
    return [x for x in suppl]

In [5]:
guacm_smiles = "/home/marconobile/Desktop/graphRNN/new_train_data/test_smiles.smiles"
guac_mols = mols_from_file(guacm_smiles, True)



In [9]:
c = Counter()
def get_atoms_info(mols):
    atoms = set()
    max_num = 0
    for num_mol, m in enumerate(mols):
        if m.GetNumAtoms() > max_num: max_num = m.GetNumAtoms()

        atom_types = [atom.GetSymbol() for atom in m.GetAtoms()]
        c.update(atom_types)
        
        for atom in atom_types:
            atoms.add(atom)
        
    atom2num = {}

    for i, atomType in enumerate(atoms):
        atom2num[str(atomType)] = i

    num2atom = {v:k for k,v in atom2num.items()}
    print("TOTAL NUM OF MOLS: ", num_mol)
    return atom2num, num2atom, max_num

In [10]:
atom2num, num2atom, max_num = get_atoms_info(guac_mols)

TOTAL NUM OF MOLS:  150065


In [11]:
bond2num = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
num2bond = {v:k for k,v in bond2num.items()}

In [12]:
#! TODO: multiprocess
def rdkit2pyg(mols):
    '''
    :param mols: iterable of rdkit mols
    :return: list of PyG data objs with one-hot node/edge features
    '''
    data_list = []
    for mol in mols:
        if mol is None:
            continue

        N = mol.GetNumAtoms()

        type_idx = []
        for atom in mol.GetAtoms():
            type_idx.append(atom2num[atom.GetSymbol()])

        x = F.one_hot(torch.tensor(type_idx), num_classes=len(atom2num))
        row, col, bond_idx = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            bond_idx += 2 * [bond2num[bond.GetBondType()]]

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = F.one_hot(torch.tensor(bond_idx).to(torch.int64),
                              num_classes=len(bond2num)).to(torch.float)
        edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)

        data_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr))

    return data_list

def pyg2rdkit(dataset):
    def numpy_to_rdkit(adj, nf, ef, sanitize=False):
        """
        Converts a molecule from numpy to RDKit format.
        :param adj: binary numpy array of shape (N, N) 
        :param nf: numpy array of shape (N, F)
        :param ef: numpy array of shape (N, N, S)
        :param sanitize: whether to sanitize the molecule after conversion
        :return: an RDKit molecule
        """
        if Chem is None:
            raise ImportError('`numpy_to_rdkit` requires RDKit.')
        mol = Chem.RWMol()
        for nf_ in nf:
            # atomic_num = torch.argmax(nf_).item()
            atomic_num = int(nf_)
            mol.AddAtom(Chem.Atom(num2atom[atomic_num]))

        for i, j in zip(*np.triu_indices(adj.shape[-1])):
            if i != j and adj[i, j] == adj[j, i] == 1 and not mol.GetBondBetweenAtoms(int(i), int(j)):
                bond_type_1 = num2bond[int(ef[i, j, 0])]
                bond_type_2 = num2bond[int(ef[j, i, 0])]
                if bond_type_1 == bond_type_2:
                    mol.AddBond(int(i), int(j), bond_type_1)

        mol = mol.GetMol()
        if sanitize: Chem.SanitizeMol(mol)
        return mol

    mols_ = []
    for i, obs in enumerate(dataset):
        ef_temp = torch.squeeze(to_dense_adj(edge_index=obs.edge_index, batch=None, edge_attr=obs.edge_attr), 0)
        ef = torch.zeros((ef_temp.shape[0], ef_temp.shape[1], 1))
        adj = torch.zeros((ef_temp.shape[0], ef_temp.shape[1]))
        
        for row in range(ef_temp.shape[0]):
            for col in range(ef_temp.shape[1]):
                if int(torch.sum(ef_temp[row, col]).item()) != 0:
                    ef[row, col, 0] = torch.argmax(ef_temp[row, col]).item()
                    adj[row, col] = 1

        ef = np.array(ef)
        adj = np.array(adj)

        adj = adj.astype(int)
        ef = ef.astype(int)

        nf = np.array([torch.argmax(row).item() for row in obs.x])
        nf = np.expand_dims(nf, 1)
        nf = nf.astype(int)
        mols_.append(numpy_to_rdkit(adj, nf, ef))

    return mols_

In [15]:
import random

def rand_perm_rooted_at_carbon(m):
    carbons = [idx for idx, atom in enumerate(m.GetAtoms()) if atom.GetSymbol() == 'C']
    smi = Chem.MolToSmiles(m, rootedAtAtom = random.choice(carbons))
    return Chem.MolFromSmiles(smi)

In [27]:
def encode_adj_(adj, max_prev_node, edge_feature_dims):
    '''
    :param adj: n*n, rows means time step, while columns are input dimension
    :param max_degree: we want to keep row number, but truncate column numbers
    :return: n x m x ef reversed
    '''
    adj = np.tril(adj, k=-1)
    n = adj.shape[0]
    adj = adj[1:n, 0:n-1]

    # use max_prev_node to truncate; now adj is a (n-1)x(max_prev_node) matrix
    adj_output = np.zeros((adj.shape[0], max_prev_node, edge_feature_dims))# np.zeros((n, max_prev_node, edge_feature_dims))
    adj_output[:, :, 0] = 1 # set all edges as absent


    for i in range(n-1):
        input_start = max(0, i - max_prev_node + 1)
        input_end = i + 1
        output_start = max_prev_node + input_start - input_end
        output_end = max_prev_node
        to_be_cat = adj[i, input_start:input_end, :]
        to_be_cat = np.concatenate((np.zeros((to_be_cat.shape[0], 1)), to_be_cat), 1)
        
        # begin by setting all as absent
        for idx_row in range(to_be_cat.shape[0]):
            if to_be_cat[idx_row].sum() == 0:
                to_be_cat[idx_row, 0] = 1
                
        # for r in range(to_be_cat.shape[0]):
            # for c in range(to_be_cat.shape[1]):
                # to_be_cat[r, c, 0] = 1

        adj_output[i, output_start:output_end, :] = to_be_cat
        adj_output[i,:] = adj_output[i,:][::-1] # reverse order
    return adj_output


In [24]:
def process_subset(subset, max_num_node, max_prev_node):
    '''
    :param subset: list of pyg obs for training/testing
    :param max_num_node: max num of nodes in the set of graphs
    :param max_prev_node: max_num_node - 1
    :return: Graph_sequence_sampler_pytorch data object
    '''
    G_list = []         # list of undirected nx graphs
    node_attr_list = [] # list of node matrices
    adj_all = []        # list of A(s) with edge features as elements a_ij [NxNxEf] i.e. dense adj
    for g in subset:
        node_attr_list.append(g.x)
        G_list.append(to_undirected(to_networkx(g)))
        adj_all.append(to_dense_adj(edge_index=g.edge_index, batch=None, edge_attr=g.edge_attr))

    return Graph_sequence_sampler_pytorch(Graph_list=G_list, node_attr_list=node_attr_list, adj_all=adj_all,
                                          max_num_node=max_num_node, max_prev_node=max_prev_node)

def create_train_val_dataloaders(dataset, max_num_node, max_prev_node):
    '''
    for supervised training takes as input:
    - dataset: a list of pyg Data obs,
    - max number of nodes of the loaded graphs,
    - max_prev_node = (max number of nodes-1)
    '''
    train_set = process_subset(dataset, max_num_node, max_prev_node)
    train_dataset_loader = DataLoader(train_set, batch_size=32, shuffle=True)#, num_workers=15)
    return train_dataset_loader, []

class Graph_sequence_sampler_pytorch(torch.utils.data.Dataset):
    '''
    returns : dictionary containing input/output nodes, input/output edges
    '''

    def __init__(self, Graph_list, node_attr_list, adj_all, max_num_node, max_prev_node):

        '''
        Graph_list: list of undirected networkx graphs
        node_attr_list: list of node matrices
        adj_all: list of A(s) with edge features as elements a_ij [NxNxEf]
        max_num_node : max number of possible nodes in a graph
        max_prev_node : max previous node that looks back (to lock back at)
        '''

        self.adj_all = adj_all  # list of multidim np.arrays (As) already in edge_feature form [V, V , node_f]
        self.len_all = []  # V for each G
        self.node_attr_list = node_attr_list
        self.graph_list = Graph_list  # list of undirected nx graphs
        for G in Graph_list:
            self.len_all.append(G.number_of_nodes())  # timesteps of node rnn for each G
        self.max_num_node = max_num_node
        self.max_prev_node = max_prev_node

        self.edge_feature_dims = 5
        self.node_feature_dims = 12

    def __len__(self):
        return len(self.adj_all)

    def __getitem__(self, idx):
        
        # 1) random permutation of A,X,edg idx
        # 2) sample a C
        # 3) apply bfs from sampled C
        # 4) encode retrieved bfs tree

        # edge encoding:
        x_batch = np.zeros((self.max_num_node, self.max_prev_node, self.edge_feature_dims))
        y_batch = np.zeros((self.max_num_node, self.max_prev_node, self.edge_feature_dims))

        original_a = np.asarray(nx.adjacency_matrix(self.graph_list[idx]).todense())  # A without edge features of the current g        
        adj_copy = np.asarray(self.adj_all[idx]).copy() # dense adj matrix
        adj_copy = np.squeeze(adj_copy)  # adj_copy had bs as first dim
        # adj_encoded = encode_adj(adj=adj_copy, original=original_a, max_prev_node=self.max_prev_node, edge_feature_dims = self.edge_feature_dims)
        adj_encoded = encode_adj_(adj=adj_copy, max_prev_node=self.max_prev_node, edge_feature_dims = self.edge_feature_dims)

        x_batch[0, :, :] = 1
        x_batch[1:adj_encoded.shape[0] + 1, :] = adj_encoded
        y_batch[0:adj_encoded.shape[0], :] = adj_encoded

        for r in range(y_batch.shape[0]):
            for c in range(y_batch.shape[1]):
                if np.sum(y_batch[r, c, :]) == 0:
                    y_batch[r, c, 0] = 1

        # node encoding:
        node_attr_list_copy = np.asarray(self.node_attr_list[idx]).copy()
        x_node_attr = np.zeros((self.max_num_node, self.node_feature_dims))
        y_node_attr = np.zeros((self.max_num_node, self.node_feature_dims))

        # input nodes:
        x_node_attr[0, :] = 1
        x_node_attr[1:node_attr_list_copy.shape[0], :] = node_attr_list_copy[:-1]
        # output nodes:
        y_node_attr[:node_attr_list_copy.shape[0], :] = node_attr_list_copy

        len_batch = node_attr_list_copy.shape[0]  # number of nodes of current g

        return {'x': x_batch, 'y': y_batch, 'len': len_batch, 'x_node_attr': x_node_attr, 'y_node_attr': y_node_attr}

In [28]:
data = rdkit2pyg([guac_mols[0]])
train_dataset_loader, _ = create_train_val_dataloaders(data, 6, 3)

In [31]:
for x in train_dataset_loader:
    print(x)

{'x': tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]],

         [[1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.]],

         [[0., 1., 0., 0., 0.],
          [1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.]],

         [[0., 1., 0., 0., 0.],
          [1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.]],

         [[1., 0., 0., 0., 0.],
          [0., 1., 0., 0., 0.],
          [1., 0., 0., 0., 0.]],

         [[1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.],
          [0., 1., 0., 0., 0.]]]], dtype=torch.float64), 'y': tensor([[[[1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.]],

         [[0., 1., 0., 0., 0.],
          [1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.]],

         [[0., 1., 0., 0., 0.],
          [1., 0., 0., 0., 0.],
          [1., 0., 0., 0., 0.]],

         [[1., 0., 0., 0., 0.],
          [0., 1., 0., 0., 0.],
          [1., 0., 0

In [None]:
# MISSING 1
# PERMUTATION IN RDKIT

# MISSING 2
# x_idx = np.array(bfs_seq(G, start_idx))
# adj_copy = adj_copy[np.ix_(x_idx, x_idx)]

In [None]:
def bfs_seq(G, start_id):
    '''
    Builds a bfs node sequence
    :param G:nx graph
    :param start_id: id of the staring node, it must be a carbon
    :return:
    '''
    dictionary = dict(nx.bfs_successors(G, start_id))
    start = [start_id]
    output = [start_id]
    while len(start) > 0:
        next = []
        while len(start) > 0:
            current = start.pop(0)
            neighbor = dictionary.get(current)
            if neighbor is not None:
                #### a wrong example, should not permute here!
                # shuffle(neighbor)
                next = next + neighbor
        output = output + next
        start = next
    return output


In [None]:
def encode_adj_flexible(adj):
    '''
    return a flexible length of output
    note that here there is no loss when encoding/decoding an adj matrix
    :param adj: adj matrix
    :return:
    '''
    # pick up lower tri
    adj = np.tril(adj, k=-1)
    n = adj.shape[0]
    adj = adj[1:n, 0:n-1]

    adj_output = []
    input_start = 0
    for i in range(adj.shape[0]):
        input_end = i + 1
        adj_slice = adj[i, input_start:input_end]
        adj_output.append(adj_slice)
        non_zero = np.nonzero(adj_slice)[0]
        input_start = input_end-len(adj_slice)+np.amin(non_zero)
    return adj_output

def calc_max_prev_node(list_of_adjs, iter=20000,topk=10):
    max_prev_node = []
    for i in range(iter):
        if i % (iter / 5) == 0:
            print('iter {} times'.format(i))
        adj_idx = np.random.randint(len(list_of_adjs))
        adj_copy = list_of_adjs[adj_idx].copy()
        # print('Graph size', adj_copy.shape[0])
        x_idx = np.random.permutation(adj_copy.shape[0])
        adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
        adj_copy_matrix = np.asmatrix(adj_copy)
        G = nx.from_numpy_array(adj_copy_matrix)
        # then do bfs in the permuted G
        start_idx = np.random.randint(adj_copy.shape[0])
        x_idx = np.array(bfs_seq(G, start_idx))
        adj_copy = adj_copy[np.ix_(x_idx, x_idx)]
        # encode adj
        adj_encoded = encode_adj_flexible(adj_copy.copy())
        max_encoded_len = max([len(adj_encoded[i]) for i in range(len(adj_encoded))])
        max_prev_node.append(max_encoded_len)
    max_prev_node = sorted(max_prev_node)[-1*topk:]
    return max_prev_node



G_list = []         # list of undirected nx graphs
for g in data:
    G_list.append(to_undirected(to_networkx(g)))

adjs = [np.asarray(nx.adjacency_matrix(g).todense()) for g in G_list]# A without edge features of the current g        
max(calc_max_prev_node(adjs, iter=100000 ,topk=10)) # max prev node
