In [2]:
import os
import sys
from pathlib import Path

__file__ = '/home/beangoben/Downloads/pom-mix-main/scripts_pom/make_embeddings.py'

script_dir = Path(__file__).parent
base_dir = Path(*script_dir.parts[:-1])
sys.path.append( str(base_dir / 'src/') )

import json
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch_geometric  as pyg
import tqdm
from ml_collections import ConfigDict
import rdkit
import rdkit.Chem
import pom.data
import pom.gnn.graphnets
from dataloader.representations import graph_utils

/home/beangoben/Downloads/pom-mix-main/datasets


# Load pretrained model

In [3]:
pom_path = base_dir / "scripts_pom/gs-lf_models/pretrained_pom"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Running on: {device}')
hp_gnn = ConfigDict(json.load(open(pom_path / 'hparams.json', 'r')))
embedder = pom.gnn.graphnets.GraphNets(node_dim=graph_utils.NODE_DIM, edge_dim=graph_utils.EDGE_DIM, **hp_gnn)
embedder.load_state_dict(torch.load(pom_path / 'gnn_embedder.pt', map_location=device))
embedder = embedder.to(device)
embedder.eval()

Running on: cpu


GraphNets(
  (layers): ModuleList(
    (0): MetaLayer(
      edge_model=EdgeFiLMModel(
      (gamma): Sequential(
        (0): Linear(371, 14, bias=True)
      )
      (gamma_act): Sigmoid()
      (beta): Sequential(
        (0): Linear(371, 14, bias=True)
      )
    ),
      node_model=NodeAttnModel(
      (self_attn): GAT(85, 85, num_layers=1)
      (output_mlp): Sequential(
        (0): Linear(85, 320, bias=True)
        (1): Dropout(p=0.0, inplace=False)
        (2): SELU()
        (3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (4): Linear(320, 85, bias=True)
      )
      (dropout_layer): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((85,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((85,), eps=1e-05, elementwise_affine=True)
    ),
      global_model=GlobalPNAModel(
      (pool): MultiAggregation([
        MeanAggregation(),
        StdAggregation(),
        MaxAggregation(),
        MinAggregation(),
      ], mode=cat)
      (global

## Load external dataset

In [4]:
df = pd.read_csv('dream_full.csv')
smi = df['CanonicalSMILES'].unique().tolist()
print(df.columns)
df

Index(['CID', 'OdorName', 'CAS', 'CanonicalSMILES', 'MolecularWeight',
       'Odor dilution', 'Subject', 'CAN OR CAN'T SMELL',
       'KNOW OR DON'T KNOW THE SMELL', 'THE ODOR IS:',
       'HOW STRONG IS THE SMELL?', 'HOW PLEASANT IS THE SMELL?',
       'HOW FAMILIAR IS THE SMELL?', 'EDIBLE', 'BAKERY', 'SWEET', 'FRUIT',
       'FISH', 'GARLIC', 'SPICES', 'COLD', 'SOUR', 'BURNT', 'ACID', 'WARM',
       'MUSKY', 'SWEATY', 'AMMONIA/URINOUS', 'DECAYED', 'WOOD', 'GRASS',
       'FLOWER', 'CHEMICAL'],
      dtype='object')


Unnamed: 0,CID,OdorName,CAS,CanonicalSMILES,MolecularWeight,Odor dilution,Subject,CAN OR CAN'T SMELL,KNOW OR DON'T KNOW THE SMELL,THE ODOR IS:,...,ACID,WARM,MUSKY,SWEATY,AMMONIA/URINOUS,DECAYED,WOOD,GRASS,FLOWER,CHEMICAL
0,16741,2-Phenylethyl isothiocyanate,2257-09-2,C1=CC=C(C=C1)CCN=C=S,163.24,"1/1,000",1,True,False,,...,0.00,0.00,0.00,0.0,0.00,0.00,0.00,0.00,0.00,0.00
1,16741,2-Phenylethyl isothiocyanate,2257-09-2,C1=CC=C(C=C1)CCN=C=S,163.24,"1/100,000",1,True,False,,...,0.00,0.00,0.00,0.0,0.00,0.00,0.00,0.00,0.00,0.00
2,16741,2-Phenylethyl isothiocyanate,2257-09-2,C1=CC=C(C=C1)CCN=C=S,163.24,"1/1,000",2,True,False,,...,0.00,0.00,0.00,0.0,0.00,0.00,0.00,0.02,0.04,0.24
3,16741,2-Phenylethyl isothiocyanate,2257-09-2,C1=CC=C(C=C1)CCN=C=S,163.24,"1/100,000",2,True,False,,...,0.00,0.00,0.00,0.0,0.00,0.00,0.00,0.00,0.00,0.00
4,16741,2-Phenylethyl isothiocyanate,2257-09-2,C1=CC=C(C=C1)CCN=C=S,163.24,"1/1,000",3,True,False,,...,0.17,0.00,0.00,0.0,0.00,0.64,0.00,0.00,0.00,0.00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48897,5315892,Cinnamyl alcohol,104-54-1,C1=CC=C(C=C1)C=CCO,134.17,"1/100,000",47,True,False,,...,0.00,0.00,0.00,0.0,0.00,0.00,0.00,0.00,0.16,0.00
48898,5315892,Cinnamyl alcohol,104-54-1,C1=CC=C(C=C1)C=CCO,134.17,"1/1,000",48,True,False,,...,0.00,0.01,0.00,0.0,0.00,0.00,0.00,0.00,0.00,0.00
48899,5315892,Cinnamyl alcohol,104-54-1,C1=CC=C(C=C1)C=CCO,134.17,"1/100,000",48,True,False,,...,0.00,0.00,0.00,0.0,0.00,0.00,0.31,0.00,0.00,0.00
48900,5315892,Cinnamyl alcohol,104-54-1,C1=CC=C(C=C1)C=CCO,134.17,"1/1,000",49,True,False,,...,0.00,0.00,0.15,0.0,0.44,0.04,0.00,0.00,0.00,0.00


# Create embeddings

In [5]:
graphs = [graph_utils.from_smiles(s) for s in smi]
dataset = pom.data.GraphDataset(graphs, [0.0]*len(smi))
dataset = pyg.loader.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
x, y = next(iter(dataset))
emb = embedder.forward(x)
emb.shape

torch.Size([474, 196])

In [7]:
emb[0]

tensor([ 0.1975,  1.4617, -1.4921, -2.3036, -0.4989, -0.4508, -1.9159,  2.7414,
         2.4319,  1.0697, -2.9949, -1.2417,  4.7734, -2.2365, -0.6493,  6.2112,
        -1.8757, -3.8271,  1.1387,  2.9507,  4.5033,  2.8877, -2.6729, -5.7986,
         2.1855,  1.1956,  2.9457,  1.0308,  3.4870,  7.4326,  0.5472, -0.3892,
        -4.1249,  1.1154,  0.8305, -2.0996, -0.5006,  0.7424, -2.7926, -1.1849,
         5.1397, -0.0919,  2.5937, -0.1401,  1.6056, -1.8571, -0.6214,  4.9353,
        -0.2975, -3.7349,  1.4572,  4.2110,  2.1373,  0.1294, -1.9961, -4.2239,
         0.3521, -0.7383, -1.0554, -0.3211,  3.0618, -4.5569, -4.4790,  0.5087,
        -0.2293, -0.3873, -0.6165, -1.6357, -0.1500, -2.3282,  2.9495,  0.0407,
         3.0001, -1.3781,  1.1923, -1.6610, -3.5916, -1.0789,  1.7773, -3.3694,
        -2.8828, -2.4767, -2.2561,  2.3635, -1.2070, -0.9065, -2.4478, -1.0420,
        -6.3652,  0.3141, -3.0892, -1.1920,  4.7504,  0.5549,  1.6700, -3.7833,
        -2.2224,  3.7712,  6.4208,  0.33