In [3]:
import sys
import os

sys.path.insert(0, os.path.abspath('..'))

import pandas as pd
from lib.pipeline import Pipeline
import torch
from torchdrug import utils, data
from lib.lr_scheduler import ExponentialLR

GPU = 0

def make_pipeline(noise_rate):
    pipeline = Pipeline(
        model='lm-gearnet',
        dataset='atpbind3d',
        gpus=[GPU],
        model_kwargs={
            'gpu': GPU,
            'gearnet_hidden_dim_size': 512,
            'gearnet_hidden_dim_count': 4,
            'bert_freeze': False,
            'bert_freeze_layer_count': 28,
        },
        optimizer_kwargs={    
            'lr': 4e-4,
        },
        rus_kwargs={
            'rus_seed': 0,
            'rus_rate': 0.05,
            'rus_by': 'residue',
            'rus_noise_rate': noise_rate,
        },
        # task_kwargs={
        #     'use_rus': True,
        #     'rus_seed': 0,
        #     'undersample_rate': 0.05,
        # },
        batch_size=8,
        optimizer='adam',
    )
    state_dict = torch.load('../ResidueType_lmg_4_512_0.57268.pth',
                            map_location=f'cuda:{GPU}')
    pipeline.model.gearnet.load_state_dict(state_dict)

    scheduler = ExponentialLR(gamma=0.5**(1/12), optimizer=pipeline.solver.optimizer)
    pipeline.solver.scheduler = scheduler
    
    return pipeline

In [10]:
import pandas as pd
df = pd.DataFrame()
for noise_rate in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
    print('noise_rate:', noise_rate)
    pipeline = make_pipeline(noise_rate)
    patience = 5
    train_record = pipeline.train_until_fit(patience=patience)
    new_row = pd.DataFrame.from_dict([{'noise_rate': noise_rate, **train_record[-1-patience], 'num_epoch': len(train_record)}])
    df = pd.concat([df, new_row], ignore_index=True)
    df.to_csv('rus_noise_rate.csv', index=False)

noise_rate: 0
Initialize RUS: seed 0, rate 0.05, by residue
train samples: 302, valid samples: 76, test samples: 41
{'sensitivity': 0.8549, 'specificity': 0.7894, 'accuracy': 0.7928, 'precision': 0.1815, 'mcc': 0.3325, 'micro_auroc': 0.9121, 'train_bce': 0.4754, 'valid_bce': 0.3986, 'valid_mcc': 0.393}
{'sensitivity': 0.5917, 'specificity': 0.962, 'accuracy': 0.9428, 'precision': 0.4597, 'mcc': 0.492, 'micro_auroc': 0.9126, 'train_bce': 0.2527, 'valid_bce': 0.1548, 'valid_mcc': 0.4392}
{'sensitivity': 0.5167, 'specificity': 0.9572, 'accuracy': 0.9344, 'precision': 0.3975, 'mcc': 0.4192, 'micro_auroc': 0.8885, 'train_bce': 0.1255, 'valid_bce': 0.2087, 'valid_mcc': 0.4017}
{'sensitivity': 0.8166, 'specificity': 0.8385, 'accuracy': 0.8374, 'precision': 0.2164, 'mcc': 0.3661, 'micro_auroc': 0.9113, 'train_bce': 0.0625, 'valid_bce': 0.4753, 'valid_mcc': 0.4044}
{'sensitivity': 0.764, 'specificity': 0.8994, 'accuracy': 0.8924, 'precision': 0.2931, 'mcc': 0.4302, 'micro_auroc': 0.9184, 'train