# IIC-3641 GML UC

In [1]:
!python3 -m pip list

[0mPackage                            Version
---------------------------------- -----------------------------------------
absl-py                            1.1.0
accelerate                         0.26.1
aggdraw                            1.3.16
aiohttp                            3.8.4
aiosignal                          1.2.0
alabaster                          0.7.12
ann-visualizer                     2.5
annotated-types                    0.6.0
antlr4-python3-runtime             4.9.3
anyio                              3.7.0
AnyQt                              0.2.0
anytree                            2.8.0
appdirs                            1.4.4
appnope                            0.1.3
apturl                             0.5.2
argon2-cffi                        21.3.0
argon2-cffi-bindings               21.2.0
array-record                       0.4.0
arrow                              1.3.0
astor                              0.8.1
asttokens                          2.4.1
astunparse  

Orange3-Text                       1.11.0
Orange3-Textable                   3.1.11
Orange3-Timeseries                 0.5.1
Orange3-WorldHappiness             0.1.8
outcome                            1.2.0
overrides                          7.7.0
Owlready2                          0.38
packaging                          23.1
pandas                             1.3.5
pandas-datareader                  0.10.0
pandocfilters                      1.5.0
paramiko                           2.6.0
parse                              1.19.1
parsel                             1.8.1
parso                              0.8.3
pathspec                           0.11.1
pathtools                          0.1.2
pathy                              0.10.1
patsy                              0.5.2
pdfminer3k                         1.3.4
peft                               0.7.2.dev0
pexpect                            4.8.0
pgmpy                              0.1.25
pickleshare             

In [2]:
import torch
print(torch.__version__)

device = torch.device('cuda')

2.4.1+cu118


#!git clone https://github.com/pgniewko/gt-pyg.git (ver en el repo el archivo environment.yml para dependencias)

## En este ejemplo usaremos teorch geometric

In [3]:
from torch_geometric.data import Data

  _torch_pytree._register_pytree_node(


## El layer de GT lo importamos desde el proyecto gt_pyg

In [4]:
from gt_pyg.gt_pyg.nn.gt_conv import GTConv

## Aquí importamos el RDKit

In [5]:
# Standard
import logging
import importlib

# Third party
import numpy as np
import rdkit
from rdkit import RDLogger
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.loader import DataLoader

## Usamos torchmetric para trabajar con MAE

In [6]:
import torchmetrics
from torchmetrics import MeanAbsoluteError

  warn(f"Failed to load image Python extension: {e}")
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [7]:
import gt_pyg.gt_pyg

## Vamos a usar el GT con edge features

In [8]:
from gt_pyg.gt_pyg.data.utils import (
    get_tensor_data, 
    get_node_dim, 
    get_edge_dim, 
    get_train_valid_test_data
)
from gt_pyg.gt_pyg.nn.model import GraphTransformerNet

In [9]:
# Turn off majority of RDKit warnings
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)


# Set a random seed for a reproducibility purposes
# torch.manual_seed(192837465)

# Setup the logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Log the used versions of RDkit and torch
print(f'Numpy version: {np.__version__}')
print(f'Rdkit version: {rdkit.__version__}')
print(f'Torch version: {torch.__version__}')
print(f'TorchMetrics version: {torchmetrics.__version__}')

Numpy version: 1.21.6
Rdkit version: 2023.03.1
Torch version: 2.4.1+cu118
TorchMetrics version: 0.11.4


## Desde tdc importamos el benchmark de ADMET

In [10]:
from tdc import utils
names = utils.retrieve_benchmark_names('ADMET_Group')
output = "\n".join([f"{index}. {name}" for index, name in enumerate(names, start=1)])
print("Available endpoints:\n\n" + output)

Available endpoints:

1. caco2_wang
2. hia_hou
3. pgp_broccatelli
4. bioavailability_ma
5. lipophilicity_astrazeneca
6. solubility_aqsoldb
7. bbb_martins
8. ppbr_az
9. vdss_lombardo
10. cyp2d6_veith
11. cyp3a4_veith
12. cyp2c9_veith
13. cyp2d6_substrate_carbonmangels
14. cyp3a4_substrate_carbonmangels
15. cyp2c9_substrate_carbonmangels
16. half_life_obach
17. clearance_microsome_az
18. clearance_hepatocyte_az
19. herg
20. ames
21. dili
22. ld50_zhu


## Vamos a trabajar con moléculas generadas por el laboratorio Astrazeneca

In [11]:
PE_DIM = 6 
(tr, va, te) = get_train_valid_test_data('lipophilicity_astrazeneca', min_num_atoms=0)
tr_dataset = get_tensor_data(tr.Drug.to_list(), tr.Y.to_list(), pe_dim=PE_DIM)
va_dataset = get_tensor_data(va.Drug.to_list(), va.Y.to_list(), pe_dim=PE_DIM)
te_dataset = get_tensor_data(te.Drug.to_list(), te.Y.to_list(), pe_dim=PE_DIM)
NODE_DIM = get_node_dim()
EDGE_DIM = get_edge_dim()

print(f'Number of training examples: {len(tr_dataset)}')
print(f'Number of validation examples: {len(va_dataset)}')
print(f'Number of test examples: {len(te_dataset)}')

train_loader = DataLoader(tr_dataset, batch_size=64)
val_loader = DataLoader(va_dataset, batch_size=512)
test_loader = DataLoader(te_dataset, batch_size=512)

Found local copy...
Loading...
Done!


Processing data: 0it [00:00, ?it/s]

Processing data: 0it [00:00, ?it/s]

Processing data: 0it [00:00, ?it/s]

Processing data: 0it [00:00, ?it/s]

Processing data: 0it [00:00, ?it/s]

Number of training examples: 2940
Number of validation examples: 420
Number of test examples: 840


In [12]:
print(f'Node dim: {NODE_DIM}')
print(f'Edge dim: {EDGE_DIM}')

Node dim: 76
Edge dim: 10


## La función de pérdida es L1 (diferencia absoluta) alineada con el error MAE

In [13]:
def train(epoch, loss_func):
    model.train()
    train_mae = MeanAbsoluteError().to(device)

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        # randomly flip sign of eigenvectors
        
        #device = data.pe.device  # Get the device of data.pe
        batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM), device=device).float() - 1.0)
        #batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM)).float() - 1.0)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, batch_pe, data.batch, zero_var=False)
        loss = loss_func(out.squeeze(), data.y)
        loss.backward()
        optimizer.step()
        
        train_mae.update(out.squeeze(), data.y)

    return train_mae.compute()


