# Design 

1. All model related loss computations and helper functions are in torch.modules and they will all be called compute_loss 
2. The MLPDecoder will house all the MLPs involved in the decoder and will have 1 compute_decoder_loss_method that will call all the individual compute_loss methods.

It will also have a decode method that will do decoding at inference time(TODO)

3. The FullGraphEncoder and the PartialGraphEncoder will each be in their own torch modules

4. Finally, the lightning module will have 3 things: 
- FullGraphEncoder
- PartialGraphEncoder (part of decoder)
- MLPDecoder

And after passing through the initial FullGraphEncoder, if we are working with a VAE, we will extract p and q for computing the kl divergence loss, otherwise we will do the other model specific stuff like diffusion.

`params` dictionary will be passed to the lightning module and each torch module will be constructed within it using the relevant parameters by destructuring the dictionary 

node type class weights will be instantiated in the lightning module and passed to the decoder

# TODO
1. fix the incrementing in the original graph edge index (DONE)
2. Work on first node prediction 
3. Investigate node_type_predictor_class_loss_weight_factor

In [1]:
# params houses all relevant model instantiation parameters
params = {}

In [3]:
%load_ext autoreload
%autoreload 2

from dataset import MolerDataset, MolerData
from utils import pprint_pyg_obj
from torch_geometric.loader import DataLoader


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [41]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/guacamol/trace_dir',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/already_batched',
    split = 'train_0',
)

In [42]:
loader = DataLoader(dataset, batch_size=1, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices',
    'original_graph_x',
    'correct_first_node_type_choices'
])

In [43]:
for batch in loader:
    break

# FullGraphEncoder

In [None]:
from model_utils import GenericGraphEncoder
import torch

In [None]:
class GraphEncoder(torch.nn.Module):
    """Returns graph level representation of the molecules."""
    def __init__(
        self,
        input_feature_dim,
        atom_or_motif_vocab_size,
        motif_embedding_size = 64,
        hidden_layer_feature_dim=64,
        num_layers=12,
        layer_type="RGATConv",
        use_intermediate_gnn_results=True,
    ):
        super(GraphEncoder, self).__init__()
        self._embed = torch.nn.Embedding(atom_or_motif_vocab_size, motif_embedding_size)
        self._model = GenericGraphEncoder(input_feature_dim = motif_embedding_size + input_feature_dim)
        
    def forward(self, original_graph_node_categorical_features, node_features, edge_index, edge_type, batch_index):
        motif_embeddings = self._embed(original_graph_node_categorical_features)
        node_features = torch.cat((node_features, motif_embeddings), axis = -1)
        input_molecule_representations, _ = self._model(node_features, edge_index.long(), edge_type, batch_index)
        return input_molecule_representations

In [None]:
params['full_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
    'atom_or_motif_vocab_size': len(dataset.node_type_index_to_string)
}

full_graph_encoder = GraphEncoder(
    input_feature_dim = batch.x.shape[-1],
    atom_or_motif_vocab_size = len(dataset.node_type_index_to_string)
)

full_graph_encoder = GraphEncoder(**params['full_graph_encoder'])

In [None]:
input_molecule_representations = full_graph_encoder(
    batch.original_graph_node_categorical_features, 
    batch.original_graph_x.float(),
    batch.original_graph_edge_index,
    batch.original_graph_edge_type,
    batch_index = batch.original_graph_x_batch,
)

# PartialGraphEncoder

In [44]:
from encoder import PartialGraphEncoder
from model_utils import get_params
params = get_params(dataset)
# partial_graph_encoder = PartialGraphEncoder(
#     input_feature_dim = batch.x.shape[-1],
#     atom_or_motif_vocab_size = len(dataset.node_type_index_to_string)
# )

partial_graph_encoder = PartialGraphEncoder(**params['partial_graph_encoder'])

In [45]:
params

