In [None]:
import os
import sys
import pprint

root = '/'

import_path = root + 'pigvae_all'
sys.path.append(import_path)
pprint.pprint(sys.path)

In [None]:
import pickle
    
with open(root + "MolData/QM9/qm9_mols.pickle", "rb") as f1:
    qm9_mols = pickle.load(f1)

with open(root + "MolData/QM9/shuffle_indices.pickle", "rb") as f2:
    shuffle_indices = pickle.load(f2)

In [None]:
import pandas as pd

#mol_table = pd.read_csv("smiles_feature.csv")
mol_table = pd.read_csv(root + "MolData/QM9/qm9_props.csv")

In [None]:
c_indices = ['mu', 'alpha', 'homo', 'lumo', 'r2', 'zpve', 'u0','cv']

In [None]:
#c_indices = mol_table.columns[5:13]

import matplotlib.pyplot as plt

target_df = mol_table[c_indices]
targets_array = target_df.values

print(targets_array.shape, c_indices)

In [None]:
import numpy as np
from rdkit import Chem

all_mols = []
atom_Ns = []
prop_list = []
atom_symbols_list = []

for id, (mol, t) in enumerate(zip(np.array(qm9_mols)[shuffle_indices],np.array(targets_array)[shuffle_indices])):


    if mol is not None:
        print(id, t)

        
        all_mols.append(mol)
        atom_Ns.append(mol.GetNumAtoms())
        prop_list.append(t)
        
        for atom in mol.GetAtoms():
            atom_symbols_list.append(atom.GetSymbol())

In [None]:
import numpy as np

max_num_nodes = max(atom_Ns)
max_num_nodes

In [None]:
import torch

all_targets = torch.from_numpy(np.array(prop_list).astype(np.float32)).clone()

In [None]:
from mol2graph_qm9 import mol2vec

num_node_f = 36
num_edge_f = 6

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import random
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import networkx as nx
from networkx.algorithms.shortest_paths.dense import floyd_warshall_numpy

from networkx.generators.random_graphs import *
from networkx.generators.ego import ego_graph
from networkx.generators.geometric import random_geometric_graph

node_features = []
edge_features = []
mask = []
props = []

for id, mol in enumerate(all_mols):

    mol_graph = mol2vec(mol)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()

    atoms_list = np.arange(len(atoms))
    bonds_list = []

    for bond in bonds:
        begin_atom, end_atom = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bonds_list.append((begin_atom, end_atom))

    mg = nx.Graph()
    mg.add_nodes_from(atoms_list)
    mg.add_edges_from(bonds_list)

    num_edges = mg.number_of_edges()
    num_nodes = mg.number_of_nodes()
    num_nodes_init = mg.number_of_nodes()

    props.append(torch.Tensor([num_nodes]))
    nf = torch.zeros(max_num_nodes, num_node_f)
    node_f = nf.unsqueeze(0)

    
    dm = torch.zeros((max_num_nodes, max_num_nodes, 1 + num_edge_f + 1)).float()
    conf = mol.GetConformer()
    
    for i in range(num_nodes):
        for j in range(num_nodes):
            pos_i = conf.GetAtomPosition(i)
            pos_j = conf.GetAtomPosition(j)
            p = np.array([pos_i.x, pos_i.y, pos_i.z])
            q = np.array([pos_j.x, pos_j.y, pos_j.z])
            dist = np.linalg.norm(p - q)
            dm[i,j, 0] = dist
            
    dm[:,:, 1] = 1.0

    i_list = list(range(0, len(mol_graph.edge_index[0]), 2))
    j_list = list(range(1, len(mol_graph.edge_index[0]), 2))
        
    for idx, (i, j) in enumerate(zip(mol_graph.edge_index[0][i_list], mol_graph.edge_index[0][j_list])):
        dm[i,j,2:2+num_edge_f] = mol_graph.edge_attr[idx]
        dm[j,i,2:2+num_edge_f] = mol_graph.edge_attr[idx]
        dm[i,j,1] = 0.0
        dm[j,i,1] = 0.0

    for idx, node_x in enumerate(mol_graph.x):
        node_f[0][idx] = node_x

    edge_features.append(dm)
    mask.append((torch.arange(max_num_nodes) < num_nodes_init).unsqueeze(0))
    node_features.append(node_f)
    
    if id % 1000 == 0:
        print(id, num_nodes_init, max_num_nodes)

In [None]:
node_features = torch.cat(node_features, dim=0)
edge_features = torch.stack(edge_features, dim=0)
mask = torch.cat(mask, dim=0)
props = torch.cat(props, dim=0)

In [None]:
data_dict = {
    'node_features': node_features,
    'edge_features': edge_features,
    'mask': mask,
    'props': props
}

save_dir = root + "dataset/train_dataset/"

# save
save_path = save_dir + 'qm9_e3data/tensor_data.pkl'
torch.save(data_dict, save_path)