# Design 

1. All model related loss computations and helper functions are in torch.modules and they will all be called compute_loss 
2. The MLPDecoder will house all the MLPs involved in the decoder and will have 1 compute_decoder_loss_method that will call all the individual compute_loss methods.

It will also have a decode method that will do decoding at inference time(TODO)

3. The FullGraphEncoder and the PartialGraphEncoder will each be in their own torch modules

4. Finally, the lightning module will have 3 things: 
- FullGraphEncoder
- PartialGraphEncoder (part of decoder)
- MLPDecoder

And after passing through the initial FullGraphEncoder, if we are working with a VAE, we will extract p and q for computing the kl divergence loss, otherwise we will do the other model specific stuff like diffusion.

`params` dictionary will be passed to the lightning module and each torch module will be constructed within it using the relevant parameters by destructuring the dictionary 

node type class weights will be instantiated in the lightning module and passed to the decoder

# TODO
1. fix the incrementing in the original graph edge index (DONE)
2. Work on first node prediction 
3. Investigate node_type_predictor_class_loss_weight_factor

In [10]:
# params houses all relevant model instantiation parameters
params = {}

In [160]:
%load_ext autoreload
%autoreload 2

from dataset import MolerDataset, MolerData
from utils import pprint_pyg_obj
from torch_geometric.loader import DataLoader


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [161]:
dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/l1000/trace_playground',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/pyg_output_playground',
    split = 'train',
)

Processing...


file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_0.pt
Processing 0, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_1.pt
Processing 0, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_2.pt
Processing 0, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_3.pt
Processing 0, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_4.pt
Processing 0, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_5.pt
Processing 0, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_6.pt
Processing 0, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_7.pt
Processing 0, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_0_step_8.pt
Processing 0, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train

Processing 1, step 46
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_1_step_47.pt
Processing 1, step 47
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_0.pt
Processing 2, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_1.pt
Processing 2, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_2.pt
Processing 2, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_3.pt
Processing 2, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_4.pt
Processing 2, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_5.pt
Processing 2, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_6.pt
Processing 2, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_2_step_7.pt
Processing 2, step 7
file_path /data/ongh0068/l1000/pyg_outpu

Processing 4, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_4_step_17.pt
Processing 4, step 17
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_4_step_18.pt
Processing 4, step 18
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_4_step_19.pt
Processing 4, step 19
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_4_step_20.pt
Processing 4, step 20
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_0.pt
Processing 5, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_1.pt
Processing 5, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_2.pt
Processing 5, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_3.pt
Processing 5, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_5_step_4.pt
Processing 5, step 4
file_path /data/ongh0068/l1000/pyg

Processing 9, step 22
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_9_step_23.pt
Processing 9, step 23
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_9_step_24.pt
Processing 9, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_0.pt
Processing 10, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_1.pt
Processing 10, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_2.pt
Processing 10, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_3.pt
Processing 10, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_4.pt
Processing 10, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_5.pt
Processing 10, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_10_step_6.pt
Processing 10, step 6
file_path /data/ongh0068

Processing 12, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_12_step_7.pt
Processing 12, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_12_step_8.pt
Processing 12, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_12_step_9.pt
Processing 12, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_12_step_10.pt
Processing 12, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_12_step_11.pt
Processing 12, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_12_step_12.pt
Processing 12, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_12_step_13.pt
Processing 12, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_12_step_14.pt
Processing 12, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_13_step_0.pt
Processing 13, step 0
file_path /dat

Processing 15, step 17
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_18.pt
Processing 15, step 18
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_19.pt
Processing 15, step 19
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_20.pt
Processing 15, step 20
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_21.pt
Processing 15, step 21
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_22.pt
Processing 15, step 22
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_23.pt
Processing 15, step 23
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_24.pt
Processing 15, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_25.pt
Processing 15, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_15_step_26.pt
Processing 15, step 26
file_

