In [1]:
import deepchem as dc
import tensorflow as tf
import sklearn as sk
import numpy as np
import pandas as pd

print("TensorFlow version: " + tf.__version__)
print("DeepChe, version: " + dc.__version__)

metric_labels=['mean_squared_error','pearson_r2_score',
               'mae_score', 'rmse']


metric1 = dc.metrics.Metric(dc.metrics.mean_squared_error)
metric2 = dc.metrics.Metric(dc.metrics.pearson_r2_score)
metric3 = dc.metrics.Metric(dc.metrics.mae_score)
metrics = [metric1, metric2, metric3]
metric_selector = 2 #which metric to use for callback

TensorFlow version: 2.7.0
DeepChe, version: 2.5.0


In [2]:
def get_them_metrics(
        model,
        datasets,
        metrics,
        metric_labels,
        transformers=[],
):
    """calculates metrics for a run
    model: trained model
    # datasets: tuple of datasets
    # metrics: list of metric objects
    # metric labels: sensible labels"""
    out = []
    for dataset in datasets:
        if transformers == []:
            egg = model.evaluate(
                dataset,
                metrics)
        else:
            egg = model.evaluate(
                dataset,
                metrics,
                transformers=transformers)
        for metric_label in metric_labels:
            if metric_label == 'rmse':
                out.append(np.sqrt(egg['mean_squared_error']))
            else:
                out.append(egg[metric_label])
    return out

## Multitask regressor

In [3]:
patience=3

tasks, datasets, transformers = dc.molnet.load_qm7(
    shard_size=2000,
    featurizer=dc.feat.CoulombMatrix
    (max_atoms=23),
    splitter='stratified',
    move_mean=False)

callback = tf.keras.callbacks.EarlyStopping(monitor='mae', patience=3)

# the datasets object is already split into the train, validation and test dataset 
train_dataset, valid_dataset, test_dataset = datasets

fit_transformers = [dc.trans.CoulombFitTransformer(train_dataset)]

# this loads in a general purpose regression model
model = dc.models.MultitaskFitTransformRegressor(
    n_tasks = len(test_dataset.tasks), # size of y, we have one output task here: finding toxicity
    n_features = [23,23],
    fit_transformers = fit_transformers # number of input features, i.e. the length of the ECFPs
)

# this sets up a callback on the validation
callback = dc.models.ValidationCallback(
            valid_dataset,
            patience,
            metrics[metric_selector])
# fit da model
model.fit(train_dataset, nb_epoch=100, callbacks=callback)



Step 3 validation: mae_score=1.48199
Step 6 validation: mae_score=1.31214
Step 9 validation: mae_score=1.44903
Step 12 validation: mae_score=1.0117
Step 15 validation: mae_score=1.12861
Step 18 validation: mae_score=1.00427
Step 21 validation: mae_score=1.06031
Step 24 validation: mae_score=1.08445
Step 27 validation: mae_score=0.87349
Step 30 validation: mae_score=0.799417
Step 33 validation: mae_score=0.969591
Step 36 validation: mae_score=0.923726
Step 39 validation: mae_score=0.891932
Step 42 validation: mae_score=0.907178
Step 45 validation: mae_score=0.905902
Step 48 validation: mae_score=0.748173
Step 51 validation: mae_score=0.83944
Step 54 validation: mae_score=0.836942
Step 57 validation: mae_score=0.997459
Step 60 validation: mae_score=0.99094
Step 63 validation: mae_score=0.858154
Step 66 validation: mae_score=0.946173
Step 69 validation: mae_score=0.853643
Step 72 validation: mae_score=0.917921
Step 75 validation: mae_score=0.845895
Step 78 validation: mae_score=0.80882
St

Step 624 validation: mae_score=0.757577
Step 627 validation: mae_score=0.708173
Step 630 validation: mae_score=0.732887
Step 633 validation: mae_score=0.725647
Step 636 validation: mae_score=0.73466
Step 639 validation: mae_score=0.717544
Step 642 validation: mae_score=0.700257
Step 645 validation: mae_score=0.708122
Step 648 validation: mae_score=0.698481
Step 651 validation: mae_score=0.715312
Step 654 validation: mae_score=0.755562
Step 657 validation: mae_score=0.706285
Step 660 validation: mae_score=0.720626
Step 663 validation: mae_score=0.731926
Step 666 validation: mae_score=0.723882
Step 669 validation: mae_score=0.74538
Step 672 validation: mae_score=0.738285
Step 675 validation: mae_score=0.717751
Step 678 validation: mae_score=0.772369
Step 681 validation: mae_score=0.713132
Step 684 validation: mae_score=0.736991
Step 687 validation: mae_score=0.715888
Step 690 validation: mae_score=0.728352
Step 693 validation: mae_score=0.715347
Step 696 validation: mae_score=0.7544
Step

Step 1236 validation: mae_score=0.770768
Step 1239 validation: mae_score=0.747758
Step 1242 validation: mae_score=0.718233
Step 1245 validation: mae_score=0.719364
Step 1248 validation: mae_score=0.72598
Step 1251 validation: mae_score=0.720495
Step 1254 validation: mae_score=0.734274
Step 1257 validation: mae_score=0.716687
Step 1260 validation: mae_score=0.757884
Step 1263 validation: mae_score=0.776074
Step 1266 validation: mae_score=0.770508
Step 1269 validation: mae_score=0.739525
Step 1272 validation: mae_score=0.7548
Step 1275 validation: mae_score=0.914955
Step 1278 validation: mae_score=0.789103
Step 1281 validation: mae_score=0.720367
Step 1284 validation: mae_score=0.778232
Step 1287 validation: mae_score=0.848883
Step 1290 validation: mae_score=0.78938
Step 1293 validation: mae_score=0.781387
Step 1296 validation: mae_score=0.787452
Step 1299 validation: mae_score=0.735243
Step 1302 validation: mae_score=0.757503
Step 1305 validation: mae_score=0.718375
Step 1308 validation

