In [5]:
# from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import softmax
from model_utils import GenericMLP
from utils import unsorted_segment_softmax
import torch
class WeightedSumGraphRepresentation(torch.nn.Module):
    def __init__(
        self,
        input_feature_dim,
        graph_representation_size,
        num_heads,
        weighting_fun = "softmax",  # One of {"softmax", "sigmoid"}
        scoring_mlp_layers = [128],
        scoring_mlp_activation_fun = "ReLU",
        scoring_mlp_use_biases: bool = False,
        scoring_mlp_dropout_rate = 0.1,
        transformation_mlp_layers = [128],
        transformation_mlp_activation_fun = "ReLU",
        transformation_mlp_use_biases = False,
        transformation_mlp_dropout_rate = 0.2,
#         transformation_mlp_result_lower_bound = None,
#         transformation_mlp_result_upper_bound = None,
#         **kwargs,
    ):
        self._num_heads = num_heads
        self._graph_representation_size = graph_representation_size
        self._weighting_fun = weighting_fun.lower()
        assert self._weighting_fun in ['softmax', 'sigmoid']
        self._transformation_mlp_activation_fun = ReLU()
        self._scoring_mlp = GenericMLP(
            input_feature_dim = input_feature_dim,
            output_size = self._num_heads, # one score for each head
            hidden_layer_feature_dim=256,
            num_hidden_layers=1,
            activation_layer_type="leaky_relu",
            dropout_prob=scoring_mlp_dropout_rate,
        )
        self._transformation_mlp = GenericMLP(
            input_feature_dim = input_feature_dim,
            output_size = self._graph_representation_size, # one score for each head
            hidden_layer_feature_dim=256,
            num_hidden_layers=1,
            activation_layer_type="leaky_relu",
            dropout_prob=transformation_mlp_dropout_rate,
        )
        
    def forward(
        self, 
        x, 
        batch = None, 
    ):
        # (1) compute weights for each node/head pair:
        scores = self_scoring_mlp(x) # Shape [number of nodes, number of heads]
        if self._weighting_fun == 'sigmoid':
            weights = torch.sigmoid(scores)
        elif self._weighting_fun == 'softmax':
            weights_per_head = []
            for head_idx in range(self._num_heads):
                head_scores = scores[:, head_idx] # Shape [V]
                head_weights = unsorted_segment_softmax(
                    logits = head_scores,
                    segment_ids = batch
                ) # Shape [V]
                weights_per_head.append(head_weights)
            weights = torch.cat(weights_per_head, axis = -1) # Shape [V, H]
        else:
            raise NotImplementedError()
        
        # (2) compute representations for each node/head pair:
        node_reprs = self._transformation_mlp_activation_fun(self._transformation_mlp(x))
        # Shape [V, graph representation dimension]
        
        node_reprs = node_reprs.view(-1, self.num_heads, self._graph_representation_size//self._num_heads)
            
        # (3) if necessary, weight representations and aggregate by graph:
        weights = torch.unsqueeze(weights, -1)  # Shape [V, H, 1]
        weighted_node_reprs = weights * node_reprs  # Shape [V, H, GD//H]

        weighted_node_reprs = weighted_node_reprs.view(-1, self._graph_representation_size)
        # Shape [V, GD]
        graph_reprs = scatter(weighted_node_reprs, batch, reduce = 'sum')  # Shape [G, GD]
        return graph_repr

In [1]:
%load_ext autoreload
%autoreload 2

from model import BaseModel
from dataset import MolerDataset, MolerData
from utils import pprint_pyg_obj
from torch_geometric.loader import DataLoader
import torch
from model_utils import get_params



dataset = MolerDataset(
    root = '/data/ongh0068', 
    raw_moler_trace_dataset_parent_folder = '/data/ongh0068/guacamol/trace_dir',
    output_pyg_trace_dataset_parent_folder = '/data/ongh0068/l1000/already_batched',
    split = 'valid_0',
)

loader = DataLoader(dataset, batch_size=1, shuffle=False, follow_batch = [
    'correct_edge_choices',
    'correct_edge_types',
    'valid_edge_choices',
    'valid_attachment_point_choices',
    'correct_attachment_point_choice',
    'correct_node_type_choices',
    'original_graph_x',
    'correct_first_node_type_choices'
])


for batch in loader:
#     batch.cuda()
    break

2023-01-16 14:42:30.509204: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from model_utils import AggrLayerType

params = get_params(dataset)

params['full_graph_encoder']['aggr_layer_type'] = 'MoLeRAggregation'
params['full_graph_encoder']['total_num_moler_aggr_heads'] = 32

params['partial_graph_encoder']['aggr_layer_type'] = 'MoLeRAggregation'
params['partial_graph_encoder']['total_num_moler_aggr_heads'] = 16

In [3]:
params

{'full_graph_encoder': {'input_feature_dim': 59,
  'atom_or_motif_vocab_size': 166,
  'aggr_layer_type': 'MoLeRAggregation',
  'total_num_moler_aggr_heads': 32},
 'partial_graph_encoder': {'input_feature_dim': 59,
  'atom_or_motif_vocab_size': 166,
  'aggr_layer_type': 'MoLeRAggregation',
  'total_num_moler_aggr_heads': 16},
 'mean_log_var_mlp': {'input_feature_dim': 832, 'output_size': 1024},
 'decoder': {'node_type_selector': {'input_feature_dim': 1344,
   'output_size': 167},
  'node_type_loss_weights': tensor([10.0000,  0.1000,  3.6015,  0.1000,  0.1000,  0.4439,  0.7549,  0.4416,
          10.0000,  2.7939,  3.3916, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000

In [4]:
model = BaseModel(params, dataset)

In [6]:
batch


MolerDataBatch(x=[15852, 59], edge_index=[2, 31368], original_graph_edge_features=[61982], original_graph_node_categorical_features=[28616], focus_node=[1000], partial_graph_edge_features=[31368], edge_features=[7104, 3], correct_edge_choices=[7104], correct_edge_choices_batch=[7104], correct_edge_choices_ptr=[1001], num_correct_edge_choices=[1000], stop_node_label=[1000], valid_edge_choices=[7104, 2], valid_edge_choices_batch=[7104], valid_edge_choices_ptr=[1001], valid_edge_types=[510, 3], correct_edge_types=[510, 3], correct_edge_types_batch=[510], correct_edge_types_ptr=[1001], partial_node_categorical_features=[15852], correct_attachment_point_choice=[42], correct_attachment_point_choice_batch=[42], correct_attachment_point_choice_ptr=[1001], correct_node_type_choices=[482, 166], correct_node_type_choices_batch=[482], correct_node_type_choices_ptr=[1001], correct_first_node_type_choices=[1000, 166], correct_first_node_type_choices_batch=[1000], correct_first_node_type_choices_ptr=

In [5]:
model._run_step(batch)

True
torch.Size([28616, 16])
torch.Size([28616, 16, 1]) torch.Size([28616, 16, 26])
torch.Size([28616, 16])
torch.Size([28616, 16, 1]) torch.Size([28616, 16, 26])
True
torch.Size([15852, 8])
torch.Size([15852, 8, 1]) torch.Size([15852, 8, 52])
torch.Size([15852, 8])
torch.Size([15852, 8, 1]) torch.Size([15852, 8, 52])


MoLeROutput(first_node_type_logits=tensor([[ 0.2009, -0.1930,  0.1892,  ..., -0.1715,  0.0229, -0.1759],
        [ 0.0147, -0.1725,  0.0655,  ..., -0.1513, -0.0655, -0.2060],
        [ 0.0849, -0.1444,  0.1801,  ..., -0.2022, -0.0208, -0.0594],
        ...,
        [-0.1937, -0.3173,  0.0169,  ..., -0.2271, -0.0100, -0.2113],
        [-0.0700, -0.1213,  0.0962,  ..., -0.1686, -0.0179, -0.1278],
        [ 0.0282, -0.2687,  0.0998,  ..., -0.0415, -0.1082, -0.0653]],
       grad_fn=<AddmmBackward0>), node_type_logits=tensor([[ 0.0409, -0.1150,  0.1528,  ..., -0.0375, -0.0429,  0.1271],
        [-0.0032, -0.2015,  0.0424,  ..., -0.0727, -0.0866,  0.0798],
        [-0.0857, -0.0029,  0.0542,  ..., -0.0162, -0.0931,  0.0746],
        ...,
        [-0.1051, -0.1175,  0.1213,  ...,  0.1465, -0.0736,  0.0625],
        [-0.0662, -0.1139,  0.0762,  ...,  0.0717, -0.0148,  0.0720],
        [ 0.0478, -0.0242,  0.0718,  ...,  0.0753, -0.0562,  0.0475]],
       grad_fn=<AddmmBackward0>), edge_candida

In [6]:
model

BaseModel(
  (_full_graph_encoder): GraphEncoder(
    (_embed): Embedding(166, 64)
    (_model): GenericGraphEncoder(
      (_first_layer): FiLMConv(123, 64, num_relations=3)
      (_encoder_layers): ModuleList(
        (0): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (1): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (2): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (3): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (4): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (5): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (6): Sequential(
          (0)