In [1]:
import pandas as pd
df = pd.read_csv('/data/ongh0068/l1000/pyg_output_playground/train/processed_file_paths.csv')

In [2]:
import os
os.chdir('/data/ongh0068/l1000/moler_reference')

In [3]:
from torch_geometric.data import Dataset
import torch
import numpy as np
import os
import gzip
import pickle
from molecule_generation.chem.motif_utils import get_motif_type_to_node_type_index_map
from torch_geometric.data import Data

def get_motif_type_to_node_type_index_map(
    motif_vocabulary, num_atom_types
):
    """Helper to construct a mapping from motif type to shifted node type."""

    return {
        motif: num_atom_types + motif_type
        for motif, motif_type in motif_vocabulary.vocabulary.items()
    }


class MolerDataset(Dataset):
    def __init__(
        self, 
        root, 
        raw_moler_trace_dataset_parent_folder, # absolute path 
        output_pyg_trace_dataset_parent_folder, # absolute path
        split = 'train',
        transform=None, 
        pre_transform=None, 
    ):
        self._processed_file_paths = None
        self._transform = transform 
        self._pre_transform = pre_transform
        self._raw_moler_trace_dataset_parent_folder = raw_moler_trace_dataset_parent_folder
        self._output_pyg_trace_dataset_parent_folder = output_pyg_trace_dataset_parent_folder
        self._split = split
        self.load_metadata()
        super().__init__(root, transform, pre_transform)
        

    @property
    def raw_file_names(self):
        """
        Raw generation trace files output from the preprocess function of the cli. These are zipped pickle
        files. This is the actual file name without the parent folder.
        """
        raw_pkl_file_folders = [folder for folder in os.listdir(self._raw_moler_trace_dataset_parent_folder) if folder.startswith(self._split)]

        assert len(raw_pkl_file_folders) > 0, f'{self._raw_moler_trace_dataset_parent_folder} does not contain {self._split} files.'
        
        raw_generation_trace_files = []
        for folder in raw_pkl_file_folders:
            for pkl_file in os.listdir(os.path.join(self._raw_moler_trace_dataset_parent_folder, folder)):
                raw_generation_trace_files.append(os.path.join(self._raw_moler_trace_dataset_parent_folder, folder, pkl_file))
        return raw_generation_trace_files

    @property
    def processed_file_names(self):
        """Processed generation trace objects that are stored as .pt files"""
        processed_file_paths_folder = os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split)
        if self._processed_file_paths is not None:
            return self._processed_file_paths
        
        if not os.path.exists(processed_file_paths_folder):
            os.mkdir(processed_file_paths_folder)
            
        # After processing, the file paths will be saved in the csv file 
        processed_file_paths_csv = os.path.join(processed_file_paths_folder, 'processed_file_paths.csv')

        if not os.path.exists(processed_file_paths_csv):
            is_csv_generated = self._generate_processed_file_paths_csv(processed_file_paths_folder, processed_file_paths_csv)
            if not is_csv_generated:
                return []
        self._processed_file_paths = pd.read_csv(processed_file_paths_csv)['file_names'].tolist()
        return self._processed_file_paths 

    def _generate_processed_file_paths_csv(self, processed_file_paths_folder, processed_file_paths_csv):
        
        file_paths = [os.path.join(processed_file_paths_folder, file_path) for file_path in os.listdir(processed_file_paths_folder)]
        if len(file_paths) > 0:
            df = pd.DataFrame(file_paths, columns = ['file_names'])
            df.to_csv(processed_file_paths_csv, index = False)
            return True
        else:
            print('No processed files found!')

    @property
    def processed_file_names_size(self):
        return len(self.processed_file_names)
    
    @property 
    def metadata(self):
        return self._metadata

    @property
    def node_type_index_to_string(self):
        return self._node_type_index_to_string
    
    @property 
    def num_node_types(self):
        return len(self.node_type_index_to_string)

    def node_type_to_index(self, node_type):
        return self._atom_type_featuriser.type_name_to_index(node_type)

    def node_types_to_indices(self, node_types):
        """Convert list of string representations into list of integer indices."""
        return [self.node_type_to_index(node_type) for node_type in node_types]

    def node_types_to_multi_hot(self, node_types):
        """Convert between string representation to multi hot encoding of correct node types.

        Note: implemented here for backwards compatibility only.
        """
        correct_indices = self.node_types_to_indices(node_types)
        multihot = np.zeros(shape=(self.num_node_types,), dtype=np.float32)
        for idx in correct_indices:
            multihot[idx] = 1.0
        return multihot
    
    def node_type_to_index(self, node_type):
        motif_node_type_index = self._motif_to_node_type_index.get(node_type)

        if motif_node_type_index is not None:
            return motif_node_type_index
        else:
            return self._atom_type_featuriser.type_name_to_index(node_type)
    
    def load_metadata(self):
        metadata_file_path = os.path.join(self._raw_moler_trace_dataset_parent_folder, 'metadata.pkl.gz')
        
        with gzip.open(metadata_file_path, 'rb') as f:
             self._metadata = pickle.load(f)
        
        self._atom_type_featuriser = next(
            featuriser
            for featuriser in self._metadata["feature_extractors"]
            if featuriser.name == "AtomType"
        )
        
        self._node_type_index_to_string = self._atom_type_featuriser.index_to_atom_type_map.copy()
        self._motif_vocabulary = self.metadata.get("motif_vocabulary")

        if self._motif_vocabulary is not None:
            self._motif_to_node_type_index = get_motif_type_to_node_type_index_map(
                motif_vocabulary=self._motif_vocabulary,
                num_atom_types=len(self._node_type_index_to_string),
            )

            for motif, node_type in self._motif_to_node_type_index.items():
                self._node_type_index_to_string[node_type] = motif
        else:
            self._motif_to_node_type_index = {}
        

    def process(self):
        """Convert raw generation traces into individual .pt files for each of the trace steps."""
        # only call process if it was not called before
        if self.processed_file_names_size > 0:
            pass
        else:
            
            for pkl_file_path in self.raw_file_names:
                generation_steps = self._convert_data_shard_to_list_of_trace_steps(pkl_file_path)
                
                for molecule_idx, molecule_gen_steps in generation_steps:
                    
                    for step_idx, step in enumerate(molecule_gen_steps):
                        file_name = f'{pkl_file_path.split("/")[-1].split(".")[0]}_mol_{molecule_idx}_step_{step_idx}.pt'
                        file_path = os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split, file_name)
                        torch.save(step, file_path)
                        print(f'Processing {molecule_idx}, step {step_idx}')
            print(f'{self.processed_file_names_size} files generated')         
            

    def _convert_data_shard_to_list_of_trace_steps(self, pkl_file_path):
        # TODO: multiprocessing to speed this up
        generation_steps = []
        
        with gzip.open(pkl_file_path, 'rb') as f:
            molecules = pickle.load(f)
            for molecule_idx, molecule in enumerate(molecules): 
                generation_steps += [(molecule_idx, self._extract_generation_steps(molecule))]
        
        return generation_steps
            
    def _extract_generation_steps(self, molecule):
        """Packages each generation step of each molecule into a pyg Data object."""
        molecule_gen_steps = []
        molecule_property_values = dict(molecule.graph_property_values)
        for gen_step in molecule:
            gen_step_features = {}
            gen_step_features['x'] = gen_step.partial_node_features
            gen_step_features['focus_node'] = gen_step.focus_node

            # have an edge type attribute to tell apart each of the 3 bond types
            edge_indexes = []
            edge_types= []
            for i, adj_list in enumerate(gen_step.partial_adjacency_lists):
                if len(adj_list) != 0:
                    edge_index = torch.tensor(adj_list).T
                    edge_indexes += [edge_index]
                    edge_types += [i]*len(adj_list)

            gen_step_features['edge_index'] = torch.cat(edge_indexes, 1) if len(edge_indexes) > 0 else torch.tensor(edge_indexes)
            gen_step_features['edge_type'] = torch.tensor(edge_types)
            gen_step_features['correct_edge_choices'] = gen_step.correct_edge_choices

            num_correct_edge_choices = np.sum(gen_step.correct_edge_choices)
            gen_step_features['num_correct_edge_choices'] = num_correct_edge_choices
            gen_step_features['stop_node_label'] = int(num_correct_edge_choices == 0)
            gen_step_features['valid_edge_choices'] = gen_step.valid_edge_choices
            gen_step_features["correct_edge_types"] = gen_step.correct_edge_types
            gen_step_features["partial_node_categorical_features"] = gen_step.partial_node_categorical_features
            if gen_step.correct_attachment_point_choice is not None:
                gen_step_features["correct_attachment_point_choice"] = list(gen_step.valid_attachment_point_choices).index(gen_step.correct_attachment_point_choice)
            else:
                gen_step_features["correct_attachment_point_choice"] = []
            gen_step_features["valid_attachment_point_choices"] = gen_step.valid_attachment_point_choices

            # And finally, the correct node type choices. Here, we have an empty list of
            # correct choices for all steps where we didn't choose a node, so we skip that:
            if gen_step.correct_node_type_choices is not None:
                gen_step_features["correct_node_type_choices"] = self.node_types_to_multi_hot(gen_step.correct_node_type_choices)
            else:
                gen_step_features["correct_node_type_choices"] = []
            gen_step_features['correct_first_node_type_choices'] = self.node_types_to_multi_hot(molecule.correct_first_node_type_choices)

            # Add graph_property_values
            gen_step_features = {**gen_step_features, **molecule_property_values}
        
            molecule_gen_steps += [gen_step_features]
        molecule_gen_steps = self._to_tensor_moler(molecule_gen_steps)
        return [Data(**step) for step in molecule_gen_steps]
    
    def _to_tensor_moler(self, molecule_gen_steps):
        for i in range(len(molecule_gen_steps)):
            for k,v in molecule_gen_steps[i].items():
                molecule_gen_steps[i][k] = torch.tensor(molecule_gen_steps[i][k])
        return molecule_gen_steps
        
    def len(self):
        return self.processed_file_names_size

    def get(self, idx):
        file_path = self.processed_file_names[idx]
        data = torch.load(file_path)
        return data