Step 1839 validation: mae_score=0.72879
Step 1842 validation: mae_score=0.769876
Step 1845 validation: mae_score=0.75602
Step 1848 validation: mae_score=0.714605
Step 1851 validation: mae_score=0.772004
Step 1854 validation: mae_score=0.740063
Step 1857 validation: mae_score=0.715893
Step 1860 validation: mae_score=0.801248
Step 1863 validation: mae_score=0.78811
Step 1866 validation: mae_score=0.732128
Step 1869 validation: mae_score=0.769925
Step 1872 validation: mae_score=0.717041
Step 1875 validation: mae_score=0.776082
Step 1878 validation: mae_score=0.742429
Step 1881 validation: mae_score=0.7615
Step 1884 validation: mae_score=0.744324
Step 1887 validation: mae_score=0.738608
Step 1890 validation: mae_score=0.741974
Step 1893 validation: mae_score=0.755048
Step 1896 validation: mae_score=0.734525
Step 1899 validation: mae_score=0.719666
Step 1902 validation: mae_score=0.692098
Step 1905 validation: mae_score=0.722257
Step 1908 validation: mae_score=0.732681
Step 1911 validation:

Step 2442 validation: mae_score=0.703115
Step 2445 validation: mae_score=0.755935
Step 2448 validation: mae_score=0.77249
Step 2451 validation: mae_score=0.751134
Step 2454 validation: mae_score=0.747964
Step 2457 validation: mae_score=0.810976
Step 2460 validation: mae_score=0.794242
Step 2463 validation: mae_score=0.770671
Step 2466 validation: mae_score=0.814789
Step 2469 validation: mae_score=0.770378
Step 2472 validation: mae_score=0.776303
Step 2475 validation: mae_score=0.844703
Step 2478 validation: mae_score=0.789013
Step 2481 validation: mae_score=0.799173
Step 2484 validation: mae_score=0.742201
Step 2487 validation: mae_score=0.707324
Step 2490 validation: mae_score=0.724654
Step 2493 validation: mae_score=0.832594
Step 2496 validation: mae_score=0.803125
Step 2499 validation: mae_score=0.811989
Step 2502 validation: mae_score=0.719231
Step 2505 validation: mae_score=0.763929
Step 2508 validation: mae_score=0.857818
Step 2511 validation: mae_score=0.757608
Step 2514 validat

Step 3045 validation: mae_score=0.766378
Step 3048 validation: mae_score=0.749532
Step 3051 validation: mae_score=0.72908
Step 3054 validation: mae_score=0.723677
Step 3057 validation: mae_score=0.772727
Step 3060 validation: mae_score=0.726287
Step 3063 validation: mae_score=0.741789
Step 3066 validation: mae_score=0.714673
Step 3069 validation: mae_score=0.718794
Step 3072 validation: mae_score=0.705432
Step 3075 validation: mae_score=0.723196
Step 3078 validation: mae_score=0.724059
Step 3081 validation: mae_score=0.735171
Step 3084 validation: mae_score=0.732138
Step 3087 validation: mae_score=0.722215
Step 3090 validation: mae_score=0.70404
Step 3093 validation: mae_score=0.715294
Step 3096 validation: mae_score=0.715106
Step 3099 validation: mae_score=0.736046
Step 3102 validation: mae_score=0.724631
Step 3105 validation: mae_score=0.718812
Step 3108 validation: mae_score=0.699172
Step 3111 validation: mae_score=0.719716
Step 3114 validation: mae_score=0.760676
Step 3117 validati

Step 3648 validation: mae_score=0.782539
Step 3651 validation: mae_score=0.737933
Step 3654 validation: mae_score=0.751853
Step 3657 validation: mae_score=0.703098
Step 3660 validation: mae_score=0.753918
Step 3663 validation: mae_score=0.760644
Step 3666 validation: mae_score=0.713784
Step 3669 validation: mae_score=0.711831
Step 3672 validation: mae_score=0.760479
Step 3675 validation: mae_score=0.716213
Step 3678 validation: mae_score=0.733387
Step 3681 validation: mae_score=0.72438
Step 3684 validation: mae_score=0.709372
Step 3687 validation: mae_score=0.75971
Step 3690 validation: mae_score=0.723542
Step 3693 validation: mae_score=0.727233
Step 3696 validation: mae_score=0.75462
Step 3699 validation: mae_score=0.741268
Step 3702 validation: mae_score=0.748053
Step 3705 validation: mae_score=0.706492
Step 3708 validation: mae_score=0.72795
Step 3711 validation: mae_score=0.722423
Step 3714 validation: mae_score=0.753931
Step 3717 validation: mae_score=0.80771
Step 3720 validation:

