In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))
from lib.pipeline import Pipeline
from lib.disable_logger import DisableLogger
import torch
GPU = 1

pipeline = Pipeline(
    model='lm-gearnet',
    dataset='atpbind3d',
    task='mean-ensemble',
    gpus=[GPU],
    model_kwargs={
        'gpu': GPU,
        'gearnet_hidden_dim_size': 512,
        'gearnet_hidden_dim_count': 4,
        'bert_freeze': False,
        'bert_freeze_layer_count': 29,
    },
    optimizer_kwargs={
        'lr': 5e-4,
    },
    task_kwargs={
        'state_dict_files': [],
    },
    bce_weight=1,
    batch_size=24,
)

torch.cuda.empty_cache()

get dataset atpbind3d
Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


In [3]:
state_dict_files_5 = [
    'rus_5_0_0.6151.pth',
    'rus_5_1_0.6221.pth',
    'rus_5_2_0.6193.pth',
    'rus_5_3_0.6266.pth',
    'rus_5_4_0.6052.pth',
    'rus_5_5_0.6085.pth',
    'rus_5_6_0.5986.pth',
    'rus_5_7_0.6108.pth',
    'rus_5_8_0.6046.pth',
    'rus_5_9_0.6080.pth',
]

with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5
    res = pipeline.evaluate()
res

{'sensitivity': 0.5422647595405579,
 'specificity': 0.9911141991615295,
 'accuracy': 0.9678671956062317,
 'precision': 0.7692307829856873,
 'mcc': 0.630212175762319,
 'micro_auroc': 0.9402540922164917}

In [2]:
state_dict_files_5 = [
    'rus_5_0_0.6151.pth',
    'rus_5_1_0.6221.pth',
    'rus_5_2_0.6193.pth',
    'rus_5_3_0.6266.pth',
    'rus_5_4_0.6052.pth',
    'rus_5_5_0.6085.pth',
    'rus_5_6_0.5986.pth',
    'rus_5_7_0.6108.pth',
    'rus_5_8_0.6046.pth',
    'rus_5_9_0.6080.pth',
]

with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5
    res = pipeline.evaluate(threshold='auto', verbose=True)
res

threshold: {'best_mcc': 0.6556262597321062, 'best_threshold': -0.6999999999999997}



{'sensitivity': 0.5725677609443665,
 'specificity': 0.9893718957901001,
 'accuracy': 0.967784583568573,
 'precision': 0.7463617324829102,
 'mcc': 0.637538591910511,
 'micro_auroc': 0.9402540922164917}

In [7]:
state_dict_files_10 = [
    'rus_10_0_0.59580.pth',
    'rus_10_1_0.59290.pth',
    'rus_10_2_0.6494.pth',
    'rus_10_3_0.6114.pth',
    'rus_10_4_0.59780.pth',
    'rus_10_5_0.6173.pth',
    'rus_10_6_0.6152.pth',
    'rus_10_7_0.6091.pth',
    'rus_10_8_0.5989.pth',
    'rus_10_9_0.5892.pth',
]


with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_10
    res = pipeline.evaluate()

res

{'sensitivity': 0.5677831172943115,
 'specificity': 0.9899817109107971,
 'accuracy': 0.968114972114563,
 'precision': 0.7558386325836182,
 'mcc': 0.6392095483443275,
 'micro_auroc': 0.9433506727218628}

In [5]:
state_dict_files_15 = [
    'rus_15_0_0.5936.pth',
    'rus_15_1_0.6274.pth',
    'rus_15_2_0.6004.pth',
    'rus_15_3_0.6200.pth',
    'rus_15_4_0.6215.pth',
    'rus_15_5_0.6085.pth',
    'rus_15_6_0.6096.pth',
    'rus_15_7_0.6098.pth',
    'rus_15_8_0.6186.pth',
    'rus_15_9_0.6188.pth',
]

with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_15
    res = pipeline.evaluate()

res

{'sensitivity': 0.4864433705806732,
 'specificity': 0.9930307269096375,
 'accuracy': 0.9667932987213135,
 'precision': 0.7922077775001526,
 'mcc': 0.6055336512444388,
 'micro_auroc': 0.9351373910903931}

In [9]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5 + state_dict_files_10
    res = pipeline.evaluate()
res


{'sensitivity': 0.5566188097000122,
 'specificity': 0.9911141991615295,
 'accuracy': 0.9686105847358704,
 'precision': 0.7738358974456787,
 'mcc': 0.6409295839955366,
 'micro_auroc': 0.9434270858764648}

In [11]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5 + state_dict_files_10
    res = pipeline.evaluate(threshold='auto', verbose=True)
res


threshold: {'best_mcc': 0.6614176215124834, 'best_threshold': -0.8999999999999999}



{'sensitivity': 0.6028708219528198,
 'specificity': 0.988152265548706,
 'accuracy': 0.9681975841522217,
 'precision': 0.7354085445404053,
 'mcc': 0.6495752989607508,
 'micro_auroc': 0.9434270858764648}

In [12]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5 + state_dict_files_10 + state_dict_files_15
    res = pipeline.evaluate()
res