Processing 20, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_3.pt
Processing 20, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_4.pt
Processing 20, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_5.pt
Processing 20, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_6.pt
Processing 20, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_7.pt
Processing 20, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_8.pt
Processing 20, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_9.pt
Processing 20, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_10.pt
Processing 20, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_20_step_11.pt
Processing 20, step 11
file_path /data/ongh

Processing 23, step 17
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_23_step_18.pt
Processing 23, step 18
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_23_step_19.pt
Processing 23, step 19
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_23_step_20.pt
Processing 23, step 20
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_23_step_21.pt
Processing 23, step 21
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_0.pt
Processing 24, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_1.pt
Processing 24, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_2.pt
Processing 24, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_3.pt
Processing 24, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_24_step_4.pt
Processing 24, step 4
file_path /data

Processing 28, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_2.pt
Processing 28, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_3.pt
Processing 28, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_4.pt
Processing 28, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_5.pt
Processing 28, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_6.pt
Processing 28, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_7.pt
Processing 28, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_8.pt
Processing 28, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_9.pt
Processing 28, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_28_step_10.pt
Processing 28, step 10
file_path /data/ongh00

Processing 31, step 20
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_21.pt
Processing 31, step 21
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_22.pt
Processing 31, step 22
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_23.pt
Processing 31, step 23
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_24.pt
Processing 31, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_25.pt
Processing 31, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_26.pt
Processing 31, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_27.pt
Processing 31, step 27
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_28.pt
Processing 31, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_31_step_29.pt
Processing 31, step 29
file_

Processing 35, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_35_step_11.pt
Processing 35, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_35_step_12.pt
Processing 35, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_35_step_13.pt
Processing 35, step 13
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_35_step_14.pt
Processing 35, step 14
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_36_step_0.pt
Processing 36, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_36_step_1.pt
Processing 36, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_36_step_2.pt
Processing 36, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_36_step_3.pt
Processing 36, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_36_step_4.pt
Processing 36, step 4
file_path /data

Processing 39, step 37
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_39_step_38.pt
Processing 39, step 38
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_40_step_0.pt
Processing 40, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_40_step_1.pt
Processing 40, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_40_step_2.pt
Processing 40, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_40_step_3.pt
Processing 40, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_41_step_0.pt
Processing 41, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_41_step_1.pt
Processing 41, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_41_step_2.pt
Processing 41, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_41_step_3.pt
Processing 41, step 3
file_path /data/ongh0

Processing 44, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_27.pt
Processing 44, step 27
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_28.pt
Processing 44, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_29.pt
Processing 44, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_30.pt
Processing 44, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_31.pt
Processing 44, step 31
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_32.pt
Processing 44, step 32
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_33.pt
Processing 44, step 33
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_34.pt
Processing 44, step 34
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_44_step_35.pt
Processing 44, step 35
file_

Processing 47, step 34
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_35.pt
Processing 47, step 35
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_36.pt
Processing 47, step 36
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_37.pt
Processing 47, step 37
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_38.pt
Processing 47, step 38
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_39.pt
Processing 47, step 39
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_40.pt
Processing 47, step 40
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_41.pt
Processing 47, step 41
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_42.pt
Processing 47, step 42
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_47_step_43.pt
Processing 47, step 43
file_

Processing 51, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_51_step_7.pt
Processing 51, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_51_step_8.pt
Processing 51, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_51_step_9.pt
Processing 51, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_51_step_10.pt
Processing 51, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_51_step_11.pt
Processing 51, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_51_step_12.pt
Processing 51, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_0.pt
Processing 52, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_1.pt
Processing 52, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_52_step_2.pt
Processing 52, step 2
file_path /data/on

Processing 53, step 39
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_53_step_40.pt
Processing 53, step 40
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_53_step_41.pt
Processing 53, step 41
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_53_step_42.pt
Processing 53, step 42
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_53_step_43.pt
Processing 53, step 43
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_53_step_44.pt
Processing 53, step 44
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_53_step_45.pt
Processing 53, step 45
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_53_step_46.pt
Processing 53, step 46
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_53_step_47.pt
Processing 53, step 47
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_54_step_0.pt
Processing 54, step 0
file_pa