Step 4251 validation: mae_score=0.766065
Step 4254 validation: mae_score=0.74184
Step 4257 validation: mae_score=0.730282
Step 4260 validation: mae_score=0.770252
Step 4263 validation: mae_score=0.778923
Step 4266 validation: mae_score=0.74397
Step 4269 validation: mae_score=0.79954
Step 4272 validation: mae_score=0.879528
Step 4275 validation: mae_score=0.758702
Step 4278 validation: mae_score=0.750523
Step 4281 validation: mae_score=0.724198
Step 4284 validation: mae_score=0.771065
Step 4287 validation: mae_score=0.770755
Step 4290 validation: mae_score=0.781799
Step 4293 validation: mae_score=0.7666
Step 4296 validation: mae_score=0.7141
Step 4299 validation: mae_score=0.734771
Step 4302 validation: mae_score=0.771511
Step 4305 validation: mae_score=0.798024
Step 4308 validation: mae_score=0.76921
Step 4311 validation: mae_score=0.739255
Step 4314 validation: mae_score=0.755829
Step 4317 validation: mae_score=0.731414
Step 4320 validation: mae_score=0.777151
Step 4323 validation: ma

Step 4854 validation: mae_score=0.724208
Step 4857 validation: mae_score=0.751881
Step 4860 validation: mae_score=0.785695
Step 4863 validation: mae_score=0.761498
Step 4866 validation: mae_score=0.714749
Step 4869 validation: mae_score=0.759326
Step 4872 validation: mae_score=0.762407
Step 4875 validation: mae_score=0.720631
Step 4878 validation: mae_score=0.753281
Step 4881 validation: mae_score=0.727526
Step 4884 validation: mae_score=0.740345
Step 4887 validation: mae_score=0.791747
Step 4890 validation: mae_score=0.746376
Step 4893 validation: mae_score=0.82912
Step 4896 validation: mae_score=0.840245
Step 4899 validation: mae_score=0.828228
Step 4902 validation: mae_score=0.722193
Step 4905 validation: mae_score=0.747996
Step 4908 validation: mae_score=0.716906
Step 4911 validation: mae_score=0.703408
Step 4914 validation: mae_score=0.733758
Step 4917 validation: mae_score=0.773733
Step 4920 validation: mae_score=0.747208
Step 4923 validation: mae_score=0.770625
Step 4926 validat

Step 5457 validation: mae_score=0.93979
Step 5460 validation: mae_score=0.760369
Step 5463 validation: mae_score=0.735177
Step 5466 validation: mae_score=0.872352
Step 5469 validation: mae_score=0.841157
Step 5472 validation: mae_score=0.788981
Step 5475 validation: mae_score=0.743748
Step 5478 validation: mae_score=0.781149
Step 5481 validation: mae_score=0.745234
Step 5484 validation: mae_score=0.764823
Step 5487 validation: mae_score=0.730441
Step 5490 validation: mae_score=0.758586
Step 5493 validation: mae_score=0.72408
Step 5496 validation: mae_score=0.736338
Step 5499 validation: mae_score=0.764953
Step 5502 validation: mae_score=0.734584
Step 5505 validation: mae_score=0.752582
Step 5508 validation: mae_score=0.808013
Step 5511 validation: mae_score=0.824834
Step 5514 validation: mae_score=0.739387
Step 5517 validation: mae_score=0.784643
Step 5520 validation: mae_score=0.771383
Step 5523 validation: mae_score=0.770015
Step 5526 validation: mae_score=0.783116
Step 5529 validati

Step 6060 validation: mae_score=0.751508
Step 6063 validation: mae_score=0.738849
Step 6066 validation: mae_score=0.73381
Step 6069 validation: mae_score=0.729495
Step 6072 validation: mae_score=0.710961
Step 6075 validation: mae_score=0.761663
Step 6078 validation: mae_score=0.720787
Step 6081 validation: mae_score=0.695571
Step 6084 validation: mae_score=0.720909
Step 6087 validation: mae_score=0.745974
Step 6090 validation: mae_score=0.762356
Step 6093 validation: mae_score=0.732089
Step 6096 validation: mae_score=0.7365
Step 6099 validation: mae_score=0.788328
Step 6102 validation: mae_score=0.747573
Step 6105 validation: mae_score=0.74432
Step 6108 validation: mae_score=0.731936
Step 6111 validation: mae_score=0.721483
Step 6114 validation: mae_score=0.740371
Step 6117 validation: mae_score=0.728246
Step 6120 validation: mae_score=0.755644
Step 6123 validation: mae_score=0.776285
Step 6126 validation: mae_score=0.713477
Step 6129 validation: mae_score=0.731756
Step 6132 validation

Step 6663 validation: mae_score=0.780591
Step 6666 validation: mae_score=0.771035
Step 6669 validation: mae_score=0.759076
Step 6672 validation: mae_score=0.7315
Step 6675 validation: mae_score=0.748998
Step 6678 validation: mae_score=0.757146
Step 6681 validation: mae_score=0.789739
Step 6684 validation: mae_score=0.819096
Step 6687 validation: mae_score=0.955336
Step 6690 validation: mae_score=0.762557
Step 6693 validation: mae_score=0.801932
Step 6696 validation: mae_score=0.905434
Step 6699 validation: mae_score=0.738577
Step 6702 validation: mae_score=0.761995
Step 6705 validation: mae_score=0.808748
Step 6708 validation: mae_score=0.812437
Step 6711 validation: mae_score=0.752859
Step 6714 validation: mae_score=0.737822
Step 6717 validation: mae_score=0.737255
Step 6720 validation: mae_score=0.762952
Step 6723 validation: mae_score=0.788909
Step 6726 validation: mae_score=0.78892
Step 6729 validation: mae_score=0.83405
Step 6732 validation: mae_score=0.849386
Step 6735 validation

