In [5]:
from sklearn.metrics import average_precision_score, roc_auc_score
import wandb

from datasets import load_dataset
import evaluate
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer, IntervalStrategy

import pandas as pd
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from tqdm.auto import tqdm

In [6]:
def tokenize_function(examples):
    return tokenizer(examples["smiles"], padding="max_length", truncation=True, max_length=300)

In [11]:
pretrained_path

'results/db_no_agree_no_dups/NCATS/seyonec/PubChem10M_SMILES_BPE_450k/checkpoint-550/'

In [14]:
split_type = 'db_no_agree_no_dups'
for dataset_name, checkpoint in [('DrugBank', 600), ('ChEMBL', 550), ('NCATS', 350)]:
    pretrained_path = f'results/{split_type}/{dataset_name}/seyonec/PubChem10M_SMILES_BPE_450k/checkpoint-{checkpoint}/'
    
        
    tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
    model = AutoModelForSequenceClassification.from_pretrained(pretrained_path, num_labels=2,
                                                           id2label={0: 'Not Withdrawn', 1:'Withdrawn'},
                                                           label2id={'Not Withdrawn': 0, 'Withdrawn': 1})
    
    dataset = load_dataset('csv', data_files={'train': f'split/{split_type}/{dataset_name}/train2.csv',
                                          'validation': f'split/{split_type}/{dataset_name}/val.csv',
                                          'test': f'split/{split_type}/{dataset_name}/test.csv',})
    dataset = dataset.rename_column('withdrawn_class', 'labels').\
            remove_columns(['Unnamed: 0', 'index', 'length', 'inchikey', 'groups', 'source']).\
            with_format('torch')
    dataset = dataset.map(tokenize_function, batched=True)
    
    preds = []
    for row in tqdm(dataset['test']):
        output = torch.softmax(model(row['input_ids'][None, ...]).logits, -1)
        preds.append((row['name'], round(output[:, 1].item(), 4), row['labels'].item()))
            
    train = pd.read_csv(f'split/{split_type}/{dataset_name}/train.csv')
    test = pd.read_csv(f'split/{split_type}/{dataset_name}/test.csv')
    
    # all the drugs in the test that occur in the train by name
    no_agree = test[test.name.isin(train.name)]
    # all the drugs that are with label=1 (i.e. in the train with 0)
    pos_only = no_agree[no_agree['withdrawn_class'] == 1]
    names = pos_only.name.tolist()
    
    # filter predictions to those only which we not agree on
    no_agree_preds = sorted(list(filter(lambda x: x[0] in names, preds)), key=lambda x: x[1], reverse=True)
    print(sorted(no_agree_preds, key=lambda x: x[1], reverse=True))

    correct = list(filter(lambda x: x[1] >= 0.5, no_agree_preds))
    
    print(f'Count: {len(no_agree_preds)}')
    print(f'Correct: {len(correct)}')
    print(f'Accuracy: {len(correct) / len(no_agree_preds)}')