{'full_graph_encoder': {'input_feature_dim': 59,
  'atom_or_motif_vocab_size': 166},
 'partial_graph_encoder': {'input_feature_dim': 59,
  'atom_or_motif_vocab_size': 166},
 'mean_log_var_mlp': {'input_feature_dim': 832, 'output_size': 1024},
 'decoder': {'node_type_selector': {'input_feature_dim': 1344,
   'output_size': 167},
  'node_type_loss_weights': tensor([10.0000,  0.1000,  3.6015,  0.1000,  0.1000,  0.4439,  0.7549,  0.4416,
          10.0000,  2.7939,  3.3916, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.000

In [47]:
dataset._node_type_index_to_string

{0: 'UNK',
 1: 'C',
 2: 'Br',
 3: 'N',
 4: 'O',
 5: 'S',
 6: 'Cl',
 7: 'F',
 8: 'P',
 9: 'N+',
 10: 'O-',
 11: 'I',
 12: 'N-',
 13: 'S+',
 14: 'B',
 15: 'Si',
 16: 'Cl+',
 17: 'C-',
 18: 'P+',
 19: 'B-',
 20: 'Se',
 21: 'O+',
 22: 'S-',
 23: 'I+',
 24: 'C+',
 25: 'F+',
 26: 'Se+',
 27: 'P-',
 28: 'I++',
 29: 'F-',
 30: 'Si-',
 31: 'Cl-',
 32: 'Se-',
 33: 'Cl+++',
 34: 'I+++',
 35: 'Br-',
 36: 'Cl++',
 37: 'Br++',
 38: 'C1=CC=CC=C1',
 39: 'C1=CC=NC=C1',
 40: 'NC=O',
 41: 'C1CCNCC1',
 42: 'C1CNCCN1',
 43: 'FC(F)F',
 44: 'C1CCNC1',
 45: 'O=CO',
 46: 'C1=CNN=C1',
 47: 'O=[N+][O-]',
 48: 'CCC',
 49: 'C1=CN=CN=C1',
 50: 'C1CCCCC1',
 51: 'C1COCCN1',
 52: 'C1=CSC=C1',
 53: 'C1=CC=C2NC=CC2=C1',
 54: 'CCO',
 55: 'C1=COC=C1',
 56: 'C1CC1',
 57: 'O=S=O',
 58: 'C1=CSC=N1',
 59: 'CNC=O',
 60: 'CC=O',
 61: 'CC(N)=O',
 62: 'N[SH](=O)=O',
 63: 'C1=CNC=N1',
 64: 'COC=O',
 65: 'C1=CC=C2C=CC=CC2=C1',
 66: 'CC(=O)O',
 67: 'C1=CC=C2NC=NC2=C1',
 68: 'C1CCOC1',
 69: 'CCOC=O',
 70: 'CC(C)C',
 71: 'CNC',
 72: '

In [48]:
atom_type_featuriser.index_to_atom_type_map

dataset.metadata.get("motif_vocabulary")

MotifVocabulary(vocabulary={'C1=CC=CC=C1': 0, 'C1=CC=NC=C1': 1, 'NC=O': 2, 'C1CCNCC1': 3, 'C1CNCCN1': 4, 'FC(F)F': 5, 'C1CCNC1': 6, 'O=CO': 7, 'C1=CNN=C1': 8, 'O=[N+][O-]': 9, 'CCC': 10, 'C1=CN=CN=C1': 11, 'C1CCCCC1': 12, 'C1COCCN1': 13, 'C1=CSC=C1': 14, 'C1=CC=C2NC=CC2=C1': 15, 'CCO': 16, 'C1=COC=C1': 17, 'C1CC1': 18, 'O=S=O': 19, 'C1=CSC=N1': 20, 'CNC=O': 21, 'CC=O': 22, 'CC(N)=O': 23, 'N[SH](=O)=O': 24, 'C1=CNC=N1': 25, 'COC=O': 26, 'C1=CC=C2C=CC=CC2=C1': 27, 'CC(=O)O': 28, 'C1=CC=C2NC=NC2=C1': 29, 'C1CCOC1': 30, 'CCOC=O': 31, 'CC(C)C': 32, 'CNC': 33, 'C1CCOCC1': 34, 'C1CCCC1': 35, 'C1=CON=C1': 36, 'C1=CC=C2N=CC=CC2=C1': 37, 'C1=CNC=C1': 38, 'C1=CN=C2C=CC=CC2=C1': 39, 'CCNC=O': 40, 'CCCC': 41, 'C1=CC=C2OCOC2=C1': 42, 'NC(N)=O': 43, 'CCN': 44, 'C[SH](=O)=O': 45, 'C1=NN=CN1': 46, 'C1=CNCNC1': 47, 'C1=CNN=N1': 48, 'C1=CC=C2NCCC2=C1': 49, 'C1=CC=C2SC=NC2=C1': 50, 'CCCO': 51, 'NC(=O)CS': 52, 'C1=NC=NO1': 53, 'C1=COC=N1': 54, 'C1CSCN1': 55, 'C1=CN=CC=N1': 56, 'C1=NC=NN1': 57, 'C1CNC1': 58

In [49]:
partial_graph_representions, node_representations = partial_graph_encoder(
    partial_graph_node_categorical_features = batch.partial_node_categorical_features,
    node_features = batch.x,
    edge_index = batch.edge_index.long(), 
    edge_features = batch.partial_graph_edge_features.int(), 
    graph_to_focus_node_map = batch.focus_node,
    candidate_attachment_points = batch.valid_attachment_point_choices,
    batch_index = batch.batch
)

In [50]:
node_representations.shape

torch.Size([16659, 832])

In [52]:
embedded_categorical_features = partial_graph_encoder._embed(batch.partial_node_categorical_features)
embedded_categorical_features.shape

torch.Size([16659, 64])

In [56]:
import torch
node_features = batch.x
initial_node_features = torch.cat(
    [node_features, embedded_categorical_features], axis=-1
)
initial_node_features.shape

torch.Size([16659, 123])

In [59]:
graph_to_focus_node_map = batch.focus_node
graph_to_focus_node_map.shape

torch.Size([1000])

In [60]:
candidate_attachment_points = batch.valid_attachment_point_choices
candidate_attachment_points

tensor([  428.,   429.,   771.,   772.,   775.,   776.,   777.,   778.,   810.,
          812.,   937.,   940.,   941.,  1061.,  1064.,  1065.,  1131.,  1134.,
         1135.,  1836.,  1837.,  1838.,  1839.,  2279.,  2280.,  2281.,  2282.,
         3000.,  3002.,  3007.,  3057.,  3059.,  3080.,  3082.,  3083.,  3084.,
         3256.,  3257.,  3258.,  3259.,  3260.,  4097.,  4098.,  4101.,  4268.,
         4270.,  4274.,  4275.,  4320.,  4321.,  4322.,  4378.,  4379.,  4380.,
         4381.,  4382.,  4846.,  4847.,  4903.,  4904.,  4905.,  4907.,  4983.,
         4985.,  5019.,  5020.,  5021.,  5023.,  5024.,  5025.,  5026.,  5250.,
         5253.,  5254.,  5255.,  5256.,  5258.,  5927.,  5928.,  5930.,  6867.,
         6870.,  7297.,  7298.,  7301.,  7323.,  7324.,  7794.,  7795.,  7796.,
         7797.,  7798.,  7857.,  7858.,  7859.,  7861.,  7862.,  7863.,  7864.,
         7896.,  7898.,  8271.,  8272.,  8273.,  8275.,  8276.,  8277.,  8279.,
         8450.,  8451.,  8452.,  8453., 

In [61]:
nodes_to_set_in_focus_bit = torch.cat(
    [graph_to_focus_node_map, candidate_attachment_points], axis=0
)
nodes_to_set_in_focus_bit

tensor([1.6000e+01, 1.7000e+01, 1.9000e+01,  ..., 1.6538e+04, 1.6539e+04,
        1.6541e+04], dtype=torch.float64)

In [68]:
nodes_to_set_in_focus_bit.int()

tensor([   16,    17,    19,  ..., 16538, 16539, 16541], dtype=torch.int32)

In [77]:
torch.all(node_is_in_focus_bit[nodes_to_set_in_focus_bit.long()]>=1)

tensor(True)

In [78]:
node_is_in_focus_bit

tensor([[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [1.]])

In [65]:
torch.ones(nodes_to_set_in_focus_bit.shape[0], 1)

tensor([[1.],
        [1.],
        [1.],
        ...,
        [1.],
        [1.],
        [1.]])

In [76]:
node_is_in_focus_bit_zeros = torch.zeros(node_features.shape[0], 1)
print(node_is_in_focus_bit_zeros.shape)
node_is_in_focus_bit = node_is_in_focus_bit_zeros.index_add_(
    dim = 0, 
    index = nodes_to_set_in_focus_bit.int(), 
    source = torch.ones(nodes_to_set_in_focus_bit.shape[0], 1)
)

torch.Size([16659, 1])


In [None]:
node_is_in_focus_bit = torch.minimum(node_is_in_focus_bit, torch.ones(1))

In [None]:
initial_node_features = torch.cat([initial_node_features, node_is_in_focus_bit], axis=-1)


In [None]:
initial_node_features.shape

In [None]:
partial_graph_representions, node_representations = partial_graph_encoder(initial_node_features, batch.edge_index.long(), batch.edge_type, batch.batch)

In [None]:
node_representations.shape

# _mean_log_var_mlp

In [None]:
from model_utils import GenericMLP
latent_dim = 512
params['mean_log_var_mlp'] = {
    'input_feature_dim': input_molecule_representations.shape[-1],
    'output_size': latent_dim * 2
}


mean_log_var_mlp = GenericMLP(**params['mean_log_var_mlp'])

In [None]:
mean_and_log_var = mean_log_var_mlp(input_molecule_representations)

In [None]:
mu = mean_and_log_var[:, : latent_dim]  # Shape: [V, MD]
log_var = mean_and_log_var[:, latent_dim :]  # Shape: [V, MD]

# result_representations: shape [num_partial_graphs, latent_repr_dim]
std = torch.exp(log_var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()

In [None]:
z.shape

# Decoder

In [None]:
from decoder import MLPDecoder

## PickAtomOrMotif

In [None]:
from molecule_generation.utils.training_utils import get_class_balancing_weights


next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
class_weight_factor = params.get("node_type_predictor_class_loss_weight_factor", 1.0)

if not (0.0 <= class_weight_factor <= 1.0):
    raise ValueError(
        f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
    )
if class_weight_factor > 0:
    atom_type_nums = [
        next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
        for type_idx in range(dataset.num_node_types)
    ]
    atom_type_nums.append(next_node_type_distribution["None"])

    class_weights = get_class_balancing_weights(
        class_counts=atom_type_nums, class_weight_factor=class_weight_factor
    )
else:
    class_weights = None
    
    
    
params['node_type_loss_weights'] = torch.tensor(class_weights)

In [None]:
from model_utils import GenericMLP
params['node_type_selector'] = {
    'input_feature_dim':  z.shape[-1] + partial_graph_representions.shape[-1], 
    'output_size': dataset.num_node_types + 1
}


graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()

node_type_selector = GenericMLP(
    input_feature_dim = z.shape[-1] + partial_graph_representions.shape[-1],
    output_size = dataset.num_node_types,
)

node_type_selector = GenericMLP(**params['node_type_selector'])

### node loss computation in the forward method

In [None]:
decoder = MLPDecoder(params)

In [None]:
node_logits = decoder.pick_node_type(
    z,
    partial_graph_representions,
    graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()
)

In [None]:
num_correct_node_type_choices = batch.correct_node_type_choices_ptr.unique().shape[-1] -1
node_type_multihot_labels = batch.correct_node_type_choices.view(num_correct_node_type_choices, -1)

In [None]:
# node_type_multihot_labels = []
# for i in range(len(batch.correct_node_type_choices_ptr)-1):
#     start_idx = batch.correct_node_type_choices_ptr[i]
#     end_idx = batch.correct_node_type_choices_ptr[i+1] 
#     if end_idx - start_idx == 0:
#         continue
#     node_selection_labels = batch.correct_node_type_choices[start_idx: end_idx]
#     node_type_multihot_labels += [node_selection_labels]
    
# node_type_multihot_labels = torch.stack(node_type_multihot_labels, axis = 0)

In [None]:
node_type_selection_loss = decoder.compute_node_type_selection_loss(
    node_logits,
    node_type_multihot_labels
)

# PickEdge

In [None]:

params['no_more_edges_repr'] = (1,node_representations.shape[-1] + batch.edge_features.shape[-1])
params['edge_candidate_scorer'] = {
    'input_feature_dim': 3011,
    'output_size': 1
}

params['edge_type_selector'] = {
    'input_feature_dim': 3011,
    'output_size': 3
}


_no_more_edges_representation = torch.nn.Parameter(torch.FloatTensor(*params['no_more_edges_repr']), requires_grad = True)
_edge_candidate_scorer = GenericMLP(**params['edge_candidate_scorer'])
_edge_type_selector = GenericMLP(**params['edge_type_selector'])

In [None]:
 torch.nn.Parameter(torch.rand(*params['no_more_edges_repr']), requires_grad = True)

In [None]:
from decoder import MLPDecoder
decoder = MLPDecoder(params)
edge_candidate_logits, edge_type_logits = decoder.pick_edge(
    z,
    partial_graph_representions,
    node_representations,
    num_graphs_in_batch = len(batch.ptr) - 1,
    graph_to_focus_node_map= batch.focus_node,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    candidate_edge_features= batch.edge_features
)
decoder.compute_edge_candidate_selection_loss(
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
)

In [None]:
decoder.compute_edge_type_selection_loss(
    batch.valid_edge_types,  # batch.valid_edge_types
    edge_type_logits,
    batch.correct_edge_choices,  # batch.correct_edge_choices
    edge_type_onehot_labels = batch.correct_edge_types,  # batch.correct_edge_types
)

In [None]:
decoder.compute_edge_type_selection_loss(
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
)

## pick attachement point

In [None]:
tmp = []
for batch2 in loader:
    if len(batch2.correct_attachment_point_choice) > 0 :
        tmp.append(batch2)

In [None]:
sample_idx = 1

tmp[sample_idx].correct_attachment_point_choice, tmp[sample_idx].valid_attachment_point_choices, tmp[sample_idx].valid_attachment_point_choices_batch

In [None]:
batch2 = tmp[sample_idx]

In [None]:
params['attachment_point_selector'] = {
    'input_feature_dim': 2176,
    'output_size': 1
}
_attachment_point_selector = GenericMLP(**params['attachment_point_selector'])
def pick_attachment_point(
    input_molecule_representations, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map, # batch.batch
    candidate_attachment_points, # valid_attachment_point_choices
):
    original_and_calculated_graph_representations = torch.cat(
        [input_molecule_representations, partial_graph_representions],
        axis=-1,
    )  # Shape: [PG, MD + PD]
    
    # Map attachment point candidates to their respective partial graphs.
    partial_graphs_for_attachment_point_choices = node_to_graph_map[candidate_attachment_points] # Shape: [CA]
    
    # To score an attachment point, we condition on the representations of input and partial
    # graphs, along with the representation of the attachment point candidate in question.
    attachment_point_representations = torch.cat(
        [
            original_and_calculated_graph_representations[partial_graphs_for_attachment_point_choices],
            node_representations[candidate_attachment_points],
        ],
        axis=-1,
    )  # Shape: [CA, MD + PD + VD*(num_layers+1)]
    print(attachment_point_representations.shape)
    attachment_point_selection_logits = torch.squeeze(_attachment_point_selector(attachment_point_representations), axis = -1)

    
    return attachment_point_selection_logits

In [None]:
input_molecule_representations = full_graph_encoder(
    batch2.original_graph_node_categorical_features, 
    batch2.original_graph_x.float(),
    batch2.original_graph_edge_index,
    batch2.original_graph_edge_type,
    batch_index = batch2.original_graph_x_batch,
)


partial_graph_representions, node_representations = partial_graph_encoder(batch2.x, batch2.edge_index.long(), batch2.edge_type.int(), batch2.batch)

In [None]:
batch2.valid_attachment_point_choices_batch

In [None]:

attachment_point_selection_logits = pick_attachment_point(
    z, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map= batch2.batch,
    candidate_attachment_points = batch2.valid_attachment_point_choices.long()
)

In [None]:
def compute_attachment_point_selection_loss(
    attachment_point_selection_logits, # as is
    attachment_point_candidate_to_graph_map,# = batch2.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices,# = batch2.correct_attachment_point_choices
):
    # Compute log softmax of the logits within each partial graph.
    attachment_point_candidate_logprobs = (
        traced_unsorted_segment_log_softmax(
            logits=attachment_point_selection_logits,
            segment_ids=attachment_point_candidate_to_graph_map,
        )
        * 1.0
    )  # Shape: [CA]
    
    attachment_point_correct_choice_neglogprobs = -attachment_point_candidate_logprobs[attachment_point_correct_choices]
     # Shape: [AP]
    
    attachment_point_selection_loss = safe_divide_loss(
        (attachment_point_correct_choice_neglogprobs).sum(),
        attachment_point_correct_choice_neglogprobs.shape[0],
    )
    return attachment_point_selection_loss

In [None]:
# compute_attachment_point_selection_loss(
#     attachment_point_selection_logits =  attachment_point_selection_logits,
#     attachment_point_candidate_to_graph_map = batch2.valid_attachment_point_choices_batch.long(),
#     attachment_point_correct_choices = batch2.correct_attachment_point_choice.long()
# )

In [None]:
decoder

In [None]:
params

In [None]:
from decoder import MLPDecoder


decoder = MLPDecoder(params)
attachment_point_selection_logits = decoder.pick_attachment_point(
    z, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map= batch2.batch,
    candidate_attachment_points = batch2.valid_attachment_point_choices.long()
)
decoder.compute_attachment_point_selection_loss(
    attachment_point_selection_logits =  attachment_point_selection_logits,
    attachment_point_candidate_to_graph_map = batch2.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices = batch2.correct_attachment_point_choice.long()
)

In [None]:

node_logits,edge_candidate_logits,edge_type_logits,attachment_point_selection_logits = decoder(
    input_molecule_representations = z,
    graph_representations = partial_graph_representions,
    graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique(),
    # edge selection
    node_representations = node_representations,
    num_graphs_in_batch = len(batch.ptr) -1,
    graph_to_focus_node_map =batch.focus_node,
    node_to_graph_map = batch.batch,
    candidate_edge_targets = batch.valid_edge_choices[:, 1].long(),
    candidate_edge_features = batch.edge_features,
    # attachment selection
    candidate_attachment_points = batch.valid_attachment_point_choices.long(),
)


loss = decoder.compute_decoder_loss(
    node_type_logits = node_logits,
    node_type_multihot_labels = node_type_multihot_labels,
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
    attachment_point_selection_logits =  attachment_point_selection_logits,
    attachment_point_candidate_to_graph_map = batch.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices = batch.correct_attachment_point_choice.long()
)

In [None]:
loss

# pick first node type


In [None]:
batch.correct_first_node_type_choices_batch

In [None]:
from model_utils import GenericMLP
params['first_node_type_selector'] = {
    'input_feature_dim':  partial_graph_representions.shape[-1], 
    'output_size': dataset.num_node_types # cant have no node as the starting point so no need to + 1
}

first_node_type_selector = GenericMLP(**params['first_node_type_selector'])

In [None]:
params

In [None]:
def pick_first_node_type(
    partial_graph_representions
):
    return first_node_type_selector(partial_graph_representions)

In [None]:
first_node_type_logits = pick_first_node_type(partial_graph_representions)

In [None]:
def compute_first_node_type_selection_loss(
    first_node_type_logits,
    first_node_type_multihot_labels,
):
    per_graph_logprobs = torch.nn.functional.log_softmax(first_node_type_logits, dim = -1)
    per_graph_num_correct_choices = torch.sum(first_node_type_multihot_labels, axis = -1, keepdims = True)
    per_graph_normalised_neglogprob = compute_neglogprob_for_multihot_objective(
        logprobs=per_graph_logprobs,
        multihot_labels=first_node_type_multihot_labels,
        per_decision_num_correct_choices=per_graph_num_correct_choices,
    ) 
#     if self._first_node_type_loss_weights is not None:
#         per_graph_normalised_neglogprob *= self._node_type_loss_weights[:-1]
        
    first_node_type_loss = safe_divide_loss(
        torch.sum(per_graph_normalised_neglogprob),
        first_node_type_multihot_labels.shape[0],
    )
    return first_node_type_loss

In [None]:
first_node_type_multihot_labels = batch.correct_first_node_type_choices.view(len(batch.ptr) -1, -1)

compute_first_node_type_selection_loss(
    first_node_type_logits,
    first_node_type_multihot_labels,
)

In [None]:
params['decoder'] = {
    'node_type_selector': params["node_type_selector"],
    'node_type_loss_weights': params["node_type_loss_weights"],
    'no_more_edges_repr':params["no_more_edges_repr"],
    'edge_candidate_scorer':params["edge_candidate_scorer"],
    'edge_type_selector':params["edge_type_selector"], 
    'attachment_point_selector':params["attachment_point_selector"],
}

params['latent_sample_strategy'] = 'per_graph'
params['latent_repr_size'] = 512

In [None]:
from model_utils import get_params

In [None]:
%load_ext autoreload
%autoreload 2
from model import BaseModel

params= get_params()
model = BaseModel(params, dataset)

moler_output = model._run_step(batch)

loss = model.compute_loss(moler_output, batch)

loss

In [None]:
model


In [None]:
model.step(batch)

In [None]:
from pytorch_lightning import Trainer
from torch_geometric.data.lightning_datamodule import LightningDataset
# datamodule = LightningDataset(dataset)
trainer = Trainer(overfit_batches=1)
trainer.fit(model, loader, loader)
# trainer.fit(model, datamodule)

# LightningModule + Vae MLP

1. Implement kd divergence loss as part of the lightning module
2. Investigate where node_type_predictor_class_loss_weight_factor is supposed to come from, otherwise, default to 1