In [13]:
import sys
import os
from argparse import Namespace
import deepchem as dc
from rdkit import Chem
from sklearn import preprocessing
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import roc_auc_score, mean_absolute_error, mean_squared_error, log_loss, accuracy_score, f1_score
from ogb.lsc import PCQM4Mv2Dataset
import pickle
import torch
sys.path.append('../codes/')
from utils import *

sys.path.append('../codes/models/GNN/')
from fingerprint_model import get_fingerprint_GNN


## Linear Probing



In [26]:
root='../dataset'
seeds=[10,20,30]
gnns = ['gcn']

In [None]:
# import pickle 
# with open('../pcqmv2_subset_dataset.pkl', 'rb') as handle:
#     ids = pickle.load(handle)
# train_ids = ids['train_ids_100k']

# with open('../pcqmv2_valid_subset_dataset.pickle', 'rb') as handle:
#     ids_valid = pickle.load(handle)
# valid_ids = ids_valid['valid_train_ids_100k']


In [6]:
dataset = PCQM4Mv2Dataset(root=root, only_smiles=True)
splits= dataset.get_idx_split()
all_data = pd.read_csv(os.path.join(root,'pcqm4m-v2/raw/data.csv.gz'))


Downloading https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m-v2.zip


Downloaded 0.06 GB: 100%|████████████████████████████████████████████████████████████████████| 60/60 [00:02<00:00, 23.79it/s]


Extracting ../dataset/pcqm4m-v2.zip


In [12]:
len(valid_ids)

73545

In [18]:
train_ids = splits['train'][torch.randperm(len(splits['train']))[:100000]]
valid_ids = splits['valid']

In [19]:
smiles_train = all_data.iloc[train_ids].smiles.values.tolist() 
smiles_test= all_data.iloc[valid_ids].smiles.values.tolist() 


In [20]:

#compute the properties 
desc_names_tarin, properties_train = compute_properties(smiles_train)
desc_names_test, properties_test = compute_properties(smiles_test)



selected_prop = ['NumAromaticCarbocycles',
                'NumAromaticRings',
                'NumSaturatedRings',
                'fr_aniline',
                'fr_benzene',
                'fr_bicyclic',
                'fr_ketone',
                'fr_methoxy',
                'fr_para_hydroxylation',
                'fr_pyridine']

y_train = properties_train[selected_prop].values.copy()
y_train[y_train>1] = 1 # binarize

y_test = properties_test[selected_prop].values.copy()
y_test[y_test>1] = 1 # binarize


