In [1]:
import pandas as pd
df = pd.read_csv('/data/ongh0068/l1000/pyg_output_playground/train/processed_file_paths.csv')

In [2]:
import os
os.chdir('/data/ongh0068/l1000/moler_reference')

In [3]:
from torch_geometric.data import Dataset
import torch
import numpy as np
import os
import gzip
import pickle
from molecule_generation.chem.motif_utils import get_motif_type_to_node_type_index_map
from torch_geometric.data import Data

def get_motif_type_to_node_type_index_map(
    motif_vocabulary, num_atom_types
):
    """Helper to construct a mapping from motif type to shifted node type."""

    return {
        motif: num_atom_types + motif_type
        for motif, motif_type in motif_vocabulary.vocabulary.items()
    }


class MolerDataset(Dataset):
    def __init__(
        self, 
        root, 
        raw_moler_trace_dataset_parent_folder, # absolute path 
        output_pyg_trace_dataset_parent_folder, # absolute path
        split = 'train',
        transform=None, 
        pre_transform=None, 
    ):
        self._processed_file_paths = None
        self._transform = transform 
        self._pre_transform = pre_transform
        self._raw_moler_trace_dataset_parent_folder = raw_moler_trace_dataset_parent_folder
        self._output_pyg_trace_dataset_parent_folder = output_pyg_trace_dataset_parent_folder
        self._split = split
        self.load_metadata()
        super().__init__(root, transform, pre_transform)
        

    @property
    def raw_file_names(self):
        """
        Raw generation trace files output from the preprocess function of the cli. These are zipped pickle
        files. This is the actual file name without the parent folder.
        """
        raw_pkl_file_folders = [folder for folder in os.listdir(self._raw_moler_trace_dataset_parent_folder) if folder.startswith(self._split)]

        assert len(raw_pkl_file_folders) > 0, f'{self._raw_moler_trace_dataset_parent_folder} does not contain {self._split} files.'
        
        raw_generation_trace_files = []
        for folder in raw_pkl_file_folders:
            for pkl_file in os.listdir(os.path.join(self._raw_moler_trace_dataset_parent_folder, folder)):
                raw_generation_trace_files.append(os.path.join(self._raw_moler_trace_dataset_parent_folder, folder, pkl_file))
        return raw_generation_trace_files

    @property
    def processed_file_names(self):
        """Processed generation trace objects that are stored as .pt files"""
        processed_file_paths_folder = os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split)
        if self._processed_file_paths is not None:
            return self._processed_file_paths
        
        if not os.path.exists(processed_file_paths_folder):
            os.mkdir(processed_file_paths_folder)
            
        # After processing, the file paths will be saved in the csv file 
        processed_file_paths_csv = os.path.join(processed_file_paths_folder, 'processed_file_paths.csv')

        if not os.path.exists(processed_file_paths_csv):
            is_csv_generated = self._generate_processed_file_paths_csv(processed_file_paths_folder, processed_file_paths_csv)
            if not is_csv_generated:
                return []
        self._processed_file_paths = pd.read_csv(processed_file_paths_csv)['file_names'].tolist()
        return self._processed_file_paths 

    def _generate_processed_file_paths_csv(self, processed_file_paths_folder, processed_file_paths_csv):
        
        file_paths = [os.path.join(processed_file_paths_folder, file_path) for file_path in os.listdir(processed_file_paths_folder)]
        if len(file_paths) > 0:
            df = pd.DataFrame(file_paths, columns = ['file_names'])
            df.to_csv(processed_file_paths_csv, index = False)
            return True
        else:
            print('No processed files found!')

    @property
    def processed_file_names_size(self):
        return len(self.processed_file_names)
    
    @property 
    def metadata(self):
        return self._metadata

    @property
    def node_type_index_to_string(self):
        return self._node_type_index_to_string
    
    @property 
    def num_node_types(self):
        return len(self.node_type_index_to_string)

    def node_type_to_index(self, node_type):
        return self._atom_type_featuriser.type_name_to_index(node_type)

    def node_types_to_indices(self, node_types):
        """Convert list of string representations into list of integer indices."""
        return [self.node_type_to_index(node_type) for node_type in node_types]

    def node_types_to_multi_hot(self, node_types):
        """Convert between string representation to multi hot encoding of correct node types.

        Note: implemented here for backwards compatibility only.
        """
        correct_indices = self.node_types_to_indices(node_types)
        multihot = np.zeros(shape=(self.num_node_types,), dtype=np.float32)
        for idx in correct_indices:
            multihot[idx] = 1.0
        return multihot
    
    def node_type_to_index(self, node_type):
        motif_node_type_index = self._motif_to_node_type_index.get(node_type)

        if motif_node_type_index is not None:
            return motif_node_type_index
        else:
            return self._atom_type_featuriser.type_name_to_index(node_type)
    
    def load_metadata(self):
        metadata_file_path = os.path.join(self._raw_moler_trace_dataset_parent_folder, 'metadata.pkl.gz')
        
        with gzip.open(metadata_file_path, 'rb') as f:
             self._metadata = pickle.load(f)
        
        self._atom_type_featuriser = next(
            featuriser
            for featuriser in self._metadata["feature_extractors"]
            if featuriser.name == "AtomType"
        )
        
        self._node_type_index_to_string = self._atom_type_featuriser.index_to_atom_type_map.copy()
        self._motif_vocabulary = self.metadata.get("motif_vocabulary")

        if self._motif_vocabulary is not None:
            self._motif_to_node_type_index = get_motif_type_to_node_type_index_map(
                motif_vocabulary=self._motif_vocabulary,
                num_atom_types=len(self._node_type_index_to_string),
            )

            for motif, node_type in self._motif_to_node_type_index.items():
                self._node_type_index_to_string[node_type] = motif
        else:
            self._motif_to_node_type_index = {}
        

    def process(self):
        """Convert raw generation traces into individual .pt files for each of the trace steps."""
        # only call process if it was not called before
        if self.processed_file_names_size > 0:
            pass
        else:
            
            for pkl_file_path in self.raw_file_names:
                generation_steps = self._convert_data_shard_to_list_of_trace_steps(pkl_file_path)
                
                for molecule_idx, molecule_gen_steps in generation_steps:
                    
                    for step_idx, step in enumerate(molecule_gen_steps):
                        file_name = f'{pkl_file_path.split("/")[-1].split(".")[0]}_mol_{molecule_idx}_step_{step_idx}.pt'
                        file_path = os.path.join(self._output_pyg_trace_dataset_parent_folder, self._split, file_name)
                        torch.save(step, file_path)
                        print(f'Processing {molecule_idx}, step {step_idx}')
            print(f'{self.processed_file_names_size} files generated')         
            

    def _convert_data_shard_to_list_of_trace_steps(self, pkl_file_path):
        # TODO: multiprocessing to speed this up
        generation_steps = []
        
        with gzip.open(pkl_file_path, 'rb') as f:
            molecules = pickle.load(f)
            for molecule_idx, molecule in enumerate(molecules): 
                generation_steps += [(molecule_idx, self._extract_generation_steps(molecule))]
        
        return generation_steps
            
    def _extract_generation_steps(self, molecule):
        """Packages each generation step of each molecule into a pyg Data object."""
        molecule_gen_steps = []
        molecule_property_values = dict(molecule.graph_property_values)
        for gen_step in molecule:
            gen_step_features = {}
            gen_step_features['x'] = gen_step.partial_node_features
            gen_step_features['focus_node'] = gen_step.focus_node

            # have an edge type attribute to tell apart each of the 3 bond types
            edge_indexes = []
            edge_types= []
            for i, adj_list in enumerate(gen_step.partial_adjacency_lists):
                if len(adj_list) != 0:
                    edge_index = torch.tensor(adj_list).T
                    edge_indexes += [edge_index]
                    edge_types += [i]*len(adj_list)

            gen_step_features['edge_index'] = torch.cat(edge_indexes, 1) if len(edge_indexes) > 0 else torch.tensor(edge_indexes)
            gen_step_features['edge_type'] = torch.tensor(edge_types)
            gen_step_features['correct_edge_choices'] = gen_step.correct_edge_choices

            num_correct_edge_choices = np.sum(gen_step.correct_edge_choices)
            gen_step_features['num_correct_edge_choices'] = num_correct_edge_choices
            gen_step_features['stop_node_label'] = int(num_correct_edge_choices == 0)
            gen_step_features['valid_edge_choices'] = gen_step.valid_edge_choices
            gen_step_features["correct_edge_types"] = gen_step.correct_edge_types
            gen_step_features["partial_node_categorical_features"] = gen_step.partial_node_categorical_features
            if gen_step.correct_attachment_point_choice is not None:
                gen_step_features["correct_attachment_point_choice"] = list(gen_step.valid_attachment_point_choices).index(gen_step.correct_attachment_point_choice)
            else:
                gen_step_features["correct_attachment_point_choice"] = []
            gen_step_features["valid_attachment_point_choices"] = gen_step.valid_attachment_point_choices

            # And finally, the correct node type choices. Here, we have an empty list of
            # correct choices for all steps where we didn't choose a node, so we skip that:
            if gen_step.correct_node_type_choices is not None:
                gen_step_features["correct_node_type_choices"] = self.node_types_to_multi_hot(gen_step.correct_node_type_choices)
            else:
                gen_step_features["correct_node_type_choices"] = []
            gen_step_features['correct_first_node_type_choices'] = self.node_types_to_multi_hot(molecule.correct_first_node_type_choices)

            # Add graph_property_values
            gen_step_features = {**gen_step_features, **molecule_property_values}
        
            molecule_gen_steps += [gen_step_features]
        molecule_gen_steps = self._to_tensor_moler(molecule_gen_steps)
        return [Data(**step) for step in molecule_gen_steps]
    
    def _to_tensor_moler(self, molecule_gen_steps):
        for i in range(len(molecule_gen_steps)):
            for k,v in molecule_gen_steps[i].items():
                molecule_gen_steps[i][k] = torch.tensor(molecule_gen_steps[i][k])
        return molecule_gen_steps
        
    def len(self):
        return self.processed_file_names_size

    def get(self, idx):
        file_path = self.processed_file_names[idx]
        data = torch.load(file_path)
        return data

