In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import networkx as nx
import toolz as tz
from eden.util import configure_logging
import logging
logger = logging.getLogger()
configure_logging(logger, verbosity=1)
import matplotlib.pyplot as plt
from IPython.core.display import HTML
HTML('<style>.container { width:95% !important; }</style><style>.output_png {display: table-cell;text-align: center;vertical-align: middle;}</style>')

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
from toolz import curry, pipe
from eden_chem.io.pubchem import download
from eden_chem.io.rdkitutils import sdf_to_nx

download_active = curry(download)(active=True)
download_inactive = curry(download)(active=False)

def get_pos_graphs(assay_id): return pipe(assay_id, download_active, sdf_to_nx, list)
def get_neg_graphs(assay_id): return pipe(assay_id, download_inactive, sdf_to_nx, list)

from GraphOptimizer.load_utils import pre_process, _random_sample
from eden_chem.io.pubchem import get_assay_description

#assay_ids = ['624466','492992','463230','651741','743219','588350','492952','624249','463213','2631','651610']

def load_PUBCHEM_data(assay_id, max_size=20):
    configure_logging(logger, verbosity=2)
    logger.debug('_'*80)
    logger.debug('Dataset %s info:'%assay_id)
    desc = get_assay_description(assay_id)
    logging.debug('\n%s'%desc)
    # extract pos and neg graphs
    all_pos_graphs, all_neg_graphs = get_pos_graphs(assay_id), get_neg_graphs(assay_id)
    # remove too large and too small graphs and outliers
    initial_max_size=2000
    initial_max_size=max(initial_max_size,max_size)
    args=dict(initial_max_size=initial_max_size, fraction_to_remove=.1, n_neighbors_for_outliers=9, remove_similar=False, max_size=max_size)
    logging.debug('\nPositive graphs')
    pos_graphs = pre_process(all_pos_graphs, **args)
    logging.debug('\nNegative graphs')
    neg_graphs = pre_process(all_neg_graphs, **args)
    logger.debug('-'*80)
    configure_logging(logger, verbosity=1)
    return pos_graphs, neg_graphs

from rdkit.Chem.Draw import SimilarityMaps
from eden_chem.io.rdkitutils import nx_to_rdkit
import matplotlib.pyplot as plt
from IPython.core.display import display
from eden_chem.display.rdkitutils import nx_to_image

def display_mol(graph, title, part_importance_estimator):
    node_score_dict, edge_score_dict = part_importance_estimator.predict(graph)
    weights = [node_score_dict[u] for u in graph.nodes()]
    mol = nx_to_rdkit(graph)
    fig = SimilarityMaps.GetSimilarityMapFromWeights(
        mol, weights, size=(250, 250), alpha=0.075, contourLines=1, sigma=.03)
    plt.title(title)
    plt.show()

def draw_mols(graphs, titles=None, num=None, n_graphs_per_line=7):
    """draw_mols."""
    if titles is None:
        titles = [str(i) for i in range(len(graphs))]
    if num is not None:
        gs = graphs[:num]
        titles = titles[:num]
    else:
        gs = graphs
    for g,t in zip(gs,titles):
        g.graph['id']=str(t)
    try:
        img = nx_to_image(gs, n_graphs_per_line=n_graphs_per_line, titles=titles)
        display(img)
    except Exception as e:
        args = dict(layout='kk', colormap='Set1', vmin=0, vmax=1, vertex_size=80, edge_label=None, vertex_color_dict=colors, vertex_color='-label-', vertex_label=None,  ignore_for_layout='nesting')
        draw_graph_set(gs, n_graphs_per_line=6, size=7, **args)
        

def display_ktop_mols(graphs, oracle_func, n_max=6):
    scores = [oracle_func(g) for g in graphs]
    ids = np.argsort(scores)[-n_max:]
    best_graphs = [graphs[id] for id in ids]
    best_scores = [scores[id] for id in ids]
    distinct_best_scores = []
    distinct_best_graphs = []
    prev_score = None
    counter = 0
    distinct_counters = []
    for best_graph, best_score in zip(best_graphs, best_scores):
        if prev_score != best_score:
            distinct_best_graphs.append(best_graph)
            distinct_best_scores.append(best_score)
            distinct_counters.append(counter)
            counter = 0
        else:
            counter += 1
        prev_score = best_score
    titles = ['%.6f x %d'%(distinct_best_scores[i], distinct_counters[i+1]+1) for i in range(len(distinct_best_scores)-1) ]
    titles += ['%.6f x %d'%(best_score,counter+1)]
    draw_mols(distinct_best_graphs, titles=titles, n_graphs_per_line=6)