2022-12-09 12:03:06.465911: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/l1000/trace_playground',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/pyg_output_playground',
    split = 'train',
)

In [5]:
dataset

MolerDataset(2384)

In [7]:
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=16, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices'
])

In [8]:
def pprint_pyg_obj(batch):
    for key in vars(batch)['_store'].keys():
        if key.startswith('_'):
            continue
        print(f'{key}: {batch[key].shape}')
for batch in loader:
    pprint_pyg_obj(batch)
    break

x: torch.Size([126, 32])
edge_index: torch.Size([2, 212])
focus_node: torch.Size([16])
edge_type: torch.Size([212])
correct_edge_choices: torch.Size([53])
correct_edge_choices_batch: torch.Size([53])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([53, 2])
valid_edge_choices_batch: torch.Size([53])
valid_edge_choices_ptr: torch.Size([17])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([126])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_attachment_point_choice_ptr: torch.Size([17])
valid_attachment_point_choices: torch.Size([0])
valid_attachment_point_choices_batch: torch.Size([0])
valid_attachment_point_choices_ptr: torch.Size([17])
correct_node_type_choices: torch.Size([1112])
correct_node_type_choi

In [76]:
# Get some information out from the dataset:
next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
atom_type_nums = [
    next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
    for type_idx in range(dataset.num_node_types)
]

