In [2]:
import pandas as pd
df = pd.read_csv('/data/ongh0068/l1000/pyg_output_playground/train/processed_file_paths.csv')

In [5]:
import os
os.chdir('/data/ongh0068/l1000/moler_reference')

In [71]:
from torch_geometric.data import Dataset
import torch
import numpy as np
import os
import gzip
import pickle
from molecule_generation.chem.motif_utils import get_motif_type_to_node_type_index_map
from torch_geometric.data import Data

def get_motif_type_to_node_type_index_map(
    motif_vocabulary, num_atom_types
):
    """Helper to construct a mapping from motif type to shifted node type."""

    return {
        motif: num_atom_types + motif_type
        for motif, motif_type in motif_vocabulary.vocabulary.items()
    }


class MolerDataset(Dataset):
    def __init__(
        self, 
        root, 
        raw_moler_trace_dataset_parent_folder, # absolute path 
        output_pyg_trace_dataset_parent_folder, # absolute path
        split = 'train',
        transform=None, 
        pre_transform=None, 
    ):
        self._processed_file_paths = None
        self._transform = transform 
        self._pre_transform = pre_transform
        self._raw_moler_trace_dataset_parent_folder = raw_moler_trace_dataset_parent_folder
        self._output_pyg_trace_dataset_parent_folder = output_pyg_trace_dataset_parent_folder
        self._split = split
        self.load_metadata()
        super().__init__(root, transform, pre_transform)
        

    @property
    def raw_file_names(self):
        """
        Raw generation trace files output from the preprocess function of the cli. These are zipped pickle
        files. This is the actual file name without the parent folder.
        """
        raw_pkl_file_folders = [folder for folder in os.listdir(self._raw_moler_trace_dataset_parent_folder) if folder.startswith(self._split)]

        assert len(raw_pkl_file_folders) > 0, f'{self._raw_moler_trace_dataset_parent_folder} does not contain {self._split} files.'
        
        raw_generation_trace_files = []
        for folder in raw_pkl_file_folders:
            for pkl_file in os.listdir(os.path.join(self._raw_moler_trace_dataset_parent_folder, folder)):
                raw_generation_trace_files.append(os.path.join(self._raw_moler_trace_dataset_parent_folder, folder, pkl_file))
        return raw_generation_trace_files

    @property
    def processed_file_names(self):
        """Processed generation trace objects that are stored as .pt files"""
        processed_file_paths_folder = os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split)
        if self._processed_file_paths is not None:
            return self._processed_file_paths
        
        if not os.path.exists(processed_file_paths_folder):
            os.mkdir(processed_file_paths_folder)
            
        # After processing, the file paths will be saved in the csv file 
        processed_file_paths_csv = os.path.join(processed_file_paths_folder, 'processed_file_paths.csv')

        if not os.path.exists(processed_file_paths_csv):
            is_csv_generated = self._generate_processed_file_paths_csv(processed_file_paths_folder, processed_file_paths_csv)
            if not is_csv_generated:
                return []
        self._processed_file_paths = pd.read_csv(processed_file_paths_csv)['file_names'].tolist()
        return self._processed_file_paths 

    def _generate_processed_file_paths_csv(self, processed_file_paths_folder, processed_file_paths_csv):
        
        file_paths = [os.path.join(processed_file_paths_folder, file_path) for file_path in os.listdir(processed_file_paths_folder)]
        if len(file_paths) > 0:
            df = pd.DataFrame(file_paths, columns = ['file_names'])
            df.to_csv(processed_file_paths_csv, index = False)
            return True
        else:
            print('No processed files found!')

    @property
    def processed_file_names_size(self):
        return len(self.processed_file_names)
    
    @property 
    def metadata(self):
        return self._metadata

    @property
    def node_type_index_to_string(self):
        return self._node_type_index_to_string
    
    @property 
    def num_node_types(self):
        return len(self.node_type_index_to_string)

    def node_type_to_index(self, node_type):
        return self._atom_type_featuriser.type_name_to_index(node_type)

    def node_types_to_indices(self, node_types):
        """Convert list of string representations into list of integer indices."""
        return [self.node_type_to_index(node_type) for node_type in node_types]

    def node_types_to_multi_hot(self, node_types):
        """Convert between string representation to multi hot encoding of correct node types.

        Note: implemented here for backwards compatibility only.
        """
        correct_indices = self.node_types_to_indices(node_types)
        multihot = np.zeros(shape=(self.num_node_types,), dtype=np.float32)
        for idx in correct_indices:
            multihot[idx] = 1.0
        return multihot
    
    def node_type_to_index(self, node_type):
        motif_node_type_index = self._motif_to_node_type_index.get(node_type)

        if motif_node_type_index is not None:
            return motif_node_type_index
        else:
            return self._atom_type_featuriser.type_name_to_index(node_type)
    
    def load_metadata(self):
        metadata_file_path = os.path.join(self._raw_moler_trace_dataset_parent_folder, 'metadata.pkl.gz')
        
        with gzip.open(metadata_file_path, 'rb') as f:
             self._metadata = pickle.load(f)
        
        self._atom_type_featuriser = next(
            featuriser
            for featuriser in self._metadata["feature_extractors"]
            if featuriser.name == "AtomType"
        )
        
        self._node_type_index_to_string = self._atom_type_featuriser.index_to_atom_type_map.copy()
        self._motif_vocabulary = self.metadata.get("motif_vocabulary")

        if self._motif_vocabulary is not None:
            self._motif_to_node_type_index = get_motif_type_to_node_type_index_map(
                motif_vocabulary=self._motif_vocabulary,
                num_atom_types=len(self._node_type_index_to_string),
            )

            for motif, node_type in self._motif_to_node_type_index.items():
                self._node_type_index_to_string[node_type] = motif
        else:
            self._motif_to_node_type_index = {}
        

    def process(self):
        """Convert raw generation traces into individual .pt files for each of the trace steps."""
        # only call process if it was not called before
        if self.processed_file_names_size > 0:
            pass
        else:
            
            for pkl_file_path in self.raw_file_names:
                generation_steps = self._convert_data_shard_to_list_of_trace_steps(pkl_file_path)
                
                for molecule_idx, molecule_gen_steps in generation_steps:
                    
                    for step_idx, step in enumerate(molecule_gen_steps):
                        file_name = f'{pkl_file_path.split("/")[-1].split(".")[0]}_mol_{molecule_idx}_step_{step_idx}.pt'
                        file_path = os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split, file_name)
                        torch.save(step, file_path)
                        print(f'Processing {molecule_idx}, step {step_idx}')
            print(f'{self.processed_file_names_size} files generated')         
            

    def _convert_data_shard_to_list_of_trace_steps(self, pkl_file_path):
        # TODO: multiprocessing to speed this up
        generation_steps = []
        
        with gzip.open(pkl_file_path, 'rb') as f:
            molecules = pickle.load(f)
            for molecule_idx, molecule in enumerate(molecules): 
                generation_steps += [(molecule_idx, self._extract_generation_steps(molecule))]
        
        return generation_steps
            
    def _extract_generation_steps(self, molecule):
        """Packages each generation step of each molecule into a pyg Data object."""
        molecule_gen_steps = []
        molecule_property_values = dict(molecule.graph_property_values)
        for gen_step in molecule:
            gen_step_features = {}
            gen_step_features['x'] = gen_step.partial_node_features
            gen_step_features['focus_node'] = gen_step.focus_node

            # have an edge type attribute to tell apart each of the 3 bond types
            edge_indexes = []
            edge_types= []
            for i, adj_list in enumerate(gen_step.partial_adjacency_lists):
                if len(adj_list) != 0:
                    edge_index = torch.tensor(adj_list).T
                    edge_indexes += [edge_index]
                    edge_types += [i]*len(adj_list)

            gen_step_features['edge_index'] = torch.cat(edge_indexes, 1) if len(edge_indexes) > 0 else torch.tensor(edge_indexes)
            gen_step_features['edge_type'] = torch.tensor(edge_types)
            gen_step_features['correct_edge_choices'] = gen_step.correct_edge_choices

            num_correct_edge_choices = np.sum(gen_step.correct_edge_choices)
            gen_step_features['num_correct_edge_choices'] = num_correct_edge_choices
            gen_step_features['stop_node_label'] = int(num_correct_edge_choices == 0)
            gen_step_features['valid_edge_choices'] = gen_step.valid_edge_choices
            gen_step_features["correct_edge_types"] = gen_step.correct_edge_types
            gen_step_features["partial_node_categorical_features"] = gen_step.partial_node_categorical_features
            if gen_step.correct_attachment_point_choice is not None:
                gen_step_features["correct_attachment_point_choice"] = list(gen_step.valid_attachment_point_choices).index(gen_step.correct_attachment_point_choice)
            else:
                gen_step_features["correct_attachment_point_choice"] = []
            gen_step_features["valid_attachment_point_choices"] = gen_step.valid_attachment_point_choices

            # And finally, the correct node type choices. Here, we have an empty list of
            # correct choices for all steps where we didn't choose a node, so we skip that:
            if gen_step.correct_node_type_choices is not None:
                gen_step_features["correct_node_type_choices"] = self.node_types_to_multi_hot(gen_step.correct_node_type_choices)
            else:
                gen_step_features["correct_node_type_choices"] = []
            gen_step_features['correct_first_node_type_choices'] = self.node_types_to_multi_hot(molecule.correct_first_node_type_choices)

            # Add graph_property_values
            gen_step_features = {**gen_step_features, **molecule_property_values}
        
            molecule_gen_steps += [gen_step_features]
        molecule_gen_steps = self._to_tensor_moler(molecule_gen_steps)
        return [Data(**step) for step in molecule_gen_steps]
    
    def _to_tensor_moler(self, molecule_gen_steps):
        for i in range(len(molecule_gen_steps)):
            for k,v in molecule_gen_steps[i].items():
                molecule_gen_steps[i][k] = torch.tensor(molecule_gen_steps[i][k])
        return molecule_gen_steps
        
    def len(self):
        return self.processed_file_names_size

    def get(self, idx):
        file_path = self.processed_file_names[idx]
        data = torch.load(file_path)
        return data