In [None]:
from collections import defaultdict
from IPython.core.display import display
from eden.display import draw_graph, draw_graph_set

def select_unique(codes, fragments):
    already_seen = set()
    unique_codes=[]
    unique_fragments=[]
    code_counts = defaultdict(int)
    for code, fragment in zip(codes, fragments):
        if code not in already_seen:
            unique_codes.append(code)
            unique_fragments.append(fragment)
            already_seen.add(code)
        code_counts[code] += 1
    return unique_codes, unique_fragments, code_counts

def show_decomposition_graphs(graphs, decompose_funcs, preprocessors=None, show_labels=None, show_all=True):
    feature_size, bitmask = set_feature_size(nbits=14)
    encoding_func = make_encoder(decompose_funcs, preprocessors=preprocessors, bitmask=bitmask, seed=1)

    from eden.display import map_labels_to_colors
    colors = map_labels_to_colors(graphs)

    for g in graphs:
        print('_'*80)        
        codes, fragments = encoding_func(g)
        unique_codes, unique_fragments, code_counts = select_unique(codes, fragments)
        titles = ['%d   #%d'%(id,code_counts[id]) for id in unique_codes]
        for f,t in zip(unique_fragments, titles):
            f.graph['id']=t
        #titles = list(map(str, unique_codes))
        print('%d unique components in %d fragments'%(len(unique_codes),len(codes)))
        n_graphs_per_line=10
        n_lines=4
        size=2
        if unique_fragments:
            args = dict(layout='spring', colormap='Set1', vmin=0, vmax=1, vertex_size=80, edge_label=None, vertex_color_dict=colors, vertex_color='-label-', vertex_label=None,  ignore_for_layout='nesting')
            draw_graph_set([g], n_graphs_per_line=n_graphs_per_line, size=7, **args)

            if show_all is True:
                draw_graph_set(unique_fragments, n_graphs_per_line=n_graphs_per_line, size=size, **args)
            else:
                draw_graph_set(unique_fragments[:n_graphs_per_line*n_lines], n_graphs_per_line=n_graphs_per_line, size=size, **args)
            if show_labels:
                draw_graph_set(unique_fragments[:n_graphs_per_line*n_lines], n_graphs_per_line=n_graphs_per_line, 
                           vertex_size=80, vertex_label='label', edge_label='label', size=size, ignore_for_layout='nesting')
        else:
            print('No fragments')
        

In [None]:
from eden.display import draw_graph, draw_graph_set, map_labels_to_colors
def plot(graphs):
    size=np.log(len(graphs[0]))+1
    colors = map_labels_to_colors(graphs)
    kwargs = dict(colormap='Set1', vmin=0, vmax=1, vertex_size=80, edge_label=None, vertex_color_dict=colors, vertex_color='_label_', vertex_label=None, ignore_for_layout='nesting', layout='spring')
    draw_graph_set(graphs, n_graphs_per_line=5, size=size, **kwargs)

In [None]:
# create random graphs
# give positive negative class to those with cycles larger than threshold
# or with at least 2 cycles larger than threshold that share an edge 

import random

def make_instance(length=20, alphabet_size=3, frac=.3, start_char=97):
    n_frac=int(length*frac/(alphabet_size-1))

    def make_char(i,start_char=97):
        return chr(i+start_char)

    def make_chars(i, dim, start_char=97):
        return make_char(i,start_char)*dim

    line=''
    line += make_chars(0, length - n_frac*(alphabet_size-1), start_char)
    for i in range(1, alphabet_size):
        line += make_chars(i, n_frac, start_char)
    line=list(line)
    random.shuffle(line)
    return ''.join(line)

from toolz import curry 


@curry
def random_path_graph(n):
    return nx.path_graph(n)

