In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

from huggingmolecules import MatModel, MatFeaturizer

sys.path.insert(0, os.path.abspath('..'))

# The following import works only from the source code directory:
from experiments.src import TrainingModule, get_data_loaders

from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam

from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Accuracy
import torch.nn as nn
from pytorch_lightning.metrics import Metric
from pytorch_lightning.metrics.functional.classification import auroc

import numpy as np 
from sklearn.metrics import roc_auc_score, average_precision_score
import wandb
from pytorch_lightning.loggers import WandbLogger  # newline 1
from pytorch_lightning.callbacks import ModelCheckpoint
import torch

In [None]:
import glob
import pandas as pd

class AUROC(Metric):
    def __init__(
            self,
            sample_weight: Optional[Sequence] = None,
            compute_on_step: bool = False,  # True likely crashes if not every batch contains all classes
            dist_sync_on_step: bool = False,
            process_group: Optional[Any] = None,
    ):
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
        )

        self.sample_weight = sample_weight
        self.add_state("all_preds", default=[])
        self.add_state("all_target", default=[])

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        self.all_preds.append(preds)
        self.all_target.append(target)

    def compute(self):
        preds_tensor = torch.cat(self.all_preds).squeeze()
        target_tensor = torch.cat(self.all_target).squeeze()
        return auroc(preds_tensor, target_tensor, sample_weight=self.sample_weight)

datasets = ['bbbp', 'clintox', 'tox21', 'hiv']
data_dir_pattern = "datasets/mymoleculenet/%s/"


In [None]:
for dataset in datasets:
    n_splits = len(glob.glob1(data_dir_pattern%dataset, "[0-9]*"))
    wandb_logger = WandbLogger(name='mat', mode='online', entity='dfstransformer', project='moleculenet10-baselines',
                               config={'dataset':dataset}, settings=wandb.Settings(start_method='fork'))
    rocs = []
    prcs = []
    for rep in range(n_splits):
        # Build and load the pre-trained model and the appropriate featurizer:
        model = MatModel.from_pretrained('mat_masking_20M')
        featurizer = MatFeaturizer.from_pretrained('mat_masking_20M')

        # Build the pytorch lightning training module:
        pl_module = TrainingModule(model,
                                   loss_fn=BCEWithLogitsLoss(),
                                   metric_cls=AUROC,
                                   optimizer=Adam(model.parameters(), lr=3e-6))
        pl_module.cuda()
        

        trainset = pd.read_csv(data_dir_pattern%dataset+"%d/train.csv"%rep)
        validset = pd.read_csv(data_dir_pattern%dataset+"%d/valid.csv"%rep)
        testset = pd.read_csv(data_dir_pattern%dataset+"%d/test.csv"%rep)
        train_X, train_y = trainset["smiles"].to_numpy(), trainset["target"].to_numpy()
        valid_X, valid_y = validset["smiles"].to_numpy(), validset["target"].to_numpy()
        test_X, test_y = testset["smiles"].to_numpy(), testset["target"].to_numpy()
        
        train_data = featurizer.encode_smiles_list(train_X, train_y)
        valid_data = featurizer.encode_smiles_list(valid_X, valid_y)
        test_data = featurizer.encode_smiles_list(test_X, test_y)
        train_loader = featurizer.get_data_loader(train_data, batch_size=32, shuffle=True, num_workers=12)
        valid_loader = featurizer.get_data_loader(valid_data, batch_size=32, shuffle=False, num_workers=12)
        test_loader = featurizer.get_data_loader(test_data, batch_size=32, shuffle=False, num_workers=12)


        # Build the pytorch lightning trainer and fine-tune the module on the train dataset:
        checkpoint_callback = ModelCheckpoint(monitor="valid_auroc", mode="max")
        
        trainer = Trainer(max_epochs=25, gpus=[0], logger=wandb_logger, callbacks=[checkpoint_callback])
        #trainer.tune(pl_module,  train_dataloader=train_loader)
        trainer.fit(pl_module, train_dataloader=train_loader, val_dataloaders=[valid_loader])

        #m = pl_module.eval().cpu()
        m = MatModel.from_pretrained('mat_masking_20M')
        #m.load_state_dict(torch.load(checkpoint_callback.best_model_path)['state_dict'])
        m.load_state_dict({key[6:]:value for key, value in torch.load(checkpoint_callback.best_model_path)['state_dict'].items()}, strict=False)
        m = m.eval().cpu()
        preds = []
        for d in test_loader:
            preds += [m(d).detach().numpy()]
        preds = np.concatenate(preds, axis=0)[:, 0]
        rocs += [roc_auc_score(test_y, preds)]
        prcs += [average_precision_score(test_y, preds)]
        wandb.log({'roc_test':rocs[-1], 'prc_test':prcs[-1]})
    wandb.log({'mean_roc_test':np.mean(rocs), 
               'std_roc_test':np.std(rocs),
               'ci_roc_test':1.96*np.std(rocs)/np.sqrt(10),
               'mean_prc_test':np.mean(prcs), 
               'std_prc_test':np.std(prcs),
               'ci_prc_test':1.96*np.std(prcs)/np.sqrt(10)})