In [1]:
import os
import pickle 
import numpy as np
import tensorflow as tf

from tqdm import tqdm
from glob import glob


def generate_vocab(corpus):
    token_index = 4
    stoi, itos = {}, {}
    
    stoi["<PAD>"], itos[0] = 0, "<PAD>"
    stoi["<CLS>"], itos[1] = 1, "<CLS>"   
    stoi["<SEP>"], itos[2] = 2, "<SEP>"
    stoi["<MASK>"], itos[3] = 3, "<MASK>"

    for line in tqdm(corpus):
        for token in line:
            if token not in stoi:
                itos[token_index] = token
                stoi[token] = token_index
                token_index += 1

    return stoi, itos


pickle_path = "data/molecule_net/molecule_total.pickle"

with open(pickle_path, "rb") as f:
        data = pickle.load(f)

molecule_stoi, molecule_itos = generate_vocab(data)

100%|██████████| 93673016/93673016 [03:16<00:00, 476307.03it/s]


In [3]:
index_ = int(len(data) * 0.1)

train_data = data[:index_*8]
valid_data = data[index_*8:index_*9]
test_data = data[index_*9:]

print(f"tain_data: {len(train_data)} valid_data: {len(valid_data)} test_data: {len(test_data)}")

tain_data: 74938408 valid_data: 9367301 test_data: 9367307


In [4]:
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()

    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def serialize_example(data, stoi, output_path):
    writer = tf.io.TFRecordWriter(output_path)
    
    for smiles in tqdm(data):
        feature = {}
        
        token = [1] + [stoi[s] for s in smiles] + [2]
        feature['smiles'] = _bytes_feature(bytes(smiles, "utf-8"))
        feature['token'] = _float_feature(token)
        
        features = tf.train.Features(feature=feature)
        example = tf.train.Example(features=features)
        serialized = example.SerializeToString()
        
        writer.write(serialized)

In [6]:
train_output_path = "data/molecule_net/molecule_train.tfrecord"
valid_output_path = "data/molecule_net/molecule_valid.tfrecord"
test_output_path = "data/molecule_net/molecule_test.tfrecord"

serialize_example(train_data, molecule_stoi, train_output_path)
serialize_example(valid_data, molecule_stoi, valid_output_path)
serialize_example(test_data, molecule_stoi, test_output_path)

100%|██████████| 74938408/74938408 [1:07:35<00:00, 18480.44it/s]
100%|██████████| 9367301/9367301 [08:24<00:00, 18569.28it/s]
100%|██████████| 9367307/9367307 [08:25<00:00, 18543.45it/s]


In [8]:
with open("data/molecule_net/molecule_tokenizer", "wb") as f:
    pickle.dump([molecule_stoi, molecule_itos], f)

In [42]:
def _parse_tfrecord():
    def parse_tfrecord(tfrecord):
        features = {
            'smiles': tf.io.FixedLenFeature([], tf.string),
            'token':  tf.io.FixedLenFeature([], tf.float32)
        }
        
        x = tf.io.parse_single_example(tfrecord, features)

        return x
    
    return parse_tfrecord


def load_tfrecord_dataset(tfrecord_name, batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name)
    raw_dataset = raw_dataset.repeat()
    
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    
    return dataset

In [43]:
ds = load_tfrecord_dataset("data/molecule_net/molecule_small.tfrecord", 1)
ds

Tensor("ParseSingleExample/ParseExample/ParseExampleV2:1", shape=(), dtype=float32)


<PrefetchDataset shapes: {smiles: (None,), token: (None,)}, types: {smiles: tf.string, token: tf.float32}>