In [27]:
from sklearn.metrics import average_precision_score, roc_auc_score
import wandb

from datasets import load_dataset
import evaluate
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer, IntervalStrategy

import pandas as pd
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F


In [28]:
auc = evaluate.load("roc_auc")
accuracy = evaluate.load("accuracy")
metric = evaluate.load("accuracy")
f1 = evaluate.load("f1")
precison = evaluate.load("precision")
recall = evaluate.load("recall")

In [74]:
split_type = 'db_no_agree_no_dups'
dataset_name = 'DrugBank'
pretrained_path = "seyonec/PubChem10M_SMILES_BPE_450k"

In [75]:
dataset = load_dataset('csv', data_files={'train': f'split/{split_type}/{dataset_name}/train2.csv',
                                          'validation': f'split/{split_type}/{dataset_name}/val.csv',
                                          'test': f'split/{split_type}/{dataset_name}/test.csv',})

Using custom data configuration default-74e83418e4d2c9a0
Found cached dataset csv (/home/eyal/.cache/huggingface/datasets/csv/default-74e83418e4d2c9a0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/3 [00:00<?, ?it/s]

In [76]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'index', 'smiles', 'length', 'inchikey', 'name', 'groups', 'withdrawn_class', 'source'],
        num_rows: 3198
    })
    validation: Dataset({
        features: ['Unnamed: 0', 'index', 'smiles', 'length', 'inchikey', 'name', 'groups', 'withdrawn_class', 'source'],
        num_rows: 800
    })
    test: Dataset({
        features: ['Unnamed: 0', 'index', 'smiles', 'length', 'inchikey', 'name', 'groups', 'withdrawn_class', 'source'],
        num_rows: 2431
    })
})

In [77]:
dataset = dataset.rename_column('withdrawn_class', 'labels').\
            remove_columns(['Unnamed: 0', 'index', 'length', 'inchikey', 'groups', 'source']).\
            with_format('torch')

In [78]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
model = AutoModelForSequenceClassification.from_pretrained(pretrained_path, num_labels=2,
                                                           id2label={0: 'Not Withdrawn', 1:'Withdrawn'},
                                                           label2id={'Not Withdrawn': 0, 'Withdrawn': 1})

loading configuration file config.json from cache at /home/eyal/.cache/huggingface/hub/models--seyonec--PubChem10M_SMILES_BPE_450k/snapshots/c18fccd09b3326bf2d4633412c256d7db872156d/config.json
Model config RobertaConfig {
  "_name_or_path": "seyonec/PubChem10M_SMILES_BPE_450k",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.25.1",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 52000
}

loading file vocab.json from cache at /home/eyal/.cache/huggingface/hub/models

In [79]:
def tokenize_function(examples):
    return tokenizer(examples["smiles"], padding="max_length", truncation=True, max_length=300)

In [80]:
dataset = dataset.map(tokenize_function, batched=True)

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

In [61]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy_score = accuracy.compute(predictions=predictions, references=labels)
    auc_score = auc.compute(prediction_scores=logits[:, 1], references=labels)
    f1_score = f1.compute(predictions=predictions, references=labels)
    aupr = average_precision_score(y_score=logits[:, 1], y_true=labels)
    precision_score = precison.compute(predictions=predictions, references=labels)
    recall_score = recall.compute(predictions=predictions, references=labels)
    return {**f1_score , **{'PR-AUC': aupr}, **accuracy_score, **auc_score, **precision_score, **recall_score}

