In [None]:
MAX_COUNT = 100

In [2]:
import os
os.chdir("Data")
!ls

[34mChallenge[m[m           leaderboard.h5      odor_pair_full.json train.h5
dataset.h5          molecules.csv       test.h5


In [3]:
import json
with open("odor_pair_full.json") as f:
    dataset = json.load(f)

len(dataset),dataset[0]

(166814,
 {'mol1': 'CCCCC/C=C/C(=O)OC',
  'mol1_notes': ['violet',
   'sweet',
   'oily',
   'melon',
   'pear',
   'hairy',
   'costus',
   'fruity',
   'violet leaf',
   'waxy',
   'fresh',
   'green'],
  'mol2': 'CCCCCOC(=O)CCC',
  'mol2_notes': ['cherry',
   'sweet',
   'pineapple',
   'fruity',
   'banana',
   'tropical'],
  'blend_notes': ['animal', 'fruity', 'waxy']})

In [4]:
import collections
import tqdm

notes_to_pairs = collections.defaultdict(list)
for d in tqdm.tqdm(dataset):
    notes_to_pairs[frozenset(d["blend_notes"])].append((d["mol1"],d["mol2"]))

len(notes_to_pairs), next(iter(notes_to_pairs.items()))

100%|████████████████████████████| 166814/166814 [00:00<00:00, 1766460.80it/s]


(2921,
 (frozenset({'animal', 'fruity', 'waxy'}),
  [('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)C(C)CC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCCCCCCC(=O)OCC')]))

In [5]:
import random

positives = dict()
for n, pairs in tqdm.tqdm(notes_to_pairs.items()):
    for p1 in pairs:
        vals = list(random.sample(pairs,MAX_COUNT)) if len(pairs) > MAX_COUNT else pairs
        positives[p1] = [p2 for p2 in vals if p1 != p2]
len(positives), next(iter(positives.items()))

100%|████████████████████████████████████| 2921/2921 [00:03<00:00, 781.78it/s]


(166814,
 (('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
  [('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)C(C)CC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCOC(=O)CCC'),
   ('CCCCC/C=C/C(=O)OC', 'CCCCCCCCCCCCCC(=O)OCC')]))

In [6]:
missing = len([p for p, ps in positives.items() if len(ps) < MAX_COUNT])
missing, missing/len(dataset)

(31983, 0.19172851199539606)

In [7]:
mol_sets = collections.defaultdict(list)
for i,d in enumerate(dataset):
    mol_sets[d["mol1"]].append(i)
    mol_sets[d["mol2"]].append(i)

In [8]:
pairs = []
pair_to_notes = dict()
for d in dataset:
    pair = (d["mol1"],d["mol2"])
    pairs.append(pair)
    pair_to_notes[pair] = set(d["blend_notes"])

In [9]:
notes_sets = dict()
for i,d in enumerate(dataset):
    notes_sets[i] = set(d["blend_notes"])

In [10]:
# Hard negatives are anchor/negatives that share a molecule
# but do not have any notes in common.
negatives = collections.defaultdict(list)
for (mol,idcs) in tqdm.tqdm(mol_sets.items()):
    for i in idcs:
        p1 = notes_sets[i]
        for j in idcs:
            p2 = notes_sets[j]
            if bool(p1 & p2):
                continue
            negatives[pairs[i]].append(pairs[j])
len(negatives)

100%|████████████████████████████████████| 2971/2971 [00:13<00:00, 225.25it/s]


166788

In [11]:
missing = len([p for p, ns in negatives.items() if len(ns) < MAX_COUNT])
missing, missing/len(dataset)

(94025, 0.5636517318690277)

In [12]:
# Supplement with random negatives if we have less than the required
# number of hard negatives.
for p, ns in tqdm.tqdm(negatives.items()):
    notes = pair_to_notes[pair]
    while len(ns) < MAX_COUNT:
        other = random.choice(dataset)
        other_pair = (other["mol1"],other["mol2"])
        other_notes = pair_to_notes[other_pair]
        if bool(notes & other_notes):
            continue
        ns.append(other_pair)
    negatives[p] = list(random.sample(pairs,MAX_COUNT))

missing = len([p for p, ns in negatives.items() if len(ns) != MAX_COUNT])
len(negatives), missing, missing/len(dataset)      

100%|██████████████████████████████| 166788/166788 [00:11<00:00, 14989.48it/s]


(166788, 0, 0.0)

In [13]:
next(iter(negatives.items()))

(('CCCCC/C=C/C(=O)OC', 'CCCCCOC(=O)CCC'),
 [('C1=CC=C2C(=C1)C=CN2', 'CC(CCCC(C)(C)O)CC=O'),
  ('CC(C)CC(=O)O', 'CCCCOC(=O)C(C)O'),
  ('CCCCCOC(=O)/C(=C\\C)/C', 'OC2C(CCC2C1CCCC1)C3CCCC3'),
  ('CC/C=C/CC1=C(C(=O)CC1)C', 'CCCCCCOC(=O)C(C)(C)C'),
  ('CCCCCC=CC(=O)OCC', 'COC(=O)CCC(=O)OC'),
  ('CCCCCCC1CCC(=O)O1', 'CCOC(=O)C(C)O'),
  ('C/C=C(\\C)/C(=O)OCCC1=CC=CC=C1', 'CC(=O)CO'),
  ('C1=CC=C(C=C1)CCC=O', 'CCCCCOC(=O)CCC'),
  ('CC=CCCCCCCCC=O', 'O=C(OC)CC1C(C(=O)CC1)CCCCC'),
  ('CC(C1CCCC(C1)(C)C)OC(C)(C)COC(=O)C2CC2', 'CC1(CCCCCCCCCCC1)OC'),
  ('CC(=CCC/C(=C/CC1CCCC1=O)/C)C', 'CCCCCCC(=O)OCCC1=CC=CC=C1'),
  ('CC1=C(C(=O)CO1)O', 'CCC1(CC(=O)C(=C(C1C)C)C)C'),
  ('CCCCCCCC(=O)OCC1=CC=CC=C1', 'CCCCCCCCCCCCOC(=O)CCC'),
  ('CC(CCC1=CC=CC=C1)CCO', 'CCCCC[C@H]1[C@@H](CCC1=O)C(=O)OC'),
  ('CCCCCCOC(=O)CCCCC', 'CCCCCOC(=O)CC(C)C'),
  ('CCCCCC(OCC)OCC', 'CCOC(=O)CC(C)C'),
  ('C/C(=C\\C=C)/CC(C)(C=C)O', 'CCOC(=O)CC1(OCC(O1)C)C'),
  ('C1CCCCCC(=O)OCC/C=C/CCCC1', 'CC(=CCC/C(=C/CO)/C)C'),
  ('CC(=CC(=O)

In [14]:
all_smiles = set()
for d in dataset:
    all_smiles.add(d["mol1"])
    all_smiles.add(d["mol2"])
len(all_smiles),next(iter(all_smiles))

(2971, 'CC1=CC[C@@H]2[C@@H](C1)[C@@H](CC=C2C)C(C)C')

In [15]:
from ogb.utils import smiles2graph

def to_torch(graph):
    tensor_keys = ["edge_index", 'edge_feat', 'node_feat']
    for key in tensor_keys:
        graph[key] = torch.from_numpy(graph[key])
    return Data(x=graph["node_feat"].float(),edge_attr=graph["edge_feat"],edge_index=graph["edge_index"])

errored = 0
graph_data = dict()
for smiles in tqdm.tqdm(all_smiles):
    try:
        graph_data[smiles] = to_torch(smiles2graph(smiles))
    except AttributeError as e:
        errored += 1
errored, len(graph_data), next(iter(graph_data.items()))

  0%|                                                | 0/2971 [00:00<?, ?it/s]


NameError: name 'torch' is not defined

In [None]:
import numpy as np
from torch_geometric.data import InMemoryDataset, download_url, Data
import torch
from torch_geometric.loader import DataLoader
import data

pair_to_data = dict()
for d in tqdm.tqdm(dataset):
    if not d["mol1"] in graph_data or not d["mol2"] in graph_data:
        continue
    pair = (d["mol1"],d["mol2"])
    g1 = graph_data[d["mol1"]]
    g2 = graph_data[d["mol2"]]
    pair_to_data[pair] = data.combine_graphs([g1,g2])
len(pair_to_data), next(iter(pair_to_data.items()))

In [None]:
valid_pairs = set(pair_to_data.keys()).intersection(set(positives.keys())).intersection(set(negatives.keys()))
len(valid_pairs), len(valid_pairs)/len(dataset)

In [None]:
import json
import h5py
import base64

def encode_smiles(smiles):
    return base64.urlsafe_b64encode(smiles.encode()).decode()

with h5py.File('dataset.h5', 'w') as f:
    for (pair, data) in tqdm.tqdm(pair_to_data.items()):
        group = f.create_group(encode_smiles(json.dumps(pair)))
        group.create_dataset("pair",data=pairs[i])
        group.create_dataset("positives",data=positives[pair])
        group.create_dataset("negatives",data=negatives[pair])
        graph_group = group.create_group("graph")
        for k,v in data.items():
            graph_group.create_dataset(k,data=v.numpy())

In [None]:
for p in random.sample(sorted(pair_to_data.keys()),100):
    print(p,json.dumps(p))