In [1]:
MAX_COUNT = 100

In [2]:
import os
os.chdir("Data")
!ls

dataset.h5          odor_pair_full.json


In [3]:
import json
with open("odor_pair_full.json") as f:
    dataset = json.load(f)

len(dataset),dataset[0]

(166814,
 {'mol1': 'CCCCC/C=C/C(=O)OC',
  'mol1_notes': ['violet',
   'sweet',
   'oily',
   'melon',
   'pear',
   'hairy',
   'costus',
   'fruity',
   'violet leaf',
   'waxy',
   'fresh',
   'green'],
  'mol2': 'CCCCCOC(=O)CCC',
  'mol2_notes': ['cherry',
   'sweet',
   'pineapple',
   'fruity',
   'banana',
   'tropical'],
  'blend_notes': ['animal', 'fruity', 'waxy']})

In [4]:
import collections
import tqdm

notes_to_pairs = collections.defaultdict(list)
for d in tqdm.tqdm(dataset):
    notes_to_pairs[frozenset(d["blend_notes"])].append((d["mol1"],d["mol2"]))

len(notes_to_pairs), next(iter(notes_to_pairs.items()))

100%|██████████████████████████████| 166814/166814 [00:00<00:00, 1930136.33it/s]


(2921,
 (frozenset({'animal', 'fruity', 'waxy'}),
  [('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)C(C)CC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCCCCCCC(=O)OCC')]))

In [5]:
import random

positives = dict()
for n, pairs in tqdm.tqdm(notes_to_pairs.items()):
    for p1 in pairs:
        vals = list(random.sample(pairs,MAX_COUNT)) if len(pairs) > MAX_COUNT else pairs
        positives[p1] = [p2 for p2 in vals if p1 != p2]
len(positives), next(iter(positives.items()))

100%|██████████████████████████████████████| 2921/2921 [00:03<00:00, 790.40it/s]


