In [1]:
import pandas as pd
df = pd.read_csv('/data/ongh0068/l1000/pyg_output_playground/train/processed_file_paths.csv')

In [5]:
import os
os.chdir('/data/ongh0068/l1000/moler_reference')

In [6]:
from torch_geometric.data import Data
class MolerData(Data):
    def __init__(
        self, 
        x = None,
        edge_index = None,
        edge_attr = None, 
        y = None, 
        pos = None,
        original_graph_edge_index=None, 
        original_graph_x=None, 
        **kwargs
        
    ):
        super().__init__(x, edge_index, edge_attr, y, pos, **kwargs)
        self.original_graph_edge_index = original_graph_edge_index
        self.original_graph_x = original_graph_x
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'original_graph_edge_index':
            return self.original_graph_x.size(0)
        else:
            
            return super().__inc__(key, value, *args, **kwargs)

In [7]:
from torch_geometric.data import Dataset, Data
import os
import pandas as pd
import numpy as np
from molecule_generation.chem.motif_utils import get_motif_type_to_node_type_index_map
import torch
import gzip
import pickle
def get_motif_type_to_node_type_index_map(
    motif_vocabulary, num_atom_types
):
    """Helper to construct a mapping from motif type to shifted node type."""

    return {
        motif: num_atom_types + motif_type
        for motif, motif_type in motif_vocabulary.vocabulary.items()
    }


class MolerDataset(Dataset):
    def __init__(
        self, 
        root, 
        raw_moler_trace_dataset_parent_folder, # absolute path 
        output_pyg_trace_dataset_parent_folder, # absolute path
        split = 'train',
        transform=None, 
        pre_transform=None, 
    ):
        self._processed_file_paths = None
        self._transform = transform 
        self._pre_transform = pre_transform
        self._raw_moler_trace_dataset_parent_folder = raw_moler_trace_dataset_parent_folder
        self._output_pyg_trace_dataset_parent_folder = output_pyg_trace_dataset_parent_folder
        self._split = split
        self.load_metadata()

        # create the directory for the processed data if it doesn't exist
        processed_file_paths_folder = os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split)
        if not os.path.exists(processed_file_paths_folder):
            os.mkdir(processed_file_paths_folder)
        # try to read in the csv with the processed file paths
        processed_file_paths_csv = os.path.join(processed_file_paths_folder, 'processed_file_paths.csv')
        if os.path.exists(processed_file_paths_csv):
            self._processed_file_paths = pd.read_csv(processed_file_paths_csv)['file_names'].tolist()
        
        super().__init__(root, transform, pre_transform)
        self._processed_file_paths = pd.read_csv(processed_file_paths_csv)['file_names'].tolist()

    @property
    def raw_file_names(self):
        """
        Raw generation trace files output from the preprocess function of the cli. These are zipped pickle
        files. This is the actual file name without the parent folder.
        """
        raw_pkl_file_folders = [folder for folder in os.listdir(self._raw_moler_trace_dataset_parent_folder) if folder.startswith(self._split)]

        assert len(raw_pkl_file_folders) > 0, f'{self._raw_moler_trace_dataset_parent_folder} does not contain {self._split} files.'
        
        raw_generation_trace_files = []
        for folder in raw_pkl_file_folders:
            for pkl_file in os.listdir(os.path.join(self._raw_moler_trace_dataset_parent_folder, folder)):
                raw_generation_trace_files.append(os.path.join(self._raw_moler_trace_dataset_parent_folder, folder, pkl_file))
        return raw_generation_trace_files

    @property
    def processed_file_names(self):
        """Processed generation trace objects that are stored as .pt files"""
        if self._processed_file_paths is not None:
            return self._processed_file_paths
        else:
            return []


    
    @property
    def processed_file_names_size(self):
        return len(self.processed_file_names)
    
    @property 
    def metadata(self):
        return self._metadata

    @property
    def node_type_index_to_string(self):
        return self._node_type_index_to_string
    
    @property 
    def num_node_types(self):
        return len(self.node_type_index_to_string)

    def node_type_to_index(self, node_type):
        return self._atom_type_featuriser.type_name_to_index(node_type)

    def node_types_to_indices(self, node_types):
        """Convert list of string representations into list of integer indices."""
        return [self.node_type_to_index(node_type) for node_type in node_types]

    def node_types_to_multi_hot(self, node_types):
        """Convert between string representation to multi hot encoding of correct node types.

        Note: implemented here for backwards compatibility only.
        """
        correct_indices = self.node_types_to_indices(node_types)
        multihot = np.zeros(shape=(self.num_node_types,), dtype=np.float32)
        for idx in correct_indices:
            multihot[idx] = 1.0
        return multihot
    
    def node_type_to_index(self, node_type):
        motif_node_type_index = self._motif_to_node_type_index.get(node_type)

        if motif_node_type_index is not None:
            return motif_node_type_index
        else:
            return self._atom_type_featuriser.type_name_to_index(node_type)
    
    def load_metadata(self):
        metadata_file_path = os.path.join(self._raw_moler_trace_dataset_parent_folder, 'metadata.pkl.gz')
        
        with gzip.open(metadata_file_path, 'rb') as f:
             self._metadata = pickle.load(f)
        
        self._atom_type_featuriser = next(
            featuriser
            for featuriser in self._metadata["feature_extractors"]
            if featuriser.name == "AtomType"
        )
        
        self._node_type_index_to_string = self._atom_type_featuriser.index_to_atom_type_map.copy()
        self._motif_vocabulary = self.metadata.get("motif_vocabulary")

        if self._motif_vocabulary is not None:
            self._motif_to_node_type_index = get_motif_type_to_node_type_index_map(
                motif_vocabulary=self._motif_vocabulary,
                num_atom_types=len(self._node_type_index_to_string),
            )

            for motif, node_type in self._motif_to_node_type_index.items():
                self._node_type_index_to_string[node_type] = motif
        else:
            self._motif_to_node_type_index = {}

    def generate_preprocessed_file_paths_csv(self, preprocessed_file_paths_folder):
        file_paths = [os.path.join(preprocessed_file_paths_folder, file_path) for file_path in os.listdir(preprocessed_file_paths_folder)]
        df = pd.DataFrame(file_paths, columns = ['file_names'])
        processed_file_paths_csv = os.path.join(preprocessed_file_paths_folder, 'processed_file_paths.csv')
        df.to_csv(processed_file_paths_csv, index = False)

    def process(self):
        """Convert raw generation traces into individual .pt files for each of the trace steps."""
        # only call process if it was not called before
        if self.processed_file_names_size > 0:
            pass
        else:
            self.load_metadata()
            for pkl_file_path in self.raw_file_names:
                generation_steps = self._convert_data_shard_to_list_of_trace_steps(pkl_file_path)
                
                for molecule_idx, molecule_gen_steps in generation_steps:
                    for step_idx, step in enumerate(molecule_gen_steps):
                        file_name = f'{pkl_file_path.split("/")[-1].split(".")[0]}_mol_{molecule_idx}_step_{step_idx}.pt'
                        file_path = os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split, file_name)
                        print('file_path', file_path)
                        torch.save(step, file_path)
                        print(f'Processing {molecule_idx}, step {step_idx}')
                        
            self.generate_preprocessed_file_paths_csv(os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split))

    def _convert_data_shard_to_list_of_trace_steps(self, pkl_file_path):
        # TODO: multiprocessing to speed this up
        generation_steps = []
        
        with gzip.open(pkl_file_path, 'rb') as f:
            molecules = pickle.load(f)
            for molecule_idx, molecule in enumerate(molecules): 
                
                generation_steps += [(molecule_idx, self._extract_generation_steps(molecule))]

        return generation_steps
            
    def _extract_generation_steps(self, molecule):
        molecule_gen_steps = []
        molecule_property_values = {k: [v] for k,v in molecule.graph_property_values.items()}
        for gen_step in molecule:
            gen_step_features = {}
            
            gen_step_features['original_graph_x'] = molecule.node_features
            # have an edge type attribute to tell apart each of the 3 bond types
            edge_indexes = []
            edge_types= []
            for i, adj_list in enumerate(molecule.adjacency_lists):
                if len(adj_list) != 0:
                    edge_index = adj_list.T
                    edge_indexes += [edge_index]
                    edge_types += [i]*len(adj_list)

            
            gen_step_features['original_graph_edge_index'] =  np.concatenate(edge_indexes, 1) if len(edge_indexes) > 0 else np.array(edge_indexes)
            gen_step_features['original_graph_edge_type'] = np.array(edge_types)
            gen_step_features['original_graph_node_categorical_features'] = molecule.node_categorical_features
            
            
            gen_step_features['x'] = gen_step.partial_node_features
            gen_step_features['focus_node'] = [gen_step.focus_node]

            # have an edge type attribute to tell apart each of the 3 bond types
            edge_indexes = []
            edge_types= []
            for i, adj_list in enumerate(gen_step.partial_adjacency_lists):
                if len(adj_list) != 0:
                    edge_index = adj_list.T
                    edge_indexes += [edge_index]
                    edge_types += [i]*len(adj_list)

            gen_step_features['edge_index'] = np.concatenate(edge_indexes, 1) if len(edge_indexes) > 0 else np.array(edge_indexes)
            gen_step_features['edge_type'] = np.array(edge_types)
            gen_step_features['edge_features'] = np.array(gen_step.edge_features)
            gen_step_features['correct_edge_choices'] = gen_step.correct_edge_choices
            
            num_correct_edge_choices = np.sum(gen_step.correct_edge_choices)
            gen_step_features['num_correct_edge_choices'] = [num_correct_edge_choices]
            gen_step_features['stop_node_label'] = [int(num_correct_edge_choices == 0)]
            gen_step_features['valid_edge_choices'] = gen_step.valid_edge_choices
            gen_step_features['valid_edge_types'] = gen_step.valid_edge_types

            gen_step_features["correct_edge_types"] = gen_step.correct_edge_types
            gen_step_features["partial_node_categorical_features"] = gen_step.partial_node_categorical_features
            if gen_step.correct_attachment_point_choice is not None:
                gen_step_features["correct_attachment_point_choice"] = [list(gen_step.valid_attachment_point_choices).index(gen_step.correct_attachment_point_choice)]
            else:
                gen_step_features["correct_attachment_point_choice"] = []
            gen_step_features["valid_attachment_point_choices"] = gen_step.valid_attachment_point_choices

            # And finally, the correct node type choices. Here, we have an empty list of
            # correct choices for all steps where we didn't choose a node, so we skip that:
            if gen_step.correct_node_type_choices is not None:
                gen_step_features["correct_node_type_choices"] = self.node_types_to_multi_hot(gen_step.correct_node_type_choices)
            else:
                gen_step_features["correct_node_type_choices"] = []
            gen_step_features['correct_first_node_type_choices'] = self.node_types_to_multi_hot(molecule.correct_first_node_type_choices)
            
            # Add graph_property_values
            gen_step_features = {**gen_step_features, **molecule_property_values}
            molecule_gen_steps += [gen_step_features]
            
        molecule_gen_steps = self._to_tensor_moler(molecule_gen_steps)

        return [MolerData(**step) for step in molecule_gen_steps]
    
    def _to_tensor_moler(self, molecule_gen_steps):
        for i in range(len(molecule_gen_steps)):
            for k,v in molecule_gen_steps[i].items():
                molecule_gen_steps[i][k] = torch.tensor(molecule_gen_steps[i][k])
        return molecule_gen_steps
    
    def len(self):
        return self.processed_file_names_size

    def get(self, idx):
        file_path = self.processed_file_names[idx]
        data = torch.load(file_path)
        return data