2022-12-13 11:52:06.296405: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/l1000/trace_playground',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/pyg_output_playground',
    split = 'train',
)

In [5]:
dataset

MolerDataset(2384)

In [6]:
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=16, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices'
])

In [7]:
def pprint_pyg_obj(batch):
    for key in vars(batch)['_store'].keys():
        if key.startswith('_'):
            continue
        print(f'{key}: {batch[key].shape}')
for batch in loader:
    pprint_pyg_obj(batch)
    break

x: torch.Size([126, 32])
edge_index: torch.Size([2, 212])
focus_node: torch.Size([16])
edge_type: torch.Size([212])
correct_edge_choices: torch.Size([53])
correct_edge_choices_batch: torch.Size([53])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([53, 2])
valid_edge_choices_batch: torch.Size([53])
valid_edge_choices_ptr: torch.Size([17])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([126])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_attachment_point_choice_ptr: torch.Size([17])
valid_attachment_point_choices: torch.Size([0])
valid_attachment_point_choices_batch: torch.Size([0])
valid_attachment_point_choices_ptr: torch.Size([17])
correct_node_type_choices: torch.Size([1112])
correct_node_type_choi

In [8]:
# Get some information out from the dataset:
next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
atom_type_nums = [
    next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
    for type_idx in range(dataset.num_node_types)
]

