### Split Data 

In [1]:
import numpy as np

In [2]:
def split_data(partition,total_num, random_seed = 1, save = True):
    ### total data index ###
    data_id_list = [i for i in range(1,total_num)]
    ### get partion for each type of data ###
    [tran,valn,testln] = [partition[i] for i in ["train","validation","test_live"]]
    ### whether or not to use random seed
    if random_seed:
        np.random.seed(random_seed)
    np.random.shuffle(data_id_list)
    train_id = data_id_list[:tran]
    val_id = data_id_list[tran:tran + valn]
    test_live_id = data_id_list[tran + valn:tran + valn + testln]
    test_id = data_id_list[tran + valn + testln:]
    if save:
        np.savez("split_infor.npz",train= train_id, test = test_id, validation = val_id, test_live = test_live_id)
    
    return train_id, val_id, test_live_id, test_id

In [3]:
### total number of data is 133885
### train: 99000 validation:1000 test_live:1000 test:32885
partition = {"train":99000,"validation":1000,"test_live":1000}
total_num = 133885
train_id,val_id,test_live_id,test_id = split_data(partition, total_num, False)

### Generate tfrecord file for training, testing, validation and test_live

In [4]:
import tarfile
import tempfile
from tempfile import TemporaryDirectory
from ase.units import Hartree, eV, Bohr, Ang
import rdkit 
from rdkit import Chem 

import os
import shutil
import tensorflow as tf;

In [9]:
tar = tarfile.open("qm9_mmff.tar.bz2")
element_conversions = {"H":1,"C":6,"O":8,"N":7,"F":9}

In [10]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = value))

def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = value))

def _float64_feature(value):
    return tf.train.Feature(float_list = tf.train.FloatList(value = value))

In [31]:
### use 10 molecules as exmaples ### 
def write_tfrecord(outfile,data_list):
    olddir = os.getcwd()
    writer = tf.python_io.TFRecordWriter(outfile)
    for data in data_list[0:10]:
        ### save into temporary directory which will be deleted automatically after exit ###
        with TemporaryDirectory() as temp_dir:
            raw_path = os.path.join(temp_dir, 'qm9_mmff_file')
            print(raw_path)
            tmpsdf ="./" + str(data) + ".sdf"
            tmpmmffsdf = "./" + str(data) + ".mmff.sdf"
            tar.extract(tmpsdf, path = raw_path)
            tar.extract(tmpmmffsdf, path = raw_path)
            os.chdir(raw_path)
            with open(tmpsdf,"r") as f:
                lines = f.readlines()
            with open(tmpmmffsdf,"r") as fm:
                linesm = fm.readlines()
            targets = [float(i) for i in lines[0].split()]
            ### using atom_num to get element number and coordinates ####
            atom_num = int(lines[3].split()[0])
            print(atom_num)
            elements= [element_conversions[line.split()[3]] for line in lines[4:4+atom_num]]      
            positions = [line.split()[0:3] for line in lines[4:4+atom_num]]
            mmffpositions = [line.split()[0:3] for line in linesm[4:4+atom_num]]
            
            elements = np.array(elements).astype(np.int64)
            targets = np.array(targets).astype(np.float32)
            positions = np.array(positions).astype(np.float32)
            mmffpositions = np.array(mmffpositions).astype(np.float32)
            
            newfeatures = {'elements' : _bytes_feature([elements.tostring()]),
                           'positions' : _bytes_feature([positions.ravel().tostring()]),
                           'mmffpositions' : _bytes_feature([mmffpositions.ravel().tostring()]),
                           'targets' : _bytes_feature([targets.tostring()])}
            
            example = tf.train.Example(features=tf.train.Features(feature=newfeatures))
            writer.write(example.SerializeToString())
            os.chdir(olddir)
    writer.close()



In [33]:
write_tfrecord("test_train.tfrecord",train_id)

/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmpobl5a0jk/qm9_mmff_file
22
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmp9pzsyzak/qm9_mmff_file
15
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmpiwdlru0_/qm9_mmff_file
15
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmpzmsjbq3g/qm9_mmff_file
13
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmp5_zyxdzz/qm9_mmff_file
19
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmp2l72fbps/qm9_mmff_file
21
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmpe9pj6t8w/qm9_mmff_file
18
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmp4a_smoou/qm9_mmff_file
20
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmpo1hpophy/qm9_mmff_file
13
/var/folders/l2/vbjcs9zx69722gqn21_25kg80000gp/T/tmpyg9uvad_/qm9_mmff_file
17


In [43]:
### Check File ###
record = tf.data.TFRecordDataset("test_train.tfrecord")
def _parser_function(record):
    features = {"elements":tf.FixedLenFeature([], tf.string),
                "positions":tf.FixedLenFeature([], tf.string),
                "mmffpositions":tf.FixedLenFeature([], tf.string),
                "targets":tf.FixedLenFeature([], tf.string)}
    parsed_features = tf.parse_single_example(record,features)
    features_new = {}
    dtype = {"elements":tf.int64,"targets":tf.float32,"positions":tf.float32,"mmffpositions":tf.float32}
    for i in parsed_features.keys():
        feat = tf.decode_raw(parsed_features[i],dtype[i])
        if i == "positions" or i == "mmffpositions":
            feat = tf.reshape(feat,[-1,3])
        features_new[i] = feat
    return features_new

record = record.map(_parser_function)
record = record.repeat()
record = record.padded_batch(batch_size = 10, padded_shapes={"elements":[None],"positions":[None,3],"mmffpositions":[None,3],"targets":[None],}, padding_values=None)
iterator  = record.make_initializable_iterator()
features_new = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    features = sess.run(features_new)
    print(features["positions"][0])

[[-1.5030e-01  1.5706e+00  3.5000e-02]
 [-2.0000e-02  3.5600e-02  3.3000e-03]
 [-1.4249e+00 -5.8550e-01  4.2000e-03]
 [ 7.4270e-01 -4.1970e-01 -1.2611e+00]
 [ 2.1783e+00  1.2230e-01 -1.3285e+00]
 [ 2.8857e+00  9.0000e-03 -4.7000e-03]
 [ 2.2528e+00 -2.4910e-01  1.1392e+00]
 [ 7.6270e-01 -4.5090e-01  1.2462e+00]
 [ 8.2200e-01  2.0606e+00  1.3580e-01]
 [-6.2530e-01  1.9433e+00 -8.7960e-01]
 [-7.6860e-01  1.8873e+00  8.8230e-01]
 [-2.0179e+00 -2.2300e-01 -8.4320e-01]
 [-1.3761e+00 -1.6779e+00 -6.5100e-02]
 [-1.9680e+00 -3.3250e-01  9.2200e-01]
 [ 7.7890e-01 -1.5170e+00 -1.2642e+00]
 [ 1.8550e-01 -1.2350e-01 -2.1583e+00]
 [ 2.7425e+00 -4.1810e-01 -2.0994e+00]
 [ 2.1791e+00  1.1722e+00 -1.6569e+00]
 [ 3.9636e+00  1.5440e-01 -4.6000e-03]
 [ 2.8226e+00 -3.3030e-01  2.0624e+00]
 [ 3.8400e-01  6.0800e-02  2.1417e+00]
 [ 5.5570e-01 -1.5184e+00  1.4167e+00]]


In [17]:
print(train_id[0:10])

[6308, 91703, 101162, 13777, 5246, 86608, 103353, 118248, 96799, 12700]