Step 7266 validation: mae_score=0.731888
Step 7269 validation: mae_score=0.727222
Step 7272 validation: mae_score=0.707182
Step 7275 validation: mae_score=0.731427
Step 7278 validation: mae_score=0.716204
Step 7281 validation: mae_score=0.713257
Step 7284 validation: mae_score=0.758683
Step 7287 validation: mae_score=0.774321
Step 7290 validation: mae_score=0.73822
Step 7293 validation: mae_score=0.717451
Step 7296 validation: mae_score=0.78669
Step 7299 validation: mae_score=0.722222
Step 7302 validation: mae_score=0.72098
Step 7305 validation: mae_score=0.734349
Step 7308 validation: mae_score=0.758462
Step 7311 validation: mae_score=0.73702
Step 7314 validation: mae_score=0.751497
Step 7317 validation: mae_score=0.780387
Step 7320 validation: mae_score=0.771637
Step 7323 validation: mae_score=0.764439
Step 7326 validation: mae_score=0.718727
Step 7329 validation: mae_score=0.745089
Step 7332 validation: mae_score=0.773306
Step 7335 validation: mae_score=0.746466
Step 7338 validation

Step 7869 validation: mae_score=0.722885
Step 7872 validation: mae_score=0.687994
Step 7875 validation: mae_score=0.76461
Step 7878 validation: mae_score=0.718446
Step 7881 validation: mae_score=0.745541
Step 7884 validation: mae_score=0.789364
Step 7887 validation: mae_score=0.734419
Step 7890 validation: mae_score=0.712763
Step 7893 validation: mae_score=0.817942
Step 7896 validation: mae_score=0.768704
Step 7899 validation: mae_score=0.738154
Step 7902 validation: mae_score=0.765112
Step 7905 validation: mae_score=0.729825
Step 7908 validation: mae_score=0.71559
Step 7911 validation: mae_score=0.779688
Step 7914 validation: mae_score=0.706536
Step 7917 validation: mae_score=0.773201
Step 7920 validation: mae_score=0.760548
Step 7923 validation: mae_score=0.785001
Step 7926 validation: mae_score=0.740707
Step 7929 validation: mae_score=0.720203
Step 7932 validation: mae_score=0.753536
Step 7935 validation: mae_score=0.710302
Step 7938 validation: mae_score=0.71033
Step 7941 validatio

Step 8472 validation: mae_score=0.727059
Step 8475 validation: mae_score=0.820972
Step 8478 validation: mae_score=0.739374
Step 8481 validation: mae_score=0.764992
Step 8484 validation: mae_score=0.727697
Step 8487 validation: mae_score=0.758001
Step 8490 validation: mae_score=0.887204
Step 8493 validation: mae_score=0.795112
Step 8496 validation: mae_score=0.712342
Step 8499 validation: mae_score=0.750753
Step 8502 validation: mae_score=0.744303
Step 8505 validation: mae_score=0.754882
Step 8508 validation: mae_score=0.727751
Step 8511 validation: mae_score=0.737379
Step 8514 validation: mae_score=0.735242
Step 8517 validation: mae_score=0.711361
Step 8520 validation: mae_score=0.720344
Step 8523 validation: mae_score=0.753957
Step 8526 validation: mae_score=0.934964
Step 8529 validation: mae_score=0.752667
Step 8532 validation: mae_score=0.717625
Step 8535 validation: mae_score=0.757987
Step 8538 validation: mae_score=0.715423
Step 8541 validation: mae_score=0.72877
Step 8544 validat

Step 9075 validation: mae_score=0.765647
Step 9078 validation: mae_score=0.790267
Step 9081 validation: mae_score=0.745392
Step 9084 validation: mae_score=0.783379
Step 9087 validation: mae_score=0.792945
Step 9090 validation: mae_score=0.73647
Step 9093 validation: mae_score=0.715308
Step 9096 validation: mae_score=0.741164
Step 9099 validation: mae_score=0.76946
Step 9102 validation: mae_score=0.763961
Step 9105 validation: mae_score=0.743401
Step 9108 validation: mae_score=0.752212
Step 9111 validation: mae_score=0.76737
Step 9114 validation: mae_score=0.738829
Step 9117 validation: mae_score=0.745552
Step 9120 validation: mae_score=0.750861
Step 9123 validation: mae_score=0.739547
Step 9126 validation: mae_score=0.745571
Step 9129 validation: mae_score=0.748542
Step 9132 validation: mae_score=0.759506
Step 9135 validation: mae_score=0.758721
Step 9138 validation: mae_score=0.800951
Step 9141 validation: mae_score=0.749461
Step 9144 validation: mae_score=0.822758
Step 9147 validatio

