In [1]:
import sys
import os


import pandas as pd
from lib.pipeline import Pipeline
import torch


GPU = 1


pipeline = Pipeline(
    model='lm-gearnet',
    dataset='atpbind3d',
    gpus=[GPU],
    model_kwargs={
        'gpu': GPU,
        'gearnet_hidden_dim_size': 512,
        'gearnet_hidden_dim_count': 4,
        'bert_freeze': False,
        'bert_freeze_layer_count': 28,
    },
    optimizer_kwargs={
        'lr': 2e-4,
        'weight_decay': 0.0001,
    },
    task_kwargs={
        'use_rus': True,
        'rus_seed': 0,
        'undersample_rate': 0.05,
    },
    batch_size=8,
    optimizer="adamw",
)
state_dict = torch.load('ResidueType_lmg_4_512_0.57268.pth',
                        map_location=f'cuda:{GPU}')
pipeline.model.gearnet.load_state_dict(state_dict)



get dataset atpbind3d
Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


<All keys matched successfully>

In [4]:
class ExponentialLR(torch.optim.lr_scheduler._LRScheduler):
    """Decays the learning rate of each parameter group by gamma every epoch.
    When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        gamma (float): Multiplicative factor of learning rate decay.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.
    """

    def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
        super(ExponentialLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * self.gamma ** self.last_epoch
                for base_lr in self.base_lrs]



In [None]:
pipeline.solver.scheduler = ExponentialLR(
    gamma=0.925, # 9번 돌리면 0.5배로 줄어듬
    optimizer=pipeline.solver.optimizer)

In [9]:
pipeline.train_until_fit(
    patience=20,
)

