In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))
from lib.pipeline import Pipeline
import torch

GPU = 1

def make_pipeline(weight_file):
    pipeline = Pipeline(
        model='lm-gearnet',
        dataset='atpbind3d',
        gpus=[GPU],
        model_kwargs={
            'gpu': GPU,
            'gearnet_hidden_dim_size': 512,
            'gearnet_hidden_dim_count': 4,
            'bert_freeze': False,
            'bert_freeze_layer_count': 29,
        },
        optimizer_kwargs={
            'lr': 5e-4,
        },
        task_kwargs={
            'use_rus': True,
            'rus_seed': 0,
        },
        bce_weight=1,
        batch_size=4,
    )
    if weight_file is not None:
        state_dict = torch.load(weight_file, map_location=f'cuda:{GPU}')
        pipeline.task.load_state_dict(state_dict, strict=False)
    
    return pipeline

trained_rus_model_file = {
    0: 'rus_pipeline_0_0.59850.pth',
    1: 'rus_pipeline_1_0.59290pth',
}
p1 = make_pipeline('rus_pipeline_0_0.432289.pth')
p2 = make_pipeline('rus_pipeline_0_0.432289.pth')

get dataset atpbind3d
Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [19]:
from torchdrug import core, models
from lib.tasks import NodePropertyPrediction
from lib.custom_models import LMGearNetModel
from torchdrug import transforms, data, core, layers, tasks, metrics, utils, models
from torchdrug.layers import functional, geometry
from torchdrug.core import Registry as R
import torch
from torch.utils import data as torch_data
from torch.nn import functional as F
import contextlib
import logging
import numpy as np
from functools import cache

# This task should only be used in inference, thus we only have to care about predict?
# engine.evaluate only cares about task.predict_and_target and task.evaluate
class MeanEnsembleNodePropertyPrediction(NodePropertyPrediction):
    def __init__(self, model, ensembled_task_file, *args, **kwargs):
        super().__init__(model, *args, **kwargs)
        self.ensembled_task_file = ensembled_task_file
    
    def predict(self, batch, all_loss=None, metric=None):
        mean_prediction = None
        for rus_seed, model_file in self.ensembled_task_file.items():
            self.load_state_dict(torch.load(model_file, map_location=self.device), strict=False)
            self.use_rus = True
            self.rus_seed = rus_seed
            current_prediction = super(MeanEnsembleNodePropertyPrediction, self).predict(batch, all_loss, metric)
            if mean_prediction is None:
                mean_prediction = current_prediction
            else:
                mean_prediction += current_prediction
        mean_prediction /= len(self.ensembled_task_file)

        return mean_prediction



Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# model = LMGearNetModel(
#     gpu=GPU,
#     gearnet_hidden_dim_size=512,
#     gearnet_hidden_dim_count=4,
#     bert_freeze=False,
#     bert_freeze_layer_count=29,
# )

# edge_layers = [
#     geometry.SpatialEdge(radius=10, min_distance=5),
#     geometry.KNNEdge(k=10, min_distance=5),
#     geometry.SequentialEdge(max_distance=2),
# ]
    
# graph_construction_model = layers.GraphConstruction(
#     node_layers=[geometry.AlphaCarbonNode()],
#     edge_layers=edge_layers,
#     edge_feature="gearnet"
# )
# me_task = MeanEnsembleNodePropertyPrediction(
#     model=model, 
#     ensembled_task_file={
#         0: 'rus_pipeline_0_0.59580.pth',
#         1: 'rus_pipeline_1_0.59290.pth',
#         2: 'rus_pipeline_2_0.55640.pth',
#     },
#     graph_construction_model=graph_construction_model,
#     normalization=False,
#     num_mlp_layer=2,
#     metric=("sensitivity", "specificity", "accuracy", "precision", "mcc", "micro_auroc",),
#     bce_weight=torch.tensor([1], device=torch.device(f'cuda:{GPU}')),
# )

In [74]:
ensembled_task_file={
    0: 'rus_pipeline_0_0.59580.pth',
    1: 'rus_pipeline_1_0.59290.pth',
    2: 'rus_pipeline_2_0.55640.pth',
    3: 'rus_pipeline_3_0.56080.pth',
    4: 'rus_pipeline_4_0.59780.pth',
    5: 'rus_pipeline_5_0.54520.pth',
}

