A step-by-step way of `utils_benchmark.test_model_performance()`:

In [35]:
import json
import pandas as pd
import time

from rdkit import Chem
from transformers import BertTokenizer

import model_roberta
import utils_split
import utils_mol

In [2]:
data_df = pd.read_csv('https://netknowledge.github.io/ADMET/datasets/solubility_Biogen.csv')

In [3]:
data_df

Unnamed: 0,logS,canonical_smiles
0,-5.548020,CNc1cc(Nc2cccn(-c3ccccn3)c2=O)nn2c(C(=O)N[C@@H...
1,-5.071409,CCOc1cc2nn(CCC(C)(C)O)cc2cc1NC(=O)c1cccc(C(F)F)n1
2,-3.925969,CC(C)(Oc1ccc(-c2cnc(N)c(-c3ccc(Cl)cc3)c2)cc1)C...
3,-4.535280,CC#CC(=O)N[C@H]1CCCN(c2c(F)cc(C(N)=O)c3[nH]c(C...
4,-4.709963,C=CC(=O)N1CCC[C@@H](n2nc(-c3ccc(Oc4ccccc4)cc3)...
...,...,...
2168,-3.733424,Cc1cc(C)cc(C(=O)NCCCNc2ncccn2)c1
2169,-4.037319,CCc1noc(COc2c(C)ccnc2Cl)n1
2170,-4.912777,CC(C)(C)Cc1nnc(-c2cnc3onc(C4CCCC4)c3c2)o1
2171,-3.781930,Cc1nonc1C(=O)NCCc1c[nH]c2cccc(C3(O)CCOCC3)c12


In [5]:
data_df['mol'] = data_df['canonical_smiles'].apply(Chem.MolFromSmiles)

In [11]:
def get_split_index(mol_series, split_trial=5, split_outpath=None):
    """"""
    tt_split_seeds = list(range(split_trial))
    train_test_index = [utils_split.scaffold_split(mol_series=mol_series, seed=i) for i in tt_split_seeds]
    # Further split train set for models with hyperparameter tuning
    tvt_split_seeds = list(range(split_trial, split_trial*2))
    train_valid_test_index = []
    for i, (train_index, test_index) in zip(tvt_split_seeds, train_test_index):
        train_valid_index = utils_split.scaffold_split(mol_series=mol_series.loc[train_index], sizes=(.85,.15), seed=i)
        train_valid_test_index.append(train_valid_index + (test_index,))
    if split_outpath is not None:
        split_index = {
            'train_test_index': dict(zip(tt_split_seeds, train_test_index)), 
            'train_valid_test_index': dict(zip(tvt_split_seeds, train_valid_test_index))
        }
        with open(split_outpath, 'w') as fout:
            json.dump(split_index, fout)
    return train_test_index, train_valid_test_index

In [21]:
train_test_index, train_valid_test_index = get_split_index(data_df['mol'], split_outpath='solubility_Biogen_splitindex.json')

In [20]:
all(len((set(train_idx) | set(test_idx)) - set(data_df.index)) == 0 for train_idx, test_idx in train_test_index)

True

In [25]:
utils_mol.append_morgan_sentence(data_df)

In [26]:
data_df

Unnamed: 0,logS,canonical_smiles,mol,morgan_sentence_r_0_s_0,morgan_sentence_r_1_s_0_radiusFirst,morgan_sentence_r_1_s_0_atomFirst,morgan_sentence_r_2_s_0_radiusFirst,morgan_sentence_r_2_s_0_atomFirst
0,-5.548020,CNc1cc(Nc2cccn(-c3ccccn3)c2=O)nn2c(C(=O)N[C@@H...,<rdkit.Chem.rdchem.Mol object at 0x7badd8be0dd0>,2246728737 847961216 3217380708 3218693969 321...,2246728737 3824063894 847961216 1965692378 321...,2246728737 847961216 3217380708 3218693969 321...,2246728737 3824063894 847961216 1965692378 311...,2246728737 847961216 3217380708 3218693969 321...
1,-5.071409,CCOc1cc2nn(CCC(C)(C)O)cc2cc1NC(=O)c1cccc(C(F)F)n1,<rdkit.Chem.rdchem.Mol object at 0x7badd8be0cf0>,2246728737 2245384272 864674487 3217380708 321...,2246728737 3542456614 2245384272 3994088662 86...,2246728737 2245384272 864674487 3217380708 321...,2246728737 3542456614 2245384272 3994088662 26...,2246728737 2245384272 864674487 3217380708 321...
2,-3.925969,CC(C)(Oc1ccc(-c2cnc(N)c(-c3ccc(Cl)cc3)c2)cc1)C...,<rdkit.Chem.rdchem.Mol object at 0x7badd8be0e40>,2246728737 2245277810 2246728737 864674487 321...,2246728737 3537123720 2245277810 2442433719 22...,2246728737 2245277810 2246728737 864674487 321...,2246728737 3537123720 2245277810 2442433719 41...,2246728737 2245277810 2246728737 864674487 321...
3,-4.535280,CC#CC(=O)N[C@H]1CCCN(c2c(F)cc(C(N)=O)c3[nH]c(C...,<rdkit.Chem.rdchem.Mol object at 0x7badd8be0eb0>,2246728737 2245900962 2245900962 2246699815 86...,2246728737 3545074552 2245900962 4291903839 22...,2246728737 2245900962 2245900962 2246699815 86...,2246728737 3545074552 2245900962 4291903839 27...,2246728737 2245900962 2245900962 2246699815 86...
4,-4.709963,C=CC(=O)N1CCC[C@@H](n2nc(-c3ccc(Oc4ccccc4)cc3)...,<rdkit.Chem.rdchem.Mol object at 0x7badd8be0f20>,2246997334 2246703798 2246699815 864942730 209...,2246997334 3696402029 2246703798 723026879 224...,2246997334 2246703798 2246699815 864942730 209...,2246997334 3696402029 2246703798 723026879 386...,2246997334 2246703798 2246699815 864942730 209...
...,...,...,...,...,...,...,...,...
2168,-3.733424,Cc1cc(C)cc(C(=O)NCCCNc2ncccn2)c1,<rdkit.Chem.rdchem.Mol object at 0x7badd8c34a50>,2246728737 3217380708 3218693969 3217380708 22...,2246728737 422715066 3217380708 3207567135 321...,2246728737 3217380708 3218693969 3217380708 22...,2246728737 422715066 3217380708 3207567135 255...,2246728737 3217380708 3218693969 3217380708 22...
2169,-4.037319,CCc1noc(COc2c(C)ccnc2Cl)n1,<rdkit.Chem.rdchem.Mol object at 0x7badd8c34ac0>,2246728737 2245384272 3217380708 2041434490 31...,2246728737 3542456614 2245384272 618671879 321...,2246728737 2245384272 3217380708 2041434490 31...,2246728737 3542456614 2245384272 618671879 370...,2246728737 2245384272 3217380708 2041434490 31...
2170,-4.912777,CC(C)(C)Cc1nnc(-c2cnc3onc(C4CCCC4)c3c2)o1,<rdkit.Chem.rdchem.Mol object at 0x7badd8c34b30>,2246728737 2245277810 2246728737 2246728737 22...,2246728737 3537123720 2245277810 1914229733 22...,2246728737 2245277810 2246728737 2246728737 22...,2246728737 3537123720 2245277810 1914229733 34...,2246728737 2245277810 2246728737 2246728737 22...
2171,-3.781930,Cc1nonc1C(=O)NCCc1c[nH]c2cccc(C3(O)CCOCC3)c12,<rdkit.Chem.rdchem.Mol object at 0x7badd8c34ba0>,2246728737 3217380708 2041434490 3189457552 20...,2246728737 422715066 3217380708 4033380444 204...,2246728737 3217380708 2041434490 3189457552 20...,2246728737 422715066 3217380708 4033380444 196...,2246728737 3217380708 2041434490 3189457552 20...


In [34]:
tokenizer = BertTokenizer.from_pretrained('Keylab/MorganBERT_r1_radius')

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/173k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/94.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [36]:
def run_roberta_like():
    """"""
    model_path = 'Keylab/MorganBERT_r1_radius'
    output_dir = './tmp_models_%s/%s_%s' % ('MorganBERT', 'MorganBERT_r1_radius', 'Biogen-sol')
    results = []
    for train_index, test_index in train_test_index:
        start = time.time()
        eval_result = model_roberta.run_roberta_like_train_test(model_path, 
                                                                tokenizer, 
                                                                data_df['morgan_sentence_r_1_s_0_radiusFirst'].loc[train_index].tolist(), 
                                                                data_df['logS'].loc[train_index].tolist(), 
                                                                data_df['morgan_sentence_r_1_s_0_radiusFirst'].loc[test_index].tolist(), 
                                                                data_df['logS'].loc[test_index].tolist(), 
                                                                'regression', 
                                                                output_dir, 
                                                                30)
        end = time.time()
        eval_result['total_runtime_sec'] = (end - start)
        results.append(eval_result)
        break
    return pd.DataFrame(results)

In [37]:
model_perf = run_roberta_like()

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at Keylab/MorganBERT_r1_radius and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss
500,0.3461
1000,0.0809
1500,0.0482
2000,0.0343
2500,0.0249
3000,0.0195




In [38]:
model_perf

Unnamed: 0,eval_loss,eval_mse,eval_mae,eval_r2,eval_rmse,eval_pcc,eval_runtime,eval_samples_per_second,eval_steps_per_second,epoch,total_runtime_sec
0,0.315148,0.314924,0.398619,0.44044,0.561181,0.66427,2.2123,196.628,6.328,30.0,787.70544