@curry
def random_tree_graph(n):
    return nx.random_tree(n)

@curry
def random_regular_graph(d, n):
    return nx.random_regular_graph(d,n)

@curry
def random_degree_seq(n, dmax):
    sequence = np.linspace(1, dmax,n).astype(int)
    return nx.expected_degree_graph(sequence)

@curry
def random_dense_graph(n, m):
    # a graph is chosen uniformly at random from the set of all graphs with n nodes and m edges
    g = nx.dense_gnm_random_graph(n, m)
    max_cc = max(nx.connected_components(g), key=lambda x: len(x))
    g = nx.subgraph(g, max_cc)
    return g

@curry
def make_graph(graph_generator, alphabet_size, frac):
    G = graph_generator
    labels = make_instance(length=len(G), alphabet_size=alphabet_size, frac=frac)
    dict_labels = {i:str(l) for i,l in enumerate(labels)}
    nx.set_node_attributes(G,dict_labels,'label')
    nx.set_edge_attributes(G,'1','label')
    return G

In [None]:
def make_sequence_data(target_graph, n_instances, diversity):
    # extract sequence of labels
    graphs = []
    for n in range(n_instances):
        seq = [target_graph.nodes[u]['label'] for u in target_graph.nodes()]
        for i in range(diversity):
            j = random.randint(1,len(target_graph))
            seq = seq[j:][::-1] + seq[:j]
        G=nx.path_graph(len(target_graph))
        dict_labels = {i:str(l) for i,l in enumerate(seq)}
        nx.set_node_attributes(G,dict_labels,'label')
        nx.set_edge_attributes(G,'1','label')
        graphs.append(G.copy())
    return graphs

In [None]:
from ego.setup import *
from ego.vectorize import vectorize as ego_vectorize
import time
from sklearn.neighbors import NearestNeighbors
    
def ego_oracle_setup(target_graph, df=None, preproc=None):
    target_graph_vec = ego_vectorize([target_graph], decomposition_funcs=df, preprocessors=preproc)
    target_norm =  target_graph_vec.dot(target_graph_vec.T).A[0,0]
    def oracle_func(g):
        g_vec = ego_vectorize([g], decomposition_funcs=df, preprocessors=preproc)
        g_norm =  g_vec.dot(g_vec.T).A[0,0]
        scale_factor = np.sqrt(g_norm * target_norm)
        score = g_vec.dot(target_graph_vec.T).A[0,0]/scale_factor
        return score
    return oracle_func

import random
def oracle_setup(target_graph, random_noise=0.05, include_structural_similarity=True):
    df = do_decompose(decompose_cycles_and_non_cycles, decompose_neighborhood(radius=2))
    preproc = preprocess_abstract_label(node_label='C', edge_label='1')
    structural_oracle_func = ego_oracle_setup(target_graph, df, preproc)

    df = do_decompose(decompose_nodes_and_edges)
    preproc = None
    compositional_oracle_func = ego_oracle_setup(target_graph, df, preproc)

    df = do_decompose(decompose_path(length=2), decompose_neighborhood)
    preproc = None
    comp_and_struct_oracle_func = ego_oracle_setup(target_graph, df, preproc)

    
    target_size = len(target_graph)

    def oracle_func(g, explain=False):
        g_size = len(g)
        size_similarity = max(0, 1 - abs(g_size - target_size)/float(target_size))
        structural_similarity = structural_oracle_func(g)
        composition_similarity = compositional_oracle_func(g)
        comp_and_struct_similarity = comp_and_struct_oracle_func(g)
        #score = min(size_similarity,structural_similarity,composition_similarity)
        score = sp.stats.gmean([size_similarity,structural_similarity,composition_similarity,comp_and_struct_similarity])
        noise = random.random()*random_noise
        tot_score = score + noise 
        if explain:
            return tot_score, score, size_similarity, structural_similarity, composition_similarity, comp_and_struct_similarity, noise
        else:
            return tot_score

    return oracle_func

In [None]:
def select_k_best(graphs, oracle_func, num=100):
    sorted_graphs = sorted(graphs, key=lambda g:oracle_func(g), reverse=True)
    return sorted_graphs[:num]

---

In [None]:
from scipy.interpolate import Rbf
from eden.display.graph_layout import KKEmbedder