2022-12-15 12:44:15.657311: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/l1000/trace_playground',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/pyg_output_playground',
    split = 'train',
)

Processing...


file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_0.pt
Processing 0, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_1.pt
Processing 0, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_2.pt
Processing 0, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_3.pt
Processing 0, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_4.pt
Processing 0, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_5.pt
Processing 0, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_6.pt
Processing 0, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_7.pt
Processing 0, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_8.pt
Processing 0, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train

Processing 2, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_4.pt
Processing 2, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_5.pt
Processing 2, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_6.pt
Processing 2, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_7.pt
Processing 2, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_8.pt
Processing 2, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_9.pt
Processing 2, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_10.pt
Processing 2, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_11.pt
Processing 2, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_12.pt
Processing 2, step 12
file_path /data/ongh0068/l1000/pyg_ou

Processing 5, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_5.pt
Processing 5, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_6.pt
Processing 5, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_7.pt
Processing 5, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_8.pt
Processing 5, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_9.pt
Processing 5, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_10.pt
Processing 5, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_11.pt
Processing 5, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_12.pt
Processing 5, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_13.pt
Processing 5, step 13
file_path /data/ongh0068/l1000/pyg_

Processing 10, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_12.pt
Processing 10, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_13.pt
Processing 10, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_14.pt
Processing 10, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_15.pt
Processing 10, step 15
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_16.pt
Processing 10, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_11_step_0.pt
Processing 11, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_11_step_1.pt
Processing 11, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_11_step_2.pt
Processing 11, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_11_step_3.pt
Processing 11, step 3
file_path /da

Processing 13, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_7.pt
Processing 13, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_8.pt
Processing 13, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_9.pt
Processing 13, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_10.pt
Processing 13, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_11.pt
Processing 13, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_12.pt
Processing 13, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_13.pt
Processing 13, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_14.pt
Processing 13, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_15.pt
Processing 13, step 15
file_path /d

Processing 15, step 22
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_23.pt
Processing 15, step 23
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_24.pt
Processing 15, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_25.pt
Processing 15, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_26.pt
Processing 15, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_27.pt
Processing 15, step 27
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_28.pt
Processing 15, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_29.pt
Processing 15, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_30.pt
Processing 15, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_31.pt
Processing 15, step 31
file_

Processing 20, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_6.pt
Processing 20, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_7.pt
Processing 20, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_8.pt
Processing 20, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_9.pt
Processing 20, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_10.pt
Processing 20, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_11.pt
Processing 20, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_12.pt
Processing 20, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_13.pt
Processing 20, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_14.pt
Processing 20, step 14
file_path /dat

Processing 24, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_12.pt
Processing 24, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_13.pt
Processing 24, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_14.pt
Processing 24, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_15.pt
Processing 24, step 15
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_16.pt
Processing 24, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_17.pt
Processing 24, step 17
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_18.pt
Processing 24, step 18
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_19.pt
Processing 24, step 19
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_20.pt
Processing 24, step 20
file_

Processing 28, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_17.pt
Processing 28, step 17
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_18.pt
Processing 28, step 18
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_19.pt
Processing 28, step 19
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_20.pt
Processing 28, step 20
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_21.pt
Processing 28, step 21
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_22.pt
Processing 28, step 22
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_23.pt
Processing 28, step 23
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_29_step_0.pt
Processing 29, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_30_step_0.pt
Processing 30, step 0
file_path

Processing 31, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_27.pt
Processing 31, step 27
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_28.pt
Processing 31, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_29.pt
Processing 31, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_30.pt
Processing 31, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_31.pt
Processing 31, step 31
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_32.pt
Processing 31, step 32
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_33.pt
Processing 31, step 33
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_32_step_0.pt
Processing 32, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_32_step_1.pt
Processing 32, step 1
file_path

Processing 37, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_1.pt
Processing 37, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_2.pt
Processing 37, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_3.pt
Processing 37, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_4.pt
Processing 37, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_5.pt
Processing 37, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_6.pt
Processing 37, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_7.pt
Processing 37, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_8.pt
Processing 37, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_37_step_9.pt
Processing 37, step 9
file_path /data/ongh0068

Processing 42, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_9.pt
Processing 42, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_10.pt
Processing 42, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_11.pt
Processing 42, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_12.pt
Processing 42, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_13.pt
Processing 42, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_14.pt
Processing 42, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_15.pt
Processing 42, step 15
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_16.pt
Processing 42, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_42_step_17.pt
Processing 42, step 17
file_pat

Processing 46, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_4.pt
Processing 46, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_5.pt
Processing 46, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_6.pt
Processing 46, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_7.pt
Processing 46, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_8.pt
Processing 46, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_9.pt
Processing 46, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_10.pt
Processing 46, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_11.pt
Processing 46, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_46_step_12.pt
Processing 46, step 12
file_path /data/on

