In [None]:
MAX_COUNT = 100

In [2]:
import os
os.chdir("Data")
!ls

In [3]:
import json
with open("odor_pair_full.json") as f:
    dataset = json.load(f)

len(dataset),dataset[0]

(166814,
 {'mol1': 'CCCCC/C=C/C(=O)OC',
  'mol1_notes': ['violet',
   'sweet',
   'oily',
   'melon',
   'pear',
   'hairy',
   'costus',
   'fruity',
   'violet leaf',
   'waxy',
   'fresh',
   'green'],
  'mol2': 'CCCCCOC(=O)CCC',
  'mol2_notes': ['cherry',
   'sweet',
   'pineapple',
   'fruity',
   'banana',
   'tropical'],
  'blend_notes': ['animal', 'fruity', 'waxy']})

In [4]:
import collections
import tqdm

notes_to_pairs = collections.defaultdict(list)
for d in tqdm.tqdm(dataset):
    notes_to_pairs[frozenset(d["blend_notes"])].append((d["mol1"],d["mol2"]))

len(notes_to_pairs), next(iter(notes_to_pairs.items()))

100%|████████████████████████████| 166814/166814 [00:00<00:00, 1778376.96it/s]


(2921,
 (frozenset({'animal', 'fruity', 'waxy'}),
  [('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)C(C)CC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCCCCCCC(=O)OCC')]))

In [5]:
import random

positives = dict()
for n, pairs in tqdm.tqdm(notes_to_pairs.items()):
    for p1 in pairs:
        vals = list(random.sample(pairs,MAX_COUNT)) if len(pairs) > MAX_COUNT else pairs
        positives[p1] = [p2 for p2 in vals if p1 != p2]
len(positives), next(iter(positives.items()))

100%|████████████████████████████████████| 2921/2921 [00:03<00:00, 795.89it/s]