In [9]:
batch.partial_node_categorical_features

tensor([ 16,  16,  16,  16,  16,  16,  16,  16, 129,  16,  16,  16,  16, 129,
         16,  16,  16,  16, 129, 129,  16,  16,  16,  16, 129, 129,  16,  16,
         16,  16, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129,  16,  16,
         16,  16, 129, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129, 129,
         16,  16,  16,  16, 129, 129, 129, 129, 129,  16,  16,  16,  16, 129,
        129, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129, 129, 129, 129,
         16,  16,  16,  16, 129, 129, 129, 129, 129, 129,  16,  16,  16,  16,
        129, 129, 129, 129, 129, 129,  16,  16,  16,  16, 129, 129, 129, 129,
        129, 129, 131,  16,  16,  16,  16, 129, 129, 129, 129, 129, 129, 131])

# Encoder
Depending on the representation of the molecular graph (whether bond type is represented as edge type or as edge attributes), we have to change the function signature of the forward method

1. Try out both
2. tune number of heads
3. Look into softmax aggregation (hyperparams) + also implement sigmoidaggregation as a separate aggregation layer
4. Set number of heads as a hyperparam
5. Add leakyrelu() + layer norm

In [10]:
vocab_size = len(dataset.node_type_index_to_string)
embedding_size = 64
for batch in loader:
    pprint_pyg_obj(batch)
    break
    
embed = torch.nn.Embedding(vocab_size, embedding_size)
motif_embeddings = embed(batch.partial_node_categorical_features)
batch.x = torch.cat((batch.x, motif_embeddings), axis = -1)

from torch_geometric.nn import RGATConv
resize_layer = RGATConv(
in_channels = batch.x.shape[-1], 
out_channels = 64, 
num_relations = 3
)


batch.x = resize_layer(batch.x, batch.edge_index.long(), batch.edge_type)


x: torch.Size([126, 32])
edge_index: torch.Size([2, 212])
focus_node: torch.Size([16])
edge_type: torch.Size([212])
correct_edge_choices: torch.Size([53])
correct_edge_choices_batch: torch.Size([53])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([53, 2])
valid_edge_choices_batch: torch.Size([53])
valid_edge_choices_ptr: torch.Size([17])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([126])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_attachment_point_choice_ptr: torch.Size([17])
valid_attachment_point_choices: torch.Size([0])
valid_attachment_point_choices_batch: torch.Size([0])
valid_attachment_point_choices_ptr: torch.Size([17])
correct_node_type_choices: torch.Size([1112])
correct_node_type_choi

In [46]:
from torch_geometric.nn import RGATConv
from torch_geometric.nn import aggr
class GraphEncoder(torch.nn.Module):
    """For constructing graph level embedding during encoder step"""
    def __init__(
        self, 
        input_feature_dim,
        node_feature_dim = 64,
        num_layers = 12,
        use_intermediate_gnn_results = True,
    ):
        super(GraphEncoder, self).__init__()
        self._first_layer = RGATConv(
            in_channels = input_feature_dim, 
            out_channels = node_feature_dim, 
            num_relations = 3
        )
        
        self._encoder_layers = torch.nn.ModuleList(
            [RGATConv(
                in_channels = node_feature_dim, 
                out_channels = node_feature_dim, 
                num_relations = 3
                ) for _ in range(num_layers)
            ]
        )
        self._softmax_aggr = aggr.SoftmaxAggregation(learn=True)
        self._use_intermediate_gnn_results = use_intermediate_gnn_results
    def forward(self, node_features, edge_index, edge_type, batch_index):
        gnn_results = []
        gnn_results += [self._first_layer(node_features, edge_index.long(), edge_type)]
        print(gnn_results[-1].shape)
        for i, layer in enumerate(self._encoder_layers):
            gnn_results += [layer(gnn_results[-1], edge_index.long(), edge_type)]
        
        if self._use_intermediate_gnn_results:
            x = torch.cat(gnn_results, axis = -1)
            graph_representations = self._softmax_aggr(x, batch_index)
        
        else:
            graph_representations = self._softmax_aggr(gnn_results[-1], batch_index)
        node_representations = torch.cat(gnn_results, axis= -1)
        return graph_representations, node_representations

