# MolPAL Figure Notebook

the first step in recreating most figures from the main text will be to import everything and load all of the data. To do that, go down to [this cell](#RUN-ALL-CELLS-ABOVE-ME), and run all cells above

In [None]:
import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px
from plotly.subplots import make_subplots

import pandas as pd

pio.templates.default = 'plotly_white'

MODELS = ['rf', 'nn', 'mpn']
METRICS = ['greedy', 'ucb', 'thompson', 'ei', 'pi']
METRIC_NAMES = {'greedy': 'greedy', 'ucb': 'UCB', 'thompson': 'TS',
                'ei': 'EI', 'pi': 'PI'}
SPLITS = [0.4, 0.2, 0.1]

DASHES = ['dash', 'dot', 'dashdot']
MARKERS = ['circle', 'square', 'diamond']
METRIC_COLORS = px.colors.qualitative.Plotly
MODEL_COLORS = px.colors.qualitative.D3

In [None]:
E10k_random = {
    'avg': ([80.023, 85.934, 88.358, 89.827, 90.707, 91.408],
            [1.149, 0.459, 0.431, 0.508, 0.270, 0.337]),
    'scores': ([1.600, 2.800, 3.200, 4.200, 4.800, 5.600],
               [1.356, 1.470, 1.600, 1.327, 0.748, 0.800]),
    'smis': ([1.600, 2.800, 3.200, 3.800, 4.400, 5.000],
             [1.356, 1.470, 1.600, 1.600, 1.020, 0.894]),
}

In [None]:
#top-100
E10k_online = {
 'mpn': {'avg': {'ei': ([80.166, 92.926, 94.476, 96.067, 97.065, 97.775],
                        [0.961, 0.076, 0.291, 0.587, 0.266, 0.265]),
                 'greedy': ([79.909, 92.176, 94.513, 96.101, 96.941, 97.682],
                            [0.532, 1.091, 0.242, 0.196, 0.395, 0.501]),
                 'pi': ([79.961, 91.992, 94.141, 95.639, 96.529, 97.398],
                        [0.565, 0.900, 0.460, 0.346, 0.370, 0.286]),
                 'thompson': ([79.865, 91.644, 93.073, 95.726, 96.185, 97.442],
                              [0.702, 0.865, 1.170, 0.917, 0.886, 0.334]),
                 'ucb': ([79.799, 91.754, 94.093, 95.993, 97.055, 97.825],
                         [0.585, 0.809, 0.600, 0.372, 0.451, 0.471])},
         'scores': {'ei': ([1.800, 15.800, 20.400, 29.000, 36.400, 45.400],
                           [0.748, 0.748, 1.356, 3.286, 2.417, 5.004]),
                    'greedy': ([1.600, 14.800, 20.400, 28.600, 36.400, 44.600],
                               [1.625, 2.638, 1.020, 2.154, 5.276, 6.946]),
                    'pi': ([1.200, 13.000, 18.200, 24.000, 30.200, 39.400],
                           [0.748, 1.414, 2.400, 2.280, 3.600, 3.262]),
                    'thompson': ([0.800, 10.800, 13.400, 25.800, 28.800, 41.400],
                                 [0.748, 3.868, 5.783, 7.250, 7.859, 4.673]),
                    'ucb': ([1.000, 12.200, 18.400, 27.400, 36.400, 45.800],
                            [0.632, 2.400, 1.625, 3.323, 5.886, 7.111])},
         'smis': {'ei': ([1.800, 13.400, 17.200, 24.800, 31.000, 38.800],
                         [0.748, 1.020, 0.980, 3.059, 2.449, 3.429]),
                  'greedy': ([1.200, 12.800, 17.000, 23.600, 30.000, 37.200],
                             [1.166, 2.135, 1.414, 2.332, 4.050, 6.013]),
                  'pi': ([1.000, 10.800, 14.200, 19.400, 24.600, 33.200],
                         [0.632, 0.748, 2.040, 2.245, 3.200, 2.926]),
                  'thompson': ([0.800, 10.200, 12.600, 23.200, 25.800, 37.000],
                               [0.748, 4.354, 6.280, 7.222, 7.935, 5.215]),
                  'ucb': ([0.800, 10.800, 15.400, 22.600, 30.400, 38.800],
                          [0.748, 2.040, 2.059, 3.137, 5.122, 6.274])}},
 'nn': {'avg': {'ei': ([79.708, 92.806, 93.385, 94.906, 95.683, 96.094],
                       [0.518, 0.432, 1.239, 1.007, 1.026, 1.119]),
                'greedy': ([79.683, 92.938, 96.221, 97.473, 97.949, 98.487],
                           [0.471, 1.073, 0.634, 0.505, 0.479, 0.305]),
                'pi': ([79.865, 92.822, 94.919, 95.898, 96.368, 96.931],
                       [0.552, 1.015, 0.771, 0.693, 0.698, 0.920]),
                'thompson': ([79.690, 92.336, 95.169, 96.796, 97.634, 98.162],
                             [0.686, 0.635, 0.434, 0.208, 0.157, 0.154]),
                'ucb': ([79.760, 92.450, 94.948, 96.099, 97.022, 97.587],
                        [0.818, 0.992, 0.977, 1.145, 1.097, 0.821])},
        'scores': {'ei': ([1.600, 12.800, 15.200, 20.000, 24.000, 28.000],
                          [1.020, 3.250, 6.431, 5.727, 5.292, 6.693]),
                   'greedy': ([0.800, 15.600, 28.600, 40.200, 46.400, 55.400],
                              [0.980, 3.878, 6.312, 7.705, 7.761, 6.184]),
                   'pi': ([1.600, 15.400, 23.200, 29.000, 32.000, 37.600],
                          [0.490, 4.224, 3.124, 4.427, 4.195, 6.711]),
                   'thompson': ([1.800, 13.400, 22.600, 34.200, 42.800, 51.200],
                                [0.400, 1.497, 3.262, 2.040, 1.600, 3.868]),
                   'ucb': ([0.600, 14.200, 22.600, 29.000, 37.000, 43.400],
                           [0.800, 3.250, 6.406, 10.450, 11.610, 11.586])},
        'smis': {'ei': ([1.600, 11.600, 14.000, 18.000, 21.400, 24.600],
                        [1.020, 2.417, 6.164, 5.550, 5.535, 6.216]),
                 'greedy': ([0.800, 14.600, 26.000, 36.000, 41.200, 49.800],
                            [0.980, 3.262, 6.573, 8.025, 8.518, 6.369]),
                 'pi': ([1.400, 14.000, 20.600, 25.000, 27.600, 33.000],
                        [0.800, 3.950, 2.577, 3.578, 3.007, 5.404]),
                 'thompson': ([1.400, 12.200, 20.600, 30.400, 38.400, 45.800],
                              [0.490, 2.227, 4.030, 3.072, 2.800, 3.544]),
                 'ucb': ([0.400, 13.200, 20.600, 26.200, 33.000, 38.400],
                         [0.800, 3.487, 6.530, 9.826, 10.526, 10.111])}},
 'rf': {'avg': {'ei': ([79.725, 89.860, 93.306, 95.026, 95.960, 96.533],
                       [0.912, 0.849, 0.676, 0.506, 0.717, 0.491]),
                'greedy': ([79.731, 92.816, 95.461, 96.767, 97.446, 97.864],
                           [0.582, 0.569, 0.566, 0.514, 0.318, 0.192]),
                'pi': ([79.865, 91.435, 94.335, 95.585, 96.235, 97.057],
                       [1.177, 0.483, 0.543, 0.426, 0.489, 0.585]),
                'thompson': ([79.035, 87.929, 90.754, 92.526, 93.712, 94.575],
                             [0.479, 0.571, 0.673, 0.381, 0.360, 0.313]),
                'ucb': ([79.166, 89.515, 92.760, 94.679, 95.759, 96.457],
                        [0.734, 1.759, 0.769, 0.858, 0.556, 0.506])},
        'scores': {'ei': ([1.200, 7.400, 16.000, 22.000, 27.600, 32.200],
                          [0.748, 2.154, 5.550, 6.542, 9.178, 7.250]),
                   'greedy': ([2.200, 15.600, 25.400, 34.000, 41.400, 46.200],
                              [1.166, 2.417, 4.964, 5.899, 4.363, 2.135]),
                   'pi': ([1.200, 10.600, 19.600, 25.200, 29.400, 36.800],
                          [0.748, 2.871, 5.004, 4.707, 4.673, 6.274]),
                   'thompson': ([1.400, 4.400, 7.400, 10.800, 14.400, 17.400],
                                [0.800, 2.728, 3.929, 3.655, 3.007, 3.200]),
                   'ucb': ([0.600, 7.400, 13.400, 19.600, 25.200, 30.200],
                           [0.800, 3.555, 4.716, 5.122, 5.776, 5.776])},
        'smis': {'ei': ([1.200, 6.200, 13.200, 18.200, 22.800, 27.000],
                        [0.748, 2.482, 4.707, 5.600, 7.547, 5.762]),
                 'greedy': ([2.200, 13.800, 22.800, 30.400, 37.000, 40.800],
                            [1.166, 2.482, 4.833, 6.829, 5.215, 3.655]),
                 'pi': ([0.800, 8.800, 16.800, 21.800, 25.200, 31.400],
                        [0.748, 2.227, 4.354, 4.167, 4.792, 5.389]),
                 'thompson': ([1.200, 3.600, 6.400, 9.000, 12.600, 15.200],
                              [0.748, 2.577, 3.555, 3.162, 2.245, 3.250]),
                 'ucb': ([0.600, 6.600, 11.200, 16.800, 22.000, 26.800],
                         [0.800, 3.200, 4.261, 5.671, 5.727, 5.307])}}}

In [None]:
E10k_retrain = {
 'mpn': {'avg': {'ei': ([79.427, 91.615, 95.612, 97.373, 98.203, 98.795],
                        [0.597, 0.933, 0.983, 0.450, 0.361, 0.190]),
                 'greedy': ([79.901, 92.609, 95.734, 97.523, 98.477, 98.942],
                            [0.509, 0.377, 0.642, 0.351, 0.163, 0.118]),
                 'pi': ([79.950, 92.379, 95.689, 97.448, 98.381, 98.758],
                        [0.472, 0.832, 0.387, 0.164, 0.204, 0.161]),
                 'thompson': ([79.232, 91.442, 95.066, 96.817, 97.858, 98.584],
                              [0.770, 0.843, 0.427, 0.377, 0.346, 0.199]),
                 'ucb': ([78.905, 91.669, 95.471, 97.332, 98.168, 98.800],
                         [0.394, 1.047, 0.334, 0.379, 0.261, 0.166])},
         'scores': {'ei': ([1.800, 13.400, 28.000, 42.400, 54.000, 64.400],
                           [0.748, 4.224, 6.261, 5.389, 5.865, 3.826]),
                    'greedy': ([1.600, 15.200, 28.400, 44.400, 58.400, 67.000],
                               [0.490, 2.638, 4.630, 5.238, 3.611, 3.033]),
                    'pi': ([1.200, 14.600, 25.800, 41.600, 56.400, 63.000],
                           [0.748, 2.498, 2.638, 1.625, 3.878, 3.162]),
                    'thompson': ([2.000, 13.200, 23.200, 35.400, 47.800, 60.000],
                                 [0.632, 2.482, 2.315, 4.454, 5.636, 4.099]),
                    'ucb': ([0.200, 12.400, 26.000, 42.600, 53.200, 64.400],
                            [0.400, 3.774, 2.280, 5.004, 4.750, 4.079])},
         'smis': {'ei': ([1.400, 12.200, 24.600, 37.200, 47.000, 57.000],
                         [0.490, 3.544, 5.463, 4.020, 5.020, 3.578]),
                  'greedy': ([1.200, 13.000, 24.000, 38.200, 51.000, 59.600],
                             [0.400, 2.191, 3.347, 4.020, 2.280, 2.332]),
                  'pi': ([1.000, 12.800, 21.600, 35.600, 48.400, 55.200],
                         [0.894, 2.482, 3.611, 2.154, 3.007, 3.311]),
                  'thompson': ([1.800, 12.000, 20.200, 30.800, 41.600, 52.200],
                               [0.748, 2.280, 1.166, 3.311, 5.426, 4.622]),
                  'ucb': ([0.000, 9.800, 21.400, 36.600, 45.800, 57.000],
                          [0.000, 3.970, 2.577, 4.454, 4.214, 3.162])}},
 'nn': {'avg': {'ei': ([79.700, 92.692, 96.330, 97.667, 98.071, 98.423],
                       [0.930, 0.798, 0.382, 0.248, 0.367, 0.416]),
                'greedy': ([79.093, 93.636, 96.678, 97.796, 98.419, 98.965],
                           [0.429, 1.159, 0.579, 0.397, 0.312, 0.202]),
                'pi': ([79.865, 92.806, 96.053, 97.434, 97.965, 98.553],
                       [0.857, 0.777, 0.372, 0.163, 0.299, 0.150]),
                'thompson': ([79.822, 92.437, 95.821, 97.351, 98.222, 98.729],
                             [0.481, 0.826, 0.443, 0.149, 0.213, 0.192]),
                'ucb': ([79.988, 92.452, 95.892, 97.411, 98.160, 98.588],
                        [0.944, 0.696, 0.182, 0.172, 0.166, 0.164])},
        'scores': {'ei': ([0.800, 14.400, 30.800, 44.200, 50.400, 56.000],
                          [0.748, 3.720, 4.578, 4.354, 6.621, 7.457]),
                   'greedy': ([1.400, 20.600, 34.400, 46.800, 55.200, 66.800],
                              [1.744, 4.176, 5.678, 5.564, 5.946, 5.418]),
                   'pi': ([1.200, 13.800, 28.600, 41.600, 47.800, 57.800],
                          [1.166, 1.720, 3.262, 2.871, 4.956, 2.400]),
                   'thompson': ([1.200, 12.800, 25.400, 39.000, 52.800, 61.400],
                                [0.748, 1.600, 1.356, 1.095, 3.868, 3.878]),
                   'ucb': ([1.600, 12.200, 27.400, 41.200, 50.600, 58.000],
                           [0.800, 2.561, 1.020, 3.059, 3.666, 3.521])},
        'smis': {'ei': ([0.600, 13.200, 27.600, 39.600, 45.000, 49.800],
                        [0.490, 4.445, 4.317, 3.878, 6.325, 6.853]),
                 'greedy': ([1.000, 18.600, 30.600, 40.800, 48.600, 59.200],
                            [1.265, 3.878, 5.571, 5.810, 6.468, 6.145]),
                 'pi': ([1.200, 13.000, 27.000, 38.200, 43.200, 51.600],
                        [1.166, 1.789, 2.530, 2.135, 3.919, 2.332]),
                 'thompson': ([1.200, 11.400, 22.400, 34.600, 46.200, 54.600],
                              [0.748, 2.059, 1.744, 1.625, 2.926, 3.441]),
                 'ucb': ([1.600, 11.000, 25.000, 36.600, 44.400, 51.200],
                         [0.800, 2.608, 1.414, 2.577, 3.441, 3.370])}},
 'rf': {'avg': {'ei': ([79.180, 90.721, 94.074, 95.639, 96.664, 97.158],
                       [0.514, 1.350, 0.811, 0.610, 0.741, 0.764]),
                'greedy': ([79.981, 92.636, 95.469, 96.864, 97.558, 98.206],
                           [0.781, 1.316, 0.556, 0.397, 0.359, 0.306]),
                'pi': ([79.884, 90.338, 94.403, 96.457, 97.274, 97.816],
                       [1.071, 2.804, 1.136, 0.295, 0.278, 0.245]),
                'thompson': ([80.180, 88.761, 91.976, 93.493, 95.002, 95.972],
                             [0.413, 0.503, 0.370, 0.367, 0.352, 0.339]),
                'ucb': ([79.762, 90.365, 94.018, 95.751, 96.968, 97.580],
                        [0.629, 0.323, 0.340, 0.271, 0.212, 0.248])},
        'scores': {'ei': ([0.600, 9.400, 18.200, 26.000, 33.600, 39.400],
                          [0.800, 2.871, 4.354, 5.404, 8.163, 9.478]),
                   'greedy': ([0.600, 13.600, 24.400, 34.200, 41.400, 51.600],
                              [0.800, 4.499, 3.072, 4.069, 5.987, 5.851]),
                   'pi': ([1.000, 10.800, 21.400, 34.000, 40.200, 47.600],
                          [1.095, 4.622, 3.826, 2.098, 2.926, 4.176]),
                   'thompson': ([1.600, 4.800, 9.600, 13.000, 20.000, 27.600],
                                [0.490, 1.600, 2.154, 1.789, 2.098, 1.855]),
                   'ucb': ([0.400, 6.400, 15.400, 25.600, 36.400, 43.200],
                           [0.490, 2.154, 1.960, 2.728, 3.382, 3.370])},
        'smis': {'ei': ([0.600, 8.000, 15.400, 22.600, 28.800, 33.800],
                        [0.800, 2.280, 3.611, 5.352, 7.547, 9.108]),
                 'greedy': ([0.600, 10.800, 21.000, 30.200, 36.200, 44.800],
                            [0.800, 4.354, 3.347, 4.707, 5.492, 5.776]),
                 'pi': ([1.000, 8.800, 18.600, 29.600, 35.400, 41.400],
                        [1.095, 3.655, 3.499, 3.072, 4.030, 3.323]),
                 'thompson': ([1.400, 4.200, 8.600, 11.200, 16.800, 22.600],
                              [0.490, 1.470, 1.855, 1.600, 2.315, 2.653]),
                 'ucb': ([0.400, 5.600, 13.400, 22.000, 31.600, 37.200],
                         [0.490, 2.059, 2.059, 2.898, 3.720, 3.059])}}}

In [None]:
E50k_random = {
    'avg': ([78.964, 85.635, 88.049, 89.505, 90.516, 91.362],
            [0.316, 0.266, 0.252, 0.286, 0.146, 0.193]),
    'scores': ([0.281, 2.040, 3.160, 4.280, 5.400, 6.600],
               [0.273, 0.344, 0.585, 1.136, 1.012, 1.095]),
    'smis': ([0.720, 1.960, 3.040, 4.000, 5.040, 6.080],
             [0.271, 0.367, 0.625, 1.152, 1.155, 1.170]),
}

In [None]:
#top-500
E50k_online = {
 'mpn': {'avg': {'ei': ([78.817, 94.143, 96.742, 98.010, 98.571, 98.953],
                        [0.413, 0.394, 0.163, 0.047, 0.133, 0.111]),
                 'greedy': ([78.456, 94.235, 96.030, 97.700, 98.245, 98.698],
                            [0.294, 0.341, 0.712, 0.261, 0.403, 0.318]),
                 'pi': ([78.845, 94.161, 95.903, 97.610, 98.083, 98.640],
                        [0.327, 0.248, 0.883, 0.442, 0.592, 0.441]),
                 'thompson': ([78.816, 92.628, 96.029, 97.598, 98.485, 98.935],
                              [0.293, 0.606, 0.133, 0.225, 0.131, 0.093]),
                 'ucb': ([78.851, 94.464, 96.884, 98.020, 98.596, 99.000],
                         [0.160, 0.147, 0.223, 0.071, 0.239, 0.152])},
         'scores': {'ei': ([0.602, 22.600, 35.480, 47.880, 56.200, 63.320],
                           [0.380, 3.642, 1.904, 1.070, 2.245, 2.594]),
                    'greedy': ([0.399, 23.520, 30.960, 43.760, 51.080, 58.400],
                               [0.670, 1.787, 3.563, 3.008, 5.344, 4.943]),
                    'pi': ([0.523, 23.560, 30.600, 42.520, 49.000, 57.680],
                           [0.241, 1.183, 4.445, 5.326, 8.494, 8.073]),
                    'thompson': ([0.640, 15.040, 28.440, 41.680, 53.680, 62.200],
                                 [0.665, 2.378, 1.722, 3.597, 2.830, 2.671]),
                    'ucb': ([0.603, 23.760, 36.600, 47.920, 56.360, 64.640],
                            [0.524, 1.286, 2.280, 1.737, 4.421, 3.483])},
         'smis': {'ei': ([0.960, 21.240, 33.080, 44.760, 52.480, 59.160],
                         [0.320, 3.171, 1.719, 1.038, 2.215, 2.490]),
                  'greedy': ([0.960, 22.040, 28.800, 40.920, 47.880, 54.840],
                             [0.265, 1.839, 3.119, 2.600, 4.999, 4.560]),
                  'pi': ([1.160, 22.440, 28.760, 40.000, 46.080, 54.280],
                         [0.150, 1.203, 3.636, 4.912, 8.029, 7.616]),
                  'thompson': ([1.000, 14.200, 26.800, 39.280, 50.640, 58.440],
                               [0.379, 2.184, 1.431, 3.300, 2.590, 2.292]),
                  'ucb': ([1.040, 22.520, 34.120, 44.840, 52.800, 60.640],
                          [0.463, 1.348, 2.160, 1.856, 4.280, 3.522])}},
 'nn': {'avg': {'ei': ([79.020, 94.143, 94.995, 96.462, 96.894, 97.103],
                       [0.453, 0.454, 0.919, 1.155, 1.002, 1.051]),
                'greedy': ([78.994, 94.823, 97.227, 98.365, 98.790, 99.088],
                           [0.211, 0.256, 0.201, 0.145, 0.139, 0.104]),
                'pi': ([79.079, 94.154, 95.026, 96.327, 97.063, 97.254],
                       [0.503, 0.407, 1.123, 1.298, 0.874, 0.699]),
                'thompson': ([78.598, 93.866, 96.408, 97.508, 98.188, 98.560],
                             [0.423, 0.437, 0.395, 0.361, 0.347, 0.167]),
                'ucb': ([78.524, 94.303, 95.975, 97.316, 97.856, 98.291],
                        [0.354, 0.673, 0.718, 0.935, 0.731, 0.541])},
        'scores': {'ei': ([0.724, 22.920, 25.840, 34.800, 37.800, 39.480],
                          [0.414, 3.154, 5.107, 8.663, 7.474, 9.225]),
                   'greedy': ([0.801, 26.640, 40.320, 53.360, 60.320, 66.520],
                              [0.632, 2.061, 2.081, 2.010, 2.415, 2.211]),
                   'pi': ([0.721, 22.200, 26.320, 33.320, 37.800, 38.840],
                          [0.808, 1.802, 5.792, 10.131, 7.915, 7.244]),
                   'thompson': ([0.682, 21.760, 33.120, 42.240, 50.520, 55.920],
                                [0.517, 1.903, 3.428, 4.240, 5.683, 3.491]),
                   'ucb': ([0.520, 23.880, 31.080, 41.920, 47.200, 52.720],
                           [0.924, 4.287, 4.960, 9.946, 10.347, 10.090])},
        'smis': {'ei': ([1.320, 21.480, 24.280, 32.720, 35.520, 37.120],
                        [0.240, 3.040, 4.857, 8.176, 6.946, 8.646]),
                 'greedy': ([1.040, 25.160, 38.080, 50.160, 56.920, 62.840],
                            [0.367, 2.122, 1.862, 2.141, 2.506, 2.041]),
                 'pi': ([1.320, 21.240, 25.040, 31.520, 35.760, 36.680],
                        [0.449, 1.525, 5.432, 9.441, 7.587, 6.818]),
                 'random': ([0.720, 1.960, 3.040, 4.000, 5.040, 6.080],
                            [0.271, 0.367, 0.625, 1.152, 1.155, 1.170]),
                 'thompson': ([0.960, 20.480, 31.120, 39.560, 47.440, 52.600],
                              [0.463, 1.809, 3.367, 4.005, 5.431, 3.272]),
                 'ucb': ([1.080, 22.760, 29.520, 39.720, 44.760, 49.920],
                         [0.652, 4.007, 4.475, 9.394, 9.687, 9.301])}},
 'rf': {'avg': {'ei': ([78.786, 91.711, 94.464, 95.952, 96.705, 97.329],
                       [0.208, 0.925, 0.368, 0.417, 0.550, 0.249]),
                'greedy': ([78.628, 93.700, 96.734, 97.921, 98.378, 98.746],
                           [0.490, 0.382, 0.173, 0.104, 0.080, 0.045]),
                'pi': ([79.068, 92.409, 95.560, 96.727, 97.501, 97.997],
                       [0.308, 0.663, 0.267, 0.290, 0.236, 0.132]),
                'thompson': ([78.810, 89.425, 92.663, 94.463, 95.571, 96.296],
                             [0.419, 0.779, 0.385, 0.297, 0.146, 0.111]),
                'ucb': ([78.854, 92.510, 94.801, 96.293, 97.159, 97.724],
                        [0.243, 0.240, 0.684, 0.372, 0.318, 0.147])},
        'scores': {'ei': ([0.401, 11.880, 20.080, 27.560, 33.080, 38.560],
                          [0.553, 2.827, 1.107, 3.563, 5.367, 2.684]),
                   'greedy': ([0.201, 20.600, 36.040, 47.560, 53.800, 60.360],
                              [0.285, 1.757, 1.176, 1.587, 1.431, 1.155]),
                   'pi': ([0.601, 14.800, 26.440, 33.360, 40.680, 46.440],
                          [0.612, 2.985, 1.582, 2.877, 2.933, 1.830]),
                   'thompson': ([0.202, 6.760, 12.320, 18.360, 23.560, 27.960],
                                [0.285, 2.037, 1.532, 1.556, 1.359, 1.235]),
                   'ucb': ([0.640, 14.640, 21.680, 30.120, 37.800, 43.480],
                           [0.690, 1.106, 3.465, 3.329, 3.277, 1.792])},
        'smis': {'ei': ([0.920, 11.520, 19.440, 26.720, 31.920, 37.120],
                        [0.449, 2.982, 1.428, 3.704, 5.280, 2.667]),
                 'greedy': ([0.760, 19.240, 33.440, 44.400, 50.200, 56.400],
                            [0.388, 1.675, 0.991, 1.486, 1.246, 0.938]),
                 'pi': ([1.200, 14.320, 25.640, 32.040, 39.160, 44.560],
                        [0.219, 2.468, 1.405, 2.658, 2.658, 1.394]),
                 'thompson': ([1.000, 6.440, 11.800, 17.600, 22.440, 26.680],
                              [0.358, 1.899, 1.585, 1.507, 1.353, 1.306]),
                 'ucb': ([0.920, 13.960, 20.680, 28.800, 36.040, 41.360],
                         [0.271, 0.933, 3.461, 3.277, 3.340, 1.932])}}}

In [None]:
E50k_retrain = {
 'mpn': {'avg': {'ei': ([78.787, 94.248, 97.409, 98.469, 98.924, 99.122],
                        [0.298, 0.187, 0.076, 0.072, 0.067, 0.032]),
                 'greedy': ([78.867, 94.544, 97.597, 98.590, 99.088, 99.340],
                            [0.202, 0.299, 0.124, 0.061, 0.048, 0.030]),
                 'pi': ([78.864, 94.373, 97.496, 98.527, 98.972, 99.170],
                        [0.331, 0.311, 0.200, 0.091, 0.067, 0.074]),
                 'thompson': ([78.847, 92.568, 96.386, 98.134, 98.850, 99.250],
                              [0.325, 0.637, 0.354, 0.136, 0.099, 0.047]),
                 'ucb': ([78.492, 94.490, 97.507, 98.585, 99.068, 99.329],
                         [0.411, 0.124, 0.129, 0.051, 0.022, 0.024])},
         'scores': {'ei': ([0.399, 23.240, 41.280, 54.040, 62.160, 67.200],
                           [0.964, 1.311, 1.230, 1.666, 1.359, 0.938]),
                    'greedy': ([0.921, 24.920, 43.360, 56.800, 66.440, 72.920],
                               [0.413, 1.657, 1.091, 1.706, 1.255, 1.269]),
                    'pi': ([0.724, 23.480, 41.720, 55.240, 63.720, 69.160],
                           [0.349, 3.090, 2.197, 1.617, 1.163, 1.689]),
                    'thompson': ([0.120, 14.000, 30.600, 48.760, 60.520, 70.360],
                                 [0.519, 3.067, 2.896, 1.541, 1.378, 0.686]),
                    'ucb': ([0.240, 25.000, 41.920, 56.080, 65.520, 72.720],
                            [0.625, 0.522, 1.690, 0.952, 0.816, 0.500])},
         'smis': {'ei': ([0.920, 22.080, 38.560, 50.400, 58.080, 62.880],
                         [0.640, 1.348, 1.335, 1.565, 1.078, 0.900]),
                  'greedy': ([0.960, 23.640, 40.640, 53.360, 62.640, 68.840],
                             [0.480, 1.439, 1.091, 1.428, 1.155, 1.031]),
                  'pi': ([1.400, 22.360, 39.320, 52.000, 59.920, 64.920],
                         [0.219, 2.888, 1.874, 1.649, 1.269, 1.505]),
                  'thompson': ([0.760, 13.160, 28.440, 45.720, 56.840, 66.040],
                               [0.320, 2.951, 2.877, 1.613, 1.222, 0.686]),
                  'ucb': ([0.920, 23.680, 39.240, 52.560, 61.560, 68.400],
                          [0.574, 0.412, 1.347, 0.916, 0.674, 0.438])}},
 'nn': {'avg': {'ei': ([78.627, 94.222, 97.266, 98.122, 98.761, 99.076],
                       [0.344, 0.329, 0.160, 0.385, 0.081, 0.118]),
                'greedy': ([78.708, 94.617, 97.765, 98.687, 99.129, 99.388],
                           [0.156, 0.297, 0.097, 0.074, 0.066, 0.048]),
                'pi': ([78.537, 94.298, 97.360, 98.191, 98.762, 99.077],
                       [0.168, 0.460, 0.255, 0.531, 0.287, 0.160]),
                'thompson': ([79.186, 93.843, 97.374, 98.559, 99.075, 99.350],
                             [0.345, 0.349, 0.213, 0.103, 0.048, 0.072]),
                'ucb': ([78.919, 94.173, 97.583, 98.633, 99.114, 99.385],
                        [0.236, 0.248, 0.186, 0.045, 0.071, 0.040])},
        'scores': {'ei': ([0.321, 22.200, 40.200, 50.080, 59.920, 66.080],
                          [0.350, 1.939, 1.152, 4.045, 1.478, 2.963]),
                   'greedy': ([0.520, 25.720, 45.600, 58.640, 67.920, 74.760],
                              [0.450, 1.746, 1.437, 1.546, 1.760, 1.127]),
                   'pi': ([0.200, 23.920, 41.960, 52.160, 60.640, 67.200],
                          [0.509, 2.664, 2.330, 7.040, 5.278, 4.002]),
                   'thompson': ([0.961, 21.560, 41.400, 56.320, 66.200, 73.360],
                                [0.599, 2.289, 2.527, 2.088, 1.233, 2.257]),
                   'ucb': ([0.682, 23.120, 44.480, 58.080, 67.120, 74.440],
                           [0.515, 2.042, 1.714, 0.917, 1.929, 1.371])},
        'smis': {'ei': ([0.760, 21.040, 38.120, 47.240, 56.560, 62.200],
                        [0.196, 2.156, 1.237, 4.035, 1.641, 2.912]),
                 'greedy': ([0.800, 24.560, 43.240, 55.360, 63.960, 70.120],
                            [0.253, 1.699, 1.335, 1.439, 1.718, 1.107]),
                 'pi': ([0.840, 22.640, 39.800, 49.360, 57.160, 63.080],
                        [0.265, 2.381, 1.850, 6.306, 4.739, 3.530]),
                 'thompson': ([1.200, 20.560, 39.560, 53.000, 62.200, 68.920],
                              [0.456, 2.214, 2.313, 1.720, 1.403, 2.286]),
                 'ucb': ([0.920, 21.640, 41.520, 54.280, 62.920, 70.040],
                         [0.483, 2.221, 1.866, 1.204, 1.709, 1.189])}},
 'rf': {'avg': {'ei': ([78.750, 91.025, 94.630, 96.294, 97.066, 97.625],
                       [0.194, 1.128, 0.984, 0.699, 0.419, 0.193]),
                'greedy': ([78.896, 93.759, 96.954, 97.977, 98.428, 98.744],
                           [0.356, 0.460, 0.248, 0.173, 0.168, 0.154]),
                'pi': ([78.797, 92.304, 95.889, 96.980, 97.623, 97.924],
                       [0.245, 0.950, 0.263, 0.108, 0.136, 0.151]),
                'thompson': ([78.623, 89.782, 93.667, 95.524, 96.717, 97.487],
                             [0.529, 0.448, 0.268, 0.245, 0.202, 0.226]),
                'ucb': ([78.800, 92.206, 95.625, 96.862, 97.713, 98.160],
                        [0.498, 1.071, 0.428, 0.155, 0.121, 0.106])},
        'scores': {'ei': ([0.439, 11.320, 21.000, 30.320, 36.160, 41.880],
                          [0.837, 3.900, 5.547, 5.814, 4.869, 2.685]),
                   'greedy': ([0.762, 21.200, 37.640, 47.600, 53.600, 59.120],
                              [0.389, 2.605, 2.949, 2.245, 2.577, 2.892]),
                   'pi': ([0.400, 14.720, 28.120, 34.960, 41.760, 45.480],
                          [0.782, 3.971, 1.870, 1.127, 1.704, 2.355]),
                   'thompson': ([0.841, 7.240, 15.960, 24.320, 32.440, 39.840],
                                [0.571, 1.317, 1.405, 2.328, 2.330, 2.949]),
                   'ucb': ([0.760, 14.680, 27.240, 35.160, 43.360, 49.040],
                           [1.340, 4.135, 2.772, 1.359, 1.235, 1.394])},
        'smis': {'ei': ([0.840, 11.040, 20.400, 29.040, 34.600, 40.120],
                        [0.496, 3.908, 5.684, 5.946, 5.055, 2.700]),
                 'greedy': ([1.080, 20.080, 35.440, 44.520, 50.160, 55.080],
                            [0.392, 2.551, 3.018, 2.560, 2.877, 2.968]),
                 'pi': ([0.880, 14.080, 26.960, 33.560, 40.040, 43.360],
                        [0.574, 3.715, 1.822, 1.286, 1.592, 2.189]),
                 'thompson': ([1.160, 6.840, 15.040, 23.000, 30.760, 37.640],
                              [0.196, 1.305, 1.607, 2.356, 2.368, 2.932]),
                 'ucb': ([1.120, 14.080, 26.200, 33.840, 41.760, 46.880],
                         [0.826, 4.015, 2.647, 1.286, 0.814, 1.287])}}}

In [None]:
HTS_004_random = {
    'avg': ([83.817, 86.459, 87.815, 88.779, 89.488, 90.094],
            [0.137, 0.061, 0.116, 0.112, 0.145, 0.147]),
    'scores': ([0.440, 0.880, 1.300, 1.800, 2.220, 2.620],
               [0.215, 0.172, 0.219, 0.261, 0.232, 0.147]),
    'smis': ([0.420, 0.860, 1.260, 1.700, 2.080, 2.420],
             [0.214, 0.174, 0.265, 0.290, 0.172, 0.075]),
}

In [None]:
#top-1000
HTS_004_online = {
 'mpn': {'avg': {'ei': ([83.684, 99.075, 99.483, 99.628, 99.679, 99.712],
                        [0.098, 0.046, 0.029, 0.024, 0.039, 0.049]),
                 'greedy': ([83.934, 99.126, 99.582, 99.758, 99.830, 99.871],
                            [0.101, 0.073, 0.030, 0.028, 0.024, 0.016]),
                 'pi': ([83.815, 99.116, 99.525, 99.620, 99.669, 99.701],
                        [0.247, 0.151, 0.052, 0.054, 0.041, 0.037]),
                 'thompson': ([83.902, 97.755, 99.391, 99.715, 99.822, 99.884],
                              [0.106, 0.210, 0.066, 0.024, 0.007, 0.017]),
                 'ucb': ([83.781, 99.205, 99.678, 99.819, 99.874, 99.913],
                         [0.060, 0.098, 0.037, 0.034, 0.018, 0.013])},
         'scores': {'ei': ([0.500, 68.580, 81.040, 86.680, 88.880, 90.200],
                           [0.276, 1.269, 1.454, 1.026, 1.452, 1.716]),
                    'greedy': ([0.480, 70.540, 86.540, 92.200, 94.140, 95.240],
                               [0.172, 1.752, 0.983, 0.603, 0.595, 0.413]),
                    'pi': ([0.480, 70.440, 83.440, 87.620, 89.540, 90.540],
                           [0.183, 3.158, 1.860, 2.040, 1.389, 1.106]),
                    'thompson': ([0.500, 42.840, 79.520, 91.260, 94.020, 95.780],
                                 [0.179, 3.456, 2.577, 0.768, 0.117, 0.306]),
                    'ucb': ([0.540, 72.380, 89.720, 93.960, 95.460, 96.680],
                            [0.162, 2.153, 1.081, 0.831, 0.535, 0.387])},
         'smis': {'ei': ([0.440, 65.120, 77.060, 82.380, 84.420, 85.720],
                         [0.206, 1.315, 1.204, 0.924, 1.388, 1.683]),
                  'greedy': ([0.420, 67.000, 82.040, 88.560, 91.480, 93.200],
                             [0.133, 1.664, 1.063, 1.080, 0.861, 0.540]),
                  'pi': ([0.420, 66.800, 79.340, 83.380, 85.300, 86.600],
                         [0.160, 3.082, 1.640, 2.004, 1.463, 1.337]),
                  'thompson': ([0.500, 40.700, 75.560, 87.400, 91.480, 93.720],
                               [0.179, 3.130, 2.613, 1.099, 0.515, 0.546]),
                  'ucb': ([0.480, 68.780, 85.220, 90.840, 93.000, 94.500],
                          [0.194, 2.245, 1.237, 1.191, 0.687, 0.352])}},
 'nn': {'avg': {'ei': ([83.784, 97.804, 97.898, 98.614, 98.821, 98.874],
                       [0.276, 0.058, 0.067, 0.193, 0.267, 0.237]),
                'greedy': ([83.707, 97.845, 99.122, 99.599, 99.701, 99.786],
                           [0.062, 0.256, 0.172, 0.067, 0.046, 0.033]),
                'pi': ([83.759, 97.862, 97.896, 98.340, 98.519, 98.583],
                       [0.134, 0.068, 0.080, 0.233, 0.510, 0.552]),
                'thompson': ([83.789, 96.998, 98.501, 98.907, 99.180, 99.275],
                             [0.145, 0.110, 0.619, 0.388, 0.398, 0.399]),
                'ucb': ([83.789, 97.953, 98.888, 99.354, 99.533, 99.693],
                        [0.053, 0.167, 0.415, 0.191, 0.096, 0.032])},
        'scores': {'ei': ([0.540, 43.420, 44.920, 58.800, 63.500, 64.940],
                          [0.196, 0.808, 1.521, 4.155, 6.591, 5.569]),
                   'greedy': ([0.400, 45.760, 72.200, 87.480, 90.860, 93.020],
                              [0.063, 4.216, 4.731, 2.044, 1.188, 0.842]),
                   'pi': ([0.320, 44.920, 45.500, 53.880, 58.420, 60.280],
                          [0.172, 1.455, 1.784, 4.895, 12.256, 14.546]),
                   'thompson': ([0.440, 32.600, 57.960, 66.640, 74.600, 77.500],
                                [0.233, 1.558, 11.872, 10.615, 11.430, 11.338]),
                   'ucb': ([0.380, 46.520, 65.280, 78.360, 84.400, 90.100],
                           [0.223, 2.727, 9.609, 6.412, 3.852, 1.056])},
        'smis': {'ei': ([0.540, 41.320, 42.820, 55.640, 60.220, 61.520],
                        [0.196, 0.631, 1.359, 3.723, 6.355, 5.538]),
                 'greedy': ([0.380, 43.340, 68.300, 82.860, 86.860, 89.760],
                            [0.075, 3.975, 4.689, 2.258, 1.513, 0.954]),
                 'pi': ([0.300, 42.620, 43.160, 50.800, 55.200, 57.040],
                        [0.190, 1.134, 1.604, 4.689, 11.611, 13.848]),
                 'thompson': ([0.400, 31.020, 54.960, 63.260, 70.680, 73.380],
                              [0.253, 1.373, 11.216, 9.827, 10.458, 10.548]),
                 'ucb': ([0.360, 44.160, 62.320, 74.260, 80.180, 86.080],
                         [0.185, 2.548, 9.264, 6.075, 3.677, 1.607])}},
 'rf': {'avg': {'ei': ([83.809, 95.565, 96.119, 97.158, 97.618, 97.927],
                       [0.109, 0.474, 0.711, 0.613, 0.324, 0.098]),
                'greedy': ([83.846, 97.655, 98.781, 99.146, 99.346, 99.455],
                           [0.152, 0.139, 0.107, 0.141, 0.083, 0.050]),
                'pi': ([83.904, 96.265, 96.694, 97.153, 97.593, 97.795],
                       [0.134, 0.646, 0.646, 0.574, 0.484, 0.410]),
                'thompson': ([83.830, 95.328, 97.106, 98.056, 98.502, 98.694],
                             [0.182, 0.240, 0.131, 0.110, 0.115, 0.071]),
                'ucb': ([83.776, 96.838, 97.310, 98.029, 98.443, 98.589],
                        [0.195, 0.375, 0.540, 0.348, 0.115, 0.137])},
        'scores': {'ei': ([0.460, 21.540, 26.260, 35.580, 40.640, 44.980],
                          [0.102, 3.571, 6.112, 6.739, 4.589, 1.759]),
                   'greedy': ([0.420, 40.640, 61.000, 70.380, 76.240, 80.620],
                              [0.160, 1.796, 2.791, 3.816, 3.090, 2.268]),
                   'pi': ([0.260, 26.920, 30.640, 35.260, 40.520, 43.180],
                          [0.102, 5.416, 6.246, 6.260, 5.662, 5.568]),
                   'thompson': ([0.380, 19.380, 35.080, 47.660, 55.380, 59.980],
                                [0.117, 1.408, 1.174, 2.461, 2.475, 2.025]),
                   'ucb': ([0.380, 31.440, 37.160, 46.460, 53.540, 56.360],
                           [0.204, 3.910, 6.637, 6.309, 2.153, 2.579])},
        'smis': {'ei': ([0.440, 20.780, 25.380, 34.300, 39.300, 43.460],
                        [0.102, 3.498, 6.019, 6.710, 4.732, 1.911]),
                 'greedy': ([0.400, 38.700, 58.260, 67.140, 72.520, 76.480],
                            [0.179, 1.628, 2.743, 3.781, 2.815, 2.060]),
                 'pi': ([0.240, 26.020, 29.660, 34.100, 39.180, 41.760],
                        [0.102, 5.202, 6.021, 6.005, 5.364, 5.324]),
                 'thompson': ([0.340, 18.540, 33.600, 45.480, 52.860, 57.100],
                              [0.185, 1.476, 1.170, 2.307, 2.331, 1.833]),
                 'ucb': ([0.360, 30.280, 35.760, 44.580, 51.260, 53.980],
                         [0.174, 3.869, 6.344, 5.977, 1.918, 2.374])}}}

In [None]:
HTS_004_retrain = {
 'mpn': {'avg': {'ei': ([83.904, 99.143, 99.633, 99.791, 99.855, 99.891],
                        [0.087, 0.113, 0.050, 0.046, 0.025, 0.017]),
                 'greedy': ([83.858, 99.129, 99.688, 99.862, 99.921, 99.941],
                            [0.239, 0.107, 0.023, 0.008, 0.008, 0.005]),
                 'pi': ([83.788, 99.074, 99.670, 99.807, 99.864, 99.904],
                        [0.164, 0.210, 0.020, 0.018, 0.019, 0.017]),
                 'thompson': ([83.788, 97.784, 99.410, 99.730, 99.869, 99.920],
                              [0.070, 0.356, 0.032, 0.018, 0.022, 0.013]),
                 'ucb': ([83.813, 99.217, 99.757, 99.885, 99.926, 99.944],
                         [0.065, 0.065, 0.035, 0.012, 0.013, 0.013])},
         'scores': {'ei': ([0.460, 70.260, 86.960, 92.900, 94.840, 95.820],
                           [0.206, 2.702, 1.803, 1.431, 0.768, 0.614]),
                    'greedy': ([0.460, 70.840, 90.480, 94.960, 96.700, 97.680],
                               [0.120, 2.351, 0.652, 0.242, 0.290, 0.232]),
                    'pi': ([0.380, 68.900, 88.720, 93.300, 94.920, 96.160],
                           [0.172, 4.281, 1.144, 0.569, 0.546, 0.575]),
                    'thompson': ([0.320, 43.020, 80.680, 91.660, 95.160, 96.740],
                                 [0.147, 5.601, 1.434, 0.516, 0.496, 0.338]),
                    'ucb': ([0.460, 72.160, 92.120, 95.660, 97.220, 98.080],
                            [0.174, 1.204, 1.034, 0.436, 0.462, 0.421])},
         'smis': {'ei': ([0.440, 67.000, 82.920, 89.540, 92.300, 93.660],
                         [0.206, 2.452, 1.589, 2.083, 1.105, 0.671]),
                  'greedy': ([0.400, 67.300, 86.140, 92.300, 94.380, 94.760],
                             [0.089, 2.348, 0.736, 0.303, 0.098, 0.136]),
                  'pi': ([0.380, 65.440, 84.200, 89.720, 92.220, 93.940],
                         [0.172, 4.109, 0.851, 0.995, 0.933, 0.709]),
                  'thompson': ([0.300, 40.760, 76.340, 87.980, 92.780, 94.500],
                               [0.126, 5.166, 1.493, 0.760, 0.730, 0.126]),
                  'ucb': ([0.440, 68.560, 88.100, 93.240, 94.620, 94.880],
                          [0.185, 1.244, 1.114, 0.233, 0.160, 0.172])}},
 'nn': {'avg': {'ei': ([83.720, 97.690, 98.626, 98.796, 99.139, 99.327],
                       [0.097, 0.089, 0.049, 0.056, 0.028, 0.048]),
                'greedy': ([83.851, 97.667, 99.174, 99.598, 99.813, 99.894],
                           [0.081, 0.207, 0.080, 0.016, 0.002, 0.002]),
                'pi': ([83.843, 97.654, 98.284, 98.661, 98.961, 99.165],
                       [0.115, 0.176, 0.220, 0.196, 0.124, 0.071]),
                'thompson': ([83.851, 97.349, 99.056, 99.580, 99.763, 99.847],
                             [0.035, 0.193, 0.106, 0.047, 0.007, 0.011]),
                'ucb': ([83.838, 98.043, 99.259, 99.618, 99.741, 99.843],
                        [0.117, 0.138, 0.040, 0.029, 0.038, 0.017])},
        'scores': {'ei': ([0.460, 42.160, 59.000, 62.700, 70.767, 76.700],
                          [0.102, 1.428, 0.712, 0.993, 1.112, 2.304]),
                   'greedy': ([0.320, 42.180, 74.175, 88.200, 93.550, 95.700],
                              [0.194, 3.384, 1.875, 0.374, 0.050, 0.000]),
                   'pi': ([0.420, 41.600, 52.325, 60.100, 67.100, 71.875],
                          [0.160, 2.642, 4.471, 4.566, 3.237, 1.906]),
                   'thompson': ([0.520, 36.840, 68.780, 86.575, 92.467, 94.533],
                                [0.172, 2.825, 2.908, 2.200, 0.464, 0.249]),
                   'ucb': ([0.420, 48.100, 75.240, 88.460, 91.780, 94.380],
                           [0.075, 2.396, 1.109, 0.905, 0.983, 0.471])},
        'smis': {'ei': ([0.420, 39.960, 56.067, 59.600, 67.333, 72.900],
                        [0.117, 1.363, 0.713, 0.993, 0.988, 2.223]),
                 'greedy': ([0.320, 40.180, 70.225, 83.375, 90.350, 93.150],
                            [0.194, 3.020, 1.895, 0.630, 0.150, 0.050]),
                 'pi': ([0.340, 39.500, 49.850, 57.350, 64.000, 68.500],
                        [0.120, 2.551, 4.275, 4.542, 3.163, 1.826]),
                 'thompson': ([0.500, 35.140, 65.340, 81.925, 88.500, 91.567],
                              [0.179, 2.932, 2.851, 2.013, 0.616, 0.287]),
                 'ucb': ([0.400, 45.700, 71.260, 83.560, 87.920, 91.480],
                         [0.089, 2.293, 1.007, 0.950, 1.451, 0.783])}},
 'rf': {'avg': {'ei': ([83.961, 96.632, 97.269, 97.540, 97.751, 97.905],
                       [0.090, 0.565, 0.475, 0.363, 0.272, 0.263]),
                'greedy': ([83.898, 97.481, 98.662, 99.215, 99.414, 99.532],
                           [0.066, 0.332, 0.078, 0.060, 0.047, 0.024]),
                'pi': ([83.836, 96.361, 97.050, 97.435, 97.645, 97.798],
                       [0.143, 0.725, 0.379, 0.182, 0.192, 0.184]),
                'thompson': ([83.842, 95.147, 97.833, 98.700, 99.055, 99.256],
                             [0.145, 0.439, 0.108, 0.034, 0.041, 0.035]),
                'ucb': ([83.876, 96.993, 98.350, 98.723, 98.884, 99.033],
                        [0.197, 0.393, 0.171, 0.158, 0.137, 0.100])},
        'scores': {'ei': ([0.680, 30.020, 36.180, 39.660, 42.420, 44.840],
                          [0.172, 5.605, 5.639, 4.786, 3.899, 3.979]),
                   'greedy': ([0.380, 38.380, 58.060, 72.000, 79.420, 84.300],
                              [0.117, 4.407, 2.170, 1.601, 1.607, 1.108]),
                   'pi': ([0.300, 28.920, 34.640, 38.860, 41.440, 43.500],
                          [0.141, 5.284, 3.733, 2.217, 2.391, 2.601]),
                   'thompson': ([0.420, 18.860, 43.260, 59.740, 68.300, 74.080],
                                [0.117, 3.526, 1.904, 0.796, 0.974, 1.040]),
                   'ucb': ([0.300, 33.220, 52.220, 60.080, 64.060, 68.160],
                           [0.167, 4.175, 3.466, 4.154, 3.682, 2.738])},
        'smis': {'ei': ([0.640, 29.040, 34.940, 38.300, 40.940, 43.220],
                        [0.162, 5.434, 5.494, 4.627, 3.815, 3.860]),
                 'greedy': ([0.340, 36.700, 55.400, 68.520, 75.340, 79.800],
                            [0.102, 4.283, 1.869, 1.474, 1.553, 0.867]),
                 'pi': ([0.280, 28.000, 33.660, 37.700, 40.200, 42.200],
                        [0.147, 5.164, 3.697, 2.136, 2.393, 2.668]),
                 'thompson': ([0.380, 18.160, 41.340, 56.840, 64.920, 70.340],
                              [0.160, 3.615, 1.930, 0.907, 1.148, 1.222]),
                 'ucb': ([0.300, 32.120, 50.380, 57.680, 61.380, 65.160],
                         [0.167, 4.023, 3.323, 3.999, 3.490, 2.628])}}}

In [None]:
HTS_002_random = {
    'avg': ([80.697, 83.715, 85.348, 86.373, 87.173, 87.751],
            [0.234, 0.153, 0.163, 0.179, 0.175, 0.144]),
    'scores': ([0.120, 0.340, 0.660, 0.840, 1.140, 1.340],
               [0.075, 0.102, 0.102, 0.185, 0.301, 0.383]),
    'smis': ([0.100, 0.320, 0.640, 0.800, 1.100, 1.300],
             [0.063, 0.098, 0.120, 0.155, 0.261, 0.341]),
}

In [None]:
HTS_002_online = {
 'mpn': {'avg': {'ei': ([80.955, 98.023, 99.010, 99.327, 99.417, 99.484],
                        [0.165, 0.166, 0.087, 0.103, 0.082, 0.074]),
                 'greedy': ([80.828, 97.961, 98.929, 99.364, 99.522, 99.619],
                            [0.091, 0.206, 0.060, 0.052, 0.028, 0.043]),
                 'pi': ([80.777, 98.000, 98.989, 99.277, 99.376, 99.413],
                        [0.195, 0.243, 0.094, 0.084, 0.087, 0.084]),
                 'thompson': ([80.899, 95.969, 98.664, 99.335, 99.527, 99.646],
                              [0.265, 0.255, 0.138, 0.047, 0.026, 0.013]),
                 'ucb': ([80.732, 98.118, 99.030, 99.444, 99.560, 99.665],
                         [0.228, 0.117, 0.106, 0.069, 0.044, 0.016])},
         'scores': {'ei': ([0.220, 45.800, 65.220, 74.560, 77.780, 80.520],
                           [0.133, 2.781, 2.374, 3.102, 3.062, 3.081]),
                    'greedy': ([0.120, 46.280, 65.700, 79.000, 85.040, 88.740],
                               [0.075, 3.178, 1.431, 1.519, 0.723, 1.306]),
                    'pi': ([0.180, 46.160, 65.620, 72.700, 76.200, 77.760],
                           [0.117, 3.992, 2.035, 2.373, 2.768, 2.826]),
                    'thompson': ([0.320, 25.060, 59.180, 77.300, 84.840, 89.280],
                                 [0.183, 1.810, 2.112, 1.119, 0.869, 0.299]),
                    'ucb': ([0.180, 48.420, 67.040, 80.720, 85.740, 89.500],
                            [0.075, 1.901, 2.284, 2.248, 1.493, 0.341])},
         'smis': {'ei': ([0.220, 44.060, 62.600, 71.340, 74.280, 76.800],
                         [0.133, 2.619, 2.224, 2.895, 2.948, 2.887]),
                  'greedy': ([0.120, 44.340, 62.600, 75.060, 80.640, 84.120],
                             [0.075, 2.889, 1.178, 1.311, 0.695, 1.477]),
                  'pi': ([0.180, 44.440, 62.860, 69.620, 72.900, 74.300],
                         [0.117, 3.811, 1.946, 2.176, 2.651, 2.653]),
                  'thompson': ([0.300, 23.960, 56.420, 73.500, 80.520, 84.860],
                               [0.167, 1.765, 2.007, 1.071, 0.652, 0.393]),
                  'ucb': ([0.160, 46.380, 64.040, 76.760, 81.420, 85.080],
                          [0.080, 1.604, 2.095, 2.201, 1.379, 0.564])}},
 'nn': {'avg': {'ei': ([80.669, 96.461, 96.763, 97.207, 97.426, 97.494],
                       [0.070, 0.091, 0.430, 0.890, 0.753, 0.795]),
                'greedy': ([80.792, 96.592, 98.399, 99.082, 99.370, 99.473],
                           [0.190, 0.522, 0.277, 0.109, 0.033, 0.028]),
                'pi': ([80.799, 96.474, 96.690, 97.178, 97.315, 97.458],
                       [0.139, 0.161, 0.302, 0.651, 0.558, 0.484]),
                'thompson': ([80.914, 95.463, 97.716, 98.516, 98.967, 99.189],
                             [0.144, 0.079, 0.758, 0.390, 0.253, 0.145]),
                'ucb': ([80.666, 96.664, 98.169, 98.655, 99.033, 99.132],
                        [0.158, 0.253, 0.628, 0.581, 0.325, 0.380])},
        'scores': {'ei': ([0.160, 28.560, 31.840, 37.980, 39.820, 40.940],
                          [0.049, 1.122, 4.955, 12.077, 10.884, 11.675]),
                   'greedy': ([0.200, 29.720, 54.240, 69.540, 78.200, 82.200],
                              [0.110, 5.704, 4.350, 2.277, 0.901, 0.762]),
                   'pi': ([0.180, 28.420, 30.740, 36.460, 37.720, 39.320],
                          [0.117, 1.342, 3.409, 8.817, 7.923, 7.133]),
                   'thompson': ([0.280, 20.600, 43.320, 56.540, 66.480, 72.480],
                                [0.183, 0.648, 9.902, 7.893, 6.285, 4.160]),
                   'ucb': ([0.220, 30.360, 50.260, 60.000, 67.740, 70.720],
                           [0.075, 3.484, 8.661, 9.605, 6.738, 8.550])},
        'smis': {'ei': ([0.140, 27.100, 30.360, 36.080, 37.840, 38.920],
                        [0.049, 0.914, 4.894, 11.509, 10.336, 11.124]),
                 'greedy': ([0.200, 28.140, 51.720, 66.340, 74.380, 78.060],
                            [0.110, 5.412, 4.293, 2.104, 0.877, 0.605]),
                 'pi': ([0.180, 27.020, 29.280, 34.700, 35.860, 37.380],
                        [0.117, 1.335, 3.323, 8.474, 7.672, 6.940]),
                 'thompson': ([0.240, 19.620, 41.460, 54.040, 63.460, 68.920],
                              [0.196, 0.435, 9.593, 7.510, 5.951, 3.854]),
                 'ucb': ([0.220, 29.020, 48.540, 57.800, 64.980, 67.860],
                         [0.075, 3.482, 8.569, 9.421, 6.631, 8.332])}},
 'rf': {'avg': {'ei': ([80.879, 94.098, 94.503, 95.673, 96.228, 96.462],
                       [0.133, 0.810, 0.549, 0.501, 0.638, 0.748]),
                'greedy': ([80.919, 96.169, 97.852, 98.565, 98.848, 99.012],
                           [0.170, 0.563, 0.377, 0.019, 0.090, 0.111]),
                'pi': ([80.847, 94.119, 94.628, 95.738, 96.478, 96.793],
                       [0.101, 1.029, 0.985, 0.918, 0.629, 0.600]),
                'thompson': ([80.707, 92.112, 94.474, 96.130, 97.011, 97.386],
                             [0.157, 0.195, 0.783, 0.684, 0.432, 0.312]),
                'ucb': ([80.739, 93.947, 95.344, 96.841, 97.265, 97.938],
                        [0.177, 3.115, 1.934, 0.972, 0.678, 0.092])},
        'scores': {'ei': ([0.220, 15.940, 17.560, 24.080, 27.960, 30.020],
                          [0.147, 4.237, 3.593, 3.756, 5.407, 7.074]),
                   'greedy': ([0.180, 26.780, 45.000, 56.200, 62.780, 66.920],
                              [0.117, 4.782, 5.383, 0.438, 2.155, 2.403]),
                   'pi': ([0.120, 15.260, 17.980, 23.900, 29.360, 32.300],
                          [0.098, 5.091, 4.810, 5.677, 4.980, 5.495]),
                   'thompson': ([0.200, 8.080, 17.400, 26.960, 34.800, 38.460],
                                [0.089, 0.770, 4.023, 5.549, 4.827, 4.257]),
                   'ucb': ([0.280, 18.680, 23.980, 33.360, 37.420, 45.760],
                           [0.133, 8.614, 9.668, 9.174, 7.951, 1.551])},
        'smis': {'ei': ([0.220, 15.420, 16.980, 23.240, 26.960, 28.960],
                        [0.147, 4.304, 3.711, 3.821, 5.478, 7.107]),
                 'greedy': ([0.180, 25.540, 43.240, 53.740, 60.080, 64.000],
                            [0.117, 4.558, 5.205, 0.546, 2.218, 2.370]),
                 'pi': ([0.100, 14.600, 17.320, 22.940, 28.260, 31.100],
                        [0.089, 5.004, 4.667, 5.295, 4.738, 5.187]),
                 'thompson': ([0.180, 7.680, 16.700, 25.820, 33.420, 36.940],
                              [0.075, 0.747, 3.888, 5.348, 4.574, 3.999]),
                 'ucb': ([0.260, 18.040, 23.200, 32.240, 36.080, 44.080],
                         [0.136, 8.520, 9.573, 9.059, 7.832, 1.427])}}}

In [None]:
HTS_002_retrain = {
 'mpn': {'avg': {'ei': ([80.939, 97.785, 99.102, 99.521, 99.655, 99.744],
                        [0.252, 0.273, 0.059, 0.021, 0.026, 0.014]),
                 'greedy': ([80.812, 98.144, 99.252, 99.566, 99.717, 99.814],
                            [0.130, 0.032, 0.057, 0.063, 0.043, 0.035]),
                 'pi': ([80.854, 97.914, 99.211, 99.550, 99.675, 99.759],
                        [0.180, 0.191, 0.092, 0.028, 0.019, 0.027]),
                 'thompson': ([80.785, 95.529, 98.555, 99.278, 99.550, 99.714],
                              [0.166, 0.962, 0.163, 0.033, 0.028, 0.014]),
                 'ucb': ([80.823, 98.160, 99.313, 99.634, 99.761, 99.836],
                         [0.134, 0.192, 0.037, 0.017, 0.014, 0.015])},
         'scores': {'ei': ([0.200, 43.380, 68.420, 82.060, 88.060, 91.680],
                           [0.063, 4.464, 1.388, 0.977, 1.076, 0.662]),
                    'greedy': ([0.220, 49.280, 73.940, 86.280, 91.180, 93.700],
                               [0.040, 0.571, 1.630, 1.885, 1.105, 0.892]),
                    'pi': ([0.220, 44.900, 70.720, 82.900, 88.280, 92.100],
                           [0.075, 2.575, 2.507, 1.495, 1.034, 0.707]),
                    'thompson': ([0.160, 23.300, 56.840, 75.140, 85.600, 90.760],
                                 [0.102, 5.548, 2.957, 1.125, 0.844, 0.418]),
                    'ucb': ([0.220, 49.160, 74.980, 87.680, 92.200, 94.260],
                            [0.075, 3.230, 1.254, 0.556, 0.456, 0.393])},
         'smis': {'ei': ([0.180, 41.700, 65.360, 78.340, 83.760, 87.380],
                         [0.075, 4.346, 1.392, 1.001, 1.177, 0.655]),
                  'greedy': ([0.220, 47.240, 70.320, 81.760, 87.240, 90.720],
                             [0.040, 0.615, 1.620, 1.940, 1.439, 0.950]),
                  'pi': ([0.200, 43.300, 67.920, 79.220, 84.160, 87.960],
                         [0.089, 2.377, 2.411, 1.399, 1.009, 0.931]),
                  'thompson': ([0.160, 22.220, 54.300, 71.340, 81.060, 86.740],
                               [0.102, 5.284, 2.669, 1.237, 0.723, 0.554]),
                  'ucb': ([0.220, 47.120, 71.520, 83.360, 88.520, 91.560],
                          [0.075, 3.191, 1.238, 0.320, 0.504, 0.476])}},
 'nn': {'avg': {'ei': ([80.988, 96.399, 97.699, 98.080, 98.414, 98.537],
                       [0.103, 0.202, 0.226, 0.238, 0.238, 0.228]),
                'greedy': ([80.730, 96.734, 98.570, 99.132, 99.447, 99.627],
                           [0.120, 0.258, 0.110, 0.084, 0.025, 0.027]),
                'pi': ([80.920, 96.366, 97.548, 98.034, 98.411, 98.674],
                       [0.124, 0.229, 0.280, 0.160, 0.178, 0.129]),
                'thompson': ([80.778, 95.606, 98.015, 98.892, 99.359, 99.532],
                             [0.194, 0.215, 0.102, 0.066, 0.024, 0.026]),
                'ucb': ([81.034, 96.737, 98.624, 99.203, 99.463, 99.586],
                        [0.134, 0.208, 0.121, 0.076, 0.057, 0.015])},
        'scores': {'ei': ([0.260, 27.560, 42.180, 48.120, 54.140, 56.640],
                          [0.196, 2.001, 3.159, 3.610, 4.268, 4.340]),
                   'greedy': ([0.300, 31.020, 58.380, 72.180, 82.600, 88.800],
                              [0.228, 2.805, 2.262, 2.009, 0.607, 0.751]),
                   'pi': ([0.260, 27.920, 40.300, 47.460, 53.560, 59.120],
                          [0.162, 2.212, 3.423, 2.451, 2.899, 3.115]),
                   'thompson': ([0.240, 21.580, 47.140, 65.300, 78.420, 85.040],
                                [0.150, 1.777, 1.617, 1.962, 1.333, 0.926]),
                   'ucb': ([0.200, 31.200, 58.640, 72.900, 81.940, 86.740],
                           [0.126, 2.427, 2.127, 1.874, 1.975, 0.543])},
        'smis': {'ei': ([0.260, 26.340, 40.380, 46.080, 51.700, 54.020],
                        [0.196, 2.150, 3.338, 3.633, 4.026, 4.074]),
                 'greedy': ([0.280, 29.400, 55.460, 68.440, 78.020, 83.920],
                            [0.223, 2.630, 2.310, 1.732, 0.511, 0.776]),
                 'pi': ([0.220, 26.560, 38.420, 45.300, 51.240, 56.600],
                        [0.172, 2.133, 3.101, 2.195, 2.681, 3.027]),
                 'thompson': ([0.240, 20.500, 44.940, 61.840, 74.200, 80.360],
                              [0.150, 1.872, 1.416, 2.154, 1.252, 0.887]),
                 'ucb': ([0.200, 29.720, 55.980, 69.480, 77.880, 82.140],
                         [0.126, 2.404, 2.113, 1.712, 1.766, 0.561])}},
 'rf': {'avg': {'ei': ([80.728, 94.054, 95.671, 96.353, 96.601, 96.865],
                       [0.150, 0.688, 0.552, 0.349, 0.376, 0.281]),
                'greedy': ([80.845, 95.389, 97.640, 98.471, 98.971, 99.231],
                           [0.146, 0.679, 0.287, 0.184, 0.104, 0.075]),
                'pi': ([80.847, 94.176, 95.481, 96.188, 96.347, 96.536],
                       [0.097, 0.527, 0.776, 0.634, 0.584, 0.522]),
                'thompson': ([80.946, 92.597, 96.079, 97.651, 98.230, 98.599],
                             [0.177, 0.375, 0.300, 0.059, 0.063, 0.048]),
                'ucb': ([80.848, 94.674, 97.064, 97.754, 98.019, 98.253],
                        [0.110, 0.521, 0.137, 0.092, 0.143, 0.151])},
        'scores': {'ei': ([0.160, 15.440, 23.760, 28.480, 30.620, 32.560],
                          [0.120, 3.469, 5.234, 3.417, 3.804, 3.064]),
                   'greedy': ([0.180, 20.940, 40.440, 54.280, 65.620, 72.280],
                              [0.075, 4.852, 3.657, 3.396, 2.386, 1.912]),
                   'pi': ([0.220, 15.740, 21.960, 26.760, 27.580, 29.280],
                          [0.075, 2.024, 4.126, 5.073, 5.028, 5.150]),
                   'thompson': ([0.340, 9.360, 26.020, 41.640, 50.580, 57.480],
                                [0.102, 1.286, 2.322, 0.731, 1.237, 1.353]),
                   'ucb': ([0.240, 18.800, 34.760, 43.080, 46.980, 50.980],
                           [0.102, 3.536, 1.499, 1.248, 2.698, 2.870])},
        'smis': {'ei': ([0.140, 14.740, 22.800, 27.360, 29.380, 31.280],
                        [0.102, 3.495, 5.246, 3.402, 3.771, 3.024]),
                 'greedy': ([0.180, 20.180, 38.800, 51.920, 62.720, 69.000],
                            [0.075, 4.710, 3.538, 3.276, 2.415, 1.889]),
                 'pi': ([0.220, 15.160, 21.240, 25.920, 26.740, 28.340],
                        [0.075, 1.712, 3.859, 4.882, 4.843, 4.987]),
                 'thompson': ([0.340, 8.940, 25.080, 40.000, 48.440, 54.840],
                              [0.102, 1.157, 2.203, 0.994, 1.375, 1.488]),
                 'ucb': ([0.220, 18.260, 33.560, 41.420, 45.140, 48.860],
                         [0.098, 3.729, 1.661, 1.280, 2.818, 2.892])}}}

In [None]:
HTS_001_random = {
    'avg': ([77.065, 80.761, 82.523, 83.792, 84.702, 85.363],
            [0.203, 0.176, 0.119, 0.111, 0.111, 0.108]),
    'scores': ([0.040, 0.120, 0.200, 0.300, 0.440, 0.600],
               [0.080, 0.147, 0.110, 0.200, 0.206, 0.219]),
    'smis': ([0.040, 0.120, 0.180, 0.240, 0.380, 0.520],
             [0.080, 0.147, 0.117, 0.136, 0.172, 0.172]),
}

In [None]:
HTS_001_online = {'mpn': {'avg': {'ei': ([77.011, 95.719, 97.870, 98.487, 98.804, 98.951],
                        [0.322, 0.385, 0.107, 0.105, 0.049, 0.061]),
                 'greedy': ([77.119, 95.799, 97.538, 98.648, 98.920, 99.168],
                            [0.170, 0.325, 0.241, 0.134, 0.181, 0.114]),
                 'pi': ([76.971, 96.051, 97.963, 98.569, 98.829, 98.972],
                        [0.173, 0.363, 0.151, 0.075, 0.086, 0.092]),
                 'thompson': ([76.989, 93.091, 97.319, 98.531, 98.990, 99.241],
                              [0.127, 0.342, 0.215, 0.149, 0.094, 0.034]),
                 'ucb': ([77.228, 96.095, 98.084, 98.855, 99.093, 99.327],
                         [0.071, 0.250, 0.192, 0.058, 0.073, 0.053])},
         'scores': {'ei': ([0.140, 26.500, 44.200, 53.640, 59.940, 63.800],
                           [0.049, 2.754, 1.472, 1.879, 1.237, 1.374]),
                    'greedy': ([0.080, 26.580, 40.760, 58.380, 65.260, 71.900],
                               [0.075, 2.431, 2.600, 2.795, 4.731, 3.341]),
                    'pi': ([0.060, 29.180, 45.760, 55.240, 60.740, 64.340],
                           [0.120, 2.161, 2.019, 1.120, 1.909, 1.768]),
                    'thompson': ([0.020, 12.600, 38.020, 56.140, 66.640, 73.340],
                                 [0.040, 1.568, 2.121, 2.667, 2.043, 0.960]),
                    'ucb': ([0.080, 28.940, 47.700, 62.240, 68.740, 75.780],
                            [0.075, 1.714, 2.711, 1.479, 1.995, 2.118])},
         'smis': {'ei': ([0.140, 25.800, 42.580, 51.540, 57.620, 61.320],
                         [0.049, 2.725, 1.422, 2.058, 1.359, 1.541]),
                  'greedy': ([0.080, 25.700, 39.260, 55.860, 62.260, 68.400],
                             [0.075, 2.347, 2.418, 2.617, 4.408, 3.455]),
                  'pi': ([0.060, 28.320, 44.080, 53.240, 58.500, 61.920],
                         [0.120, 2.130, 2.060, 1.098, 1.806, 1.775]),
                  'thompson': ([0.020, 12.280, 36.500, 53.720, 63.700, 70.040],
                               [0.040, 1.607, 2.013, 2.552, 2.131, 1.227]),
                  'ucb': ([0.080, 28.260, 45.980, 59.680, 65.900, 72.360],
                          [0.075, 1.664, 2.627, 1.393, 1.925, 1.804])}},
 'nn': {'avg': {'ei': ([77.223, 94.523, 94.856, 95.303, 95.696, 96.225],
                       [0.242, 0.313, 0.811, 1.265, 1.045, 0.978]),
                'greedy': ([77.295, 94.600, 97.088, 98.146, 98.702, 98.956],
                           [0.293, 0.494, 0.283, 0.147, 0.079, 0.087]),
                'pi': ([77.224, 94.173, 95.544, 97.069, 97.277, 97.387],
                       [0.142, 0.196, 0.636, 0.340, 0.179, 0.299]),
                'thompson': ([77.048, 93.544, 95.970, 96.949, 97.639, 97.898],
                             [0.180, 0.305, 1.197, 1.009, 0.531, 0.619]),
                'ucb': ([76.843, 94.479, 96.181, 96.965, 97.413, 97.903],
                        [0.188, 0.436, 1.259, 1.119, 0.994, 0.624])},
        'scores': {'ei': ([0.060, 17.320, 19.520, 22.480, 23.940, 27.400],
                          [0.049, 2.359, 5.685, 9.806, 9.066, 9.594]),
                   'greedy': ([0.160, 17.740, 35.400, 49.220, 59.540, 65.840],
                              [0.080, 3.028, 3.051, 2.116, 1.714, 2.226]),
                   'pi': ([0.180, 16.100, 23.960, 35.240, 37.340, 38.800],
                          [0.133, 0.853, 3.759, 3.433, 2.030, 3.569]),
                   'thompson': ([0.180, 13.580, 27.520, 35.360, 42.140, 46.100],
                                [0.117, 1.250, 6.867, 8.682, 7.335, 9.103]),
                   'ucb': ([0.040, 17.700, 29.480, 35.940, 40.740, 45.780],
                           [0.049, 2.784, 8.671, 10.293, 11.121, 8.874])},
        'smis': {'ei': ([0.060, 16.520, 18.680, 21.540, 22.880, 26.180],
                        [0.049, 2.301, 5.470, 9.528, 8.871, 9.425]),
                 'greedy': ([0.140, 16.780, 33.860, 47.220, 57.020, 63.180],
                            [0.080, 2.806, 2.958, 2.167, 1.577, 1.951]),
                 'pi': ([0.160, 15.300, 22.980, 33.860, 35.880, 37.320],
                        [0.136, 0.927, 3.645, 3.262, 1.790, 3.337]),
                 'thompson': ([0.160, 12.780, 26.400, 33.860, 40.500, 44.300],
                              [0.102, 1.329, 6.647, 8.350, 7.025, 8.794]),
                 'ucb': ([0.040, 16.940, 28.360, 34.480, 39.140, 43.860],
                         [0.049, 2.750, 8.468, 10.064, 10.847, 8.647])}},
 'rf': {'avg': {'ei': ([77.188, 90.755, 91.591, 93.206, 94.276, 94.899],
                       [0.121, 1.514, 1.641, 2.137, 1.553, 1.293]),
                'greedy': ([77.047, 93.235, 95.601, 96.695, 97.141, 97.598],
                           [0.115, 0.876, 0.358, 0.323, 0.308, 0.488]),
                'pi': ([77.085, 91.782, 93.112, 94.547, 95.315, 96.054],
                       [0.181, 0.882, 0.868, 0.658, 1.139, 0.852]),
                'thompson': ([77.190, 88.720, 91.614, 93.629, 94.869, 95.634],
                             [0.179, 0.460, 0.959, 0.822, 0.369, 0.329]),
                'ucb': ([77.003, 91.568, 93.413, 94.253, 95.221, 95.837],
                        [0.185, 1.298, 1.391, 1.428, 1.016, 0.906])},
        'scores': {'ei': ([0.140, 7.760, 9.700, 14.340, 17.880, 20.200],
                          [0.102, 3.388, 4.380, 6.545, 5.657, 5.511]),
                   'greedy': ([0.020, 13.200, 24.240, 31.340, 35.380, 40.980],
                              [0.040, 3.514, 3.463, 3.787, 3.727, 6.469]),
                   'pi': ([0.140, 10.020, 13.680, 18.880, 22.940, 27.120],
                          [0.136, 2.641, 2.846, 3.387, 6.511, 5.721]),
                   'thompson': ([0.160, 3.860, 9.680, 15.380, 20.380, 24.120],
                                [0.120, 0.786, 2.691, 3.413, 1.957, 1.629]),
                   'ucb': ([0.080, 10.000, 15.740, 18.500, 22.580, 26.240],
                           [0.040, 4.964, 5.188, 5.964, 6.033, 6.776])},
        'smis': {'ei': ([0.140, 7.440, 9.340, 13.920, 17.380, 19.580],
                        [0.102, 3.335, 4.309, 6.387, 5.450, 5.247]),
                 'greedy': ([0.020, 12.500, 23.360, 30.120, 34.100, 39.480],
                            [0.040, 3.416, 3.539, 3.868, 3.597, 6.155]),
                 'pi': ([0.140, 9.560, 13.140, 18.300, 22.320, 26.400],
                        [0.136, 2.546, 2.773, 3.372, 6.471, 5.628]),
                 'thompson': ([0.120, 3.680, 9.260, 14.800, 19.560, 23.120],
                              [0.098, 0.763, 2.515, 3.247, 1.847, 1.522]),
                 'ucb': ([0.080, 9.580, 15.140, 17.820, 21.860, 25.420],
                         [0.040, 4.803, 5.095, 5.837, 5.854, 6.569])}}}

In [None]:
HTS_001_retrain = {
 'mpn': {'avg': {'ei': ([77.087, 95.947, 97.954, 98.876, 99.213, 99.423],
                        [0.138, 0.431, 0.192, 0.055, 0.064, 0.026]),
                 'greedy': ([77.241, 96.071, 98.266, 98.916, 99.257, 99.470],
                            [0.117, 0.270, 0.115, 0.062, 0.022, 0.030]),
                 'pi': ([77.132, 95.993, 98.123, 98.830, 99.187, 99.407],
                        [0.158, 0.593, 0.093, 0.051, 0.041, 0.016]),
                 'thompson': ([76.906, 93.051, 97.070, 98.357, 98.923, 99.271],
                              [0.194, 0.523, 0.401, 0.177, 0.134, 0.035]),
                 'ucb': ([77.163, 95.855, 98.319, 99.052, 99.375, 99.536],
                         [0.179, 0.444, 0.129, 0.049, 0.031, 0.025])},
         'scores': {'ei': ([0.100, 28.080, 45.440, 61.580, 70.240, 77.000],
                           [0.063, 3.226, 2.664, 1.326, 1.360, 1.246]),
                    'greedy': ([0.120, 29.440, 51.540, 65.240, 74.660, 82.220],
                               [0.075, 1.623, 1.798, 1.372, 0.852, 0.928]),
                    'pi': ([0.120, 28.540, 47.440, 60.660, 69.780, 76.100],
                           [0.075, 4.280, 1.728, 1.469, 1.423, 0.867]),
                    'thompson': ([0.100, 12.580, 35.340, 52.980, 65.440, 74.800],
                                 [0.000, 2.104, 4.061, 3.383, 3.447, 1.862]),
                    'ucb': ([0.200, 27.260, 51.220, 66.920, 77.160, 83.980],
                            [0.167, 3.342, 2.420, 1.382, 1.084, 0.801])},
         'smis': {'ei': ([0.100, 27.240, 43.800, 59.240, 67.580, 73.880],
                         [0.063, 3.013, 2.598, 1.385, 1.367, 1.182]),
                  'greedy': ([0.120, 28.480, 49.380, 62.140, 70.940, 78.100],
                             [0.075, 1.476, 1.761, 1.255, 0.905, 0.829]),
                  'pi': ([0.120, 27.640, 45.600, 58.220, 66.940, 73.000],
                         [0.075, 4.101, 1.575, 1.595, 1.407, 0.746]),
                  'thompson': ([0.100, 12.200, 34.040, 50.680, 62.460, 71.200],
                               [0.000, 2.066, 3.568, 3.149, 3.231, 1.793]),
                  'ucb': ([0.200, 26.720, 49.320, 64.240, 73.680, 79.920],
                          [0.167, 3.161, 2.278, 1.322, 0.999, 0.708])}},
 'nn': {'avg': {'ei': ([77.132, 94.567, 96.391, 96.959, 97.412, 97.737],
                       [0.276, 0.295, 0.588, 0.619, 0.420, 0.287]),
                'greedy': ([76.881, 94.686, 97.377, 98.306, 98.775, 99.075],
                           [0.335, 0.513, 0.195, 0.132, 0.078, 0.085]),
                'pi': ([76.955, 94.528, 96.593, 97.227, 97.551, 97.935],
                       [0.157, 0.156, 0.491, 0.153, 0.163, 0.169]),
                'thompson': ([76.912, 93.766, 97.011, 98.079, 98.719, 99.053],
                             [0.189, 0.335, 0.196, 0.082, 0.054, 0.027]),
                'ucb': ([77.051, 95.030, 97.167, 98.245, 98.781, 99.044],
                        [0.209, 0.296, 0.363, 0.099, 0.070, 0.047])},
        'scores': {'ei': ([0.140, 17.920, 29.200, 34.320, 39.260, 43.300],
                          [0.150, 1.560, 4.996, 5.578, 4.445, 3.823]),
                   'greedy': ([0.160, 17.960, 39.060, 53.020, 63.020, 70.500],
                              [0.150, 3.215, 2.786, 2.316, 1.766, 1.785]),
                   'pi': ([0.100, 18.040, 31.480, 37.220, 41.060, 46.260],
                          [0.089, 1.017, 3.046, 1.653, 1.966, 2.276]),
                   'thompson': ([0.140, 14.440, 34.160, 48.080, 60.020, 67.960],
                                [0.102, 1.960, 2.741, 1.105, 1.120, 0.794]),
                   'ucb': ([0.120, 20.640, 36.300, 50.760, 61.460, 68.040],
                           [0.075, 2.246, 3.511, 1.652, 1.669, 0.911])},
        'smis': {'ei': ([0.120, 17.120, 27.960, 32.940, 37.680, 41.480],
                        [0.160, 1.548, 4.763, 5.262, 4.086, 3.531]),
                 'greedy': ([0.160, 16.860, 36.880, 50.400, 59.940, 66.880],
                            [0.150, 3.140, 2.595, 2.140, 1.638, 1.623]),
                 'pi': ([0.100, 17.060, 30.020, 35.460, 39.220, 44.180],
                        [0.089, 1.013, 2.908, 1.661, 1.873, 2.180]),
                 'thompson': ([0.140, 13.760, 32.640, 45.980, 57.320, 64.760],
                              [0.102, 1.927, 2.587, 1.053, 1.048, 0.750]),
                 'ucb': ([0.120, 19.660, 34.660, 48.620, 58.780, 64.840],
                         [0.075, 2.207, 3.383, 1.787, 1.845, 1.169])}},
 'rf': {'avg': {'ei': ([76.999, 89.700, 93.547, 94.633, 95.232, 95.616],
                       [0.209, 1.906, 0.796, 0.517, 0.566, 0.369]),
                'greedy': ([77.075, 93.313, 96.596, 97.451, 97.976, 98.540],
                           [0.239, 0.623, 0.171, 0.171, 0.068, 0.244]),
                'pi': ([77.111, 92.042, 94.651, 95.478, 95.927, 96.274],
                       [0.183, 1.109, 0.335, 0.548, 0.319, 0.367]),
                'thompson': ([77.073, 88.705, 93.618, 95.832, 96.847, 97.429],
                             [0.249, 0.226, 0.415, 0.222, 0.142, 0.176]),
                'ucb': ([77.273, 89.118, 94.814, 96.101, 96.935, 97.155],
                        [0.189, 4.484, 0.751, 0.431, 0.381, 0.419])},
        'scores': {'ei': ([0.060, 4.540, 14.020, 18.120, 20.760, 22.600],
                          [0.080, 2.620, 3.558, 3.170, 3.693, 2.744]),
                   'greedy': ([0.160, 14.100, 30.740, 39.000, 45.860, 55.800],
                              [0.102, 2.609, 1.515, 2.352, 1.701, 4.894]),
                   'pi': ([0.120, 10.600, 18.620, 22.720, 25.340, 27.920],
                          [0.098, 4.363, 2.107, 3.421, 2.072, 2.697]),
                   'thompson': ([0.100, 3.760, 15.140, 25.200, 32.660, 38.760],
                                [0.063, 0.882, 2.148, 1.858, 1.674, 2.483]),
                   'ucb': ([0.100, 6.900, 20.820, 27.400, 34.000, 36.160],
                           [0.089, 3.401, 3.022, 3.666, 3.795, 4.247])},
        'smis': {'ei': ([0.060, 4.380, 13.540, 17.480, 20.020, 21.860],
                        [0.080, 2.529, 3.391, 3.112, 3.570, 2.603]),
                 'greedy': ([0.120, 13.560, 29.400, 37.300, 43.800, 53.420],
                            [0.117, 2.372, 1.432, 2.249, 1.390, 4.640]),
                 'pi': ([0.120, 10.260, 18.000, 22.000, 24.540, 27.100],
                        [0.098, 4.414, 2.183, 3.362, 2.088, 2.739]),
                 'thompson': ([0.080, 3.700, 14.680, 24.440, 31.440, 37.100],
                              [0.075, 0.888, 2.134, 1.868, 1.765, 2.552]),
                 'ucb': ([0.100, 6.660, 20.060, 26.360, 32.820, 34.940],
                         [0.089, 3.313, 2.904, 3.588, 3.798, 4.206])}}}

In [None]:
# top-1000, replicate of 3
HTS_02_004 = {
 'mpn': {'avg': {'greedy': ([89.441, 99.326],
                            [0.082, 0.078])},
         'scores': {'greedy': ([2.067, 77.367],
                               [0.685, 2.145])},
         'smis': {'greedy': ([2.000, 73.500],
                             [0.638, 2.040])}},
 'nn': {'avg': {'greedy': ([89.414, 98.477],
                           [0.050, 0.023])},
        'scores': {'greedy': ([2.200, 56.400],
                              [0.356, 0.245])},
        'smis': {'greedy': ([2.100, 53.500],
                            [0.356, 0.245])}},
 'rf': {'avg': {'greedy': ([89.403, 97.857],
                           [0.139, 0.230])},
        'scores': {'greedy': ([1.633, 42.867],
                              [0.492, 3.229])},
        'smis': {'greedy': ([1.467, 40.833],
                            [0.419, 3.065])}}}

In [None]:
# top-1000, replicate of 3
HTS_004_02 = {
 'mpn': {'avg': {'greedy': ([83.876, 99.854],
                            [0.113, 0.032])},
         'scores': {'greedy': ([0.367, 94.933],
                               [0.047, 0.818])},
         'smis': {'greedy': ([0.367, 92.400],
                             [0.047, 0.726])}},
 'nn': {'avg': {'greedy': ([83.905, 99.506],
                           [0.179, 0.100])},
        'scores': {'greedy': ([0.567, 87.300],
                              [0.262, 1.980])},
        'smis': {'greedy': ([0.533, 82.700],
                            [0.249, 2.276])}},
 'rf': {'avg': {'greedy': ([83.805, 99.344], 
                           [0.116, 0.081])},
        'scores': {'greedy': ([0.467, 76.567],
                              [0.170, 2.788])},
        'smis': {'greedy': ([0.433, 72.833],
                            [0.189, 2.370])}}}

In [None]:
AmpC_001_random = {
 'avg': ([48.793, 54.653, 58.036, 60.576, 62.639, 64.436],
         [0.052, 0.040, 0.046, 0.030, 0.049, 0.053]),
 'scores': ([0.113, 0.221, 0.315, 0.422, 0.528, 0.632],
            [0.010, 0.007, 0.010, 0.021, 0.013, 0.006]),
 'smis': ([0.113, 0.221, 0.314, 0.421, 0.527, 0.631],
          [0.009, 0.006, 0.010, 0.020, 0.012, 0.005])}

In [None]:
AmpC_001_online = {
 'mpn': {'avg': {'ei': ([48.719, 87.471, 88.509, 91.554, 92.119, 94.148],
                        [0.002, 0.643, 0.330, 0.249, 0.382, 0.318]),
                 'greedy': ([48.792, 89.567, 91.003, 94.531, 95.176, 96.269],
                            [0.006, 0.571, 0.746, 0.300, 0.143, 0.332]),
                 'pi': ([48.727, 87.548, 89.254, 92.218, 93.118, 94.500],
                        [0.018, 0.123, 0.148, 0.118, 0.173, 0.130]),
                 'thompson': ([48.754, 87.180, 92.068, 95.072, 96.336, 97.212],
                              [0.020, 0.233, 0.640, 0.239, 0.294, 0.264]),
                 'ucb': ([48.725, 87.942, 90.297, 93.999, 94.498, 95.017],
                         [0.025, 0.092, 0.689, 0.265, 0.265, 0.765])},
         'scores': {'ei': ([0.092, 6.909, 8.430, 13.688, 15.743, 23.308],
                           [0.004, 0.955, 0.830, 0.942, 1.629, 1.730]),
                    'greedy': ([0.086, 9.327, 13.062, 24.671, 28.444, 36.536],
                               [0.002, 1.317, 1.752, 1.603, 0.864, 2.844]),
                    'pi': ([0.109, 6.942, 9.574, 15.919, 19.329, 25.243],
                           [0.003, 0.846, 0.922, 0.961, 0.897, 0.809]),
                    'thompson': ([0.099, 6.130, 15.966, 28.510, 37.591, 45.871],
                                 [0.003, 0.032, 1.956, 1.280, 2.453, 2.891]),
                    'ucb': ([0.103, 9.069, 13.482, 23.994, 26.445, 29.092],
                            [0.003, 0.273, 1.798, 1.620, 1.731, 4.300])},
         'smis': {'ei': ([0.091, 6.904, 8.424, 13.680, 15.734, 23.291],
                         [0.005, 0.952, 0.828, 0.942, 1.628, 1.727]),
                  'greedy': ([0.086, 9.322, 13.056, 24.656, 28.429, 36.514],
                             [0.002, 1.314, 1.750, 1.604, 0.863, 2.840]),
                  'pi': ([0.109, 6.937, 9.568, 15.902, 19.309, 25.219],
                         [0.003, 0.847, 0.922, 0.962, 0.899, 0.811]),
                  'thompson': ([0.099, 6.125, 15.956, 28.488, 37.563, 45.836],
                               [0.003, 0.033, 1.954, 1.278, 2.453, 2.892]),
                  'ucb': ([0.102, 9.064, 13.476, 23.978, 26.427, 29.071],
                          [0.002, 0.272, 1.796, 1.614, 1.725, 4.291])}},
 'nn': {'avg': {'ei': ([48.770, 87.837, 87.842, 88.916, 88.999, 89.010],
                       [0.010, 0.282, 0.278, 1.292, 1.393, 1.392]),
                'greedy': ([48.711, 88.540, 90.187, 91.955, 92.945, 93.277],
                           [0.044, 0.238, 0.740, 0.126, 0.220, 0.666]),
                'pi': ([48.777, 87.822, 88.418, 89.057, 89.068, 89.245],
                       [0.018, 0.279, 0.999, 1.865, 1.862, 1.873]),
                'thompson': ([48.770, 88.125, 88.815, 91.440, 92.197, 93.158],
                             [0.025, 0.318, 0.160, 0.337, 0.347, 0.238]),
                'ucb': ([48.725, 88.218, 88.356, 89.263, 89.616, 90.164],
                        [0.015, 0.055, 0.102, 0.940, 1.246, 1.783])},
        'scores': {'ei': ([0.102, 6.172, 6.177, 7.699, 7.841, 7.852],
                          [0.012, 0.202, 0.198, 1.974, 2.167, 2.173]),
                   'greedy': ([0.097, 7.398, 10.118, 13.457, 16.550, 17.945],
                              [0.004, 0.143, 1.560, 0.504, 0.795, 2.718]),
                   'pi': ([0.097, 6.469, 7.346, 8.454, 8.461, 8.625],
                          [0.003, 0.367, 1.454, 2.993, 2.994, 3.159]),
                   'thompson': ([0.103, 7.253, 8.081, 12.188, 13.921, 16.940],
                                [0.020, 0.225, 0.507, 0.818, 0.737, 0.894]),
                   'ucb': ([0.095, 7.289, 7.387, 8.283, 8.860, 10.020],
                           [0.009, 0.169, 0.203, 1.057, 1.594, 2.744])},
        'smis': {'ei': ([0.102, 6.161, 6.165, 7.687, 7.829, 7.839],
                        [0.012, 0.205, 0.200, 1.971, 2.163, 2.169]),
                 'greedy': ([0.097, 7.389, 10.107, 13.443, 16.531, 17.924],
                            [0.004, 0.141, 1.557, 0.503, 0.797, 2.718]),
                 'pi': ([0.097, 6.462, 7.338, 8.446, 8.453, 8.617],
                        [0.003, 0.366, 1.455, 2.995, 2.996, 3.160]),
                 'thompson': ([0.103, 7.247, 8.074, 12.175, 13.907, 16.925],
                              [0.020, 0.224, 0.507, 0.816, 0.736, 0.893]),
                 'ucb': ([0.095, 7.280, 7.378, 8.274, 8.851, 10.011],
                         [0.009, 0.168, 0.202, 1.056, 1.593, 2.743])}},
 'rf': {'avg': {'ei': ([48.786, 85.084, 85.670, 88.656, 88.782, 89.748],
                       [0.011, 0.538, 0.456, 0.673, 0.601, 1.216]),
                'greedy': ([48.744, 88.730, 89.659, 91.901, 92.037, 93.013],
                           [0.029, 0.066, 0.132, 0.099, 0.084, 0.090]),
                'pi': ([48.722, 84.991, 85.261, 88.422, 88.644, 89.244],
                       [0.080, 1.046, 1.083, 0.462, 0.651, 0.588]),
                'thompson': ([48.796, 84.519, 84.777, 89.450, 89.491, 91.249],
                             [0.036, 0.371, 0.297, 0.198, 0.187, 0.103]),
                'ucb': ([48.721, 85.789, 86.115, 89.014, 89.457, 90.494],
                        [0.050, 0.519, 0.559, 0.449, 0.326, 0.957])},
        'scores': {'ei': ([0.087, 5.282, 5.557, 7.711, 7.851, 9.399],
                          [0.023, 0.556, 0.540, 0.558, 0.599, 1.775]),
                   'greedy': ([0.095, 6.779, 8.435, 12.539, 12.965, 15.893],
                              [0.003, 0.386, 0.425, 0.305, 0.123, 0.443]),
                   'pi': ([0.104, 4.763, 4.971, 7.375, 7.655, 8.226],
                          [0.009, 0.461, 0.541, 1.141, 1.342, 1.090]),
                   'thompson': ([0.100, 3.874, 4.013, 7.945, 8.029, 11.051],
                                [0.006, 0.044, 0.111, 0.114, 0.114, 0.118]),
                   'ucb': ([0.099, 4.741, 4.980, 7.754, 8.429, 10.671],
                           [0.014, 0.809, 0.793, 0.495, 0.273, 1.594])},
        'smis': {'ei': ([0.087, 5.281, 5.556, 7.709, 7.849, 9.395],
                        [0.023, 0.556, 0.539, 0.558, 0.599, 1.772]),
                 'greedy': ([0.095, 6.770, 8.425, 12.526, 12.951, 15.879],
                            [0.003, 0.385, 0.426, 0.304, 0.123, 0.443]),
                 'pi': ([0.104, 4.761, 4.970, 7.373, 7.653, 8.223],
                        [0.009, 0.460, 0.540, 1.140, 1.340, 1.088]),
                 'thompson': ([0.100, 3.870, 4.009, 7.939, 8.023, 11.044],
                              [0.006, 0.048, 0.114, 0.111, 0.112, 0.121]),
                 'ucb': ([0.099, 4.740, 4.978, 7.751, 8.425, 10.667],
                         [0.014, 0.810, 0.795, 0.495, 0.272, 1.593])}}}

In [None]:
AmpC_001_retrain = {
 'mpn': {'avg': {'ei': ([48.753, 87.369, 90.782, 93.038, 94.460, 95.385],
                        [0.012, 1.004, 0.514, 0.435, 0.298, 0.295]),
                 'greedy': ([48.716, 89.005, 94.439, 96.243, 97.120, 97.677],
                            [0.027, 0.259, 0.021, 0.098, 0.105, 0.041]),
                 'pi': ([48.784, 87.347, 90.939, 93.136, 94.490, 95.487],
                        [0.003, 0.334, 0.207, 0.329, 0.259, 0.175]),
                 'thompson': ([48.726, 87.683, 92.966, 95.363, 96.707, 97.533],
                              [0.041, 0.234, 0.089, 0.091, 0.091, 0.032]),
                 'ucb': ([48.757, 87.136, 92.169, 94.834, 96.397, 97.362],
                         [0.038, 0.888, 0.365, 0.079, 0.064, 0.050])},
         'scores': {'ei': ([0.095, 7.989, 13.075, 19.487, 25.430, 30.499],
                           [0.002, 1.286, 1.255, 1.548, 1.359, 1.795]),
                    'greedy': ([0.099, 8.230, 23.930, 35.933, 44.867, 51.994],
                               [0.009, 0.625, 0.276, 0.934, 1.238, 0.507]),
                    'pi': ([0.095, 6.857, 12.847, 19.564, 25.386, 31.132],
                           [0.010, 0.452, 0.949, 1.724, 1.701, 1.333]),
                    'thompson': ([0.101, 7.165, 18.356, 30.160, 40.716, 49.434],
                                 [0.013, 0.290, 0.332, 0.630, 0.785, 0.314]),
                    'ucb': ([0.085, 7.342, 16.827, 27.923, 38.273, 47.144],
                            [0.009, 1.458, 1.426, 0.439, 0.537, 0.490])},
         'smis': {'ei': ([0.095, 7.983, 13.065, 19.472, 25.411, 30.475],
                         [0.002, 1.287, 1.255, 1.548, 1.359, 1.796]),
                  'greedy': ([0.099, 8.227, 23.917, 35.911, 44.836, 51.958],
                             [0.009, 0.624, 0.276, 0.937, 1.240, 0.507]),
                  'pi': ([0.095, 6.851, 12.839, 19.549, 25.366, 31.106],
                         [0.010, 0.452, 0.950, 1.723, 1.700, 1.333]),
                  'thompson': ([0.101, 7.161, 18.345, 30.141, 40.687, 49.397],
                               [0.013, 0.288, 0.330, 0.628, 0.788, 0.315]),
                  'ucb': ([0.085, 7.335, 16.814, 27.901, 38.247, 47.111],
                          [0.009, 1.456, 1.427, 0.442, 0.531, 0.488])}},
 'nn': {'avg': {'ei': ([48.719, 88.003, 90.455, 91.357, 91.942, 92.414],
                       [0.008, 0.321, 0.568, 0.628, 0.080, 0.303]),
                'greedy': ([48.722, 88.826, 92.886, 94.420, 95.372, 95.999],
                           [0.022, 0.095, 0.150, 0.064, 0.051, 0.035]),
                'pi': ([48.753, 88.109, 90.675, 91.120, 91.700, 92.722],
                       [0.041, 0.371, 0.337, 0.227, 0.451, 0.428]),
                'thompson': ([48.799, 87.707, 92.071, 93.928, 95.026, 95.731],
                             [0.053, 0.443, 0.219, 0.094, 0.058, 0.100]),
                'ucb': ([48.766, 87.963, 92.267, 94.083, 95.107, 95.800],
                        [0.057, 0.190, 0.030, 0.071, 0.103, 0.070])},
        'scores': {'ei': ([0.099, 6.889, 10.001, 11.855, 13.142, 14.499],
                          [0.006, 0.431, 0.979, 1.409, 0.288, 0.828]),
                   'greedy': ([0.094, 8.027, 16.143, 22.403, 28.315, 33.251],
                              [0.018, 0.277, 0.541, 0.337, 0.447, 0.329]),
                   'pi': ([0.111, 6.864, 10.086, 10.975, 12.335, 15.207],
                          [0.008, 0.526, 0.777, 0.659, 1.031, 1.421]),
                   'thompson': ([0.116, 6.352, 13.411, 19.985, 25.988, 31.043],
                                [0.003, 0.428, 0.697, 0.540, 0.403, 0.783]),
                   'ucb': ([0.101, 6.774, 14.020, 20.736, 26.458, 31.547],
                           [0.012, 0.350, 0.081, 0.351, 0.762, 0.608])},
        'smis': {'ei': ([0.099, 6.881, 9.990, 11.841, 13.127, 14.482],
                        [0.006, 0.431, 0.978, 1.408, 0.291, 0.824]),
                 'greedy': ([0.094, 8.017, 16.127, 22.386, 28.293, 33.225],
                            [0.018, 0.274, 0.539, 0.334, 0.447, 0.331]),
                 'pi': ([0.111, 6.855, 10.075, 10.963, 12.323, 15.192],
                        [0.008, 0.528, 0.779, 0.660, 1.030, 1.419]),
                 'thompson': ([0.115, 6.346, 13.398, 19.966, 25.967, 31.017],
                              [0.003, 0.428, 0.694, 0.537, 0.401, 0.781]),
                 'ucb': ([0.101, 6.766, 14.006, 20.714, 26.433, 31.520],
                         [0.012, 0.350, 0.080, 0.350, 0.762, 0.609])}},
 'rf': {'avg': {'ei': ([48.737, 84.648, 87.105, 88.320, 89.255, 90.028],
                       [0.028, 0.555, 0.317, 0.715, 0.924, 1.030]),
                'greedy': ([48.778, 88.889, 91.863, 93.255, 94.062, 94.764],
                           [0.042, 0.197, 0.331, 0.252, 0.366, 0.346]),
                'pi': ([48.765, 85.709, 88.188, 88.990, 89.615, 90.443],
                       [0.028, 0.410, 0.435, 0.412, 0.487, 0.539]),
                'thompson': ([48.759, 84.805, 89.590, 91.976, 93.251, 94.076],
                             [0.044, 0.533, 0.682, 0.783, 0.532, 0.381]),
                'ucb': ([48.767, 86.557, 88.683, 89.745, 91.040, 92.007],
                        [0.016, 0.804, 0.541, 0.052, 0.596, 0.421])},
        'scores': {'ei': ([0.103, 4.907, 6.023, 7.016, 8.475, 9.755],
                          [0.005, 0.783, 0.828, 0.830, 1.949, 2.338]),
                   'greedy': ([0.094, 6.791, 12.151, 16.664, 20.207, 24.003],
                              [0.016, 0.564, 1.272, 1.320, 1.936, 2.233]),
                   'pi': ([0.101, 5.143, 6.630, 7.634, 8.439, 9.752],
                          [0.010, 0.269, 0.691, 0.731, 1.026, 1.264]),
                   'thompson': ([0.106, 4.057, 7.879, 12.683, 16.671, 20.072],
                                [0.014, 0.436, 1.566, 2.461, 2.255, 2.048]),
                   'ucb': ([0.096, 5.465, 7.327, 8.596, 11.537, 13.808],
                           [0.013, 1.562, 1.404, 1.016, 1.049, 0.959])},
        'smis': {'ei': ([0.103, 4.906, 6.021, 7.012, 8.471, 9.751],
                        [0.005, 0.784, 0.828, 0.830, 1.949, 2.338]),
                 'greedy': ([0.094, 6.783, 12.137, 16.647, 20.187, 23.981],
                            [0.016, 0.563, 1.272, 1.320, 1.938, 2.233]),
                 'pi': ([0.101, 5.141, 6.629, 7.631, 8.437, 9.749],
                        [0.010, 0.269, 0.691, 0.731, 1.026, 1.263]),
                 'thompson': ([0.106, 4.055, 7.875, 12.678, 16.661, 20.060],
                              [0.014, 0.436, 1.565, 2.461, 2.252, 2.048]),
                 'ucb': ([0.095, 5.463, 7.323, 8.591, 11.530, 13.801],
                         [0.012, 1.559, 1.402, 1.015, 1.047, 0.957])}}}

In [None]:
AmpC_002_random = {
 'avg': ([54.602, 60.453, 64.333, 67.379, 69.976, 72.228],
         [0.062, 0.051, 0.084, 0.103, 0.101, 0.096]),
 'scores': ([0.209, 0.398, 0.595, 0.805, 1.010, 1.218],
            [0.008, 0.011, 0.031, 0.028, 0.027, 0.048]),
 'smis': ([0.209, 0.398, 0.595, 0.805, 1.009, 1.217],
          [0.008, 0.011, 0.031, 0.028, 0.028, 0.048])}

In [None]:
AmpC_002_online = {
 'mpn': {'avg': {'ei': ([54.571, 91.517, 92.589, 95.318, 95.959, 96.977],
                        [0.041, 0.372, 0.245, 0.192, 0.237, 0.232]),
                 'greedy': ([54.624, 92.892, 95.324, 97.289, 97.400, 97.414],
                            [0.024, 0.276, 0.531, 0.346, 0.247, 0.242]),
                 'pi': ([54.621, 91.464, 92.734, 95.337, 95.989, 96.911],
                        [0.019, 0.109, 0.187, 0.029, 0.044, 0.079]),
                 'thompson': ([54.582, 91.630, 95.726, 97.507, 98.157, 98.460],
                              [0.042, 0.388, 0.139, 0.065, 0.057, 0.186]),
                 'ucb': ([54.576, 91.625, 92.986, 95.907, 96.452, 96.783],
                         [0.071, 0.358, 0.579, 0.282, 0.450, 0.706])},
         'scores': {'ei': ([0.207, 14.079, 17.473, 29.541, 34.071, 42.694],
                           [0.022, 0.995, 0.818, 1.125, 1.759, 2.409]),
                    'greedy': ([0.185, 16.261, 29.485, 46.891, 48.132, 48.314],
                               [0.016, 1.357, 3.478, 4.064, 2.999, 2.952]),
                    'pi': ([0.196, 13.624, 17.801, 29.745, 34.260, 41.929],
                           [0.012, 0.132, 0.973, 0.335, 0.446, 0.709]),
                    'thompson': ([0.195, 13.073, 32.226, 49.138, 58.151, 63.145],
                                 [0.011, 1.157, 1.111, 0.905, 0.913, 3.135]),
                    'ucb': ([0.207, 13.601, 18.778, 33.919, 38.360, 41.698],
                            [0.015, 1.149, 2.469, 2.298, 3.826, 6.658])},
         'smis': {'ei': ([0.207, 14.072, 17.464, 29.521, 34.045, 42.663],
                         [0.022, 0.991, 0.812, 1.123, 1.756, 2.409]),
                  'greedy': ([0.185, 16.247, 29.464, 46.855, 48.095, 48.277],
                             [0.016, 1.357, 3.473, 4.059, 2.994, 2.948]),
                  'pi': ([0.196, 13.612, 17.787, 29.722, 34.232, 41.896],
                         [0.012, 0.132, 0.971, 0.334, 0.444, 0.712]),
                  'thompson': ([0.195, 13.063, 32.199, 49.105, 58.113, 63.101],
                               [0.011, 1.156, 1.109, 0.903, 0.916, 3.130]),
                  'ucb': ([0.207, 13.593, 18.764, 33.895, 38.335, 41.671],
                          [0.015, 1.148, 2.471, 2.299, 3.826, 6.657])}},
 'nn': {'avg': {'ei': ([54.597, 91.333, 91.343, 91.408, 91.465, 91.483],
                       [0.032, 0.236, 0.237, 0.270, 0.235, 0.222]),
                'greedy': ([54.630, 91.824, 92.551, 93.746, 93.954, 94.689],
                           [0.013, 0.144, 0.374, 1.182, 1.125, 1.092]),
                'pi': ([54.567, 91.526, 91.530, 92.158, 92.220, 92.256],
                       [0.070, 0.037, 0.039, 0.911, 0.871, 0.876]),
                'thompson': ([54.615, 91.061, 91.435, 92.282, 92.527, 92.828],
                             [0.035, 0.311, 0.363, 1.476, 1.707, 2.123]),
                'ucb': ([54.575, 91.296, 91.623, 92.667, 92.889, 93.286],
                        [0.030, 0.140, 0.043, 0.815, 0.836, 1.139])},
        'scores': {'ei': ([0.193, 11.553, 11.569, 11.633, 11.676, 11.685],
                          [0.028, 0.392, 0.390, 0.415, 0.388, 0.385]),
                   'greedy': ([0.195, 13.037, 15.312, 20.443, 21.219, 24.866],
                              [0.003, 0.184, 1.149, 4.744, 4.866, 6.017]),
                   'pi': ([0.225, 12.030, 12.042, 14.131, 14.196, 14.237],
                          [0.017, 0.082, 0.087, 3.018, 2.983, 3.019]),
                   'thompson': ([0.200, 10.845, 11.717, 14.773, 15.933, 18.031],
                                [0.013, 1.072, 1.135, 5.271, 6.819, 9.773]),
                   'ucb': ([0.198, 11.369, 12.125, 15.184, 15.862, 17.622],
                           [0.008, 0.441, 0.132, 2.578, 2.672, 4.159])},
        'smis': {'ei': ([0.193, 11.540, 11.556, 11.619, 11.663, 11.672],
                        [0.028, 0.394, 0.392, 0.417, 0.390, 0.387]),
                 'greedy': ([0.195, 13.026, 15.300, 20.423, 21.199, 24.844],
                            [0.003, 0.185, 1.148, 4.737, 4.859, 6.009]),
                 'pi': ([0.225, 12.018, 12.030, 14.118, 14.183, 14.224],
                        [0.017, 0.080, 0.085, 3.014, 2.978, 3.014]),
                 'thompson': ([0.200, 10.833, 11.705, 14.758, 15.918, 18.014],
                              [0.013, 1.069, 1.133, 5.267, 6.814, 9.765]),
                 'ucb': ([0.198, 11.359, 12.113, 15.172, 15.849, 17.607],
                         [0.008, 0.438, 0.131, 2.577, 2.670, 4.154])}},
 'rf': {'avg': {'ei': ([54.629, 89.144, 89.227, 91.656, 92.087, 93.031],
                       [0.087, 0.188, 0.232, 0.699, 0.685, 0.658]),
                'greedy': ([54.536, 92.144, 92.650, 94.400, 94.558, 95.669],
                           [0.057, 0.098, 0.241, 0.259, 0.301, 0.237]),
                'pi': ([54.610, 89.187, 89.371, 91.292, 92.151, 92.964],
                       [0.031, 0.329, 0.436, 0.998, 0.086, 0.628]),
                'thompson': ([54.625, 90.027, 90.498, 93.554, 93.597, 94.755],
                             [0.013, 0.180, 0.230, 0.066, 0.101, 0.543]),
                'ucb': ([54.594, 89.654, 89.835, 92.342, 92.389, 93.227],
                        [0.025, 0.172, 0.339, 0.288, 0.291, 0.459])},
        'scores': {'ei': ([0.199, 8.636, 8.759, 13.615, 14.649, 17.320],
                          [0.007, 0.196, 0.230, 1.773, 2.046, 2.277]),
                   'greedy': ([0.221, 12.919, 14.668, 21.807, 22.691, 30.355],
                              [0.021, 0.395, 0.860, 1.676, 1.968, 1.965]),
                   'pi': ([0.183, 8.001, 8.319, 12.207, 14.028, 16.668],
                          [0.008, 0.421, 0.661, 2.141, 0.210, 2.331]),
                   'thompson': ([0.217, 8.603, 9.431, 18.118, 18.305, 24.219],
                                [0.021, 0.256, 0.409, 0.211, 0.281, 3.282]),
                   'ucb': ([0.187, 7.604, 8.030, 13.806, 13.936, 16.763],
                           [0.016, 0.107, 0.448, 0.763, 0.787, 1.779])},
        'smis': {'ei': ([0.199, 8.634, 8.757, 13.610, 14.643, 17.314],
                        [0.007, 0.197, 0.232, 1.770, 2.043, 2.273]),
                 'greedy': ([0.221, 12.905, 14.651, 21.789, 22.673, 30.333],
                            [0.020, 0.396, 0.861, 1.678, 1.969, 1.966]),
                 'pi': ([0.183, 7.999, 8.317, 12.203, 14.024, 16.661],
                        [0.008, 0.422, 0.662, 2.142, 0.211, 2.330]),
                 'thompson': ([0.217, 8.597, 9.425, 18.104, 18.291, 24.201],
                              [0.021, 0.257, 0.412, 0.211, 0.281, 3.279]),
                 'ucb': ([0.187, 7.601, 8.027, 13.798, 13.928, 16.754],
                         [0.016, 0.108, 0.449, 0.759, 0.782, 1.777])}}}

In [None]:
AmpC_002_retrain = {
 'mpn': {'avg': {'ei': ([54.595, 91.112, 94.196, 95.839, 96.951, 97.716],
                        [0.015, 0.096, 0.148, 0.097, 0.041, 0.023]),
                 'greedy': ([54.542, 92.580, 96.611, 97.855, 98.455, 98.614],
                            [0.022, 0.211, 0.183, 0.147, 0.126, 0.094]),
                 'pi': ([54.575, 91.188, 94.499, 96.142, 97.266, 97.842],
                        [0.056, 0.330, 0.192, 0.149, 0.094, 0.044]),
                 'thompson': ([54.585, 91.780, 95.884, 97.551, 98.062, 98.678],
                              [0.030, 0.259, 0.116, 0.100, 0.313, 0.171]),
                 'ucb': ([54.634, 91.079, 95.379, 97.595, 98.680, 99.161],
                         [0.003, 0.306, 0.054, 0.040, 0.056, 0.029])},
         'scores': {'ei': ([0.196, 12.596, 23.377, 33.123, 42.395, 50.805],
                           [0.003, 0.451, 0.905, 0.935, 0.500, 0.292]),
                    'greedy': ([0.194, 14.833, 39.277, 54.836, 64.313, 67.093],
                               [0.022, 1.144, 1.899, 1.859, 1.703, 2.146]),
                    'pi': ([0.191, 12.791, 25.151, 35.516, 45.620, 52.303],
                           [0.008, 1.069, 1.172, 1.250, 0.999, 0.619]),
                    'thompson': ([0.193, 13.231, 33.189, 49.799, 57.181, 67.449],
                                 [0.022, 0.905, 0.913, 1.186, 4.715, 3.324]),
                    'ucb': ([0.203, 12.353, 30.471, 49.814, 65.689, 75.505],
                            [0.019, 0.623, 0.232, 0.507, 1.090, 0.711])},
         'smis': {'ei': ([0.196, 12.587, 23.359, 33.095, 42.361, 50.765],
                         [0.003, 0.449, 0.904, 0.933, 0.504, 0.292]),
                  'greedy': ([0.194, 14.821, 39.251, 54.798, 64.266, 67.044],
                             [0.022, 1.141, 1.895, 1.859, 1.702, 2.144]),
                  'pi': ([0.191, 12.784, 25.135, 35.494, 45.590, 52.271],
                         [0.008, 1.066, 1.170, 1.248, 0.994, 0.620]),
                  'thompson': ([0.193, 13.223, 33.170, 49.768, 57.143, 67.404],
                               [0.022, 0.904, 0.913, 1.184, 4.711, 3.322]),
                  'ucb': ([0.203, 12.346, 30.455, 49.784, 65.649, 75.457],
                          [0.019, 0.621, 0.229, 0.505, 1.088, 0.708])}},
 'nn': {'avg': {'ei': ([54.557, 91.470, 92.868, 93.674, 93.902, 94.746],
                       [0.029, 0.212, 0.625, 0.595, 0.427, 0.165]),
                'greedy': ([54.622, 92.035, 95.098, 96.409, 97.247, 97.721],
                           [0.037, 0.345, 0.115, 0.022, 0.036, 0.033]),
                'pi': ([54.648, 91.621, 93.386, 93.751, 94.064, 94.418],
                       [0.015, 0.156, 0.155, 0.265, 0.171, 0.194]),
                'thompson': ([54.606, 91.416, 94.613, 96.145, 96.919, 97.534],
                             [0.032, 0.100, 0.030, 0.027, 0.044, 0.068]),
                'ucb': ([54.567, 91.665, 94.862, 96.216, 96.919, 97.516],
                        [0.042, 0.259, 0.124, 0.157, 0.087, 0.040])},
        'scores': {'ei': ([0.216, 11.909, 16.061, 19.054, 19.995, 24.179],
                          [0.009, 0.636, 2.289, 2.487, 1.937, 0.969]),
                   'greedy': ([0.219, 13.545, 26.499, 36.979, 46.306, 52.826],
                              [0.016, 0.915, 0.752, 0.183, 0.474, 0.503]),
                   'pi': ([0.195, 12.300, 17.680, 19.214, 20.612, 22.334],
                          [0.014, 0.671, 0.699, 1.299, 0.980, 1.148]),
                   'thompson': ([0.189, 11.516, 23.353, 34.457, 42.294, 50.149],
                                [0.017, 0.328, 0.232, 0.239, 0.545, 1.021]),
                   'ucb': ([0.203, 12.597, 24.945, 35.081, 42.222, 49.799],
                           [0.003, 0.800, 0.771, 1.469, 0.957, 0.546])},
        'smis': {'ei': ([0.216, 11.897, 16.047, 19.039, 19.980, 24.161],
                        [0.009, 0.636, 2.291, 2.487, 1.938, 0.969]),
                 'greedy': ([0.218, 13.531, 26.478, 36.950, 46.269, 52.785],
                            [0.016, 0.911, 0.748, 0.182, 0.471, 0.500]),
                 'pi': ([0.195, 12.287, 17.665, 19.197, 20.595, 22.315],
                        [0.014, 0.674, 0.699, 1.297, 0.978, 1.144]),
                 'thompson': ([0.189, 11.504, 23.332, 34.425, 42.261, 50.109],
                              [0.017, 0.328, 0.228, 0.240, 0.541, 1.017]),
                 'ucb': ([0.203, 12.585, 24.922, 35.051, 42.186, 49.758],
                         [0.003, 0.798, 0.771, 1.466, 0.958, 0.547])}},
 'rf': {'avg': {'ei': ([54.588, 88.782, 89.924, 90.379, 91.597, 92.442],
                       [0.036, 0.406, 0.261, 0.159, 0.561, 0.853]),
                'greedy': ([54.618, 91.876, 94.532, 95.662, 96.569, 97.190],
                           [0.012, 0.364, 0.329, 0.316, 0.138, 0.138]),
                'pi': ([54.620, 89.807, 90.733, 91.349, 92.081, 92.829],
                       [0.020, 0.089, 0.547, 0.558, 0.706, 0.442]),
                'thompson': ([54.596, 89.959, 93.019, 94.819, 95.835, 96.798],
                             [0.020, 0.137, 0.343, 0.137, 0.172, 0.166]),
                'ucb': ([54.620, 89.696, 91.205, 91.982, 93.484, 94.807],
                        [0.053, 0.386, 0.455, 0.333, 0.220, 0.401])},
        'scores': {'ei': ([0.172, 6.993, 8.669, 9.440, 12.229, 14.629],
                          [0.011, 0.448, 0.659, 0.424, 1.458, 2.664]),
                   'greedy': ([0.195, 12.132, 22.873, 30.537, 38.577, 45.543],
                              [0.019, 1.504, 2.149, 2.490, 1.419, 1.806]),
                   'pi': ([0.201, 9.471, 11.030, 12.143, 13.815, 15.959],
                          [0.028, 0.737, 1.457, 1.455, 2.215, 1.649]),
                   'thompson': ([0.214, 8.539, 15.739, 23.995, 31.353, 40.843],
                                [0.017, 0.296, 0.983, 0.805, 1.430, 1.907]),
                   'ucb': ([0.207, 8.055, 10.563, 12.635, 17.811, 24.393],
                           [0.010, 0.935, 1.243, 0.885, 0.712, 1.955])},
        'smis': {'ei': ([0.172, 6.992, 8.665, 9.437, 12.224, 14.623],
                        [0.011, 0.449, 0.660, 0.424, 1.455, 2.661]),
                 'greedy': ([0.195, 12.122, 22.853, 30.512, 38.545, 45.505],
                            [0.019, 1.502, 2.147, 2.488, 1.422, 1.806]),
                 'pi': ([0.201, 9.468, 11.027, 12.139, 13.809, 15.952],
                        [0.027, 0.739, 1.456, 1.455, 2.213, 1.644]),
                 'thompson': ([0.214, 8.535, 15.729, 23.977, 31.329, 40.811],
                              [0.017, 0.293, 0.980, 0.804, 1.429, 1.902]),
                 'ucb': ([0.206, 8.052, 10.559, 12.631, 17.800, 24.379],
                         [0.009, 0.935, 1.242, 0.885, 0.710, 1.950])}}}

In [None]:
AmpC_004_random = {
 'avg': ([60.527, 67.424, 72.200, 75.909, 78.821, 81.033],
         [0.061, 0.060, 0.040, 0.074, 0.052, 0.040]),
 'scores': ([0.411, 0.787, 1.189, 1.585, 1.977, 2.381],
            [0.040, 0.046, 0.050, 0.057, 0.056, 0.076]),
 'smis': ([0.410, 0.785, 1.187, 1.583, 1.975, 2.379],
          [0.040, 0.046, 0.052, 0.059, 0.058, 0.079])}

In [None]:
AmpC_004_online = {
 'mpn': {'avg': {'ei': ([60.466, 94.494, 95.478, 97.411, 97.540, 98.413],
                        [0.058, 0.611, 0.476, 0.293, 0.231, 0.122]),
                 'greedy': ([60.481, 95.846, 97.267, 98.591, 98.695, 98.950],
                            [0.056, 0.194, 0.060, 0.050, 0.113, 0.286]),
                 'pi': ([60.478, 94.360, 95.272, 97.608, 97.815, 98.320],
                        [0.037, 0.090, 0.034, 0.024, 0.172, 0.509]),
                 'thompson': ([60.486, 95.013, 97.539, 98.664, 98.871, 98.989],
                              [0.032, 0.126, 0.199, 0.101, 0.193, 0.254]),
                 'ucb': ([60.448, 94.918, 96.968, 98.743, 99.017, 99.256],
                         [0.082, 0.060, 0.347, 0.133, 0.105, 0.246])},
         'scores': {'ei': ([0.391, 24.933, 30.864, 47.396, 48.764, 60.919],
                           [0.040, 3.515, 3.297, 3.405, 2.794, 2.176]),
                    'greedy': ([0.393, 32.399, 46.498, 66.301, 68.261, 73.729],
                               [0.007, 1.697, 0.804, 0.962, 1.985, 5.727]),
                    'pi': ([0.389, 23.393, 28.959, 49.654, 52.258, 60.230],
                           [0.012, 0.470, 0.172, 0.282, 2.309, 7.773]),
                    'thompson': ([0.390, 26.394, 49.765, 67.285, 71.391, 73.989],
                                 [0.025, 0.841, 2.535, 1.924, 3.871, 5.190]),
                    'ucb': ([0.405, 26.710, 43.082, 66.695, 71.859, 77.763],
                            [0.018, 0.439, 3.640, 2.436, 2.228, 5.817])},
         'smis': {'ei': ([0.391, 24.915, 30.841, 47.360, 48.727, 60.874],
                         [0.040, 3.509, 3.290, 3.399, 2.789, 2.172]),
                  'greedy': ([0.393, 32.385, 46.469, 66.251, 68.211, 73.673],
                             [0.007, 1.696, 0.800, 0.960, 1.982, 5.721]),
                  'pi': ([0.389, 23.381, 28.941, 49.618, 52.218, 60.187],
                         [0.012, 0.472, 0.173, 0.277, 2.307, 7.765]),
                  'thompson': ([0.390, 26.376, 49.730, 67.237, 71.338, 73.935],
                               [0.025, 0.836, 2.523, 1.917, 3.860, 5.180]),
                  'ucb': ([0.405, 26.693, 43.055, 66.657, 71.814, 77.713],
                          [0.018, 0.435, 3.636, 2.436, 2.227, 5.811])}},
 'nn': {'avg': {'ei': ([60.459, 94.279, 94.486, 95.705, 95.774, 96.059],
                       [0.031, 0.020, 0.308, 1.028, 1.088, 1.378]),
                'greedy': ([60.455, 94.584, 95.124, 96.749, 96.939, 97.589],
                           [0.054, 0.132, 0.127, 0.173, 0.215, 0.380]),
                'pi': ([60.572, 94.205, 94.217, 94.245, 94.270, 94.337],
                       [0.051, 0.282, 0.292, 0.280, 0.294, 0.304]),
                'thompson': ([60.470, 94.164, 94.781, 95.515, 95.593, 95.640],
                             [0.053, 0.167, 0.693, 1.114, 1.153, 1.115]),
                'ucb': ([60.532, 94.416, 94.513, 95.161, 95.212, 95.409],
                        [0.017, 0.013, 0.107, 1.006, 1.005, 1.164])},
        'scores': {'ei': ([0.394, 21.761, 22.985, 32.114, 32.817, 36.260],
                          [0.038, 0.064, 1.733, 7.701, 8.459, 12.559]),
                   'greedy': ([0.376, 23.454, 26.665, 40.370, 42.419, 50.717],
                              [0.009, 0.919, 0.873, 1.761, 2.376, 5.224]),
                   'pi': ([0.395, 21.414, 21.457, 21.534, 21.627, 21.930],
                          [0.025, 1.495, 1.523, 1.486, 1.580, 1.642]),
                   'thompson': ([0.417, 21.323, 24.937, 30.672, 31.360, 31.609],
                                [0.045, 0.965, 4.338, 8.148, 8.770, 8.641]),
                   'ucb': ([0.376, 22.499, 22.993, 27.906, 28.235, 30.047],
                           [0.013, 0.144, 0.693, 7.609, 7.716, 9.718])},
        'smis': {'ei': ([0.393, 21.740, 22.963, 32.085, 32.789, 36.228],
                        [0.038, 0.064, 1.733, 7.697, 8.456, 12.551]),
                 'greedy': ([0.376, 23.437, 26.645, 40.333, 42.380, 50.675],
                            [0.009, 0.921, 0.877, 1.759, 2.371, 5.221]),
                 'pi': ([0.395, 21.392, 21.435, 21.512, 21.605, 21.908],
                        [0.025, 1.495, 1.522, 1.486, 1.579, 1.641]),
                 'thompson': ([0.417, 21.299, 24.911, 30.643, 31.331, 31.580],
                              [0.045, 0.963, 4.336, 8.144, 8.766, 8.637]),
                 'ucb': ([0.375, 22.476, 22.970, 27.879, 28.207, 30.019],
                         [0.012, 0.144, 0.692, 7.603, 7.708, 9.710])}},
 'rf': {'avg': {'ei': ([60.522, 91.419, 91.608, 94.121, 94.694, 95.045],
                       [0.013, 0.353, 0.430, 1.535, 1.196, 1.123]),
                'greedy': ([60.495, 94.490, 94.722, 96.652, 96.762, 97.702],
                           [0.028, 0.309, 0.301, 0.134, 0.183, 0.255]),
                'pi': ([60.461, 91.558, 91.650, 93.211, 94.143, 94.814],
                       [0.079, 0.397, 0.365, 1.029, 0.194, 0.506]),
                'thompson': ([60.517, 93.728, 94.099, 96.371, 96.453, 97.595],
                             [0.103, 0.062, 0.119, 0.085, 0.088, 0.109]),
                'ucb': ([60.528, 92.941, 93.029, 95.785, 95.928, 96.291],
                        [0.043, 0.289, 0.316, 0.254, 0.188, 0.323])},
        'scores': {'ei': ([0.389, 11.376, 11.890, 22.675, 25.065, 27.063],
                          [0.029, 1.654, 1.852, 7.726, 7.844, 7.968]),
                   'greedy': ([0.365, 22.169, 23.573, 39.183, 40.344, 52.307],
                              [0.012, 1.888, 1.904, 1.502, 2.045, 3.572]),
                   'pi': ([0.390, 12.247, 12.473, 17.640, 20.956, 24.781],
                          [0.014, 1.124, 1.062, 3.321, 0.913, 2.875]),
                   'thompson': ([0.411, 18.621, 20.454, 36.401, 37.205, 50.696],
                                [0.005, 0.339, 0.633, 0.725, 0.858, 1.473]),
                   'ucb': ([0.418, 15.949, 16.311, 31.344, 32.459, 35.755],
                           [0.007, 1.212, 1.337, 2.073, 1.601, 3.087])},
        'smis': {'ei': ([0.387, 11.371, 11.885, 22.662, 25.050, 27.047],
                        [0.030, 1.651, 1.848, 7.717, 7.835, 7.958]),
                 'greedy': ([0.365, 22.149, 23.554, 39.155, 40.315, 52.269],
                            [0.012, 1.888, 1.904, 1.503, 2.048, 3.573]),
                 'pi': ([0.390, 12.241, 12.467, 17.631, 20.945, 24.765],
                        [0.014, 1.124, 1.062, 3.320, 0.911, 2.872]),
                 'thompson': ([0.411, 18.605, 20.437, 36.369, 37.173, 50.657],
                              [0.005, 0.340, 0.634, 0.725, 0.858, 1.470]),
                 'ucb': ([0.417, 15.946, 16.307, 31.324, 32.438, 35.732],
                         [0.008, 1.210, 1.335, 2.066, 1.594, 3.083])}}}

In [None]:
AmpC_004_retrain = {
 'mpn': {'avg': {'ei': ([60.526, 94.018, 96.738, 98.250, 98.729, 99.173],
                        [0.058, 0.329, 0.228, 0.095, 0.435, 0.239]),
                 'greedy': ([60.508, 95.474, 98.223, 99.013, 99.293, 99.556],
                            [0.007, 0.140, 0.138, 0.095, 0.134, 0.089]),
                 'pi': ([60.443, 94.290, 97.140, 98.334, 98.953, 99.307],
                        [0.054, 0.174, 0.100, 0.060, 0.064, 0.035]),
                 'thompson': ([60.500, 94.875, 97.803, 98.777, 99.285, 99.565],
                              [0.055, 0.145, 0.034, 0.048, 0.012, 0.009]),
                 'ucb': ([60.556, 95.143, 98.512, 99.461, 99.721, 99.828],
                         [0.029, 0.015, 0.052, 0.026, 0.006, 0.004])},
         'scores': {'ei': ([0.394, 22.079, 40.393, 58.185, 66.873, 75.440],
                           [0.033, 1.931, 2.213, 1.545, 7.860, 5.359]),
                    'greedy': ([0.385, 29.185, 60.558, 75.195, 81.395, 87.895],
                               [0.021, 1.212, 2.087, 1.613, 3.184, 2.317]),
                    'pi': ([0.378, 23.500, 44.471, 59.473, 70.519, 78.268],
                           [0.017, 1.006, 0.973, 1.059, 1.501, 0.863]),
                    'thompson': ([0.386, 25.513, 53.571, 69.967, 81.050, 87.947],
                                 [0.030, 1.179, 0.514, 0.938, 0.238, 0.201]),
                    'ucb': ([0.392, 28.151, 62.949, 83.239, 91.105, 94.764],
                            [0.030, 0.321, 0.929, 0.701, 0.132, 0.118])},
         'smis': {'ei': ([0.394, 22.063, 40.365, 58.145, 66.830, 75.390],
                         [0.033, 1.925, 2.207, 1.545, 7.858, 5.355]),
                  'greedy': ([0.384, 29.165, 60.514, 75.141, 81.341, 87.851],
                             [0.022, 1.213, 2.087, 1.613, 3.184, 2.327]),
                  'pi': ([0.378, 23.490, 44.446, 59.433, 70.473, 78.219],
                         [0.017, 1.010, 0.972, 1.057, 1.499, 0.861]),
                  'thompson': ([0.386, 25.493, 53.534, 69.916, 80.993, 87.906],
                               [0.030, 1.178, 0.515, 0.936, 0.234, 0.199]),
                  'ucb': ([0.391, 28.129, 62.905, 83.183, 91.061, 94.739],
                          [0.029, 0.323, 0.926, 0.700, 0.129, 0.117])}},
 'nn': {'avg': {'ei': ([60.506, 94.133, 95.036, 95.889, 96.540, 96.903],
                       [0.079, 0.237, 0.376, 0.061, 0.315, 0.162]),
                'greedy': ([60.496, 94.842, 97.097, 98.104, 98.633, 98.942],
                           [0.024, 0.109, 0.037, 0.025, 0.058, 0.076]),
                'pi': ([60.496, 94.195, 95.179, 96.362, 96.669, 97.075],
                       [0.029, 0.103, 0.020, 0.208, 0.164, 0.167]),
                'thompson': ([60.475, 94.145, 96.768, 97.878, 98.488, 98.923],
                             [0.083, 0.106, 0.039, 0.025, 0.068, 0.055]),
                'ucb': ([60.437, 94.247, 96.720, 97.910, 98.566, 98.646],
                        [0.038, 0.404, 0.146, 0.030, 0.015, 0.074])},
        'scores': {'ei': ([0.399, 20.869, 25.993, 32.073, 38.081, 41.799],
                          [0.011, 1.220, 2.474, 0.531, 3.171, 1.754]),
                   'greedy': ([0.395, 25.126, 44.356, 58.696, 68.289, 74.691],
                              [0.021, 0.780, 0.432, 0.337, 0.905, 1.441]),
                   'pi': ([0.415, 21.327, 26.951, 36.413, 39.425, 43.917],
                          [0.027, 0.470, 0.118, 2.067, 1.648, 2.065]),
                   'thompson': ([0.416, 20.933, 40.499, 54.969, 65.277, 73.772],
                                [0.021, 0.554, 0.404, 0.510, 1.314, 1.230]),
                   'ucb': ([0.395, 21.653, 40.113, 55.651, 66.853, 68.367],
                           [0.017, 2.239, 1.661, 0.520, 0.254, 1.449])},
        'smis': {'ei': ([0.399, 20.844, 25.964, 32.041, 38.043, 41.758],
                        [0.011, 1.220, 2.470, 0.530, 3.167, 1.751]),
                 'greedy': ([0.395, 25.100, 44.319, 58.648, 68.234, 74.630],
                            [0.021, 0.779, 0.434, 0.333, 0.901, 1.440]),
                 'pi': ([0.415, 21.309, 26.930, 36.387, 39.396, 43.883],
                        [0.027, 0.470, 0.119, 2.067, 1.647, 2.067]),
                 'thompson': ([0.416, 20.913, 40.463, 54.925, 65.225, 73.713],
                              [0.021, 0.559, 0.405, 0.507, 1.313, 1.229]),
                 'ucb': ([0.395, 21.632, 40.079, 55.607, 66.797, 68.311],
                         [0.017, 2.235, 1.662, 0.520, 0.253, 1.448])}},
 'rf': {'avg': {'ei': ([60.494, 91.279, 92.061, 92.995, 94.211, 95.468],
                       [0.056, 0.185, 0.220, 0.097, 0.165, 0.592]),
                'greedy': ([60.527, 94.457, 96.494, 97.626, 98.329, 98.789],
                           [0.039, 0.119, 0.239, 0.156, 0.179, 0.130]),
                'pi': ([60.491, 91.360, 91.996, 93.007, 93.592, 95.032],
                       [0.051, 0.082, 0.313, 0.594, 0.698, 0.710]),
                'thompson': ([60.466, 93.764, 96.031, 97.454, 98.158, 98.779],
                             [0.031, 0.105, 0.090, 0.068, 0.155, 0.100]),
                'ucb': ([60.457, 93.274, 94.043, 95.781, 96.885, 97.465],
                        [0.029, 0.374, 0.430, 0.500, 0.651, 0.582])},
        'scores': {'ei': ([0.384, 11.603, 13.287, 16.153, 21.417, 29.141],
                          [0.029, 0.825, 0.730, 0.341, 0.636, 4.427]),
                   'greedy': ([0.375, 21.977, 37.721, 51.615, 62.747, 71.376],
                              [0.015, 0.850, 2.496, 2.075, 2.715, 2.077]),
                   'pi': ([0.373, 11.414, 13.066, 16.087, 18.525, 26.371],
                          [0.021, 0.338, 1.070, 2.637, 3.363, 4.693]),
                   'thompson': ([0.437, 19.010, 33.323, 49.008, 59.905, 71.694],
                                [0.031, 0.531, 0.814, 0.941, 2.611, 1.943]),
                   'ucb': ([0.373, 17.559, 20.793, 31.487, 42.086, 49.165],
                           [0.027, 1.650, 2.073, 4.080, 7.028, 7.737])},
        'smis': {'ei': ([0.384, 11.599, 13.283, 16.147, 21.407, 29.125],
                        [0.029, 0.824, 0.730, 0.339, 0.635, 4.422]),
                 'greedy': ([0.374, 21.957, 37.689, 51.573, 62.699, 71.321],
                            [0.014, 0.850, 2.494, 2.076, 2.715, 2.075]),
                 'pi': ([0.372, 11.407, 13.059, 16.077, 18.514, 26.357],
                        [0.022, 0.339, 1.071, 2.637, 3.364, 4.693]),
                 'thompson': ([0.437, 18.996, 33.295, 48.973, 59.863, 71.644],
                              [0.031, 0.530, 0.812, 0.939, 2.609, 1.941]),
                 'ucb': ([0.373, 17.549, 20.780, 31.467, 42.059, 49.133],
                         [0.027, 1.649, 2.071, 4.074, 7.022, 7.728])}}}

In [None]:
all_results = {
    '10k': {
        1.0: {'online': E10k_online,
              'retrain': E10k_retrain,
              'random': E10k_random},
        'size': 10560,
        'topk': 100,
        'y_min': 75
    },
    '50k': {
        1.0: {'online': E50k_online,
              'retrain': E50k_retrain,
              'random': E50k_random},
        'size': 50240,
        'topk': 500,
        'y_min': 75
    },
    'HTS': {
        0.1: {'online': HTS_001_online,
              'retrain': HTS_001_retrain,
              'random': HTS_001_random},
        0.2: {'online': HTS_002_online,
              'retrain': HTS_002_retrain,
              'random': HTS_002_random},
        0.4: {'online': HTS_004_online,
              'retrain': HTS_004_retrain,
              'random': HTS_004_random},
        'size': 2141514,
        'topk': 1000,
        'y_min': 75
    },
    'AmpC': {
        0.1: {'online': AmpC_001_online,
              'retrain': AmpC_001_retrain,
              'random': AmpC_001_random},
        0.2: {'online': AmpC_002_online,
              'retrain': AmpC_002_retrain,
              'random': AmpC_002_random},
        0.4: {'online': AmpC_004_online,
              'retrain': AmpC_004_retrain,
              'random': AmpC_004_random},
        'size': 99459561,
        'topk': 50000,
        'y_min': 45
    }
}

In [None]:
# 0.1/0.1 split, top-1k
HTS_convergence = {
 'mpn': {'average': [76.890, 95.170, 98.299, 98.927, 99.275, 99.415, 99.539, 99.629],
         'smis': [0.0, 22.4, 50.0, 62.9, 72.4, 77.1, 81.4, 84.5],
         'scores': [0.0, 22.8, 51.8, 65.4, 76.0, 81.1, 85.9, 88.7]},
 'nn': {'average': [76.911, 94.357, 97.226, 98.204, 98.810, 99.018, 99.329, 99.445, 99.530],
        'smis': [0.0, 14.9, 34.6, 48.3, 60.4, 65.8, 73.9, 77.8, 80.6],
        'scores': [0.0, 15.7, 36.3, 50.7, 63.2, 69.3, 78.0, 82.2, 85.4]},
 'rf': {'average': [77.008, 92.383, 96.318, 97.508, 98.300, 98.766, 98.946, 99.154, 99.264, 99.320],
        'smis': [0.0, 9.0, 27.0, 37.6, 48.6, 57.9, 62.4, 67.2, 70.0, 71.9],
        'scores': [0.0, 9.2, 27.5, 38.8, 50.4, 60.4, 65.1, 70.7, 73.8, 75.8]}}

In [None]:
HTS_common_smis_by_iter_online = {
    0.1: {'mpn': {'ei': [0, 287, 410, 475, 542, 557],
                  'greedy': [0, 314, 359, 567, 591, 658],
                  'pi': [0, 291, 458, 500, 533, 542],
                  'thompson': [0, 8, 121, 365, 492, 616],
                  'ucb': [0, 320, 448, 572, 589, 690]},
          'nn': {'ei': [0, 76, 69, 54, 57, 77],
                 'greedy': [0, 202, 314, 464, 569, 597],
                 'pi': [0, 58, 105, 289, 309, 312],
                 'thompson': [0, 7, 68, 99, 167, 182],
                 'ucb': [0, 76, 98, 127, 127, 180]},
          'rf': {'ei': [0, 0, 0, 0, 0, 3],
                 'greedy': [0, 20, 41, 109, 160, 187],
                 'pi': [0, 0, 3, 21, 85, 98],
                 'thompson': [0, 0, 0, 7, 57, 109],
                 'ucb': [0, 10, 87, 82, 82, 85]}},
    0.2: {'mpn': {'ei': [0, 404, 551, 636, 640, 651],
                 'greedy': [0, 362, 577, 710, 727, 748],
                 'pi': [0, 393, 569, 646, 651, 644],
                 'thompson': [0, 15, 290, 551, 643, 722],
                 'ucb': [0, 385, 583, 701, 708, 746]},
          'nn': {'ei': [0, 145, 146, 142, 154, 162],
                 'greedy': [0, 320, 490, 647, 733, 767],
                 'pi': [0, 151, 154, 147, 159, 171],
                 'thompson': [0, 17, 123, 291, 456, 552],
                 'ucb': [0, 220, 384, 376, 495, 502]},
          'rf': {'ei': [0, 1, 2, 9, 12, 13],
                 'greedy': [0, 137, 219, 485, 529, 566],
                 'pi': [0, 1, 6, 20, 48, 58],
                 'thompson': [0, 0, 5, 40, 186, 241],
                 'ucb': [0, 0, 25, 68, 100, 300]}},
    0.4: {'mpn': {'ei': [0, 567, 594, 654, 689, 713],
                  'greedy': [0, 649, 709, 831, 888, 918],
                  'pi': [0, 583, 622, 680, 715, 730],
                  'thompson': [0, 74, 505, 745, 854, 912],
                  'ucb': [0, 646, 749, 845, 900, 930]},
          'nn': {'ei': [0, 223, 251, 364, 417, 446],
                 'greedy': [0, 519, 719, 759, 817, 866],
                 'pi': [0, 245, 251, 338, 291, 287],
                 'thompson': [0, 66, 206, 320, 401, 429],
                 'ucb': [0, 385, 460, 595, 671, 772]},
          'rf': {'ei': [0, 3, 9, 65, 118, 210],
                 'greedy': [0, 432, 538, 628, 666, 658],
                 'pi': [0, 75, 103, 135, 189, 210],
                 'thompson': [0, 2, 119, 273, 375, 384],
                 'ucb': [0, 23, 50, 187, 381, 420]}}
}

HTS_common_smis_by_iter_retrain = {
    0.1: {'mpn': {'ei': [0, 253, 446, 574, 626, 689],
                  'greedy': [0, 303, 576, 676, 758, 791],
                  'pi': [0, 298, 512, 616, 656, 715],
                  'thompson': [0, 6, 115, 283, 440, 601],
                  'ucb': [0, 296, 493, 660, 743, 773]},
          'nn': {'ei': [0, 74, 150, 204, 313, 366],
                 'greedy': [0, 201, 478, 668, 734, 779],
                 'pi': [0, 80, 237, 331, 367, 394],
                 'thompson': [0, 10, 180, 383, 562, 649],
                 'ucb': [0, 138, 368, 539, 678, 679]},
          'rf': {'ei': [0, 0, 0, 4, 8, 23],
                 'greedy': [0, 5, 51, 184, 269, 452],
                 'pi': [0, 0, 6, 74, 103, 116],
                 'thompson': [0, 0, 34, 209, 267, 348],
                 'ucb': [0, 0, 119, 193, 293, 309]}},
    0.2: {'mpn': {'ei': [0, 263, 570, 675, 704, 776],
                  'greedy': [0, 498, 708, 755, 820, 894],
                  'pi': [0, 327, 622, 688, 737, 797],
                  'thompson': [0, 10, 279, 549, 643, 757],
                  'ucb': [0, 475, 711, 762, 843, 894]},
          'nn': {'ei': [0, 140, 301, 361, 435, 452],
                 'greedy': [0, 434, 704, 815, 828, 871],
                 'pi': [0, 132, 275, 353, 428, 453],
                 'thompson': [0, 26, 276, 516, 682, 748],
                 'ucb': [0, 243, 567, 733, 753, 800]},
          'rf': {'ei': [0, 1, 23, 97, 106, 137],
                 'greedy': [0, 42, 269, 387, 604, 724],
                 'pi': [0, 7, 25, 77, 86, 106],
                 'thompson': [0, 0, 34, 303, 362, 478],
                 'ucb': [0, 8, 252, 429, 478, 510]}},
    0.4: {'mpn': {'ei': [0, 587, 667, 795, 867, 902],
                  'greedy': [0, 618, 804, 911, 953, 963],
                  'pi': [0, 534, 709, 816, 871, 915],
                  'thompson': [0, 74, 531, 752, 877, 919],
                  'ucb': [0, 632, 816, 925, 943, 935]},
          'nn': {'ei': [0, 238, 569, 575, 675, 715],
                 'greedy': [0, 570, 827, 869, 951, 978],
                 'pi': [0, 268, 425, 473, 538, 622],
                 'thompson': [0, 84, 474, 722, 882, 939],
                 'ucb': [0, 417, 709, 795, 855, 926]},
          'rf': {'ei': [0, 27, 46, 95, 186, 226],
                 'greedy': [0, 189, 441, 733, 729, 795],
                 'pi': [0, 46, 101, 165, 264, 331],
                 'thompson': [0, 4, 179, 399, 556, 634],
                 'ucb': [0, 81, 312, 438, 478, 554]}}
}

In [None]:
E10k_union_smis_retrain = {
 'mpn': {'ei': [519, 861, 1140, 1311, 1471, 1600],
         'greedy': [514, 772, 1042, 1249, 1403, 1578],
         'pi': [508, 800, 1069, 1236, 1393, 1536],
         'thompson': [512, 895, 1195, 1435, 1648, 1822],
         'ucb': [506, 828, 1100, 1305, 1478, 1640]},
 'nn': {'ei': [519, 923, 1186, 1452, 1765, 2070],
        'greedy': [515, 910, 1172, 1395, 1570, 1732],
        'pi': [519, 926, 1234, 1492, 1788, 2001],
        'thompson': [513, 918, 1218, 1456, 1627, 1775],
        'ucb': [515, 926, 1179, 1412, 1583, 1744]},
 'rf': {'ei': [510, 985, 1360, 1685, 1984, 2241],
        'greedy': [516, 941, 1281, 1520, 1742, 1919],
        'pi': [510, 968, 1319, 1568, 1820, 2075],
        'thompson': [514, 999, 1427, 1835, 2172, 2507],
        'ucb': [519, 996, 1370, 1657, 1868, 2095]}
}

E50k_union_smis_retrain = {
 'mpn': {'ei': [2442, 3777, 4680, 5500, 6329, 7119],
         'greedy': [2438, 3686, 4585, 5461, 6223, 7032],
         'pi': [2437, 3729, 4654, 5465, 6267, 7161],
         'thompson': [2434, 4344, 5622, 6411, 7247, 7870],
         'ucb': [2440, 3668, 4637, 5411, 6258, 7023]},
 'nn': {'ei': [2443, 4015, 5104, 6414, 7571, 8679],
        'greedy': [2437, 3961, 4769, 5431, 6022, 6544],
        'pi': [2436, 3959, 5150, 6560, 7648, 8761],
        'thompson': [2462, 4227, 5170, 5818, 6437, 6959],
        'ucb': [2434, 4053, 4925, 5580, 6164, 6696]},
 'rf': {'ei': [2451, 4602, 6181, 7481, 8732, 9893],
        'greedy': [2445, 4131, 4934, 5642, 6428, 7106],
        'pi': [2438, 4386, 5628, 6950, 8240, 9483],
        'thompson': [2445, 4687, 6584, 8239, 9587, 10836],
        'ucb': [2430, 4345, 5578, 6950, 8034, 9116]}
}

In [None]:
HTS_intersection_smis_online = {
    0.1: {'mpn': {'ei': [10686, 15707, 20790, 26448, 32863, 38821],
                  'greedy': [10682, 15360, 21985, 26185, 32611, 37807],
                  'pi': [10691, 15775, 21199, 26991, 32398, 38013],
                  'thompson': [10689, 19789, 25916, 30591, 35279, 39820],
                  'ucb': [10682, 15718, 21350, 25570, 32154, 36446]},
         'nn': {'ei': [10686, 17636, 27806, 35279, 44661, 52823],
                 'greedy': [10693, 15984, 21439, 26856, 31478, 36426],
                 'pi': [10686, 17622, 26732, 32948, 42245, 48641],
                 'thompson': [10691, 19028, 26577, 33758, 40633, 48483],
                 'ucb': [10683, 17515, 25787, 33519, 41550, 48461]},
         'rf': {'ei': [10684, 20463, 30709, 40230, 49149, 57727],
                 'greedy': [10687, 18885, 26634, 32087, 40330, 46737],
                 'pi': [10692, 20342, 30102, 38600, 47056, 55368],
                 'thompson': [10681, 20919, 30827, 40026, 49158, 57738],
                 'ucb': [10689, 19597, 28477, 37797, 46996, 55803]}},
    0.2: {'mpn': {'ei': [21330, 30905, 40827, 51900, 65390, 76970],
                  'greedy': [21331, 31115, 43322, 51150, 62549, 70862],
                  'pi': [21347, 31009, 40982, 52101, 63631, 75153],
                  'thompson': [21332, 37797, 47820, 55846, 64280, 72898],
                  'ucb': [21345, 30789, 42967, 51034, 62016, 71688]},
         'nn': {'ei': [21339, 33855, 54021, 71242, 85249, 97256],
                 'greedy': [21350, 30195, 39989, 49518, 57577, 66290],
                 'pi': [21325, 33492, 53654, 70341, 88444, 97081],
                 'thompson': [21336, 36534, 48142, 59066, 68633, 78866],
                 'ucb': [21348, 32184, 44552, 56964, 68075, 82882]},
         'rf': {'ei': [21334, 39097, 59278, 76728, 93197, 112255],
                'greedy': [21342, 33917, 48247, 59901, 74441, 87288],
                'pi': [21343, 39641, 59651, 76862, 93724, 110352],
                'thompson': [21322, 41135, 60442, 76552, 93343, 109832],
                'ucb': [21330, 38829, 56877, 71586, 89522, 104610]}},
    0.4: {'mpn': {'ei': [42472, 60886, 85271, 107016, 128315, 150345],
                  'greedy': [42497, 59653, 81235, 97256, 115706, 130478],
                  'pi': [42483, 60472, 84103, 108413, 130376, 150972],
                  'thompson': [42508, 71672, 90374, 105287, 120655, 133771],
                  'ucb': [42490, 60021, 79660, 96576, 116595, 132428]},
          'nn': {'ei': [42522, 65769, 105401, 125242, 159336, 171939],
                 'greedy': [42489, 57089, 72717, 89684, 106228, 123014],
                 'pi': [42496, 65302, 105648, 131060, 158766, 172462],
                 'thompson': [42519, 68998, 92833, 119993, 144478, 171245],
                 'ucb': [42494, 61335, 89349, 107954, 136428, 150832]},
          'rf': {'ei': [42508, 75895, 114197, 142269, 174207, 205535],
                 'greedy': [42505, 63177, 89032, 109855, 137817, 162966],
                 'pi': [42512, 72402, 105881, 136261, 170318, 200344],
                 'thompson': [42473, 77295, 113569, 141028, 170529, 199158],
                 'ucb': [42535, 70928, 107662, 134741, 169029, 200558]}}
}

HTS_intersection_smis_retrain = {
    0.1: {'mpn': {'ei': [10688, 15720, 20010, 23635, 28248, 32577],
                  'greedy': [10691, 15694, 19327, 23105, 26720, 29849],
                  'pi': [10691, 15754, 19792, 23738, 27741, 31631],
                  'thompson': [10697, 19679, 25714, 30499, 34974, 38645],
                  'ucb': [10686, 15626, 19579, 23216, 26540, 29878]},
          'nn': {'ei': [10695, 17595, 24872, 33185, 41013, 49236],
                 'greedy': [10690, 16108, 19271, 21993, 24378, 27369],
                 'pi': [10686, 17580, 24365, 32492, 40787, 48178],
                 'thompson': [10675, 18890, 23187, 26574, 29027, 31806],
                 'ucb': [10686, 17165, 21434, 24970, 27540, 31772]},
          'rf': {'ei': [10681, 21005, 30020, 38782, 47135, 55790],
                 'greedy': [10686, 19151, 24234, 28787, 31887, 35929],
                 'pi': [10690, 20325, 27737, 35464, 44363, 53039],
                 'thompson': [10688, 20988, 29592, 37123, 44302, 50846],
                 'ucb': [10692, 20641, 27850, 35624, 41939, 50688]}},
    0.2: {'mpn': {'ei': [21341, 32616, 40030, 48217, 57254, 65806],
                  'greedy': [21353, 30415, 37947, 44542, 50942, 57041],
                  'pi': [21338, 32203, 39701, 47931, 56763, 64843],
                  'thompson': [21343, 38634, 48668, 56850, 64222, 71360],
                  'ucb': [21344, 30433, 37861, 44686, 51380, 57881]},
          'nn': {'ei': [21330, 33790, 48847, 66135, 81600, 97863],
                 'greedy': [21352, 29598, 34802, 40432, 45217, 50382],
                 'pi': [21336, 33700, 49855, 66010, 81079, 95444],
                 'thompson': [21325, 36307, 43935, 49119, 53282, 58230],
                 'ucb': [21348, 32032, 38783, 45160, 52772, 59558]},
          'rf': {'ei': [21336, 39771, 57288, 73902, 91191, 107190],
                 'greedy': [21334, 35859, 45486, 52240, 56539, 62690],
                 'pi': [21352, 38455, 55276, 71233, 88634, 106287],
                 'thompson': [21338, 40888, 56793, 69270, 80600, 90675],
                 'ucb': [21325, 37870, 51109, 64817, 78571, 91339]}},
    0.4: {'mpn': {'ei': [42510, 60711, 79150, 96571, 112905, 129702],
                  'greedy': [42490, 60267, 73209, 85490, 97573, 108971],
                  'pi': [42495, 61570, 78156, 95570, 113005, 127987],
                  'thompson': [42496, 72017, 87698, 102635, 117048, 130359],
                  'ucb': [42499, 59736, 73507, 86779, 99818, 111995]},
          'nn': {'ei': [42500, 65546, 98218, 128223, 152412, 178024],
                 'greedy': [42500, 57106, 66708, 75909, 85261, 95448],
                 'pi': [42462, 64925, 97658, 128572, 153130, 181438],
                 'thompson': [42490, 67841, 80125, 88336, 97228, 105883],
                 'ucb': [42477, 61066, 74259, 84028, 101318, 111424]},
         'rf': {'ei': [42480, 71782, 102182, 135475, 166549, 197109],
                'greedy': [42500, 64980, 82142, 94578, 106366, 116845],
                'pi': [42510, 74537, 106604, 136851, 166823, 195612],
                'thompson': [42520, 77848, 101661, 121834, 140229, 158613],
                'ucb': [42491, 71171, 95932, 120656, 145471, 166648]}}
}

In [None]:
smis_results ={
    '10k': {
        'union': E10k_union_smis_retrain,
    },
    '50k': {
        'union': E50k_union_smis_retrain,
    },
    'HTS' : {
        'intersection': {
            'online': HTS_intersection_smis_online,
            'retrain': HTS_intersection_smis_retrain
        },
        'union': {
            'online': HTS_union_smis_online,
            'retrain': HTS_union_smis_retrain
        }
    }
}

In [None]:
one_shot_results = {
    2: HTS_02_004,
    0.4: HTS_004_02
}

#### RUN ALL CELLS ABOVE ME

Now we're ready to make some figures and nicer data presentation

## Writing CSVs of the full data

these two functions allow you to write a formatted CSV of the full data that you just loaded in above.

The first function will write the final results for a given library and batch size (including both online and full retraining results as well as data for the top-k SMILES and top-k average metrics.

In [None]:
def write_final_results_csv(library, batch_size):
    lib_results = all_results[library][batch_size]
    results_df = []
    for training in ['retrain', 'online']:
        results = lib_results[training]
        for model in MODELS:
            for metric in METRICS:
                if metric == 'greedy':
                    metric_ = metric.capitalize()
                elif metric == 'thompson':
                    metric_ = 'TS'
                else:
                    metric_ = metric.upper()

                scores = results[model]['scores'][metric]
                smis = results[model]['smis'][metric]
                avg = results[model]['avg'][metric]
                results_df.append({
                    'Training': training,
                    'Model': model.upper(),
                    'Metric': metric_,
                    'Scores ($\pm$ s.d.)': f'{scores[0][-1]:0.1f} ({scores[1][-1]:0.1f})',
                    'SMILES ($\pm$ s.d.)': f'{smis[0][-1]:0.1f} ({smis[1][-1]:0.1f})',
                    'Average ($\pm$ s.d.)': f'{avg[0][-1]:0.2f} ({avg[1][-1]:0.2f})'
                })
    df = pd.DataFrame(results_df).set_index(['Training', 'Model', 'Metric'])
    df.to_csv(f'{library}_{batch_size}_final_results.csv')
    return df

The second function will write the results of each iteration for a given library, batch size, and scoring metric (including both online and full retraining results).

In [None]:
def write_results_csv_by_iter(library, batch_size, score_mode):
    lib_results = all_results[library][batch_size]
    results_df = []
    for training in ['retrain', 'online']:
        results = lib_results[training]
        for model in MODELS:
            for metric in METRICS:
                if metric == 'greedy':
                    metric_ = metric.capitalize()
                elif metric == 'thompson':
                    metric_ = 'TS'
                else:
                    metric_ = metric.upper()

                means = results[model][score_mode][metric][0]
                sds = results[model][score_mode][metric][1]
                row = {
                    'Training': training,
                    'Model': model.upper(),
                    'Metric': metric_,
                }
                for i, (mean, sd) in enumerate(zip(means, sds)):
                    row[f'iter_{i} ($\pm$ s.d.)'] = f'{mean:0.1f} ({sd:0.1f})'
                results_df.append(row)
    df = pd.DataFrame(results_df).set_index(['Training', 'Model', 'Metric'])
    df.to_csv(f'{library}_{batch_size}_{score_mode}_results_full.csv', index=False)
    return df

running the cell below will generate the CSVs for both the full dataset of each experiment for each score metric and the abbreviated (final) results of each experiment. If you only want one or the other, comment out the appropriate lines

In [None]:
for library in ['10k', '50k']:
    write_final_results_csv(library, 1.0)
    for score_mode in ['scores', 'smis', 'avg']:
        write_results_csv_by_iter(library, 1.0, score_mode)

for library in ['HTS', 'AmpC']:
    for batch_size in [0.4, 0.2, 0.1]:
        write_final_results_csv(library, batch_size)
        for score_mode in ['scores', 'smis', 'avg']:
            write_results_csv_by_iter(library, batch_size, score_mode)     

## Generating the main text figures

### Recreating the 10k and 50k figures
the below function will, for a given library and split generate a three-paneled figure, where each panel will correspond to a given surrogate model architecture and contain the traces corresponding to each acquisition metric. This function was used to generate the 10k and 50k figures in the main text. By default, this function only produces results of full model retraining. If si_fig is set to true, it will show the results on online model training at full opacity and a faded trace of full model retraining as well.

In [None]:
def gen_metrics_figure(library, split, si_fig=False):
    score = 'scores'

    retrain_results = all_results[library][split]['retrain']
    online_results = all_results[library][split]['online']
    random_results = all_results[library][split]['random']
    size = all_results[library]['size']
    topk = all_results[library]['topk']
    y_min = all_results[library]['y_min']

    results_series = make_subplots(
        rows=1, cols=len(retrain_results),
        shared_xaxes=True, shared_yaxes=True,
        x_title='Number of Ligands Explored', y_title=f'Percentage of Top-{topk} Scores Found',
        subplot_titles=[model.upper() for model in MODELS]
    )

    xs = [int(size * split/100 * i) for i in range(1, 7)]

    for i, model in enumerate(MODELS):
        for j, metric in enumerate(METRICS):
            # full retrain trace
            ys, y_sds = retrain_results[model][score][metric]
            results_series.add_trace(go.Scatter(
                x=xs, y=ys, opacity=0.33 if si_fig else 1.,
                error_y=dict(type='data', array=y_sds, visible=not si_fig),
                marker=dict(color=METRIC_COLORS[j]),
                mode='lines+markers', name=METRIC_NAMES[metric],
                legendgroup=metric, showlegend=(model=='rf' and not si_fig)
            ), row=1, col=i+1)
            
            # online train trace
            if si_fig:
                ys, y_sds = online_results[model][score][metric]
                results_series.add_trace(go.Scatter(
                    x=xs, y=ys,
                    error_y=dict(type='data', array=y_sds, visible=True),
                    marker=dict(color=METRIC_COLORS[j]),
                    mode='lines+markers', name=METRIC_NAMES[metric],
                    legendgroup=metric, showlegend=(model=='rf')
                ), row=1, col=i+1)
                
        ys, y_sds = random_results[score]
        metric_trace = go.Scatter(
            x=xs, y=ys, 
            error_y=dict(type='data', array=y_sds, visible=True),
            marker=dict(color='slategray'),
            mode='lines+markers',
            name='random', legendgroup='random',
            showlegend=model=='rf',
        )
        results_series.add_trace(metric_trace, row=1, col=i+1)
        
        results_series.update_xaxes(row=1, col=i+1, 
                                    rangemode='tozero', nticks=10)
        
    border_params = dict(
        showgrid=True, zeroline=False, showticklabels=True, 
        visible=True, mirror=True, linewidth=2,
        linecolor='black'
    )
    results_series.update_xaxes(**border_params, tickangle=-35)
    results_series.update_yaxes(**border_params)
    
    results_series['layout']['legend']['title']['text'] = 'Metric'
    results_series.update_traces(
        marker=dict(symbol='circle', line_width=2, size=7.5),
        #line=dict(dash='solid')
    )
    
    for i in results_series['layout']['annotations']:
        i['font'] = dict(color='black', size=20, family='sans-serif')
    results_series['layout']['annotations'][-2]['yanchor']='bottom'
    results_series['layout']['annotations'][-2]['y'] = -0.15
    HEIGHT = 500
    results_series.update_layout(
        legend_title_text='Metric', #legend_traceorder='reversed',
        height=HEIGHT, width=HEIGHT/1.5*3,
        font=dict(color='black', size=18, family='sans-serif'),
    )
    return results_series

def gen_10k50k_figure(library, si_fig):
    return gen_metrics_figure(library, 1., si_fig)

The below cell will generate the 10k and 50k figures from the main text. Setting the second argument to `True` in the function calls will produce the 10k and 50k SI figures.

In [None]:
fig = gen_10k50k_figure('10k', False)
fig.show()
fig = gen_10k50k_figure('50k', False)
fig.show()

### Recreating the HTS and AmpC figures

The HTS and AmpC figures used a different figure design, where each panel now corresponds to a given initialization/exploration batch size and each panel contains the results of all three models' results with a greedy acquisition metric. The function below will generate those figures. To run these functions, go to the cell below.

In [None]:
def gen_HTSAmpC_figure_main(library):
    score = 'scores'
    size = all_results[library]['size']
    topk = all_results[library]['topk']
    y_min = all_results[library]['y_min']
    split_results = all_results[library]
    
    results_series = make_subplots(
        rows=1, cols=len(SPLITS), shared_yaxes=True,
        x_title='Number of Ligands Explored', y_title=f'Percentage of Top-{topk} Scores Found',
        subplot_titles=[f'{split}%' for split in SPLITS])

    for i, split in enumerate(SPLITS):
        retrain_results = split_results[split]['retrain']
        online_results = split_results[split]['retrain']
        random_results = split_results[split]['random']
        xs = [int(size * split/100 * i) for i in range(1, 7)]
        
        ys, y_sds = random_results[score]
        random_trace = go.Scatter(
            x=xs, y=ys, 
            error_y=dict(type='data', array=y_sds, visible=True),
            mode='lines+markers',
            marker=dict(symbol='circle'),
            line=dict(color='slategray'),
            name='random', legendgroup='random',
            showlegend=split==0.1,
        )
        results_series.add_trace(random_trace, row=1, col=i+1, )
        
        for j, model in enumerate(MODELS):
            ys, y_sds = retrain_results[model][score]['greedy']

            split_trace = go.Scatter(
                x=xs, y=ys,
                error_y=dict(type='data', array=y_sds, visible=True),
                mode='lines+markers',
                marker=dict(symbol=MARKERS[j], color=MODEL_COLORS[j]),
                name=model.upper(), legendgroup=model,
                showlegend=split==0.1
            )
            results_series.add_trace(split_trace, row=1, col=i+1)
            
        results_series.update_xaxes(row=1, col=i+1, 
                                    rangemode='tozero', nticks=10)
        
    border_params = dict(
        showgrid=True, zeroline=False, showticklabels=True, 
        visible=True, mirror=True, linewidth=2,
        linecolor='black'
    )
    results_series.update_xaxes(**border_params, tickangle=-35)
    results_series.update_yaxes(**border_params)
    results_series.update_traces(marker=dict(line_width=2, size=7.5))
    
    for i in results_series['layout']['annotations']:
        i['font'] = dict(color='black', size=20, family='sans-serif')
    results_series['layout']['annotations'][-2]['yanchor']='bottom'
    results_series['layout']['annotations'][-2]['y'] = -0.15
    HEIGHT = 500
    results_series.update_layout(
        legend_title_text='Model', legend_traceorder='reversed',
        height=HEIGHT, width=HEIGHT/1.5*3,
        font=dict(color='black', size=20, family='sans-serif')
    )
    
    return results_series

This cell will generate the HTS and AmpC figures from the main text.

In [None]:
fig = gen_HTSAmpC_figure_main('AmpC')
fig.show()
fig = gen_HTSAmpC_figure_main('HTS')
fig.show()

### Recreating the Single-iteration and Convergence Figures

the two cells below contain the functions to generate these figures, and the third cell will actually create the figures

In [None]:
def generate_one_shot_figure():
    fig = go.Figure()
    fig.update_layout(legend_title_text='Model')
    size = 2141514

    for i, model in enumerate(MODELS):
        # 0.2/0.04
        xs = [size*0.02, size*0.024]
        ys, y_sds = one_shot_results[2][model]['scores']['greedy']
        fig.add_trace(go.Scatter(
            x=xs, y=ys, 
            error_y=dict(type='data', array=y_sds, visible=True),
            line=dict(dash='dash'),
            marker=dict(symbol=MARKERS[i], color=MODEL_COLORS[i]),
            mode='lines+markers',
            name=model.upper(), legendgroup=model
        ))
        
        # 0.04/0.2
        xs = [size*0.004, size*0.024]
        ys, y_sds = one_shot_results[0.4][model]['scores']['greedy']
        fig.add_trace(go.Scatter(
            x=xs, y=ys, 
            error_y=dict(type='data', array=y_sds, visible=True),
            line=dict(dash='solid'),
            marker=dict(symbol=MARKERS[i], color=MODEL_COLORS[i]),
            mode='lines+markers', showlegend=False,
            name=model, legendgroup=model
        ))
        # AL trace
        xs = [int(size * 0.4/100 * i) for i in range(1, 7)]
        ys, y_sds = HTS_004_retrain[model]['scores']['greedy']
        fig.add_trace(go.Scatter(
            x=xs, y=ys, opacity=0.5,
            line=dict(dash='solid'),
            marker=dict(symbol=MARKERS[i], color=MODEL_COLORS[i]),
            mode='lines+markers', showlegend=False,
            name=model, legendgroup=model
        ))
    
    
    border_params = dict(
        showgrid=True, zeroline=False, showticklabels=True, 
        visible=True, mirror=True, linewidth=2,
        linecolor='black'
    )
    fig.update_yaxes(title_text=f'Percentage of Top-1000 Scores Found',
                     tickfont=dict(color='black', size=18, family='sans-serif'),
                     **border_params)
    fig.update_xaxes(title_text='Number of Ligands Explored',
                     rangemode='tozero', nticks=10, tickangle=-35,
                     tickfont=dict(color='black', size=18, family='sans-serif'),
                     **border_params)
    fig.update_traces(mode='lines+markers', marker_line_width=2, marker_size=7.5)
    
    HEIGHT = 500
    fig.update_layout(
        legend_title_text='Model', legend_traceorder='reversed',
        height=HEIGHT, width=HEIGHT/1.25,
        font=dict(color='black', size=16, family='sans-serif')
    )
    
    return fig

In [None]:
def generate_convergence_figure(score):    
    fig = go.Figure()
    fig.update_layout(legend_title_text='Model')
    
    size = 2141514
    
    for i, model in enumerate(MODELS):
        ys = HTS_convergence[model][score]
        xs = [int(size * 0.001 * i) for i in range(1, len(ys)+1)]
        fig.add_trace(go.Scatter(
            x=xs, y=ys, 
            mode='lines+markers',
            marker=dict(symbol=MARKERS[i], color=MODEL_COLORS[i]),
            name=model.upper()
        ))
    
    border_params = dict(
        showgrid=True, zeroline=False, showticklabels=True, 
        visible=True, mirror=True, linewidth=2,
        linecolor='black'
    )
    fig.update_yaxes(title_text=f'Percentage of Top-1000 {score.capitalize()} Found',
                     tickfont=dict(color='black', size=18, family='sans-serif'),
                     **border_params)
    fig.update_xaxes(title_text='Number of Ligands Explored',
                     rangemode='tozero', tickangle=-35, nticks=11,
                     tickfont=dict(color='black', size=18, family='sans-serif'),
                     **border_params)
    fig.update_traces(mode='lines+markers', marker_line_width=2, marker_size=7.5)
    
    HEIGHT = 500
    fig.update_layout(
        legend_title_text='Model', legend_traceorder='reversed',
        height=HEIGHT, width=HEIGHT/1.25,
        font=dict(color='black', size=16, family='sans-serif')
    )
    
    return fig

In [None]:
fig = generate_one_shot_figure()
fig.show()

fig = generate_convergence_figure('scores')
fig.show()

## Generating SI Figures

the cell below will generate the 10k, 50k, HTS, and AmpC SI figures

In [None]:
fig = gen_10k50k_figure('10k', False)
fig.show()
fig = gen_10k50k_figure('50k', False)
fig.show()

for split in [0.4, 0.2, 0.1]:
    digit = str(split).split('.')[1]
    fig = gen_metrics_figure('AmpC', split, True)
    fig.write_image(f'figures/SI/AmpC_00{digit}_model_by_metric_online_with_retrain_faded.pdf')
    fig = gen_metrics_figure('HTS', split, True)
    fig.write_image(f'figures/SI/HTS_00{digit}_model_by_metric_online_with_retrain_faded.pdf')

### Union Plots

the two functions below were used to generate the union plots in the SI. The mode can also be set to `intersection`, although these figures were not used in the paper. Any metric may also be specificed, but only greedy data was shown in the paper.

In [None]:
def gen_10k50k_smis_figure(mode='union', metric='greedy'):
    
    if mode == 'union':
        y_title = 'Total Number of Molecules Explored Among Runs'
    elif mode == 'intersection':
        y_title = 'Number of Molecules Shared Among Runs'
    else:
        raise ValueError('Unsupported mode!')
        
    results_series = make_subplots(
        rows=1, cols=2,
        x_title='Iteration', y_title=y_title,
        subplot_titles=['10k', '50k'])

    xs = list(range(6))
    
    lib_results = smis_results['10k'][mode]
    size = 10560
    for j, model in enumerate(MODELS):
        ys = lib_results[model][metric]

        model_trace = go.Scatter(
            x=xs, y=ys,
            mode='lines+markers',
            marker=dict(symbol=MARKERS[j], color=MODEL_COLORS[j]),
            name=model.upper(), legendgroup=model,
            showlegend=True
        )
        results_series.add_trace(model_trace, row=1, col=1)
    if mode =='union':
        ys_upper = [int(5*(x+1)*1/100*size) for x in xs]
        results_series.add_trace(go.Scatter(
            x=xs, y=ys_upper,
            mode='lines', line=dict(color='black'),
            name='upper_bound', showlegend=False
        ), row=1, col=1)

        ys_lower = [int((5+x)*1/100*size) for x in xs]
        results_series.add_trace(go.Scatter(
            x=xs, y=ys_lower,
            mode='lines', line=dict(color='black'),
            name='lower_bound', showlegend=False
        ), row=1, col=1)

        results_series.update_yaxes(row=1, col=1, range=[0, ys_upper[-1]])
    elif mode == 'intersection':
        results_series.update_yaxes(row=1, col=1, range=[0, 100])
    results_series.update_xaxes(row=1, col=1, range=[-0, 5], dtick=1)
    
    lib_results = smis_results['50k'][mode]
    size = 50240
    for j, model in enumerate(MODELS):
        ys = lib_results[model][metric]

        model_trace = go.Scatter(
            x=xs, y=ys,
            mode='lines+markers',
            marker=dict(symbol=MARKERS[j], color=MODEL_COLORS[j]),
            name=model.upper(), legendgroup=model,
            showlegend=False
        )
        results_series.add_trace(model_trace, row=1, col=2)
    if mode =='union':
        ys_upper = [int(5*(x+1)*1/100*size) for x in xs]
        results_series.add_trace(go.Scatter(
            x=xs, y=ys_upper,
            mode='lines', line=dict(color='black'),
            name='upper_bound', showlegend=False
        ), row=1, col=2)

        ys_lower = [int((5+x)*1/100*size) for x in xs]
        results_series.add_trace(go.Scatter(
            x=xs, y=ys_lower,
            mode='lines', line=dict(color='black'),
            name='lower_bound', showlegend=False
        ), row=1, col=2)

        results_series.update_yaxes(row=1, col=2, range=[0, ys_upper[-1]])
        
    results_series.update_xaxes(row=1, col=2, range=[-0, 5], dtick=1)
    border_params = dict(
        showgrid=True, zeroline=False, showticklabels=True, 
        visible=True, mirror=True, linewidth=2,
        linecolor='black'
    )
    results_series.update_xaxes(**border_params)
    results_series.update_yaxes(**border_params)
    results_series.update_traces(marker=dict(line_width=2, size=7.5))
    
    for i in results_series['layout']['annotations']:
        i['font'] = dict(color='black', size=20, family='sans-serif')
    results_series['layout']['annotations'][-2]['yanchor']='bottom'
    results_series['layout']['annotations'][-2]['y'] = -0.15
    results_series['layout']['annotations'][-1].x = -0.02
    HEIGHT = 500
    results_series.update_layout(
        legend_title_text='Model', legend_traceorder='reversed',
        height=HEIGHT, width=HEIGHT/1.5*2,
        font=dict(color='black', size=20, family='sans-serif')
    )
    
    return results_series

In [None]:
def gen_HTS_smis_figure(mode='union', training='retrain', metric='greedy'):
    size = 2.1E6
    results = smis_results['HTS'][mode][training]
    if mode == 'union':
        y_title = 'Total Number of Molecules Explored Among Runs'
    elif mode == 'intersection':
        y_title = 'Number of Molecules Shared Among Runs'
    else:
        raise ValueError('Unsupported mode!')
        
    results_series = make_subplots(
        rows=1, cols=len(SPLITS), shared_yaxes=True,
        x_title='Iteration', y_title=y_title,
        subplot_titles=[f'{split}%' for split in SPLITS])

    for i, split in enumerate(SPLITS):
        split_results = results[split]
        xs = list(range(6))
        # add actual data
        for j, model in enumerate(MODELS):
            ys = split_results[model][metric]

            split_trace = go.Scatter(
                x=xs, y=ys,
                mode='lines+markers',
                marker=dict(symbol=MARKERS[j], color=MODEL_COLORS[j]),
                name=model.upper(), legendgroup=model,
                showlegend=split==0.1
            )
            results_series.add_trace(split_trace, row=1, col=i+1)
            
        # add standards
        if mode =='union':
            ys_upper = [int(5*(x+1)*split/100*size) for x in xs]
            results_series.add_trace(go.Scatter(
                x=xs, y=ys_upper,
                mode='lines', line=dict(color='black'),
                name='upper_bound', showlegend=False
            ), row=1, col=i+1)
            
            ys_lower = [int((5+x)*split/100*size) for x in xs]
            results_series.add_trace(go.Scatter(
                x=xs, y=ys_lower,
                mode='lines', line=dict(color='black'),
                name='lower_bound', showlegend=False
            ), row=1, col=i+1)
            
            results_series.update_yaxes(row=1, col=i+1, range=[0, ys_upper[-1]])
        elif mode == 'intersection':
            results_series.update_yaxes(row=1, col=i+1, range=[0, 1000])
        
        results_series.update_xaxes(row=1, col=i+1, range=[-0, 5], dtick=1)
                                    #rangemode='tozero', nticks=6)
    border_params = dict(
        showgrid=True, zeroline=False, showticklabels=True, 
        visible=True, mirror=True, linewidth=2,
        linecolor='black'
    )
    results_series.update_xaxes(**border_params)
    results_series.update_yaxes(**border_params)
    results_series.update_traces(marker=dict(line_width=2, size=7.5))
    
    for i in results_series['layout']['annotations']:
        i['font'] = dict(color='black', size=20, family='sans-serif')
    results_series['layout']['annotations'][-2]['yanchor']='bottom'
    results_series['layout']['annotations'][-2]['y'] = -0.15
    results_series['layout']['annotations'][-1].x = -0.02
    HEIGHT = 500
    results_series.update_layout(
        legend_title_text='Model', legend_traceorder='reversed',
        height=HEIGHT, width=HEIGHT/1.5*3,
        font=dict(color='black', size=20, family='sans-serif')
    )
    
    return results_series

run the cell below to generate the plots

In [None]:
fig = gen_10k50k_smis_figure('union', 'greedy')
fig.show()

fig = gen_HTS_smis_figure('union', 'retrain', 'greedy')
fig.show()