In [None]:
import os
os.chdir("..")
os.listdir()

In [2]:
import json

# Define a maximum allowable exponent to prevent overflow
MAX_EXPONENT = 709  # Close to the upper limit for math.exp without overflow
MIN_EXPONENT = -709  # To avoid underflow to 0

with open("dataset/full_large.json") as f:
    full_data = json.load(f)

full_data = [d for d in full_data if d["mol1"] and d["mol2"]]
len(full_data), full_data[0]

(266758,
 {'mol1': 'CC1=CC2C(C2(C)C)CC1C(=O)C',
  'mol2': 'CCCC(CC)O',
  'blend_notes': ['herbal']})

In [3]:
import graph.utils
import collections

all_edges = set()
all_nodes = set()
edge_to_notes = dict()
node_to_edges = collections.defaultdict(set)

for d in full_data:
    edge = graph.utils.sort(d["mol1"],d["mol2"])
    all_edges.add(edge)
    edge_to_notes[edge] = graph.utils.canonize(d["blend_notes"])

    all_nodes.add(d["mol1"])
    node_to_edges[d["mol1"]].add(edge)
    
    all_nodes.add(d["mol2"])
    node_to_edges[d["mol2"]].add(edge)

full_graph = collections.defaultdict(set)
for n1, n2 in all_edges:
    full_graph[n1].add(n2)
    full_graph[n2].add(n1)


In [4]:
import random

def build_edges(all_nodes):
    all_edges = set()
    for node in all_nodes:
        for other in full_graph[node]:
            if not other in all_nodes:
                continue
            all_edges.add(graph.utils.sort(node,other))
    return all_edges

def random_split_carving(train_fraction):
    train_nodes = set(random.sample(sorted(all_nodes),int(len(all_nodes)*train_fraction)))
    test_nodes = all_nodes.difference(train_nodes)
    
    train_edges = build_edges(train_nodes)
    test_edges = build_edges(test_nodes)

    assert not train_nodes.intersection(test_nodes)
    assert not train_edges.intersection(test_edges)
    
    return train_edges, test_edges

In [5]:
import tqdm
import math
data_path = "notebooks/carvings"

def calculate_decay_constant(initial_temp, final_temp, steps):
    """
    Calculate the decay constant given the initial temperature, final temperature, and steps.

    :param initial_temp: Initial temperature (e.g., 50)
    :param final_temp: Final temperature (e.g., 0.001)
    :param steps: Number of steps (e.g., 100)
    :return: Decay constant
    """
    decay_constant = math.exp(math.log(final_temp / initial_temp) / steps)
    return decay_constant

def make_dataset(edges):
    return [{"mol1":m1,"mol2":m2,"blend_notes":edge_to_notes[(m1,m2)]} for (m1,m2) in edges]

def save_carving(fname,train_edges,test_edges,covered):
    result = {"train":make_dataset(train_edges),"test":make_dataset(test_edges),"covered_notes":covered}
    with open(os.path.join(data_path,fname),"w") as f:
            json.dump(result,f)

def anneal_better_coverage(fname, train_fraction, steps = 10000, initial_temp = 50, final_temp = 1e-3):
    def build_covered(edges):
        covered = collections.Counter()
        for e in edges:
            covered.update(edge_to_notes[e])
        return covered

    def get_data(nodes):
        edges = build_edges(nodes)
        return nodes, edges, build_covered(edges)

    def get_relevant_edges(new_nodes,n):
        edges = node_to_edges[n]
        return {(n1,n2) for (n1,n2) in edges if n1 in new_nodes and n2 in new_nodes}

    def update_data(nodes,edges,covered,to_add,to_remove):
        remove_edges = get_relevant_edges(nodes,to_remove)
        new_nodes = {n for n in nodes if n!=to_remove}

        new_nodes.add(to_add)
        add_edges = get_relevant_edges(new_nodes,to_add)

        lost = build_covered(remove_edges)
        gained = build_covered(add_edges)

        new_edges = edges.union(add_edges).difference(remove_edges)
        new_covered = covered - lost + gained
        return new_nodes, new_edges, new_covered

    test_fraction = 1 - train_fraction
    all_covered = set()
    i = 0
    train_nodes = set(random.sample(sorted(all_nodes),int(len(all_nodes)*train_fraction)))
    test_nodes = all_nodes.difference(train_nodes)
    
    train_nodes, train_edges, train_covered = get_data(train_nodes)
    test_nodes, test_edges, test_covered = get_data(test_nodes)

    temperature = initial_temp
    decay = calculate_decay_constant(initial_temp, final_temp, steps)
    print(f"Using decay of {decay:.4f} from t = {initial_temp:.3f} -> {final_temp:.3f} over {steps:,} steps")
    skipped = 0
    hits = 0

    print(f"train/test: {train_fraction}/{test_fraction}:, steps: {steps}. temperature: {temperature}. decay:{decay}.")
    with tqdm.tqdm(total=steps, smoothing=0) as pbar:
        while i < steps:
            fraction = i/steps
            covered = set(train_covered.keys()).intersection(set(test_covered.keys()))
            old_covered = len(covered)
            fraction = (len(train_edges)+len(test_edges))/len(all_edges)
            i += 1
            temperature *= decay

            pbar.update(1)
            pbar.set_postfix({"Covered":len(covered),"Fraction":fraction,"Temperature":temperature})

            x1,x2 = random.choice(list(train_nodes)), random.choice(list(test_nodes))
            new_train_nodes, new_train_edges, new_train_covered = update_data(train_nodes,train_edges,train_covered,x2,x1)
            new_test_nodes, new_test_edges, new_test_covered = update_data(test_nodes,test_edges,test_covered,x1,x2)

            new_covered = len(set(new_train_covered.keys()).intersection(set(new_test_covered.keys())))

            if new_covered < old_covered:
                skipped += 1
                continue


             # Check if the new configuration results in fewer edges
            delta = (len(new_train_edges) - len(train_edges)) + (len(new_test_edges) - len(test_edges))
            if delta < 0: 
                # Calculate the exponent and cap it
                exponent = delta / temperature  # delta is negative here
                exponent = max(min(exponent, MAX_EXPONENT), MIN_EXPONENT)

                # Calculate the acceptance probability with the capped exponent
                acceptance_prob = math.exp(exponent)

                # Decide whether to continue based on acceptance probability
                if random.random() > acceptance_prob:
                    skipped += 1
                    continue

            train_nodes, train_edges, train_covered = new_train_nodes, new_train_edges, new_train_covered
            test_nodes, test_edges, test_covered = new_test_nodes, new_test_edges, new_test_covered
            hits += 1

    train_covered_set = {k for k,v in train_covered.most_common() if v>0}
    test_covered_set = {k for k,v in test_covered.most_common() if v>0}
    covered = list(train_covered_set.intersection(test_covered_set))
    print(f"Skipped: {skipped}, Hits: {hits}, Train Edges: {len(train_edges)}, Test Edges: {len(test_edges)}, Covered: {len(covered)}")
    save_carving(fname,train_edges,test_edges,covered)
            
anneal_better_coverage("example.json",.5,steps=1000)

Using decay of 0.9892 from t = 50.000 -> 0.001 over 1,000 steps
train/test: 0.5/0.5:, steps: 1000. temperature: 50. decay:0.9892385449788692.


100%|█| 1000/1000 [00:07<00:00, 130.26it/s, Covered=96, Fraction=0.531, Temperat


Skipped: 515, Hits: 485, Train Edges: 57338, Test Edges: 84436, Covered: 96