def visualize_importance(g, part_importance_estimator, title=''):
    g = nx.convert_node_labels_to_integers(g)
    
    node_imp_dict = part_importance_estimator.node_importance(g)

    colors = map_labels_to_colors([g])
    node_color = []
    for u, d in g.nodes(data=True):
        label = d.get('label', '.')
        node_color.append(colors.get(label, 0))
    pos = pos = KKEmbedder().transform(g)
    X = np.array([pos[u] for u in sorted(pos.keys())])
    y = np.array([node_imp_dict[u] for u in sorted(pos.keys())])

    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    h = (x_max - x_min)/50  # step size in the mesh
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

    a, b = X.T
    rbf = Rbf(a,b, y, epsilon=h*2, function='gaussian')
    Z = rbf(xx, yy)
    Z = Z.reshape(xx.shape)
    plt.contourf(xx, yy, Z, 50, cmap='hot')

    nodes = nx.draw_networkx_nodes(g, pos, node_size=90, linewidths=2, node_color=node_color, cmap='Set1')
    nodes.set_edgecolor('k')
    nx.draw_networkx_edges(g, pos, width=2, edge_color='grey')

    a,b=X.T
    plt.scatter(a,b,c=y, s=300, cmap='hot')
    plt.title(title)
    plt.axis('off')
    
def paired_visualize_importance(g1,g2, part_importance_estimator, title1='', title2=''):
    plt.figure(figsize=(15,7))
    plt.subplot(1, 2, 1)
    visualize_importance(g1, part_importance_estimator, title=title1)
    plt.subplot(1, 2, 2)
    visualize_importance(g2, part_importance_estimator, title=title2)
    plt.tight_layout()
    plt.show()
    

In [None]:
def plot_grammar(grammar, n_max=9):
    print(grammar)
    cips_list = grammar.get()
    for i, cips in enumerate(cips_list):
        cip_list = [cip for interface, cip in cips if len(cip)>0]
        interface_list = [interface for interface, cip in cips if len(interface)>0]
        if interface_list:
            print('-'*80)
            print('interface %d/%d: interface size:%d  #cores:%d'%(i+1, len(cips_list),len(interface_list[0]),len(cips)))
            draw_graph_set([interface_list[0]]+cip_list[:n_max], n_graphs_per_line=10, vertex_size=90, size=4, layout='spring')

In [None]:
from scipy.ndimage import gaussian_filter1d
from scipy import interpolate
from eden.display import draw_graph, draw_graph_set, map_labels_to_colors

    
def draw_graphs(graphs, titles, n_graphs_per_line=6):
    size=np.log(len(graphs[0]))+1
    colors = map_labels_to_colors(graphs)
    gs = graphs[:]
    for g,t in zip(gs, titles): g.graph['id']=str(t)
    kwargs = dict(colormap='Set1', vertex_size=80, edge_label=None, vertex_color_dict=colors, vertex_color='_label_', vertex_label=None, layout='KK')
    draw_graph_set(gs, n_graphs_per_line=n_graphs_per_line, size=size, **kwargs)

    
def display_ktop_graphs(graphs, oracle_func, n_max=6):
    scores = [oracle_func(g) for g in graphs]
    ids = np.argsort(scores)[-n_max:]
    best_graphs = [graphs[id] for id in ids]
    best_scores = [scores[id] for id in ids]
    distinct_best_scores = []
    distinct_best_graphs = []
    prev_score = None
    counter = 0
    distinct_counters = []
    for best_graph, best_score in zip(best_graphs, best_scores):
        if prev_score != best_score:
            distinct_best_graphs.append(best_graph)
            distinct_best_scores.append(best_score)
            distinct_counters.append(counter)
            counter = 0
        else:
            counter += 1
        prev_score = best_score
    titles = ['%.6f x %d'%(distinct_best_scores[i], distinct_counters[i+1]+1) for i in range(len(distinct_best_scores)-1) ]
    titles += ['%.6f x %d'%(best_score,counter+1)]
    draw_graphs(distinct_best_graphs, titles=titles, n_graphs_per_line=6)
    
    