In [47]:
model = GraphEncoder(input_feature_dim = batch.x.shape[-1])
input_molecule_representations, _ = model(batch.x, batch.edge_index.long(), batch.edge_type, batch.batch)

torch.Size([126, 64])


In [13]:
input_molecule_representations.shape

torch.Size([16, 832])

In [14]:
latent_dim = 512

In [15]:
class MeanAndLogVarMLP(torch.nn.Module):
    """For constructing graph level embedding during encoder step"""
    def __init__(
        self, 
        input_dim = 768,
        latent_dim = 512,
        num_layers = 1,
    ):
        super(MeanAndLogVarMLP, self).__init__()
        self._mean_log_var_mlp = torch.nn.Linear(input_dim, 2*latent_dim)

    def forward(self, x):
        x = self._mean_log_var_mlp(x)
        
        return x




In [17]:
mean_log_var_mlp = MeanAndLogVarMLP(input_dim = input_molecule_representations.shape[-1])
x = mean_log_var_mlp(input_molecule_representations)

In [18]:
x.shape

# this represents the aggregated graph representations after passing through 
# the mean and log var mlp => we need to sample from this distribution
# for each of the partial graphs. There are 16 partial graphs, so 
# we need to get a separate sample for each of them


torch.Size([16, 1024])

In [19]:
mu = x[:, : latent_dim]  # Shape: [V, MD]
log_var = x[:, latent_dim :]  # Shape: [V, MD]

# result_representations: shape [num_partial_graphs, latent_repr_dim]
std = torch.exp(log_var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()

In [20]:
z.shape

torch.Size([16, 512])

# Decoder
 Every MLP should condition on the latent representation (currently computed from the partial graph itself which is wrong, should be from the original full molecular graph)

## PickAtomOrMotif + Node type selection loss

In [48]:
decoder_gnn = GraphEncoder(input_feature_dim = batch.x.shape[-1])
partial_graph_representions, node_representations = model(batch.x, batch.edge_index.long(), batch.edge_type, batch.batch)

torch.Size([126, 64])


In [33]:
partial_graph_representions.shape

torch.Size([16, 832])

In [34]:
pprint_pyg_obj(batch)


x: torch.Size([126, 64])
edge_index: torch.Size([2, 212])
focus_node: torch.Size([16])
edge_type: torch.Size([212])
correct_edge_choices: torch.Size([53])
correct_edge_choices_batch: torch.Size([53])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([53, 2])
valid_edge_choices_batch: torch.Size([53])
valid_edge_choices_ptr: torch.Size([17])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([126])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_attachment_point_choice_ptr: torch.Size([17])
valid_attachment_point_choices: torch.Size([0])
valid_attachment_point_choices_batch: torch.Size([0])
valid_attachment_point_choices_ptr: torch.Size([17])
correct_node_type_choices: torch.Size([1112])
correct_node_type_choi

In [35]:
batch.correct_node_type_choices_batch.unique() # graphs requiring node choices

tensor([ 0,  2,  4,  6,  8, 10, 13, 15])

In [36]:
batch.correct_node_type_choices_ptr

tensor([   0,  139,  139,  278,  278,  417,  417,  556,  556,  695,  695,  834,
         834,  834,  973,  973, 1112])

In [38]:
graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()

In [39]:
start_idx_graphs_requiring_node_choices = batch.correct_node_type_choices_ptr[graphs_requiring_node_choices]

In [40]:
start_idx_graphs_requiring_node_choices

tensor([  0, 139, 278, 417, 556, 695, 834, 973])

In [41]:
node_type_multihot_labels = []
for i in range(len(batch.correct_node_type_choices_ptr)-1):
    start_idx = batch.correct_node_type_choices_ptr[i]
    end_idx = batch.correct_node_type_choices_ptr[i+1] 
    if end_idx - start_idx == 0:
        continue
    node_selection_labels = batch.correct_node_type_choices[start_idx: end_idx]
    node_type_multihot_labels += [node_selection_labels]
    
node_type_multihot_labels = torch.stack(node_type_multihot_labels, axis = 0)

In [42]:
node_type_multihot_labels

tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]])