(166814,
 (('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
  [('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)C(C)CC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCCCCCCC(=O)OCC')]))

In [6]:
missing = len([p for p, ps in positives.items() if len(ps) < MAX_COUNT])
missing, missing/len(dataset)

(31827, 0.19079333868859927)

In [7]:
mol_sets = collections.defaultdict(list)
for i,d in enumerate(dataset):
    mol_sets[d["mol1"]].append(i)
    mol_sets[d["mol2"]].append(i)

In [8]:
pairs = []
pair_to_notes = dict()
for d in dataset:
    pair = (d["mol1"],d["mol2"])
    pairs.append(pair)
    pair_to_notes[pair] = set(d["blend_notes"])

In [9]:
notes_sets = dict()
for i,d in enumerate(dataset):
    notes_sets[i] = set(d["blend_notes"])

In [10]:
# Hard negatives are anchor/negatives that share a molecule
# but do not have any notes in common.
negatives = collections.defaultdict(list)
for (mol,idcs) in tqdm.tqdm(mol_sets.items()):
    for i in idcs:
        p1 = notes_sets[i]
        for j in idcs:
            p2 = notes_sets[j]
            if bool(p1 & p2):
                continue
            negatives[pairs[i]].append(pairs[j])
len(negatives)

100%|████████████████████████████████████| 2971/2971 [00:12<00:00, 238.52it/s]


166788

In [11]:
missing = len([p for p, ns in negatives.items() if len(ns) < MAX_COUNT])
missing, missing/len(dataset)

(94025, 0.5636517318690277)

In [12]:
# Supplement with random negatives if we have less than the required
# number of hard negatives.
for p, ns in tqdm.tqdm(negatives.items()):
    notes = pair_to_notes[pair]
    while len(ns) < MAX_COUNT:
        other = random.choice(dataset)
        other_pair = (other["mol1"],other["mol2"])
        other_notes = pair_to_notes[other_pair]
        if bool(notes & other_notes):
            continue
        ns.append(other_pair)
    negatives[p] = list(random.sample(pairs,MAX_COUNT))

missing = len([p for p, ns in negatives.items() if len(ns) != MAX_COUNT])
len(negatives), missing, missing/len(dataset)      

100%|██████████████████████████████| 166788/166788 [00:11<00:00, 15034.80it/s]


(166788, 0, 0.0)

In [13]:
next(iter(negatives.items()))

(('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
 [('CCC1=C(C(=O)C(O1)C)O', 'CCCCCC(=O)C(C)C(=O)C'),
  ('CCCCCCOC(=O)/C=C/C', 'CCCCCOC(=O)C(C)CC'),
  ('CCCCCCCOC(=O)C(C)CC', 'CCCCCOC(=O)C(C)CC'),
  ('CC1=CCC(CC1)C(C)(C)S', 'CCCC(CC(=O)OCC)O'),
  ('C/C/1=C/CCC(=C)C2CC(C2CC1)(C)C',
   'C[C@@H]1CC[C@@H]2[C@@]13C[C@H](C2(C)C)C(=C(C3)C(=O)C)C'),
  ('CC=CC=CCOC(=O)C(C)C', 'CCC(=O)OCC(C)(C)OC(C)C1CCCC(C1)(C)C'),
  ('CC(=CCCC(C)(C=C)OC(=O)C)C', 'CCCCC\\C=C/C=C/C(=O)OCC'),
  ('CCCC(=O)OC(C)(C)CC1=CC=CC=C1', 'CCCCC(=O)OCCC(C)CCC=C(C)C'),
  ('CC1CC(CC(C1)(C)C)O', 'CC1CCC(C(C1)OC(=O)C)C(C)C'),
  ('C1CC(=O)OC2=CC=CC=C21', 'CC1=CC=C(C=C1)C(C)(C)O'),
  ('CCCCCC1CCCC1=O', 'CCOC(=O)C=CC1=CC=CC=C1'),
  ('CC1CCCC(C1CCC(C)OC(=O)C)(C)C', 'O=C(C)O[C@@H]1CC[C@@H](CC1)C(C)(C)C'),
  ('CC1=CC=C(C=C1)OC', 'CCCCCCCCCCCC(=O)OCC'),
  ('CC(CC1=CC2=C(C=C1)OCO2)C=O', 'CCCOC(=O)C=CC=CCCCCC'),
  ('CC(=O)OCC1=CC2=C(C=C1)OCO2', 'CCC(=O)C1=CC=CC=C1'),
  ('CC(C)CC(=O)OCCC1=CC=CC=C1', 'CCOC(=O)CC(C)C'),
  ('CC(=CCCC(=CCCC(=O)C)C)C'

In [14]:
all_smiles = set()
for d in dataset:
    all_smiles.add(d["mol1"])
    all_smiles.add(d["mol2"])
len(all_smiles),next(iter(all_smiles))

(2971, 'C1=CC=C(C=C1)COC=O')

In [15]:
from ogb.utils import smiles2graph
from torch_geometric.data import InMemoryDataset, download_url, Data
import torch
from torch_geometric.loader import DataLoader

def to_torch(graph):
    tensor_keys = ["edge_index", 'edge_feat', 'node_feat']
    for key in tensor_keys:
        graph[key] = torch.from_numpy(graph[key])
    return Data(x=graph["node_feat"].float(),edge_attr=graph["edge_feat"],edge_index=graph["edge_index"])

errored = 0
graph_data = dict()
for smiles in tqdm.tqdm(all_smiles):
    try:
        graph_data[smiles] = to_torch(smiles2graph(smiles))
    except AttributeError as e:
        errored += 1
errored, len(graph_data), next(iter(graph_data.items()))

 13%|████▌                               | 374/2971 [00:00<00:02, 1242.35it/s][20:18:07] SMILES Parse Error: syntax error while parsing: InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3
[20:18:07] SMILES Parse Error: Failed parsing SMILES 'InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3' for input: 'InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3'
 76%|██████████████████████████▌        | 2255/2971 [00:01<00:00, 1242.45it/s][20:18:09] SMILES Parse Error: syntax error while parsing: (C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC
[20:18:09] SMILES Parse Error: Failed parsing SMILES '(C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC' for input: '(C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC'
 84%|█████████████████████████████▌     | 2510/2971 [00:02<00:00, 1254.59it/s][20:18:09] Can't kekulize mol.  Unkekulized atoms: 3 4 5 6 8
100%|███████████████████████████████████| 2971/2971 [00:02<00:00, 1239.35it/s]


(3,
 2968,
 ('C1=CC=C(C=C1)COC=O',
  Data(x=[10, 9], edge_index=[2, 20], edge_attr=[20, 3])))

In [16]:
import numpy as np
import data

pair_to_data = dict()
for d in tqdm.tqdm(dataset):
    if not d["mol1"] in graph_data or not d["mol2"] in graph_data:
        continue
    pair = (d["mol1"],d["mol2"])
    g1 = graph_data[d["mol1"]]
    g2 = graph_data[d["mol2"]]
    pair_to_data[pair] = data.combine_graphs([g1,g2])
len(pair_to_data), next(iter(pair_to_data.items()))

100%|███████████████████████████████| 166814/166814 [00:41<00:00, 4004.13it/s]


(166733,
 (('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
  BlendData(x=[22, 9], edge_index=[2, 40], edge_attr=[40, 3], mol_batch=[22], blend_batch=[2])))

In [17]:
valid_pairs = set(pair_to_data.keys()).intersection(set(positives.keys())).intersection(set(negatives.keys()))
len(valid_pairs), len(valid_pairs)/len(dataset)

(166707, 0.9993585670267483)

In [18]:
import json
import h5py
import base64

def encode_smiles(smiles):
    return base64.urlsafe_b64encode(smiles.encode()).decode()

with h5py.File('dataset.h5', 'w') as f:
    for (pair, data) in tqdm.tqdm(pair_to_data.items()):
        group = f.create_group(encode_smiles(json.dumps(pair)))
        group.create_dataset("pair",data=pairs[i])
        group.create_dataset("positives",data=positives[pair])
        group.create_dataset("negatives",data=negatives[pair])
        graph_group = group.create_group("graph")
        for k,v in data.items():
            graph_group.create_dataset(k,data=v.numpy())

100%|████████████████████████████████| 166733/166733 [03:05<00:00, 899.79it/s]


In [19]:
for p in random.sample(sorted(pair_to_data.keys()),100):
    print(p,json.dumps(p))

('CC(C)C[C@H](O)C(=O)OCC', 'CCOC(=O)C1C(O1)C2=CC=CC=C2') ["CC(C)C[C@H](O)C(=O)OCC", "CCOC(=O)C1C(O1)C2=CC=CC=C2"]
('CC(=O)CCSC', 'OC(C1OC(\\C=C)(CC1)C)(C)C') ["CC(=O)CCSC", "OC(C1OC(\\C=C)(CC1)C)(C)C"]
('CC/C=C/CC(=O)OC', 'CC/C=C\\CCOCC(=C)C') ["CC/C=C/CC(=O)OC", "CC/C=C\\CCOCC(=C)C"]
('COC(=O)OC1CCC\\C=C/CC1', 'O=C(OCC/1=C/C[C@@H]2C[C@H]\\1C2(C)C)C') ["COC(=O)OC1CCC\\C=C/CC1", "O=C(OCC/1=C/C[C@@H]2C[C@H]\\1C2(C)C)C"]
('CC(=CCOC=O)C', 'CCC=C(C)C(=O)OCC') ["CC(=CCOC=O)C", "CCC=C(C)C(=O)OCC"]
('CC1=CSC(=N1)C(C)C', 'CCCCCCCCCC(=O)OCCC') ["CC1=CSC(=N1)C(C)C", "CCCCCCCCCC(=O)OCCC"]
('CCC(=O)OC(C)(C)C1CCC(=CC1)C', 'COC1=CC=CC=C1C=C') ["CCC(=O)OC(C)(C)C1CCC(=CC1)C", "COC1=CC=CC=C1C=C"]
('C1=CC=C(C=C1)CCOC=O', 'CC1C=CCCC1(C)C=O') ["C1=CC=C(C=C1)CCOC=O", "CC1C=CCCC1(C)C=O"]
('CCC(=O)OC(C)(C)C1CCC(=CC1)C', 'CCC/C=C/C(=O)OCC') ["CCC(=O)OC(C)(C)C1CCC(=CC1)C", "CCC/C=C/C(=O)OCC"]
('CC(=O)OCC1CCCO1', 'CCC1(CC(=O)C(=C(C1C)C)C)C') ["CC(=O)OCC1CCCO1", "CCC1(CC(=O)C(=C(C1C)C)C)C"]
('O(C(OCC=C(C)CC\\C=C(