Processing 48, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_26.pt
Processing 48, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_27.pt
Processing 48, step 27
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_28.pt
Processing 48, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_29.pt
Processing 48, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_30.pt
Processing 48, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_31.pt
Processing 48, step 31
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_32.pt
Processing 48, step 32
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_33.pt
Processing 48, step 33
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_48_step_34.pt
Processing 48, step 34
file_

Processing 52, step 21
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_22.pt
Processing 52, step 22
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_23.pt
Processing 52, step 23
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_24.pt
Processing 52, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_25.pt
Processing 52, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_26.pt
Processing 52, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_27.pt
Processing 52, step 27
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_28.pt
Processing 52, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_29.pt
Processing 52, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_30.pt
Processing 52, step 30
file_

Processing 55, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_5.pt
Processing 55, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_6.pt
Processing 55, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_7.pt
Processing 55, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_8.pt
Processing 55, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_9.pt
Processing 55, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_10.pt
Processing 55, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_11.pt
Processing 55, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_12.pt
Processing 55, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_55_step_13.pt
Processing 55, step 13
file_path /data/

Processing 57, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_13.pt
Processing 57, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_14.pt
Processing 57, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_15.pt
Processing 57, step 15
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_16.pt
Processing 57, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_17.pt
Processing 57, step 17
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_18.pt
Processing 57, step 18
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_19.pt
Processing 57, step 19
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_20.pt
Processing 57, step 20
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_57_step_21.pt
Processing 57, step 21
file_

Processing 60, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_25.pt
Processing 60, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_26.pt
Processing 60, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_61_step_0.pt
Processing 61, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_61_step_1.pt
Processing 61, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_61_step_2.pt
Processing 61, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_61_step_3.pt
Processing 61, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_61_step_4.pt
Processing 61, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_61_step_5.pt
Processing 61, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_61_step_6.pt
Processing 61, step 6
file_path /data/ong

Processing 63, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_1.pt
Processing 63, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_2.pt
Processing 63, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_3.pt
Processing 63, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_4.pt
Processing 63, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_5.pt
Processing 63, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_6.pt
Processing 63, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_7.pt
Processing 63, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_8.pt
Processing 63, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_63_step_9.pt
Processing 63, step 9
file_path /data/ongh0068

Processing 68, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_0.pt
Processing 69, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_1.pt
Processing 69, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_2.pt
Processing 69, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_3.pt
Processing 69, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_4.pt
Processing 69, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_5.pt
Processing 69, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_6.pt
Processing 69, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_7.pt
Processing 69, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_69_step_8.pt
Processing 69, step 8
file_path /data/ongh0068

Processing 72, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_7.pt
Processing 72, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_8.pt
Processing 72, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_9.pt
Processing 72, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_10.pt
Processing 72, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_11.pt
Processing 72, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_12.pt
Processing 72, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_13.pt
Processing 72, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_14.pt
Processing 72, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_72_step_15.pt
Processing 72, step 15
file_path /d

Processing 74, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_9.pt
Processing 74, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_10.pt
Processing 74, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_11.pt
Processing 74, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_12.pt
Processing 74, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_13.pt
Processing 74, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_14.pt
Processing 74, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_15.pt
Processing 74, step 15
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_16.pt
Processing 74, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_74_step_17.pt
Processing 74, step 17
file_pat

Processing 76, step 39
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_40.pt
Processing 76, step 40
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_41.pt
Processing 76, step 41
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_42.pt
Processing 76, step 42
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_43.pt
Processing 76, step 43
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_44.pt
Processing 76, step 44
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_77_step_0.pt
Processing 77, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_77_step_1.pt
Processing 77, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_77_step_2.pt
Processing 77, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_77_step_3.pt
Processing 77, step 3
file_path /da

Processing 81, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_9.pt
Processing 81, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_10.pt
Processing 81, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_11.pt
Processing 81, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_12.pt
Processing 81, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_13.pt
Processing 81, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_14.pt
Processing 81, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_15.pt
Processing 81, step 15
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_16.pt
Processing 81, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_81_step_17.pt
Processing 81, step 17
file_pat

Processing 86, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_7.pt
Processing 86, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_8.pt
Processing 86, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_9.pt
Processing 86, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_10.pt
Processing 86, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_11.pt
Processing 86, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_12.pt
Processing 86, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_13.pt
Processing 86, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_14.pt
Processing 86, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_15.pt
Processing 86, step 15
file_path /d

Processing 89, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_89_step_10.pt
Processing 89, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_89_step_11.pt
Processing 89, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_89_step_12.pt
Processing 89, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_89_step_13.pt
Processing 89, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_89_step_14.pt
Processing 89, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_89_step_15.pt
Processing 89, step 15
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_90_step_0.pt
Processing 90, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_90_step_1.pt
Processing 90, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_90_step_2.pt
Processing 90, step 2
file_path /d

Processing 93, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_93_step_31.pt
Processing 93, step 31
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_93_step_32.pt
Processing 93, step 32
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_93_step_33.pt
Processing 93, step 33
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_93_step_34.pt
Processing 93, step 34
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_94_step_0.pt
Processing 94, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_94_step_1.pt
Processing 94, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_94_step_2.pt
Processing 94, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_94_step_3.pt
Processing 94, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_94_step_4.pt
Processing 94, step 4
file_path /data

Processing 98, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_3.pt
Processing 98, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_4.pt
Processing 98, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_5.pt
Processing 98, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_6.pt
Processing 98, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_7.pt
Processing 98, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_8.pt
Processing 98, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_9.pt
Processing 98, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_10.pt
Processing 98, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_11.pt
Processing 98, step 11
file_path /data/ongh

Done!


In [9]:
dataset

MolerDataset(2384)

In [13]:
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=16, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices',
    'original_graph_x'
])

In [11]:
def pprint_pyg_obj(batch):
    for key in vars(batch)['_store'].keys():
        if key.startswith('_'):
            continue
        print(f'{key}: {batch[key].shape}')
for batch in loader:
    pprint_pyg_obj(batch)
    break