(166814,
 (('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
  [('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)C(C)CC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCCCCCCC(=O)OCC')]))

In [6]:
missing = len([p for p, ps in positives.items() if len(ps) < MAX_COUNT])
missing, missing/len(dataset)

(31871, 0.19105710551872146)

In [7]:
mol_sets = collections.defaultdict(list)
for i,d in enumerate(dataset):
    mol_sets[d["mol1"]].append(i)
    mol_sets[d["mol2"]].append(i)

In [8]:
pairs = []
pair_to_notes = dict()
for d in dataset:
    pair = (d["mol1"],d["mol2"])
    pairs.append(pair)
    pair_to_notes[pair] = set(d["blend_notes"])

In [9]:
notes_sets = dict()
for i,d in enumerate(dataset):
    notes_sets[i] = set(d["blend_notes"])

In [10]:
# Hard negatives are anchor/negatives that share a molecule
# but do not have any notes in common.
negatives = collections.defaultdict(list)
for (mol,idcs) in tqdm.tqdm(mol_sets.items()):
    for i in idcs:
        p1 = notes_sets[i]
        for j in idcs:
            p2 = notes_sets[j]
            if bool(p1 & p2):
                continue
            negatives[pairs[i]].append(pairs[j])
len(negatives)

100%|██████████████████████████████████████| 2971/2971 [00:13<00:00, 215.94it/s]


166788

In [11]:
missing = len([p for p, ns in negatives.items() if len(ns) < MAX_COUNT])
missing, missing/len(dataset)

(94025, 0.5636517318690277)

In [12]:
# Supplement with random negatives if we have less than the required
# number of hard negatives.
for p, ns in tqdm.tqdm(negatives.items()):
    notes = pair_to_notes[pair]
    while len(ns) < MAX_COUNT:
        other = random.choice(dataset)
        other_pair = (other["mol1"],other["mol2"])
        other_notes = pair_to_notes[other_pair]
        if bool(notes & other_notes):
            continue
        ns.append(other_pair)
    negatives[p] = list(random.sample(pairs,MAX_COUNT))

missing = len([p for p, ns in negatives.items() if len(ns) != MAX_COUNT])
len(negatives), missing, missing/len(dataset)      

100%|████████████████████████████████| 166788/166788 [00:10<00:00, 15531.19it/s]


(166788, 0, 0.0)

In [13]:
next(iter(negatives.items()))

(('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
 [('CC(C)(CCCC(=C)C=C)O', 'CC(CCCC(C)(C)O)C=C'),
  ('CCCC1=NC=CN=C1OC', 'CCCCCC=CC=CC=C'),
  ('CC(C)C(=O)OCC1=CC=CC=C1', 'CC(C)C1CCC2CC(=O)CCC2C1'),
  ('CC1C2(O1)CC3CC2C4C3CC=C(C4)C', 'COC1=C(C=CC(=C1)CC=C)OC=O'),
  ('CC1CCC(CC1OC(=O)C)C(=C)C', 'CCCCCCCCCCOC(=O)C1=CC=CC=C1N'),
  ('CCOC(=O)C1CC2CC1C=C2', 'CC\\1C=C(CC/C1=C\\NC2=CC=CC=C2C(=O)OC)C'),
  ('C/C(=C\\CCC(=C)C=C)/CC/C=C(\\C)/C=O', 'CC(CCC1=CC=CC=C1)CC=O'),
  ('CC(=O)OC1=C(C=C(C=C1)C=O)OC', 'CC/C=C\\CC/C=C/C=O'),
  ('CCC(C)(C)C1CCCCC1OC(=O)C',
   'CCCC(=O)OC(C)(C)[C@@H]1CC[C@@H](C2=C(C1)[C@H](CC2)C)C'),
  ('CC(=CCOC(=O)C1=CC=CC=C1O)C', 'CC1=C[C@H]2[C@@H](CC1)C(=C)CC[C@@H]2C(C)C'),
  ('CCCCC/C=C/CCC(=O)OCC', 'CC\\C=C\\CCOC(=O)/C(=C\\CC)/C'),
  ('CC(C)CC(=O)OCCC1=CC=CC=C1', 'CCCCCC(C)O'),
  ('CC/C=C\\C/C=C\\CCO', 'CCCCC(=O)C'),
  ('CCCC(=O)OCC=CC=CC', 'CCCCCC(C)C(=O)O'),
  ('CC(CCC=C(C)C)CCOC=O', 'CCC=CCCC=CCO'),
  ('CC(=O)OCC1=CC2=C(C=C1)OCO2', 'CCOC(=O)C=CC1=CC=CC=C1'),
  ('CCC(C)(CC)O', 

In [14]:
all_smiles = set()
for d in dataset:
    all_smiles.add(d["mol1"])
    all_smiles.add(d["mol2"])
len(all_smiles),next(iter(all_smiles))

(2971, 'CCC(=O)OCC(C)(C)CC1=C(CCC1C(=C)C)C')

In [15]:
from ogb.utils import smiles2graph

errored = 0
graph_data = dict()
for smiles in tqdm.tqdm(all_smiles):
    try:
        graph_data[smiles] = smiles2graph(smiles)
    except AttributeError as e:
        errored += 1
errored, len(graph_data), next(iter(graph_data.items()))

 51%|███████████████████                  | 1528/2971 [00:01<00:01, 1269.12it/s][17:56:52] SMILES Parse Error: syntax error while parsing: (C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC
[17:56:52] SMILES Parse Error: Failed parsing SMILES '(C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC' for input: '(C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC'
 60%|██████████████████████▏              | 1784/2971 [00:01<00:00, 1272.08it/s][17:56:52] SMILES Parse Error: syntax error while parsing: InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3
[17:56:52] SMILES Parse Error: Failed parsing SMILES 'InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3' for input: 'InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3'
 90%|█████████████████████████████████▍   | 2681/2971 [00:02<00:00, 1277.53it/s][17:56:53] Can't kekulize mol.  Unkekulized atoms: 3 4 5 6 8
100%|█████████████████████████████████████| 2971/2971 [00:02<00:00, 1268.83it/s]


(2968,
 ('CCC(=O)OCC(C)(C)CC1=C(CCC1C(=C)C)C',
  {'edge_index': array([[ 0,  1,  1,  2,  2,  3,  2,  4,  4,  5,  5,  6,  6,  7,  6,  8,
            6,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16,
           15, 17, 11, 18, 14, 10],
          [ 1,  0,  2,  1,  3,  2,  4,  2,  5,  4,  6,  5,  7,  6,  8,  6,
            9,  6, 10,  9, 11, 10, 12, 11, 13, 12, 14, 13, 15, 14, 16, 15,
           17, 15, 18, 11, 10, 14]]),
   'edge_feat': array([[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [1, 0, 1],
          [1, 0, 1],
          [0, 0, 1],
          [0, 0, 1],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
      

In [16]:
import numpy as np
from torch_geometric.data import InMemoryDataset, download_url, Data
import torch

def combine_graphs(g1,g2):
    combined_edge_index = np.concatenate([g1["edge_index"],g1["num_nodes"]+g2["edge_index"]],axis=1)
    combined_edge_feat = np.concatenate([g1["edge_feat"],g2["edge_feat"]],axis=0)
    combined_node_feat = np.concatenate([g1["node_feat"],g2["node_feat"]],axis=0)
    combined_num_nodes = g1["num_nodes"] + g2["num_nodes"]
    combined = {"edge_index":combined_edge_index,"edge_feat":combined_edge_feat,"node_feat":combined_node_feat,"num_nodes":combined_num_nodes}
    return combined

def to_torch(graph):
    tensor_keys = ["edge_index", 'edge_feat', 'node_feat']
    for key in tensor_keys:
        graph[key] = torch.tensor(graph[key])
    return graph

pair_to_data = dict()
for d in tqdm.tqdm(dataset):
    if not d["mol1"] in graph_data or not d["mol2"] in graph_data:
        continue
    pair = (d["mol1"],d["mol2"])
    combined = combine_graphs(graph_data[d["mol1"]],graph_data[d["mol1"]])
    graph = to_torch(combined)
    pair_to_data[pair] = Data(x=graph["node_feat"].float(),edge_attr=graph["edge_feat"],edge_index=graph["edge_index"])
len(pair_to_data), next(iter(pair_to_data.items()))

100%|████████████████████████████████| 166814/166814 [00:06<00:00, 25218.91it/s]


(166733,
 (('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
  Data(x=[22, 9], edge_index=[2, 40], edge_attr=[40, 3])))

In [17]:
valid_pairs = set(pair_to_data.keys()).intersection(set(positives.keys())).intersection(set(negatives.keys()))
len(valid_pairs), len(valid_pairs)/len(dataset)

(166707, 0.9993585670267483)

In [24]:
import json
import h5py
import base64

def encode_smiles(smiles):
    return base64.urlsafe_b64encode(smiles.encode()).decode()

with h5py.File('dataset.h5', 'w') as f:
    for (pair, data) in tqdm.tqdm(pair_to_data.items()):
        group = f.create_group(encode_smiles(json.dumps(pair)))
        group.create_dataset("pair",data=pairs[i])
        group.create_dataset("positives",data=positives[pair])
        group.create_dataset("negatives",data=negatives[pair])
        graph_group = group.create_group("graph")
        for k,v in data.items():
            graph_group.create_dataset(k,data=v.numpy())

100%|██████████████████████████████████| 166733/166733 [03:00<00:00, 923.52it/s]


In [22]:
for p in random.sample(sorted(pair_to_data.keys()),100):
    print(p,json.dumps(p))

('CCCC(=O)OC/C=C(/C)\\CCC=C(C)C', 'CCCCCCCCCCCCCC(=O)OC') ["CCCC(=O)OC/C=C(/C)\\CCC=C(C)C", "CCCCCCCCCCCCCC(=O)OC"]
('CC(=O)CC(=O)OC', 'C[C@@H]1CC[C@H]([C@H]2[C@]13[C@@H]2C(=C)CC3)C(C)C') ["CC(=O)CC(=O)OC", "C[C@@H]1CC[C@H]([C@H]2[C@]13[C@@H]2C(=C)CC3)C(C)C"]
('C1=CC=C(C=C1)OC2=CC=CC=C2', 'CC(=CC/C=C(/C)\\C=C)C') ["C1=CC=C(C=C1)OC2=CC=CC=C2", "CC(=CC/C=C(/C)\\C=C)C"]
('CCC(=O)OCC=CC1=CC=CC=C1', 'CCCCCCCOC(=O)CCC') ["CCC(=O)OCC=CC1=CC=CC=C1", "CCCCCCCOC(=O)CCC"]
('CC(=CCCC(=CCCC(C)(C=C)O)C)C', 'CCCCCCCCCCCCCC=O') ["CC(=CCCC(=CCCC(C)(C=C)O)C)C", "CCCCCCCCCCCCCC=O"]
('C/C=C/C(=O)C1=C(CCCC1(C)C)C', 'CCCCC/C=C/C(=O)OC') ["C/C=C/C(=O)C1=C(CCCC1(C)C)C", "CCCCC/C=C/C(=O)OC"]
('CC1CC2=C(CC1(C)C(=O)C)C(CCC2)(C)C', 'CCCCCCOC(=O)C(=CC)C') ["CC1CC2=C(CC1(C)C(=O)C)C(CCC2)(C)C", "CCCCCCOC(=O)C(=CC)C"]
('C/C=C/C(=O)OC/C=C(\\C)/CCC=C(C)C', 'C1=CC=C(C=C1)CC(=O)CC2=CC=CC=C2') ["C/C=C/C(=O)OC/C=C(\\C)/CCC=C(C)C", "C1=CC=C(C=C1)CC(=O)CC2=CC=CC=C2"]
('CC(C)CC(=O)OC(C=C)C(C)CC=C(C)C', 'CCCCCCC(=O)OCC1=CC=CC=