In [None]:
import os
import sys
import pprint

root = '/'

import_path = root + 'pigvae_all'
sys.path.append(import_path)
pprint.pprint(sys.path)

In [None]:
import numpy as np
import pandas as pd

data_dir = root + 'MolData/ZINC/'
data_path = data_dir + "zinc15_250K_2D.csv"

indices_list = np.load(data_dir + "sample_indices_75k.npy").tolist()
zinc_df = pd.read_csv(data_path)

In [None]:
df_extracted = zinc_df.loc[indices_list]

In [None]:
zinc_smiles = df_extracted["smiles"].tolist()

In [None]:
smiles_list = []

#for data in datalist:
for data in zinc_smiles:
    smi = data.rstrip('\n')
    smiles_list.append(smi)

In [None]:
from rdkit import Chem

all_mols = []
atom_Ns = []

for id, smi in enumerate(smiles_list):

    mol = Chem.MolFromSmiles(smi)

    if mol is not None:
        print(id)
        all_mols.append(mol)
        atom_Ns.append(mol.GetNumAtoms())

In [None]:
atom_symbols_list = []

for id, mol in enumerate(all_mols):
    
    print(id)
    
    for atom in mol.GetAtoms():
        #print(atom.GetSymbol(), atom.GetAtomicNum())
        atom_symbols_list.append(atom.GetSymbol())

In [None]:
list(set(atom_symbols_list))

In [None]:
import numpy as np

max_num_nodes = max(atom_Ns)
max_num_nodes

In [None]:
len(atom_Ns), len(all_mols)

In [None]:
from mol2graph import mol2vec

node_features = []
edge_features = []
mask = []
props = []

mol_graphs = []

for mol in all_mols:
    mol_graphs.append(mol2vec(mol))

In [None]:
num_node_f = mol_graphs[0].x.shape[1]
num_edge_f = mol_graphs[0].edge_attr.shape[1]

In [None]:
num_node_f, num_edge_f

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import random
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import networkx as nx
from networkx.algorithms.shortest_paths.dense import floyd_warshall_numpy

from networkx.generators.random_graphs import *
from networkx.generators.ego import ego_graph
from networkx.generators.geometric import random_geometric_graph

In [None]:
from mol2graph import mol2vec

node_features = []
edge_features = []
mask = []
props = []

for id, mol in enumerate(all_mols):

    mol_graph = mol2vec(mol)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()

    atoms_list = np.arange(len(atoms))
    bonds_list = []

    for bond in bonds:

        begin_atom, end_atom = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bonds_list.append((begin_atom, end_atom))

    mg = nx.Graph()
    mg.add_nodes_from(atoms_list)
    mg.add_edges_from(bonds_list)

    num_edges = mg.number_of_edges()
    num_nodes = mg.number_of_nodes()
    num_nodes_init = mg.number_of_nodes()

    props.append(torch.Tensor([num_nodes]))
    mg.add_nodes_from([i for i in range(num_nodes, max_num_nodes)])
    nf = torch.zeros(max_num_nodes, num_node_f)
    node_f = nf.unsqueeze(0)

    dm = torch.from_numpy(floyd_warshall_numpy(mg)).long()
    dm = torch.clamp(dm, 0, 5).unsqueeze(-1)
    num_nodes = dm.size(1)
    dm = torch.zeros((num_nodes, num_nodes, 6 + num_edge_f + 1)).type_as(dm).scatter_(2, dm, 1).float()
    dm[:,:, 6] = 1.0

    i_list = list(range(0, len(mol_graph.edge_index[0]), 2))
    j_list = list(range(1, len(mol_graph.edge_index[0]), 2))
        
    for idx, (i, j) in enumerate(zip(mol_graph.edge_index[0][i_list], mol_graph.edge_index[0][j_list])):
        dm[i,j,7:7+num_edge_f] = mol_graph.edge_attr[idx]
        dm[j,i,7:7+num_edge_f] = mol_graph.edge_attr[idx]
        dm[i,j,6] = 0.0
        dm[j,i,6] = 0.0

    for idx, node_x in enumerate(mol_graph.x):
        node_f[0][idx] = node_x

    edge_features.append(dm)
    mask.append((torch.arange(max_num_nodes) < num_nodes_init).unsqueeze(0))
    node_features.append(node_f)
    
    print(id, num_nodes_init, max_num_nodes)

In [None]:
node_features = torch.cat(node_features, dim=0)
edge_features = torch.stack(edge_features, dim=0)
mask = torch.cat(mask, dim=0)
props = torch.cat(props, dim=0)

In [None]:
data_dict = {
    'node_features': node_features,
    'edge_features': edge_features,
    'mask': mask,
    'props': props
}

save_dir = root + "dataset/train_dataset/"

# save
save_path = save_dir + 'zinc_gdata/tensor_data.pkl'
torch.save(data_dict, save_path)