Processing 56, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_30.pt
Processing 56, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_31.pt
Processing 56, step 31
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_32.pt
Processing 56, step 32
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_33.pt
Processing 56, step 33
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_34.pt
Processing 56, step 34
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_35.pt
Processing 56, step 35
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_36.pt
Processing 56, step 36
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_37.pt
Processing 56, step 37
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_56_step_38.pt
Processing 56, step 38
file_

Processing 59, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_59_step_25.pt
Processing 59, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_59_step_26.pt
Processing 59, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_0.pt
Processing 60, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_1.pt
Processing 60, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_2.pt
Processing 60, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_3.pt
Processing 60, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_4.pt
Processing 60, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_5.pt
Processing 60, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_60_step_6.pt
Processing 60, step 6
file_path /data/ong

Processing 62, step 20
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_21.pt
Processing 62, step 21
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_22.pt
Processing 62, step 22
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_23.pt
Processing 62, step 23
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_24.pt
Processing 62, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_25.pt
Processing 62, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_26.pt
Processing 62, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_27.pt
Processing 62, step 27
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_28.pt
Processing 62, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_62_step_29.pt
Processing 62, step 29
file_

Processing 67, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_4.pt
Processing 67, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_5.pt
Processing 67, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_6.pt
Processing 67, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_7.pt
Processing 67, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_8.pt
Processing 67, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_9.pt
Processing 67, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_10.pt
Processing 67, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_11.pt
Processing 67, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_67_step_12.pt
Processing 67, step 12
file_path /data/on

Processing 70, step 16
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_70_step_17.pt
Processing 70, step 17
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_71_step_0.pt
Processing 71, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_71_step_1.pt
Processing 71, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_71_step_2.pt
Processing 71, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_71_step_3.pt
Processing 71, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_71_step_4.pt
Processing 71, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_71_step_5.pt
Processing 71, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_71_step_6.pt
Processing 71, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_71_step_7.pt
Processing 71, step 7
file_path /data/ongh0

Processing 73, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_4.pt
Processing 73, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_5.pt
Processing 73, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_6.pt
Processing 73, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_7.pt
Processing 73, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_8.pt
Processing 73, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_9.pt
Processing 73, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_10.pt
Processing 73, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_11.pt
Processing 73, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_73_step_12.pt
Processing 73, step 12
file_path /data/on

Processing 76, step 23
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_24.pt
Processing 76, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_25.pt
Processing 76, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_26.pt
Processing 76, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_27.pt
Processing 76, step 27
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_28.pt
Processing 76, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_29.pt
Processing 76, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_30.pt
Processing 76, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_31.pt
Processing 76, step 31
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_76_step_32.pt
Processing 76, step 32
file_

Processing 79, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_1.pt
Processing 79, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_2.pt
Processing 79, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_3.pt
Processing 79, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_4.pt
Processing 79, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_5.pt
Processing 79, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_6.pt
Processing 79, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_7.pt
Processing 79, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_8.pt
Processing 79, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_79_step_9.pt
Processing 79, step 9
file_path /data/ongh0068

Processing 84, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_5.pt
Processing 84, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_6.pt
Processing 84, step 6
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_7.pt
Processing 84, step 7
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_8.pt
Processing 84, step 8
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_9.pt
Processing 84, step 9
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_10.pt
Processing 84, step 10
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_11.pt
Processing 84, step 11
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_12.pt
Processing 84, step 12
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_84_step_13.pt
Processing 84, step 13
file_path /data/

Processing 86, step 52
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_53.pt
Processing 86, step 53
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_86_step_54.pt
Processing 86, step 54
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_87_step_0.pt
Processing 87, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_87_step_1.pt
Processing 87, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_87_step_2.pt
Processing 87, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_87_step_3.pt
Processing 87, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_87_step_4.pt
Processing 87, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_87_step_5.pt
Processing 87, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_87_step_6.pt
Processing 87, step 6
file_path /data/ong

