# Design 

1. All model related loss computations and helper functions are in torch.modules and they will all be called compute_loss 
2. The MLPDecoder will house all the MLPs involved in the decoder and will have 1 compute_decoder_loss_method that will call all the individual compute_loss methods.

It will also have a decode method that will do decoding at inference time(TODO)

3. The FullGraphEncoder and the PartialGraphEncoder will each be in their own torch modules

4. Finally, the lightning module will have 3 things: 
- FullGraphEncoder
- PartialGraphEncoder (part of decoder)
- MLPDecoder

And after passing through the initial FullGraphEncoder, if we are working with a VAE, we will extract p and q for computing the kl divergence loss, otherwise we will do the other model specific stuff like diffusion.

`params` dictionary will be passed to the lightning module and each torch module will be constructed within it using the relevant parameters by destructuring the dictionary 

node type class weights will be instantiated in the lightning module and passed to the decoder

# TODO
1. fix the incrementing in the original graph edge index (DONE)
2. Work on first node prediction 
3. Investigate node_type_predictor_class_loss_weight_factor

In [1]:
# params houses all relevant model instantiation parameters
params = {}

In [4]:
%load_ext autoreload
%autoreload 2

from dataset import MolerDataset, MolerData
from utils import pprint_pyg_obj
from torch_geometric.loader import DataLoader


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/l1000/trace_playground',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/pyg_output_playground',
    split = 'train',
)

In [6]:
loader = DataLoader(dataset, batch_size=16, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices',
    'original_graph_x'
])

In [7]:
for batch in loader:
    break

# FullGraphEncoder

In [6]:
from model_utils import GenericGraphEncoder
import torch

In [7]:
class GraphEncoder(torch.nn.Module):
    """Returns graph level representation of the molecules."""
    def __init__(
        self,
        input_feature_dim,
        atom_or_motif_vocab_size,
        motif_embedding_size = 64,
        hidden_layer_feature_dim=64,
        num_layers=12,
        layer_type="RGATConv",
        use_intermediate_gnn_results=True,
    ):
        super(GraphEncoder, self).__init__()
        self._embed = torch.nn.Embedding(atom_or_motif_vocab_size, motif_embedding_size)
        self._model = GenericGraphEncoder(input_feature_dim = motif_embedding_size + input_feature_dim)
        
    def forward(self, original_graph_node_categorical_features, node_features, edge_index, edge_type, batch_index):
        motif_embeddings = self._embed(original_graph_node_categorical_features)
        node_features = torch.cat((node_features, motif_embeddings), axis = -1)
        input_molecule_representations, _ = self._model(node_features, edge_index.long(), edge_type, batch_index)
        return input_molecule_representations

In [8]:
params['full_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
    'atom_or_motif_vocab_size': len(dataset.node_type_index_to_string)
}

full_graph_encoder = GraphEncoder(
    input_feature_dim = batch.x.shape[-1],
    atom_or_motif_vocab_size = len(dataset.node_type_index_to_string)
)

full_graph_encoder = GraphEncoder(**params['full_graph_encoder'])

In [9]:
input_molecule_representations = full_graph_encoder(
    batch.original_graph_node_categorical_features, 
    batch.original_graph_x.float(),
    batch.original_graph_edge_index,
    batch.original_graph_edge_type,
    batch_index = batch.original_graph_x_batch,
)

# PartialGraphEncoder

In [10]:
params['partial_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
}

partial_graph_encoder = GenericGraphEncoder(
    input_feature_dim = batch.x.shape[-1],
)

partial_graph_encoder = GenericGraphEncoder(**params['partial_graph_encoder'])

In [11]:
partial_graph_representions, node_representations = partial_graph_encoder(batch.x, batch.edge_index.long(), batch.edge_type, batch.batch)

In [12]:
node_representations.shape

torch.Size([193, 832])

# _mean_log_var_mlp

In [13]:
from model_utils import GenericMLP
latent_dim = 512
params['mean_log_var_mlp'] = {
    'input_feature_dim': input_molecule_representations.shape[-1],
    'output_size': latent_dim * 2
}