In [93]:
from torch.nn import Linear
class PickAtomOrMotifMLP(torch.nn.Module):
    """
    For choosing an atom/motif out of the motif vocabulary
    Notes:
    Softmax layer at the end with 
    
    """
    def __init__(self, num_node_types, latent_vector_dim, hidden_channels=64):
        super(PickAtomOrMotifMLP, self).__init__()
        self.hidden_layer1 = Linear(latent_vector_dim, hidden_channels)
        self.hidden_layer2 = Linear(hidden_channels, num_node_types + 1) # add 1 for <END OF GENERATION TOKEN>

        
    def forward(self, original_and_calculated_graph_representations):
        print(original_and_calculated_graph_representations.shape)
        x = self.hidden_layer1(original_and_calculated_graph_representations)
        x = self.hidden_layer2(x)
        return x

In [116]:
node_type_selector = PickAtomOrMotifMLP(
    num_node_types = dataset.num_node_types,
    latent_vector_dim = z.shape[-1] + partial_graph_representions.shape[-1]
)

In [117]:
# result_representations: shape [PG, MD]
# number of partial graphs, molecule representation dimension

def pick_node_type(
    input_molecule_representations,
    graph_representations,
    graphs_requiring_node_choices
):
    
    relevant_graph_representations = input_molecule_representations[graphs_requiring_node_choices]
    relevant_input_molecule_representations = graph_representations[graphs_requiring_node_choices]
    original_and_calculated_graph_representations = torch.cat((relevant_graph_representations,relevant_input_molecule_representations), axis = -1)
    print(original_and_calculated_graph_representations.shape)
    return node_type_selector(original_and_calculated_graph_representations)

output = pick_node_type(z, partial_graph_representions, graphs_requiring_node_choices)


torch.Size([8, 1344])
torch.Size([8, 1344])


In [118]:
output.shape

torch.Size([8, 140])

In [171]:
def safe_divide_loss(loss, num_choices):
    """Divide `loss` by `num_choices`, but guard against `num_choices` being 0."""
    return loss / max(num_choices, 1.0) 
SMALL_NUMBER, BIG_NUMBER = 1e-7, 1e7
def compute_neglogprob_for_multihot_objective(
    logprobs,
    multihot_labels,
    per_decision_num_correct_choices,
):
    # Normalise by number of correct choices and mask out entries for wrong decisions:
    return -(
        (logprobs + torch.log(per_decision_num_correct_choices + SMALL_NUMBER))
        * multihot_labels
        / (per_decision_num_correct_choices + SMALL_NUMBER)
    )

def compute_node_type_selection_loss(
    node_type_logits,
    node_type_multihot_labels
):
    per_node_decision_logprobs = torch.nn.functional.log_softmax(node_type_logits, dim = -1)
    # Shape: [NTP, NT + 1]
    
    # number of correct choices for each of the partial graphs that require node choices
    per_node_decision_num_correct_choices = torch.sum(node_type_multihot_labels, keepdim = True, axis = -1)
    # Shape [NTP, 1]
    
    per_correct_node_decision_normalised_neglogprob = compute_neglogprob_for_multihot_objective(
        logprobs = per_node_decision_logprobs[:, :-1], # separate out the no node prediction
        multihot_labels = node_type_multihot_labels,
        per_decision_num_correct_choices = per_node_decision_num_correct_choices,
    ) # Shape [NTP, NT]
    
    no_node_decision_correct =(per_node_decision_num_correct_choices == 0.0).sum()  # Shape [NTP]
    per_correct_no_node_decision_neglogprob = -(
        per_node_decision_logprobs[:, -1]
        * torch.squeeze(no_node_decision_correct).type(torch.FloatTensor)
    )  # Shape [NTP]

#     if self._node_type_loss_weights is not None:
#         per_correct_node_decision_normalised_neglogprob *= self._node_type_loss_weights[:-1]
#         per_correct_no_node_decision_neglogprob *= self._node_type_loss_weights[-1]
    
    # Loss is the sum of the masked (no) node decisions, averaged over number of decisions made:
    total_node_type_loss = torch.sum(
        per_correct_node_decision_normalised_neglogprob
    ) + torch.sum(per_correct_no_node_decision_neglogprob)
    node_type_loss = safe_divide_loss(
        total_node_type_loss, node_type_multihot_labels.shape[0]
    )

    return node_type_loss

In [173]:
node_type_selection_loss = compute_node_type_selection_loss(
    node_type_logits = output,
    node_type_multihot_labels = node_type_multihot_labels
)

In [175]:
node_type_selection_loss.backward()


## Pick edge and edge type + Edge candidate loss

1. Find out what stop node refers to

In [176]:
def pick_edge(
    input_molecule_representations,
    graph_representations,
    node_representations,
    graph_to_focus_node_map,
    node_to_graph_map,
    candidate_edge_targets,
    candidate_edge_features,
    graphs_requiring_node_choices
):
    # given candidate edges in the partial graphs, 
    #compute logits for the likelihood of adding candidates as well as 
    # logits for the type of edge if it is picked

IndentationError: expected an indented block (2967395225.py, line 7)

In [25]:
pprint_pyg_obj(batch)

