In [12]:
import json
with open("../dataset/single_fold.json") as f:
    data = json.load(f)
data["train"][0]

{'edge': ['CCCCCCCCCCCC(OC)OC', 'CCC/C=C\\CO'], 'blend_notes': ['green']}

In [18]:
with open("../dataset/full.json") as f:
    full_data = json.load(f)
full_data[0]

{'mol1': 'CCCCC/C=C/C(=O)OC',
 'mol1_notes': ['violet',
  'sweet',
  'oily',
  'melon',
  'pear',
  'hairy',
  'costus',
  'fruity',
  'violet leaf',
  'waxy',
  'fresh',
  'green'],
 'mol2': 'CCCCCOC(=O)CCC',
 'mol2_notes': ['cherry',
  'sweet',
  'pineapple',
  'fruity',
  'banana',
  'tropical'],
 'blend_notes': ['animal', 'fruity', 'waxy']}

In [20]:
train_set = set()
test_set = set()

for d in data["train"]:
    train_set.update(d["edge"])

for d in data["test"]:
    test_set.update(d["edge"])


In [21]:
mol_notes = dict()
for d in full_data:
    mol_notes[d["mol1"]] = d["mol1_notes"]
    mol_notes[d["mol2"]] = d["mol2_notes"]


In [37]:
single_notes = ['aldehydic', 'alliaceous', 'amber', 'animal', 'anise', 'aromatic', 'balsamic', 'berry', 'bitter', 'burnt', 'buttery', 'camphoreous', 'caramellic', 'cheesy', 'chocolate', 'citrus', 'clean', 'cocoa', 'coconut', 'coffee', 'cooling', 'coumarinic', 'creamy', 'dairy', 'earthy', 'ethereal', 'fatty', 'fermented', 'floral', 'fresh', 'fruity', 'green', 'herbal', 'honey', 'meaty', 'medicinal', 'melon', 'minty', 'musk', 'musty', 'nutty', 'oily', 'onion', 'orris', 'phenolic', 'powdery', 'roasted', 'rummy', 'soapy', 'solvent', 'sour', 'spicy', 'sulfurous', 'sweet', 'tropical', 'vanilla', 'vegetable', 'waxy', 'winey', 'woody']
len(single_notes)

60

In [54]:
import tqdm
from ogb.utils import smiles2graph
import torch
import torch_geometric as pyg

def multi_hot(notes, all_notes):
    notes = [n for n in notes if n in all_notes]
    indices = torch.tensor([all_notes.index(n) for n in notes])
    if len(indices) == 0:
        # Occurs when the notes in the pair were removed due to infrequency.
        raise AttributeError("Found no valid notes.")
    one_hots = torch.nn.functional.one_hot(indices, len(all_notes))
    return one_hots.sum(dim=0)
    
def to_torch(smiles):
    graph = smiles2graph(smiles)
    tensor_keys = ["edge_index", 'edge_feat', 'node_feat']
    for key in tensor_keys:
        graph[key] = torch.tensor(graph[key])
    return graph
    
def make_data(mol_set):
    graph_data = []
    for smiles in tqdm.tqdm(mol_set):
        try:
            graph = to_torch(smiles)
            graph["y"] = multi_hot(mol_notes[smiles],single_notes)
            gd = pyg.data.Data(x=graph["node_feat"].float(),edge_attr=graph["edge_feat"],edge_index=graph["edge_index"],y=graph["y"].float(),smiles=smiles)
            graph_data.append(gd)
        except AttributeError:
            continue
            
    return graph_data

train_data = make_data(train_set)
test_data = make_data(test_set)
len(train_data), len(test_data)

100%|███████████████████████████████████| 1471/1471 [00:01<00:00, 1035.90it/s]
100%|████████████████████████████████████| 1468/1468 [00:04<00:00, 360.30it/s]


(1446, 1434)

In [42]:
import os
os.chdir("..")
os.getcwd()

'/Users/laurasisson/odor-pair'

In [53]:
import single.data

train_fname = "trainsingles.pt"
test_fname = "testsingles.pt"

single.data.save(train_data,train_fname)
single.data.save(test_data,test_fname)