Processing 91, step 24
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_91_step_25.pt
Processing 91, step 25
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_91_step_26.pt
Processing 91, step 26
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_92_step_0.pt
Processing 92, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_92_step_1.pt
Processing 92, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_92_step_2.pt
Processing 92, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_92_step_3.pt
Processing 92, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_92_step_4.pt
Processing 92, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_92_step_5.pt
Processing 92, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_92_step_6.pt
Processing 92, step 6
file_path /data/ong

Processing 94, step 28
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_94_step_29.pt
Processing 94, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_94_step_30.pt
Processing 94, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_95_step_0.pt
Processing 95, step 0
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_95_step_1.pt
Processing 95, step 1
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_95_step_2.pt
Processing 95, step 2
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_95_step_3.pt
Processing 95, step 3
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_95_step_4.pt
Processing 95, step 4
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_95_step_5.pt
Processing 95, step 5
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_95_step_6.pt
Processing 95, step 6
file_path /data/ong

Processing 98, step 29
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_30.pt
Processing 98, step 30
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_31.pt
Processing 98, step 31
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_32.pt
Processing 98, step 32
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_33.pt
Processing 98, step 33
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_34.pt
Processing 98, step 34
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_35.pt
Processing 98, step 35
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_36.pt
Processing 98, step 36
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_37.pt
Processing 98, step 37
file_path /data/ongh0068/l1000/pyg_output_playground/train/train_0_mol_98_step_38.pt
Processing 98, step 38
file_

Done!


In [162]:
loader = DataLoader(dataset, batch_size=16, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices',
    'original_graph_x'
])

In [4]:
for batch in loader:
    break

# FullGraphEncoder

In [5]:
from model_utils import GenericGraphEncoder
import torch

In [6]:
class GraphEncoder(torch.nn.Module):
    """Returns graph level representation of the molecules."""
    def __init__(
        self,
        input_feature_dim,
        atom_or_motif_vocab_size,
        motif_embedding_size = 64,
        hidden_layer_feature_dim=64,
        num_layers=12,
        layer_type="RGATConv",
        use_intermediate_gnn_results=True,
    ):
        super(GraphEncoder, self).__init__()
        self._embed = torch.nn.Embedding(atom_or_motif_vocab_size, motif_embedding_size)
        self._model = GenericGraphEncoder(input_feature_dim = motif_embedding_size + input_feature_dim)
        
    def forward(self, original_graph_node_categorical_features, node_features, edge_index, edge_type, batch_index):
        motif_embeddings = self._embed(original_graph_node_categorical_features)
        node_features = torch.cat((node_features, motif_embeddings), axis = -1)
        input_molecule_representations, _ = self._model(node_features, edge_index.long(), edge_type, batch_index)
        return input_molecule_representations

In [11]:
params['full_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
    'atom_or_motif_vocab_size': len(dataset.node_type_index_to_string)
}

full_graph_encoder = GraphEncoder(
    input_feature_dim = batch.x.shape[-1],
    atom_or_motif_vocab_size = len(dataset.node_type_index_to_string)
)

full_graph_encoder = GraphEncoder(**params['full_graph_encoder'])

In [12]:
input_molecule_representations = full_graph_encoder(
    batch.original_graph_node_categorical_features, 
    batch.original_graph_x.float(),
    batch.original_graph_edge_index,
    batch.original_graph_edge_type,
    batch_index = batch.original_graph_x_batch,
)

# PartialGraphEncoder

In [13]:
params['partial_graph_encoder'] = {
    'input_feature_dim': batch.x.shape[-1],
}

partial_graph_encoder = GenericGraphEncoder(
    input_feature_dim = batch.x.shape[-1],
)

partial_graph_encoder = GenericGraphEncoder(**params['partial_graph_encoder'])

In [14]:
partial_graph_representions, node_representations = partial_graph_encoder(batch.x, batch.edge_index.long(), batch.edge_type, batch.batch)

In [15]:
node_representations.shape

torch.Size([193, 832])

# _mean_log_var_mlp

In [16]:
from model_utils import GenericMLP
latent_dim = 512
params['mean_log_var_mlp'] = {
    'input_feature_dim': input_molecule_representations.shape[-1],
    'output_size': latent_dim * 2
}


