In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pickle
import tqdm
import pandas as pd
import selfies as sf
from rebadd.datautils import get_fragment_from_selfies

In [None]:
class DATACONFIGS:
    def __init__(self):
        ## input
        self.input_dir = os.path.join('data', 'zinc15')
        self.train_data_path = os.path.join(self.input_dir, 'zinc15_train.txt')
        ## output
        self.output_dir = os.path.join('processed_data', 'zinc15')
        assert os.path.exists(self.output_dir)       

dataconfigs = DATACONFIGS()

In [None]:
with open(dataconfigs.train_data_path) as fin:
    lines = fin.readlines()
    
data = [line.rstrip() for line in lines]

In [None]:
print(f'Number of training data (raw): {len(data)}')
data[:5]

In [None]:
def make_selfies_data(smiles_iter):
    selfies_list = []
    
    for smi in tqdm.tqdm(smiles_iter):
        try:
            sel = sf.encoder(smi)
            selfies_list.append(sel)
        except AssertionError:
            pass
        except sf.EncoderError:
            pass
    
    return selfies_list

In [None]:
selfies_list = make_selfies_data(data)

In [None]:
print(f'Number of training data (selfies): {len(selfies_list)}')
selfies_list[:5]

In [None]:
with open(os.path.join(dataconfigs.output_dir, 'selfies.csv'), 'w') as fout:
    for selfies in selfies_list:
        fout.write(f'{selfies}\n')

In [None]:
fragments_list = get_fragment_from_selfies(selfies_list)

In [None]:
print(f'maxlen of fragments: {max([len(fragments) for fragments in fragments_list])}')
fragments_list[:2]

In [None]:
with open(os.path.join(dataconfigs.output_dir, 'fragments_list.pkl'), 'wb') as fout:
    pickle.dump(fragments_list, fout)

In [None]:
vocabs = set()
for fragments in tqdm.tqdm(fragments_list):
    vocabs = vocabs.union(set(fragments))

vocabs = sorted(vocabs)

In [None]:
print(f'Number of vocabulary(unique fragments): {len(vocabs)}')
vocabs

In [None]:
with open(os.path.join(dataconfigs.output_dir, 'vocabulary.csv'), 'w') as fout:
    for v in vocabs:
        fout.write(f'{v}\n')