In [None]:
import pandas as pd
import plotly.express as px
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.colors import DEFAULT_PLOTLY_COLORS as colors
from scipy import interpolate

### Reproduce Fig.3

In [None]:
all_levels = ['AAV_medium', 'AAV_hard', 'GFP_medium', 'GFP_hard']
algs = ['ours', 'adalead', 'pex']

fig = make_subplots(rows=4, cols=4, vertical_spacing=0.05, horizontal_spacing=0.02,
                    row_titles=[' '.join(l.split('_')) for l in all_levels], column_titles=['Fitness', 'Diversity', 'd_init', 'd_high'])
for row, protein_level in enumerate(all_levels):
    
    for alg, color in zip(algs, colors):
        data = pd.read_csv(f'summary/{alg}/{protein_level}_total.csv')
        mean = data.groupby('round').mean().reset_index()
        std = data.groupby('round').std().reset_index()
        
        for col, metric in enumerate(['median fitness','diversity','novelty','high']):
        
            fig.add_trace(go.Scatter(x=mean['round'], y=mean[metric], 
                                     marker=dict(size=8),
                                     line=dict(color=color, width=4), showlegend=False), 
                          row=row+1, col=col+1)
            fig.add_trace(go.Scatter(x=mean['round'],
                y=mean[metric] + std[metric],
                mode='lines',
                line=dict(width=0),
                showlegend=False), row=row+1, col=col+1)
            fig.add_trace(go.Scatter(x=mean['round'],
                y=mean[metric] - std[metric],
                mode='lines',
                fill='tonexty',
                fillcolor=color.replace('rgb','rgba').replace(')',',0.2)'),
                line=dict(width=0),
                showlegend=False), row=row+1, col=col+1) 
for alg, color in zip(algs, colors):
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines',
                    line=dict(color=color,width=6), name=alg))

fig.update_layout(legend=dict(
    orientation="h",
    entrywidth=120,
    yanchor="bottom",
    y=-0.15,
    xanchor="left",
    font=dict(size=30),
    x=0.33
))
fig.update_xaxes(showline=True, linewidth=3, linecolor='black', gridcolor='lightgray')
fig.update_yaxes(showline=True, linewidth=3, linecolor='black', gridcolor='lightgray')
fig.update_layout(margin=dict(b=40,r=10,l=10,t=30),
                  height=800, width=1600, plot_bgcolor='rgba(0,0,0,0)', font=dict(size=20))
for i, annotation in enumerate(fig.layout.annotations):
    fig.layout.annotations[i].font.size = 25


fig.show()

In [None]:
fig.write_image('rounds.png', scale=1.2)

### Reproduce Fig.5

In [None]:
all_levels = [ 'AAV_medium', 'AAV_hard', 'GFP_medium', 'GFP_hard',]
fig = make_subplots(rows=4, cols=3, vertical_spacing=0.06, horizontal_spacing=0.06,
                    row_titles=[' '.join(l.split('_')) for l in all_levels], column_titles=['Fitness','Step Mutation','Episode Length'])
max_x = 4_000
common_x = np.arange(0, max_x, 50)

for row, protein_level in enumerate(all_levels):
    
    for alg, color in zip(['ablation', 'ablation_m3'], colors):
        datas = []
        for i in range(5):
            data = pd.read_csv(f'summary/{alg}/{protein_level}_{i}_eval.csv')
            data['Episode Length'] = data['_step'].diff().shift(-1)
            data = data[data['_step'] <= max_x]
            datas.append(data)
        trains = []
        for i in range(5):
            data = pd.read_csv(f'summary/{alg}/{protein_level}_{i}_train.csv')
            trains.append(data.iloc[:15]['eval/fitness'].tolist())
        mean_y = np.mean(np.array(trains).T, axis=1)
        std_y = np.std(np.array(trains).T, axis=1)
        fig.add_trace(go.Scatter(x=list(range(1,16)), y=mean_y, 
                                 marker=dict(size=8),
                                 line=dict(color=color, width=3), showlegend=False), 
                      row=row+1, col=1)
        fig.add_trace(go.Scatter(x=list(range(1,16)),
            y=(mean_y + std_y),
            mode='lines',
            line=dict(width=0),
            showlegend=False),  row=row+1, col=1)
        fig.add_trace(go.Scatter(x=list(range(1,16)),
            y=(mean_y - std_y),
            fill='tonexty',
            mode='lines',
            fillcolor=color.replace('rgb','rgba').replace(')',',0.2)'),
            line=dict(width=0),
            showlegend=False), row=row+1, col=1)         

        for col, y in enumerate(['Step Mutation','Episode Length']):
            
            interpolated_ys = []

            for data in datas:
                x = data['_step']
                y_val = data[y]
            
                f = interpolate.interp1d(x, y_val, bounds_error=False, fill_value="extrapolate")
                interpolated_y = f(common_x)
                interpolated_ys.append(interpolated_y)

            interpolated_ys = np.array(interpolated_ys)
            mean_y = np.mean(interpolated_ys, axis=0)
            std_y = np.std(interpolated_ys, axis=0)
            
            fig.add_trace(go.Scatter(x=common_x[1:], y=mean_y[1:], line=dict(color=color, width=3), showlegend=False), row=row+1, col=col+2)
            
        
for color, name in zip(colors[:2], ['m_step=3', 'None']):
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines',
                    line=dict(color=color), name=name))

fig.update_layout(legend=dict(
    orientation="h",
    entrywidth=70,
    yanchor="bottom",
    y=-0.2,
    xanchor="left",
    x=0
))

fig.update_xaxes(title='Step', row=4, col=2)
fig.update_xaxes(title='Step', row=4, col=3)
fig.update_xaxes(title='Round', row=4, col=1)

for row in range(1, 5):
    fig.update_yaxes(range=[0,5.5], row=row, col=2)
    fig.update_yaxes(range=[0,7], row=row, col=3)

fig.update_xaxes(showline=True, linewidth=3, linecolor='black', gridcolor='lightgray')
fig.update_yaxes(showline=True, linewidth=3, linecolor='black', gridcolor='lightgray')
# , paper_bgcolor='rgba(0,0,0,0)'
fig.update_layout(margin=dict(b=40,r=40,l=10,t=40),
                  height=700, width=1000, plot_bgcolor='rgba(0,0,0,0)', font=dict(size=20))
for i, annotation in enumerate(fig.layout.annotations):
    fig.layout.annotations[i].font.size = 23
    fig.layout.annotations[i].font.color = 'black'
fig.show()

In [None]:
fig.write_image('mstep.png', scale=1.2)