x: torch.Size([126, 64])
edge_index: torch.Size([2, 212])
focus_node: torch.Size([16])
edge_type: torch.Size([212])
correct_edge_choices: torch.Size([53])
correct_edge_choices_batch: torch.Size([53])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([53, 2])
valid_edge_choices_batch: torch.Size([53])
valid_edge_choices_ptr: torch.Size([17])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([126])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_attachment_point_choice_ptr: torch.Size([17])
valid_attachment_point_choices: torch.Size([0])
valid_attachment_point_choices_batch: torch.Size([0])
valid_attachment_point_choices_ptr: torch.Size([17])
correct_node_type_choices: torch.Size([1112])
correct_node_type_choi

In [128]:
focus_node_idx_in_batch = batch.ptr[:-1] + batch.focus_node # offset each focus node idx by the starting idx of the molecule

In [130]:
focus_node_idx_in_batch

tensor([  0,   8,  13,  19,  25,  32,  39,  47,  55,  64,  73,  83,  93, 103,
        114, 125])

In [59]:
node_representations[focus_node_idx_in_batch].shape # extract focus node node level repr

torch.Size([16, 832])

In [60]:
focus_node_representations = node_representations[focus_node_idx]

In [None]:
graph_and_focus_node_representations = torch.cat(
    (input_molecule_representations, partial_graph_representions, focus_node_representations),
    axis = -1
)

# Explanation: at each step, there is a focus node, which is the node we are 
# focusing on right now in terms of adding another edge to it. When adding a new
# edge, the edge can be between the focus node and a variety of other nodes.
# This is likely based on valency, and in reality, it is possible that none of the
# edge choices are correct (when that generation step is a node addition step)
# and not an edge addition step. Regardless, we still want to consider the candidates
#"target" refers to the node at the other end of the candidate edge
candidate_edge_targets = batch.valid_edge_choices[:, 1]

valid_target_to_graph_map = 

In [112]:
batch.valid_edge_choices[:, 1] # the 2nd element in each edge is the target since
# the 1st element is the focus node at that step

tensor([0, 2, 0, 0, 4, 0, 0, 4, 5, 0, 4, 0, 4, 5, 6, 0, 4, 5, 0, 4, 5, 6, 7, 0,
        4, 5, 6, 0, 4, 5, 6, 7, 8, 0, 5, 6, 7, 8, 0, 5, 6, 7, 0, 5, 6, 7, 8, 9,
        0, 5, 6, 7, 9], dtype=torch.int32)

In [131]:
candidate_edge_target_node_idx = batch.valid_edge_choices[:, 1] + batch.ptr[batch.valid_edge_choices_batch]  # idx offset for the edge choices

In [134]:
partial_graph_idx_per_edge_candidate = batch.batch[candidate_edge_target_node_idx]

In [137]:
partial_graph_idx_per_edge_candidate

tensor([ 1,  1,  2,  3,  3,  4,  5,  5,  5,  6,  6,  7,  7,  7,  7,  8,  8,  8,
         9,  9,  9,  9,  9, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 12,
        12, 12, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15])

In [138]:
graph_and_focus_node_representations = torch.cat(
    (input_molecule_representations, partial_graph_representions, focus_node_representations),
    axis = -1
)

In [143]:
graph_and_focus_node_representations_per_edge_candidate = graph_and_focus_node_representations[partial_graph_idx_per_edge_candidate].shape

In [144]:
candidate_edge_target_node_idx

tensor([  4,   6,   9,  14,  18,  20,  26,  30,  31,  33,  37,  40,  44,  45,
         46,  48,  52,  53,  56,  60,  61,  62,  63,  65,  69,  70,  71,  74,
         78,  79,  80,  81,  82,  84,  89,  90,  91,  92,  94,  99, 100, 101,
        104, 109, 110, 111, 112, 113, 115, 120, 121, 122, 124])

In [145]:
edge_candidate_target_node_representations = node_representations[candidate_edge_target_node_idx]

In [147]:
edge_candidate_target_node_representations.shape

torch.Size([53, 832])

In [149]:
pprint_pyg_obj(batch)

x: torch.Size([126, 64])
edge_index: torch.Size([2, 212])
focus_node: torch.Size([16])
edge_type: torch.Size([212])
correct_edge_choices: torch.Size([53])
correct_edge_choices_batch: torch.Size([53])
correct_edge_choices_ptr: torch.Size([17])
num_correct_edge_choices: torch.Size([16])
stop_node_label: torch.Size([16])
valid_edge_choices: torch.Size([53, 2])
valid_edge_choices_batch: torch.Size([53])
valid_edge_choices_ptr: torch.Size([17])
correct_edge_types: torch.Size([9, 3])
correct_edge_types_batch: torch.Size([9])
correct_edge_types_ptr: torch.Size([17])
partial_node_categorical_features: torch.Size([126])
correct_attachment_point_choice: torch.Size([0])
correct_attachment_point_choice_batch: torch.Size([0])
correct_attachment_point_choice_ptr: torch.Size([17])
valid_attachment_point_choices: torch.Size([0])
valid_attachment_point_choices_batch: torch.Size([0])
valid_attachment_point_choices_ptr: torch.Size([17])
correct_node_type_choices: torch.Size([1112])
correct_node_type_choi

