In [1]:
import sys
import os

os.chdir("..")
os.getcwd()

'/Users/laurasisson/dream'

In [2]:
import json

with open("Data/annealed_70_30.json") as f:
    annealed = json.load(f)

annealed.keys()

dict_keys(['train', 'test', 'covered_notes'])

In [3]:
import pairdata

pairdata.convert(annealed["train"][0])

{'mol1': 'CCC\\C=C/CCC(=O)OCC',
 'mol2': 'CCC(=O)OCCC(C)C',
 'blend_notes': ['fruity']}

In [4]:
converted = [pairdata.convert(d) for d in annealed["train"]]
len(converted)

216961

In [5]:
converted[0]

{'mol1': 'CCC\\C=C/CCC(=O)OCC',
 'mol2': 'CCC(=O)OCCC(C)C',
 'blend_notes': ['fruity']}

In [6]:
all_smiles = set()
for d in converted:
    all_smiles.add(d["mol1"])
    all_smiles.add(d["mol2"])

In [7]:
from ogb.utils import smiles2graph

errored = 0
graph_data = dict()
for smiles in all_smiles:
    try:
        graph_data[smiles] = pairdata.to_torch(smiles2graph(smiles))
    except AttributeError as e:
        print(e)
        errored += 1

errored

[17:35:20] SMILES Parse Error: syntax error while parsing: InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3
[17:35:20] SMILES Parse Error: Failed parsing SMILES 'InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3' for input: 'InChI=1/C7H8S/c1-6-4-2-3-5-7(6)8/h2-5,8H,1H3'
[17:35:20] SMILES Parse Error: syntax error while parsing: (C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC
[17:35:20] SMILES Parse Error: Failed parsing SMILES '(C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC' for input: '(C)C1=CN=CC(=N1)OC.CC(C)C1=CN=C(C=N1)OC.CC(C)C1=NC=CN=C1OC'


'NoneType' object has no attribute 'GetAtoms'
'NoneType' object has no attribute 'GetAtoms'


2

In [14]:
import torch_geometric as tg
import data


def combine_graphs_old(graphs):
    combined_batch = next(iter(tg.loader.DataLoader(graphs, batch_size=len(graphs))))
    # Index of the molecule, for each atom
    mol_batch = combined_batch.batch
    # Index of the blend, for each molecule (increment during batch)
    blend_batch = torch.zeros(len(graphs), dtype=torch.long)
    return data.BlendData(
        x=combined_batch.x,
        edge_attr=combined_batch.edge_attr,
        edge_index=combined_batch.edge_index,
        mol_batch=mol_batch,
        blend_batch=blend_batch,
    )

In [15]:
import torch
import tqdm

for d in tqdm.tqdm(converted):
    if not d["mol1"] in graph_data or not d["mol2"] in graph_data:
        continue

    g1 = graph_data[d["mol1"]]
    g2 = graph_data[d["mol2"]]

    graphs = combine_graphs_old([g1, g2])
    graphs_fast = data.combine_graphs([g1, g2])
    for k in graphs.keys():
        assert torch.equal(graphs[k], graphs_fast[k])

100%|█████████████████████████████████| 216961/216961 [00:54<00:00, 4005.01it/s]


In [None]:
import time

start = time.perf_counter()

for d in tqdm.tqdm(converted):
    if not d["mol1"] in graph_data or not d["mol2"] in graph_data:
        continue

    g1 = graph_data[d["mol1"]]
    g2 = graph_data[d["mol2"]]

    graphs = combine_graphs_old([g1, g2])

print(f"Old method elapsed = {time.perf_counter() - start:.2f}s")

In [None]:
import time

start = time.perf_counter()

for d in tqdm.tqdm(converted):
    if not d["mol1"] in graph_data or not d["mol2"] in graph_data:
        continue

    g1 = graph_data[d["mol1"]]
    g2 = graph_data[d["mol2"]]

    graphs = data.combine_graphs([g1, g2])

print(f"New method elapsed = {time.perf_counter() - start:.2f}s")