## Actividad en clase

Vamos a usar el Graph Transformer para trabajar en esta actividad.

- Corra el graph transformer y vea que puede reproducir el ejemplo de la clase.
- Use el dataset ppbr_az de que viene en ADMET_Group. 
- Reentrene el graph transformer. Pruebe variantes con distinto número de cabezales y capas para ver el efecto que esto tiene en el MAE. Haga 3 variantes, al menos, y corra 100 epochs.
- Vea la última lámina de la clase y dígame si le ganó al paper.
- Cuanto termine, me avisa para entregarle una **L (logrado)**.
- Recuerde que las L otorgan un bono en la nota final de la asignatura.

***Tiene hasta el final de la clase.***

In [1]:
from torch_geometric.data import Data
from gt_pyg.gt_pyg.nn.gt_conv import GTConv

# Standard
import logging
import importlib

# Third party
import numpy as np
import rdkit
from rdkit import RDLogger
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.loader import DataLoader
import torchmetrics
from torchmetrics import MeanAbsoluteError

import gt_pyg.gt_pyg

from gt_pyg.gt_pyg.data.utils import (
    get_tensor_data, 
    get_node_dim, 
    get_edge_dim, 
    get_train_valid_test_data
)
from gt_pyg.gt_pyg.nn.model import GraphTransformerNet

  _torch_pytree._register_pytree_node(
  warn(f"Failed to load image Python extension: {e}")
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
# Turn off majority of RDKit warnings
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Setup the logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

from tdc import utils
names = utils.retrieve_benchmark_names('ADMET_Group')
output = "\n".join([f"{index}. {name}" for index, name in enumerate(names, start=1)])
print("Available endpoints:\n\n" + output)

Available endpoints:

1. caco2_wang
2. hia_hou
3. pgp_broccatelli
4. bioavailability_ma
5. lipophilicity_astrazeneca
6. solubility_aqsoldb
7. bbb_martins
8. ppbr_az
9. vdss_lombardo
10. cyp2d6_veith
11. cyp3a4_veith
12. cyp2c9_veith
13. cyp2d6_substrate_carbonmangels
14. cyp3a4_substrate_carbonmangels
15. cyp2c9_substrate_carbonmangels
16. half_life_obach
17. clearance_microsome_az
18. clearance_hepatocyte_az
19. herg
20. ames
21. dili
22. ld50_zhu


In [3]:
PE_DIM = 6 
(tr, va, te) = get_train_valid_test_data('ppbr_az', min_num_atoms=0)
tr_dataset = get_tensor_data(tr.Drug.to_list(), tr.Y.to_list(), pe_dim=PE_DIM)
va_dataset = get_tensor_data(va.Drug.to_list(), va.Y.to_list(), pe_dim=PE_DIM)
te_dataset = get_tensor_data(te.Drug.to_list(), te.Y.to_list(), pe_dim=PE_DIM)
NODE_DIM = get_node_dim()
EDGE_DIM = get_edge_dim()

print(f'Number of training examples: {len(tr_dataset)}')
print(f'Number of validation examples: {len(va_dataset)}')
print(f'Number of test examples: {len(te_dataset)}')

train_loader = DataLoader(tr_dataset, batch_size=64)
val_loader = DataLoader(va_dataset, batch_size=512)
test_loader = DataLoader(te_dataset, batch_size=512)

Found local copy...
Loading...
Done!


Processing data: 0it [00:00, ?it/s]

Processing data: 0it [00:00, ?it/s]

Processing data: 0it [00:00, ?it/s]

Processing data: 0it [00:00, ?it/s]

Processing data: 0it [00:00, ?it/s]

Number of training examples: 1130
Number of validation examples: 161
Number of test examples: 323


In [4]:
print(f'Node dim: {NODE_DIM}')
print(f'Edge dim: {EDGE_DIM}')

Node dim: 76
Edge dim: 10


In [5]:
def train(epoch, loss_func):
    model.train()
    train_mae = MeanAbsoluteError().to(device)

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        # randomly flip sign of eigenvectors
        
        #device = data.pe.device  # Get the device of data.pe
        batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM), device=device).float() - 1.0)
        #batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM)).float() - 1.0)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, batch_pe, data.batch, zero_var=False)
        loss = loss_func(out.squeeze(), data.y)
        loss.backward()
        optimizer.step()
        
        train_mae.update(out.squeeze(), data.y)

    return train_mae.compute()