def display_graph_list(graphs, oracle_func, score_estimator, n_max, display_as_molecules=False):
    gs = sorted(graphs, key=lambda g:oracle_func(g), reverse=True)[:n_max]
    titles = []
    for g in gs:
        true_score = oracle_func(g)
        pred_score = score_estimator.predict([g])[0]
        titles.append('true:%.3f pred:%.3f  '%(true_score, pred_score))
    if display_as_molecules:
        draw_mols(gs, titles=titles, n_graphs_per_line=9)
    else:
        draw_graphs(gs, titles=titles, n_graphs_per_line=9)

def smooth(x,y, sigma_fact=7):
    sigma = (max(x)-min(x))/sigma_fact
    xnew = np.linspace(min(x), max(x), 200)
    gy = gaussian_filter1d(y, sigma)
    f = interpolate.InterpolatedUnivariateSpline(x, gy)
    ynew = f(xnew)
    return xnew, ynew 
    
def plot_status(estimated_mean_and_std_target, current_best, scores_list, num_oracle_queries, sigma_fact=7):
    # target with variance
    estimated_mean_and_std_target_array = np.array(estimated_mean_and_std_target)
    target_means, target_stds = estimated_mean_and_std_target_array.T
    fig = plt.figure(figsize=(17,5))
    ax1 = fig.add_subplot(1, 1, 1)
    
    ax1.fill_between(range(len(estimated_mean_and_std_target)), target_means+target_stds, target_means-target_stds, alpha=.1, color='steelblue')
    ax1.fill_between(range(len(estimated_mean_and_std_target)), target_means+target_stds/10, target_means-target_stds/10, alpha=.1, color='steelblue')
    ax1.fill_between(range(len(estimated_mean_and_std_target)), target_means+target_stds/100, target_means-target_stds/100, alpha=.1, color='steelblue')
    ax1.plot(target_means, linestyle='dashed')
    xx, m = smooth(range(len(target_means)),target_means,sigma_fact)
    ax1.plot(xx, m, lw=5, color='steelblue', label='true target graph scored by predictor')
    
    # median and violinplot
    #plt.violinplot(scores_list, range(len(scores_list)), points=60, widths=0.7, showmeans=True, showextrema=True, showmedians=True, bw_method=0.3)
    medians = [np.median(scores) for scores in scores_list]
    ax1.plot(medians, color='darkorange', lw=1, linestyle='dotted')
    xx, m = smooth(range(len(medians)), medians, sigma_fact)
    ax1.plot(xx, m, lw=3, linestyle='dashed', color='darkorange', label='median of generated graphs scored by oracle')

    #current best
    ax1.plot(current_best, color='darkorange', linestyle='dashed')
    xx, m = smooth(range(len(current_best)), current_best,sigma_fact)
    ax1.plot(xx, m, lw=5, color='darkorange', label='current opt graph scored by oracle')
    ax1.legend()
    y_low = max(0,min(min(medians), min(current_best)))
    y_up = min(1,max(max(medians), max(current_best)))
    ax1.set_ylim(y_low,y_up)
    ax1.set_xlabel('num iteration')
    ax1.set_ylabel('score')
    ax1.grid()
    
    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
    ax2.plot(num_oracle_queries, linestyle='dotted', color='gray', alpha=.5, label='# queries to oracle')
    ax2.set_ylabel('# queries')
    fig.tight_layout()
    plt.show()
    