In [9]:
batch.partial_node_categorical_features

tensor([ 16,  16,  16,  16,  16,  16,  16,  16, 129,  16,  16,  16,  16, 129,
         16,  16,  16,  16, 129, 129,  16,  16,  16,  16, 129, 129,  16,  16,
         16,  16, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129,  16,  16,
         16,  16, 129, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129, 129,
         16,  16,  16,  16, 129, 129, 129, 129, 129,  16,  16,  16,  16, 129,
        129, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129, 129, 129, 129,
         16,  16,  16,  16, 129, 129, 129, 129, 129, 129,  16,  16,  16,  16,
        129, 129, 129, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129, 129,
        129, 129, 131,  16,  16,  16,  16, 129, 129, 129, 129, 129, 129, 131])

In [None]:
from molecule_generation.utils.training_utils import get_class_balancing_weights

class BaseModel(torch.nn.Module):
    def __init__(self, params, dataset):
        super(BaseModel, self).__init__()
        self._init_params(params, dataset)
        self._motif_aware_embedding_layer = Embedding(params['motif_aware_embedding'])
        self._encoder_gnn = FullGraphEncoder(params['encoder_gnn'])
        self._decoder_gnn = PartialGraphEncoder(params['decoder_gnn'])
        self._decoder_mlp = MLPDecoder(params['decoderMLP'])
        
        
               
    def _init_params(self, params, dataset):
        """
        Initialise class weights for next node prediction and placefolder for
        motif/node embeddings.
        """
        self._latent_repr_dim = params["latent_repr_size"]
        # Get some information out from the dataset:
        next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
        class_weight_factor = self._params.get("node_type_predictor_class_loss_weight_factor", 0.0)
        
        if not (0.0 <= class_weight_factor <= 1.0):
            raise ValueError(
                f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
            )
        if class_weight_factor > 0:
            atom_type_nums = [
                next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
                for type_idx in range(dataset.num_node_types)
            ]
            atom_type_nums.append(next_node_type_distribution["None"])

            self.class_weights = get_class_balancing_weights(
                class_counts=atom_type_nums, class_weight_factor=class_weight_factor
            )
        else:
            self.class_weights = None
            
        motif_vocabulary = dataset.metadata.get("motif_vocabulary")
        self._uses_motifs = motif_vocabulary is not None

        self._node_categorical_num_classes = dataset.node_categorical_num_classes
        
        
        if self.uses_categorical_features:
            if "categorical_features_embedding_dim" in self._params:
                self._node_categorical_features_embedding = None
        
    @property
    def uses_motifs(self):
        return self._uses_motifs

    @property
    def uses_categorical_features(self):
        return self._node_categorical_num_classes is not None

    @property
    def decoder(self):
        return self._decoder

    @property
    def encoder(self):
        return self._encoder
    
    @property
    def motif_aware_embedding_layer(self):
        return self._motif_aware_embedding_layer
    
    @property
    def latent_dim(self):
        return self._latent_repr_dim
    
    def forward(self, x, edge_index, edge_attr, batch):
        # Obtain node embeddings 
        
        
        
        # Forward pass through encoder
        latent_repr = self.encoder(input)
        
        
        # Forward pass through decoder
        node_type_logits, edge_type_logits, edge_type_logits, attachment_point_selection_logits = self.decoder(latent_repr)
        
        
        # NOTE: loss computation will be done in lightning module
        
        
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_attr)
        return x