In [None]:
def compute_edge_type_selection_loss(

)

    edge_loss = self.compute_edge_candidate_selection_loss(
        num_graphs_in_batch=batch_features["num_partial_graphs_in_batch"],
        node_to_graph_map=batch_features["node_to_partial_graph_map"],
        candidate_edge_targets=batch_features["valid_edge_choices"][:, 1],
        edge_candidate_logits=task_output.edge_candidate_logits,
        per_graph_num_correct_edge_choices=batch_labels["num_correct_edge_choices"],
        edge_candidate_correctness_labels=batch_labels["correct_edge_choices"],
        no_edge_selected_labels=batch_labels["stop_node_label"],
    )

Shape abbreviations used throughout the model:
- PG = number of _p_artial _g_raphs
- PV = number of _p_artial graph _v_ertices
- PD = size of _p_artial graph _representation _d_imension
- VD = GNN _v_ertex representation _d_imension
- MD = _m_olecule representation _d_imension
- EFD = _e_dge _f_eature _d_imension
- NTP = number of partial graphs requiring a _n_ode _t_ype _p_ick
- NT = number of _n_ode _t_ypes
- CE = number of _c_andidate _e_dges
- CCE = number of _c_orrect _c_andidate _e_dges
- ET = number of _e_dge _t_ypes
- CA = number of _c_andidate _a_ttachment points
- AP = number of _a_ttachment _p_oint choices

# TODO

1. Implement Sampling strategy
2. Look into why first node prediction is treated separately

Current design plan: all model specific things are in the BaseModel class, 
All loss computation and what not will be in the lightning module 
This decouples the training code, optimizers and schedulers from the model itself, which arguably makes it cleaner when doing hyperparameter tuning and also during model loading





Note: Original implementation of the GNN adds each molecule and their corresponding generation steps sequentially into the batch, so what happens is that a molecule's sequential generation steps are seen by the GNN during training in a sequential manner. In this implementation, if we use shuffle = False, we can replicate this behaviour in our dataloader, but if we use shuffle = True, then we can't. 

Possible additions to allow for molecule wise shuffling => add additional column to CSV to allow for a molecule's generation steps to be treated as one set of training samples => then shuffle this instead of the actual generation steps (try this if the training is not converging)

This has additional implications for the latent space sampling strategies.

In [None]:
from dataclasses import dataclass

@dataclass
class MoLeROutput:
    """Class for keeping track of output from MoLeR model."""
    node_type_logits
    edge_candidate_logits
    edge_type_logits
    attachment_point_selection_logits
    p
    q

# Follow implementation here: https://github.com/Lightning-AI/lightning-bolts/blob/master/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py

In [None]:
from molecule_generation.utils.training_utils import get_class_balancing_weights
from pytorch_lightning import LightningModule, Trainer, seed_everything



class BaseModel(LightningModule):
    def __init__(self, params, dataset):
        """Params is a nested dictionary with the relevant parameters."""
        super(BaseModel, self).__init__()
        self._init_params(params, dataset)
        self._motif_aware_embedding_layer = Embedding(params['motif_aware_embedding'])
        self._encoder_gnn = FullGraphEncoder(params['encoder_gnn'])
        self._decoder_gnn = PartialGraphEncoder(params['decoder_gnn'])
        self._decoder_mlp = MLPDecoder(params['decoder_mlp'])
        
        # params for latent space
        self._latent_sample_strategy = params['latent_sample_strategy']
        self._latent_repr_dim = params["latent_repr_size"]
        
        
               
    def _init_params(self, params, dataset):
        """
        Initialise class weights for next node prediction and placefolder for
        motif/node embeddings.
        """
        
        # Get some information out from the dataset:
        next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
        class_weight_factor = self._params.get("node_type_predictor_class_loss_weight_factor", 0.0)
        
        if not (0.0 <= class_weight_factor <= 1.0):
            raise ValueError(
                f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
            )
        if class_weight_factor > 0:
            atom_type_nums = [
                next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
                for type_idx in range(dataset.num_node_types)
            ]
            atom_type_nums.append(next_node_type_distribution["None"])

            self.class_weights = get_class_balancing_weights(
                class_counts=atom_type_nums, class_weight_factor=class_weight_factor
            )
        else:
            self.class_weights = None
            
        motif_vocabulary = dataset.metadata.get("motif_vocabulary")
        self._uses_motifs = motif_vocabulary is not None

        self._node_categorical_num_classes = dataset.node_categorical_num_classes
        
        
        if self.uses_categorical_features:
            if "categorical_features_embedding_dim" in self._params:
                self._node_categorical_features_embedding = None
        
    @property
    def uses_motifs(self):
        return self._uses_motifs

    @property
    def uses_categorical_features(self):
        return self._node_categorical_num_classes is not None

    @property
    def decoder(self):
        return self._decoder

    @property
    def encoder(self):
        return self._encoder
    
    @property
    def motif_aware_embedding_layer(self):
        return self._motif_aware_embedding_layer
    
    @property
    def latent_dim(self):
        return self._latent_repr_dim
    
    def compute_initial_node_features(batch )
        # Compute embedding
        pass
        
    
    def sample_from_latent_repr(latent_repr):
        # perturb latent repr
        mu = latent_repr[:, : self.latent_dim]  # Shape: [V, MD]
        log_var = latent_repr[:, self.latent_dim :]  # Shape: [V, MD]

        # result_representations: shape [num_partial_graphs, latent_repr_dim]
        p, q, z = self.sample(mu, log_var)
        
        return p, q, z 
        
    def sample(self, mu, log_var)
        """Samples a different noise vector for each partial graph. 
        TODO: look into the other sampling strategies."""
        std = torch.exp(log_var / 2)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return p, q, z
    
    
    def forward(self, x, edge_index, edge_attr, batch, ??):
        # Obtain node embeddings 
        batch = self.compute_initial_node_features(batch)
        
        # Forward pass through encoder
        latent_repr = self.encoder(batch)
        
        # Apply latent sampling strategy
        p, q, latent_repr = self.sample_from_latent_repr(latent_repr)
        
        # Forward pass through decoder
        node_type_logits, edge_candidate_logits, edge_type_logits, attachment_point_selection_logits = self.decoder(latent_repr)
        
        # NOTE: loss computation will be done in lightning module
        return MoLeROutput(
            node_type_logits = node_type_logits,
            edge_candidate_logits = edge_candidate_logits,
            edge_type_logits = edge_type_logits,
            attachment_point_selection_logits = attachment_point_selection_logits,
            p = p,
            q = q,
        )
    
        
        
        