{'sensitivity': 0.5311004519462585,
 'specificity': 0.9919853806495667,
 'accuracy': 0.968114972114563,
 'precision': 0.7835294008255005,
 'mcc': 0.6298313699130325,
 'micro_auroc': 0.9421367645263672}

In [13]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5 + state_dict_files_10 + state_dict_files_15
    res = pipeline.evaluate(threshold='auto', verbose=True)
res

threshold: {'best_mcc': 0.6623866009432018, 'best_threshold': -0.6999999999999997}



{'sensitivity': 0.5693780183792114,
 'specificity': 0.9905915260314941,
 'accuracy': 0.9687758088111877,
 'precision': 0.7677419185638428,
 'mcc': 0.6456966813369992,
 'micro_auroc': 0.9421366453170776}

In [14]:
state_dict_files_5_new = [
    'rus_5_0_0.6151.pth',
    'rus_5_1_0.6221.pth',
    'rus_5_2_0.6193.pth',
    'rus_5_3_0.6266.pth',
    'rus_5_4_0.6052.pth',
    'rus_5_5_0.6085.pth',
    'rus_5_6_0.5986.pth',
    'rus_5_7_0.6108.pth',
    'rus_5_8_0.6046.pth',
    'rus_5_9_0.6080.pth',
    'rus_5_10_0.6089.pth',
    'rus_5_11_0.6174.pth',
    'rus_5_12_0.6239.pth',
    'rus_5_13_0.6337.pth',
    'rus_5_14_0.6144.pth',
    'rus_5_15_0.5990.pth',
    'rus_5_16_0.5824.pth',
    'rus_5_17_0.6237.pth',
    'rus_5_18_0.6112.pth',
    'rus_5_19_0.6090.pth',
]

with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5_new
    res = pipeline.evaluate()

res

{'sensitivity': 0.5502392053604126,
 'specificity': 0.9917240142822266,
 'accuracy': 0.9688584208488464,
 'precision': 0.7840909361839294,
 'mcc': 0.6417536010596256,
 'micro_auroc': 0.9407066702842712}

In [3]:
state_dict_files_5_new = [
    'rus_5_0_0.6151.pth',
    'rus_5_1_0.6221.pth',
    'rus_5_2_0.6193.pth',
    'rus_5_3_0.6266.pth',
    'rus_5_4_0.6052.pth',
    'rus_5_5_0.6085.pth',
    'rus_5_6_0.5986.pth',
    'rus_5_7_0.6108.pth',
    'rus_5_8_0.6046.pth',
    'rus_5_9_0.6080.pth',
    'rus_5_10_0.6089.pth',
    'rus_5_11_0.6174.pth',
    'rus_5_12_0.6239.pth',
    'rus_5_13_0.6337.pth',
    'rus_5_14_0.6144.pth',
    'rus_5_15_0.6066.pth',
    'rus_5_16_0.5824.pth',
    'rus_5_17_0.6237.pth',
    'rus_5_18_0.6112.pth',
    'rus_5_19_0.6090.pth',
]

with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5_new
    res = pipeline.evaluate(threshold='auto', verbose=True)

res

threshold: {'best_mcc': 0.6553472253500954, 'best_threshold': -0.3999999999999999}



{'sensitivity': 0.5614035129547119,
 'specificity': 0.9902430772781372,
 'accuracy': 0.9680323600769043,
 'precision': 0.7586206793785095,
 'mcc': 0.6367572903013485,
 'micro_auroc': 0.9396817088127136}

In [5]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5_new
    res = pipeline.evaluate()
res

{'sensitivity': 0.5454545617103577,
 'specificity': 0.9916369318962097,
 'accuracy': 0.9685280323028564,
 'precision': 0.7808219194412231,
 'mcc': 0.6373804959992757,
 'micro_auroc': 0.9396817088127136}

In [16]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5_new + state_dict_files_10
    res = pipeline.evaluate()

res

{'sensitivity': 0.5502392053604126,
 'specificity': 0.9911141991615295,
 'accuracy': 0.9682801961898804,
 'precision': 0.7718120813369751,
 'mcc': 0.6361833356155675,
 'micro_auroc': 0.9427580833435059}

In [17]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5_new + state_dict_files_10
    res = pipeline.evaluate(threshold='auto', verbose=True)

res

threshold: {'best_mcc': 0.6592559265163181, 'best_threshold': -0.8999999999999999}



{'sensitivity': 0.59968101978302,
 'specificity': 0.9882394075393677,
 'accuracy': 0.968114972114563,
 'precision': 0.7358121275901794,
 'mcc': 0.6479753908165176,
 'micro_auroc': 0.9427580833435059}

In [8]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5_new + state_dict_files_10
    res = pipeline.evaluate(threshold='auto', verbose=True)

res

threshold: {'best_mcc': 0.6574639477195364, 'best_threshold': -0.7999999999999998}



{'sensitivity': 0.5885167717933655,
 'specificity': 0.9887620806694031,
 'accuracy': 0.9680323600769043,
 'precision': 0.740963876247406,
 'mcc': 0.6441367257609167,
 'micro_auroc': 0.9421606659889221}