@torch.no_grad()
def test(loader):
    model.eval()
    test_mae = MeanAbsoluteError().to(device)

    total_error = 0
    for data in loader:
        data = data.to(device)
        # randomly flip sign of eigenvectors
        batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM), device=device).float() - 1.0)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, batch_pe, data.batch)
        
        test_mae.update(out.squeeze(), data.y)
        
    return test_mae.compute()

train_loss = nn.L1Loss(reduction='mean')

## Aquí dimensionamos al GT. Son 4 capas y 8 heads. 

In [14]:
from torch_geometric import *

import torch._dynamo
torch._dynamo.config.suppress_errors = True

device = torch.device('cuda')

model = GraphTransformerNet(node_dim_in=NODE_DIM,
                            edge_dim_in=EDGE_DIM,
                            pe_in_dim=PE_DIM,
                            num_gt_layers=4, 
                            hidden_dim=128,
                            num_heads=8,
                            norm='bn',
                            gate=False,
                            qkv_bias=False,
                            gt_aggregators=['sum'],
                            aggregators=['sum'],
                            dropout=0.1,
                            act='relu').to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

print(model)
print(f"Number of params: {model.num_parameters()//1000} k")

OptimizedModule(
  (_orig_mod): GraphTransformerNet(
    (node_emb): Linear(in_features=76, out_features=128, bias=False)
    (edge_emb): Linear(in_features=10, out_features=128, bias=False)
    (pe_emb): Linear(in_features=6, out_features=128, bias=False)
    (gt_layers): ModuleList(
      (0-3): 4 x GTConv(128, 128, heads=8, aggrs: sum, qkv_bias: False, gate: False)
    )
    (global_pool): MultiAggregation([
      SumAggregation(),
    ], mode=cat)
    (mu_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    (log_var_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
  )
)
Number of params: 708 k


In [15]:
best_epoch = 0
best_validation_loss = np.inf
test_set_mae = np.inf

## El loop de entrenamiento informa sobre MAE en Val y Test

In [16]:
for epoch in range(1, 51):
    tr_loss = train(epoch, loss_func=train_loss)
    va_loss = test(val_loader)
    te_loss = test(test_loader)
    scheduler.step(va_loss)
    print(f'Epoch: {epoch:02d}, Loss: {tr_loss:.4f}, Val: {va_loss:.4f}, '
          f'Test: {te_loss:.4f}')
    if va_loss < best_validation_loss:
        best_epoch = epoch
        best_validation_loss = va_loss
        test_set_mae = te_loss
        
print("\nModel's performance on the test set\n"
        "===================================\n"
        f'MAE={test_set_mae}\n'
        f'Epoch={best_epoch}')

W1008 10:20:48.094755 140382441187136 torch/_dynamo/variables/tensor.py:715] [1/0] Graph break from `Tensor.item()`, consider setting:
W1008 10:20:48.094755 140382441187136 torch/_dynamo/variables/tensor.py:715] [1/0]     torch._dynamo.config.capture_scalar_outputs = True
W1008 10:20:48.094755 140382441187136 torch/_dynamo/variables/tensor.py:715] [1/0] or:
W1008 10:20:48.094755 140382441187136 torch/_dynamo/variables/tensor.py:715] [1/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W1008 10:20:48.094755 140382441187136 torch/_dynamo/variables/tensor.py:715] [1/0] to include these operations in the captured graph.
W1008 10:20:48.094755 140382441187136 torch/_dynamo/variables/tensor.py:715] [1/0] 
W1008 10:20:48.369405 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT forward /home/marcelo/Dropbox/Codes_STGO/IIC-3641/Codes 2024/gt_pyg/gt_pyg/nn/gt_conv.py line 143 
W1008 10:20:48.369405 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:48.369405 

