# Design 

1. All model related loss computations and helper functions are in torch.modules and they will all be called compute_loss 
2. The MLPDecoder will house all the MLPs involved in the decoder and will have 1 compute_decoder_loss_method that will call all the individual compute_loss methods.

It will also have a decode method that will do decoding at inference time(TODO)

3. The FullGraphEncoder and the PartialGraphEncoder will each be in their own torch modules

4. Finally, the lightning module will have 3 things: 
- FullGraphEncoder
- PartialGraphEncoder (part of decoder)
- MLPDecoder

And after passing through the initial FullGraphEncoder, if we are working with a VAE, we will extract p and q for computing the kl divergence loss, otherwise we will do the other model specific stuff like diffusion.

`params` dictionary will be passed to the lightning module and each torch module will be constructed within it using the relevant parameters by destructuring the dictionary 

node type class weights will be instantiated in the lightning module and passed to the decoder

# TODO
1. fix the incrementing in the original graph edge index (DONE)
2. Work on first node prediction 
3. Investigate node_type_predictor_class_loss_weight_factor

In [10]:
# params houses all relevant model instantiation parameters
params = {}

In [1]:
%load_ext autoreload
%autoreload 2

from dataset import MolerDataset, MolerData
from utils import pprint_pyg_obj
from torch_geometric.loader import DataLoader


In [2]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/l1000/trace_playground',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/pyg_output_playground',
    split = 'train',
)

2022-12-16 10:12:52.696562: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
loader = DataLoader(dataset, batch_size=16, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices',
    'original_graph_x'
])

In [4]:
for batch in loader:
    break

# FullGraphEncoder

In [5]:
from model_utils import GenericGraphEncoder
import torch

In [6]:
class GraphEncoder(torch.nn.Module):
    """Returns graph level representation of the molecules."""
    def __init__(
        self,
        input_feature_dim,
        atom_or_motif_vocab_size,
        motif_embedding_size = 64,
        hidden_layer_feature_dim=64,
        num_layers=12,
        layer_type="RGATConv",
        use_intermediate_gnn_results=True,
    ):
        super(GraphEncoder, self).__init__()
        self._embed = torch.nn.Embedding(atom_or_motif_vocab_size, motif_embedding_size)
        self._model = GenericGraphEncoder(input_feature_dim = motif_embedding_size + input_feature_dim)
        
    def forward(self, original_graph_node_categorical_features, node_features, edge_index, edge_type, batch_index):
        motif_embeddings = self._embed(original_graph_node_categorical_features)
        node_features = torch.cat((node_features, motif_embeddings), axis = -1)
        input_molecule_representations, _ = self._model(node_features, edge_index.long(), edge_type, batch_index)
        return input_molecule_representations

In [11]:
params['full_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
    'atom_or_motif_vocab_size': len(dataset.node_type_index_to_string)
}

full_graph_encoder = GraphEncoder(
    input_feature_dim = batch.x.shape[-1],
    atom_or_motif_vocab_size = len(dataset.node_type_index_to_string)
)

full_graph_encoder = GraphEncoder(**params['full_graph_encoder'])

In [12]:
input_molecule_representations = full_graph_encoder(
    batch.original_graph_node_categorical_features, 
    batch.original_graph_x.float(),
    batch.original_graph_edge_index,
    batch.original_graph_edge_type,
    batch_index = batch.original_graph_x_batch,
)

# PartialGraphEncoder

In [13]:
params['partial_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
}

partial_graph_encoder = GenericGraphEncoder(
    input_feature_dim = batch.x.shape[-1],
)

partial_graph_encoder = GenericGraphEncoder(**params['partial_graph_encoder'])

In [14]:
partial_graph_representions, node_representations = partial_graph_encoder(batch.x, batch.edge_index.long(), batch.edge_type, batch.batch)

In [15]:
node_representations.shape

torch.Size([193, 832])

# _mean_log_var_mlp