import time
def make_monitor(target_graph, oracle_func, show_step=1, display_as_molecules=False):
    history = []
    estimated_mean_and_std_target=[]
    current_best=[]
    scores_list=[]  
    duration = []
    num_oracle_queries = []

    def monitor(i, graphs, all_graphs, score_estimator):
        num_oracle_queries.append(len(all_graphs))
        history.extend(graphs[:])
        mu, sigma = score_estimator.predict([target_graph]), score_estimator.predict_uncertainty([target_graph])
        estimated_mean_and_std_target.append((mu[0],sigma[0]))
        
        true_scores = [oracle_func(g) for g in graphs]
        pred_scores = score_estimator.predict(graphs)
            
        scores_list.append(true_scores)
        best_score = max(true_scores)
        best_graph = graphs[np.argmax(true_scores)]
        print('< %.3f > best score after %d queries'%(best_score, len(all_graphs)))
        tot_score, score, size_similarity, structural_similarity, composition_similarity, comp_and_struct_similarity, noise = oracle_func(best_graph, explain=True)
        print('    score decomposition: %.3f = size:%.3f  structure:%.3f  composition:%.3f  comp_struct:%.3f'%(score, size_similarity, structural_similarity, composition_similarity, comp_and_struct_similarity))
        current_best.append(best_score)
        duration.append(time.clock())
        if i>0 and (show_step==1 or i%show_step==0):
            if len(estimated_mean_and_std_target)>5:
                plot_status(estimated_mean_and_std_target,current_best, scores_list, num_oracle_queries, 5)

            if len(duration)>2: print('%d) corr coeff true vs preds: %.3f  runtime:%.1f mins' % (i+1, np.corrcoef(true_scores,pred_scores)[0,1], (duration[-1]-duration[-2])/60))       
            display_graph_list(graphs+[target_graph], oracle_func, score_estimator, n_max=9, display_as_molecules=display_as_molecules)
            
    return monitor

In [None]:
from ego.vectorize import hash_graph

def make_variants(target_graph):
    from ego.optimization.neighborhood_edge_swap import NeighborhoodEdgeSwap
    nes = NeighborhoodEdgeSwap(n_edges=2, n_neighbors=10)
    from ego.optimization.neighborhood_node_label_swap import NeighborhoodNodeLabelSwap
    nns = NeighborhoodNodeLabelSwap(n_nodes=1, n_neighbors=10)
    from ego.optimization.neighborhood_node_remove import NeighborhoodNodeRemove
    nnr = NeighborhoodNodeRemove(n_neighbors=10, n_nodes=1)

    gs = [target_graph]
    transformations = [nes, nns] # [nes, nns, nnr]
    for ne in transformations:
        gs = [ng for g in gs for ng in ne.neighbors(g)]
    return gs

def remove_duplicates(graphs):
    df = decompose_neighborhood(radius=2)
    selected_graphs_dict = {hash_graph(
        g, decomposition_funcs=df): g for g in graphs}
    return list(selected_graphs_dict.values())

def build_artificial_experiment(GRAPH_TYPE, n_init_instances, n_domain_instances, alphabet_size, diversity, max_score_threshold):
    if GRAPH_TYPE == 'path':
        graph_generator = random_path_graph(n=15)

    if GRAPH_TYPE == 'tree':
        graph_generator = random_tree_graph(n=18)
    
    if GRAPH_TYPE == 'degree':
        n=12
        dmax=4
        graph_generator = random_degree_seq(n, dmax)
        while nx.is_connected(graph_generator) is not True:
            graph_generator = random_degree_seq(n, dmax)
    
    if GRAPH_TYPE == 'regular':
        graph_generator = random_regular_graph(d=3, n=14)

    if GRAPH_TYPE == 'dense':
            graph_generator = random_dense_graph(n=15, m=38)
            
    target_graph = make_graph(graph_generator, alphabet_size=alphabet_size, frac=.5)
    domain_graphs = make_variants(target_graph)
    domain_graphs = domain_graphs[:n_domain_instances]
    oracle_func = oracle_setup(target_graph, random_noise=0.0)
    domain_graphs = [g for g in domain_graphs if oracle_func(g) < max_score_threshold]
    domain_graphs = remove_duplicates(domain_graphs)
    sorted_graphs = sorted(domain_graphs, key=lambda g:oracle_func(g), reverse=True) 
    half_size = int(n_init_instances/2)
    rest_graphs = sorted_graphs[half_size:]
    random.shuffle(rest_graphs)
    init_graphs = sorted_graphs[:half_size] + rest_graphs[:half_size]
    
    return init_graphs, domain_graphs, oracle_func, target_graph

def target_quality(target_graph, graphs, max_score_threshold, min_score_threshold):
    # the quality of the target is measured as the fraction of graphs that are in a desired range of similarity 
    oracle_func = oracle_setup(target_graph, random_noise=0.0)
    sel_graphs = [g for g in graphs if min_score_threshold < oracle_func(g) < max_score_threshold]
    quality_score = len(sel_graphs)/float(len(graphs))
    print('%.3f   '%(quality_score), end=" ")
    return quality_score