Uncaught exception in compile_worker subprocess
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_worker/__main__.py", line 38, in main
    pre_fork_setup()
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/async_compile.py", line 62, in pre_fork_setup
    from triton.compiler.compiler import triton_key
ImportError: cannot import name 'triton_key' from 'triton.compiler.compiler' (/usr/local/lib/python3.8/dist-packages/triton/compiler/compiler.py)
W1008 10:20:48.798929 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT propagate /tmp/gt_pyg.gt_pyg.nn.gt_conv_GTConv_propagate_tp56l8ft.py line 176 
W1008 10:20:48.798929 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:48.798929 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:48.798929 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packa

W1008 10:20:48.962696 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT collect /tmp/gt_pyg.gt_pyg.nn.gt_conv_GTConv_propagate_tp56l8ft.py line 37 
W1008 10:20:48.962696 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:48.962696 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:48.962696 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:48.962696 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:48.962696 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:48.962696 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:48.962696 140382441187136 torch/_dynamo/convert_frame.py:1009]   File 

W1008 10:20:49.027299 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT _index_select /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py line 263 
W1008 10:20:49.027299 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:49.027299 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:49.027299 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:49.027299 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:49.027299 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:49.027299 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:49.027299 140382441187136 torc

W1008 10:20:49.189640 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT message /home/marcelo/Dropbox/Codes_STGO/IIC-3641/Codes 2024/gt_pyg/gt_pyg/nn/gt_conv.py line 189 
W1008 10:20:49.189640 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:49.189640 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:49.189640 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:49.189640 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:49.189640 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:49.189640 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:49.189640 140382441187136 torch/_dynamo/convert

W1008 10:20:49.317110 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT maybe_num_nodes /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/utils/num_nodes.py line 12 
W1008 10:20:49.317110 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:49.317110 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:49.317110 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:49.317110 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:49.317110 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:49.317110 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:49.317110 140382441187136 torch/_dyna

W1008 10:20:49.517646 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT torch_dynamo_resume_in_softmax_at_74 /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/utils/_softmax.py line 74 
W1008 10:20:49.517646 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:49.517646 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:49.517646 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:49.517646 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:49.517646 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:49.517646 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:49.517646 1403824

W1008 10:20:49.575239 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT scatter /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/utils/_scatter.py line 15 
W1008 10:20:49.575239 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:49.575239 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:49.575239 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:49.575239 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:49.575239 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:49.575239 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:49.575239 140382441187136 torch/_dynamo/conver

W1008 10:20:49.746258 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT aggregate /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py line 577 
W1008 10:20:49.746258 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:49.746258 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:49.746258 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:49.746258 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:49.746258 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:49.746258 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:49.746258 140382441187136 torch/_d

W1008 10:20:49.834987 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT wrapper /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/experimental.py line 114 
W1008 10:20:49.834987 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:49.834987 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:49.834987 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:49.834987 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:49.834987 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:49.834987 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:49.834987 140382441187136 torch/_dynamo/convert