#     def _compute_decoder_loss_and_metrics(
#         self, batch_features, task_output, batch_labels
#     ) -> Tuple[tf.Tensor, MoLeRDecoderMetrics]:

#         decoder_metrics = self.decoder.compute_metrics(
#             batch_features=batch_features, batch_labels=batch_labels, task_output=task_output
#         )

#         total_loss = (
#             self._params["node_classification_loss_weight"]
#             * decoder_metrics.node_classification_loss
#             + self._params["first_node_classification_loss_weight"]
#             * decoder_metrics.first_node_classification_loss
#             + self._params["edge_selection_loss_weight"] * decoder_metrics.edge_loss
#             + self._params["edge_type_loss_weight"] * decoder_metrics.edge_type_loss
#         )

#         if self.uses_motifs:
#             total_loss += (
#                 self._params["attachment_point_selection_weight"]
#                 * decoder_metrics.attachment_point_selection_loss
#             )

#         return total_loss, decoder_metrics

In [84]:
class GraphEncoder(torch.nn.Module):
    """For constructing graph level embedding during encoder step"""
    def __init__(self, num_node_features, hidden_channels):
        super(GraphEncoder, self).__init__()
        # torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, edge_attr, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_attr)
        x = x.leaky_relu()
        x = self.conv2(x, edge_index, edge_attr)
        x = x.leaky_relu()
        x = self.conv3(x, edge_index, edge_attr)
        x = 
        return x


class PartialGraphEncoder(torch.nn.Module):
    """For constructing graph level embedding during decoder steps"""
    def __init__(self, num_node_features, hidden_channels):
        super(PartialGraphEncoder, self).__init__()
        # torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, edge_attr, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_attr)
        return x
  

class PickAtomOrMotifMLP(torch.nn.Module):
    """
    For choosing an atom/motif out of the motif vocabulary
    Notes:
    Softmax layer at the end with 
    
    """
    def __init__(self, motif_vocabulary, latent_vector_dim, hidden_channels=256):
        super(PickAtomOrMotifMLP, self).__init__()
        self.hidden_layer1 = Linear(latent_vector_dim, hidden_channels)
        self.hidden_layer2 = Linear(hidden_channels, len(motif_vocabulary) + 1) # add 1 for <END OF GENERATION TOKEN>
        self.activation = torch.Softmax
        
    def forward(self, latent_vector, partial_graph_embedding):
        x = self.hidden_layer1(latent_vector)
        x = self.hidden_layer2(x)
        return self.activation(x)
        

class PickAttachmentMLP(torch.nn.Module):
    def __init__(self, motif_dictionary, hidden_channels):
        super(PickAttachmentMLP, self).__init__()
        pass
    def forward(self, latent_vector, partial_graph_embedding):
        pass

class PickBond(torch.nn.Module):
    def __init__(self, motif_dictionary, hidden_channels):
        super(PickAttachmentMLP, self).__init__()
        pass
    def forward(self, latent_vector, partial_graph_embedding):
        pass

# TODOs

1. Implement the learned aggregation function in Moler Paper by subclassing MessagePassing layer 
2. Change the init function of the encoder to take in the relevant hyperparameters from the `moler_vae` file