Step 9678 validation: mae_score=0.733112
Step 9681 validation: mae_score=0.789951
Step 9684 validation: mae_score=0.743199
Step 9687 validation: mae_score=0.731453
Step 9690 validation: mae_score=0.735908
Step 9693 validation: mae_score=0.711508
Step 9696 validation: mae_score=0.709466
Step 9699 validation: mae_score=0.761759
Step 9702 validation: mae_score=0.719524
Step 9705 validation: mae_score=0.725137
Step 9708 validation: mae_score=0.70736
Step 9711 validation: mae_score=0.717571
Step 9714 validation: mae_score=0.761615
Step 9717 validation: mae_score=0.745517
Step 9720 validation: mae_score=0.76802
Step 9723 validation: mae_score=0.770977
Step 9726 validation: mae_score=0.792324
Step 9729 validation: mae_score=0.773801
Step 9732 validation: mae_score=0.755716
Step 9735 validation: mae_score=0.869665
Step 9738 validation: mae_score=0.738292
Step 9741 validation: mae_score=0.766797
Step 9744 validation: mae_score=0.797713
Step 9747 validation: mae_score=0.726517
Step 9750 validati

Step 10275 validation: mae_score=0.731741
Step 10278 validation: mae_score=0.708217
Step 10281 validation: mae_score=0.708428
Step 10284 validation: mae_score=0.766562
Step 10287 validation: mae_score=0.742049
Step 10290 validation: mae_score=0.717666
Step 10293 validation: mae_score=0.749845
Step 10296 validation: mae_score=0.719498
Step 10299 validation: mae_score=0.734802
Step 10302 validation: mae_score=0.713921
Step 10305 validation: mae_score=0.727168
Step 10308 validation: mae_score=0.710834
Step 10311 validation: mae_score=0.736002
Step 10314 validation: mae_score=0.879012
Step 10317 validation: mae_score=0.757719
Step 10320 validation: mae_score=0.762422
Step 10323 validation: mae_score=0.762966
Step 10326 validation: mae_score=0.768213
Step 10329 validation: mae_score=0.792774
Step 10332 validation: mae_score=0.737761
Step 10335 validation: mae_score=0.710115
Step 10338 validation: mae_score=0.729479
Step 10341 validation: mae_score=0.785796
Step 10344 validation: mae_score=0

Step 10863 validation: mae_score=0.739788
Step 10866 validation: mae_score=0.713971
Step 10869 validation: mae_score=0.732882
Step 10872 validation: mae_score=0.74391
Step 10875 validation: mae_score=0.756377
Step 10878 validation: mae_score=0.726585
Step 10881 validation: mae_score=0.763456
Step 10884 validation: mae_score=0.805756
Step 10887 validation: mae_score=0.746351
Step 10890 validation: mae_score=0.748749
Step 10893 validation: mae_score=0.763669
Step 10896 validation: mae_score=0.729195
Step 10899 validation: mae_score=0.761481
Step 10902 validation: mae_score=0.73991
Step 10905 validation: mae_score=0.763248
Step 10908 validation: mae_score=0.727531
Step 10911 validation: mae_score=0.735985
Step 10914 validation: mae_score=0.719355
Step 10917 validation: mae_score=0.730856
Step 10920 validation: mae_score=0.739128
Step 10923 validation: mae_score=0.778341
Step 10926 validation: mae_score=0.732435
Step 10929 validation: mae_score=0.740648
Step 10932 validation: mae_score=0.7

1.002245635986328

In [4]:
# little function to calc metrics on this data
out=get_them_metrics(
            model,
            datasets,
            metrics,
            metric_labels,
            transformers)
# makes a nice dataframe
pd_out = pd.DataFrame([out], columns=['tr_mse', 'tr_r2', 'tr_mae', 'tr_rmse',
                                        'val_mse', 'val_r2', 'val_mae', 'val_rmse',
                                        'te_mse', 'te_r2', 'te_mae', 'te_rmse'])
print(pd_out)

         tr_mse   tr_r2      tr_mae     tr_rmse       val_mse    val_r2  \
0  42155.950355  0.2159  164.936179  205.319143  44061.564081  0.146461   

      val_mae    val_rmse        te_mse    te_r2      te_mae    te_rmse  
0  167.858829  209.908466  42246.346131  0.19779  163.180392  205.53916  


## DTNN

In [5]:
patience=3

# This loads the data without shuffling or splitting
tasks, datasets, transformers = dc.molnet.load_qm7(
    shard_size=2000,
    featurizer=dc.feat.CoulombMatrix(max_atoms=23),
    splitter='stratified')

callback = tf.keras.callbacks.EarlyStopping(monitor='mae', patience=3)

# the datasets object is already split into the train, validation and test dataset 
train_dataset, valid_dataset, test_dataset = datasets

fit_transformers = [dc.trans.CoulombFitTransformer(train_dataset)]

# this loads in a general purpose regression model
model = dc.models.DTNNModel(
    n_tasks = len(test_dataset.tasks) # number of input features, i.e. the length of the ECFPs
)

# this sets up a callback on the validation
callback = dc.models.ValidationCallback(
            valid_dataset,
            patience,
            metrics[metric_selector])
# fit da model
model.fit(train_dataset, nb_epoch=100, callbacks=callback)



  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)


