In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn

import olympus
from olympus.datasets import Dataset
from olympus.emulators import Emulator
from olympus.planners import Planner
from olympus.models import BayesNeuralNet, NeuralNet
from olympus.objects import ParameterVector

## Initialize test case dataset - `suzuki_iii`

In [3]:
# load suzuki_iii dataset
dataset = Dataset(kind='suzuki_iii')
print(dataset.data.shape)
dataset.data.head()

(88, 6)


Unnamed: 0,ligand,res_time,temperature,catalyst_loading,yield,turnover
0,L3,60.0,30.0,2.513,0.2,0.1
1,L1,600.0,30.0,2.494,0.2,0.1
2,L0,60.0,30.0,0.51,0.2,0.3
3,L4,600.0,30.0,2.511,0.2,0.1
4,L5,60.0,30.0,0.499,0.2,0.3


In [4]:
dataset.target_names

['yield', 'turnover']

In [5]:
dataset.feature_names

['ligand', 'res_time', 'temperature', 'catalyst_loading']

In [6]:
dataset.features

Unnamed: 0,ligand,res_time,temperature,catalyst_loading
0,L3,60.0,30.0,2.513
1,L1,600.0,30.0,2.494
2,L0,60.0,30.0,0.51
3,L4,600.0,30.0,2.511
4,L5,60.0,30.0,0.499
...,...,...,...,...
83,L1,600.0,110.0,1.587
84,L2,600.0,110.0,1.787
85,L2,60.0,40.2,0.503
86,L2,60.0,40.1,0.503


In [7]:
model = BayesNeuralNet()
emulator = Emulator(dataset='suzuki_iii', model=model)
print(emulator)

<Emulator (Dataset(kind=suzuki_iii), model=
--> batch_size:    20
--> es_patience:   100
--> hidden_act:    leaky_relu
--> hidden_depth:  3
--> hidden_nodes:  48
--> kind:          BayesNeuralNet
--> learning_rate: 0.001
--> max_epochs:    100000
--> out_act:       linear
--> pred_int:      100
--> reg:           0.001
--> scope:         model)>


In [8]:
emulator.cross_validate()