In [72]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/l1000/trace_playground',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/pyg_output_playground',
    split = 'train',
)

No processed files found!
No processed files found!


Processing...
  molecule_gen_steps[i][k] = torch.tensor(molecule_gen_steps[i][k])


Processing 0, step 0
Processing 0, step 1
Processing 0, step 2
Processing 0, step 3
Processing 0, step 4
Processing 0, step 5
Processing 0, step 6
Processing 0, step 7
Processing 0, step 8
Processing 0, step 9
Processing 0, step 10
Processing 0, step 11
Processing 0, step 12
Processing 0, step 13
Processing 0, step 14
Processing 0, step 15
Processing 0, step 16
Processing 0, step 17
Processing 0, step 18
Processing 0, step 19
Processing 0, step 20
Processing 0, step 21
Processing 0, step 22
Processing 0, step 23
Processing 0, step 24
Processing 0, step 25
Processing 0, step 26
Processing 0, step 27
Processing 0, step 28
Processing 0, step 29
Processing 0, step 30
Processing 0, step 31
Processing 0, step 32
Processing 0, step 33
Processing 0, step 34
Processing 0, step 35
Processing 0, step 36
Processing 0, step 37
Processing 0, step 38
Processing 0, step 39
Processing 0, step 40
Processing 0, step 41
Processing 0, step 42
Processing 0, step 43
Processing 1, step 0
Processing 1, step 1