#         x = self.conv1(x, edge_index, edge_attr)
#         x = x.relu()
#         x = self.conv2(x, edge_index, edge_attr)
#         x = x.relu()
#         x = self.conv3(x, edge_index, edge_attr)
#         return x
#     def _compute_decoder_loss_and_metrics(
#         self, batch_features, task_output, batch_labels
#     ) -> Tuple[tf.Tensor, MoLeRDecoderMetrics]:

#         decoder_metrics = self.decoder.compute_metrics(
#             batch_features=batch_features, batch_labels=batch_labels, task_output=task_output
#         )

#         total_loss = (
#             self._params["node_classification_loss_weight"]
#             * decoder_metrics.node_classification_loss
#             + self._params["first_node_classification_loss_weight"]
#             * decoder_metrics.first_node_classification_loss
#             + self._params["edge_selection_loss_weight"] * decoder_metrics.edge_loss
#             + self._params["edge_type_loss_weight"] * decoder_metrics.edge_type_loss
#         )

#         if self.uses_motifs:
#             total_loss += (
#                 self._params["attachment_point_selection_weight"]
#                 * decoder_metrics.attachment_point_selection_loss
#             )

#         return total_loss, decoder_metrics

In [84]:
class GraphEncoder(torch.nn.Module):
    """For constructing graph level embedding during encoder step"""
    def __init__(self, num_node_features, hidden_channels):
        super(GraphEncoder, self).__init__()
        # torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, edge_attr, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_attr)
        x = x.leaky_relu()
        x = self.conv2(x, edge_index, edge_attr)
        x = x.leaky_relu()
        x = self.conv3(x, edge_index, edge_attr)
        x = 
        return x


class PartialGraphEncoder(torch.nn.Module):
    """For constructing graph level embedding during decoder steps"""
    def __init__(self, num_node_features, hidden_channels):
        super(PartialGraphEncoder, self).__init__()
        # torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, edge_attr, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_attr)
        return x
  

class PickAtomOrMotifMLP(torch.nn.Module):
    """
    For choosing an atom/motif out of the motif vocabulary
    Notes:
    Softmax layer at the end with 
    
    """
    def __init__(self, motif_vocabulary, latent_vector_dim, hidden_channels=256):
        super(PickAtomOrMotifMLP, self).__init__()
        self.hidden_layer1 = Linear(latent_vector_dim, hidden_channels)
        self.hidden_layer2 = Linear(hidden_channels, len(motif_vocabulary) + 1) # add 1 for <END OF GENERATION TOKEN>
        self.activation = torch.Softmax
        
    def forward(self, latent_vector, partial_graph_embedding):
        x = self.hidden_layer1(latent_vector)
        x = self.hidden_layer2(x)
        return self.activation(x)
        

class PickAttachmentMLP(torch.nn.Module):
    def __init__(self, motif_dictionary, hidden_channels):
        super(PickAttachmentMLP, self).__init__()
        pass
    def forward(self, latent_vector, partial_graph_embedding):
        pass

class PickBond(torch.nn.Module):
    def __init__(self, motif_dictionary, hidden_channels):
        super(PickAttachmentMLP, self).__init__()
        pass
    def forward(self, latent_vector, partial_graph_embedding):
        pass

# TODOs

1. Implement the learned aggregation function in Moler Paper by subclassing MessagePassing layer 
2. Change the init function of the encoder to take in the relevant hyperparameters from the `moler_vae` file
3. Figure out which of the latent sample strategies work best:
- pass through
- per graph
- per partial graph