Step 3 validation: mae_score=1.50871
Step 6 validation: mae_score=1.04493
Step 9 validation: mae_score=1.16019
Step 12 validation: mae_score=0.824027
Step 15 validation: mae_score=0.989748
Step 18 validation: mae_score=0.803524
Step 21 validation: mae_score=0.854378
Step 24 validation: mae_score=0.851334
Step 27 validation: mae_score=0.776245
Step 30 validation: mae_score=0.786026
Step 33 validation: mae_score=0.764884
Step 36 validation: mae_score=0.778969
Step 39 validation: mae_score=0.778876
Step 42 validation: mae_score=0.755708
Step 45 validation: mae_score=0.751947
Step 48 validation: mae_score=0.754305
Step 51 validation: mae_score=0.75575
Step 54 validation: mae_score=0.749723
Step 57 validation: mae_score=0.747802
Step 60 validation: mae_score=0.749095
Step 63 validation: mae_score=0.749054
Step 66 validation: mae_score=0.749381
Step 69 validation: mae_score=0.748069
Step 72 validation: mae_score=0.74268
Step 75 validation: mae_score=0.741122
Step 78 validation: mae_score=0.7

Step 624 validation: mae_score=0.708815
Step 627 validation: mae_score=0.709289
Step 630 validation: mae_score=0.711828
Step 633 validation: mae_score=0.717042
Step 636 validation: mae_score=0.705424
Step 639 validation: mae_score=0.707046
Step 642 validation: mae_score=0.729086
Step 645 validation: mae_score=0.699956
Step 648 validation: mae_score=0.69669
Step 651 validation: mae_score=0.723748
Step 654 validation: mae_score=0.702335
Step 657 validation: mae_score=0.698974
Step 660 validation: mae_score=0.719593
Step 663 validation: mae_score=0.695229
Step 666 validation: mae_score=0.710256
Step 669 validation: mae_score=0.715575
Step 672 validation: mae_score=0.702594
Step 675 validation: mae_score=0.711823
Step 678 validation: mae_score=0.701322
Step 681 validation: mae_score=0.697568
Step 684 validation: mae_score=0.698638
Step 687 validation: mae_score=0.698668
Step 690 validation: mae_score=0.71688
Step 693 validation: mae_score=0.703433
Step 696 validation: mae_score=0.699625
St

Step 1236 validation: mae_score=0.70342
Step 1239 validation: mae_score=0.715111
Step 1242 validation: mae_score=0.722376
Step 1245 validation: mae_score=0.708685
Step 1248 validation: mae_score=0.708251
Step 1251 validation: mae_score=0.713925
Step 1254 validation: mae_score=0.730018
Step 1257 validation: mae_score=0.716761
Step 1260 validation: mae_score=0.722208
Step 1263 validation: mae_score=0.733893
Step 1266 validation: mae_score=0.719387
Step 1269 validation: mae_score=0.721206
Step 1272 validation: mae_score=0.719704
Step 1275 validation: mae_score=0.710249
Step 1278 validation: mae_score=0.707378
Step 1281 validation: mae_score=0.717346
Step 1284 validation: mae_score=0.715714
Step 1287 validation: mae_score=0.71954
Step 1290 validation: mae_score=0.70584
Step 1293 validation: mae_score=0.718299
Step 1296 validation: mae_score=0.738463
Step 1299 validation: mae_score=0.719704
Step 1302 validation: mae_score=0.720554
Step 1305 validation: mae_score=0.725196
Step 1308 validatio

Step 1839 validation: mae_score=0.72355
Step 1842 validation: mae_score=0.742868
Step 1845 validation: mae_score=0.760933
Step 1848 validation: mae_score=0.73711
Step 1851 validation: mae_score=0.734375
Step 1854 validation: mae_score=0.732105
Step 1857 validation: mae_score=0.747157
Step 1860 validation: mae_score=0.731025
Step 1863 validation: mae_score=0.72041
Step 1866 validation: mae_score=0.735682
Step 1869 validation: mae_score=0.718322
Step 1872 validation: mae_score=0.734353
Step 1875 validation: mae_score=0.736169
Step 1878 validation: mae_score=0.71314
Step 1881 validation: mae_score=0.713949
Step 1884 validation: mae_score=0.724115
Step 1887 validation: mae_score=0.731685
Step 1890 validation: mae_score=0.731158
Step 1893 validation: mae_score=0.730415
Step 1896 validation: mae_score=0.723333
Step 1899 validation: mae_score=0.719025
Step 1902 validation: mae_score=0.721906
Step 1905 validation: mae_score=0.726077
Step 1908 validation: mae_score=0.736752
Step 1911 validation

Step 2442 validation: mae_score=0.743052
Step 2445 validation: mae_score=0.741106
Step 2448 validation: mae_score=0.740374
Step 2451 validation: mae_score=0.742144
Step 2454 validation: mae_score=0.743835
Step 2457 validation: mae_score=0.746278
Step 2460 validation: mae_score=0.7416
Step 2463 validation: mae_score=0.745934
Step 2466 validation: mae_score=0.748329
Step 2469 validation: mae_score=0.750305
Step 2472 validation: mae_score=0.740829
Step 2475 validation: mae_score=0.735178
Step 2478 validation: mae_score=0.748416
Step 2481 validation: mae_score=0.737779
Step 2484 validation: mae_score=0.746499
Step 2487 validation: mae_score=0.752255
Step 2490 validation: mae_score=0.759091
Step 2493 validation: mae_score=0.749551
Step 2496 validation: mae_score=0.744673
Step 2499 validation: mae_score=0.742809
Step 2502 validation: mae_score=0.755404
Step 2505 validation: mae_score=0.755038
Step 2508 validation: mae_score=0.749736
Step 2511 validation: mae_score=0.752302
Step 2514 validati