In [75]:
from torchdrug import transforms, data, core, layers, tasks, metrics, utils, models
from torchdrug.layers import functional, geometry
from torchdrug.core import Registry as R
import torch
from torch.utils import data as torch_data
from torch.nn import functional as F
import contextlib
import logging
import numpy as np
from functools import cache

from lib.tasks import NodePropertyPrediction
from lib.datasets import ATPBind, ATPBind3D
from lib.bert import BertWrapModel
from lib.custom_models import GearNetWrapModel, LMGearNetModel
from lib.utils import dict_tensor_to_num, round_dict
from lib.pipeline import get_dataset
class DisableLogger():
    def __enter__(self):
       logging.disable(logging.CRITICAL)
    def __exit__(self, exit_type, exit_value, exit_traceback):
       logging.disable(logging.NOTSET)


METRICS_USING = ("sensitivity", "specificity", "accuracy", "precision", "mcc", "micro_auroc",)
class CustomPipeline:
    possible_models = ['bert', 'gearnet', 'lm-gearnet', 'cnn']
    possible_datasets = ['atpbind', 'atpbind3d', 'atpbind3d-minimal']
    threshold = 0
    
    def __init__(self, 
                 model,
                 dataset,
                 gpus,
                 model_kwargs={},
                 optimizer_kwargs={},
                 task_kwargs={},
                 graph_knn_k=10,
                 graph_spatial_radius=10.0,
                 graph_sequential_max_distance=2,
                 batch_size=1,
                 bce_weight=1,
                 verbose=False,
                 ):
        self.gpus = gpus

        if model not in self.possible_models:
            raise ValueError('Model must be one of {}'.format(self.possible_models))
    
        if dataset not in self.possible_datasets:
            raise ValueError('Dataset must be one of {}'.format(self.possible_datasets))
           
        with DisableLogger():     
            if model == 'bert':
                self.model = BertWrapModel(**model_kwargs)
            elif model == 'gearnet':
                self.model = GearNetWrapModel(graph_sequential_max_distance=graph_sequential_max_distance, **model_kwargs)
            elif model == 'lm-gearnet':
                self.model = LMGearNetModel(graph_sequential_max_distance=graph_sequential_max_distance, **model_kwargs)
            elif model == 'cnn':
                self.model = models.ProteinCNN(**model_kwargs)
        
        self.train_set, self.valid_set, self.test_set = get_dataset(dataset)
        
        if dataset == 'atpbind':
            self.task = MeanEnsembleNodePropertyPrediction(
                self.model, 
                normalization=False,
                num_mlp_layer=2,
                metric=METRICS_USING,
                bce_weight=torch.tensor([bce_weight], device=torch.device(f'cuda:{self.gpus[0]}')),
                **task_kwargs,
            )
        elif dataset == 'atpbind3d' or dataset == 'atpbind3d-minimal':
            edge_layers = [
                geometry.SpatialEdge(radius=graph_spatial_radius, min_distance=5),
                geometry.KNNEdge(k=graph_knn_k, min_distance=5),
                geometry.SequentialEdge(max_distance=graph_sequential_max_distance),
            ]
                
            graph_construction_model = layers.GraphConstruction(
                node_layers=[geometry.AlphaCarbonNode()],
                edge_layers=edge_layers,
                edge_feature="gearnet"
            )
            self.task = MeanEnsembleNodePropertyPrediction(
                self.model,
                ensembled_task_file=ensembled_task_file,
                graph_construction_model=graph_construction_model,
                normalization=False,
                num_mlp_layer=2,
                metric=METRICS_USING,
                bce_weight=torch.tensor([bce_weight], device=torch.device(f'cuda:{self.gpus[0]}')),
                **task_kwargs,
            )
        
        optimizer = torch.optim.Adam(self.model.parameters(), **optimizer_kwargs)
        with DisableLogger():
            self.solver = core.Engine(self.task,
                                        self.train_set,
                                        self.valid_set,
                                        self.test_set,
                                        optimizer,
                                        batch_size=batch_size,
                                        log_interval=1000000000,
                                        gpus=gpus
            )
        
        self.verbose = verbose
        
    def train(self, num_epoch):
        return self.solver.train(num_epoch=num_epoch)
    
    def train_until_fit(self, patience=1):
        from timer_cm import Timer
        from itertools import count
        train_record = []
        for epoch in count(start=1):
            cm = contextlib.nullcontext() if self.verbose else DisableLogger()
            with cm:
                self.train(num_epoch=1)
                cur_result = self.evaluate()
                cur_result['train_bce'] = self.get_last_bce()
                cur_result['valid_bce'] = self.calculate_valid_loss()
                cur_result['valid_mcc'] = self.calculate_best_mcc_and_threshold(
                    threshold_set='valid'
                )['best_mcc']
                cur_result = round_dict(cur_result, 4)
                train_record.append(cur_result)
                print(cur_result)
                max_mcc_index = np.argmax([record['valid_mcc'] for record in train_record])
                if max_mcc_index < len(train_record) - patience:
                    break
        return train_record
        

    def get_last_bce(self):
        from statistics import mean
        meter = self.solver.meter
        index = slice(meter.epoch2batch[-2], meter.epoch2batch[-1])
        bce_records = meter.records['binary cross entropy'][index]
        return mean(bce_records)
    
    def calculate_valid_loss(self):
        from statistics import mean
        dataloader = data.DataLoader(self.valid_set, batch_size=1, shuffle=False)
        model = self.task

        model.eval()

        metrics = []
        with torch.no_grad():
            for batch in dataloader:
                batch = utils.cuda(batch, device=f'cuda:{self.gpus[0]}')
                loss, metric = model(batch)
                metrics.append(metric['binary cross entropy'].item())
        
        return mean(metrics)


    def calculate_best_mcc_and_threshold(self, threshold_set='valid'):
        dataloader = data.DataLoader(
            self.valid_set if threshold_set == 'valid' else self.test_set,
            batch_size=1,
            shuffle=False
        )

        preds = []
        targets = []
        thresholds = np.linspace(-3, 1, num=41)
        mcc_values = [0 for i in range(len(thresholds))]
        self.model.eval()
        with torch.no_grad():
            for batch in dataloader:
                batch = utils.cuda(batch, device=torch.device(f'cuda:{self.gpus[0]}'))
                pred, target = self.task.predict_and_target(batch)
                preds.append(pred)
                targets.append(target)
        
        pred = utils.cat(preds)
        target = utils.cat(targets)

        for i, threshold in enumerate(thresholds):
            mcc = self.task.evaluate(
                pred, target, threshold
            )['mcc']
            mcc_values[i] = mcc_values[i] + mcc

        max_mcc_idx = np.argmax(mcc_values)
        
        return {
            'best_mcc': mcc_values[max_mcc_idx],
            'best_threshold': thresholds[max_mcc_idx]
        }


    def evaluate(self, threshold_set='valid', verbose=False):
        mcc_and_threshold = self.calculate_best_mcc_and_threshold(threshold_set)
        if verbose:
            print(f'threshold: {mcc_and_threshold}\n')
        self.task.threshold = mcc_and_threshold['best_threshold']
        return dict_tensor_to_num(self.solver.evaluate("test"))


In [76]:
pipeline = CustomPipeline(
    model='lm-gearnet',
    dataset='atpbind3d',
    gpus=[GPU],
    model_kwargs={
        'gpu': GPU,
        'gearnet_hidden_dim_size': 512,
        'gearnet_hidden_dim_count': 4,
        'bert_freeze': False,
        'bert_freeze_layer_count': 29,
    },
    optimizer_kwargs={
        'lr': 5e-4,
    },
    task_kwargs={
        'use_rus': True,
        'rus_seed': 0,
    },
    bce_weight=1,
    batch_size=4,
)

pipeline.evaluate()

17:37:28   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
17:37:28   Evaluate on test
17:37:50   ------------------------------
17:37:50   accuracy: 0.966876
17:37:50   mcc: 0.630999
17:37:50   micro_auroc: 0.942937
17:37:50   precision: 0.726908
17:37:50   sensitivity: 0.577352
17:37:50   specificity: 0.988152


{'sensitivity': 0.5773524641990662,
 'specificity': 0.988152265548706,
 'accuracy': 0.9668759107589722,
 'precision': 0.7269076108932495,
 'mcc': 0.6309990197220169,
 'micro_auroc': 0.942936897277832}