{'sensitivity': 0.8915, 'specificity': 0.7183, 'accuracy': 0.7272, 'precision': 0.1474, 'mcc': 0.2913, 'micro_auroc': 0.9078, 'train_bce': 0.0419, 'valid_bce': 0.0314, 'valid_mcc': 0.3499}
{'sensitivity': 0.8453, 'specificity': 0.8469, 'accuracy': 0.8469, 'precision': 0.2317, 'mcc': 0.3919, 'micro_auroc': 0.9262, 'train_bce': 0.0263, 'valid_bce': 0.029, 'valid_mcc': 0.413}
{'sensitivity': 0.7002, 'specificity': 0.9172, 'accuracy': 0.906, 'precision': 0.3161, 'mcc': 0.4293, 'micro_auroc': 0.9217, 'train_bce': 0.0209, 'valid_bce': 0.0291, 'valid_mcc': 0.5106}
{'sensitivity': 0.8006, 'specificity': 0.886, 'accuracy': 0.8815, 'precision': 0.2772, 'mcc': 0.4266, 'micro_auroc': 0.9195, 'train_bce': 0.0157, 'valid_bce': 0.0255, 'valid_mcc': 0.473}
{'sensitivity': 0.5439, 'specificity': 0.9788, 'accuracy': 0.9563, 'precision': 0.5839, 'mcc': 0.5406, 'micro_auroc': 0.9136, 'train_bce': 0.0112, 'valid_bce': 0.0513, 'valid_mcc': 0.55}
{'sensitivity': 0.5805, 'specificity': 0.9575, 'accuracy': 0.9

KeyboardInterrupt: 

In [5]:
def make_pipeline(lr, gamma):
    pipeline = Pipeline(
        model='lm-gearnet',
        dataset='atpbind3d',
        gpus=[GPU],
        model_kwargs={
            'gpu': GPU,
            'gearnet_hidden_dim_size': 512,
            'gearnet_hidden_dim_count': 4,
            'bert_freeze': False,
            'bert_freeze_layer_count': 28,
        },
        optimizer_kwargs={
            'lr': lr,
            'weight_decay': 0.0001,
        },
        task_kwargs={
            'use_rus': True,
            'rus_seed': 0,
            'undersample_rate': 0.05,
        },
        batch_size=8,
        optimizer="adamw",
    )
    state_dict = torch.load('ResidueType_lmg_4_512_0.57268.pth',
                            map_location=f'cuda:{GPU}')
    pipeline.model.gearnet.load_state_dict(state_dict)
    
    pipeline.solver.scheduler = ExponentialLR(
        gamma=gamma, 
        optimizer=pipeline.solver.optimizer
    )
    
    return pipeline

pipeline = make_pipeline(lr=5e-4, gamma=0.917) # 8번 돌리면 0.5배로 줄어듬


In [11]:
pipeline.train_until_fit(
    patience=10,
)

{'sensitivity': 0.8501, 'specificity': 0.8051, 'accuracy': 0.8075, 'precision': 0.1924, 'mcc': 0.3457, 'micro_auroc': 0.9061, 'train_bce': 0.0439, 'valid_bce': 0.0274, 'valid_mcc': 0.3868}
{'sensitivity': 0.4577, 'specificity': 0.9793, 'accuracy': 0.9523, 'precision': 0.5467, 'mcc': 0.4755, 'micro_auroc': 0.8979, 'train_bce': 0.0289, 'valid_bce': 0.0505, 'valid_mcc': 0.5193}
{'sensitivity': 0.8724, 'specificity': 0.8082, 'accuracy': 0.8115, 'precision': 0.199, 'mcc': 0.36, 'micro_auroc': 0.9253, 'train_bce': 0.0227, 'valid_bce': 0.0281, 'valid_mcc': 0.4118}
{'sensitivity': 0.6746, 'specificity': 0.9449, 'accuracy': 0.9309, 'precision': 0.4006, 'mcc': 0.4865, 'micro_auroc': 0.9301, 'train_bce': 0.0209, 'valid_bce': 0.0295, 'valid_mcc': 0.4875}
{'sensitivity': 0.7097, 'specificity': 0.9402, 'accuracy': 0.9282, 'precision': 0.3931, 'mcc': 0.4947, 'micro_auroc': 0.9243, 'train_bce': 0.0138, 'valid_bce': 0.0386, 'valid_mcc': 0.5455}
{'sensitivity': 0.5981, 'specificity': 0.9684, 'accuracy':

[{'sensitivity': 0.8501,
  'specificity': 0.8051,
  'accuracy': 0.8075,
  'precision': 0.1924,
  'mcc': 0.3457,
  'micro_auroc': 0.9061,
  'train_bce': 0.0439,
  'valid_bce': 0.0274,
  'valid_mcc': 0.3868},
 {'sensitivity': 0.4577,
  'specificity': 0.9793,
  'accuracy': 0.9523,
  'precision': 0.5467,
  'mcc': 0.4755,
  'micro_auroc': 0.8979,
  'train_bce': 0.0289,
  'valid_bce': 0.0505,
  'valid_mcc': 0.5193},
 {'sensitivity': 0.8724,
  'specificity': 0.8082,
  'accuracy': 0.8115,
  'precision': 0.199,
  'mcc': 0.36,
  'micro_auroc': 0.9253,
  'train_bce': 0.0227,
  'valid_bce': 0.0281,
  'valid_mcc': 0.4118},
 {'sensitivity': 0.6746,
  'specificity': 0.9449,
  'accuracy': 0.9309,
  'precision': 0.4006,
  'mcc': 0.4865,
  'micro_auroc': 0.9301,
  'train_bce': 0.0209,
  'valid_bce': 0.0295,
  'valid_mcc': 0.4875},
 {'sensitivity': 0.7097,
  'specificity': 0.9402,
  'accuracy': 0.9282,
  'precision': 0.3931,
  'mcc': 0.4947,
  'micro_auroc': 0.9243,
  'train_bce': 0.0138,
  'valid_bce': 

In [6]:
pipeline.train_until_fit(
    patience=10,
)

{'sensitivity': 0.874, 'specificity': 0.7258, 'accuracy': 0.7335, 'precision': 0.1483, 'mcc': 0.2887, 'micro_auroc': 0.8953, 'train_bce': 0.0457, 'valid_bce': 0.0315, 'valid_mcc': 0.3598}
{'sensitivity': 0.4833, 'specificity': 0.975, 'accuracy': 0.9495, 'precision': 0.5136, 'mcc': 0.4716, 'micro_auroc': 0.9103, 'train_bce': 0.0305, 'valid_bce': 0.0477, 'valid_mcc': 0.4958}
{'sensitivity': 0.7911, 'specificity': 0.885, 'accuracy': 0.8801, 'precision': 0.2731, 'mcc': 0.4196, 'micro_auroc': 0.9208, 'train_bce': 0.0221, 'valid_bce': 0.027, 'valid_mcc': 0.4661}
{'sensitivity': 0.5104, 'specificity': 0.9647, 'accuracy': 0.9412, 'precision': 0.4414, 'mcc': 0.4437, 'micro_auroc': 0.8925, 'train_bce': 0.0215, 'valid_bce': 0.0531, 'valid_mcc': 0.5022}
{'sensitivity': 0.7018, 'specificity': 0.956, 'accuracy': 0.9428, 'precision': 0.4656, 'mcc': 0.5434, 'micro_auroc': 0.929, 'train_bce': 0.0131, 'valid_bce': 0.0374, 'valid_mcc': 0.5415}
{'sensitivity': 0.4833, 'specificity': 0.9818, 'accuracy': 0.

[{'sensitivity': 0.874,
  'specificity': 0.7258,
  'accuracy': 0.7335,
  'precision': 0.1483,
  'mcc': 0.2887,
  'micro_auroc': 0.8953,
  'train_bce': 0.0457,
  'valid_bce': 0.0315,
  'valid_mcc': 0.3598},
 {'sensitivity': 0.4833,
  'specificity': 0.975,
  'accuracy': 0.9495,
  'precision': 0.5136,
  'mcc': 0.4716,
  'micro_auroc': 0.9103,
  'train_bce': 0.0305,
  'valid_bce': 0.0477,
  'valid_mcc': 0.4958},
 {'sensitivity': 0.7911,
  'specificity': 0.885,
  'accuracy': 0.8801,
  'precision': 0.2731,
  'mcc': 0.4196,
  'micro_auroc': 0.9208,
  'train_bce': 0.0221,
  'valid_bce': 0.027,
  'valid_mcc': 0.4661},
 {'sensitivity': 0.5104,
  'specificity': 0.9647,
  'accuracy': 0.9412,
  'precision': 0.4414,
  'mcc': 0.4437,
  'micro_auroc': 0.8925,
  'train_bce': 0.0215,
  'valid_bce': 0.0531,
  'valid_mcc': 0.5022},
 {'sensitivity': 0.7018,
  'specificity': 0.956,
  'accuracy': 0.9428,
  'precision': 0.4656,
  'mcc': 0.5434,
  'micro_auroc': 0.929,
  'train_bce': 0.0131,
  'valid_bce': 0.0

In [7]:
pipeline = make_pipeline(lr=1e-3, gamma=0.906) # 7번 돌리면 0.5배로 줄어듬

pipeline.train_until_fit(
    patience=10,
)

{'sensitivity': 0.6778, 'specificity': 0.8371, 'accuracy': 0.8288, 'precision': 0.1852, 'mcc': 0.2911, 'micro_auroc': 0.8432, 'train_bce': 0.0506, 'valid_bce': 0.0372, 'valid_mcc': 0.3244}
{'sensitivity': 0.2759, 'specificity': 0.9906, 'accuracy': 0.9536, 'precision': 0.6157, 'mcc': 0.3922, 'micro_auroc': 0.8711, 'train_bce': 0.0334, 'valid_bce': 0.0754, 'valid_mcc': 0.4393}
{'sensitivity': 0.5407, 'specificity': 0.9598, 'accuracy': 0.938, 'precision': 0.4232, 'mcc': 0.4461, 'micro_auroc': 0.8952, 'train_bce': 0.0285, 'valid_bce': 0.0525, 'valid_mcc': 0.4696}
{'sensitivity': 0.866, 'specificity': 0.7508, 'accuracy': 0.7567, 'precision': 0.1595, 'mcc': 0.304, 'micro_auroc': 0.8964, 'train_bce': 0.0269, 'valid_bce': 0.0281, 'valid_mcc': 0.3747}
{'sensitivity': 0.6013, 'specificity': 0.9565, 'accuracy': 0.9381, 'precision': 0.4304, 'mcc': 0.4771, 'micro_auroc': 0.9143, 'train_bce': 0.0182, 'valid_bce': 0.045, 'valid_mcc': 0.5039}
{'sensitivity': 0.7384, 'specificity': 0.9092, 'accuracy': 

[{'sensitivity': 0.6778,
  'specificity': 0.8371,
  'accuracy': 0.8288,
  'precision': 0.1852,
  'mcc': 0.2911,
  'micro_auroc': 0.8432,
  'train_bce': 0.0506,
  'valid_bce': 0.0372,
  'valid_mcc': 0.3244},
 {'sensitivity': 0.2759,
  'specificity': 0.9906,
  'accuracy': 0.9536,
  'precision': 0.6157,
  'mcc': 0.3922,
  'micro_auroc': 0.8711,
  'train_bce': 0.0334,
  'valid_bce': 0.0754,
  'valid_mcc': 0.4393},
 {'sensitivity': 0.5407,
  'specificity': 0.9598,
  'accuracy': 0.938,
  'precision': 0.4232,
  'mcc': 0.4461,
  'micro_auroc': 0.8952,
  'train_bce': 0.0285,
  'valid_bce': 0.0525,
  'valid_mcc': 0.4696},
 {'sensitivity': 0.866,
  'specificity': 0.7508,
  'accuracy': 0.7567,
  'precision': 0.1595,
  'mcc': 0.304,
  'micro_auroc': 0.8964,
  'train_bce': 0.0269,
  'valid_bce': 0.0281,
  'valid_mcc': 0.3747},
 {'sensitivity': 0.6013,
  'specificity': 0.9565,
  'accuracy': 0.9381,
  'precision': 0.4304,
  'mcc': 0.4771,
  'micro_auroc': 0.9143,
  'train_bce': 0.0182,
  'valid_bce': 

In [8]:
pipeline = make_pipeline(lr=5e-4, gamma=0.926) # 9번 돌리면 0.5배로 줄어듬

pipeline.train_until_fit(
    patience=10,
)

{'sensitivity': 0.7464, 'specificity': 0.8268, 'accuracy': 0.8226, 'precision': 0.1906, 'mcc': 0.3159, 'micro_auroc': 0.8799, 'train_bce': 0.0443, 'valid_bce': 0.0314, 'valid_mcc': 0.3519}
{'sensitivity': 0.5359, 'specificity': 0.9593, 'accuracy': 0.9374, 'precision': 0.4184, 'mcc': 0.441, 'micro_auroc': 0.898, 'train_bce': 0.0301, 'valid_bce': 0.0402, 'valid_mcc': 0.4548}
{'sensitivity': 0.8357, 'specificity': 0.8505, 'accuracy': 0.8497, 'precision': 0.2339, 'mcc': 0.3916, 'micro_auroc': 0.9248, 'train_bce': 0.0217, 'valid_bce': 0.0277, 'valid_mcc': 0.4373}
{'sensitivity': 0.638, 'specificity': 0.9376, 'accuracy': 0.9221, 'precision': 0.3584, 'mcc': 0.4409, 'micro_auroc': 0.9115, 'train_bce': 0.02, 'valid_bce': 0.0357, 'valid_mcc': 0.527}
{'sensitivity': 0.638, 'specificity': 0.9602, 'accuracy': 0.9435, 'precision': 0.4667, 'mcc': 0.5168, 'micro_auroc': 0.9238, 'train_bce': 0.0132, 'valid_bce': 0.0404, 'valid_mcc': 0.5781}
{'sensitivity': 0.5949, 'specificity': 0.9671, 'accuracy': 0.9

[{'sensitivity': 0.7464,
  'specificity': 0.8268,
  'accuracy': 0.8226,
  'precision': 0.1906,
  'mcc': 0.3159,
  'micro_auroc': 0.8799,
  'train_bce': 0.0443,
  'valid_bce': 0.0314,
  'valid_mcc': 0.3519},
 {'sensitivity': 0.5359,
  'specificity': 0.9593,
  'accuracy': 0.9374,
  'precision': 0.4184,
  'mcc': 0.441,
  'micro_auroc': 0.898,
  'train_bce': 0.0301,
  'valid_bce': 0.0402,
  'valid_mcc': 0.4548},
 {'sensitivity': 0.8357,
  'specificity': 0.8505,
  'accuracy': 0.8497,
  'precision': 0.2339,
  'mcc': 0.3916,
  'micro_auroc': 0.9248,
  'train_bce': 0.0217,
  'valid_bce': 0.0277,
  'valid_mcc': 0.4373},
 {'sensitivity': 0.638,
  'specificity': 0.9376,
  'accuracy': 0.9221,
  'precision': 0.3584,
  'mcc': 0.4409,
  'micro_auroc': 0.9115,
  'train_bce': 0.02,
  'valid_bce': 0.0357,
  'valid_mcc': 0.527},
 {'sensitivity': 0.638,
  'specificity': 0.9602,
  'accuracy': 0.9435,
  'precision': 0.4667,
  'mcc': 0.5168,
  'micro_auroc': 0.9238,
  'train_bce': 0.0132,
  'valid_bce': 0.04