Processing 13, step 30
Processing 13, step 31
Processing 13, step 32
Processing 13, step 33
Processing 13, step 34
Processing 13, step 35
Processing 13, step 36
Processing 14, step 0
Processing 14, step 1
Processing 14, step 2
Processing 14, step 3
Processing 14, step 4
Processing 14, step 5
Processing 14, step 6
Processing 14, step 7
Processing 14, step 8
Processing 14, step 9
Processing 14, step 10
Processing 14, step 11
Processing 14, step 12
Processing 14, step 13
Processing 14, step 14
Processing 14, step 15
Processing 14, step 16
Processing 14, step 17
Processing 14, step 18
Processing 14, step 19
Processing 14, step 20
Processing 14, step 21
Processing 14, step 22
Processing 14, step 23
Processing 14, step 24
Processing 14, step 25
Processing 14, step 26
Processing 14, step 27
Processing 15, step 0
Processing 15, step 1
Processing 15, step 2
Processing 15, step 3
Processing 15, step 4
Processing 15, step 5
Processing 15, step 6
Processing 15, step 7
Processing 15, step 8
Process

Processing 31, step 5
Processing 31, step 6
Processing 31, step 7
Processing 31, step 8
Processing 31, step 9
Processing 31, step 10
Processing 31, step 11
Processing 31, step 12
Processing 31, step 13
Processing 31, step 14
Processing 31, step 15
Processing 31, step 16
Processing 31, step 17
Processing 31, step 18
Processing 31, step 19
Processing 31, step 20
Processing 31, step 21
Processing 31, step 22
Processing 31, step 23
Processing 31, step 24
Processing 31, step 25
Processing 31, step 26
Processing 31, step 27
Processing 31, step 28
Processing 31, step 29
Processing 31, step 30
Processing 31, step 31
Processing 31, step 32
Processing 31, step 33
Processing 32, step 0
Processing 32, step 1
Processing 32, step 2
Processing 32, step 3
Processing 32, step 4
Processing 32, step 5
Processing 32, step 6
Processing 32, step 7
Processing 32, step 8
Processing 32, step 9
Processing 32, step 10
Processing 32, step 11
Processing 32, step 12
Processing 32, step 13
Processing 32, step 14
Pro