Using custom data configuration default-74e83418e4d2c9a0
Found cached dataset csv (/home/eyal/.cache/huggingface/datasets/csv/default-74e83418e4d2c9a0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-74e83418e4d2c9a0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-5e1041380be4dd2c.arrow
Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-74e83418e4d2c9a0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-71a8b67d9d9df64d.arrow
Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-74e83418e4d2c9a0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-b8e66c7c5950463c.arrow


  0%|          | 0/2431 [00:00<?, ?it/s]

[('ifenprodil', 0.9928, 1), ('fendiline', 0.9911, 1), ('flubendazole', 0.9839, 1), ('nefazodone', 0.9827, 1), ('etifoxine', 0.98, 1), ('cefadroxil', 0.9788, 1), ('eprazinone', 0.9745, 1), ('tolcapone', 0.9714, 1), ('floctafenine', 0.9689, 1), ('benzbromarone', 0.9649, 1), ('hetacillin', 0.9592, 1), ('amlexanox', 0.9567, 1), ('thioridazine', 0.9548, 1), ('alosetron', 0.9519, 1), ('sertindole', 0.948, 1), ('hexoprenaline', 0.9384, 1), ('oxeladin', 0.9374, 1), ('zotepine', 0.9302, 1), ('cianidanol', 0.9211, 1), ('clobutinol', 0.8894, 1), ('acetohexamide', 0.8874, 1), ('acetarsol', 0.865, 1), ('thalidomide', 0.8593, 1), ('ranitidine', 0.7762, 1), ('hexachlorophene', 0.7481, 1), ('melphalan flufenamide', 0.7366, 1), ('viloxazine', 0.7204, 1), ('clioquinol', 0.6854, 1), ('methyclothiazide', 0.5722, 1), ('dexrazoxane', 0.497, 1), ('haloprogin', 0.4826, 1), ('testosterone propionate', 0.452, 1), ('lithium hydroxide', 0.4355, 1), ('hydroflumethiazide', 0.4113, 1), ('medrogestone', 0.395, 1), ('

Using custom data configuration default-26ef806f20143910
Found cached dataset csv (/home/eyal/.cache/huggingface/datasets/csv/default-26ef806f20143910/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-26ef806f20143910/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-45bf09ea091143ac.arrow
Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-26ef806f20143910/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-d9249b61164f4062.arrow
Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-26ef806f20143910/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-f78f2780c730b598.arrow


  0%|          | 0/2565 [00:00<?, ?it/s]

[('betamethasone benzoate', 0.9872, 1), ('tubocurarine', 0.9841, 1), ('rescinnamine', 0.9829, 1), ('cianidanol', 0.9828, 1), ('metocurine', 0.9818, 1), ('deserpidine', 0.9808, 1), ('dexamethasone acetate', 0.9806, 1), ('phenprocoumon', 0.9779, 1), ('paramethasone acetate', 0.9772, 1), ('triamcinolone', 0.9763, 1), ('hetacillin', 0.9716, 1), ('deslanoside', 0.9705, 1), ('novobiocin', 0.9675, 1), ('trimetrexate', 0.9671, 1), ('acetyldigitoxin', 0.9668, 1), ('trimethaphan', 0.9645, 1), ('erythromycin estolate', 0.9641, 1), ('clometacin', 0.9638, 1), ('digitoxin', 0.9623, 1), ('doxacurium', 0.9613, 1), ('prednisolone tebutate', 0.9595, 1), ('masoprocol', 0.9573, 1), ('meprednisone', 0.9549, 1), ('cyclothiazide', 0.9545, 1), ('fluprednisolone', 0.9541, 1), ('vincamine', 0.948, 1), ('troleandomycin', 0.9461, 1), ('fenoterol', 0.9397, 1), ('mazindol', 0.9392, 1), ('sulfaphenazole', 0.9384, 1), ('protokylol', 0.9339, 1), ('fenclofenac', 0.9305, 1), ('aminoglutethimide', 0.9295, 1), ('isoethari

Using custom data configuration default-91940da3d7934178
Found cached dataset csv (/home/eyal/.cache/huggingface/datasets/csv/default-91940da3d7934178/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-91940da3d7934178/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-9783c284089d6969.arrow
Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-91940da3d7934178/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-a2e28bc4b6f3595b.arrow
Loading cached processed dataset at /home/eyal/.cache/huggingface/datasets/csv/default-91940da3d7934178/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317/cache-3a2d51c462208001.arrow


  0%|          | 0/4013 [00:00<?, ?it/s]

[('trimetrexate', 0.9968, 1), ('protokylol', 0.9957, 1), ('niclosamide', 0.9956, 1), ('mazindol', 0.9954, 1), ('masoprocol', 0.995, 1), ('cyclacillin', 0.9948, 1), ('stanozolol', 0.9946, 1), ('procyclidine', 0.9941, 1), ('amodiaquine', 0.994, 1), ('pyrvinium', 0.9938, 1), ('thiabendazole', 0.9935, 1), ('phenylbutazone', 0.9934, 1), ('spectinomycin', 0.9934, 1), ('triamcinolone', 0.9932, 1), ('buclizine', 0.993, 1), ('phenprocoumon', 0.9929, 1), ('cycrimine', 0.9922, 1), ('trimethaphan', 0.9914, 1), ('phenindione', 0.9914, 1), ('metocurine', 0.991, 1), ('cloxacillin', 0.9909, 1), ('mezlocillin', 0.9907, 1), ('carbenicillin', 0.9904, 1), ('anisindione', 0.9898, 1), ('dapiprazole', 0.9894, 1), ('phenyl aminosalicylate', 0.9893, 1), ('deserpidine', 0.9887, 1), ('noscapine', 0.9883, 1), ('methdilazine', 0.988, 1), ('quinestrol', 0.9875, 1), ('fluprednisolone', 0.987, 1), ('sulfamethazine', 0.9868, 1), ('mesoridazine', 0.9865, 1), ('medrysone', 0.986, 1), ('sulfabenzamide', 0.9858, 1), ('ami