________________________________________________________________________________
[Memory] Calling utils.compute_properties...
compute_properties([ 'O=CO[C@H](N(Cc1ccccc1)C)C',
  'Nc1ccc(cc1)Oc1cnc(nc1)C',
  'SCC[C@H](CC1CCCC1)CS',
  'COc1c2CCCCc2cc(c1C)C',
  'CC[C@@H]1COC(=O)N1C(C)C',
  'OC(=O)CC(=C)c1ccc2c(c1)CCCN2C',
  'CNC[C@@H](CSCC1CCCC1)C',
  'OC[C@H]1NC[C@@H]2[N@@](C1)CCOC2',
  'O[C@@H]1CC[C@H](C1)c1ccc2c(c1)OCCO2',
  'CCc1nc(C)nc(c1)[C@H]([C@H](N)C)CC',
  'NN1N(N)CN=C1O',
  'CCOC(=S)S[C@@H](CC)C',
  'N#Cc1cc(C=O)ccc1n1cnc(c1)C',
  'OCC[C@@H](CNCc1cccc(c1)C(=N)O)C',
  'CN(C(=O)CN[C@@H]1CCCc2c1ccs2)C',
  'CC1CCN(CC1)c1n(C)nc(c1C(=O)O)C',
  'O=N(=O)c1c(C)nn(c1C)CCC(F)(F)F',
  'C1=CN2C(=CCN=C2)NN1',
  'OCC(Cc1ccccc1)(CO)Cl',
  'COC(=O)c1ccc2c(c1)/C(=N\\O)/CC2(C)C',
  'CC([C@H](/C(=N\\c1ccc(cc1N)F)/O)C)C',
  'CC#CC1=CC=C(C1)N(=O)=...)


100%|███████████████████████████████████████████████████████████████████████████████| 100000/100000 [09:43<00:00, 171.48it/s]


_____________________________________________compute_properties - 588.9s, 9.8min
________________________________________________________________________________
[Memory] Calling utils.compute_properties...
compute_properties([ 'COc1ccccc1N[C@H](/C(=N\\C(=N)O)/O)C',
  'COc1ccccc1N[C@H](/C(=N\\C(=N)O)/O)C',
  'CC(/N=C(\\N/N=C/1\\C[C@H]2[C@@H]1CC=C2)/S)C',
  'CC(/N=C(\\N/N=C/1\\C[C@H]2[C@@H]1CC=C2)/S)C',
  'C/N=C(\\c1cc2c(s1)ccc(c2)F)/O',
  'C/N=C(\\CSc1ccccc1N(=O)=O)/O',
  'C[C@@H](/C(=N\\C)/O)Sc1ccc(cc1)N(=O)=O',
  'CNC(=O)c1sc2c(c1)c(F)ccc2',
  'C/N=C(\\Cc1cc(ccc1OC)C(=O)C)/O',
  'O=C(O[C@@H](/C(=N\\C1CC1)/O)C)Cn1cnnn1',
  'CSc1nnnn1c1ccc(cc1)O',
  'C/N=C(\\COC(=O)Cc1ccccc1OC)/O',
  'COC(=O)c1ccc(c(c1)OC)OCC(=N)O',
  'COC(=O)c1cccc(c1)OCC(=N)O',
  'O/C(=N/C(=N)O)/COc1cccc(c1)C(=O)C',
  'C/N=C(\\COc1cccc(c1)C(=O)C)/O',
  'C/C(=N\\c1ccc(cc1)S[C@H](C(=N)O)C)/O',
  'OC(=N)COc1ccc(cc1)N1CCCC1=O',
  'COc1ccc(cc1)SC1=N...)


100%|█████████████████████████████████████████████████████████████████████████████████| 73545/73545 [07:14<00:00, 169.23it/s]


_____________________________________________compute_properties - 438.8s, 7.3min


In [None]:
checkpoint_path = ''

In [27]:
res_all_mean={}
res_all_std={}
for gnn in gnns:


        embd_train = get_fingerprint_GNN(smiles_train,f'{gnn}-virtual', checkpoint=f'../../evalgnn/Prob_LS/checkpoints/{gnn}_virtual/checkpoint.pt',batch_size=512)
        embd_test = get_fingerprint_GNN(smiles_test,f'{gnn}-virtual', checkpoint=f'../../evalgnn/Prob_LS/checkpoints/{gnn}_virtual/checkpoint.pt',batch_size=512)

        
        
        
            
        res=[]
        for n_label in range(y_train.shape[1]):  
            if len(np.unique(y_test[:,n_label][(y_test[:,n_label] == y_test[:,n_label])])) < 2:
                continue
            res.append(linear_probing(embedding_train=embd_train[-1], y_train=y_train[:,n_label],
                                      embeding_test=embd_test[-1], y_test=y_test[:, n_label], task='classification',scale=True,))
        con = pd.DataFrame(res, index=selected_prop).T
        res_all_mean[f'{gnn}']= con



../../evalgnn/Prob_LS/checkpoints/gcn_virtual/checkpoint.pt


Iteration: 100%|███████████████████████████████████████████████████████████████████████████| 196/196 [01:09<00:00,  2.80it/s]


../../evalgnn/Prob_LS/checkpoints/gcn_virtual/checkpoint.pt


Iteration: 100%|███████████████████████████████████████████████████████████████████████████| 144/144 [00:51<00:00,  2.79it/s]


In [28]:
pd.concat(res_all_mean)

Unnamed: 0,Unnamed: 1,NumAromaticCarbocycles,NumAromaticRings,NumSaturatedRings,fr_aniline,fr_benzene,fr_bicyclic,fr_ketone,fr_methoxy,fr_para_hydroxylation,fr_pyridine
gcn,log_loss,0.187415,0.109151,0.422762,0.275579,0.187033,0.355493,0.113781,0.288564,0.226182,0.173817
gcn,accuracy,0.935903,0.967816,0.840683,0.900659,0.93608,0.860711,0.967041,0.894595,0.910069,0.937236
gcn,f1_score,0.935867,0.964757,0.792264,0.76568,0.936042,0.793409,0.906618,0.638962,0.710575,0.732988
gcn,roc_auc,0.980808,0.992162,0.877461,0.896192,0.980878,0.897034,0.98158,0.821035,0.908016,0.904066
