# infomax

In [1]:
# order matters for some reason
import dgl
from collections import defaultdict
from dgl.nn.pytorch.glob import AvgPooling
from dgllife.model import load_pretrained
from dgllife.model.model_zoo import *
from dgllife.utils import mol_to_bigraph, PretrainAtomFeaturizer, PretrainBondFeaturizer

import numpy as np
import pandas as pd
import pickle
import os

from rdkit import Chem

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

def collate(gs):
    return dgl.batch(gs)

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



In [2]:
dgl.__version__

'2.2.1'

In [3]:
torch.__version__

'2.3.0'

In [4]:
# train/test data
path_data = "./data"

# output
path_out = "./data"

## Load Pre-trained Model

In [5]:
model = load_pretrained('gin_supervised_infomax') # contextpred infomax edgepred masking
model.to('cpu')
model.eval()

Downloading gin_supervised_infomax_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_infomax.pth...


gin_supervised_infomax_pre_trained.pth:   0%|          | 0.00/7.45M [00:00<?, ?B/s]

Pretrained model loaded


GIN(
  (dropout): Dropout(p=0.5, inplace=False)
  (node_embeddings): ModuleList(
    (0): Embedding(120, 300)
    (1): Embedding(3, 300)
  )
  (gnn_layers): ModuleList(
    (0-4): 5 x GINLayer(
      (mlp): Sequential(
        (0): Linear(in_features=300, out_features=600, bias=True)
        (1): ReLU()
        (2): Linear(in_features=600, out_features=300, bias=True)
      )
      (edge_embeddings): ModuleList(
        (0): Embedding(6, 300)
        (1): Embedding(3, 300)
      )
      (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)

## Make Train Dataset

In [6]:
train = pd.read_csv(os.path.join(path_data, "train.csv"))
smiles = train['Smiles']
y_train = train['pIC50']

In [7]:
smiles

0       CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](NC(=O)CC...
1       CC(C)(O)[C@H](F)CN1Cc2cc(NC(=O)c3cnn4cccnc34)c...
2       CC(C)(O)[C@H](F)CN1Cc2cc(NC(=O)c3cnn4cccnc34)c...
3       CC(C)(O)[C@H](F)CN1Cc2cc(NC(=O)c3cnn4cccnc34)c...
4       COc1cc2c(OC[C@@H]3CCC(=O)N3)ncc(C#CCCCCCCCCCCC...
                              ...                        
1947        O=C(Nc1nc2cc[nH]cc-2n1)c1cccc([N+](=O)[O-])c1
1948                CCCCn1c(NC(=O)c2cccc(Cl)c2)nc2ccccc21
1949    O=C(Nc1nc2cc(F)c(F)cc2[nH]1)c1cccc([N+](=O)[O-...
1950    OC[C@H]1C[C@@H](Nc2nc(Nc3ccccc3)ncc2-c2nc3cccc...
1951                                 CC(C)Oc1ccccc1C(N)=O
Name: Smiles, Length: 1952, dtype: object

In [8]:
y_train

0       10.66
1       10.59
2       10.11
3       10.09
4       10.00
        ...  
1947     4.52
1948     4.52
1949     4.52
1950     4.38
1951     4.26
Name: pIC50, Length: 1952, dtype: float64

In [9]:
def mol2graph(smiles):
    graphs = []
    for smi in smiles:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            g = mol_to_bigraph(mol, add_self_loop=True,
                               node_featurizer=PretrainAtomFeaturizer(),
                               edge_featurizer=PretrainBondFeaturizer(),
                               canonical_atom_order=True)
            graphs.append(g)
    
        except:
            continue
    return graphs

In [10]:
graphs = mol2graph(smiles)

In [11]:
graphs[:2]

[Graph(num_nodes=72, num_edges=228,
       ndata_schemes={'atomic_number': Scheme(shape=(), dtype=torch.int64), 'chirality_type': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={'bond_type': Scheme(shape=(), dtype=torch.int64), 'bond_direction_type': Scheme(shape=(), dtype=torch.int64)}),
 Graph(num_nodes=39, num_edges=127,
       ndata_schemes={'atomic_number': Scheme(shape=(), dtype=torch.int64), 'chirality_type': Scheme(shape=(), dtype=torch.int64)}
       edata_schemes={'bond_type': Scheme(shape=(), dtype=torch.int64), 'bond_direction_type': Scheme(shape=(), dtype=torch.int64)})]

In [12]:
len(graphs)

1952

In [13]:
def graph2infomax(graphs):
    data_loader = DataLoader(graphs, batch_size=256, collate_fn=collate, shuffle=False)
    
    readout = AvgPooling()
    
    mol_emb = []
    for batch_id, bg in enumerate(data_loader):
        bg = bg.to('cpu')
        nfeats = [bg.ndata.pop('atomic_number').to('cpu'),
                  bg.ndata.pop('chirality_type').to('cpu')]
        efeats = [bg.edata.pop('bond_type').to('cpu'),
                  bg.edata.pop('bond_direction_type').to('cpu')]
        with torch.no_grad():
            node_repr = model(bg, nfeats, efeats)
        mol_emb.append(readout(bg, node_repr))
    mol_emb = torch.cat(mol_emb, dim=0).detach().cpu().numpy()
    return mol_emb

In [14]:
mol_emb = graph2infomax(graphs)

In [15]:
fps_infomax = pd.DataFrame(mol_emb)

In [16]:
fps_infomax

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
0,0.012840,-0.123612,0.019478,-0.095734,0.000431,-0.017829,-0.078715,0.065078,0.238945,-0.056917,...,0.080919,0.050735,0.000909,-0.073446,0.015369,0.014067,0.015880,-0.064515,0.082496,0.105162
1,-0.028148,0.031318,0.121894,-0.290171,-0.290834,-0.015077,-0.095040,0.004878,0.048621,0.133710,...,-0.100384,-0.025584,-0.079524,0.218776,-0.066931,0.012332,0.010995,0.027660,0.164179,0.148279
2,0.063364,0.027728,0.087096,-0.149783,-0.127030,-0.011410,-0.060086,0.001026,0.025006,-0.008106,...,0.058023,0.075595,-0.033144,0.090290,0.024366,0.014276,-0.034440,0.051902,0.171981,0.197170
3,0.097603,0.019553,0.065902,-0.077765,-0.124646,-0.000896,-0.086592,-0.010941,0.026417,0.031529,...,-0.026220,0.100599,0.001874,0.101770,-0.064440,0.014887,-0.046015,0.098437,0.181150,0.130759
4,-0.025283,-0.119529,-0.089081,0.063854,0.068520,-0.021033,0.000180,0.007412,0.110019,-0.083853,...,-0.023524,0.150811,-0.026152,0.049845,-0.093691,0.013429,-0.000075,0.029468,-0.076713,0.203222
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1947,0.211332,-0.387519,0.127945,0.058811,-0.044250,0.004553,0.094400,-0.007075,-0.183452,-0.039473,...,0.054417,0.208933,-0.016199,0.228975,-0.285116,0.016910,0.068193,-0.399502,-0.128271,0.094290
1948,-0.072240,0.061641,0.184552,0.157616,-0.021388,0.007599,0.259612,-0.026218,0.054832,0.050813,...,-0.181979,-0.243417,-0.101053,0.020782,-0.200466,0.017252,-0.070357,-0.092547,-0.211524,-0.179452
1949,0.022348,-0.209482,0.003446,0.228658,-0.086536,0.035788,-0.034413,-0.010757,-0.025031,-0.012417,...,-0.235875,0.048142,-0.007963,0.132116,0.036141,0.017842,-0.061312,0.017715,0.011352,0.038577
1950,-0.136114,-0.152204,0.095464,0.098265,-0.048130,-0.000218,0.041317,0.012223,0.144368,-0.172915,...,-0.246787,0.288097,-0.136169,0.052516,-0.006338,0.016241,0.021214,-0.008134,0.104631,0.192539


In [17]:
# save as npy
np.save(os.path.join(path_out, "infomax300.train.npy"), fps_infomax, allow_pickle=False)

## Make Test Dataset

In [18]:
test = pd.read_csv(os.path.join(path_data, "test.csv"))
smiles_test = test['Smiles']

graphs_test = mol2graph(smiles_test)

In [19]:
len(graphs_test)

113

In [20]:
mol_emb_test = graph2infomax(graphs_test)
fps_infomax_test = pd.DataFrame(mol_emb_test)
fps_infomax_test

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
0,0.126515,-0.070788,-0.125801,-0.101102,-0.511209,-0.032890,0.021499,-0.063034,-0.106494,-0.069529,...,0.033377,0.280455,-0.149791,0.075763,0.116040,0.014384,0.048331,-0.289128,-0.173506,0.010870
1,0.007390,-0.191097,0.201529,-0.104568,0.019750,0.009533,0.030039,-0.001289,0.039038,0.182121,...,-0.028664,-0.012639,-0.082434,0.156703,-0.022110,0.014307,0.071994,0.049064,-0.197774,-0.007805
2,0.117358,-0.079716,0.311567,-0.078423,-0.129427,0.011311,-0.029453,-0.068265,0.053194,0.196659,...,-0.209977,0.074811,-0.055765,0.046730,-0.061311,0.014415,0.035847,0.017876,-0.031516,0.417887
3,0.001371,-0.177162,0.069632,0.031912,0.153429,0.009424,-0.039644,0.020556,0.055848,0.070463,...,-0.067156,-0.086307,-0.097164,0.068690,0.008971,0.015352,-0.000318,0.042622,-0.135696,0.255778
4,-0.074349,-0.083011,0.082345,-0.000260,0.072066,0.006270,-0.003263,-0.017865,0.087652,0.081751,...,-0.100356,-0.087644,-0.176923,0.122760,-0.057769,0.014668,0.015914,0.068808,-0.091183,0.391046
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
108,0.016622,-0.210138,0.202166,-0.176142,0.079797,0.010415,0.066732,0.005604,0.018268,0.153305,...,0.077292,-0.009808,-0.095489,0.050751,0.135017,0.014747,0.051156,0.093848,-0.145246,0.199565
109,0.041986,-0.051799,-0.029769,-0.347117,-0.378878,-0.033246,-0.107170,-0.103580,0.035946,0.174398,...,0.179841,0.091508,-0.144260,0.234152,-0.161647,0.013604,0.037245,-0.268749,-0.141467,0.207426
110,0.099336,-0.044724,0.323531,-0.057606,0.082990,0.010310,0.033376,-0.054251,0.093466,0.147529,...,-0.187741,0.043302,-0.057870,0.100413,-0.083857,0.014305,0.055422,0.098623,-0.136232,0.330306
111,0.029937,-0.029792,0.226532,-0.049432,0.100409,0.020880,0.009613,0.032810,0.187503,0.164056,...,0.078971,0.022854,-0.037969,0.119911,-0.059762,0.015853,0.077687,0.112469,-0.165969,0.298141


In [21]:
# save as npy
np.save(os.path.join(path_out, "infomax300.test.npy"), fps_infomax_test, allow_pickle=False)