@torch.no_grad()
def test(loader):
    model.eval()
    test_mae = MeanAbsoluteError().to(device)

    total_error = 0
    for data in loader:
        data = data.to(device)
        # randomly flip sign of eigenvectors
        batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM), device=device).float() - 1.0)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, batch_pe, data.batch)
        
        test_mae.update(out.squeeze(), data.y)
        
    return test_mae.compute()

train_loss = nn.L1Loss(reduction='mean')

## Modelo 1, el mismo de clases

In [6]:
from torch_geometric import *

import torch._dynamo
torch._dynamo.config.suppress_errors = True

device = torch.device('cuda')

model = GraphTransformerNet(node_dim_in=NODE_DIM,
                            edge_dim_in=EDGE_DIM,
                            pe_in_dim=PE_DIM,
                            num_gt_layers=4, 
                            hidden_dim=128,
                            num_heads=8,
                            norm='bn',
                            gate=False,
                            qkv_bias=False,
                            gt_aggregators=['sum'],
                            aggregators=['sum'],
                            dropout=0.1,
                            act='relu').to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

print(model)
print(f"Number of params: {model.num_parameters()//1000} k")

OptimizedModule(
  (_orig_mod): GraphTransformerNet(
    (node_emb): Linear(in_features=76, out_features=128, bias=False)
    (edge_emb): Linear(in_features=10, out_features=128, bias=False)
    (pe_emb): Linear(in_features=6, out_features=128, bias=False)
    (gt_layers): ModuleList(
      (0-3): 4 x GTConv(128, 128, heads=8, aggrs: sum, qkv_bias: False, gate: False)
    )
    (global_pool): MultiAggregation([
      SumAggregation(),
    ], mode=cat)
    (mu_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    (log_var_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
  )
)
Number of params: 708 k


Uncaught exception in compile_worker subprocess
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_worker/__main__.py", line 38, in main
    pre_fork_setup()
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/async_compile.py", line 62, in pre_fork_setup
    from triton.compiler.compiler import triton_key
ImportError: cannot import name 'triton_key' from 'triton.compiler.compiler' (/usr/local/lib/python3.8/dist-packages/triton/compiler/compiler.py)


In [12]:
best_epoch = 0
best_validation_loss = np.inf
test_set_mae = np.inf

for epoch in range(1, 101):
    tr_loss = train(epoch, loss_func=train_loss)
    va_loss = test(val_loader)
    te_loss = test(test_loader)
    scheduler.step(va_loss)
    print(f'Epoch: {epoch:02d}, Loss: {tr_loss:.4f}, Val: {va_loss:.4f}, '
          f'Test: {te_loss:.4f}')
    if va_loss < best_validation_loss:
        best_epoch = epoch
        best_validation_loss = va_loss
        test_set_mae = te_loss
        
print("\nModel's performance on the test set\n"
        "===================================\n"
        f'MAE={test_set_mae}\n'
        f'Epoch={best_epoch}')

Epoch: 01, Loss: 10.1665, Val: 12.6608, Test: 13.6024
Epoch: 02, Loss: 10.0647, Val: 9.9974, Test: 10.4031
Epoch: 03, Loss: 9.9231, Val: 10.6948, Test: 10.2096
Epoch: 04, Loss: 10.0852, Val: 22.0640, Test: 21.2211
Epoch: 05, Loss: 10.1501, Val: 13.9527, Test: 13.0613
Epoch: 06, Loss: 10.3019, Val: 9.8290, Test: 9.7904
Epoch: 07, Loss: 10.4457, Val: 11.0555, Test: 10.3235
Epoch: 08, Loss: 9.8910, Val: 11.5018, Test: 11.9727
Epoch: 09, Loss: 9.7831, Val: 10.0029, Test: 9.8427
Epoch: 10, Loss: 9.7450, Val: 8.9749, Test: 9.7999
Epoch: 11, Loss: 9.5890, Val: 10.3776, Test: 9.8586
Epoch: 12, Loss: 9.5558, Val: 9.6941, Test: 10.4032
Epoch: 13, Loss: 9.4649, Val: 10.4144, Test: 11.3883
Epoch: 14, Loss: 9.3803, Val: 9.6984, Test: 9.8036
Epoch: 15, Loss: 9.4742, Val: 8.7416, Test: 9.3590
Epoch: 16, Loss: 9.4356, Val: 12.3145, Test: 12.6562
Epoch: 17, Loss: 9.4620, Val: 8.9924, Test: 9.3180
Epoch: 18, Loss: 9.2845, Val: 10.7053, Test: 10.4935
Epoch: 19, Loss: 9.1825, Val: 9.4052, Test: 9.5097
Epo

## Modelo 2, bajo los cabezales a la mitad

In [13]:
model = GraphTransformerNet(node_dim_in=NODE_DIM,
                            edge_dim_in=EDGE_DIM,
                            pe_in_dim=PE_DIM,
                            num_gt_layers=4, 
                            hidden_dim=128,
                            num_heads=4,
                            norm='bn',
                            gate=False,
                            qkv_bias=False,
                            gt_aggregators=['sum'],
                            aggregators=['sum'],
                            dropout=0.1,
                            act='relu').to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

print(model)
print(f"Number of params: {model.num_parameters()//1000} k")


best_epoch = 0
best_validation_loss = np.inf
test_set_mae = np.inf
for epoch in range(1, 101):
    tr_loss = train(epoch, loss_func=train_loss)
    va_loss = test(val_loader)
    te_loss = test(test_loader)
    scheduler.step(va_loss)
    print(f'Epoch: {epoch:02d}, Loss: {tr_loss:.4f}, Val: {va_loss:.4f}, '
          f'Test: {te_loss:.4f}')
    if va_loss < best_validation_loss:
        best_epoch = epoch
        best_validation_loss = va_loss
        test_set_mae = te_loss
        
print("\nModel's performance on the test set\n"
        "===================================\n"
        f'MAE={test_set_mae}\n'
        f'Epoch={best_epoch}')

OptimizedModule(
  (_orig_mod): GraphTransformerNet(
    (node_emb): Linear(in_features=76, out_features=128, bias=False)
    (edge_emb): Linear(in_features=10, out_features=128, bias=False)
    (pe_emb): Linear(in_features=6, out_features=128, bias=False)
    (gt_layers): ModuleList(
      (0-3): 4 x GTConv(128, 128, heads=4, aggrs: sum, qkv_bias: False, gate: False)
    )
    (global_pool): MultiAggregation([
      SumAggregation(),
    ], mode=cat)
    (mu_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    (log_var_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
  )
)
Number of params: 708 k
Epoch: 01, Loss: 56.0192, Val: 91.8635, Test: 86.3211
Epoch: 02, Loss: 27.1171, Val: 32.7959,

## Modelo 3, igual al modelo 1 pero con la mitad de capas

In [14]:
model = GraphTransformerNet(node_dim_in=NODE_DIM,
                            edge_dim_in=EDGE_DIM,
                            pe_in_dim=PE_DIM,
                            num_gt_layers=2, 
                            hidden_dim=128,
                            num_heads=8,
                            norm='bn',
                            gate=False,
                            qkv_bias=False,
                            gt_aggregators=['sum'],
                            aggregators=['sum'],
                            dropout=0.1,
                            act='relu').to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

print(model)
print(f"Number of params: {model.num_parameters()//1000} k")


best_epoch = 0
best_validation_loss = np.inf
test_set_mae = np.inf
for epoch in range(1, 101):
    tr_loss = train(epoch, loss_func=train_loss)
    va_loss = test(val_loader)
    te_loss = test(test_loader)
    scheduler.step(va_loss)
    print(f'Epoch: {epoch:02d}, Loss: {tr_loss:.4f}, Val: {va_loss:.4f}, '
          f'Test: {te_loss:.4f}')
    if va_loss < best_validation_loss:
        best_epoch = epoch
        best_validation_loss = va_loss
        test_set_mae = te_loss
        
print("\nModel's performance on the test set\n"
        "===================================\n"
        f'MAE={test_set_mae}\n'
        f'Epoch={best_epoch}')

OptimizedModule(
  (_orig_mod): GraphTransformerNet(
    (node_emb): Linear(in_features=76, out_features=128, bias=False)
    (edge_emb): Linear(in_features=10, out_features=128, bias=False)
    (pe_emb): Linear(in_features=6, out_features=128, bias=False)
    (gt_layers): ModuleList(
      (0-1): 2 x GTConv(128, 128, heads=8, aggrs: sum, qkv_bias: False, gate: False)
    )
    (global_pool): MultiAggregation([
      SumAggregation(),
    ], mode=cat)
    (mu_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    (log_var_mlp): MLP(
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=1, bias=True)
      )
    )
  )
)
Number of params: 376 k
Epoch: 01, Loss: 58.6166, Val: 82.5431, Test: 68.1310
Epoch: 02, Loss: 35.8914, Val: 38.2512,