Step 3045 validation: mae_score=0.771851
Step 3048 validation: mae_score=0.773509
Step 3051 validation: mae_score=0.764389
Step 3054 validation: mae_score=0.784115
Step 3057 validation: mae_score=0.782016
Step 3060 validation: mae_score=0.753064
Step 3063 validation: mae_score=0.747841
Step 3066 validation: mae_score=0.760918
Step 3069 validation: mae_score=0.742577
Step 3072 validation: mae_score=0.736362
Step 3075 validation: mae_score=0.743279
Step 3078 validation: mae_score=0.739459
Step 3081 validation: mae_score=0.744897
Step 3084 validation: mae_score=0.752997
Step 3087 validation: mae_score=0.756301
Step 3090 validation: mae_score=0.754779
Step 3093 validation: mae_score=0.753902
Step 3096 validation: mae_score=0.767769
Step 3099 validation: mae_score=0.758635
Step 3102 validation: mae_score=0.765651
Step 3105 validation: mae_score=0.810395
Step 3108 validation: mae_score=0.770734
Step 3111 validation: mae_score=0.770389
Step 3114 validation: mae_score=0.76253
Step 3117 validat

Step 3648 validation: mae_score=0.796911
Step 3651 validation: mae_score=0.771982
Step 3654 validation: mae_score=0.76852
Step 3657 validation: mae_score=0.784085
Step 3660 validation: mae_score=0.776166
Step 3663 validation: mae_score=0.764661
Step 3666 validation: mae_score=0.759044
Step 3669 validation: mae_score=0.765461
Step 3672 validation: mae_score=0.764799
Step 3675 validation: mae_score=0.750167
Step 3678 validation: mae_score=0.770369
Step 3681 validation: mae_score=0.766253
Step 3684 validation: mae_score=0.7709
Step 3687 validation: mae_score=0.769059
Step 3690 validation: mae_score=0.755848
Step 3693 validation: mae_score=0.756922
Step 3696 validation: mae_score=0.751767
Step 3699 validation: mae_score=0.765863
Step 3702 validation: mae_score=0.784541
Step 3705 validation: mae_score=0.77275
Step 3708 validation: mae_score=0.772478
Step 3711 validation: mae_score=0.766212
Step 3714 validation: mae_score=0.758783
Step 3717 validation: mae_score=0.765342
Step 3720 validation

Step 4251 validation: mae_score=0.770146
Step 4254 validation: mae_score=0.770013
Step 4257 validation: mae_score=0.774506
Step 4260 validation: mae_score=0.767606
Step 4263 validation: mae_score=0.776964
Step 4266 validation: mae_score=0.776767
Step 4269 validation: mae_score=0.779542
Step 4272 validation: mae_score=0.781965
Step 4275 validation: mae_score=0.782909
Step 4278 validation: mae_score=0.787936
Step 4281 validation: mae_score=0.785952
Step 4284 validation: mae_score=0.789742
Step 4287 validation: mae_score=0.794017
Step 4290 validation: mae_score=0.812276
Step 4293 validation: mae_score=0.785596
Step 4296 validation: mae_score=0.769281
Step 4299 validation: mae_score=0.766242
Step 4302 validation: mae_score=0.762733
Step 4305 validation: mae_score=0.766472
Step 4308 validation: mae_score=0.779774
Step 4311 validation: mae_score=0.788038
Step 4314 validation: mae_score=0.791793
Step 4317 validation: mae_score=0.787917
Step 4320 validation: mae_score=0.78634
Step 4323 validat

Step 4854 validation: mae_score=0.777066
Step 4857 validation: mae_score=0.781418
Step 4860 validation: mae_score=0.784753
Step 4863 validation: mae_score=0.788958
Step 4866 validation: mae_score=0.800736
Step 4869 validation: mae_score=0.789595
Step 4872 validation: mae_score=0.79065
Step 4875 validation: mae_score=0.784636
Step 4878 validation: mae_score=0.77769
Step 4881 validation: mae_score=0.789324
Step 4884 validation: mae_score=0.793409
Step 4887 validation: mae_score=0.78828
Step 4890 validation: mae_score=0.797806
Step 4893 validation: mae_score=0.812447
Step 4896 validation: mae_score=0.808927
Step 4899 validation: mae_score=0.808524
Step 4902 validation: mae_score=0.789034
Step 4905 validation: mae_score=0.79062
Step 4908 validation: mae_score=0.798306
Step 4911 validation: mae_score=0.793489
Step 4914 validation: mae_score=0.790572
Step 4917 validation: mae_score=0.7914
Step 4920 validation: mae_score=0.790956
Step 4923 validation: mae_score=0.797633
Step 4926 validation: 

Step 5457 validation: mae_score=0.804353
Step 5460 validation: mae_score=0.795312
Step 5463 validation: mae_score=0.789276
Step 5466 validation: mae_score=0.790251
Step 5469 validation: mae_score=0.801258
Step 5472 validation: mae_score=0.799493
Step 5475 validation: mae_score=0.81062
Step 5478 validation: mae_score=0.800783
Step 5481 validation: mae_score=0.80336
Step 5484 validation: mae_score=0.803415
Step 5487 validation: mae_score=0.80675
Step 5490 validation: mae_score=0.797925
Step 5493 validation: mae_score=0.799604
Step 5496 validation: mae_score=0.80387
Step 5499 validation: mae_score=0.796615


0.4357027053833008

In [6]:
# little function to calc metrics on this data
out=get_them_metrics(
            model,
            datasets,
            metrics,
            metric_labels,
            transformers)