Processing 48, step 20
Processing 48, step 21
Processing 48, step 22
Processing 48, step 23
Processing 48, step 24
Processing 48, step 25
Processing 48, step 26
Processing 48, step 27
Processing 48, step 28
Processing 48, step 29
Processing 48, step 30
Processing 48, step 31
Processing 48, step 32
Processing 48, step 33
Processing 48, step 34
Processing 48, step 35
Processing 48, step 36
Processing 48, step 37
Processing 48, step 38
Processing 48, step 39
Processing 48, step 40
Processing 48, step 41
Processing 49, step 0
Processing 49, step 1
Processing 49, step 2
Processing 49, step 3
Processing 49, step 4
Processing 49, step 5
Processing 49, step 6
Processing 49, step 7
Processing 49, step 8
Processing 49, step 9
Processing 49, step 10
Processing 49, step 11
Processing 49, step 12
Processing 50, step 0
Processing 50, step 1
Processing 50, step 2
Processing 50, step 3
Processing 50, step 4
Processing 50, step 5
Processing 50, step 6
Processing 50, step 7
Processing 50, step 8
Process

Processing 62, step 2
Processing 62, step 3
Processing 62, step 4
Processing 62, step 5
Processing 62, step 6
Processing 62, step 7
Processing 62, step 8
Processing 62, step 9
Processing 62, step 10
Processing 62, step 11
Processing 62, step 12
Processing 62, step 13
Processing 62, step 14
Processing 62, step 15
Processing 62, step 16
Processing 62, step 17
Processing 62, step 18
Processing 62, step 19
Processing 62, step 20
Processing 62, step 21
Processing 62, step 22
Processing 62, step 23
Processing 62, step 24
Processing 62, step 25
Processing 62, step 26
Processing 62, step 27
Processing 62, step 28
Processing 62, step 29
Processing 62, step 30
Processing 62, step 31
Processing 62, step 32
Processing 62, step 33
Processing 62, step 34
Processing 62, step 35
Processing 62, step 36
Processing 62, step 37
Processing 62, step 38
Processing 62, step 39
Processing 62, step 40
Processing 62, step 41
Processing 62, step 42
Processing 62, step 43
Processing 62, step 44
Processing 62, step

Processing 76, step 24
Processing 76, step 25
Processing 76, step 26
Processing 76, step 27
Processing 76, step 28
Processing 76, step 29
Processing 76, step 30
Processing 76, step 31
Processing 76, step 32
Processing 76, step 33
Processing 76, step 34
Processing 76, step 35
Processing 76, step 36
Processing 76, step 37
Processing 76, step 38
Processing 76, step 39
Processing 76, step 40
Processing 76, step 41
Processing 76, step 42
Processing 76, step 43
Processing 76, step 44
Processing 77, step 0
Processing 77, step 1
Processing 77, step 2
Processing 77, step 3
Processing 77, step 4
Processing 77, step 5
Processing 77, step 6
Processing 77, step 7
Processing 77, step 8
Processing 77, step 9
Processing 77, step 10
Processing 77, step 11
Processing 77, step 12
Processing 77, step 13
Processing 77, step 14
Processing 77, step 15
Processing 77, step 16
Processing 77, step 17
Processing 77, step 18
Processing 77, step 19
Processing 77, step 20
Processing 77, step 21
Processing 77, step 2

