In [1]:
import os
os.chdir("..")
os.getcwd()

'/Users/laurasisson/odor-pair'

In [2]:
import json
with open("dataset/aroma_large.json","r") as f:
    data = json.load(f)
len(data), data[0]

(266721,
 {'mol1': 'CCCC(CC)O',
  'mol2': 'CC1=CC2C(C2(C)C)CC1C(=O)C',
  'blend_notes': ['herbal']})

In [3]:
import collections
import graph.utils

all_edges = set()
all_nodes = set()
edge_to_notes = dict()
node_to_edges = collections.defaultdict(set)
all_notes = set()


for d in data:
    edge = graph.utils.sort(d["mol1"],d["mol2"])
    all_edges.add(edge)
    edge_to_notes[edge] = d["blend_notes"]
    all_notes.update(d["blend_notes"])


    all_nodes.add(d["mol1"])
    node_to_edges[d["mol1"]].add(edge)
    
    all_nodes.add(d["mol2"])
    node_to_edges[d["mol2"]].add(edge)


for n1, n2 in all_edges:
    if (n2,n1) in all_edges:
        print(n2,n2)
    assert not (n2,n1) in all_edges
    assert not n1 == n2

full_graph = collections.defaultdict(set)
for n1, n2 in all_edges:
    full_graph[n1].add(n2)
    full_graph[n2].add(n1)

train_fraction = .8
test_fraction = .2

assert train_fraction + test_fraction == 1
len(all_notes)

130

In [4]:
import random

def build_edges(subgraph_nodes):
    subgraph_edges = set()
    for node in subgraph_nodes:
        for other in full_graph[node]:
            if not other in subgraph_nodes:
                continue
            subgraph_edges.add(graph.utils.sort(node,other))
    return subgraph_edges
build_edges(random.sample(list(all_nodes),10))

{('CCCCCCCCCCCCOC(=O)CCC', 'CC(C)COC(=O)CCCCCCCCC=C')}

In [None]:
from tqdm.notebook import tqdm

def anneal_better_coverage(lim=100):
    def build_covered(edges):
        covered = collections.Counter()
        for e in edges:
            covered.update(edge_to_notes[e])
        return covered

    def get_data(nodes):
        edges = build_edges(nodes)
        return nodes, edges, build_covered(edges)

    def get_relevant_edges(new_nodes,n):
        edges = node_to_edges[n]
        return {(n1,n2) for (n1,n2) in edges if n1 in new_nodes and n2 in new_nodes}

    def update_data(nodes,edges,covered,to_add,to_remove):
        remove_edges = get_relevant_edges(nodes,to_remove)
        new_nodes = {n for n in nodes if n!=to_remove}

        new_nodes.add(to_add)
        add_edges = get_relevant_edges(new_nodes,to_add)

        lost = build_covered(remove_edges)
        gained = build_covered(add_edges)

        new_edges = edges.union(add_edges).difference(remove_edges)
        new_covered = covered - lost + gained
        return new_nodes, new_edges, new_covered

    all_covered = set()
    i = 0
    train_nodes = set(random.sample(sorted(all_nodes),int(len(all_nodes)*train_fraction)))
    test_nodes = all_nodes.difference(train_nodes)
    
    train_nodes, train_edges, train_covered = get_data(train_nodes)
    test_nodes, test_edges, test_covered = get_data(test_nodes)

    skip_bad_edge_trade = True

    skipped = 0
    hits = 0
    with tqdm(total=lim,disable=True) as pbar:
        while i < lim:
            fraction = i/lim
            covered = set(train_covered.keys()).intersection(set(test_covered.keys()))
            old_covered = len(covered)
            # if i % int(lim*.1) == 0:
            #     print("Swap",fraction,f"{skipped/(hits+1):.2f}",len(train_edges),len(test_edges),(len(train_edges)+len(test_edges))/len(all_edges))
            # elif i % int(lim*.01) == 0:
            #     print("Swap",i,len(train_covered),len(test_covered),len(set(train_covered.keys()).intersection(set(test_covered.keys()))))
            # if i % int(lim*.05) == 0:
            #     print("MISSING",sorted(list(graph.utils.missing_notes(covered))))

            i += 1
            x1,x2 = random.choice(list(train_nodes)), random.choice(list(test_nodes))
            new_train_nodes, new_train_edges, new_train_covered = update_data(train_nodes,train_edges,train_covered,x2,x1)
            new_test_nodes, new_test_edges, new_test_covered = update_data(test_nodes,test_edges,test_covered,x1,x2)

            new_covered = len(set(new_train_covered.keys()).intersection(set(new_test_covered.keys())))

            if new_covered < old_covered:
                skipped += 1
                pbar.update(1)
                continue

            if len(new_train_covered) < len(train_covered) or len(new_test_covered) < len(test_covered):
                skipped += 1
                pbar.update(1)
                continue

            # As the fraction approaches 1, this statement becomes true more,
            # So we skip bad trades more.
            #  
            # if random.random()/100 < fraction and (len(new_train_edges) < len(train_edges) or len(new_test_edges) < len(test_edges)):
            #     skipped += 1
            #     pbar.update(1)
            #     continue

            train_nodes, train_edges, train_covered = new_train_nodes, new_train_edges, new_train_covered
            test_nodes, test_edges, test_covered = new_test_nodes, new_test_edges, new_test_covered
            hits += 1
            pbar.update(1)
            
        

    train_covered_set = {k for k,v in train_covered.most_common() if v>0}
    test_covered_set = {k for k,v in test_covered.most_common() if v>0}
    covered = list(train_covered_set.intersection(test_covered_set))
    return train_edges, test_edges, covered

best_total = 0
for i in range(500):
    train_edges, test_edges, covered = anneal_better_coverage(10000)
    best_total = max(best_total, len(train_edges) + len(test_edges))
    print(len(train_edges), len(test_edges), len(covered), best_total)

173564 9916 102 183480
168708 11138 101 183480
172783 10169 102 183480
169192 11227 102 183480
173722 9901 102 183623
172228 10310 102 183623
169667 10842 102 183623
175459 9612 102 185071
169959 10598 102 185071