In [11]:
training_args = TrainingArguments(
    output_dir=f"./results/{split_type}/{dataset_name}/{pretrained_path}",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    weight_decay=0.01,
    evaluation_strategy=IntervalStrategy.STEPS,
    save_strategy=IntervalStrategy.STEPS,
    report_to='wandb',
    run_name=f'{pretrained_path} {split_type} {dataset_name}',
    logging_steps=50,
    save_steps=50,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset={'Validation': dataset["validation"], 'Test': dataset["test"]},
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [81]:
pretrained_path = 'results/db_no_agree_no_dups/DrugBank/seyonec/PubChem10M_SMILES_BPE_450k/checkpoint-550/'

In [82]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
model = AutoModelForSequenceClassification.from_pretrained(pretrained_path, num_labels=2,
                                                           id2label={0: 'Not Withdrawn', 1:'Withdrawn'},
                                                           label2id={'Not Withdrawn': 0, 'Withdrawn': 1})

loading file vocab.json
loading file merges.txt
loading file tokenizer.json
loading file added_tokens.json
loading file special_tokens_map.json
loading file tokenizer_config.json
loading configuration file results/db_no_agree_no_dups/DrugBank/seyonec/PubChem10M_SMILES_BPE_450k/checkpoint-550/config.json
Model config RobertaConfig {
  "_name_or_path": "results/db_no_agree_no_dups/DrugBank/seyonec/PubChem10M_SMILES_BPE_450k/checkpoint-550/",
  "architectures": [
    "RobertaForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "Not Withdrawn",
    "1": "Withdrawn"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "Not Withdrawn": 0,
    "Withdrawn": 1
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "r

In [83]:
from tqdm.auto import tqdm

In [84]:
preds = []
for row in tqdm(dataset['test']):
    output = torch.softmax(model(row['input_ids'][None, ...]).logits, -1)
    preds.append((row['name'], round(output[:, 1].item(), 4), row['labels'].item()))

  0%|          | 0/2431 [00:00<?, ?it/s]

In [85]:
preds

[('fondaparinux', 0.1192, 0),
 ('somatostatin', 0.4787, 0),
 ('degarelix', 0.6071, 0),
 ('afamelanotide', 0.2219, 0),
 ('vancomycin', 0.3895, 0),
 ('daptomycin', 0.1647, 0),
 ('vasopressin', 0.9756, 0),
 ('plecanatide', 0.0728, 0),
 ('hydroxocobalamin', 0.2348, 0),
 ('mecobalamin', 0.2074, 0),
 ('cyanocobalamin co-60', 0.0597, 1),
 ('lutetium lu 177 dotatate', 0.1153, 0),
 ('dotatate gallium ga-68', 0.1199, 0),
 ('ganirelix', 0.307, 0),
 ('copper oxodotreotide cu-64', 0.1143, 0),
 ('bacitracin', 0.0997, 0),
 ('histrelin', 0.1626, 0),
 ('edotreotide gallium ga-68', 0.088, 0),
 ('micafungin', 0.6008, 0),
 ('podophyllin', 0.9335, 0),
 ('linaclotide', 0.0999, 0),
 ('cetrorelix', 0.2081, 0),
 ('dotatate', 0.1093, 0),
 ('triptorelin', 0.2097, 0),
 ('nafarelin', 0.3783, 0),
 ('ceruletide', 0.3317, 0),
 ('anidulafungin', 0.1104, 0),
 ('bleomycin', 0.0963, 0),
 ('pentetreotide', 0.1237, 0),
 ('sucralfate', 0.2423, 0),
 ('dactinomycin', 0.2565, 0),
 ('depreotide', 0.2154, 1),
 ('deslanoside', 0.

In [86]:
sorted(preds, key=lambda x: x[1], reverse=True)

[('deferasirox', 0.9935, 0),
 ('oxyphenisatin acetate', 0.9932, 1),
 ('bentiromide', 0.9931, 1),
 ('bitolterol', 0.9924, 1),
 ('bisoxatin', 0.9923, 0),
 ('bifonazole', 0.9922, 0),
 ('clomifene', 0.9919, 0),
 ('ifenprodil', 0.9917, 1),
 ('brilliant green cation', 0.9916, 0),
 ('bisacodyl', 0.9916, 0),
 ('telmisartan', 0.9915, 0),
 ('gentian violet cation', 0.9913, 0),
 ('ethyl biscoumacetate', 0.9912, 1),
 ('brilliant blue g', 0.9911, 0),
 ('candesartan cilexetil', 0.9911, 0),
 ('imatinib', 0.9911, 0),
 ('patent blue', 0.991, 0),
 ('isosulfan blue', 0.9907, 0),
 ('prenylamine', 0.9906, 1),
 ('oxyphenisatine', 0.9905, 1),
 ('olmesartan', 0.9902, 0),
 ('fendiline', 0.99, 1),
 ('cinnarizine', 0.9898, 0),
 ('phenolphthalein', 0.9898, 1),
 ('boscalid', 0.9898, 0),
 ('dabigatran', 0.9897, 0),
 ('pyrvinium', 0.9897, 0),
 ('toremifene', 0.9896, 0),
 ('conivaptan', 0.9894, 0),
 ('fenofibric acid', 0.9892, 0),
 ('clotrimazole', 0.9891, 0),
 ('etoricoxib', 0.9888, 0),
 ('naftifine', 0.9888, 0),
 (

In [89]:
names = ['acetarsol', 'acetohexamide', 'alosetron', 'amlexanox', 'benzbromarone', 'boceprevir', 'carmofur',
         'cefadroxil', 'chlophedianol', 'cianidanol', 'clioquinol', 'clobutinol', 'dexrazoxane', 'eflornithine',
         'eprazinone', 'etifoxine', 'fendiline', 'floctafenine', 'flubendazole', 'formestane', 'gemeprost',
         'halcinonide', 'haloprogin', 'hetacillin', 'hexachlorophene', 'hexoprenaline', 'hydroflumethiazide',
         'ifenprodil', 'lithium hydroxide', 'medrogestone', 'melphalan flufenamide', 'methyclothiazide',
         'nefazodone', 'oxeladin', 'ranitidine', 'sertindole', 'testosterone propionate', 'thalidomide',
         'thioridazine', 'tolcapone', 'viloxazine', 'zotepine']

In [100]:
sorted(list(filter(lambda x: x[0] in names, preds)), key=lambda x: x[1], reverse=True)

[('ifenprodil', 0.9917, 1),
 ('fendiline', 0.99, 1),
 ('flubendazole', 0.9815, 1),
 ('nefazodone', 0.98, 1),
 ('etifoxine', 0.9766, 1),
 ('cefadroxil', 0.9753, 1),
 ('eprazinone', 0.9701, 1),
 ('tolcapone', 0.9628, 1),
 ('floctafenine', 0.9626, 1),
 ('benzbromarone', 0.9569, 1),
 ('hetacillin', 0.9506, 1),
 ('amlexanox', 0.9479, 1),
 ('thioridazine', 0.9466, 1),
 ('alosetron', 0.9412, 1),
 ('sertindole', 0.9377, 1),
 ('oxeladin', 0.9231, 1),
 ('zotepine', 0.9179, 1),
 ('hexoprenaline', 0.9153, 1),
 ('cianidanol', 0.8934, 1),
 ('acetohexamide', 0.8633, 1),
 ('clobutinol', 0.8561, 1),
 ('thalidomide', 0.8289, 1),
 ('acetarsol', 0.8265, 1),
 ('ranitidine', 0.7294, 1),
 ('hexachlorophene', 0.6978, 1),
 ('viloxazine', 0.6906, 1),
 ('melphalan flufenamide', 0.6849, 1),
 ('clioquinol', 0.6352, 1),
 ('methyclothiazide', 0.5087, 1),
 ('dexrazoxane', 0.4578, 1),
 ('haloprogin', 0.4337, 1),
 ('testosterone propionate', 0.4104, 1),
 ('lithium hydroxide', 0.3991, 1),
 ('hydroflumethiazide', 0.3711,

In [95]:
(42 -13) / 42

0.6904761904761905

In [15]:
predictions = trainer.predict(dataset["test"])

The following columns in the test set don't have a corresponding argument in `RobertaForSequenceClassification.forward` and have been ignored: smiles. If smiles are not expected by `RobertaForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 3601
  Batch size = 4
You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [16]:
print(predictions.predictions.shape, predictions.label_ids.shape)

(3601, 2) (3601,)


In [17]:
preds = torch.softmax(torch.from_numpy(predictions.predictions), -1)[:, 1]

In [24]:
sorted(preds, reverse=True)

[tensor(0.9910),
 tensor(0.9909),
 tensor(0.9905),
 tensor(0.9904),
 tensor(0.9899),
 tensor(0.9898),
 tensor(0.9894),
 tensor(0.9888),
 tensor(0.9887),
 tensor(0.9878),
 tensor(0.9878),
 tensor(0.9877),
 tensor(0.9872),
 tensor(0.9870),
 tensor(0.9869),
 tensor(0.9866),
 tensor(0.9851),
 tensor(0.9851),
 tensor(0.9841),
 tensor(0.9836),
 tensor(0.9834),
 tensor(0.9831),
 tensor(0.9829),
 tensor(0.9823),
 tensor(0.9822),
 tensor(0.9822),
 tensor(0.9815),
 tensor(0.9800),
 tensor(0.9800),
 tensor(0.9792),
 tensor(0.9791),
 tensor(0.9788),
 tensor(0.9786),
 tensor(0.9778),
 tensor(0.9772),
 tensor(0.9765),
 tensor(0.9753),
 tensor(0.9751),
 tensor(0.9743),
 tensor(0.9726),
 tensor(0.9721),
 tensor(0.9719),
 tensor(0.9711),
 tensor(0.9687),
 tensor(0.9686),
 tensor(0.9680),
 tensor(0.9679),
 tensor(0.9677),
 tensor(0.9677),
 tensor(0.9674),
 tensor(0.9670),
 tensor(0.9668),
 tensor(0.9665),
 tensor(0.9655),
 tensor(0.9653),
 tensor(0.9640),
 tensor(0.9629),
 tensor(0.9620),
 tensor(0.9620

In [18]:
torch.argsort(preds, descending=True)[:10]

tensor([1674, 1923, 1075, 1141, 1176,  934, 2204,  813, 2105, 1870])

In [22]:
preds[[1674, 1923]]

tensor([0.9910, 0.9909])

In [25]:
model(dataset['test'][[1674, 1923]]['input_ids'].cuda())

SequenceClassifierOutput(loss=None, logits=tensor([[-2.0995,  1.9418],
        [-1.8883,  1.8032]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [26]:
predictions.predictions[[1674, 1923]]

array([[-2.4405417,  2.263262 ],
       [-2.4196503,  2.2723093]], dtype=float32)

In [23]:
torch.softmax(model(dataset['test'][[1674, 1923]]['input_ids'].cuda()).logits, -1)[:, 1]

tensor([0.9827, 0.9757], device='cuda:0', grad_fn=<SelectBackward0>)

In [76]:
dataset['test'][666]

{'smiles': 'CCN1CCN(CC2=CC=C(NC3=NC=C(F)C(=N3)C3=CC(F)=C4N=C(C)N(C(C)C)C4=C3)N=C2)CC1',
 'labels': tensor(0),
 'input_ids': tensor([  0, 289,  21, 289,  12, 262,  22,  33, 262,  33,  39,  12, 270,  23,
          33, 270,  33,  39,  12,  42,  13,  39, 263,  50,  23,  13,  39,  23,
          33, 262,  12,  42, 287,  39,  24,  50,  33,  39,  12,  39,  13,  50,
          12,  39,  12,  39,  13,  39,  13,  39,  24,  33,  39,  23,  13,  50,
          33,  39,  22,  13, 262,  21,   2,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
        

In [52]:
preds[[1674, 1923, 1075, 1141, 1176,  934, 2204,  813, 2105, 1870]]

tensor([0.9910, 0.9909, 0.9905, 0.9904, 0.9899, 0.9898, 0.9894, 0.9888, 0.9887,
        0.9878])

In [54]:
predictions.label_ids[[1674, 1923, 1075, 1141, 1176,  934, 2204,  813, 2105, 1870]]

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [45]:
sorted(, reverse=True)

[tensor(0.9910),
 tensor(0.9909),
 tensor(0.9905),
 tensor(0.9904),
 tensor(0.9899),
 tensor(0.9898),
 tensor(0.9894),
 tensor(0.9888),
 tensor(0.9887),
 tensor(0.9878),
 tensor(0.9878),
 tensor(0.9877),
 tensor(0.9872),
 tensor(0.9870),
 tensor(0.9869),
 tensor(0.9866),
 tensor(0.9851),
 tensor(0.9851),
 tensor(0.9841),
 tensor(0.9836),
 tensor(0.9834),
 tensor(0.9831),
 tensor(0.9829),
 tensor(0.9823),
 tensor(0.9822),
 tensor(0.9822),
 tensor(0.9815),
 tensor(0.9800),
 tensor(0.9800),
 tensor(0.9792),
 tensor(0.9791),
 tensor(0.9788),
 tensor(0.9786),
 tensor(0.9778),
 tensor(0.9772),
 tensor(0.9765),
 tensor(0.9753),
 tensor(0.9751),
 tensor(0.9743),
 tensor(0.9726),
 tensor(0.9721),
 tensor(0.9719),
 tensor(0.9711),
 tensor(0.9687),
 tensor(0.9686),
 tensor(0.9680),
 tensor(0.9679),
 tensor(0.9677),
 tensor(0.9677),
 tensor(0.9674),
 tensor(0.9670),
 tensor(0.9668),
 tensor(0.9665),
 tensor(0.9655),
 tensor(0.9653),
 tensor(0.9640),
 tensor(0.9629),
 tensor(0.9620),
 tensor(0.9620