Processing 94, step 0
Processing 94, step 1
Processing 94, step 2
Processing 94, step 3
Processing 94, step 4
Processing 94, step 5
Processing 94, step 6
Processing 94, step 7
Processing 94, step 8
Processing 94, step 9
Processing 94, step 10
Processing 94, step 11
Processing 94, step 12
Processing 94, step 13
Processing 94, step 14
Processing 94, step 15
Processing 94, step 16
Processing 94, step 17
Processing 94, step 18
Processing 94, step 19
Processing 94, step 20
Processing 94, step 21
Processing 94, step 22
Processing 94, step 23
Processing 94, step 24
Processing 94, step 25
Processing 94, step 26
Processing 94, step 27
Processing 94, step 28
Processing 94, step 29
Processing 94, step 30
Processing 95, step 0
Processing 95, step 1
Processing 95, step 2
Processing 95, step 3
Processing 95, step 4
Processing 95, step 5
Processing 95, step 6
Processing 95, step 7
Processing 95, step 8
Processing 95, step 9
Processing 95, step 10
Processing 95, step 11
Processing 95, step 12
Processi

Done!


In [73]:
dataset

MolerDataset(2384)

In [74]:
pprint_pyg_obj(dataset[0])

x: torch.Size([4, 32])
edge_index: torch.Size([2, 6])
focus_node: torch.Size([])
edge_type: torch.Size([6])
correct_edge_choices: torch.Size([0])
num_correct_edge_choices: torch.Size([])
stop_node_label: torch.Size([])
valid_edge_choices: torch.Size([0, 2])
correct_edge_types: torch.Size([0, 3])
partial_node_categorical_features: torch.Size([4])
correct_attachment_point_choice: torch.Size([0])
valid_attachment_point_choices: torch.Size([0])
correct_node_type_choices: torch.Size([139])
correct_first_node_type_choices: torch.Size([139])
sa_score: torch.Size([])
clogp: torch.Size([])
mol_weight: torch.Size([])
qed: torch.Size([])
bertz: torch.Size([])


In [58]:
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=16, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices'
])

In [59]:
def pprint_pyg_obj(batch):
    for key in vars(batch)['_store'].keys():
        if key.startswith('_'):
            continue
        print(f'{key}: {batch[key].shape}')
for batch in loader:
    pprint_pyg_obj(batch)
    break

x: torch.Size([126, 32])
edge_index: torch.Size([2, 212])
focus_node: torch.Size([16])
edge_type: torch.Size([212])
correct_edge_choices: torch.Size([53])
correct_edge_choices_batch: torch.Size([53])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([53, 2])
valid_edge_choices_batch: torch.Size([53])
valid_edge_choices_ptr: torch.Size([17])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([126])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_attachment_point_choice_ptr: torch.Size([17])
valid_attachment_point_choices: torch.Size([0])
valid_attachment_point_choices_batch: torch.Size([0])
valid_attachment_point_choices_ptr: torch.Size([17])
correct_node_type_choices: torch.Size([1112])
correct_node_type_choi

In [76]:
# Get some information out from the dataset:
next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
atom_type_nums = [
    next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
    for type_idx in range(dataset.num_node_types)
]

Counter({'C': 473,
         'N': 89,
         'C1=CC=CC=C1': 92,
         'O': 105,
         'S': 10,
         'None': 100,
         'CCN(C)C': 1,
         'NC=O': 8,
         'C1=CC2=C(CCC2)S1': 2,
         'Cl': 16,
         'C1CCNCC1': 12,
         'C1=CC=C2N=CC=CC2=C1': 2,
         'ClC(Cl)Cl': 1,
         'C1=CNN=C1': 5,
         'N=CO': 4,
         'C1=CN2N=CC=C2N=C1': 2,
         'OC(F)F': 2,
         'C1OC2CNC1C2': 1,
         'C1=CSC=N1': 3,
         'O=[N+][O-]': 7,
         'C1CC1': 8,
         'COC=O': 1,
         'C1=CC=NC=C1': 12,
         'FC(F)F': 4,
         'O=C(O)CS': 1,
         'C1=NC=NC2=C1N=CN2': 1,
         'CCNC(C)=O': 1,
         'C1CCNC1': 9,
         'CNC(=O)CS': 1,
         'C1=COC=C1': 2,
         'CCCO': 2,
         'NC(N)=O': 1,
         'CN(C)C(N)=O': 1,
         'C1NC2C3C4CC5C3C1C1C5C4C21': 1,
         'C1=NC=C2CCCCC2=N1': 2,
         'C1CCC2CCCC2C1': 1,
         'C1COCCN1': 6,
         'C1CCC2C(C1)CCC1C3CCCC3CCC21': 1,
         'C1CCOCC1': 4,
        