# Design 

1. All model related loss computations and helper functions are in torch.modules and they will all be called compute_loss 
2. The MLPDecoder will house all the MLPs involved in the decoder and will have 1 compute_decoder_loss_method that will call all the individual compute_loss methods.

It will also have a decode method that will do decoding at inference time(TODO)

3. The FullGraphEncoder and the PartialGraphEncoder will each be in their own torch modules

4. Finally, the lightning module will have 3 things: 
- FullGraphEncoder
- PartialGraphEncoder (part of decoder)
- MLPDecoder

And after passing through the initial FullGraphEncoder, if we are working with a VAE, we will extract p and q for computing the kl divergence loss, otherwise we will do the other model specific stuff like diffusion.

`params` dictionary will be passed to the lightning module and each torch module will be constructed within it using the relevant parameters by destructuring the dictionary 

node type class weights will be instantiated in the lightning module and passed to the decoder

# TODO
1. fix the incrementing in the original graph edge index (DONE)
2. Work on first node prediction 
3. Investigate node_type_predictor_class_loss_weight_factor

In [117]:
# params houses all relevant model instantiation parameters
params = {}

In [109]:
%load_ext autoreload
%autoreload 2

from dataset import MolerDataset, MolerData
from utils import pprint_pyg_obj
from torch_geometric.loader import DataLoader


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [110]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/l1000/trace_playground',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/pyg_output_playground',
    split = 'train',
)

In [111]:
loader = DataLoader(dataset, batch_size=16, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices',
    'original_graph_x'
])

In [112]:
for batch in loader:
    break

# FullGraphEncoder

In [119]:
from model_utils import GenericGraphEncoder
import torch

In [120]:
class GraphEncoder(torch.nn.Module):
    """Returns graph level representation of the molecules."""
    def __init__(
        self,
        input_feature_dim,
        atom_or_motif_vocab_size,
        motif_embedding_size = 64,
        hidden_layer_feature_dim=64,
        num_layers=12,
        layer_type="RGATConv",
        use_intermediate_gnn_results=True,
    ):
        super(GraphEncoder, self).__init__()
        self._embed = torch.nn.Embedding(atom_or_motif_vocab_size, motif_embedding_size)
        self._model = GenericGraphEncoder(input_feature_dim = motif_embedding_size + input_feature_dim)
        
    def forward(self, original_graph_node_categorical_features, node_features, edge_index, edge_type, batch_index):
        motif_embeddings = self._embed(original_graph_node_categorical_features)
        node_features = torch.cat((node_features, motif_embeddings), axis = -1)
        input_molecule_representations, _ = self._model(node_features, edge_index.long(), edge_type, batch_index)
        return input_molecule_representations

In [121]:
params['full_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
    'atom_or_motif_vocab_size': len(dataset.node_type_index_to_string)
}

full_graph_encoder = GraphEncoder(
    input_feature_dim = batch.x.shape[-1],
    atom_or_motif_vocab_size = len(dataset.node_type_index_to_string)
)

full_graph_encoder = GraphEncoder(**params['full_graph_encoder'])

In [131]:
input_molecule_representations = full_graph_encoder(
    batch.original_graph_node_categorical_features, 
    batch.original_graph_x.float(),
    batch.original_graph_edge_index,
    batch.original_graph_edge_type,
    batch_index = batch.original_graph_x_batch,
)

# PartialGraphEncoder

In [126]:
params['partial_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
}

partial_graph_encoder = GenericGraphEncoder(
    input_feature_dim = batch.x.shape[-1],
)

partial_graph_encoder = GenericGraphEncoder(**params['partial_graph_encoder'])

In [127]:
partial_graph_representions, node_representations = partial_graph_encoder(batch.x, batch.edge_index.long(), batch.edge_type, batch.batch)

In [128]:
node_representations.shape

torch.Size([193, 832])

# _mean_log_var_mlp