mean_log_var_mlp = GenericMLP(**params['mean_log_var_mlp'])

In [17]:
mean_and_log_var = mean_log_var_mlp(input_molecule_representations)

In [18]:
mu = mean_and_log_var[:, : latent_dim]  # Shape: [V, MD]
log_var = mean_and_log_var[:, latent_dim :]  # Shape: [V, MD]

# result_representations: shape [num_partial_graphs, latent_repr_dim]
std = torch.exp(log_var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()

In [19]:
z.shape

torch.Size([16, 512])

# Decoder

In [106]:
from decoder import MLPDecoder

## PickAtomOrMotif

In [44]:
from molecule_generation.utils.training_utils import get_class_balancing_weights


next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
class_weight_factor = params.get("node_type_predictor_class_loss_weight_factor", 1.0)

if not (0.0 <= class_weight_factor <= 1.0):
    raise ValueError(
        f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
    )
if class_weight_factor > 0:
    atom_type_nums = [
        next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
        for type_idx in range(dataset.num_node_types)
    ]
    atom_type_nums.append(next_node_type_distribution["None"])

    class_weights = get_class_balancing_weights(
        class_counts=atom_type_nums, class_weight_factor=class_weight_factor
    )
else:
    class_weights = None
    
    
    
params['node_type_loss_weights'] = torch.tensor(class_weights)

In [45]:
from model_utils import GenericMLP
params['node_type_selector'] = {
    'input_feature_dim':  z.shape[-1] + partial_graph_representions.shape[-1], 
    'output_size': dataset.num_node_types + 1
}


graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()

node_type_selector = GenericMLP(
    input_feature_dim = z.shape[-1] + partial_graph_representions.shape[-1],
    output_size = dataset.num_node_types,
)

node_type_selector = GenericMLP(**params['node_type_selector'])

### node loss computation in the forward method

In [91]:
decoder = MLPDecoder(params)

In [92]:
node_logits = decoder.pick_node_type(
    z,
    partial_graph_representions,
    graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique()
)

In [267]:
num_correct_node_type_choices = batch.correct_node_type_choices_ptr.unique().shape[-1] -1
node_type_multihot_labels = batch.correct_node_type_choices.view(num_correct_node_type_choices, -1)

In [93]:
# node_type_multihot_labels = []
# for i in range(len(batch.correct_node_type_choices_ptr)-1):
#     start_idx = batch.correct_node_type_choices_ptr[i]
#     end_idx = batch.correct_node_type_choices_ptr[i+1] 
#     if end_idx - start_idx == 0:
#         continue
#     node_selection_labels = batch.correct_node_type_choices[start_idx: end_idx]
#     node_type_multihot_labels += [node_selection_labels]
    
# node_type_multihot_labels = torch.stack(node_type_multihot_labels, axis = 0)

In [268]:
node_type_selection_loss = decoder.compute_node_type_selection_loss(
    node_logits,
    node_type_multihot_labels
)

# PickEdge

In [226]:

params['no_more_edges_repr'] = (1,node_representations.shape[-1] + batch.edge_features.shape[-1])
params['edge_candidate_scorer'] = {
    'input_feature_dim': 3011,
    'output_size': 1
}

params['edge_type_selector'] = {
    'input_feature_dim': 3011,
    'output_size': 3
}


_no_more_edges_representation = torch.nn.Parameter(torch.FloatTensor(*params['no_more_edges_repr']), requires_grad = True)
_edge_candidate_scorer = GenericMLP(**params['edge_candidate_scorer'])
_edge_type_selector = GenericMLP(**params['edge_type_selector'])

In [150]:
from decoder import MLPDecoder
decoder = MLPDecoder(params)
edge_candidate_logits, edge_type_logits = decoder.pick_edge(
    input_molecule_representations,
    partial_graph_representions,
    node_representations,
    num_graphs_in_batch = len(batch.ptr) - 1,
    graph_to_focus_node_map= batch.focus_node,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    candidate_edge_features= batch.edge_features
)
decoder.compute_edge_candidate_selection_loss(
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
)

tensor(1.5850, grad_fn=<DivBackward0>)

## pick attachement point

In [163]:
tmp = []
for batch2 in loader:
    if len(batch2.correct_attachment_point_choice) > 0 :
        tmp.append(batch2)

In [164]:
sample_idx = 1

tmp[sample_idx].correct_attachment_point_choice, tmp[sample_idx].valid_attachment_point_choices, tmp[sample_idx].valid_attachment_point_choices_batch

(tensor([0., 4.]),
 tensor([133., 135., 140., 141., 200., 203., 204., 205.], dtype=torch.float64),
 tensor([ 8,  8,  8,  8, 13, 13, 13, 13]))

In [171]:
batch2 = tmp[sample_idx]

In [231]:
params['attachment_point_selector'] = {
    'input_feature_dim': 2176,
    'output_size': 1
}
_attachment_point_selector = GenericMLP(**params['attachment_point_selector'])
def pick_attachment_point(
    input_molecule_representations, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map, # batch.batch
    candidate_attachment_points, # valid_attachment_point_choices
):
    original_and_calculated_graph_representations = torch.cat(
        [input_molecule_representations, partial_graph_representions],
        axis=-1,
    )  # Shape: [PG, MD + PD]
    
    # Map attachment point candidates to their respective partial graphs.
    partial_graphs_for_attachment_point_choices = node_to_graph_map[candidate_attachment_points] # Shape: [CA]
    
    # To score an attachment point, we condition on the representations of input and partial
    # graphs, along with the representation of the attachment point candidate in question.
    attachment_point_representations = torch.cat(
        [
            original_and_calculated_graph_representations[partial_graphs_for_attachment_point_choices],
            node_representations[candidate_attachment_points],
        ],
        axis=-1,
    )  # Shape: [CA, MD + PD + VD*(num_layers+1)]
    print(attachment_point_representations.shape)
    attachment_point_selection_logits = torch.squeeze(_attachment_point_selector(attachment_point_representations), axis = -1)

    
    return attachment_point_selection_logits

In [186]:
input_molecule_representations = full_graph_encoder(
    batch2.original_graph_node_categorical_features, 
    batch2.original_graph_x.float(),
    batch2.original_graph_edge_index,
    batch2.original_graph_edge_type,
    batch_index = batch2.original_graph_x_batch,
)


partial_graph_representions, node_representations = partial_graph_encoder(batch2.x, batch2.edge_index.long(), batch2.edge_type.int(), batch2.batch)

In [190]:
batch2.valid_attachment_point_choices_batch

tensor([ 8,  8,  8,  8, 13, 13, 13, 13])

In [271]:

attachment_point_selection_logits = pick_attachment_point(
    z, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map= batch2.batch,
    candidate_attachment_points = batch2.valid_attachment_point_choices.long()
)

torch.Size([8, 2176])


In [199]:
def compute_attachment_point_selection_loss(
    attachment_point_selection_logits, # as is
    attachment_point_candidate_to_graph_map,# = batch2.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices,# = batch2.correct_attachment_point_choices
):
    # Compute log softmax of the logits within each partial graph.
    attachment_point_candidate_logprobs = (
        traced_unsorted_segment_log_softmax(
            logits=attachment_point_selection_logits,
            segment_ids=attachment_point_candidate_to_graph_map,
        )
        * 1.0
    )  # Shape: [CA]
    
    attachment_point_correct_choice_neglogprobs = -attachment_point_candidate_logprobs[attachment_point_correct_choices]
     # Shape: [AP]
    
    attachment_point_selection_loss = safe_divide_loss(
        (attachment_point_correct_choice_neglogprobs).sum(),
        attachment_point_correct_choice_neglogprobs.shape[0],
    )
    return attachment_point_selection_loss

In [200]:
compute_attachment_point_selection_loss(
    attachment_point_selection_logits =  attachment_point_selection_logits,
    attachment_point_candidate_to_graph_map = batch2.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices = batch2.correct_attachment_point_choice.long()
)

tensor(1.3053, grad_fn=<DivBackward0>)

In [270]:
params

{'full_graph_encoder': {'input_feature_dim': 32,
  'atom_or_motif_vocab_size': 139},
 'partial_graph_encoder': {'input_feature_dim': 32},
 'mean_log_var_mlp': {'input_feature_dim': 832, 'output_size': 1024},
 'node_type_loss_weights': tensor([10.0000,  0.1000,  0.1000,  0.1000,  0.7879,  0.4924,  0.6060, 10.0000,
          7.8786, 10.0000,  7.8786,  0.1000,  0.6565,  0.6565,  0.9848,  0.8754,
          0.8754,  1.1255,  0.9848,  1.3131,  1.5757,  1.9696,  1.5757,  1.9696,
          2.6262,  1.9696,  1.9696,  7.8786,  7.8786,  3.9393,  2.6262,  2.6262,
          2.6262,  2.6262,  3.9393,  7.8786,  7.8786,  7.8786,  3.9393,  7.8786,
         10.0000,  7.8786,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,
          3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  3.9393,  7.8786,  7.8786,
         10.0000, 10.0000,  7.8786,  7.8786, 10.0000,  7.8786,  7.8786, 10.0000,
          7.8786,  7.8786, 10.0000,  7.8786, 10.0000,  7.8786,  7.8786, 10.0000,
          7.8786,  7.8786,  7.8786, 1

In [235]:
from decoder import MLPDecoder
decoder = MLPDecoder(params)
attachment_point_selection_logits = decoder.pick_attachment_point(
    z, # as is
    partial_graph_representions, # partial_graph_representions
    node_representations, #as is
    node_to_graph_map= batch2.batch,
    candidate_attachment_points = batch2.valid_attachment_point_choices.long()
)
decoder.compute_attachment_point_selection_loss(
    attachment_point_selection_logits =  attachment_point_selection_logits,
    attachment_point_candidate_to_graph_map = batch2.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices = batch2.correct_attachment_point_choice.long()
)

tensor(1.4311, grad_fn=<DivBackward0>)

In [241]:

node_logits,edge_candidate_logits,edge_type_logits,attachment_point_selection_logits = decoder(
    input_molecule_representations = z,
    graph_representations = partial_graph_representions,
    graphs_requiring_node_choices = batch.correct_node_type_choices_batch.unique(),
    # edge selection
    node_representations = node_representations,
    num_graphs_in_batch = len(batch.ptr) -1,
    graph_to_focus_node_map =batch.focus_node,
    node_to_graph_map = batch.batch,
    candidate_edge_targets = batch.valid_edge_choices[:, 1].long(),
    candidate_edge_features = batch.edge_features,
    # attachment selection
    candidate_attachment_points = batch.valid_attachment_point_choices.long(),
)


loss = decoder.compute_decoder_loss(
    node_logits = node_logits,
    node_type_multihot_labels = node_type_multihot_labels,
    num_graphs_in_batch= len(batch.ptr)-1,
    node_to_graph_map=batch.batch,
    candidate_edge_targets= batch.valid_edge_choices[:, 1].long(),
    edge_candidate_logits = edge_candidate_logits, # as is
    per_graph_num_correct_edge_choices= batch.num_correct_edge_choices,
    edge_candidate_correctness_labels = batch.correct_edge_choices,
    no_edge_selected_labels = batch.stop_node_label,
    attachment_point_selection_logits =  attachment_point_selection_logits,
    attachment_point_candidate_to_graph_map = batch.valid_attachment_point_choices_batch.long(),
    attachment_point_correct_choices = batch.correct_attachment_point_choice.long()
)

In [242]:
loss

tensor(2.0889, grad_fn=<AddBackward0>)

# LightningModule + Vae MLP

1. Implement kd divergence loss as part of the lightning module
2. Investigate where node_type_predictor_class_loss_weight_factor is supposed to come from, otherwise, default to 1

In [None]:
from molecule_generation.utils.training_utils import get_class_balancing_weights
from pytorch_lightning import LightningModule, Trainer, seed_everything



class BaseModel(LightningModule):
    def __init__(self, params, dataset):
        """Params is a nested dictionary with the relevant parameters."""
        super(BaseModel, self).__init__()
        self._init_params(params, dataset)
        self._full_graph_encoder = GraphEncoder(**params['full_graph_encoder'])
        self._partial_graph_encoder = GenericGraphEncoder(**params['partial_graph_encoder'])
        self._decoder = MLPDecoder(params['decoder'])
        
        # params for latent space
        self._latent_sample_strategy = params['latent_sample_strategy']
        self._latent_repr_dim = params["latent_repr_size"]
        
        
               
    def _init_params(self, params, dataset):
        """
        Initialise class weights for next node prediction and placefolder for
        motif/node embeddings.
        """
        
        # Get some information out from the dataset:
        next_node_type_distribution = dataset.metadata.get("train_next_node_type_distribution")
        class_weight_factor = self._params.get("node_type_predictor_class_loss_weight_factor", 1.0)
        
        if not (0.0 <= class_weight_factor <= 1.0):
            raise ValueError(
                f"Node class loss weight node_classifier_class_loss_weight_factor must be in [0,1], but is {class_weight_factor}!"
            )
        if class_weight_factor > 0:
            atom_type_nums = [
                next_node_type_distribution[dataset.node_type_index_to_string[type_idx]]
                for type_idx in range(dataset.num_node_types)
            ]
            atom_type_nums.append(next_node_type_distribution["None"])

            self.class_weights = get_class_balancing_weights(
                class_counts=atom_type_nums, class_weight_factor=class_weight_factor
            )
        else:
            self.class_weights = None
            
        motif_vocabulary = dataset.metadata.get("motif_vocabulary")
        self._uses_motifs = motif_vocabulary is not None

        self._node_categorical_num_classes = dataset.node_categorical_num_classes
        
        
        if self.uses_categorical_features:
            if "categorical_features_embedding_dim" in self._params:
                self._node_categorical_features_embedding = None
        
    @property
    def uses_motifs(self):
        return self._uses_motifs

    @property
    def uses_categorical_features(self):
        return self._node_categorical_num_classes is not None

    @property
    def decoder(self):
        return self._decoder

    @property
    def encoder(self):
        return self._encoder
    
    @property
    def motif_aware_embedding_layer(self):
        return self._motif_aware_embedding_layer
    
    @property
    def latent_dim(self):
        return self._latent_repr_dim
    
    def compute_initial_node_features(batch )
        # Compute embedding
        pass
        
    
    def sample_from_latent_repr(latent_repr):
        # perturb latent repr
        mu = latent_repr[:, : self.latent_dim]  # Shape: [V, MD]
        log_var = latent_repr[:, self.latent_dim :]  # Shape: [V, MD]

        # result_representations: shape [num_partial_graphs, latent_repr_dim]
        p, q, z = self.sample(mu, log_var)
        
        return p, q, z 
        
    def sample(self, mu, log_var)
        """Samples a different noise vector for each partial graph. 
        TODO: look into the other sampling strategies."""
        std = torch.exp(log_var / 2)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return p, q, z
    
    
    def forward(self, x, edge_index, edge_attr, batch, ??):
        # Obtain node embeddings 
        
        # Forward pass through encoder
        latent_repr = self.encoder(batch)
        
        # Apply latent sampling strategy
        p, q, latent_repr = self.sample_from_latent_repr(latent_repr)
        
        # Forward pass through decoder
        node_type_logits, edge_candidate_logits, edge_type_logits, attachment_point_selection_logits = self.decoder(latent_repr)
        
        # NOTE: loss computation will be done in lightning module
        return MoLeROutput(
            node_type_logits = node_type_logits,
            edge_candidate_logits = edge_candidate_logits,
            edge_type_logits = edge_type_logits,
            attachment_point_selection_logits = attachment_point_selection_logits,
            p = p,
            q = q,
        )