x: torch.Size([193, 32])
edge_index: torch.Size([2, 370])
original_graph_edge_type: torch.Size([1024])
original_graph_node_categorical_features: torch.Size([464])
focus_node: torch.Size([16])
edge_type: torch.Size([370])
edge_features: torch.Size([86, 3])
correct_edge_choices: torch.Size([86])
correct_edge_choices_batch: torch.Size([86])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([86, 2])
valid_edge_choices_batch: torch.Size([86])
valid_edge_choices_ptr: torch.Size([17])
valid_edge_types: torch.Size([9, 3])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([193])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_attachment_point_choice_ptr: torch.Size([17])
valid_attachment_point_choices: torch.Size([0]

In [12]:
batch.original_graph_edge_index

tensor([[  0,   1,   2,  ..., 460, 462, 463],
        [  1,   2,   4,  ..., 459, 461, 461]], dtype=torch.int32)

In [10]:
# Get some information out from the dataset:
next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
atom_type_nums = [
    next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
    for type_idx in range(dataset.num_node_types)
]

In [11]:
batch.partial_node_categorical_features

tensor([ 16,  16,  16,  16,  16,  16,  16,  16, 129,  16,  16,  16,  16, 129,
        129, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129, 129, 129, 129,
         16,  16,  16,  16, 129, 129, 129, 129, 129, 129,  16,  16,  16,  16,
        129, 129, 129, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129, 129,
        129, 129, 131,  16,  16,  16,  16, 129, 129, 129, 129, 129, 129, 131,
         16,  16,  16,  16, 129, 129, 129, 129, 129, 129, 131, 129,  16,  16,
         16,  16, 129, 129, 129, 129, 129, 129, 131, 129,  16,  16,  16,  16,
        129, 129, 129, 129, 129, 129, 131, 129,   0,   0,   0,   0,   0,   0,
         16,  16,  16,  16, 129, 129, 129, 129, 129, 129, 131, 129,   0,   0,
          0,   0,   0,   0,  16,  16,  16,  16, 129,  16,  16,  16,  16, 129,
        129, 129, 129, 129, 129, 131, 129,   0,   0,   0,   0,   0,   0, 129,
         16,  16,  16,  16, 129, 129, 129, 129, 129, 129, 131, 129,   0,   0,
          0,   0,   0,   0, 129,  16,  16,  16,  16, 129, 129, 1

# Encoder
Depending on the representation of the molecular graph (whether bond type is represented as edge type or as edge attributes), we have to change the function signature of the forward method

1. Try out both
2. tune number of heads
3. Look into softmax aggregation (hyperparams) + also implement sigmoidaggregation as a separate aggregation layer
4. Set number of heads as a hyperparam
5. Add leakyrelu() + layer norm

In [12]:
vocab_size = len(dataset.node_type_index_to_string)
embedding_size = 64
for batch in loader:
    pprint_pyg_obj(batch)
    break
    
embed = torch.nn.Embedding(vocab_size, embedding_size)
motif_embeddings = embed(batch.partial_node_categorical_features)
batch.x = torch.cat((batch.x, motif_embeddings), axis = -1)

from torch_geometric.nn import RGATConv
resize_layer = RGATConv(
in_channels = batch.x.shape[-1], 
out_channels = 64, 
num_relations = 3
)


batch.x = resize_layer(batch.x, batch.edge_index.long(), batch.edge_type)


x: torch.Size([193, 32])
edge_index: torch.Size([2, 370])
original_graph_x: torch.Size([464, 32])
original_graph_edge_index: torch.Size([2, 1024])
original_graph_edge_type: torch.Size([1024])
original_graph_node_categorical_features: torch.Size([464])
focus_node: torch.Size([16])
edge_type: torch.Size([370])
edge_features: torch.Size([86, 3])
correct_edge_choices: torch.Size([86])
correct_edge_choices_batch: torch.Size([86])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([86, 2])
valid_edge_choices_batch: torch.Size([86])
valid_edge_choices_ptr: torch.Size([17])
valid_edge_types: torch.Size([9, 3])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([193])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_att

In [13]:
from torch_geometric.nn import RGATConv
from torch_geometric.nn import aggr
class GraphEncoder(torch.nn.Module):
    """For constructing graph level embedding during encoder step"""
    def __init__(
        self, 
        input_feature_dim,
        node_feature_dim = 64,
        num_layers = 12,
        use_intermediate_gnn_results = True,
    ):
        super(GraphEncoder, self).__init__()
        self._first_layer = RGATConv(
            in_channels = input_feature_dim, 
            out_channels = node_feature_dim, 
            num_relations = 3
        )
        
        self._encoder_layers = torch.nn.ModuleList(
            [RGATConv(
                in_channels = node_feature_dim, 
                out_channels = node_feature_dim, 
                num_relations = 3
                ) for _ in range(num_layers)
            ]
        )
        self._softmax_aggr = aggr.SoftmaxAggregation(learn=True)
        self._use_intermediate_gnn_results = use_intermediate_gnn_results
    def forward(self, node_features, edge_index, edge_type, batch_index):
        gnn_results = []
        gnn_results += [self._first_layer(node_features, edge_index.long(), edge_type)]
        print(gnn_results[-1].shape)
        for i, layer in enumerate(self._encoder_layers):
            gnn_results += [layer(gnn_results[-1], edge_index.long(), edge_type)]
        
        if self._use_intermediate_gnn_results:
            x = torch.cat(gnn_results, axis = -1)
            graph_representations = self._softmax_aggr(x, batch_index)
        
        else:
            graph_representations = self._softmax_aggr(gnn_results[-1], batch_index)
        node_representations = torch.cat(gnn_results, axis= -1)
        return graph_representations, node_representations

In [14]:
model = GraphEncoder(input_feature_dim = batch.x.shape[-1])
input_molecule_representations, _ = model(batch.x, batch.edge_index.long(), batch.edge_type, batch.batch)

torch.Size([193, 64])


In [15]:
input_molecule_representations.shape

torch.Size([16, 832])

In [16]:
latent_dim = 512

In [17]:
class MeanAndLogVarMLP(torch.nn.Module):
    """For constructing graph level embedding during encoder step"""
    def __init__(
        self, 
        input_dim = 768,
        latent_dim = 512,
        num_layers = 1,
    ):
        super(MeanAndLogVarMLP, self).__init__()
        self._mean_log_var_mlp = torch.nn.Linear(input_dim, 2*latent_dim)

    def forward(self, x):
        x = self._mean_log_var_mlp(x)
        
        return x




In [18]:
mean_log_var_mlp = MeanAndLogVarMLP(input_dim = input_molecule_representations.shape[-1])
x = mean_log_var_mlp(input_molecule_representations)

In [19]:
x.shape

# this represents the aggregated graph representations after passing through 
# the mean and log var mlp => we need to sample from this distribution
# for each of the partial graphs. There are 16 partial graphs, so 
# we need to get a separate sample for each of them


torch.Size([16, 1024])

In [20]:
mu = x[:, : latent_dim]  # Shape: [V, MD]
log_var = x[:, latent_dim :]  # Shape: [V, MD]

# result_representations: shape [num_partial_graphs, latent_repr_dim]
std = torch.exp(log_var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()

In [21]:
z.shape

torch.Size([16, 512])

# Decoder
 Every MLP should condition on the latent representation (currently computed from the partial graph itself which is wrong, should be from the original full molecular graph)

## PickAtomOrMotif + Node type selection loss

In [22]:
decoder_gnn = GraphEncoder(input_feature_dim = batch.x.shape[-1])
partial_graph_representions, node_representations = model(batch.x, batch.edge_index.long(), batch.edge_type, batch.batch)

torch.Size([193, 64])


In [23]:
partial_graph_representions.shape

torch.Size([16, 832])

In [24]:
batch.x.shape[-1]

64

In [25]:
pprint_pyg_obj(batch)


x: torch.Size([193, 64])
edge_index: torch.Size([2, 370])
original_graph_x: torch.Size([464, 32])
original_graph_edge_index: torch.Size([2, 1024])
original_graph_edge_type: torch.Size([1024])
original_graph_node_categorical_features: torch.Size([464])
focus_node: torch.Size([16])
edge_type: torch.Size([370])
edge_features: torch.Size([86, 3])
correct_edge_choices: torch.Size([86])
correct_edge_choices_batch: torch.Size([86])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([86, 2])
valid_edge_choices_batch: torch.Size([86])
valid_edge_choices_ptr: torch.Size([17])
valid_edge_types: torch.Size([9, 3])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([193])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_att

In [26]:
batch.correct_node_type_choices_batch.unique() # graphs requiring node choices

tensor([ 0,  2,  5,  7,  9, 11, 12, 14])

In [27]:
batch.correct_node_type_choices_ptr

tensor([   0,  139,  139,  278,  278,  278,  417,  417,  556,  556,  695,  695,
         834,  973,  973, 1112, 1112])

In [28]:
graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()

In [29]:
start_idx_graphs_requiring_node_choices = batch.correct_node_type_choices_ptr[graphs_requiring_node_choices]

In [30]:
start_idx_graphs_requiring_node_choices

tensor([  0, 139, 278, 417, 556, 695, 834, 973])

In [31]:
node_type_multihot_labels = []
for i in range(len(batch.correct_node_type_choices_ptr)-1):
    start_idx = batch.correct_node_type_choices_ptr[i]
    end_idx = batch.correct_node_type_choices_ptr[i+1] 
    if end_idx - start_idx == 0:
        continue
    node_selection_labels = batch.correct_node_type_choices[start_idx: end_idx]
    node_type_multihot_labels += [node_selection_labels]
    
node_type_multihot_labels = torch.stack(node_type_multihot_labels, axis = 0)

In [32]:
node_type_multihot_labels

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.]])

In [33]:
from torch.nn import Linear
class PickAtomOrMotifMLP(torch.nn.Module):
    """
    For choosing an atom/motif out of the motif vocabulary
    Notes:
    Softmax layer at the end with 
    
    """
    def __init__(self, num_node_types, latent_vector_dim, hidden_channels=64):
        super(PickAtomOrMotifMLP, self).__init__()
        self.hidden_layer1 = Linear(latent_vector_dim, hidden_channels)
        self.hidden_layer2 = Linear(hidden_channels, num_node_types + 1) # add 1 for <END OF GENERATION TOKEN>

        
    def forward(self, original_and_calculated_graph_representations):
        print(original_and_calculated_graph_representations.shape)
        x = self.hidden_layer1(original_and_calculated_graph_representations)
        x = self.hidden_layer2(x)
        return x

In [34]:
node_type_selector = PickAtomOrMotifMLP(
    num_node_types = dataset.num_node_types,
    latent_vector_dim = z.shape[-1] + partial_graph_representions.shape[-1]
)

In [35]:
# result_representations: shape [PG, MD]
# number of partial graphs, molecule representation dimension

def pick_node_type(
    input_molecule_representations,
    graph_representations,
    graphs_requiring_node_choices
):
    
    relevant_graph_representations = input_molecule_representations[graphs_requiring_node_choices]
    relevant_input_molecule_representations = graph_representations[graphs_requiring_node_choices]
    original_and_calculated_graph_representations = torch.cat((relevant_graph_representations,relevant_input_molecule_representations), axis = -1)
    print(original_and_calculated_graph_representations.shape)
    return node_type_selector(original_and_calculated_graph_representations)

output = pick_node_type(z, partial_graph_representions, graphs_requiring_node_choices)


torch.Size([8, 1344])
torch.Size([8, 1344])


In [36]:
def safe_divide_loss(loss, num_choices):
    """Divide `loss` by `num_choices`, but guard against `num_choices` being 0."""
    return loss / max(num_choices, 1.0) 
SMALL_NUMBER, BIG_NUMBER = 1e-7, 1e7
def compute_neglogprob_for_multihot_objective(
    logprobs,
    multihot_labels,
    per_decision_num_correct_choices,
):
    # Normalise by number of correct choices and mask out entries for wrong decisions:
    return -(
        (logprobs + torch.log(per_decision_num_correct_choices + SMALL_NUMBER))
        * multihot_labels
        / (per_decision_num_correct_choices + SMALL_NUMBER)
    )

def compute_node_type_selection_loss(
    node_type_logits,
    node_type_multihot_labels
):
    per_node_decision_logprobs = torch.nn.functional.log_softmax(node_type_logits, dim = -1)
    # Shape: [NTP, NT + 1]
    
    # number of correct choices for each of the partial graphs that require node choices
    per_node_decision_num_correct_choices = torch.sum(node_type_multihot_labels, keepdim = True, axis = -1)
    # Shape [NTP, 1]
    
    per_correct_node_decision_normalised_neglogprob = compute_neglogprob_for_multihot_objective(
        logprobs = per_node_decision_logprobs[:, :-1], # separate out the no node prediction
        multihot_labels = node_type_multihot_labels,
        per_decision_num_correct_choices = per_node_decision_num_correct_choices,
    ) # Shape [NTP, NT]
    
    no_node_decision_correct =(per_node_decision_num_correct_choices == 0.0).sum()  # Shape [NTP]
    per_correct_no_node_decision_neglogprob = -(
        per_node_decision_logprobs[:, -1]
        * torch.squeeze(no_node_decision_correct).type(torch.FloatTensor)
    )  # Shape [NTP]

#     if self._node_type_loss_weights is not None:
#         per_correct_node_decision_normalised_neglogprob *= self._node_type_loss_weights[:-1]
#         per_correct_no_node_decision_neglogprob *= self._node_type_loss_weights[-1]
    
    # Loss is the sum of the masked (no) node decisions, averaged over number of decisions made:
    total_node_type_loss = torch.sum(
        per_correct_node_decision_normalised_neglogprob
    ) + torch.sum(per_correct_no_node_decision_neglogprob)
    node_type_loss = safe_divide_loss(
        total_node_type_loss, node_type_multihot_labels.shape[0]
    )

    return node_type_loss

In [37]:
node_type_selection_loss = compute_node_type_selection_loss(
    node_type_logits = output,
    node_type_multihot_labels = node_type_multihot_labels
)

In [38]:
node_type_selection_loss.backward()


## Pick edge and edge type + Edge candidate loss

1. Find out what stop node refers to

In [39]:
def pick_edge(
    input_molecule_representations,
    graph_representations,
    node_representations,
    graph_to_focus_node_map,
    node_to_graph_map,
    candidate_edge_targets,
    candidate_edge_features,
    graphs_requiring_node_choices
):
    # given candidate edges in the partial graphs, 
    #compute logits for the likelihood of adding candidates as well as 
    # logits for the type of edge if it is picked

IndentationError: expected an indented block (2135581705.py, line 13)

In [None]:
pprint_pyg_obj(batch)

In [None]:
focus_node_idx_in_batch = batch.ptr[:-1] + batch.focus_node # offset each focus node idx by the starting idx of the molecule

In [None]:
node_representations[focus_node_idx_in_batch].shape # extract focus node node level repr

In [None]:
focus_node_representations = node_representations[focus_node_idx_in_batch]

In [None]:
graph_and_focus_node_representations = torch.cat(
    (input_molecule_representations, partial_graph_representions, focus_node_representations),
    axis = -1
)

# Explanation: at each step, there is a focus node, which is the node we are 
# focusing on right now in terms of adding another edge to it. When adding a new
# edge, the edge can be between the focus node and a variety of other nodes.
# This is likely based on valency, and in reality, it is possible that none of the
# edge choices are correct (when that generation step is a node addition step)
# and not an edge addition step. Regardless, we still want to consider the candidates
#"target" refers to the node at the other end of the candidate edge



In [None]:
batch.valid_edge_choices[:, 1] # the 2nd element in each edge is the target since
# the 1st element is the focus node at that step

In [None]:
candidate_edge_target_node_idx = batch.valid_edge_choices[:, 1] + batch.ptr[batch.valid_edge_choices_batch]  # idx offset for the edge choices

In [None]:
partial_graph_idx_per_edge_candidate = batch.batch[candidate_edge_target_node_idx]

In [None]:
graph_and_focus_node_representations = torch.cat(
    (input_molecule_representations, partial_graph_representions, focus_node_representations),
    axis = -1
)

In [None]:
graph_and_focus_node_representations_per_edge_candidate = graph_and_focus_node_representations[partial_graph_idx_per_edge_candidate]

In [None]:
edge_candidate_target_node_representations = node_representations[candidate_edge_target_node_idx]

In [None]:
candidate_edge_features = batch.edge_features

In [None]:
distance_truncation = 10
candidate_edge_features[:, 0]

In [None]:
# The zeroth element of edge_features is the graph distance. We need to look that up
# in the distance embeddings:
truncated_distances = torch.minimum(
    candidate_edge_features[:, 0],
    torch.ones(len(candidate_edge_features)) * (distance_truncation - 1),
)  # shape: [CE]

In [None]:
truncated_distances = truncated_distances.type(torch.LongTensor)

In [None]:
# since we want to truncate the distance, we should have an embedding layer for it
distance_embedding_layer = torch.nn.Embedding(distance_truncation, 1)

distance_embedding = distance_embedding_layer(truncated_distances)

edge_candidate_representation = torch.cat(
    (
        graph_and_focus_node_representations_per_edge_candidate, 
        edge_candidate_target_node_representations,
        distance_embedding,
        candidate_edge_features[:, 1:],
    ),
    axis = -1
)


In [None]:
edge_candidate_representation.shape

In [None]:
graph_and_focus_node_representations.shape[0]

In [None]:
 torch.nn.Parameter(torch.FloatTensor(1,node_representations.shape[-1] + batch.edge_features.shape[-1]), requires_grad = True).shape

In [None]:
_no_more_edges_representation = torch.nn.Parameter(torch.FloatTensor(1,node_representations.shape[-1] + batch.edge_features.shape[-1]), requires_grad = True)
# Calculate the stop node features as well.
num_graphs_in_batch= graph_and_focus_node_representations.shape[0]
stop_edge_selection_representation = torch.cat(
    [
        graph_and_focus_node_representations,
        torch.tile(
            _no_more_edges_representation,
            dims=(num_graphs_in_batch, 1),
        ),
    ],
    axis=-1,
)  # shape: [PG, MD + PD + 2 * VD*(num_layers+1) + FD]

edge_candidate_and_stop_features = torch.cat(
    [edge_candidate_representation, stop_edge_selection_representation], axis=0
)  # shape: [CE + PG, MD + PD + 2 * VD*(num_layers+1) + FD]


from torch.nn import Linear
class EdgeCandidateScorer(torch.nn.Module):
    """
    For choosing an atom/motif out of the motif vocabulary
    Notes:
    Softmax layer at the end with 
    
    """
    def __init__(self, latent_vector_dim, output_size = 1, hidden_channels=64):
        super(EdgeCandidateScorer, self).__init__()
        self.hidden_layer1 = Linear(latent_vector_dim, hidden_channels)
        self.hidden_layer2 = Linear(hidden_channels, output_size) # add 1 for <END OF GENERATION TOKEN>

        
    def forward(self, original_and_calculated_graph_representations):
        x = self.hidden_layer1(original_and_calculated_graph_representations)
        x = self.hidden_layer2(x)
        return x

class EdgeTypeSelector(torch.nn.Module):
    """
    For choosing an atom/motif out of the motif vocabulary
    Notes:
    Softmax layer at the end with 
    
    """
    def __init__(self, latent_vector_dim,num_edge_types = 3, hidden_channels=64):
        super(EdgeTypeSelector, self).__init__()
        self.hidden_layer1 = Linear(latent_vector_dim, hidden_channels)
        self.hidden_layer2 = Linear(hidden_channels, num_edge_types) # add 1 for <END OF GENERATION TOKEN>

        
    def forward(self, original_and_calculated_graph_representations):
        x = self.hidden_layer1(original_and_calculated_graph_representations)
        x = self.hidden_layer2(x)
        return x

_edge_candidate_scorer = EdgeCandidateScorer(latent_vector_dim = edge_candidate_and_stop_features.shape[-1])
_edge_type_selector = EdgeTypeSelector(latent_vector_dim = edge_candidate_and_stop_features.shape[-1])
edge_candidate_logits = torch.squeeze(
    _edge_candidate_scorer(edge_candidate_and_stop_features),
    axis=-1,
)  # shape: [CE + PG]
edge_type_logits = _edge_type_selector(
    edge_candidate_representation
)  # shape: [CE, ET]

In [None]:
edge_candidate_logits.shape, edge_type_logits.shape

In [None]:
batch.correct_edge_choices.shape

In [None]:
batch.correct_edge_choices

In [None]:
batch.correct_edge_choices_batch

In [None]:
edge_candidate_to_graph_map = batch.batch[candidate_edge_target_node_idx]
# add the end bond labels to the end 
edge_candidate_to_graph_map = torch.cat((edge_candidate_to_graph_map, torch.arange(0, num_graphs_in_batch)))


In [None]:
from torch_geometric.utils import scatter 
def traced_unsorted_segment_log_softmax(
    logits, #edge_candidate_logits
    segment_ids, #edge_candidate_to_graph_map
    num_segments, # num_graphs_in_batch
):
    
    max_per_segment = scatter(logits, segment_ids, reduce = 'max')
    scattered_maxes = max_per_segment[segment_ids]
    recentered_scores = logits - scattered_maxes
    exped_recentered_scores = torch.exp(recentered_scores)

    per_segment_sums = scatter(exped_recentered_scores, segment_ids, reduce = 'sum')
    per_segment_normalization_consts = torch.log(per_segment_sums)

    log_probs = recentered_scores - per_segment_normalization_consts[segment_ids]
    return log_probs

In [None]:
edge_candidate_logprobs = traced_unsorted_segment_log_softmax(
    logits = edge_candidate_logits,
    segment_ids = edge_candidate_to_graph_map,
    num_segments = num_graphs_in_batch
)

In [None]:
def compute_edge_candidate_selection_loss(
    num_graphs_in_batch, # correct_node_type_choices
    node_to_graph_map, #batch.batch
    candidate_edge_targets, # batch_features["valid_edge_choices"][:, 1]
    edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices, # batch.num_correct_edge_choices
    edge_candidate_correctness_labels, # correct edge choices
    no_edge_selected_labels # stop node label
):
    
    # First, we construct full labels for all edge decisions, which are the concat of
    # edge candidate logits and the logits for choosing no edge:
    edge_correctness_labels = torch.cat(
        [edge_candidate_correctness_labels, no_edge_selected_labels.float()],
        axis=0,
    )  # Shape: [CE + PG]

    # To compute a softmax over all candidate edges (and the "no edge" choice) corresponding
    # to the same graph, we first need to build the map from each logit to the corresponding
    # graph id. Then, we can do an unsorted_segment_softmax using that map:
    edge_candidate_to_graph_map = batch.batch[candidate_edge_target_node_idx]
    # add the end bond labels to the end 
    edge_candidate_to_graph_map = torch.cat((edge_candidate_to_graph_map, torch.arange(0, num_graphs_in_batch)))

    edge_candidate_logprobs = traced_unsorted_segment_log_softmax(
        logits=edge_candidate_logits,
        segment_ids=edge_candidate_to_graph_map,
        num_segments=num_graphs_in_batch,
    )  # Shape: [CE + PG]

    # Compute the edge loss with the multihot objective.
    # For a single graph with three valid choices (+ stop node) of which two are correct,
    # we may have the following:
    #  edge_candidate_logprobs = log([0.05, 0.5, 0.4, 0.05])
    #  per_graph_num_correct_edge_choices = [2]
    #  edge_candidate_correctness_labels = [0.0, 1.0, 1.0]
    #  edge_correctness_labels = [0.0, 1.0, 1.0, 0.0]
    # To get the loss, we simply look at the things in edge_candidate_logprobs that correspond
    # to correct entries.
    # However, to account for the _multi_hot nature, we scale up each entry of
    # edge_candidate_logprobs by the number of correct choices, i.e., consider the
    # correct entries of
    #  log([0.05, 0.5, 0.4, 0.05]) + log([2, 2, 2, 2]) = log([0.1, 1.0, 0.8, 0.1])
    # In this form, we want to have each correct entry to be as near possible to 1.
    # Finally, we normalise loss contributions to by-graph, by dividing the crossentropy
    # loss by the number of correct choices (i.e., in the example above, this results in
    # a loss of -((log(1.0) + log(0.8)) / 2) = 0.11...).

    # Note: per_graph_num_correct_edge_choices does not include the choice of an edge to
    # the stop node, so can be zero.
    per_graph_num_correct_edge_choices = torch.max(
        per_graph_num_correct_edge_choices, torch.ones(per_graph_num_correct_edge_choices.shape)
    )  # Shape: [PG]


    per_edge_candidate_num_correct_choices = per_graph_num_correct_edge_choices[edge_candidate_to_graph_map]
    # Shape: [CE]
    per_correct_edge_neglogprob = -(
        (edge_candidate_logprobs + torch.log(per_edge_candidate_num_correct_choices))
        * edge_correctness_labels
        / per_edge_candidate_num_correct_choices
    )  # Shape: [CE]
    
    # Normalise by number of graphs for which we made edge selection decisions:
    edge_loss = safe_divide_loss(
        torch.sum(per_correct_edge_neglogprob), num_graphs_in_batch
    )
    print(edge_correctness_labels, per_edge_candidate_num_correct_choices)

    return edge_loss


In [None]:
compute_edge_candidate_selection_loss(
    num_graphs_in_batch = len(batch.ptr) -1, # correct_node_type_choices
    node_to_graph_map = batch.batch, #batch.batch
    candidate_edge_targets = batch.valid_edge_choices[:, 1], # batch_features["valid_edge_choices"][:, 1]
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices = batch.num_correct_edge_choices, # batch.num_correct_edge_choices
    edge_candidate_correctness_labels = batch.correct_edge_choices, # correct edge choices
    no_edge_selected_labels = batch.stop_node_label.float() # stop node label
)

## compute_edge_type_selection_loss 

In [None]:
correct_target_indices = batch.correct_edge_choices != 0
edge_type_logits_for_correct_edges = edge_type_logits[correct_target_indices]

In [None]:
batch.correct_edge_choices.shape

In [None]:
BIG_NUMBER = 1e7


In [None]:
valid_edge_types = batch.valid_edge_types
edge_type_onehot_labels = batch.correct_edge_types


# The `valid_edge_types` tensor is equal to 1 when the edge is valid (it may be invalid due
# to valency constraints), 0 otherwise.
# We want to multiply the selection probabilities by this mask. Because the logits are in
# log space, we instead subtract a large value from the logits wherever this mask is zero.
scaled_edge_mask = (1-valid_edge_types.float()) * BIG_NUMBER # Shape: [CCE, ET]

In [None]:
masked_edge_type_logits = (
    edge_type_logits_for_correct_edges - scaled_edge_mask
)  # Shape: [CCE, ET]

In [None]:
torch.nn.functional.cross_entropy(edge_type_onehot_labels.float(), masked_edge_type_logits)

In [None]:
edge_type_loss= torch.nn.CrossEntropyLoss(reduction = 'none')(edge_type_onehot_labels.float(), masked_edge_type_logits)

In [None]:
# Normalise by the number of edges for which we needed to pick a type:
# instead of mean, we must use safe divide because the batch can have zero edges 
# requring edge types
edge_type_loss = safe_divide_loss(
    torch.sum(edge_type_loss), len(edge_type_loss)
)

In [None]:
edge_type_loss

In [None]:
batch

# pick attachement point + compute pick attachement point loss

In [None]:
def pick_attachement_point(
    input_molecule_representations, # as is
    graph_representations, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map, # batch.batch
    candidate_attachment_points, # 
):
    

In [None]:
tmp = []
for batch in loader:
    if len(batch.correct_attachment_point_choice) > 0 :
        tmp.append(batch)
        

In [None]:
batch = tmp[0]

In [None]:
batch.valid_attachment_point_choices_ptr

In [None]:
batch.valid_attachment_point_choices_batch

In [None]:
batch.valid_attachment_point_choices_ptr.shape, batch.valid_attachment_point_choices.shape

In [None]:
model = GraphEncoder(input_feature_dim = batch.x.shape[-1])
input_molecule_representations, _ = model(batch.x, batch.edge_index, batch.edge_type.int(), batch.batch)


decoder_gnn = GraphEncoder(input_feature_dim = batch.x.shape[-1])
partial_graph_representions, node_representations = model(batch.x, batch.edge_index, batch.edge_type.int(), batch.batch)

In [None]:
original_and_calculated_graph_representations = torch.cat(
    [input_molecule_representations, partial_graph_representions],
    axis=-1,
)  # Shape: [PG, MD + PD]

In [None]:
candidate_attachment_pt_corresp_partial_graph_idx = batch.ptr[batch.valid_attachment_point_choices_batch]
candidate_attachment_pt_node_idx = (attachment_pt_corresp_partial_graph_idx + batch.valid_attachment_point_choices).long()

In [None]:
# Extract the partial graphs that have attachment point node candidates.
partial_graphs_for_attachment_point_choices = batch.valid_attachment_point_choices_batch

# Shape: [CA]


# To score an attachment point, we condition on the representations of input and partial
# graphs, along with the representation of the attachment point candidate in question.
attachment_point_representations = torch.cat(
    [
        original_and_calculated_graph_representations[partial_graphs_for_attachment_point_choices],
        node_representations[attachment_pt_node_idx],
    ],
    axis=-1,
)  # Shape: [CA, MD + PD + VD*(num_layers+1)]



class AttachmentPointScorer(torch.nn.Module):
    """
    For choosing an atom/motif out of the motif vocabulary
    Notes:
    Softmax layer at the end with 
    
    """
    def __init__(self, latent_vector_dim, output_size = 1, hidden_channels=64):
        super(AttachmentPointScorer, self).__init__()
        self.hidden_layer1 = Linear(latent_vector_dim, hidden_channels)
        self.hidden_layer2 = Linear(hidden_channels, output_size) # add 1 for <END OF GENERATION TOKEN>

        
    def forward(self, original_and_calculated_graph_representations):
        x = self.hidden_layer1(original_and_calculated_graph_representations)
        x = self.hidden_layer2(x)
        return x







_attachment_point_selector = AttachmentPointScorer(latent_vector_dim = attachment_point_representations.shape[-1])
attachment_point_selection_logits = torch.squeeze(_attachment_point_selector(attachment_point_representations), axis = -1)
# Shape: [CA]


In [None]:
attachment_point_selection_logits

In [None]:
def compute_attachment_point_selection_loss(
    num_graphs_in_batch, # len(batch.ptr)- 1
    node_to_graph_map, # batch.batch
    attachment_point_selection_logits, # as is
    attachment_point_candidate_choices, # valid_attachment_point_choices
    attachment_point_correct_choices, # correct_attachment_point_choices
)

In [None]:
attachment_point_candidate_to_graph_map = batch.valid_attachment_point_choices_batch
# Shape: [CA]

# Compute log softmax of the logits within each partial graph.
attachment_point_candidate_logprobs = (
    traced_unsorted_segment_log_softmax(
        logits=attachment_point_selection_logits,
        segment_ids=attachment_point_candidate_to_graph_map,
        num_segments=num_graphs_in_batch,
    )
    * 1.0
)  # Shape: [CA]

In [None]:
attachment_point_candidate_logprobs

In [None]:
partial_graph_idx_requiring_attachement_pts, lengths_of_attachement_point_choices_per_graph = batch.valid_attachment_point_choices_batch.unique(return_counts = True)
attachement_pt_corrrect_choices_offset = lengths_of_attachement_point_choices_per_graph[:-1]


attachment_point_correct_choices = batch.correct_attachment_point_choice
# offset the correct choices from the 2nd idx onwards due to batching
attachment_point_correct_choices[1:]  = attachement_pt_corrrect_choices_offset

attachment_point_correct_choice_neglogprobs = -attachment_point_candidate_logprobs[(attachment_point_correct_choices).long()]
 # Shape: [AP]

In [None]:
attachment_point_correct_choice_neglogprobs

In [None]:
attachment_point_selection_loss = safe_divide_loss(
    (attachment_point_correct_choice_neglogprobs).sum(),
    attachment_point_correct_choice_neglogprobs.shape[0],
)

In [None]:
attachment_point_selection_loss

## TODO:
1. (DONE) Investigate why focus_node, graph properties and attachment points are all empty

    edge_loss = self.compute_edge_candidate_selection_loss(
        num_graphs_in_batch=batch_features["num_partial_graphs_in_batch"],
        node_to_graph_map=batch_features["node_to_partial_graph_map"],
        candidate_edge_targets=batch_features["valid_edge_choices"][:, 1],
        edge_candidate_logits=task_output.edge_candidate_logits,
        per_graph_num_correct_edge_choices=batch_labels["num_correct_edge_choices"],
        edge_candidate_correctness_labels=batch_labels["correct_edge_choices"],
        no_edge_selected_labels=batch_labels["stop_node_label"],
    )

Shape abbreviations used throughout the model:
- PG = number of _p_artial _g_raphs
- PV = number of _p_artial graph _v_ertices
- PD = size of _p_artial graph _representation _d_imension
- VD = GNN _v_ertex representation _d_imension
- MD = _m_olecule representation _d_imension
- EFD = _e_dge _f_eature _d_imension
- NTP = number of partial graphs requiring a _n_ode _t_ype _p_ick
- NT = number of _n_ode _t_ypes
- CE = number of _c_andidate _e_dges
- CCE = number of _c_orrect _c_andidate _e_dges
- ET = number of _e_dge _t_ypes
- CA = number of _c_andidate _a_ttachment points
- AP = number of _a_ttachment _p_oint choices

# TODO

1. Implement Sampling strategy
2. Look into why first node prediction is treated separately
3. Add dropout layers
4. investigate why the losses are all so huge

Current design plan: all model specific things are in the BaseModel class, 
All loss computation and what not will be in the lightning module 
This decouples the training code, optimizers and schedulers from the model itself, which arguably makes it cleaner when doing hyperparameter tuning and also during model loading





Note: Original implementation of the GNN adds each molecule and their corresponding generation steps sequentially into the batch, so what happens is that a molecule's sequential generation steps are seen by the GNN during training in a sequential manner. In this implementation, if we use shuffle = False, we can replicate this behaviour in our dataloader, but if we use shuffle = True, then we can't. 

Possible additions to allow for molecule wise shuffling => add additional column to CSV to allow for a molecule's generation steps to be treated as one set of training samples => then shuffle this instead of the actual generation steps (try this if the training is not converging)

This has additional implications for the latent space sampling strategies.

In [None]:
from dataclasses import dataclass

@dataclass
class MoLeROutput:
    """Class for keeping track of output from MoLeR model."""
    node_type_logits
    edge_candidate_logits
    edge_type_logits
    attachment_point_selection_logits
    p
    q

# Follow implementation here: https://github.com/Lightning-AI/lightning-bolts/blob/master/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py

In [None]:
from molecule_generation.utils.training_utils import get_class_balancing_weights
from pytorch_lightning import LightningModule, Trainer, seed_everything



class BaseModel(LightningModule):
    def __init__(self, params, dataset):
        """Params is a nested dictionary with the relevant parameters."""
        super(BaseModel, self).__init__()
        self._init_params(params, dataset)
        self._motif_aware_embedding_layer = Embedding(params['motif_aware_embedding'])
        self._encoder_gnn = FullGraphEncoder(params['encoder_gnn'])
        self._decoder_gnn = PartialGraphEncoder(params['decoder_gnn'])
        self._decoder_mlp = MLPDecoder(params['decoder_mlp'])
        
        # params for latent space
        self._latent_sample_strategy = params['latent_sample_strategy']
        self._latent_repr_dim = params["latent_repr_size"]
        
        
               
    def _init_params(self, params, dataset):
        """
        Initialise class weights for next node prediction and placefolder for
        motif/node embeddings.
        """
        
        # Get some information out from the dataset:
        next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
        class_weight_factor = self._params.get("node_type_predictor_class_loss_weight_factor", 0.0)
        
        if not (0.0 <= class_weight_factor <= 1.0):
            raise ValueError(
                f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
            )
        if class_weight_factor > 0:
            atom_type_nums = [
                next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
                for type_idx in range(dataset.num_node_types)
            ]
            atom_type_nums.append(next_node_type_distribution["None"])

            self.class_weights = get_class_balancing_weights(
                class_counts=atom_type_nums, class_weight_factor=class_weight_factor
            )
        else:
            self.class_weights = None
            
        motif_vocabulary = dataset.metadata.get("motif_vocabulary")
        self._uses_motifs = motif_vocabulary is not None

        self._node_categorical_num_classes = dataset.node_categorical_num_classes
        
        
        if self.uses_categorical_features:
            if "categorical_features_embedding_dim" in self._params:
                self._node_categorical_features_embedding = None
        
    @property
    def uses_motifs(self):
        return self._uses_motifs

    @property
    def uses_categorical_features(self):
        return self._node_categorical_num_classes is not None

    @property
    def decoder(self):
        return self._decoder

    @property
    def encoder(self):
        return self._encoder
    
    @property
    def motif_aware_embedding_layer(self):
        return self._motif_aware_embedding_layer
    
    @property
    def latent_dim(self):
        return self._latent_repr_dim
    
    def compute_initial_node_features(batch )
        # Compute embedding
        pass
        
    
    def sample_from_latent_repr(latent_repr):
        # perturb latent repr
        mu = latent_repr[:, : self.latent_dim]  # Shape: [V, MD]
        log_var = latent_repr[:, self.latent_dim :]  # Shape: [V, MD]

        # result_representations: shape [num_partial_graphs, latent_repr_dim]
        p, q, z = self.sample(mu, log_var)
        
        return p, q, z 
        
    def sample(self, mu, log_var)
        """Samples a different noise vector for each partial graph. 
        TODO: look into the other sampling strategies."""
        std = torch.exp(log_var / 2)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return p, q, z
    
    
    def forward(self, x, edge_index, edge_attr, batch, ??):
        # Obtain node embeddings 
        batch = self.compute_initial_node_features(batch)
        
        # Forward pass through encoder
        latent_repr = self.encoder(batch)
        
        # Apply latent sampling strategy
        p, q, latent_repr = self.sample_from_latent_repr(latent_repr)
        
        # Forward pass through decoder
        node_type_logits, edge_candidate_logits, edge_type_logits, attachment_point_selection_logits = self.decoder(latent_repr)
        
        # NOTE: loss computation will be done in lightning module
        return MoLeROutput(
            node_type_logits = node_type_logits,
            edge_candidate_logits = edge_candidate_logits,
            edge_type_logits = edge_type_logits,
            attachment_point_selection_logits = attachment_point_selection_logits,
            p = p,
            q = q,
        )
    
        
        
        
#         x = self.conv1(x, edge_index, edge_attr)
#         x = x.relu()
#         x = self.conv2(x, edge_index, edge_attr)
#         x = x.relu()
#         x = self.conv3(x, edge_index, edge_attr)
#         return x
#     def _compute_decoder_loss_and_metrics(
#         self, batch_features, task_output, batch_labels
#     ) -> Tuple[tf.Tensor, MoLeRDecoderMetrics]:

#         decoder_metrics = self.decoder.compute_metrics(
#             batch_features=batch_features, batch_labels=batch_labels, task_output=task_output
#         )

#         total_loss = (
#             self._params["node_classification_loss_weight"]
#             * decoder_metrics.node_classification_loss
#             + self._params["first_node_classification_loss_weight"]
#             * decoder_metrics.first_node_classification_loss
#             + self._params["edge_selection_loss_weight"] * decoder_metrics.edge_loss
#             + self._params["edge_type_loss_weight"] * decoder_metrics.edge_type_loss
#         )

#         if self.uses_motifs:
#             total_loss += (
#                 self._params["attachment_point_selection_weight"]
#                 * decoder_metrics.attachment_point_selection_loss
#             )

#         return total_loss, decoder_metrics

In [84]:
class GraphEncoder(torch.nn.Module):
    """For constructing graph level embedding during encoder step"""
    def __init__(self, num_node_features, hidden_channels):
        super(GraphEncoder, self).__init__()
        # torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, edge_attr, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_attr)
        x = x.leaky_relu()
        x = self.conv2(x, edge_index, edge_attr)
        x = x.leaky_relu()
        x = self.conv3(x, edge_index, edge_attr)
        x = 
        return x


class PartialGraphEncoder(torch.nn.Module):
    """For constructing graph level embedding during decoder steps"""
    def __init__(self, num_node_features, hidden_channels):
        super(PartialGraphEncoder, self).__init__()
        # torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, edge_attr, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_attr)
        return x
  

class PickAtomOrMotifMLP(torch.nn.Module):
    """
    For choosing an atom/motif out of the motif vocabulary
    Notes:
    Softmax layer at the end with 
    
    """
    def __init__(self, motif_vocabulary, latent_vector_dim, hidden_channels=256):
        super(PickAtomOrMotifMLP, self).__init__()
        self.hidden_layer1 = Linear(latent_vector_dim, hidden_channels)
        self.hidden_layer2 = Linear(hidden_channels, len(motif_vocabulary) + 1) # add 1 for <END OF GENERATION TOKEN>
        self.activation = torch.Softmax
        
    def forward(self, latent_vector, partial_graph_embedding):
        x = self.hidden_layer1(latent_vector)
        x = self.hidden_layer2(x)
        return self.activation(x)
        

class PickAttachmentMLP(torch.nn.Module):
    def __init__(self, motif_dictionary, hidden_channels):
        super(PickAttachmentMLP, self).__init__()
        pass
    def forward(self, latent_vector, partial_graph_embedding):
        pass

class PickBond(torch.nn.Module):
    def __init__(self, motif_dictionary, hidden_channels):
        super(PickAttachmentMLP, self).__init__()
        pass
    def forward(self, latent_vector, partial_graph_embedding):
        pass

# TODOs

1. Implement the learned aggregation function in Moler Paper by subclassing MessagePassing layer 
2. Change the init function of the encoder to take in the relevant hyperparameters from the `moler_vae` file
3. Figure out which of the latent sample strategies work best:
- pass through
- per graph
- per partial graph