## Inference for GPS

In [35]:
import torch

from torch_geometric.graphgym.loader import create_dataset, DataLoader, load_pyg
from torch_geometric.graphgym.config import cfg, set_cfg, assert_cfg
from torch_geometric.graphgym.checkpoint import MODEL_STATE
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.utils.device import auto_select_device

import graphgps # register modules
from graphgps.finetuning import load_pretrained_model_cfg, init_model_from_pretrained
from graphgps.loader.dataset.custom_datasets import AQSOL,OPERA, SDF, ZINC

import numpy as np
from scipy.stats import pearsonr
from sklearn.metrics import mean_absolute_error, mean_squared_error
import pandas as pd

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
cfg_file = "/home/loschen/calc/GraphGPS/configs/GPS/molfile-inference.yaml"
set_cfg(cfg)
cfg.merge_from_file(cfg_file)
cfg = load_pretrained_model_cfg(cfg)
cfg.dataset.name = "/NEWDATA/moldata/sampl6/testset_logp2.csv"
cfg

CfgNode({'print': 'both', 'accelerator': 'auto', 'devices': None, 'out_dir': 'results', 'cfg_dest': 'config.yaml', 'custom_metrics': [], 'seed': 0, 'round': 5, 'tensorboard_each_run': False, 'tensorboard_agg': True, 'num_workers': 0, 'num_threads': 6, 'metric_best': 'mae', 'metric_agg': 'argmin', 'view_emb': False, 'gpu_mem': False, 'benchmark': False, 'share': CfgNode({'dim_in': 1, 'dim_out': 1, 'num_splits': 1}), 'dataset': CfgNode({'name': '/NEWDATA/moldata/sampl6/testset_logp2.csv', 'format': 'PyG-SDF', 'dir': './datasets', 'task': 'graph', 'task_type': 'regression', 'transductive': False, 'split': [0.8, 0.1, 0.1], 'shuffle_split': True, 'split_mode': 'random', 'encoder': True, 'encoder_name': 'db', 'encoder_bn': True, 'node_encoder': True, 'node_encoder_name': 'TypeDictNode+RWSE', 'node_encoder_bn': False, 'edge_encoder': True, 'edge_encoder_name': 'TypeDictEdge', 'edge_encoder_bn': False, 'encoder_dim': 128, 'edge_dim': 128, 'edge_train_mode': 'all', 'edge_message_ratio': 0.8, 'e

In [48]:
dataset = create_dataset()
#dataset = SDF("./datasets/SAMPL6",name = "/NEWDATA/moldata/sampl6/sampl6.sdf")
#dataset = OPERA("./datasets/OPERA",name="OPERA_LogP")
delattr(dataset.data, 'train_graph_index')
delattr(dataset.data, 'val_graph_index')
delattr(dataset.data, 'test_graph_index')
loader = DataLoader(dataset, batch_size=32,
                                  shuffle=False, num_workers=cfg.num_workers,
                                  pin_memory=True)
loader


Processing...


    no                                             SMILES   
0    1  CC1=C(C2=C(N1C(=O)C3=CC=C(C=C3)Cl)C=CC(=C2)OC)...  \
1    2                   CC(C)NCC(COC1=CC=CC2=CC=CC=C21)O   
2    3                                     C1=CC=C(C=C1)O   
3    4           C1=CC=C(C(=C1)CC(=O)O)NC2=C(C=CC=C2Cl)Cl   
4    5                                C1=CC=C(C=C1)C(=O)O   
5    6                             C1=CC=C(C(=C1)C(=O)O)N   
6    7                        CN1C2=C(C(=O)N(C1=O)C)NC=N2   
7    8     C1=COC(=C1)CNC2=CC(=C(C=C2C(=O)O)S(=O)(=O)N)Cl   
8    9                                            CC(=S)N   
9   10  C1CN(CCC1N2C3=C(C=C(C=C3)Cl)NC2=O)CCCN4C5=CC=C...   
10  11                                 CC(=O)NC1=CC=CC=C1   
11  12                             C1=CC=C(C(=C1)C(=O)O)O   
12  13                                            CC(=O)O   
13  14                                    C1=CC=C(C=C1)CN   
14  15                              COC(=O)C1=CC=C(C=C1)O   
15  16                  

Processing train dataset: 100%|██████████| 53/53 [00:00<00:00, 11285.31it/s]
Done!
100%|██████████| 53/53 [00:00<00:00, 965.91it/s]


<torch_geometric.loader.dataloader.DataLoader at 0x7f35aa7da6e0>

In [49]:
auto_select_device()
model = create_model()
model = init_model_from_pretrained(
                model, cfg.pretrained.dir, cfg.pretrained.freeze_main,
                cfg.pretrained.reset_prediction_head, seed=cfg.seed
            )
model.eval()
model,cfg.pretrained.reset_prediction_head

(GraphGymModule(
   (model): GPSModel(
     (encoder): FeatureEncoder(
       (node_encoder): Concat2NodeEncoder(
         (encoder1): TypeDictNodeEncoder(
           (encoder): Embedding(28, 36)
         )
         (encoder2): RWSENodeEncoder(
           (raw_norm): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
           (pe_encoder): Linear(in_features=20, out_features=28, bias=True)
         )
       )
       (edge_encoder): TypeDictEdgeEncoder(
         (encoder): Embedding(4, 64)
       )
     )
     (layers): Sequential(
       (0): GPSLayer(
         summary: dim_h=64, local_gnn_type=GINE, global_model_type=Transformer, heads=4
         (local_model): GINEConv(nn=Sequential(
           (0): Linear(64, 64, bias=True)
           (1): ReLU()
           (2): Linear(64, 64, bias=True)
         ))
         (self_attn): MultiheadAttention(
           (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
         )
   

In [50]:
@torch.no_grad()
def predict(model,loader):
    predictions = []
    targets = []
    for batch in loader:
        batch.to(torch.device(cfg.accelerator))
        
        if cfg.gnn.head == 'inductive_edge':
            pred, true, extra_stats = model(batch)
        else:
            pred, true = model(batch)
            extra_stats = {}
        predictions.append(pred.detach().to('cpu', non_blocking=True))
        targets.append(true.detach().to('cpu', non_blocking=True))
            
    predictions = torch.cat(predictions,dim=0).numpy().flatten()
    targets = torch.cat(targets,dim=0).numpy().flatten()
    df_pred = pd.DataFrame({'y_pred':predictions,'y':targets})
    return df_pred

df_pred = predict(model,loader)
df_pred


Unnamed: 0,y_pred,y
0,3.797744,3.52
1,3.008636,3.48
2,1.442594,1.48
3,4.30298,4.51
4,1.927755,1.96
5,1.21767,1.26
6,-0.133645,0.0
7,1.263541,2.56
8,-0.084778,-0.26
9,3.466144,3.9


In [51]:
# compute and print MAE
mae = mean_absolute_error(df_pred.y, df_pred.y_pred)
rmse = np.sqrt(mean_squared_error(df_pred.y, df_pred.y_pred))
pearson_corr= pearsonr(df_pred.y, df_pred.y_pred)[0]**2
print(f"Mean Absolute Error (MAE): {mae:.2f} (RMSE): {rmse:.2f} R: {pearson_corr:.2f} N:{len(df_pred)}")

import plotly.express as px
fig = px.scatter(x=df_pred.y_pred, y=df_pred.y)
fig.show()

Mean Absolute Error (MAE): 0.24 (RMSE): 0.39 R: 0.95 N:53