In [16]:
from model_utils import GenericMLP
latent_dim = 512
params['mean_log_var_mlp'] = {
    'input_feature_dim': input_molecule_representations.shape[-1],
    'output_size': latent_dim * 2
}


mean_log_var_mlp = GenericMLP(**params['mean_log_var_mlp'])

In [17]:
mean_and_log_var = mean_log_var_mlp(input_molecule_representations)

In [18]:
mu = mean_and_log_var[:, : latent_dim]  # Shape: [V, MD]
log_var = mean_and_log_var[:, latent_dim :]  # Shape: [V, MD]

# result_representations: shape [num_partial_graphs, latent_repr_dim]
std = torch.exp(log_var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()

In [19]:
z.shape

torch.Size([16, 512])

# Decoder

In [106]:
from decoder import MLPDecoder

## PickAtomOrMotif

In [44]:
from molecule_generation.utils.training_utils import get_class_balancing_weights


next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
class_weight_factor = params.get("node_type_predictor_class_loss_weight_factor", 1.0)

if not (0.0 <= class_weight_factor <= 1.0):
    raise ValueError(
        f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
    )
if class_weight_factor > 0:
    atom_type_nums = [
        next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
        for type_idx in range(dataset.num_node_types)
    ]
    atom_type_nums.append(next_node_type_distribution["None"])

    class_weights = get_class_balancing_weights(
        class_counts=atom_type_nums, class_weight_factor=class_weight_factor
    )
else:
    class_weights = None
    
    
    
params['node_type_loss_weights'] = torch.tensor(class_weights)

In [45]:
from model_utils import GenericMLP
params['node_type_selector'] = {
    'input_feature_dim':  z.shape[-1] + partial_graph_representions.shape[-1], 
    'output_size': dataset.num_node_types + 1
}


graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()

node_type_selector = GenericMLP(
    input_feature_dim = z.shape[-1] + partial_graph_representions.shape[-1],
    output_size = dataset.num_node_types,
)

node_type_selector = GenericMLP(**params['node_type_selector'])

### node loss computation in the forward method

In [91]:
decoder = MLPDecoder(params)

In [92]:
node_logits = decoder.pick_node_type(
    z,
    partial_graph_representions,
    graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()
)

In [93]:
node_type_multihot_labels = []
for i in range(len(batch.correct_node_type_choices_ptr)-1):
    start_idx = batch.correct_node_type_choices_ptr[i]
    end_idx = batch.correct_node_type_choices_ptr[i+1] 
    if end_idx - start_idx == 0:
        continue
    node_selection_labels = batch.correct_node_type_choices[start_idx: end_idx]
    node_type_multihot_labels += [node_selection_labels]
    
node_type_multihot_labels = torch.stack(node_type_multihot_labels, axis = 0)

In [94]:
node_type_selection_loss = decoder.compute_node_type_selection_loss(
    node_logits,
    node_type_multihot_labels
)

# PickEdge

In [95]:

params['no_more_edges_repr'] = (1,node_representations.shape[-1] + batch.edge_features.shape[-1])
params['edge_candidate_scorer'] = {
    'input_feature_dim': 3331,
    'output_size': 1
}

params['edge_type_selector'] = {
    'input_feature_dim': 3331,
    'output_size': 3
}


_no_more_edges_representation = torch.nn.Parameter(torch.FloatTensor(*params['no_more_edges_repr']), requires_grad = True)
_edge_candidate_scorer = GenericMLP(**params['edge_candidate_scorer'])
_edge_type_selector = GenericMLP(**params['edge_type_selector'])

In [150]:
from decoder import MLPDecoder
decoder = MLPDecoder(params)
edge_candidate_logits, edge_type_logits = decoder.pick_edge(
    input_molecule_representations,
    partial_graph_representions,
    node_representations,
    num_graphs_in_batch = len(batch.ptr) - 1,
    graph_to_focus_node_map= batch.focus_node,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    candidate_edge_features= batch.edge_features
)
decoder.compute_edge_candidate_selection_loss(
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
)

tensor(1.5850, grad_fn=<DivBackward0>)

# LightningModule + Vae MLP

1. Implement kd divergence loss as part of the lightning module

In [None]:
from molecule_generation.utils.training_utils import get_class_balancing_weights
from pytorch_lightning import LightningModule, Trainer, seed_everything



class BaseModel(LightningModule):
    def __init__(self, params, dataset):
        """Params is a nested dictionary with the relevant parameters."""
        super(BaseModel, self).__init__()
        self._init_params(params, dataset)
        self._motif_aware_embedding_layer = Embedding(params['motif_aware_embedding'])
        self._encoder_gnn = FullGraphEncoder(params['encoder_gnn'])
        self._decoder_gnn = PartialGraphEncoder(params['decoder_gnn'])
        self._decoder_mlp = MLPDecoder(params['decoder_mlp'])
        
        # params for latent space
        self._latent_sample_strategy = params['latent_sample_strategy']
        self._latent_repr_dim = params["latent_repr_size"]
        
        
               
    def _init_params(self, params, dataset):
        """
        Initialise class weights for next node prediction and placefolder for
        motif/node embeddings.
        """
        
        # Get some information out from the dataset:
        next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
        class_weight_factor = self._params.get("node_type_predictor_class_loss_weight_factor", 0.0)
        
        if not (0.0 <= class_weight_factor <= 1.0):
            raise ValueError(
                f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
            )
        if class_weight_factor > 0:
            atom_type_nums = [
                next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
                for type_idx in range(dataset.num_node_types)
            ]
            atom_type_nums.append(next_node_type_distribution["None"])

            self.class_weights = get_class_balancing_weights(
                class_counts=atom_type_nums, class_weight_factor=class_weight_factor
            )
        else:
            self.class_weights = None
            
        motif_vocabulary = dataset.metadata.get("motif_vocabulary")
        self._uses_motifs = motif_vocabulary is not None

        self._node_categorical_num_classes = dataset.node_categorical_num_classes
        
        
        if self.uses_categorical_features:
            if "categorical_features_embedding_dim" in self._params:
                self._node_categorical_features_embedding = None
        
    @property
    def uses_motifs(self):
        return self._uses_motifs

    @property
    def uses_categorical_features(self):
        return self._node_categorical_num_classes is not None

    @property
    def decoder(self):
        return self._decoder

    @property
    def encoder(self):
        return self._encoder
    
    @property
    def motif_aware_embedding_layer(self):
        return self._motif_aware_embedding_layer
    
    @property
    def latent_dim(self):
        return self._latent_repr_dim
    
    def compute_initial_node_features(batch )
        # Compute embedding
        pass
        
    
    def sample_from_latent_repr(latent_repr):
        # perturb latent repr
        mu = latent_repr[:, : self.latent_dim]  # Shape: [V, MD]
        log_var = latent_repr[:, self.latent_dim :]  # Shape: [V, MD]

        # result_representations: shape [num_partial_graphs, latent_repr_dim]
        p, q, z = self.sample(mu, log_var)
        
        return p, q, z 
        
    def sample(self, mu, log_var)
        """Samples a different noise vector for each partial graph. 
        TODO: look into the other sampling strategies."""
        std = torch.exp(log_var / 2)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return p, q, z
    
    
    def forward(self, x, edge_index, edge_attr, batch, ??):
        # Obtain node embeddings 
        batch = self.compute_initial_node_features(batch)
        
        # Forward pass through encoder
        latent_repr = self.encoder(batch)
        
        # Apply latent sampling strategy
        p, q, latent_repr = self.sample_from_latent_repr(latent_repr)
        
        # Forward pass through decoder
        node_type_logits, edge_candidate_logits, edge_type_logits, attachment_point_selection_logits = self.decoder(latent_repr)
        
        # NOTE: loss computation will be done in lightning module
        return MoLeROutput(
            node_type_logits = node_type_logits,
            edge_candidate_logits = edge_candidate_logits,
            edge_type_logits = edge_type_logits,
            attachment_point_selection_logits = attachment_point_selection_logits,
            p = p,
            q = q,
        )