# AI-driven Automated Discovery Tools for Synthetic Circuit Engineering

Authors | Affiliation | Published
--------|-------------|----------
[Mayalen Etcheverry](https://znah.net) | [INRIA, Flowers team](https://flowers.inria.fr/), [Poietis](https://poietis.com/) | September, 2023
[Clément Moulin-Frier](http://clement-moulin-frier.github.io/) | [INRIA, Flowers team](https://flowers.inria.fr/) |
[Pierre-Yves Oudeyer](http://www.pyoudeyer.com/) | [INRIA, Flowers team](https://flowers.inria.fr/) |
[Michael Levin](https://drmichaellevin.org/) | [The Levin Lab, Tufts University](https://drmichaellevin.org/)| <a href="https://colab.research.google.com/github/flowersteam/curious-exploration-of-grn-competencies/tree/main/notebooks/tuto2.ipynb" target="_blank" id="colablink" class="colab-root"><span id="reprotext">Reproduce in </span><span class="colab-span">Notebook</span></a>

[TOC]

## Introduction

!!! hint "TL;DR"
    This second tutorial accompanies our paper [Automated Discovery Tools Reveal Behavioral Competencies of Biological Networks](https://developmentalsystems.org/curious-exploration-of-grn-competencies/paper.html), and more particularly the last section "Reuse of the framework as an alternative strategy to gene circuit engineering".

### &#128221; How to follow this tutorial

* In [**Introduction**](#introduction), we detail the modelstep function used to simulate our synthetic gene regulatory network.
* In [**Part 1**](#part-1-curiosity-driven-search-for-the-discovery-of-diverse-oscillator-circuits), we give a step-by-step walkthrough of using curiosity-driven search algorithm for the discovery of diverse oscillator circuits.
* In [**Part 2**](#part-2-curiosity-search-as-an-alternative-to-pure-optimization-driven-search-strategy?), we compare with a pure (gradient-based) optimization search strategy as proposed in the literature.

### Setup

In [1]:
#@title [Notebook config]
nb_mode = "load" #@param ["run", "load"]
nb_save_outputs = True #@param {type:"boolean"}
nb_save_outputs = nb_save_outputs and nb_mode == "run"
nb_renderer = "html" #@param ["html", "img"]

In [2]:
#@title [Ignore warnings]
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [3]:
#@title [Install & imports]
# %pip install -q addict
# %pip install -q autodiscjax
# %pip install -q plotly

import jax
jax.config.update('jax_platform_name', 'cpu')

from addict import Dict
from autodiscjax import DictTree
from autodiscjax.experiment_pipelines import run_imgep_experiment, run_rs_experiment
from autodiscjax.modules import optimizers
import autodiscjax.modules.grnwrappers as grn
from autodiscjax.utils.create_modules import *
from autodiscjax.utils.misc import uniform
from autodiscjax.utils.timeseries import is_periodic
from copy import deepcopy
from fractions import Fraction
from functools import partial
import equinox as eqx
from IPython.display import display, HTML, Image
from jax import lax, jit, vmap, value_and_grad
import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu
from jaxtyping import Array
import matplotlib.colors as mcolors
from matplotlib.patches import Ellipse
import optax
import os
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly
import numpy as np
import requests

I0000 00:00:1695483851.057819    7085 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [4]:
#@title [Plot utils]

# color to visualize time from trajectory start point A (red) to trajectory end point B (cyan) 
n_points = 100 
c = [mcolors.hsv_to_rgb((step / (2*n_points), 1, 1)) for step in range(n_points)]
traj_cscale=[(x.item(), f'rgb({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)})') for (x, color) in zip(jnp.linspace(0., 1., len(c)), c)]

# default colors
default_colors = ['rgb(204,121,167)', 
                  'rgb(0,114,178)',
                  'rgb(230,159,0)',  
                  'rgb(0,158,115)',
                  'rgb(127,127,127)',
                  'rgb(240,228,66)',
                  'rgb(148,103,189)',
                  'rgb(86,180,233)',
                  'rgb(213,94,0)',
                  'rgb(140,86,75)',
                  'rgb(214,39,40)',
                  'rgb(0,0,0)']
transparency = 0.6
default_colors_shade = ['rgba' + color[3:-1]  + ', ' + str(transparency) + ')' for color in default_colors]

default_dashes = ['solid', 'longdash', 'dot', 'dash', 'dashdot', 'longdashdot',
                 'solid', 'longdash', 'dot', 'dash', 'dashdot', 'longdashdot']


# plotly default layout
default_layout = Dict(
    font=Dict(
        size=10,
    ),
    title=Dict(font_size=10),
    
    xaxis=Dict(
        titlefont=Dict(size=10),
        tickfont=Dict(size=10),
        title_standoff=5,
        linecolor='rgba(0, 0, 0, .1)',
    ),
    
    yaxis=Dict(
        titlefont=Dict(size=10),
        tickfont=Dict(size=10),
        title_standoff=5,
        gridcolor='rgba(0, 0, 0, .1)',
        linecolor='rgba(0, 0, 0, .1)',
        #zerolinecolor='rgba(0, 0, 0, .1)'
    ),
    
    updatemenus=[],
    autosize=True,

    plot_bgcolor='rgba(0, 0, 0, 0)', 
    paper_bgcolor='rgba(0, 0, 0, 0)', 

    margin = Dict(
        l=20,
        r=20,
        b=20,
        t=20
        ),

    legend=Dict(
        xanchor='left',
        yanchor='top',
        y=1,
        x=1,
        font_size=10,
     ),     

    )

default_annotation_layout = Dict(
    font_size=10,
)

def make_html_fig(fig_idx, fig, width, height, title, size_unit="px", figtitle_fontsize="1em", title_fontsize="0.8em", full_html=False, include_plotlyjs=False, config={}):
    if isinstance(fig, go.Figure):
        # autosize fig
        fig.layout.autosize = True
        fig.layout.width = None
        fig.layout.height = None
        
        # convert to html
        default_config = {'displaylogo': False, 'modeBarButtonsToRemove': ['select', 'lasso2d', 'autoScale']}
        for k,v in config.items():
            default_config[k] = v
        html_fig = fig.to_html(config=default_config, full_html=full_html, include_plotlyjs=include_plotlyjs, default_width='100%', default_height='100%')
        html_fig = html_fig[:4]+ f' style="aspect-ratio: {str(Fraction(width,height))};"'+ html_fig[4:]# add aspect ratio
        
    elif isinstance(fig, str) and fig.split(".")[-1] in ["png","jpg","jpeg"]:
        html_fig = f'<div><img src="{fig}" alt="Figure {fig_idx}" style="aspect-ratio: {str(Fraction(width,height))}; width:100%;"></div>'

    # change div style and append title
    div_tag = f'<div id="figure-{fig_idx}" style="margin: 50px auto; max-width: {width}{size_unit};">'
    title_tag = f'<span style="font-size: {title_fontsize}; color: rgba(0, 0, 0, 0.6)">' + f'<b style="font-size: {figtitle_fontsize};">Figure {fig_idx}: </b>' + title + '</span>'
    html_fig = div_tag + html_fig +  title_tag + '</div>'
    
    return html_fig


def make_img_fig(fig_idx, fig, width, height, title, img_format="png", size_unit="px", title_fontsize="1em"):
    
    if isinstance(fig, go.Figure):
        # autosize fig
        fig.layout.autosize = True
        fig.layout.width = None
        fig.layout.height = None

        if size_unit != "px":
            raise NotImplementedError

        # convert to img
        img_fig = fig.to_image(width=width, height=height, format=img_format)
        
    elif isinstance(fig, str) and fig.split(".")[-1] in ["png","jpg","jpeg"]:
        img_fig = fig
        
    img_title = f'Figure {fig_idx}: ' + title
    
    return img_fig, img_title

if nb_renderer == "html":
    display(HTML('<script type="text/javascript" id="MathJax-script" async src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_SVG">'))
    display(HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/plotly.js/2.18.2/plotly.min.js" integrity="sha512-D52Rvz8mPwpAIIg9bRTFYiyy3GlBIE7kN8wGscKV+EgN8tJB7x7scvLlCBsK2KfYXQvPclyv1uY6E8+0HU+sfA==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>'))

In [5]:
#@title [set seed]
key = jrandom.PRNGKey(0)

### ModelStep function

When simulating *synthetic* gene regulatory network, we typically assume one family of ODE equations. Here we use the transcriptional gene circuit model with a simple model step defined as:
 
$\frac{d{y}_i}{dt}=\phi \left({\sum}_j{W}_{ij}{y}_j+{B}_i\right)-{k}_i{y}_i$

Here, we use $k_i=1, W_{ij}\in[-30,30], B_{i}\in[-10,10]$ and with these parameters species concentrations are constrained in $y\in[0,1]$

In [6]:
@jit
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

class SimpleModelStep(eqx.Module):
    def __init__(self, **kwargs):
        super().__init__()

    @jit
    def __call__(self, y, w, c, t, deltaT):

        n = len(y)
        W = c[:n * n].reshape((n, n))
        B = c[n * n:(n + 1) * n]
        y_new = y + deltaT * (sigmoid(W @ y + B ) - y)
        t_new = t + deltaT
        w_new = w

        return y_new, w_new, c, t_new

## Part 1: Curiosity-driven search for the discovery of diverse oscillator circuits

### Experiment Pipeline and Modules

#### System Rollout module 
Now that we have define the new ModelStep function, AutoDiscJax allows us to simulate system rollout (and applying different kind of interventions on it) in the same manner that we did for biological networks in the first tutorial.

Let's instantitate the system rollout module.

In [7]:
if nb_mode == "run":
    
    n = 3 #number of nodes
    deltaT = 0.01
    n_secs = 100
    n_steps = int(n_secs/deltaT)

    c = jnp.empty(((n + 1) * n, ))
    c_low = jnp.array([-30.]*n**2 + [-10.]*n)
    c_high = jnp.array([30.]*n**2 + [10.]*n)
    grn_step=SimpleModelStep()

    y0=jnp.empty(shape=(n,))
    y0_low = 0.
    y0_high = 1.

    w0 = jnp.array([])

    system_rollout = grn.GRNRollout(n_steps=n_steps, y0=y0, w0=w0, c=c, t0=0.0, deltaT=deltaT, grn_step=grn_step)

#### Random Intervention Generator
Let's now use intervention to (randomly) set the GRN's init state (y0) and kinematic parameters (c)

In [8]:
if nb_mode == "run":
    
    # Create an intervention generator and an intervention_fn modules to set the initial state and the kinematic parameters to random values
    random_intervention_generator_config = Dict()
    random_intervention_generator_config.intervention_type = "set_uniform"
    random_intervention_generator_config.controlled_intervals = [[0, deltaT/2.0]]

    intervention_params_tree = DictTree()
    intervention_params_low = DictTree()
    intervention_params_high = DictTree()
    for y_idx in range(len(y0)):
        intervention_params_tree.y[y_idx] = "placeholder"
        intervention_params_low.y[y_idx] = y0_low
        intervention_params_high.y[y_idx] = y0_high
    for c_idx in range(len(c)):
        intervention_params_tree.c[c_idx] = "placeholder"
        intervention_params_low.c[c_idx] = c_low[c_idx]
        intervention_params_high.c[c_idx] = c_high[c_idx]

    random_intervention_generator_config.out_treedef = jtu.tree_structure(intervention_params_tree)
    random_intervention_generator_config.out_shape = jtu.tree_map(lambda _: (len(random_intervention_generator_config.controlled_intervals),), intervention_params_tree)
    random_intervention_generator_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, intervention_params_tree)
    random_intervention_generator_config.low = intervention_params_low
    random_intervention_generator_config.high = intervention_params_high

    random_intervention_generator, intervention_fn = create_intervention_module(random_intervention_generator_config)

In [9]:
if nb_mode == "run":
    
    # example: generate a random set of intervention parameters between low and high
    key, subkey = jrandom.split(key)
    intervention_params, log_data = random_intervention_generator(subkey)
    
    # Run the system with the sample intervention
    key, subkey = jrandom.split(key)
    random_system_outputs, log_data = system_rollout(subkey, intervention_fn=intervention_fn, intervention_params=intervention_params)

In [10]:
#@title [Figure 1]
fig_idx = 1

if nb_mode == "run":
    
    fig = go.Figure(layout=default_layout)
    
    for y_idx in range(n):
        fig.add_trace(go.Scatter(x=random_system_outputs.ts, y=random_system_outputs.ys[y_idx], name=f"y{y_idx}",
                                 line=dict(color=default_colors[y_idx])))                          

    fig.update_xaxes(title_text="reaction time [sec]")
    fig.update_yaxes(title_text="gene expression level")
    
    # Serialize fig to json and save
    if nb_save_outputs:
        fig.write_json(f"figures/tuto2_fig_{fig_idx}.json")
        
            
elif nb_mode == "load":
    fig = plotly.io.read_json(f"figures/tuto2_fig_{fig_idx}.json")
    

# Display Fig
width, height = 600, 400
t = f"Simulation results of the mathematical modeling for random kinetic parameters (c) and initial gene expression levels (y0). "

if nb_renderer == "html":
    html_fig = make_html_fig(fig_idx, fig, width, height, t)
    display(HTML(html_fig))

elif nb_renderer == "img":
    img_fig, img_title = make_img_fig(fig_idx, fig, width, height, t)
    display(Image(img_fig))
    print(img_title)

#### Goal Embedding Encoder 
For the IMGEP goal space, we use the image space of the discrete fourier transform of the 1d-signal $y[n=0]$.

In [11]:
#@title [Figure 2]
fig_idx = 2

if nb_mode == "run":
    
    fig = make_subplots(rows=1, cols=2, subplot_titles=["Examples wave signal", "Fourier Descriptors"])
    fig.update_layout(default_layout)
    fig.update_annotations(default_annotation_layout)
    
    # Arbitrary functions made of sin and cos waves
    f = lambda t: 1+jnp.sin(5*2*jnp.pi*t)+jnp.cos(10*2*jnp.pi*t)
    ts = jnp.linspace(0,1,n_steps//2)
    y = vmap(f)(ts)
    y_descriptors = jnp.fft.rfft(y)
    reconstructed_y = jnp.fft.irfft(y_descriptors)

    fig.add_trace(go.Scatter(x=ts, y=y, name="y = 1 + sin(5 * 2&#960;t) + cos(10 * 2&#960;t)", showlegend=True,
                             line=dict(color=default_colors[2])),
                 row=1, col=1)
    fig.update_xaxes(title_text="t", row=1, col=1)
    fig.update_yaxes(title_text="y", row=1, col=1)
    
    fig.add_trace(go.Scatter(y=jnp.abs(y_descriptors), name="y fourier descriptors", showlegend=True,
                             line=dict(color=default_colors[3])),
                 row=1, col=2)
    fig.update_xaxes(title_text="frequency F", row=1, col=2)
    fig.update_yaxes(title_text="magnitude", row=1, col=2)
    
    fig.add_annotation(x=0, y=jnp.abs(y_descriptors)[0], text="F=0", ax=20, ay=30, row=1, col=2)
    fig.add_annotation(x=5, y=jnp.abs(y_descriptors)[5], text="F=5", ax=20, ay=-30, row=1, col=2)
    fig.add_annotation(x=10, y=jnp.abs(y_descriptors)[10], text="F=10", ax=60, ay=-30, row=1, col=2)
                          
    
    # Serialize fig to json and save
    if nb_save_outputs:
        fig.write_json(f"figures/tuto2_fig_{fig_idx}.json")
        
            
elif nb_mode == "load":
    fig = plotly.io.read_json(f"figures/tuto2_fig_{fig_idx}.json")
    

# Display Fig
width, height = 940, 350
t = f"Illustration of Fourier descriptors, which are here used as goal representation by the IMGEP."

if nb_renderer == "html":
    html_fig = make_html_fig(fig_idx, fig, width, height, t)
    display(HTML(html_fig))

elif nb_renderer == "img":
    img_fig, img_title = make_img_fig(fig_idx, fig, width, height, t)
    display(Image(img_fig))
    print(img_title)

In [12]:
if nb_mode == "run":
    observed_node_ids = [0]

    goal_embedding_encoder_config = Dict()
    goal_embedding_encoder_config.encoder_type = "filter"
    goal_embedding_tree = "placeholder"
    goal_embedding_encoder_config.out_treedef = jtu.tree_structure(goal_embedding_tree)
    goal_embedding_encoder_config.out_shape = jtu.tree_map(lambda _: (len(observed_node_ids)*(system_rollout.n_steps//2//2+1), ), goal_embedding_tree)
    goal_embedding_encoder_config.out_dtype = jtu.tree_map(lambda _: jnp.float32, goal_embedding_tree)
    goal_embedding_encoder_config.filter_fn = jtu.Partial(lambda system_outputs: jnp.fft.rfft(system_outputs.ys[observed_node_ids, -system_rollout.n_steps//2:]).flatten())


    goal_embedding_encoder = create_goal_embedding_encoder_module(goal_embedding_encoder_config)

In [13]:
if nb_mode == "run":
    
    # example: encode system outputs
    key, subkey = jrandom.split(key)
    reached_goal_embedding, log_data = goal_embedding_encoder(subkey, random_system_outputs)
    print(reached_goal_embedding.shape)

#### Goal-conditioned Achievement Loss
Distance in the goal space measures average difference in spectral amplitude.

In [14]:
if nb_mode == "run":
    
    goal_achievement_loss_config = Dict()
    goal_achievement_loss_config.loss_type = "custom" 
    goal_achievement_loss_config.loss_f = jtu.Partial(lambda reached_goal, target_goal: abs(reached_goal - target_goal).mean())
    goal_achievement_loss = create_goal_achievement_loss_module(goal_achievement_loss_config)

In [15]:
if nb_mode == "run":
    
    # example
    target_goal_embedding = y_descriptors
    key, subkey = jrandom.split(key)
    gc_loss, log_data = goal_achievement_loss(subkey, reached_goal_embedding, target_goal_embedding)
    print(gc_loss)

#### Goal Generator
For the goal generator, goal-conditionned intervention selector and optimizer we re-use the same simple variants that the one used in the first tutorial.

In [16]:
if nb_mode == "run":
    
    goal_generator_config = DictTree()
    goal_generator_config.out_treedef = goal_embedding_encoder_config.out_treedef
    goal_generator_config.out_shape = goal_embedding_encoder_config.out_shape
    goal_generator_config.out_dtype = goal_embedding_encoder_config.out_dtype
    goal_generator_config.low = None
    goal_generator_config.high = None
    goal_generator_config.generator_type = "hypercube"
    goal_generator_config.hypercube_scaling = 1.3

    goal_generator = create_goal_generator_module(goal_generator_config)

In [17]:
if nb_mode == "run":
    
    # example
    key, subkey = jrandom.split(key)
    next_target_goal_embedding, log_data = goal_generator(subkey, target_goal_embedding[jnp.newaxis], jnp.stack([reached_goal_embedding,target_goal_embedding]))
    print(next_target_goal_embedding.shape)

#### Goal-conditioned Intervention Selector

In [18]:
if nb_mode == "run":
    
    gc_intervention_selector_config = Dict()
    gc_intervention_selector_config.selector_type = "nearest_neighbor"
    gc_intervention_selector_config.loss_f = goal_achievement_loss.loss_f
    gc_intervention_selector_config.k = 1

    gc_intervention_selector = create_gc_intervention_selector_module(gc_intervention_selector_config)

In [19]:
if nb_mode == "run":
    
    # example
    key, subkey = jrandom.split(key)
    source_interventions_idx, log_data = gc_intervention_selector(subkey, next_target_goal_embedding, jnp.stack([reached_goal_embedding, target_goal_embedding]))
    print(source_interventions_idx)

#### Goal-conditioned Intevention Optimizer

In [20]:
if nb_mode == "run":
    
    gc_intervention_optimizer_config = Dict()
    gc_intervention_optimizer_config.out_treedef = random_intervention_generator.out_treedef
    gc_intervention_optimizer_config.out_shape = random_intervention_generator.out_shape
    gc_intervention_optimizer_config.out_dtype = random_intervention_generator.out_dtype
    gc_intervention_optimizer_config.low = random_intervention_generator.low
    gc_intervention_optimizer_config.high = random_intervention_generator.high
    gc_intervention_optimizer_config.optimizer_type = "EA"
    gc_intervention_optimizer_config.n_optim_steps = 1
    gc_intervention_optimizer_config.n_workers = 1
    gc_intervention_optimizer_config.init_noise_std = jtu.tree_map(lambda low, high: 0.1 * (high - low), 
                                                                   gc_intervention_optimizer_config.low, gc_intervention_optimizer_config.high)

    gc_intervention_optimizer = create_gc_intervention_optimizer_module(gc_intervention_optimizer_config)
    null_perturbation_generator, null_perturbation_fn = create_perturbation_module(Dict(perturbation_type="null"))
    null_rollout_statistics_encoder = create_rollout_statistics_encoder_module(Dict(statistics_type="null"))
    partial_gc_intervention_optimizer = jtu.Partial(gc_intervention_optimizer,
                                            perturbation_generator=null_perturbation_generator, perturbation_fn=null_perturbation_fn,
                                            intervention_fn=intervention_fn, system_rollout=system_rollout,
                                            goal_embedding_encoder=goal_embedding_encoder, goal_achievement_loss=goal_achievement_loss,
                                            rollout_statistics_encoder=null_rollout_statistics_encoder
                                            )

In [21]:
if nb_mode == "run":
    
    # example
    key, subkey = jrandom.split(key)
    optimized_intervention_params, log_data = partial_gc_intervention_optimizer(subkey, intervention_params, next_target_goal_embedding, reached_goal_embedding)
    print(jtu.tree_map(lambda node: node.shape, optimized_intervention_params))

#### Run Experiment Pipeline
Now that we have defined the IMGEP internal models, we can run the IMGEP experimental pipeline. As in the previous tutorial, we compare it with a random exploration strategy given the same experimental budget of experiments.
Here we define a total of N=5000 experiments, with a batch size of 100.

In [22]:
if nb_mode == "run":
    
    jax_platform_name = "cpu"
    seed = 0

    # Run IMGEP
    n_random_batches = 10 
    n_imgep_batches = 40
    batch_size = 100
    imgep_experiment_data_save_folder = "data/periodic_imgep_data"
    if not os.path.exists(os.path.join(imgep_experiment_data_save_folder, "history.pickle")):
        run_imgep_experiment(jax_platform_name, seed, n_random_batches, n_imgep_batches, batch_size,
                             imgep_experiment_data_save_folder,
                             random_intervention_generator, intervention_fn,
                             null_perturbation_generator, null_perturbation_fn,
                             system_rollout, null_rollout_statistics_encoder,
                             goal_generator, gc_intervention_selector, gc_intervention_optimizer,
                             goal_embedding_encoder, goal_achievement_loss,
                             out_sanity_check=False, save_modules=False, save_logs=False)

    # Run Random Search
    rs_experiment_data_save_folder = "data/periodic_rs_data"
    if not os.path.exists(os.path.join(rs_experiment_data_save_folder, "history.pickle")):
        run_rs_experiment(jax_platform_name, seed, n_random_batches+n_imgep_batches, batch_size, 
                          rs_experiment_data_save_folder,
                          random_intervention_generator, intervention_fn,
                          null_perturbation_generator, null_perturbation_fn,
                          system_rollout, null_rollout_statistics_encoder,
                          out_sanity_check=False, save_modules=False, save_logs=False)

### Analysis of the discoveries

In [23]:
if nb_mode == "run":
    
    imgep_experiment_history = DictTree.load(os.path.join(imgep_experiment_data_save_folder, "history.pickle"))
    imgep_reached_goals_embeddings = imgep_experiment_history.reached_goal_embedding_library
    print(imgep_reached_goals_embeddings.shape)
    
    rs_experiment_history = DictTree.load(os.path.join(rs_experiment_data_save_folder, "history.pickle")) 
    key, *subkeys = jrandom.split(key, num=len(imgep_reached_goals_embeddings)+ 1)
    rs_reached_goals_embeddings, _ = vmap(goal_embedding_encoder)(jnp.array(subkeys), rs_experiment_history.system_output_library)
    print(rs_reached_goals_embeddings.shape)

#### Number of discovered oscillators

In [24]:
if nb_mode == "run":
    
    rs_is_periodic_bool, rs_offset_vals, rs_ampl_vals, rs_freq_vals = is_periodic(rs_experiment_history.system_output_library.ys[:,0,:], jnp.r_[-system_rollout.n_steps//2:0], system_rollout.deltaT, 40)
    rs_is_periodic_ids = jnp.where(rs_is_periodic_bool)[0]
    print(f"Random Search has discovered {len(rs_is_periodic_ids)} oscillator circuits out of N={len(rs_is_periodic_bool)} trials")

    imgep_is_periodic_bool, imgep_offset_vals, imgep_ampl_vals, imgep_freq_vals = is_periodic(imgep_experiment_history.system_output_library.ys[:,0,:], jnp.r_[-system_rollout.n_steps//2:0], system_rollout.deltaT, 40)
    imgep_is_periodic_ids = jnp.where(imgep_is_periodic_bool)[0]
    print(f"Curiosity Search has discovered {len(imgep_is_periodic_ids)} oscillator circuits out of N={len(imgep_is_periodic_bool)} trials")

#### Diversity of discovered oscillators
Here, the analytic BC space is the space of (amplitude $A$, main frequency $\omega$, offset $b$) of the discovered oscillators, where $(A,\omega,b)$ are estimated by Autodiscjax `is_periodic` util function.
Diversity is measured with the QD-score, a binning-based metric where the BC space is discretized into a collection of bins and the diversity is quantified as the number of bins filled over the course of exploration.
We opt for a regular binning where each dimension of the BC space is discretized into equally sized bins, using 20 bins per dimension.
We do not use the threshold-coverage metric as in tutorial 1 as it is difficult to compute in n-dimensional spaces where $n\ge3$.

In [25]:
if nb_mode == "run":
    
    imgep_reached_goals_embeddings = jnp.stack([imgep_offset_vals.at[~imgep_is_periodic_bool].set(0.0), 
                                                imgep_ampl_vals.at[~imgep_is_periodic_bool].set(0.0), 
                                                imgep_freq_vals.at[~imgep_is_periodic_bool].set(0.0)], -1)
    rs_reached_goals_embeddings = jnp.stack([rs_offset_vals.at[~rs_is_periodic_bool].set(0.0), 
                                             rs_ampl_vals.at[~rs_is_periodic_bool].set(0.0), 
                                             rs_freq_vals.at[~rs_is_periodic_bool].set(0.0)], -1)

    analytic_bc_space_low = jnp.minimum(jnp.nanmin(imgep_reached_goals_embeddings, 0), jnp.nanmin(rs_reached_goals_embeddings, 0))
    analytic_bc_space_high = jnp.maximum(jnp.nanmax(imgep_reached_goals_embeddings, 0), jnp.nanmax(rs_reached_goals_embeddings, 0))
    analytic_bc_space_extent = jnp.stack([analytic_bc_space_low, analytic_bc_space_high]).transpose()

    def calc_analytic_bc_coverage_histograms(reached_goals_embeddings, analytic_bc_space_extent, n_bins=20, every_n_steps=1):
        def f(carry, goal_embedding):
            Hf = carry
            cur_Hf, _ = jnp.histogramdd(goal_embedding[jnp.newaxis], bins=n_bins,range=analytic_bc_space_extent)
            Hf = Hf + cur_Hf.transpose()
            return Hf, Hf

        final_coverage_histogram, coverage_histograms = lax.scan(f, jnp.zeros((n_bins, n_bins, n_bins), dtype=jnp.int32), reached_goals_embeddings[::every_n_steps])
        return coverage_histograms


    imgep_coverage_histograms = calc_analytic_bc_coverage_histograms(imgep_reached_goals_embeddings, analytic_bc_space_extent, n_bins=20)
    rs_coverage_histograms = calc_analytic_bc_coverage_histograms(rs_reached_goals_embeddings, analytic_bc_space_extent, n_bins=20)

In [26]:
#@title [Figure 3]
fig_idx = 3

if nb_mode == "run":
    
    row_titles=["RANDOM SEARCH", "CURIOSITY SEARCH"]
    fig = make_subplots(rows=3, cols=6, specs=[[{'type': 'surface', 'colspan': 2}, None, {'colspan': 2}, None, {'colspan': 2}, None],
                                               [{}, {}, {}, {}, {}, {}], 
                                               [{}, {}, {}, {}, {}, {}]], 
                        row_heights=[2.5,1,1], row_titles=row_titles,
                        subplot_titles=["<b>(a) discoveries in 3D analytic space</b><br> ", "<b>(b) discoveries distribution (in analytic space)</b><br> ", "<b>(c) diversity throughout exploration</b><br> "] + 
                        ["<b>(d) offset min</b><br> ", "<b>(e) offset max</b><br> ", "<b>(f) ampl min</b><br> ", "<b>(g) ampl max</b><br> ", "<b>(h) freq min</b><br> ", "<b>(i) freq max</b><br> "] + [""]*6,
                        horizontal_spacing=0.05,
                        vertical_spacing=0.15)
    fig.update_layout(default_layout, margin_l=40, margin_t=40)
    fig.update_annotations(default_annotation_layout)
    fig.for_each_annotation(lambda a:  a.update(x = -0.05, textangle=-90) if a.text in row_titles else())
    
    
    # Diversity curves
    fig.add_trace(go.Scatter(x=jnp.arange(len(rs_coverage_histograms)), y=(rs_coverage_histograms>0).sum(-1).sum(-1).sum(-1),
                             name="random search", legendgroup="random search", showlegend=False, line=dict(color=default_colors[0])),
                  row=1, col=5)
    fig.add_trace(go.Scatter(x=jnp.arange(len(imgep_coverage_histograms)), y=(imgep_coverage_histograms>0).sum(-1).sum(-1).sum(-1),
                             name="imgep", legendgroup="imgep", showlegend=False, line=dict(color=default_colors[1])),
                  row=1, col=5)
    fig.add_vline(x=5000, line_width=2, line_dash="dash", annotation_text="n=5000", annotation_position="top left", row=1, col=5)  
    fig.update_xaxes(default_layout.xaxis, title_text="n (# exploration runs)", title_standoff=3, side="bottom", row=1, col=5)
    fig.update_yaxes(default_layout.yaxis, title_text="Diversity", title_standoff=1, row=1, col=5)
    
    
    # Box Plots
    rs_data = np.concatenate([rs_offset_vals[rs_is_periodic_ids], rs_ampl_vals[rs_is_periodic_ids], rs_freq_vals[rs_is_periodic_ids]])
    fig.add_trace(go.Box(x=np.array(["offset"]*len(rs_is_periodic_ids) + ["amplitude"]*len(rs_is_periodic_ids) + ["main<br>frequency"]*len(rs_is_periodic_ids)), 
                         y=rs_data, name=f"random search<br>(<b>{len(rs_is_periodic_ids)}</b>/{len(rs_is_periodic_bool)} oscillators)", legendgroup="random search", 
                         line=dict(color=default_colors[0]), fillcolor=default_colors_shade[0],
                         boxpoints='all', pointpos=0., marker_size=3), row=1, col=3)

    imgep_data = np.concatenate([imgep_offset_vals[imgep_is_periodic_ids], imgep_ampl_vals[imgep_is_periodic_ids], imgep_freq_vals[imgep_is_periodic_ids]])
    fig.add_trace(go.Box(x=np.array(["offset"]*len(imgep_is_periodic_ids) + ["amplitude"]*len(imgep_is_periodic_ids) + ["main<br>frequency"]*len(imgep_is_periodic_ids)), 
                         y=imgep_data, name=f"curiosity search<br>(<b>{len(imgep_is_periodic_ids)}</b>/{len(imgep_is_periodic_bool)} oscillators)", legendgroup="imgep", 
                         line=dict(color=default_colors[1]), fillcolor=default_colors_shade[1],
                         boxpoints='all', pointpos=0., marker_size=3), row=1, col=3)
    
    fig.update_layout(boxmode='group')

    
    # 3D scatter plot
    fig.add_trace(go.Scatter3d(x=imgep_offset_vals[imgep_is_periodic_ids], 
                               y=imgep_ampl_vals[imgep_is_periodic_ids], 
                               z=imgep_freq_vals[imgep_is_periodic_ids], 
                               name="curiosity search", legendgroup="imgep", 
                               showlegend = False, mode="markers",
                               marker=dict(color=default_colors[1], size=3),
                        ), row=1, col=1)
    fig.add_trace(go.Scatter3d(x=rs_offset_vals[rs_is_periodic_ids], 
                           y=rs_ampl_vals[rs_is_periodic_ids], 
                           z=rs_freq_vals[rs_is_periodic_ids],
                           name="random search", legendgroup="random search", 
                           showlegend = False, mode="markers",
                           marker=dict(color=default_colors[0], size=3),
                    ), row=1, col=1)
    
    ## draw x,y and z axes
    fig.add_trace(go.Scatter3d(x=[0,1], y=[0,0], z=[0,0], showlegend=False, mode="lines", hoverinfo="skip", line_color="black"), row=1, col=1)
    fig.add_trace(go.Cone(x=[1], y=[0], z=[0], u=[0.2], v=[0], w=[0], hoverinfo="skip", sizemode="absolute", colorscale=[[0, "black"], [1, "black"]], showscale=False), row=1, col=1)
    fig.add_trace(go.Scatter3d(x=[0,0], y=[0,1], z=[0,0], showlegend=False, mode="lines", hoverinfo="skip", line_color="black"), row=1, col=1)
    fig.add_trace(go.Cone(x=[0], y=[1], z=[0], u=[0], v=[0.2], w=[0], hoverinfo="skip", sizemode="absolute", colorscale=[[0, "black"], [1, "black"]], showscale=False), row=1, col=1)
    fig.add_trace(go.Scatter3d(x=[0,0], y=[0,0], z=[0,1], showlegend=False, mode="lines", hoverinfo="skip", line_color="black"), row=1, col=1)
    fig.add_trace(go.Cone(x=[0], y=[0], z=[1], u=[0], v=[0], w=[0.2], hoverinfo="skip", sizemode="absolute", colorscale=[[0, "black"], [1, "black"]], showscale=False), row=1, col=1)
    
    fig.update_scenes(xaxis=dict(title="", autorange="reversed"), 
                      yaxis=dict(title=""),
                      zaxis=dict(title=""),
                      annotations=[dict(x=1.1, y=-0.1, z=-0.1, showarrow=False, text="offset"),
                                  dict(x=-0.1, y=1.1, z=-0.1, showarrow=False, text="amplitude"),
                                  dict(x=-0.1, y=-0.1, z=1.1, showarrow=False, text="main<br>frequency")],
                      camera=dict(eye=dict(x=1.5,y=-1.,z=0.2),),
                      row=1, col=1)
    
    
    # Individual examples
    for row_idx in range(2):
        row_legendgroup = ["random search", "imgep"][row_idx]
        row_data = [rs_data, imgep_data][row_idx]
        sub_data_size = len(row_data)//3
        oscillator_ids = [rs_is_periodic_ids, imgep_is_periodic_ids][row_idx]
        system_outputs_library = [rs_experiment_history.system_output_library, imgep_experiment_history.system_output_library][row_idx]
        for col_idx in range(6):
            col_data = row_data[(col_idx//2)*sub_data_size:((col_idx//2)+1)*sub_data_size]
            sample_idx = col_data.argmin() if col_idx % 2 ==0 else col_data.argmax()
            fig.add_trace(go.Scatter(x=system_outputs_library.ts[oscillator_ids[sample_idx]],
                                     y=system_outputs_library.ys[oscillator_ids[sample_idx], observed_node_ids[0]], 
                                     legendgroup=row_legendgroup, showlegend=False, mode="lines", line_color=default_colors[row_idx]), 
                          row=row_idx+2, col=col_idx+1)
            fig.update_yaxes(range=[0,1], row=row_idx+2, col=col_idx+1)

    
    
    # Serialize fig to json and save
    if nb_save_outputs:
        fig.write_json(f"figures/tuto2_fig_{fig_idx}.json")
        
            
elif nb_mode == "load":
    fig = plotly.io.read_json(f"figures/tuto2_fig_{fig_idx}.json")
    

# Display Fig
width, height = 940, 700
t = f"Characteristics of the discovered oscillators."

if nb_renderer == "html":
    html_fig = make_html_fig(fig_idx, fig, width, height, t)
    display(HTML(html_fig))

elif nb_renderer == "img":
    img_fig, img_title = make_img_fig(fig_idx, fig, width, height, t)
    display(Image(img_fig))
    print(img_title)

&#x1F449; We can see that, one again, curiosity search is much more efficient than random search in revealing a diversity of possible oscillator behaviors. Given the same experimental budget of 5000 model rollouts, random search was able to find only 42 configurations leading to periodic patterns whereas curiosity search was able to find 1167. Projecting the discoveries into the space of (amplitude, frequency, offset), we can see that curiosity search efficiently reveals and covers the reachable space (a-c), reaching hard-to-discover behaviors on the borders of the space (d-i).

## Part 2: Curiosity search as an alternative to pure optimization-driven search strategy?

In [Tom W. Hiscock's paper](https://link.springer.com/article/10.1186/s12859-019-2788-3), it is showcased how the use of gradient descent-based optimization can be helpful to design (synthetic) gene circuits with desired functionalities, and the example of optimizing the transcriptional gene circuit parameters to generate oscillations with desired (amplitude $A$, main frequency $\omega$) is considered.
In the paper, the loss function is defined as $C = \sum_t(y_i(t) - (A \cos(2\pi\omega t)+b))^2$ where $y_i$ is the observed node (here i=0).
Adam optimizer is then used with parameters $lr = 0.1, b1 = 0.02, b2 = 0.001$. Here we use the same parameters except for the learning rate that is chosen as $lr=1e-3$ (0.1 too big here).
Note that $b$ is not optimized in the original paper and considered fixed as $b=0$ but this leads to biologically not-admissible target with negative gene expression levels.
Here, we consider targets respecting the plausible gene expression levels $0 \le y \le 1$ in the gene circuit model. 
We define $A \in [0.1,05],b \in [A,1-A], w \in [0,1]$.

### Loss definition

In [27]:
def loss_pattern(ys, A, b, w):
    target_ys = A*jnp.cos(2*jnp.pi*w*random_system_outputs.ts)+b
    loss = jnp.sqrt(jnp.square(ys-target_ys).sum())
    return loss

### Optimization pipelines (global and local)

In [28]:
# Model Rollout
class ModelRollout(eqx.Module):
    deltaT: float
    y0: Array
    c: Array
    grn_step: SimpleModelStep
    
    def __init__(self, deltaT, y0, c, grn_step):
        super().__init__()
        self.deltaT = deltaT
        self.y0 = jnp.maximum(y0, 0.)
        self.c = c
        self.grn_step = grn_step
    
    @partial(jit, static_argnames=("n_steps",))
    def __call__(self, n_steps):
        def f(carry, x):
            y, w, c, t = carry
            return self.grn_step(y, w, c, t, self.deltaT), (y, w, t)
        (y, w, c, t), (ys, ws, ts) = lax.scan(f, (self.y0, jnp.array([]), self.c, 0.0), jnp.arange(n_steps))
        ys = jnp.moveaxis(ys, 0, -1)
        ws = jnp.moveaxis(ws, 0, -1)
        return ys, ws, ts
    

# Optax optimizer, loss function and update function
optim = optax.adam(1e-3, b1=0.02, b2=0.001) # Same optimizer params than in Hiscock et al.

@jit
def loss_fn(params, A, b, w): 
    """loss function"""
    y0, c = params
    model = ModelRollout(deltaT, y0, c, SimpleModelStep())
    ys, ws, ts = model(n_steps)
    loss = loss_pattern(ys[0], A, b, w)
    return loss

@jit
def make_step(params, A, b, w, opt_state):
    """update function"""
    loss, grads = value_and_grad(loss_fn)(params, A, b, w)
    updates, opt_state = optim.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state

We consider two setups:
1. Giving gradient-descent a budget of N=5000 optimization steps (same number of model rollouts allowed than for curiosity search and random search) and starting from random init $y_0 \in [0,1]^{n}, W \in [-30,30]^{n \times n}, B \in[-10,10]^{n}$
2. Giving gradient-descent a budget of N=100 optimization steps and starting from the best discoveries made by the curiosity search and random search exploration strategies (small budget this time, for local refinement). 

In [29]:
if nb_mode == "run":
    # Generate RANDOM Target
    A = jrandom.uniform(subkey, minval=0.1, maxval=0.5)
    key, subkey = jrandom.split(key)
    w = jrandom.beta(subkey, a=2, b=8)
    key, subkey = jrandom.split(key)
    b = jrandom.uniform(subkey, minval=A, maxval=1-A)
        
    # Optax pipeline from RANDOM Init
    key, subkey = jrandom.split(key)
    y0_sgd = jrandom.uniform(subkey, shape=(n, ), minval=y0_low, maxval=y0_high)
    key, subkey = jrandom.split(key)
    c_sgd = jrandom.uniform(subkey, shape=(n**2+n, ), minval=c_low, maxval=c_high)
    
    model = ModelRollout(deltaT, y0_sgd, c_sgd, SimpleModelStep())
    ys_random, _, _ = model(n_steps)
    
    n_optim_steps = 5000
    opt_state = optim.init((y0_sgd, c_sgd))
    loss_sgd = []
    n_sgd_oscillators = 0
    for optim_step_idx in range(n_optim_steps):
        loss, (y0_sgd, c_sgd), opt_state = make_step((y0_sgd, c_sgd), A, b, w, opt_state)
        loss_sgd.append(loss)
        
        # check whether gradient descent passes through some oscillator behaviors
        model = ModelRollout(deltaT, y0_sgd, c_sgd, SimpleModelStep())
        ys_sgd, _, _ = model(n_steps)
        is_periodic_bool, _, _, _ = is_periodic(ys_sgd[0,:], jnp.r_[-system_rollout.n_steps//2:0], deltaT, 40)
        n_sgd_oscillators += int(is_periodic_bool)
        
    print(f"Gradient-descent optimization has discovered {n_sgd_oscillators} oscillator circuits out of N={n_optim_steps} trials")
    
    # Optax pipeline from closest init in IMGEP discoveries
    loss_imgep = vmap(loss_pattern, in_axes=(0,None,None,None))(imgep_experiment_history.system_output_library.ys[:,0], A, b, w)
    imgep_best_idx = imgep_is_periodic_ids[loss_imgep[imgep_is_periodic_ids].argmin()]
    y0_imgep = jnp.array([imgep_experiment_history.intervention_params_library.y[node_idx][imgep_best_idx, 0] for node_idx in range(n)])
    c_imgep = jnp.array([imgep_experiment_history.intervention_params_library.c[param_idx][imgep_best_idx, 0] for param_idx in range(n**2+n)])
    
    model = ModelRollout(deltaT, y0_imgep, c_imgep, SimpleModelStep())
    ys_imgep, _, _ = model(n_steps)
    
    n_optim_steps = 100
    opt_state = optim.init((y0_imgep, c_imgep))
    loss_imgep_sgd = []
    for optim_step_idx in range(n_optim_steps):
        loss, (y0_imgep, c_imgep), opt_state = make_step((y0_imgep, c_imgep), A, b, w, opt_state)
        loss_imgep_sgd.append(loss)
        
    model = ModelRollout(deltaT, y0_imgep, c_imgep, SimpleModelStep())
    ys_imgep_sgd, _, _ = model(n_steps)
    
    ## arrange loss prior optim for plotting
    cur_min = loss_imgep[0]
    for i, cur_loss in enumerate(loss_imgep):
        if cur_loss > cur_min:
            loss_imgep = loss_imgep.at[i].set(cur_min)
        else:
            cur_min = cur_loss
            
    # Optax pipeline from closest init in RS discoveries
    loss_rs = vmap(loss_pattern, in_axes=(0,None,None,None))(rs_experiment_history.system_output_library.ys[:,0], A, b, w)
    rs_best_idx = rs_is_periodic_ids[loss_rs[rs_is_periodic_ids].argmin()]
    y0_rs = jnp.array([rs_experiment_history.intervention_params_library.y[node_idx][rs_best_idx, 0] for node_idx in range(n)])
    c_rs = jnp.array([rs_experiment_history.intervention_params_library.c[param_idx][rs_best_idx, 0] for param_idx in range(n**2+n)])
    
    model = ModelRollout(deltaT, y0_rs, c_rs, SimpleModelStep())
    ys_rs, _, _ = model(n_steps)
    
    n_optim_steps = 100
    opt_state = optim.init((y0_rs, c_rs))
    loss_rs_sgd = []
    for optim_step_idx in range(n_optim_steps):
        loss, (y0_rs, c_rs), opt_state = make_step((y0_rs, c_rs), A, b, w, opt_state)
        loss_rs_sgd.append(loss)
        
    model = ModelRollout(deltaT, y0_rs, c_rs, SimpleModelStep())
    ys_rs_sgd, _, _ = model(n_steps)
    
    ## arrange loss prior optim for plotting
    cur_min = loss_rs[0]
    for i, cur_loss in enumerate(loss_rs):
        if cur_loss > cur_min:
            loss_rs = loss_rs.at[i].set(cur_min)
        else:
            cur_min = cur_loss

### Analysis of the discoveries

In [30]:
#@title [Figure 4]
fig_idx = 4

if nb_mode == "run":
    
    fig = make_subplots(rows=3, cols=12, specs=[[{'type': 'surface', 'colspan': 4}, None, None, None, {'colspan': 4}, None, None, None, {'colspan': 4}, None, None, None],
                                               [{'colspan': 3}, None, None, {'colspan': 3}, None, None, {'colspan': 3}, None, None, {'colspan': 3}, None, None], 
                                               [{'colspan': 3}, None, None, {'colspan': 3}, None, None, {'colspan': 3}, None, None, {'colspan': 3}, None, None]], 
                        row_heights=[2.5,1,1], horizontal_spacing=0.05, vertical_spacing=0.15,
                        subplot_titles=["<b>(a) discoveries in 3D analytic space</b><br> ", "<b>(b) discoveries distribution (in analytic space)</b><br> ", "<b>(c) diversity throughout exploration</b><br> ",
                                        "<b>(d) training losses</b><br> ", "<b>(e) IMGEP best discovery</b><br> ", "<b>(f) RS best discovery</b><br> ", "<b>(g) SGD best discovery</b><br> ",
                                        "<b>(h) IMGEP best + SGD local refinement</b><br> ", "", "<b>(i) RS best + SGD local refinement</b><br> ", ""])
    fig.update_layout(default_layout, margin_l=40, margin_t=40)
    fig.update_annotations(default_annotation_layout)    
    
    # (c) Diversity curves
    fig.add_trace(go.Scatter(x=jnp.arange(len(rs_coverage_histograms)), y=(rs_coverage_histograms>0).sum(-1).sum(-1).sum(-1),
                             legendgroup="RS", showlegend=False, line=dict(color=default_colors[0])),
                  row=1, col=9)
    fig.add_trace(go.Scatter(x=jnp.arange(len(imgep_coverage_histograms)), y=(imgep_coverage_histograms>0).sum(-1).sum(-1).sum(-1),
                             legendgroup="IMGEP", showlegend=False, line=dict(color=default_colors[1])),
                  row=1, col=9)
    fig.add_vline(x=5000, line_width=2, line_dash="dash", annotation_text="n=5000", annotation_position="top left", row=1, col=9)  
    fig.update_xaxes(default_layout.xaxis, title_text="n (# exploration runs)", title_standoff=3, side="bottom", row=1, col=9)
    fig.update_yaxes(default_layout.yaxis, title_text="Diversity", title_standoff=1, row=1, col=9)
    
    
    # (b) Box Plots
    rs_data = np.concatenate([rs_offset_vals[rs_is_periodic_ids], rs_ampl_vals[rs_is_periodic_ids], rs_freq_vals[rs_is_periodic_ids]])
    fig.add_trace(go.Box(x=np.array(["offset"]*len(rs_is_periodic_ids) + ["amplitude"]*len(rs_is_periodic_ids) + ["main<br>frequency"]*len(rs_is_periodic_ids)), 
                         y=rs_data, showlegend=False, legendgroup="RS",
                         line=dict(color=default_colors[0]), fillcolor=default_colors_shade[0],
                         boxpoints='all', pointpos=0., marker_size=3), row=1, col=5)

    imgep_data = np.concatenate([imgep_offset_vals[imgep_is_periodic_ids], imgep_ampl_vals[imgep_is_periodic_ids], imgep_freq_vals[imgep_is_periodic_ids]])
    fig.add_trace(go.Box(x=np.array(["offset"]*len(imgep_is_periodic_ids) + ["amplitude"]*len(imgep_is_periodic_ids) + ["main<br>frequency"]*len(imgep_is_periodic_ids)), 
                         y=imgep_data, showlegend=False, legendgroup="imgep",
                         line=dict(color=default_colors[1]), fillcolor=default_colors_shade[1],
                         boxpoints='all', pointpos=0., marker_size=3), row=1, col=5)
    
    fig.update_layout(boxmode='group')

    
    # (a) 3D scatter plot
    fig.add_trace(go.Scatter3d(x=imgep_offset_vals[imgep_is_periodic_ids], 
                               y=imgep_ampl_vals[imgep_is_periodic_ids], 
                               z=imgep_freq_vals[imgep_is_periodic_ids], 
                               legendgroup="IMGEP", 
                               showlegend = False, mode="markers",
                               marker=dict(color=default_colors[1], size=3),
                        ), row=1, col=1)
    fig.add_trace(go.Scatter3d(x=rs_offset_vals[rs_is_periodic_ids], 
                           y=rs_ampl_vals[rs_is_periodic_ids], 
                           z=rs_freq_vals[rs_is_periodic_ids],
                           legendgroup="RS",
                           showlegend = False, mode="markers",
                           marker=dict(color=default_colors[0], size=3),
                    ), row=1, col=1)
    
    ## draw x,y and z axes
    fig.add_trace(go.Scatter3d(x=[0,1], y=[0,0], z=[0,0], showlegend=False, mode="lines", hoverinfo="skip", line_color="black"), row=1, col=1)
    fig.add_trace(go.Cone(x=[1], y=[0], z=[0], u=[0.2], v=[0], w=[0], hoverinfo="skip", sizemode="absolute", colorscale=[[0, "black"], [1, "black"]], showscale=False), row=1, col=1)
    fig.add_trace(go.Scatter3d(x=[0,0], y=[0,1], z=[0,0], showlegend=False, mode="lines", hoverinfo="skip", line_color="black"), row=1, col=1)
    fig.add_trace(go.Cone(x=[0], y=[1], z=[0], u=[0], v=[0.2], w=[0], hoverinfo="skip", sizemode="absolute", colorscale=[[0, "black"], [1, "black"]], showscale=False), row=1, col=1)
    fig.add_trace(go.Scatter3d(x=[0,0], y=[0,0], z=[0,1], showlegend=False, mode="lines", hoverinfo="skip", line_color="black"), row=1, col=1)
    fig.add_trace(go.Cone(x=[0], y=[0], z=[1], u=[0], v=[0], w=[0.2], hoverinfo="skip", sizemode="absolute", colorscale=[[0, "black"], [1, "black"]], showscale=False), row=1, col=1)
    
    fig.update_scenes(xaxis=dict(title="", autorange="reversed"), 
                      yaxis=dict(title=""),
                      zaxis=dict(title=""),
                      annotations=[dict(x=1.1, y=-0.1, z=-0.1, showarrow=False, text="offset"),
                                  dict(x=-0.1, y=1.1, z=-0.1, showarrow=False, text="amplitude"),
                                  dict(x=-0.1, y=-0.1, z=1.1, showarrow=False, text="main<br>frequency")],
                      camera=dict(eye=dict(x=1.5,y=-1.,z=0.2),),
                      row=1, col=1)
    
    
    # (d) training losses
    fig.add_trace(go.Scatter(y=loss_imgep, mode="lines", name=f"curiosity search (IMGEP)<br>(<b>{len(imgep_is_periodic_ids)}</b>/{len(imgep_is_periodic_bool)} oscillators)", showlegend=True, legendgroup="IMGEP",
                             line_color=default_colors[1]), row=2, col=1)
    imgep_idx = jnp.where(loss_imgep == loss_imgep[-1])[0][0]+1
    fig.add_annotation(x=imgep_idx, y=loss_imgep[imgep_idx], text=f"<b>N={imgep_idx} <br> L={loss_imgep[imgep_idx]:.2f}</b>", ax=30, ay=-25, row=2, col=1)
    fig.add_trace(go.Scatter(y=loss_rs, mode="lines", name=f"random search (RS)<br>(<b>{len(rs_is_periodic_ids)}</b>/{len(rs_is_periodic_bool)} oscillators)", showlegend=True, legendgroup="RS",
                             line_color=default_colors[0]), row=2, col=1)
    rs_idx = jnp.where(loss_rs == loss_rs[-1])[0][0]+1
    fig.add_annotation(x=rs_idx, y=loss_rs[rs_idx], text=f"N={rs_idx} <br> L={loss_rs[rs_idx]:.2f}", ax=30, ay=-25, row=2, col=1)
    fig.add_trace(go.Scatter(y=loss_sgd, mode="lines", name=f"gradient descent (SGD)<br>(<b>{n_sgd_oscillators}</b>/{len(rs_is_periodic_bool)} oscillators)", showlegend=True, legendgroup="SGD",
                             line_color=default_colors[2]), row=2, col=1)
    fig.add_annotation(x=len(loss_sgd)-1, y=loss_sgd[-2], text=f"N=5000 <br> L={loss_sgd[-2]:.2f}", ax=-30, ay=-30, row=2, col=1)
    
    # (e) IMGEP best discovery 
    fig.add_trace(go.Scatter(x=random_system_outputs.ts, y=ys_imgep[0, :3001], showlegend=False, legendgroup="IMGEP",
                                 mode="lines", line_color=default_colors[1]), row=2, col=4)
    
    # (f) RS best discovery 
    fig.add_trace(go.Scatter(x=random_system_outputs.ts, y=ys_rs[0, :3001], showlegend=False, legendgroup="RS",
                                 mode="lines", line_color=default_colors[0]), row=2, col=7)
    
    # (g) SGD best discovery
    fig.add_trace(go.Scatter(x=random_system_outputs.ts, y=ys_sgd[0, :3001], name="gradient descent", showlegend=False, legendgroup="SGD",
                                 mode="lines", line_color=default_colors[2]), row=2, col=10)
    
    # (h) IMGEP best + SGD local refinement
    ## loss
    fig.add_trace(go.Scatter(y=loss_imgep_sgd, mode="lines", name="target", showlegend=False, legendgroup="IMGEP",
                             line_color=default_colors[1]), row=3, col=1)
    fig.add_annotation(x=0, y=loss_imgep_sgd[0], text=f"{loss_imgep_sgd[0]:.2f}", ax=30, ay=10, row=3, col=1)
    fig.add_annotation(x=len(loss_imgep_sgd)-1, y=loss_imgep_sgd[-2], text=f"<b>{loss_imgep_sgd[-2]:.2f}</b>", ax=-30, ay=-30, row=3, col=1)
    ## results
    fig.add_trace(go.Scatter(x=random_system_outputs.ts, y=ys_imgep_sgd[0, :3001], showlegend=False, legendgroup="IMGEP",
                                 mode="lines", line_color=default_colors[1]), row=3, col=4)
    
    # (i) RS best + SGD local refinement
    ## loss
    fig.add_trace(go.Scatter(y=loss_rs_sgd, mode="lines", showlegend=False, legendgroup="RS",
                             line_color=default_colors[0]), row=3, col=7)
    fig.add_annotation(x=0, y=loss_rs_sgd[0], text=f"{loss_rs_sgd[0]:.2f}", ax=10, ay=20, row=3, col=7)
    fig.add_annotation(x=len(loss_rs_sgd)-1, y=loss_rs_sgd[-2], text=f"{loss_rs_sgd[-2]:.2f}", ax=-10, ay=20, row=3, col=7)
    ## results
    fig.add_trace(go.Scatter(x=random_system_outputs.ts, y=ys_rs_sgd[0, :3001], showlegend=False, legendgroup="RS",
                                 mode="lines", line_color=default_colors[0]), row=3, col=10)
    
    # add target
    target_ys = A*jnp.cos(2*jnp.pi*w*random_system_outputs.ts)+b
    for row_idx in [2,3]:
        for col_idx in [4, 7, 10]:
            if row_idx == 3 and col_idx == 7:
                continue
            fig.add_trace(go.Scatter(x=random_system_outputs.ts, y=target_ys[:3001], name="target", showlegend=(row_idx==2 and col_idx==4),
                             legendgroup="target", mode="lines", line_color=default_colors[4]), row=row_idx, col=col_idx)
            fig.update_yaxes(range=[0,1.05], row=row_idx, col=col_idx)

    # Layout  
    fig.update_yaxes(gridcolor='rgba(0., 0., 0., 0.)')
    fig.update_yaxes(range=[13,24.9], row=3, col=1)
    fig.update_yaxes(range=[13,24.9], row=3, col=7)

    
    
    # Serialize fig to json and save
    if nb_save_outputs:
        fig.write_json(f"figures/tuto2_fig_{fig_idx}.json")
        
            
elif nb_mode == "load":
    fig = plotly.io.read_json(f"figures/tuto2_fig_{fig_idx}.json")
    

# Display Fig
width, height = 940, 700
t = f"Characteristics of the discovered oscillators."

if nb_renderer == "html":
    html_fig = make_html_fig(fig_idx, fig, width, height, t)
    display(HTML(html_fig))

elif nb_renderer == "img":
    img_fig, img_title = make_img_fig(fig_idx, fig, width, height, t)
    display(Image(img_fig))
    print(img_title)

Figure shows (a) the evolution of the training loss L throughout the N=5000 trials for the three exploration strategies, (b-c-d) the resulting  best discoveries (minimum loss $L$ as defined before) for each exploration strategies, and (e-f) the local training loss and resulting finetuning of the best discoveries with gradient descent. 

&#x1F449; We can see that gradient descent alone fails to discover an oscillator in this example, as it get trap in a strong local minima (constant signal with same average than the target oscillator). This shows the challenge of finding a proper loss and/or parameter initialization. However, it can be useful for local finetuning of "good-enough" discoveries, such as the ones discovered by the curiosity search and/or random search (though curiosity search allows to reach a better results in a  more efficient way with here N=1057+100, L=13.97 instead of N=2663+100, L=22.92).