DATA :  <class 'numpy.ndarray'>
DATA INNER :  <class 'numpy.ndarray'>
DATA :  <class 'numpy.ndarray'>
DATA INNER :  <class 'numpy.ndarray'>
[0;37m[INFO] >>> Training model on fold #0...
[0m

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -2.995         53.213         -2.916         58.139 *
[0m[0;37m[INFO]             100          0.441         19.667         -0.053         29.392 *
[0m[0;37m[INFO]             200          0.461         18.793          0.221         25.300 *
[0m[0;37m[INFO]             300          0.546         17.379          0.127         26.833
[0m[0;37m[INFO]             400          0.591         16.701          0.197         25.421
[0m[0;37m[INFO]             500          0.621         15.884          0.178         26.014
[0m[0;37m[INFO]             600          0.597         17.011          0.105         27.134
[0m[0;37m[INFO]             700          0.602         16.736          0.204         25.638
[0m[0;37m[INFO]             800          0.674         14.485          0.224         25.198 *
[0m[0;37m[INFO]             900          0.666    

[0m[0;37m[INFO]            8400          0.848          9.995          0.845         10.837
[0m[0;37m[INFO]            8500          0.847          9.986          0.873          9.806
[0m[0;37m[INFO]            8600          0.851          9.357          0.832         11.349
[0m[0;37m[INFO]            8700          0.847         10.013          0.851         10.659
[0m[0;37m[INFO]            8800          0.831          9.963          0.840         11.103
[0m[0;37m[INFO]            8900          0.849         10.117          0.867          9.855
[0m[0;37m[INFO]            9000          0.835         10.907          0.878          9.356
[0m[0;37m[INFO]            9100          0.844         10.569          0.903          8.725 *
[0m[0;37m[INFO]            9200          0.876          9.040          0.826         11.522
[0m[0;37m[INFO]            9300          0.867          9.578          0.882          9.612
[0m[0;37m[INFO]            9400          0.858          

[0m[0;37m[INFO]           17200          0.994          1.819          0.936          7.154
[0m[0;37m[INFO]           17300          0.994          2.238          0.956          6.291
[0m[0;37m[INFO]           17400          0.997          1.356          0.968          5.298
[0m[0;37m[INFO]           17500          0.997          1.423          0.954          6.300
[0m[0;37m[INFO]           17600          0.992          2.087          0.959          6.011
[0m[0;37m[INFO]           17700          0.983          3.576          0.956          6.201
[0m[0;37m[INFO]           17800          0.997          1.312          0.964          5.508
[0m[0;37m[INFO]           17900          0.997          1.188          0.949          6.280
[0m[0;37m[INFO]           18000          0.998          1.209          0.952          6.187
[0m[0;37m[INFO]           18100          0.996          1.640          0.941          7.070
[0m[0;37m[INFO]           18200          0.996          1.

[0m[0;37m[INFO]           26000          0.999          0.709          0.960          5.591
[0m[0;37m[INFO]           26100          0.999          1.070          0.943          6.800
[0m[0;37m[INFO]           26200          0.998          0.966          0.954          6.132
[0m[0;37m[INFO]           26300          0.998          1.085          0.948          6.329
[0m[0;37m[INFO]           26400          0.999          0.876          0.942          6.875
[0m[0;37m[INFO]           26500          0.999          0.681          0.940          6.850
[0m[0;37m[INFO]           26600          0.999          0.869          0.931          7.481
[0m[0;37m[INFO]           26700          0.999          0.702          0.937          7.002
[0m[0;37m[INFO]           26800          0.999          0.652          0.936          7.119
[0m[0;37m[INFO]           26900          0.996          1.762          0.922          8.072
[0m[0;37m[INFO]           27000          0.999          0.

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -2.499         56.829         -2.188         60.103 *
[0m[0;37m[INFO]             100          0.306         24.871          0.381         25.933 *
[0m[0;37m[INFO]             200          0.407         22.561          0.354         26.282
[0m[0;37m[INFO]             300          0.346         23.781          0.317         27.187
[0m[0;37m[INFO]             400          0.354         23.767          0.311         27.249
[0m[0;37m[INFO]             500          0.372         23.555          0.409         25.146 *
[0m[0;37m[INFO]             600          0.461         22.004          0.387         25.624
[0m[0;37m[INFO]             700          0.478         21.201          0.408         25.272
[0m[0;37m[INFO]             800          0.459         21.853          0.368         26.117
[0m[0;37m[INFO]             900          0.555      

[0m[0;37m[INFO]            8400          0.867         11.879          0.775         15.473
[0m[0;37m[INFO]            8500          0.870         11.722          0.784         15.141 *
[0m[0;37m[INFO]            8600          0.881         11.287          0.779         15.288
[0m[0;37m[INFO]            8700          0.868         11.998          0.767         15.785
[0m[0;37m[INFO]            8800          0.885         11.238          0.762         15.968
[0m[0;37m[INFO]            8900          0.864         11.866          0.797         14.553 *
[0m[0;37m[INFO]            9000          0.883         11.106          0.793         14.767
[0m[0;37m[INFO]            9100          0.892         10.886          0.768         15.638
[0m[0;37m[INFO]            9200          0.888         10.914          0.781         15.179
[0m[0;37m[INFO]            9300          0.881         11.026          0.793         14.633
[0m[0;37m[INFO]            9400          0.914        

[0m[0;37m[INFO]           17100          0.990          2.917          0.884         11.694
[0m[0;37m[INFO]           17200          0.985          3.395          0.891         11.366
[0m[0;37m[INFO]           17300          0.993          2.527          0.874         12.152
[0m[0;37m[INFO]           17400          0.986          3.358          0.891         11.396
[0m[0;37m[INFO]           17500          0.991          2.959          0.889         11.301
[0m[0;37m[INFO]           17600          0.994          2.221          0.856         12.937
[0m[0;37m[INFO]           17700          0.993          2.469          0.867         12.449
[0m[0;37m[INFO]           17800          0.993          2.488          0.880         11.887
[0m[0;37m[INFO]           17900          0.991          2.747          0.881         11.803
[0m[0;37m[INFO]           18000          0.990          3.015          0.874         12.034
[0m[0;37m[INFO]           18100          0.985          3.

[0m[0;37m[INFO]           26000          0.997          1.777          0.877         12.107
[0m[0;37m[INFO]           26100          0.995          2.131          0.860         12.879
[0m[0;37m[INFO]           26200          0.982          3.745          0.893         11.329
[0m[0;37m[INFO] Training completed in 47.21 seconds.

[0mDATA :  <class 'numpy.ndarray'>
DATA INNER :  <class 'numpy.ndarray'>
DATA :  <class 'numpy.ndarray'>
DATA INNER :  <class 'numpy.ndarray'>
[0;37m[INFO] >>> Training model on fold #2...
[0m

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -4.893         62.211         -1.482         53.359 *
[0m[0;37m[INFO]             100          0.198         22.762          0.252         28.016 *
[0m[0;37m[INFO]             200          0.248         21.994          0.159         29.920
[0m[0;37m[INFO]             300          0.241         22.143          0.239         27.888 *
[0m[0;37m[INFO]             400          0.114         23.925          0.213         28.884
[0m[0;37m[INFO]             500          0.268         21.736          0.333         26.842 *
[0m[0;37m[INFO]             600          0.073         24.459          0.409         25.719 *
[0m[0;37m[INFO]             700          0.175         23.068          0.424         25.349 *
[0m[0;37m[INFO]             800          0.202         22.704          0.302         27.408
[0m[0;37m[INFO]             900          0.265

[0m[0;37m[INFO]            8400          0.844          9.882          0.873         12.380 *
[0m[0;37m[INFO]            8500          0.824         10.539          0.871         12.993
[0m[0;37m[INFO]            8600          0.880          8.743          0.832         14.178
[0m[0;37m[INFO]            8700          0.872          8.989          0.846         13.819
[0m[0;37m[INFO]            8800          0.870          9.042          0.870         12.837
[0m[0;37m[INFO]            8900          0.879          8.749          0.812         14.724
[0m[0;37m[INFO]            9000          0.870          9.075          0.830         14.672
[0m[0;37m[INFO]            9100          0.885          8.537          0.828         14.416
[0m[0;37m[INFO]            9200          0.865          9.211          0.869         13.167
[0m[0;37m[INFO]            9300          0.893          8.230          0.837         13.895
[0m[0;37m[INFO]            9400          0.885          

[0m[0;37m[INFO]           17300          0.978          3.789          0.741         18.635
[0m[0;37m[INFO]           17400          0.983          3.319          0.737         18.635
[0m[0;37m[INFO]           17500          0.978          3.693          0.760         17.851
[0m[0;37m[INFO]           17600          0.970          4.359          0.758         18.014
[0m[0;37m[INFO]           17700          0.972          4.178          0.785         16.938
[0m[0;37m[INFO]           17800          0.975          4.024          0.697         19.568
[0m[0;37m[INFO]           17900          0.982          3.380          0.752         18.199
[0m[0;37m[INFO]           18000          0.984          3.204          0.720         18.940
[0m[0;37m[INFO]           18100          0.984          3.204          0.752         17.978
[0m[0;37m[INFO]           18200          0.981          3.518          0.756         17.805
[0m[0;37m[INFO]           18300          0.986          2.

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -2.169         55.036         -1.452         50.051 *
[0m[0;37m[INFO]             100          0.414         22.838          0.447         23.909 *
[0m[0;37m[INFO]             200          0.506         20.972          0.395         25.039
[0m[0;37m[INFO]             300          0.525         20.305          0.418         24.671
[0m[0;37m[INFO]             400          0.598         18.757          0.373         25.373
[0m[0;37m[INFO]             500          0.491         21.306          0.438         23.793 *
[0m[0;37m[INFO]             600          0.584         19.072          0.426         24.396
[0m[0;37m[INFO]             700          0.623         18.225          0.402         24.655
[0m[0;37m[INFO]             800          0.619         18.078          0.424         24.213
[0m[0;37m[INFO]             900          0.626      

[0m[0;37m[INFO]            8500          0.923          8.057          0.721         17.556 *
[0m[0;37m[INFO]            8600          0.927          7.849          0.720         17.529 *
[0m[0;37m[INFO]            8700          0.935          7.642          0.665         19.404
[0m[0;37m[INFO]            8800          0.935          7.483          0.690         18.667
[0m[0;37m[INFO]            8900          0.928          7.913          0.711         18.014
[0m[0;37m[INFO]            9000          0.941          7.241          0.729         17.377 *
[0m[0;37m[INFO]            9100          0.943          6.789          0.747         16.939 *
[0m[0;37m[INFO]            9200          0.928          7.862          0.739         16.965
[0m[0;37m[INFO]            9300          0.917          8.490          0.761         16.215 *
[0m[0;37m[INFO]            9400          0.933          7.573          0.758         16.288
[0m[0;37m[INFO]            9500          0.951  

[0m[0;37m[INFO]           17200          0.991          2.667          0.809         14.749
[0m[0;37m[INFO]           17300          0.991          2.924          0.800         15.211
[0m[0;37m[INFO]           17400          0.991          2.771          0.799         15.075
[0m[0;37m[INFO]           17500          0.992          2.721          0.789         15.557
[0m[0;37m[INFO]           17600          0.984          3.697          0.782         15.698
[0m[0;37m[INFO]           17700          0.993          2.584          0.800         15.104
[0m[0;37m[INFO]           17800          0.991          2.790          0.800         15.039
[0m[0;37m[INFO]           17900          0.993          2.575          0.806         14.867
[0m[0;37m[INFO]           18000          0.992          2.572          0.810         14.669
[0m[0;37m[INFO]           18100          0.989          3.132          0.795         15.301
[0m[0;37m[INFO]           18200          0.993          2.

[0m[0;37m[INFO]           26000          0.994          2.272          0.819         14.489
[0m[0;37m[INFO]           26100          0.992          2.556          0.797         15.232
[0m[0;37m[INFO]           26200          0.992          2.607          0.811         14.764
[0m[0;37m[INFO]           26300          0.994          2.228          0.810         14.805
[0m[0;37m[INFO]           26400          0.995          2.228          0.805         15.012
[0m[0;37m[INFO]           26500          0.993          2.635          0.785         15.780
[0m[0;37m[INFO]           26600          0.993          2.498          0.808         14.944
[0m[0;37m[INFO]           26700          0.994          2.329          0.820         14.455
[0m[0;37m[INFO]           26800          0.995          2.202          0.819         14.520
[0m[0;37m[INFO]           26900          0.996          2.007          0.813         14.746
[0m[0;37m[INFO]           27000          0.991          2.

[0m[0;37m[INFO]           34800          0.994          2.460          0.820         14.480
[0m[0;37m[INFO]           34900          0.992          2.572          0.807         15.005
[0m[0;37m[INFO] Training completed in 57.39 seconds.

[0mDATA :  <class 'numpy.ndarray'>
DATA INNER :  <class 'numpy.ndarray'>
DATA :  <class 'numpy.ndarray'>
DATA INNER :  <class 'numpy.ndarray'>
[0;37m[INFO] >>> Training model on fold #4...
[0m

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -2.545         56.393         -2.758         55.471 *
[0m[0;37m[INFO]             100          0.190         26.876          0.117         26.100 *
[0m[0;37m[INFO]             200          0.197         26.915          0.239         24.087 *
[0m[0;37m[INFO]             300          0.258         25.948          0.205         24.706
[0m[0;37m[INFO]             400          0.161         27.374          0.224         24.583
[0m[0;37m[INFO]             500          0.216         26.543          0.210         24.785
[0m[0;37m[INFO]             600          0.300         25.168          0.212         24.560
[0m[0;37m[INFO]             700          0.329         24.867          0.308         23.246 *
[0m[0;37m[INFO]             800          0.322         24.945          0.320         22.838 *
[0m[0;37m[INFO]             900          0.342  

[0m[0;37m[INFO]            8500          0.924          9.261          0.913          7.777
[0m[0;37m[INFO]            8600          0.895         10.702          0.907          7.715
[0m[0;37m[INFO]            8700          0.909          9.875          0.916          7.468 *
[0m[0;37m[INFO]            8800          0.930          8.939          0.889          8.556
[0m[0;37m[INFO]            8900          0.916          9.709          0.922          7.097 *
[0m[0;37m[INFO]            9000          0.928          8.940          0.902          7.910
[0m[0;37m[INFO]            9100          0.938          8.445          0.919          7.347
[0m[0;37m[INFO]            9200          0.935          8.602          0.918          7.259
[0m[0;37m[INFO]            9300          0.924          9.159          0.919          7.318
[0m[0;37m[INFO]            9400          0.922          9.040          0.893          8.509
[0m[0;37m[INFO]            9500          0.933        

[0m[0;37m[INFO]           17200          0.992          2.655          0.959          5.661
[0m[0;37m[INFO]           17300          0.995          2.187          0.955          5.895
[0m[0;37m[INFO]           17400          0.994          2.253          0.953          6.128
[0m[0;37m[INFO]           17500          0.995          2.105          0.956          5.978
[0m[0;37m[INFO]           17600          0.992          2.498          0.953          6.297
[0m[0;37m[INFO]           17700          0.994          2.306          0.956          5.829
[0m[0;37m[INFO]           17800          0.992          2.470          0.957          5.975
[0m[0;37m[INFO]           17900          0.996          2.005          0.958          5.692
[0m[0;37m[INFO]           18000          0.987          3.223          0.949          6.450
[0m[0;37m[INFO]           18100          0.984          3.272          0.962          5.633
[0m[0;37m[INFO]           18200          0.985          3.

{'train_r2': array([0.99546708, 0.98626813, 0.84429688, 0.993336  , 0.98758204]),
 'validate_r2': array([0.97832669, 0.8953732 , 0.87272429, 0.83281492, 0.97090376]),
 'train_rmsd': array([1.45801866, 3.67640813, 9.88245871, 2.54093677, 3.08764875]),
 'validate_rmsd': array([ 4.31648744, 11.01203601, 12.37996314, 13.92500955,  4.73852879])}

In [9]:
emulator.train(retrain=True)

DATA :  <class 'numpy.ndarray'>
DATA INNER :  <class 'numpy.ndarray'>
DATA :  <class 'numpy.ndarray'>
DATA INNER :  <class 'numpy.ndarray'>
[0;37m[INFO] >>> Training model on 80% of the dataset, testing on 20%...
[0m

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -1.983         49.560         -2.101         56.297 *
[0m[0;37m[INFO]             100          0.346         23.555          0.377         25.050 *
[0m[0;37m[INFO]             200          0.373         23.291          0.245         27.679
[0m[0;37m[INFO]             300          0.380         22.958          0.408         24.457 *
[0m[0;37m[INFO]             400          0.408         21.991          0.413         24.169 *
[0m[0;37m[INFO]             500          0.402         22.513          0.401         24.579
[0m[0;37m[INFO]             600          0.402         22.338          0.437         23.890 *
[0m[0;37m[INFO]             700          0.390         22.269          0.452         23.351 *
[0m[0;37m[INFO]             800          0.422         21.523          0.542         21.373 *
[0m[0;37m[INFO]             900          0.4

[0m[0;37m[INFO]            8500          0.850         10.065          0.740         15.037
[0m[0;37m[INFO]            8600          0.879          8.933          0.766         14.660
[0m[0;37m[INFO]            8700          0.875          9.346          0.764         14.372
[0m[0;37m[INFO]            8800          0.846          9.983          0.788         13.709
[0m[0;37m[INFO]            8900          0.877          9.028          0.778         13.871
[0m[0;37m[INFO]            9000          0.876          9.259          0.775         14.448
[0m[0;37m[INFO]            9100          0.868          8.998          0.815         12.929 *
[0m[0;37m[INFO]            9200          0.841         10.252          0.741         15.135
[0m[0;37m[INFO]            9300          0.908          7.879          0.802         13.269
[0m[0;37m[INFO]            9400          0.903          8.004          0.778         14.390
[0m[0;37m[INFO]            9500          0.915          

[0m[0;37m[INFO]           17200          0.991          2.589          0.908          9.688
[0m[0;37m[INFO]           17300          0.990          2.954          0.920          9.081 *
[0m[0;37m[INFO]           17400          0.990          2.783          0.909          9.787
[0m[0;37m[INFO]           17500          0.975          4.080          0.924          9.063 *
[0m[0;37m[INFO]           17600          0.993          2.348          0.907          9.925
[0m[0;37m[INFO]           17700          0.989          2.745          0.911          9.744
[0m[0;37m[INFO]           17800          0.986          3.352          0.904          9.974
[0m[0;37m[INFO]           17900          0.991          2.734          0.916          9.327
[0m[0;37m[INFO]           18000          0.993          2.384          0.910          9.609
[0m[0;37m[INFO]           18100          0.988          3.063          0.911          9.711
[0m[0;37m[INFO]           18200          0.990        

[0m[0;37m[INFO]           26100          0.993          2.359          0.897         10.074
[0m[0;37m[INFO]           26200          0.993          2.449          0.916          9.378
[0m[0;37m[INFO]           26300          0.996          1.873          0.907          9.810
[0m[0;37m[INFO]           26400          0.991          2.369          0.909          9.927
[0m[0;37m[INFO]           26500          0.995          2.049          0.909          9.781
[0m[0;37m[INFO]           26600          0.993          2.315          0.912          9.642
[0m[0;37m[INFO]           26700          0.991          2.499          0.907          9.987
[0m[0;37m[INFO]           26800          0.988          2.822          0.910          9.875
[0m[0;37m[INFO]           26900          0.993          2.575          0.918          9.164
[0m[0;37m[INFO]           27000          0.991          2.773          0.909          9.364
[0m[0;37m[INFO]           27100          0.989          2.

{'train_r2': 0.9868713167542207,
 'test_r2': 0.9229099554333615,
 'train_rmsd': 3.305444941981536,
 'test_rmsd': 8.989096993333217}

In [32]:
params = [['L2', 60.0, 65.3, 2.512], ['L5', 510.0, 100.0, 0.87]]

params = [
    ParameterVector().from_dict({'liagnd': 'L2', 'res_time':  60.0, 'temperature': 65.3, 'catalyst_loading': 2.512}),
    ParameterVector().from_dict({'liagnd': 'L5', 'res_time':  510.0, 'temperature': 100.0, 'catalyst_loading': 0.87})
]

In [33]:
measurement = emulator.run(params, return_paramvector=False)

print('measurement : ', measurement)
type(measurement)

processed features :  [['L2', 60.0, 65.3, 2.512], ['L5', 510.0, 100.0, 0.87]]
measurement :  [[23.56794739  9.48339176]
 [-1.97675288 -0.66644633]]


numpy.ndarray

Categorical (name='ligand', num_opts: 8, options=['L0', 'L1', 'L2', 'L3', 'L4', 'L5', 'L6', 'L7'], descriptors=[None, None, None, None, None, None, None, None])
Continuous (name='res_time', low=60.0, high=600.0, is_periodic=False)
Continuous (name='temperature', low=30.0, high=110.0, is_periodic=False)
Continuous (name='catalyst_loading', low=0.5, high=2.5, is_periodic=False)


In [17]:
pvec = ParameterVector().from_dict({'liagnd': 'L2', 'res_time':  60.0, 'temperature': 65.3, 'catalyst_loading': 2.512})

In [18]:
pvec.to_array()

array(['L2', '60.0', '65.3', '2.512'], dtype='<U32')

In [19]:
pvec.to_list()

['L2', 60.0, 65.3, 2.512]