mean_log_var_mlp = GenericMLP(**params['mean_log_var_mlp'])

In [14]:
mean_and_log_var = mean_log_var_mlp(input_molecule_representations)

In [15]:
mu = mean_and_log_var[:, : latent_dim]  # Shape: [V, MD]
log_var = mean_and_log_var[:, latent_dim :]  # Shape: [V, MD]

# result_representations: shape [num_partial_graphs, latent_repr_dim]
std = torch.exp(log_var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()

In [16]:
z.shape

torch.Size([16, 512])

# Decoder

In [17]:
from decoder import MLPDecoder

## PickAtomOrMotif

In [18]:
from molecule_generation.utils.training_utils import get_class_balancing_weights


next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
class_weight_factor = params.get("node_type_predictor_class_loss_weight_factor", 1.0)

if not (0.0 <= class_weight_factor <= 1.0):
    raise ValueError(
        f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
    )
if class_weight_factor > 0:
    atom_type_nums = [
        next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
        for type_idx in range(dataset.num_node_types)
    ]
    atom_type_nums.append(next_node_type_distribution["None"])

    class_weights = get_class_balancing_weights(
        class_counts=atom_type_nums, class_weight_factor=class_weight_factor
    )
else:
    class_weights = None
    
    
    
params['node_type_loss_weights'] = torch.tensor(class_weights)

In [19]:
from model_utils import GenericMLP
params['node_type_selector'] = {
    'input_feature_dim':  z.shape[-1] + partial_graph_representions.shape[-1], 
    'output_size': dataset.num_node_types + 1
}


graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()

node_type_selector = GenericMLP(
    input_feature_dim = z.shape[-1] + partial_graph_representions.shape[-1],
    output_size = dataset.num_node_types,
)

node_type_selector = GenericMLP(**params['node_type_selector'])

### node loss computation in the forward method

In [20]:
decoder = MLPDecoder(params)

KeyError: 'no_more_edges_repr'

In [None]:
node_logits = decoder.pick_node_type(
    z,
    partial_graph_representions,
    graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()
)

In [None]:
num_correct_node_type_choices = batch.correct_node_type_choices_ptr.unique().shape[-1] -1
node_type_multihot_labels = batch.correct_node_type_choices.view(num_correct_node_type_choices, -1)

In [None]:
# node_type_multihot_labels = []
# for i in range(len(batch.correct_node_type_choices_ptr)-1):
#     start_idx = batch.correct_node_type_choices_ptr[i]
#     end_idx = batch.correct_node_type_choices_ptr[i+1] 
#     if end_idx - start_idx == 0:
#         continue
#     node_selection_labels = batch.correct_node_type_choices[start_idx: end_idx]
#     node_type_multihot_labels += [node_selection_labels]
    
# node_type_multihot_labels = torch.stack(node_type_multihot_labels, axis = 0)

In [71]:
node_type_selection_loss = decoder.compute_node_type_selection_loss(
    node_logits,
    node_type_multihot_labels
)

# PickEdge

In [21]:

params['no_more_edges_repr'] = (1,node_representations.shape[-1] + batch.edge_features.shape[-1])
params['edge_candidate_scorer'] = {
    'input_feature_dim': 3011,
    'output_size': 1
}

params['edge_type_selector'] = {
    'input_feature_dim': 3011,
    'output_size': 3
}


_no_more_edges_representation = torch.nn.Parameter(torch.FloatTensor(*params['no_more_edges_repr']), requires_grad = True)
_edge_candidate_scorer = GenericMLP(**params['edge_candidate_scorer'])
_edge_type_selector = GenericMLP(**params['edge_type_selector'])

In [73]:
 torch.nn.Parameter(torch.rand(*params['no_more_edges_repr']), requires_grad = True)

Parameter containing:
tensor([[0.9971, 0.6084, 0.7427, 0.7022, 0.1249, 0.1739, 0.7455, 0.7201, 0.0630,
         0.1107, 0.8570, 0.9180, 0.7902, 0.7420, 0.6279, 0.2161, 0.0617, 0.8683,
         0.9575, 0.5330, 0.3011, 0.6251, 0.8214, 0.1165, 0.2619, 0.5124, 0.0169,
         0.2141, 0.7995, 0.9968, 0.5250, 0.5041, 0.7605, 0.1113, 0.2157, 0.7756,
         0.2733, 0.7232, 0.3839, 0.9642, 0.6305, 0.5619, 0.9571, 0.1208, 0.3966,
         0.6586, 0.7224, 0.1828, 0.7047, 0.3445, 0.6042, 0.7228, 0.0972, 0.1349,
         0.5780, 0.6872, 0.6862, 0.1813, 0.3681, 0.0688, 0.2306, 0.6856, 0.0269,
         0.8810, 0.4770, 0.2878, 0.8051, 0.8592, 0.8275, 0.3095, 0.4563, 0.5700,
         0.7781, 0.4993, 0.2236, 0.2904, 0.9160, 0.6280, 0.4211, 0.4221, 0.9815,
         0.0770, 0.3525, 0.6074, 0.4632, 0.2368, 0.6255, 0.8887, 0.0804, 0.7062,
         0.3688, 0.0948, 0.6150, 0.8482, 0.0598, 0.6699, 0.9480, 0.7518, 0.1468,
         0.9098, 0.2061, 0.1497, 0.9135, 0.5522, 0.3129, 0.3556, 0.2689, 0.2434,
      

In [24]:
from decoder import MLPDecoder
decoder = MLPDecoder(params)
edge_candidate_logits, edge_type_logits = decoder.pick_edge(
    z,
    partial_graph_representions,
    node_representations,
    num_graphs_in_batch = len(batch.ptr) - 1,
    graph_to_focus_node_map= batch.focus_node,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    candidate_edge_features= batch.edge_features
)
decoder.compute_edge_candidate_selection_loss(
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
)

tensor(1.5894, grad_fn=<DivBackward0>)

In [25]:
decoder.compute_edge_type_selection_loss(
    batch.valid_edge_types,  # batch.valid_edge_types
    edge_type_logits,
    batch.correct_edge_choices,  # batch.correct_edge_choices
    edge_type_onehot_labels = batch.correct_edge_types,  # batch.correct_edge_types
)

tensor([[-0.0608, -0.0029,  0.1712],
        [-0.1539, -0.0618,  0.1270],
        [-0.2074, -0.1262,  0.2669],
        [-0.1903, -0.2312,  0.2555],
        [-0.2037, -0.1158,  0.1301],
        [-0.1033, -0.0862,  0.2385],
        [ 0.1175, -0.0780,  0.0936],
        [-0.0079, -0.0690,  0.0794],
        [ 0.0171, -0.1130,  0.1455]], grad_fn=<IndexBackward0>)
tensor([[1, 0, 0],
        [0, 1, 0],
        [1, 0, 0],
        [0, 1, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 1, 0]], dtype=torch.int32)
tensor([[-6.0778e-02, -1.0000e+07, -1.0000e+07],
        [-1.5395e-01, -6.1793e-02,  1.2696e-01],
        [-2.0740e-01, -1.0000e+07, -1.0000e+07],
        [-1.9027e-01, -2.3120e-01,  2.5546e-01],
        [-2.0374e-01, -1.0000e+07, -1.0000e+07],
        [-1.0328e-01, -8.6219e-02, -1.0000e+07],
        [ 1.1754e-01, -1.0000e+07, -1.0000e+07],
        [-7.9075e-03, -1.0000e+07, -1.0000e+07],
        [ 1.7055e-02, -1.1305e-01, -1.0000e+07]], grad_fn

tensor(0.4333, grad_fn=<DivBackward0>)

In [76]:
decoder.compute_edge_type_selection_loss(
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
)

NameError: name 'edge_candidate_logits' is not defined

## pick attachement point

In [77]:
tmp = []
for batch2 in loader:
    if len(batch2.correct_attachment_point_choice) > 0 :
        tmp.append(batch2)

In [78]:
sample_idx = 1

tmp[sample_idx].correct_attachment_point_choice, tmp[sample_idx].valid_attachment_point_choices, tmp[sample_idx].valid_attachment_point_choices_batch

(tensor([0., 4.]),
 tensor([133., 135., 140., 141., 200., 203., 204., 205.], dtype=torch.float64),
 tensor([ 8,  8,  8,  8, 13, 13, 13, 13]))

In [79]:
batch2 = tmp[sample_idx]

In [23]:
params['attachment_point_selector'] = {
    'input_feature_dim': 2176,
    'output_size': 1
}
_attachment_point_selector = GenericMLP(**params['attachment_point_selector'])
def pick_attachment_point(
    input_molecule_representations, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map, # batch.batch
    candidate_attachment_points, # valid_attachment_point_choices
):
    original_and_calculated_graph_representations = torch.cat(
        [input_molecule_representations, partial_graph_representions],
        axis=-1,
    )  # Shape: [PG, MD + PD]
    
    # Map attachment point candidates to their respective partial graphs.
    partial_graphs_for_attachment_point_choices = node_to_graph_map[candidate_attachment_points] # Shape: [CA]
    
    # To score an attachment point, we condition on the representations of input and partial
    # graphs, along with the representation of the attachment point candidate in question.
    attachment_point_representations = torch.cat(
        [
            original_and_calculated_graph_representations[partial_graphs_for_attachment_point_choices],
            node_representations[candidate_attachment_points],
        ],
        axis=-1,
    )  # Shape: [CA, MD + PD + VD*(num_layers+1)]
    print(attachment_point_representations.shape)
    attachment_point_selection_logits = torch.squeeze(_attachment_point_selector(attachment_point_representations), axis = -1)

    
    return attachment_point_selection_logits

In [81]:
input_molecule_representations = full_graph_encoder(
    batch2.original_graph_node_categorical_features, 
    batch2.original_graph_x.float(),
    batch2.original_graph_edge_index,
    batch2.original_graph_edge_type,
    batch_index = batch2.original_graph_x_batch,
)


partial_graph_representions, node_representations = partial_graph_encoder(batch2.x, batch2.edge_index.long(), batch2.edge_type.int(), batch2.batch)

In [82]:
batch2.valid_attachment_point_choices_batch

tensor([ 8,  8,  8,  8, 13, 13, 13, 13])

In [83]:

attachment_point_selection_logits = pick_attachment_point(
    z, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map= batch2.batch,
    candidate_attachment_points = batch2.valid_attachment_point_choices.long()
)

torch.Size([8, 2176])


In [84]:
def compute_attachment_point_selection_loss(
    attachment_point_selection_logits, # as is
    attachment_point_candidate_to_graph_map,# = batch2.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices,# = batch2.correct_attachment_point_choices
):
    # Compute log softmax of the logits within each partial graph.
    attachment_point_candidate_logprobs = (
        traced_unsorted_segment_log_softmax(
            logits=attachment_point_selection_logits,
            segment_ids=attachment_point_candidate_to_graph_map,
        )
        * 1.0
    )  # Shape: [CA]
    
    attachment_point_correct_choice_neglogprobs = -attachment_point_candidate_logprobs[attachment_point_correct_choices]
     # Shape: [AP]
    
    attachment_point_selection_loss = safe_divide_loss(
        (attachment_point_correct_choice_neglogprobs).sum(),
        attachment_point_correct_choice_neglogprobs.shape[0],
    )
    return attachment_point_selection_loss

In [85]:
# compute_attachment_point_selection_loss(
#     attachment_point_selection_logits =  attachment_point_selection_logits,
#     attachment_point_candidate_to_graph_map = batch2.valid_attachment_point_choices_batch.long(),
#     attachment_point_correct_choices = batch2.correct_attachment_point_choice.long()
# )

In [86]:
decoder

MLPDecoder(
  (_node_type_selector): GenericMLP(
    (_first_layer): Linear(in_features=1344, out_features=64, bias=True)
    (_hidden_layers): ModuleList(
      (0): LeakyReLU(negative_slope=0.01)
      (1): Dropout(p=0.2, inplace=False)
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
    (_output_layer): Sequential(
      (0): LeakyReLU(negative_slope=0.01)
      (1): Dropout(p=0.2, inplace=False)
      (2): Linear(in_features=64, out_features=140, bias=True)
    )
  )
  (_edge_candidate_scorer): GenericMLP(
    (_first_layer): Linear(in_features=3011, out_features=64, bias=True)
    (_hidden_layers): ModuleList(
      (0): LeakyReLU(negative_slope=0.01)
      (1): Dropout(p=0.2, inplace=False)
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
    (_output_layer): Sequential(
      (0): LeakyReLU(negative_slope=0.01)
      (1): Dropout(p=0.2, inplace=False)
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (_edge_type_selec

In [40]:
params

{'full_graph_encoder': {'input_feature_dim': 32,
  'atom_or_motif_vocab_size': 139},
 'partial_graph_encoder': {'input_feature_dim': 32},
 'mean_log_var_mlp': {'input_feature_dim': 832, 'output_size': 1024},
 'node_type_loss_weights': tensor([10.0000,  0.1000,  0.1000,  0.1000,  0.7879,  0.4924,  0.6060, 10.0000,
          7.8786, 10.0000,  7.8786,  0.1000,  0.6565,  0.6565,  0.9848,  0.8754,
          0.8754,  1.1255,  0.9848,  1.3131,  1.5757,  1.9696,  1.5757,  1.9696,
          2.6262,  1.9696,  1.9696,  7.8786,  7.8786,  3.9393,  2.6262,  2.6262,
          2.6262,  2.6262,  3.9393,  7.8786,  7.8786,  7.8786,  3.9393,  7.8786,
         10.0000,  7.8786,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,
          3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  7.8786,  7.8786,
         10.0000, 10.0000,  7.8786,  7.8786, 10.0000,  7.8786,  7.8786, 10.0000,
          7.8786,  7.8786, 10.0000,  7.8786, 10.0000,  7.8786,  7.8786, 10.0000,
          7.8786,  7.8786,  7.8786, 1

In [41]:
from decoder import MLPDecoder


decoder = MLPDecoder(params)
attachment_point_selection_logits = decoder.pick_attachment_point(
    z, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map= batch2.batch,
    candidate_attachment_points = batch2.valid_attachment_point_choices.long()
)
decoder.compute_attachment_point_selection_loss(
    attachment_point_selection_logits =  attachment_point_selection_logits,
    attachment_point_candidate_to_graph_map = batch2.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices = batch2.correct_attachment_point_choice.long()
)

tensor(1.4267, grad_fn=<DivBackward0>)

In [44]:

node_logits,edge_candidate_logits,edge_type_logits,attachment_point_selection_logits = decoder(
    input_molecule_representations = z,
    graph_representations = partial_graph_representions,
    graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique(),
    # edge selection
    node_representations = node_representations,
    num_graphs_in_batch = len(batch.ptr) -1,
    graph_to_focus_node_map =batch.focus_node,
    node_to_graph_map = batch.batch,
    candidate_edge_targets = batch.valid_edge_choices[:, 1].long(),
    candidate_edge_features = batch.edge_features,
    # attachment selection
    candidate_attachment_points = batch.valid_attachment_point_choices.long(),
)


loss = decoder.compute_decoder_loss(
    node_type_logits = node_logits,
    node_type_multihot_labels = node_type_multihot_labels,
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
    attachment_point_selection_logits =  attachment_point_selection_logits,
    attachment_point_candidate_to_graph_map = batch.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices = batch.correct_attachment_point_choice.long()
)

In [45]:
loss

tensor(2.0684, grad_fn=<AddBackward0>)

In [26]:
params['decoder'] = {
    'node_type_selector': params["node_type_selector"],
    'node_type_loss_weights': params["node_type_loss_weights"],
    'no_more_edges_repr':params["no_more_edges_repr"],
    'edge_candidate_scorer':params["edge_candidate_scorer"],
    'edge_type_selector':params["edge_type_selector"], 
    'attachment_point_selector':params["attachment_point_selector"],
}

params['latent_sample_strategy'] = 'per_graph'
params['latent_repr_size'] = 512

In [2]:
import torch
params = {'full_graph_encoder': {'input_feature_dim': 32,
  'atom_or_motif_vocab_size': 139},
 'partial_graph_encoder': {'input_feature_dim': 32},
 'mean_log_var_mlp': {'input_feature_dim': 832, 'output_size': 1024},
 'decoder': {'node_type_selector': {'input_feature_dim': 1344,
   'output_size': 140},
  'node_type_loss_weights': torch.tensor([10.0000,  0.1000,  0.1000,  0.1000,  0.7879,  0.4924,  0.6060, 10.0000,
           7.8786, 10.0000,  7.8786,  0.1000,  0.6565,  0.6565,  0.9848,  0.8754,
           0.8754,  1.1255,  0.9848,  1.3131,  1.5757,  1.9696,  1.5757,  1.9696,
           2.6262,  1.9696,  1.9696,  7.8786,  7.8786,  3.9393,  2.6262,  2.6262,
           2.6262,  2.6262,  3.9393,  7.8786,  7.8786,  7.8786,  3.9393,  7.8786,
          10.0000,  7.8786,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,
           3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  7.8786,  7.8786,
          10.0000, 10.0000,  7.8786,  7.8786, 10.0000,  7.8786,  7.8786, 10.0000,
           7.8786,  7.8786, 10.0000,  7.8786, 10.0000,  7.8786,  7.8786, 10.0000,
           7.8786,  7.8786,  7.8786, 10.0000, 10.0000,  7.8786,  7.8786,  7.8786,
           7.8786,  7.8786, 10.0000, 10.0000, 10.0000, 10.0000,  7.8786, 10.0000,
          10.0000, 10.0000,  7.8786, 10.0000,  7.8786, 10.0000,  7.8786, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,  7.8786,
          10.0000,  7.8786,  7.8786,  7.8786,  7.8786, 10.0000,  7.8786, 10.0000,
          10.0000, 10.0000,  7.8786,  7.8786,  7.8786,  7.8786,  7.8786,  7.8786,
           7.8786,  7.8786,  7.8786,  7.8786,  7.8786,  7.8786,  7.8786,  7.8786,
           7.8786,  7.8786,  7.8786,  7.8786,  7.8786,  7.8786,  7.8786,  7.8786,
           7.8786,  7.8786,  7.8786,  0.1000]),
  'no_more_edges_repr': (1, 835),
  'edge_candidate_scorer': {'input_feature_dim': 3011, 'output_size': 1},
  'edge_type_selector': {'input_feature_dim': 3011, 'output_size': 3},
  'attachment_point_selector': {'input_feature_dim': 2176, 'output_size': 1}},
 'latent_sample_strategy': 'per_graph',
 'latent_repr_dim': 512,
 'latent_repr_size': 512}

In [23]:
%load_ext autoreload
%autoreload 2
from model import BaseModel
model = BaseModel(params, dataset)

moler_output = model._run_step(batch)

loss = model.compute_loss(moler_output, batch)

loss

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


tensor(2.4433, grad_fn=<AddBackward0>)

In [19]:
model.step(batch)

(tensor(2.4733, grad_fn=<AddBackward0>),
 {'decoder_loss': tensor(2.4473, grad_fn=<AddBackward0>),
  'kl': tensor(0.0259, grad_fn=<MeanBackward0>),
  'loss': tensor(2.4733, grad_fn=<AddBackward0>)})

In [29]:
from pytorch_lightning import Trainer
from torch_geometric.data.lightning_datamodule import LightningDataset
# datamodule = LightningDataset(dataset)
trainer = Trainer(overfit_batches=1)
trainer.fit(model, loader, loader)
# trainer.fit(model, datamodule)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name                   | Type                | Params
---------------------------------------------------------------
0 | _full_graph_encoder    | GraphEncoder        | 233 K 
1 | _partial_graph_encoder | GenericGraphEncoder | 212 K 
2 | _mean_log_var_mlp      | GenericMLP          | 124 K 
3 | _decoder               | MLPDecoder          | 637 K 
---------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.834     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

# LightningModule + Vae MLP

1. Implement kd divergence loss as part of the lightning module
2. Investigate where node_type_predictor_class_loss_weight_factor is supposed to come from, otherwise, default to 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
