In [4]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))
from lib.pipeline import Pipeline
from lib.disable_logger import DisableLogger
import torch

GPU = 1

pipeline = Pipeline(
    model='lm-gearnet',
    dataset='atpbind3d',
    task='mean-ensemble',
    gpus=[GPU],
    model_kwargs={
        'gpu': GPU,
        'gearnet_hidden_dim_size': 512,
        'gearnet_hidden_dim_count': 4,
        'bert_freeze': False,
        'bert_freeze_layer_count': 29,
    },
    optimizer_kwargs={
        'lr': 5e-4,
    },
    task_kwargs={
        'state_dict_files': [],
    },
    bce_weight=1,
    batch_size=16,
)

In [5]:
state_dict_files_5 = [
    'rus_5_0_0.6151.pth',
    'rus_5_1_0.6221.pth',
    'rus_5_2_0.6193.pth',
    'rus_5_3_0.6266.pth',
    'rus_5_4_0.6052.pth',
    'rus_5_5_0.6085.pth',
    'rus_5_6_0.5986.pth',
    'rus_5_7_0.6108.pth',
    'rus_5_8_0.6046.pth',
    'rus_5_9_0.6080.pth',
]

with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5
    print(pipeline.evaluate())

{'sensitivity': 0.5725677609443665, 'specificity': 0.9893718957901001, 'accuracy': 0.967784583568573, 'precision': 0.7463617324829102, 'mcc': 0.637538591910511, 'micro_auroc': 0.9402536749839783}


In [4]:
state_dict_files_10 = [
    'rus_10_0_0.59580.pth',
    'rus_10_1_0.59290.pth',
    'rus_10_2_0.6494.pth',
    'rus_10_3_0.6114.pth',
    'rus_10_4_0.59780.pth',
    'rus_10_5_0.6173.pth',
    'rus_10_6_0.6152.pth',
    'rus_10_7_0.6091.pth',
    'rus_10_8_0.5989.pth',
    'rus_10_9_0.5892.pth',
]


with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_10
    print(pipeline.evaluate())

{'sensitivity': 0.5693780183792114, 'specificity': 0.9895461201667786, 'accuracy': 0.967784583568573, 'precision': 0.74842768907547, 'mcc': 0.6366610392623603, 'micro_auroc': 0.9433506727218628}


In [5]:
state_dict_files_15 = [
    'rus_15_0_0.5936.pth',
    'rus_15_1_0.6274.pth',
    'rus_15_2_0.6004.pth',
    'rus_15_3_0.6200.pth',
    'rus_15_4_0.6215.pth',
    'rus_15_5_0.6085.pth',
    'rus_15_6_0.6096.pth',
    'rus_15_7_0.6098.pth',
    'rus_15_8_0.6186.pth',
    'rus_15_9_0.6188.pth',
]

with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_15
    print(pipeline.evaluate())

{'sensitivity': 0.5534290075302124, 'specificity': 0.9899817109107971, 'accuracy': 0.9673715233802795, 'precision': 0.7510822415351868, 'mcc': 0.6285530296658867, 'micro_auroc': 0.9351375102996826}


In [9]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5 + state_dict_files_10
    res = pipeline.evaluate()
res


{'sensitivity': 0.6028708219528198,
 'specificity': 0.988152265548706,
 'accuracy': 0.9681975841522217,
 'precision': 0.7354085445404053,
 'mcc': 0.6495752989607508,
 'micro_auroc': 0.9434273838996887}

In [10]:
with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5 + state_dict_files_10 + state_dict_files_15
    res = pipeline.evaluate()
res

{'sensitivity': 0.5693780183792114,
 'specificity': 0.9905915260314941,
 'accuracy': 0.9687758088111877,
 'precision': 0.7677419185638428,
 'mcc': 0.6456966813369992,
 'micro_auroc': 0.9421370625495911}

In [3]:
state_dict_files_5_new = [
    'rus_5_0_0.6151.pth',
    'rus_5_1_0.6221.pth',
    'rus_5_2_0.6193.pth',
    'rus_5_3_0.6266.pth',
    'rus_5_4_0.6052.pth',
    'rus_5_5_0.6085.pth',
    'rus_5_6_0.5986.pth',
    'rus_5_7_0.6108.pth',
    'rus_5_8_0.6046.pth',
    'rus_5_9_0.6080.pth',
    'rus_5_10_0.6089.pth',
    'rus_5_11_0.6174.pth',
]

with DisableLogger():
    pipeline.task.state_dict_files = state_dict_files_5_new
    print(pipeline.evaluate())

{'sensitivity': 0.5917065143585205, 'specificity': 0.9880651831626892, 'accuracy': 0.9675367474555969, 'precision': 0.7303149700164795, 'mcc': 0.6407954270209901, 'micro_auroc': 0.9395251274108887}