# makes a nice dataframe
pd_out = pd.DataFrame([out], columns=['tr_mse', 'tr_r2', 'tr_mae', 'tr_rmse',
                                        'val_mse', 'val_r2', 'val_mae', 'val_rmse',
                                        'te_mse', 'te_r2', 'te_mae', 'te_rmse'])
print(pd_out)

         tr_mse     tr_r2      tr_mae     tr_rmse       val_mse    val_r2  \
0  19889.560397  0.606074  111.881211  141.030353  49244.905068  0.149077   

      val_mae    val_rmse        te_mse     te_r2      te_mae     te_rmse  
0  179.053167  221.911931  48458.332315  0.152918  173.510713  220.132534  


## Kernel ridge regression

In [7]:
train_dataset_X = [x.flatten() for x in train_dataset.X[0:5]]
train_dataset_y = [x.flatten() for x in train_dataset.y[0:5]]


In [8]:
from sklearn.kernel_ridge import KernelRidge
from sklearn.metrics import mean_absolute_error

tasks, datasets, transformers = dc.molnet.load_qm7(
    featurizer=dc.feat.CoulombMatrix(max_atoms=23), 
    splitter='stratified', 
    move_mean=False)

train_dataset, valid_dataset, test_dataset = datasets

train_dataset, valid_dataset, test_dataset = datasets

train_dataset_X = [x.flatten() for x in train_dataset.X]
train_dataset_y = [x.flatten() for x in train_dataset.y]

test_dataset_X = [x.flatten() for x in test_dataset.X]
test_dataset_y = [x.flatten() for x in test_dataset.y]

valid_dataset_X = [x.flatten() for x in valid_dataset.X]
valid_dataset_y = [x.flatten() for x in valid_dataset.y]

#train_dataset = 


sklearn_model = KernelRidge(kernel="rbf", alpha=5e-4, gamma=0.008)


#dc_model = dc.models.SklearnModel(sklearn_model)



# Fit trained model
sklearn_model.fit(train_dataset_X, train_dataset_y)

KernelRidge(alpha=0.0005, gamma=0.008, kernel='rbf')

In [9]:
a=mean_absolute_error(
    train_dataset_y,
    sklearn_model.predict(train_dataset_X))

b=mean_absolute_error(
    valid_dataset_y,
    sklearn_model.predict(valid_dataset_X))


c=mean_absolute_error(
    test_dataset_y,
    sklearn_model.predict(test_dataset_X))

print('Normalised values (I do not know how to unnormalise this data)')

print(f'train error sklearn {a}')
print(f'valid error sklearn {b}')
print(f'test error sklearn {c}')


Normalised values (I do not know how to unnormalise this data)
train error sklearn 0.0008054387649134457
valid error sklearn 0.6945489133221255
test error sklearn 0.6883037767813053


## I cannot get kernel ridge regression to work with Coulomb matrix

In [10]:
from sklearn.kernel_ridge import KernelRidge

tasks, datasets, transformers = dc.molnet.load_qm7(
    featurizer=dc.feat.CoulombMatrix(max_atoms=23), 
    splitter='stratified', 
    move_mean=False)

train_dataset, valid_dataset, test_dataset = datasets

train_dataset, valid_dataset, test_dataset = datasets

train_dataset, valid_dataset, test_dataset = datasets

train_dataset_X = [x.flatten() for x in train_dataset.X]
train_dataset_y = [x.flatten() for x in train_dataset.y]

test_dataset_X = [x.flatten() for x in test_dataset.X]
test_dataset_y = [x.flatten() for x in test_dataset.y]

valid_dataset_X = [x.flatten() for x in valid_dataset.X]
valid_dataset_y = [x.flatten() for x in valid_dataset.y]

#dataset = dc.data.DiskDataset.from_numpy(X)
train_dataset = dc.data.datasets.DiskDataset.from_numpy(
    X=train_dataset_X, 
    y=train_dataset_y)

train_dataset = dc.data.datasets.DiskDataset.from_numpy(
    X=valid_dataset_X, 
    y=valid_dataset_y)

train_dataset = dc.data.datasets.DiskDataset.from_numpy(
    X=test_dataset_X, 
    y=test_dataset_y)

def model_builder(model_dir):
  sklearn_model = KernelRidge(kernel="rbf", alpha=5e-4, gamma=0.008)
  return dc.models.SklearnModel(sklearn_model, model_dir)

dc_model = dc.models.SklearnModel(sklearn_model)

model = dc.models.SingletaskToMultitask(
    tasks, 
    model_builder)

# Fit trained model
model.fit(train_dataset)



In [11]:
model

SingletaskToMultitask(model_builder=<function model_builder at 0x000002408AA2AAF8>,
                      model_dir='C:\\Users\\ella_\\AppData\\Local\\Temp\\tmp8lmvntxq',
                      tasks=['u0_atom'])

In [12]:
# little function to calc metrics on this data
out=get_them_metrics(
            model,
            datasets,
            metrics,
            metric_labels,
            transformers)
# makes a nice dataframe
pd_out = pd.DataFrame([out], columns=['tr_mse', 'tr_r2', 'tr_mae', 'tr_rmse',
                                        'val_mse', 'val_r2', 'val_mae', 'val_rmse',
                                        'te_mse', 'te_r2', 'te_mae', 'te_rmse'])
print(pd_out)

ValueError: Found array with dim 3. Estimator expected <= 2.