def build_chemical_experiment(assay_id, n_init_instances, n_domain_instances, max_score_threshold, n_targets):
    pos_graphs, neg_graphs = load_PUBCHEM_data(assay_id, max_size=n_domain_instances)
    domain_graphs = pos_graphs+neg_graphs
    domain_graphs = remove_duplicates(domain_graphs)
    random.shuffle(domain_graphs)
    
    target_graph = max(domain_graphs[:n_targets], key=lambda g:target_quality(g, domain_graphs, max_score_threshold, max_score_threshold/1.5))
    oracle_func = oracle_setup(target_graph, random_noise=0.0)

    domain_graphs = [g for g in domain_graphs if oracle_func(g) < max_score_threshold]
    sorted_graphs = sorted(domain_graphs, key=lambda g:oracle_func(g), reverse=True) 
    half_size = int(n_init_instances/2)
    rest_graphs = sorted_graphs[half_size:]
    random.shuffle(rest_graphs)
    init_graphs = sorted_graphs[:half_size] + rest_graphs[:half_size]
    return init_graphs, domain_graphs, oracle_func, target_graph

def display_score_statistics(domain_graphs, oracle_func):
    n_plots=5
    plt.figure(figsize=(6,4))
    scores = np.array([oracle_func(g) for g in domain_graphs])
    plt.hist(scores, 30, density=True, alpha=.3)
    plt.title('Scores')
    plt.grid()
    plt.show()

---

# Experiments

In [None]:
%%time

EXPERIMENT_TYPE = 'CHEMICAL'

if EXPERIMENT_TYPE == 'ARTIFICIAL':
    res = build_artificial_experiment(
        GRAPH_TYPE='tree', # path  tree  degree  regular  dense
        n_init_instances=10, 
        n_domain_instances=100,
        alphabet_size=4, 
        diversity=2, 
        max_score_threshold=.8)
    
if EXPERIMENT_TYPE == 'CHEMICAL':
    res = build_chemical_experiment(
        assay_id='743219',  # assay_ids = ['624466','492992','463230','651741','743219','588350','492952','624249','463213','2631','651610']
        n_init_instances=50, 
        n_domain_instances=300,
        max_score_threshold=.8,
        n_targets=2)

init_graphs, domain_graphs, oracle_func, target_graph = res
print('Generated %d graphs'%len(domain_graphs))
display_score_statistics(domain_graphs, oracle_func)

In [None]:
# decomposition definition

#decomposition_function = do_decompose(decompose_nodes_and_edges, decompose_path(min_len=2, max_len=3), decompose_neighborhood, decompose_neighborhood(radius=2), decompose_cycles)
decomposition_function = do_decompose(decompose_nodes_and_edges, decompose_path(length=2), decompose_neighborhood, decompose_neighborhood(radius=2), decompose_cycles)

In [None]:
DISPLAY_GRAMMAR=False

if DISPLAY_GRAMMAR:
    from graphlearn.lsgg_ego import lsgg_ego
    grammar = lsgg_ego(decomposition_function=decomposition_function,
                thickness=1,
                filter_min_cip=1,
                filter_min_interface=2,
                nodelevel_radius_and_thickness=False)
    grammar.fit(init_graphs)
    plot_grammar(grammar)

In [None]:
DISPLAY_DECOMPOSITION=False

if DISPLAY_DECOMPOSITION:
    print('This is how the target graph is decomposed:')
    show_decomposition_graphs([target_graph], decompose_funcs=decomposition_function)

In [None]:
print('Best graphs in initial sample of %d'%len(init_graphs))
if EXPERIMENT_TYPE == 'CHEMICAL':
    display_ktop_mols(init_graphs+[target_graph], oracle_func, n_max=6)
else:
    display_ktop_graphs(init_graphs+[target_graph], oracle_func, n_max=6)

In [None]:
from ego.optimization.neighborhood_node_label_mutation import NeighborhoodNodeLabelMutation
nnlm = NeighborhoodNodeLabelMutation(n_nodes=5, n_neighbors=10)
nnlm.fit(domain_graphs, None)
gs = nnlm.neighbors(target_graph)
gs = remove_duplicates(gs)
display_ktop_mols(gs+[target_graph], oracle_func, n_max=6)