W1008 10:20:49.922983 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT __call__ /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/nn/aggr/base.py line 101 
W1008 10:20:49.922983 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:49.922983 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:49.922983 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:49.922983 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:49.922983 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:49.922983 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:49.922983 140382441187136 torch/_dynamo/conver

W1008 10:20:50.001847 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT forward /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/nn/aggr/multi.py line 152 
W1008 10:20:50.001847 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:50.001847 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:50.001847 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:50.001847 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:50.001847 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:50.001847 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:50.001847 140382441187136 torch/_dynamo/conver

W1008 10:20:50.077903 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT forward /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/nn/aggr/basic.py line 19 
W1008 10:20:50.077903 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:50.077903 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:50.077903 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:50.077903 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:50.077903 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:50.077903 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:50.077903 140382441187136 torch/_dynamo/convert

W1008 10:20:50.149980 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT reduce /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/nn/aggr/base.py line 173 
W1008 10:20:50.149980 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:50.149980 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:50.149980 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:50.149980 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:50.149980 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:50.149980 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:50.149980 140382441187136 torch/_dynamo/convert_

W1008 10:20:50.351862 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT forward /home/marcelo/Dropbox/Codes_STGO/IIC-3641/Codes 2024/gt_pyg/gt_pyg/nn/mlp.py line 52 
W1008 10:20:50.351862 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:50.351862 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:50.351862 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:50.351862 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:50.351862 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:50.351862 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:50.351862 140382441187136 torch/_dynamo/convert_fram

W1008 10:20:50.445326 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT forward /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/nn/aggr/fused.py line 191 
W1008 10:20:50.445326 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:20:50.445326 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:20:50.445326 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:20:50.445326 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:20:50.445326 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:20:50.445326 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:20:50.445326 140382441187136 torch/_dynamo/conver