In [140]:
from model_utils import GenericMLP
latent_dim = 512
params['mean_log_var_mlp'] = {
    'input_feature_dim': input_molecule_representations.shape[-1],
    'output_size': latent_dim * 2
}


mean_log_var_mlp = GenericMLP(**params['mean_log_var_mlp'])

In [141]:
mean_and_log_var = mean_log_var_mlp(input_molecule_representations)

In [142]:
mu = mean_and_log_var[:, : latent_dim]  # Shape: [V, MD]
log_var = mean_and_log_var[:, latent_dim :]  # Shape: [V, MD]

# result_representations: shape [num_partial_graphs, latent_repr_dim]
std = torch.exp(log_var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()

In [143]:
z.shape

torch.Size([16, 512])

# Decoder

In [171]:
from utils import safe_divide_loss, compute_neglogprob_for_multihot_objective
from model_utils import GenericMLP


class MLPDecoder(torch.nn.Module):
    """Returns graph level representation of the molecules."""
    def __init__(
        self,
        params, # nested dictionary of parameters for each MLP
    ):
        super(MLPDecoder, self).__init__()
        # Node selection 
        self._node_type_selector = GenericMLP(**params['node_type_selector'])
        self._node_type_loss_weights = params['node_type_loss_weights']
        
        # Edge selection
        self._no_more_edges_representation = torch.nn.Parameter(torch.FloatTensor(*params['no_more_edges_repr']), requires_grad = True)
        self._edge_candidate_scorer = GenericMLP(**params['edge_candidate_scorer'])
        self._edge_type_selector = GenericMLP(**params['edge_type_selector'])
        
        
        # Attachment Point Selection
        
        
    def pick_node_type(
        self,
        input_molecule_representations,
        graph_representations,
        graphs_requiring_node_choices
    ):
        relevant_graph_representations = input_molecule_representations[graphs_requiring_node_choices]
        relevant_input_molecule_representations = graph_representations[graphs_requiring_node_choices]
        original_and_calculated_graph_representations = torch.cat((relevant_graph_representations,relevant_input_molecule_representations), axis = -1)
        node_type_logits = self._node_type_selector(original_and_calculated_graph_representations)
        return node_type_logits
    
    def compute_node_type_selection_loss(
        self,
        node_type_logits,
        node_type_multihot_labels
    ):
        per_node_decision_logprobs = torch.nn.functional.log_softmax(node_type_logits, dim = -1)
        # Shape: [NTP, NT + 1]

        # number of correct choices for each of the partial graphs that require node choices
        per_node_decision_num_correct_choices = torch.sum(node_type_multihot_labels, keepdim = True, axis = -1)
        # Shape [NTP, 1]

        per_correct_node_decision_normalised_neglogprob = compute_neglogprob_for_multihot_objective(
            logprobs = per_node_decision_logprobs[:, :-1], # separate out the no node prediction
            multihot_labels = node_type_multihot_labels,
            per_decision_num_correct_choices = per_node_decision_num_correct_choices,
        ) # Shape [NTP, NT]

        no_node_decision_correct =(per_node_decision_num_correct_choices == 0.0).sum()  # Shape [NTP]
        per_correct_no_node_decision_neglogprob = -(
            per_node_decision_logprobs[:, -1]
            * torch.squeeze(no_node_decision_correct).type(torch.FloatTensor)
        )  # Shape [NTP]

        if self._node_type_loss_weights is not None:
            per_correct_node_decision_normalised_neglogprob *= self._node_type_loss_weights[:-1]
            per_correct_no_node_decision_neglogprob *= self._node_type_loss_weights[-1]

        # Loss is the sum of the masked (no) node decisions, averaged over number of decisions made:
        total_node_type_loss = torch.sum(
            per_correct_node_decision_normalised_neglogprob
        ) + torch.sum(per_correct_no_node_decision_neglogprob)
        node_type_loss = safe_divide_loss(
            total_node_type_loss, node_type_multihot_labels.shape[0]
        )

        return node_type_loss

    
    def compute_edge_candidate_selection_loss(
        num_graphs_in_batch, # correct_node_type_choices
        node_to_graph_map, #batch.batch
        candidate_edge_targets, # batch_features["valid_edge_choices"][:, 1]
        edge_candidate_logits, # as is
        per_graph_num_correct_edge_choices, # batch.num_correct_edge_choices
        edge_candidate_correctness_labels, # correct edge choices
        no_edge_selected_labels # stop node label
    ):

        # First, we construct full labels for all edge decisions, which are the concat of
        # edge candidate logits and the logits for choosing no edge:
        edge_correctness_labels = torch.cat(
            [edge_candidate_correctness_labels, no_edge_selected_labels.float()],
            axis=0,
        )  # Shape: [CE + PG]

        # To compute a softmax over all candidate edges (and the "no edge" choice) corresponding
        # to the same graph, we first need to build the map from each logit to the corresponding
        # graph id. Then, we can do an unsorted_segment_softmax using that map:
        edge_candidate_to_graph_map = batch.batch[candidate_edge_target_node_idx]
        # add the end bond labels to the end 
        edge_candidate_to_graph_map = torch.cat((edge_candidate_to_graph_map, torch.arange(0, num_graphs_in_batch)))

        edge_candidate_logprobs = traced_unsorted_segment_log_softmax(
            logits=edge_candidate_logits,
            segment_ids=edge_candidate_to_graph_map,
            num_segments=num_graphs_in_batch,
        )  # Shape: [CE + PG]

        # Compute the edge loss with the multihot objective.
        # For a single graph with three valid choices (+ stop node) of which two are correct,
        # we may have the following:
        #  edge_candidate_logprobs = log([0.05, 0.5, 0.4, 0.05])
        #  per_graph_num_correct_edge_choices = [2]
        #  edge_candidate_correctness_labels = [0.0, 1.0, 1.0]
        #  edge_correctness_labels = [0.0, 1.0, 1.0, 0.0]
        # To get the loss, we simply look at the things in edge_candidate_logprobs that correspond
        # to correct entries.
        # However, to account for the _multi_hot nature, we scale up each entry of
        # edge_candidate_logprobs by the number of correct choices, i.e., consider the
        # correct entries of
        #  log([0.05, 0.5, 0.4, 0.05]) + log([2, 2, 2, 2]) = log([0.1, 1.0, 0.8, 0.1])
        # In this form, we want to have each correct entry to be as near possible to 1.
        # Finally, we normalise loss contributions to by-graph, by dividing the crossentropy
        # loss by the number of correct choices (i.e., in the example above, this results in
        # a loss of -((log(1.0) + log(0.8)) / 2) = 0.11...).

        # Note: per_graph_num_correct_edge_choices does not include the choice of an edge to
        # the stop node, so can be zero.
        per_graph_num_correct_edge_choices = torch.max(
            per_graph_num_correct_edge_choices, torch.ones(per_graph_num_correct_edge_choices.shape)
        )  # Shape: [PG]


        per_edge_candidate_num_correct_choices = per_graph_num_correct_edge_choices[edge_candidate_to_graph_map]
        # Shape: [CE]
        per_correct_edge_neglogprob = -(
            (edge_candidate_logprobs + torch.log(per_edge_candidate_num_correct_choices))
            * edge_correctness_labels
            / per_edge_candidate_num_correct_choices
        )  # Shape: [CE]

        # Normalise by number of graphs for which we made edge selection decisions:
        edge_loss = safe_divide_loss(
            torch.sum(per_correct_edge_neglogprob), num_graphs_in_batch
        )

        return edge_loss    
    
    def compute_decoder_loss(
        node_type_logits,
        node_type_multihot_labels
    
    ):
        # Compute node selection loss
        node_selection_loss = self.compute_node_type_selection_loss(
            node_type_logits,
            node_type_multihot_labels
        )
        
        # Compute edge selection loss 
        
        
        # Compute attachement point selection loss
        
        
        # Weighted sum of the losses and return it for backpropagation in
        # the lightning module
        return node_selection_loss
        
        
    def forward(
        self,
        input_molecule_representations,
        graph_representations,
        graphs_requiring_node_choices
    
    ):
        # Compute node logits
        node_logits = self.pick_node_type(
            input_molecule_representations,
            graph_representations,
            graphs_requiring_node_choices
        )
        
        # Compute edge logits
        
        
        # Compute attachment point logits
        
        
        
        # return all logits
        return node_logits

## PickAtomOrMotif

In [186]:
from molecule_generation.utils.training_utils import get_class_balancing_weights


next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
class_weight_factor = params.get("node_type_predictor_class_loss_weight_factor", 1.0)

if not (0.0 <= class_weight_factor <= 1.0):
    raise ValueError(
        f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
    )
if class_weight_factor > 0:
    atom_type_nums = [
        next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
        for type_idx in range(dataset.num_node_types)
    ]
    atom_type_nums.append(next_node_type_distribution["None"])

    class_weights = get_class_balancing_weights(
        class_counts=atom_type_nums, class_weight_factor=class_weight_factor
    )
else:
    class_weights = None
    
    
    
params['node_type_loss_weights'] = torch.tensor(class_weights)

In [187]:
from model_utils import GenericMLP
params['node_type_selector'] = {
    'input_feature_dim':  z.shape[-1] + partial_graph_representions.shape[-1], 
    'output_size': dataset.num_node_types + 1
}


graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()

node_type_selector = GenericMLP(
    input_feature_dim = z.shape[-1] + partial_graph_representions.shape[-1],
    output_size = dataset.num_node_types,
)

node_type_selector = GenericMLP(**params['node_type_selector'])

### node loss computation in the forward method

In [188]:
decoder = MLPDecoder(params)

In [189]:
node_logits = decoder.pick_node_type(
    z,
    partial_graph_representions,
    graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()
)

In [190]:
node_type_multihot_labels = []
for i in range(len(batch.correct_node_type_choices_ptr)-1):
    start_idx = batch.correct_node_type_choices_ptr[i]
    end_idx = batch.correct_node_type_choices_ptr[i+1] 
    if end_idx - start_idx == 0:
        continue
    node_selection_labels = batch.correct_node_type_choices[start_idx: end_idx]
    node_type_multihot_labels += [node_selection_labels]
    
node_type_multihot_labels = torch.stack(node_type_multihot_labels, axis = 0)

In [191]:
node_type_selection_loss = decoder.compute_node_type_selection_loss(
    node_logits,
    node_type_multihot_labels
)

tensor(0.4860, grad_fn=<DivBackward0>)

# PickEdge

In [None]:
def pick_edge(
    input_molecule_representations,
    graph_representations,
    node_representations,
    graph_to_focus_node_map,
    node_to_graph_map,
    candidate_edge_targets,
    candidate_edge_features,
    graphs_requiring_node_choices,
    partial_graph_start_node_idx, # batch.ptr[:-1]
    focus_node_idx_within_each_partial_graph # batch.focus_node
):
    

In [202]:
batch.ptr[:-1] + batch.focus_node 

tensor([  0,   8,  17,  27,  37,  47,  58,  69,  81,  93, 106, 124, 134, 153,
        172, 192])

In [198]:
batch.focus_node.shape

torch.Size([16])

In [204]:
batch.correct_edge_choices

tensor([0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])

In [205]:
batch.correct_edge_choices != 0

tensor([False,  True, False, False, False, False, False,  True, False, False,
        False,  True, False, False, False, False,  True, False, False, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False,  True, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True])

# LightningModule + Vae MLP

1. Implement kd divergence loss as part of the lightning module

In [None]:
from molecule_generation.utils.training_utils import get_class_balancing_weights
from pytorch_lightning import LightningModule, Trainer, seed_everything



class BaseModel(LightningModule):
    def __init__(self, params, dataset):
        """Params is a nested dictionary with the relevant parameters."""
        super(BaseModel, self).__init__()
        self._init_params(params, dataset)
        self._motif_aware_embedding_layer = Embedding(params['motif_aware_embedding'])
        self._encoder_gnn = FullGraphEncoder(params['encoder_gnn'])
        self._decoder_gnn = PartialGraphEncoder(params['decoder_gnn'])
        self._decoder_mlp = MLPDecoder(params['decoder_mlp'])
        
        # params for latent space
        self._latent_sample_strategy = params['latent_sample_strategy']
        self._latent_repr_dim = params["latent_repr_size"]
        
        
               
    def _init_params(self, params, dataset):
        """
        Initialise class weights for next node prediction and placefolder for
        motif/node embeddings.
        """
        
        # Get some information out from the dataset:
        next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
        class_weight_factor = self._params.get("node_type_predictor_class_loss_weight_factor", 0.0)
        
        if not (0.0 <= class_weight_factor <= 1.0):
            raise ValueError(
                f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
            )
        if class_weight_factor > 0:
            atom_type_nums = [
                next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
                for type_idx in range(dataset.num_node_types)
            ]
            atom_type_nums.append(next_node_type_distribution["None"])

            self.class_weights = get_class_balancing_weights(
                class_counts=atom_type_nums, class_weight_factor=class_weight_factor
            )
        else:
            self.class_weights = None
            
        motif_vocabulary = dataset.metadata.get("motif_vocabulary")
        self._uses_motifs = motif_vocabulary is not None

        self._node_categorical_num_classes = dataset.node_categorical_num_classes
        
        
        if self.uses_categorical_features:
            if "categorical_features_embedding_dim" in self._params:
                self._node_categorical_features_embedding = None
        
    @property
    def uses_motifs(self):
        return self._uses_motifs

    @property
    def uses_categorical_features(self):
        return self._node_categorical_num_classes is not None

    @property
    def decoder(self):
        return self._decoder

    @property
    def encoder(self):
        return self._encoder
    
    @property
    def motif_aware_embedding_layer(self):
        return self._motif_aware_embedding_layer
    
    @property
    def latent_dim(self):
        return self._latent_repr_dim
    
    def compute_initial_node_features(batch )
        # Compute embedding
        pass
        
    
    def sample_from_latent_repr(latent_repr):
        # perturb latent repr
        mu = latent_repr[:, : self.latent_dim]  # Shape: [V, MD]
        log_var = latent_repr[:, self.latent_dim :]  # Shape: [V, MD]

        # result_representations: shape [num_partial_graphs, latent_repr_dim]
        p, q, z = self.sample(mu, log_var)
        
        return p, q, z 
        
    def sample(self, mu, log_var)
        """Samples a different noise vector for each partial graph. 
        TODO: look into the other sampling strategies."""
        std = torch.exp(log_var / 2)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return p, q, z
    
    
    def forward(self, x, edge_index, edge_attr, batch, ??):
        # Obtain node embeddings 
        batch = self.compute_initial_node_features(batch)
        
        # Forward pass through encoder
        latent_repr = self.encoder(batch)
        
        # Apply latent sampling strategy
        p, q, latent_repr = self.sample_from_latent_repr(latent_repr)
        
        # Forward pass through decoder
        node_type_logits, edge_candidate_logits, edge_type_logits, attachment_point_selection_logits = self.decoder(latent_repr)
        
        # NOTE: loss computation will be done in lightning module
        return MoLeROutput(
            node_type_logits = node_type_logits,
            edge_candidate_logits = edge_candidate_logits,
            edge_type_logits = edge_type_logits,
            attachment_point_selection_logits = attachment_point_selection_logits,
            p = p,
            q = q,
        )