In [None]:
from ego.optimization.neighborhood_edge_label_mutation import NeighborhoodEdgeLabelMutation
nnlm = NeighborhoodEdgeLabelMutation(n_edges=5, n_neighbors=10)
nnlm.fit(domain_graphs, None)
gs = nnlm.neighbors(target_graph)
gs = remove_duplicates(gs)
display_ktop_mols(gs+[target_graph], oracle_func, n_max=6)

In [None]:
from ego.optimization.neighborhood_edge_remove import NeighborhoodEdgeRemove
nnlm = NeighborhoodEdgeRemove(n_edges=1, n_neighbors=None)
nnlm.fit(domain_graphs, None)
gs = nnlm.neighbors(target_graph)
gs = remove_duplicates(gs)
display_ktop_mols(gs+[target_graph], oracle_func, n_max=6)

In [None]:
%%time
from ego.optimization.optimize import optimizer_setup, optimize

# performance monitor
if EXPERIMENT_TYPE == 'CHEMICAL':
    monitor = make_monitor(target_graph, oracle_func, display_as_molecules=True)
else:
    monitor = make_monitor(target_graph, oracle_func)

from ego.setup import *
decomposition_function = do_decompose(
    decompose_nodes_and_edges, 
    decompose_path(length=2), 
    decompose_neighborhood, 
    decompose_neighborhood(radius=2), 
    decompose_neighborhood(radius=3), 
    decompose_cycles)

decomposition_score_estimator = decomposition_function
decomposition_fixed_grammar = decomposition_function
decomposition_adaptive_grammar = decomposition_function

neighborhood_estimators, score_estimator = optimizer_setup(
    decomposition_score_estimator=decomposition_score_estimator,
    use_UCB_estimator=False,
    use_RandomForest_estimator=True,
    use_Linear_estimator=False,
    use_EI_estimator=False,
    n_estimators=200,
    exploration_vs_exploitation=0,
    
    use_edge_swapping=True,
    n_neighbors_edge_swapping=None, n_edge_swapping=1,

    use_node_label_swapping=True,
    n_neighbors_node_label_swapping=None, n_node_label_swapping=1,

    use_edge_label_swapping=False,
    n_neighbors_edge_label_swapping=None, n_edge_label_swapping=1,

    use_node_removal=False,
    n_neighbors_node_removal=None, n_node_removal=1,
    
    use_edge_removal=True,
    n_neighbors_edge_removal=None, n_edge_removal=1,

    use_node_label_mutation=True,
    n_neighbors_node_mutation=100, n_node_mutation=1,
    
    use_edge_label_mutation=False,
    n_neighbors_edge_mutation=100, n_edge_mutation=1,
    
    use_fixed_grammar=False,
    n_neighbors_fixed_grammar=None,
    conservativeness_fixed_grammar=3,
    context_size_fixed_grammar=1,
    decomposition_fixed_grammar=decomposition_fixed_grammar,
    domain_graphs_fixed_grammar=domain_graphs,

    use_adaptive_grammar=True,
    n_neighbors_adaptive_grammar=None,
    conservativeness_adaptive_grammar=5,
    context_size_adaptive_grammar=1,
    part_size_adaptive_grammar=4,
    decomposition_adaptive_grammar=decomposition_adaptive_grammar)

graphs = init_graphs[:]
graphs = optimize(
    graphs, 
    oracle_func, 
    n_iter=100, 
    n_queries_to_oracle_per_iter=200,
    frac_instances_to_remove_per_iter=.5,
    sample_size_to_perturb=4, 
    n_steps_driven_by_estimator=1,
    sample_size_for_grammars=40,
    neighborhood_estimators=neighborhood_estimators,
    score_estimator=score_estimator, 
    monitor=monitor)

In [None]:
print('Final: best sinthesized graphs in set of size %d'%len(graphs))
if EXPERIMENT_TYPE == 'CHEMICAL':
    display_ktop_mols(graphs+[target_graph], oracle_func, n_max=6*2)
else:
    display_ktop_graphs(graphs+[target_graph], oracle_func, n_max=6*2)

---