W1008 10:20:52.947504 140382441187136 torch/_dynamo/convert_frame.py:762] [15/8] torch._dynamo hit config.cache_size_limit (8)
W1008 10:20:52.947504 140382441187136 torch/_dynamo/convert_frame.py:762] [15/8]    function: 'broadcast' (/home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/utils/_scatter.py:192)
W1008 10:20:52.947504 140382441187136 torch/_dynamo/convert_frame.py:762] [15/8]    last reason: GLOBAL_STATE changed: grad_mode 
W1008 10:20:52.947504 140382441187136 torch/_dynamo/convert_frame.py:762] [15/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1008 10:20:52.947504 140382441187136 torch/_dynamo/convert_frame.py:762] [15/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


Epoch: 01, Loss: 1.7570, Val: 1.0077, Test: 1.0874
Epoch: 02, Loss: 1.0890, Val: 1.2647, Test: 1.3239
Epoch: 03, Loss: 1.0871, Val: 1.1578, Test: 1.2477
Epoch: 04, Loss: 1.0192, Val: 1.3667, Test: 1.4846
Epoch: 05, Loss: 1.0192, Val: 1.4004, Test: 1.5476
Epoch: 06, Loss: 1.0206, Val: 1.3003, Test: 1.4099
Epoch: 07, Loss: 0.9904, Val: 1.4159, Test: 1.5717
Epoch: 08, Loss: 0.9686, Val: 1.4100, Test: 1.5118
Epoch: 09, Loss: 0.9460, Val: 1.4998, Test: 1.6702
Epoch: 10, Loss: 0.9498, Val: 1.3408, Test: 1.5183
Epoch: 11, Loss: 0.9561, Val: 1.0133, Test: 1.0819
Epoch: 12, Loss: 0.9360, Val: 1.1705, Test: 1.2673
Epoch: 13, Loss: 0.9065, Val: 0.9536, Test: 1.0161
Epoch: 14, Loss: 0.8917, Val: 1.0989, Test: 1.1874
Epoch: 15, Loss: 0.8982, Val: 1.0818, Test: 1.1732
Epoch: 16, Loss: 0.8928, Val: 1.1501, Test: 1.2565
Epoch: 17, Loss: 0.8872, Val: 0.9683, Test: 1.0651
Epoch: 18, Loss: 0.8709, Val: 0.8999, Test: 0.9702
Epoch: 19, Loss: 0.8717, Val: 1.1634, Test: 1.3168
Epoch: 20, Loss: 0.8613, Val: 1

## Aquí una variante que cambia la función de activación

In [17]:
model = GraphTransformerNet(node_dim_in=NODE_DIM,
                            edge_dim_in=EDGE_DIM,
                            pe_in_dim=PE_DIM,
                            num_gt_layers=4, 
                            hidden_dim=128,
                            num_heads=8,
                            norm='bn',
                            gt_aggregators=['sum', "mean"],
                            aggregators=['sum','mean','max', 'std'],
                            dropout=0.1,
                            act='gelu').to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

print(model)
print(f"Number of params: {model.num_parameters()//1000} k")


best_epoch = 0
best_validation_loss = np.inf
test_set_mae = np.inf
for epoch in range(1, 51):
    tr_loss = train(epoch, loss_func=train_loss)
    va_loss = test(val_loader)
    te_loss = test(test_loader)
    scheduler.step(va_loss)
    print(f'Epoch: {epoch:02d}, Loss: {tr_loss:.4f}, Val: {va_loss:.4f}, '
          f'Test: {te_loss:.4f}')
    if va_loss < best_validation_loss:
        best_epoch = epoch
        best_validation_loss = va_loss
        test_set_mae = te_loss
        
print("\nModel's performance on the test set\n"
        "===================================\n"
        f'MAE={test_set_mae}\n'
        f'Epoch={best_epoch}')

OptimizedModule(
  (_orig_mod): GraphTransformerNet(
    (node_emb): Linear(in_features=76, out_features=128, bias=False)
    (edge_emb): Linear(in_features=10, out_features=128, bias=False)
    (pe_emb): Linear(in_features=6, out_features=128, bias=False)
    (gt_layers): ModuleList(
      (0-3): 4 x GTConv(128, 128, heads=8, aggrs: sum,mean, qkv_bias: False, gate: False)
    )
    (global_pool): MultiAggregation([
      SumAggregation(),
      MeanAggregation(),
      MaxAggregation(),
      StdAggregation(),
    ], mode=cat)
    (mu_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=512, out_features=128, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    (log_var_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=512, out_features=128, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
  )
)
Num

W1008 10:22:28.975227 140382441187136 torch/_dynamo/convert_frame.py:1009] WON'T CONVERT forward /home/marcelo/.local/lib/python3.8/site-packages/torch_geometric/nn/aggr/basic.py line 33 
W1008 10:22:28.975227 140382441187136 torch/_dynamo/convert_frame.py:1009] due to: 
W1008 10:22:28.975227 140382441187136 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last):
W1008 10:22:28.975227 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 948, in __call__
W1008 10:22:28.975227 140382441187136 torch/_dynamo/convert_frame.py:1009]     result = self._inner_convert(
W1008 10:22:28.975227 140382441187136 torch/_dynamo/convert_frame.py:1009]   File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
W1008 10:22:28.975227 140382441187136 torch/_dynamo/convert_frame.py:1009]     return _compile(
W1008 10:22:28.975227 140382441187136 torch/_dynamo/convert

Epoch: 01, Loss: 1.1644, Val: 1.0854, Test: 1.1061
Epoch: 02, Loss: 0.9733, Val: 1.1846, Test: 1.2481
Epoch: 03, Loss: 0.9463, Val: 1.1183, Test: 1.2172
Epoch: 04, Loss: 0.9356, Val: 1.0477, Test: 1.0700
Epoch: 05, Loss: 0.8883, Val: 1.0379, Test: 1.0295
Epoch: 06, Loss: 0.8955, Val: 0.9915, Test: 1.0146
Epoch: 07, Loss: 0.8557, Val: 1.0464, Test: 1.1427
Epoch: 08, Loss: 0.8233, Val: 0.8565, Test: 0.8701
Epoch: 09, Loss: 0.7962, Val: 0.9420, Test: 0.9827
Epoch: 10, Loss: 0.7905, Val: 1.6679, Test: 1.7724
Epoch: 11, Loss: 0.7843, Val: 0.9867, Test: 0.9551
Epoch: 12, Loss: 0.7589, Val: 1.4355, Test: 1.5197
Epoch: 13, Loss: 0.7326, Val: 0.7995, Test: 0.8153
Epoch: 14, Loss: 0.7201, Val: 0.8180, Test: 0.8312
Epoch: 15, Loss: 0.6840, Val: 0.6959, Test: 0.7463
Epoch: 16, Loss: 0.6687, Val: 1.1551, Test: 1.2408
Epoch: 17, Loss: 0.6668, Val: 1.3269, Test: 1.4975
Epoch: 18, Loss: 0.6486, Val: 0.8942, Test: 1.0254
Epoch: 19, Loss: 0.6396, Val: 0.6696, Test: 0.7344
Epoch: 20, Loss: 0.6092, Val: 0