In [1]:
from lib.pipeline import Pipeline
import torch
from torchdrug import utils, data
import pandas as pd

GPU = 1

def create_single_pred_dataframe(pipeline, dataset):
    df = pd.DataFrame()
    pipeline.task.eval()
    for protein_index, batch in enumerate(data.DataLoader(dataset, batch_size=1, shuffle=False)):
        batch = utils.cuda(batch, device=f'cuda:{GPU}')
        label = pipeline.task.target(batch)['label'].flatten()
        
        new_data = {
            'protein_index': protein_index,
            'residue_index': list(range(len(label))),
            'target': label.tolist(),
        }
        pred = pipeline.task.predict(batch).flatten()
        assert(len(label) == len(pred))
        new_data[f'pred'] = [round(t, 5) for t in pred.tolist()]
        new_data = pd.DataFrame(new_data)
        df = pd.concat([df, new_data])
    
    return df

def adaboost_iter(iter_num, masks=None):
    # initialize new pipeline
    print('Initializing new pipeline')
    model_kwargs = {
        'gpu': GPU,
        'lm_type': 'esm-t33',
        'gearnet_hidden_dim_size': 512,
        'gearnet_hidden_dim_count': 4,
    }
    pipeline = Pipeline(
        model='lm-gearnet',
        dataset='atpbind3d',
        gpus=[GPU],
        model_kwargs=model_kwargs,
        optimizer_kwargs={
            'lr': 1e-3,
        },
        batch_size=4,
    )
    pipeline.model.freeze_lm(freeze_all=False, freeze_layer_count=30)
    
    print('Training..')
    train_record, state_dict = pipeline.train_until_fit(patience=5, return_state_dict=True, use_dynamic_threshold=False)

    print('Train Done')
    train_dataloader = data.DataLoader(pipeline.train_set, batch_size=1, shuffle=False)

    # load the best model
    print('Loading best model')
    pipeline.task.load_state_dict(state_dict)
    pipeline.task.eval()


    # Get the prediction of all residues with negative labels
    print('Getting prediciton for negative labels')
    if not masks:
        masks = [
            torch.ones(train_data['graph'].num_residue.item()).bool() 
            for train_data in pipeline.train_set
        ]

    negative_labels = []
    for protein_index, batch in enumerate(train_dataloader):
        batch = utils.cuda(batch, device=f'cuda:{GPU}')
        label = pipeline.task.target(batch)['label'].flatten()
        pred = pipeline.task.predict(batch).flatten()
        for i in range(len(label)):
            if label[i] == 0 and masks[protein_index][i]:
                negative_labels.append({
                    "protein_index": protein_index,
                    'resudie_index': i,
                    'pred': pred[i].item(),
                })
            
    negative_labels = sorted(negative_labels, key=lambda x: x['pred'], reverse=False)
    top_10_percent = int(len(negative_labels) * 0.1)
    for elem in negative_labels[:top_10_percent]:
        masks[elem['protein_index']][elem['resudie_index']] = False

    # save prediction of current round
    print('Saving prediction')
    df_valid = create_single_pred_dataframe(pipeline, pipeline.valid_set)
    df_valid.to_csv(f'preds/adaboost_{iter_num:02d}_valid.csv', index=False)

    df_test = create_single_pred_dataframe(pipeline, pipeline.test_set)
    df_test.to_csv(f'preds/adaboost_{iter_num:02d}_test.csv', index=False)
    
    return masks
    
    

In [2]:
masks = adaboost_iter(iter_num=1, masks=None)

Initializing new pipeline
load model lm-gearnet, kwargs: {'gpu': 1, 'lm_type': 'esm-t33', 'gearnet_hidden_dim_size': 512, 'gearnet_hidden_dim_count': 4}
get dataset atpbind3d
Initialize RUS: None
train samples: 302, valid samples: 76, test samples: 41
Training..
0m48s {'sensitivity': 0.2695, 'specificity': 0.9966, 'accuracy': 0.9589, 'precision': 0.8125, 'mcc': 0.4539, 'micro_auroc': 0.9095, 'train_bce': 0.1841, 'valid_bce': 0.1192, 'valid_mcc': 0.4189}
0m46s {'sensitivity': 0.5981, 'specificity': 0.986, 'accuracy': 0.9659, 'precision': 0.6996, 'mcc': 0.6292, 'micro_auroc': 0.9487, 'train_bce': 0.0893, 'valid_bce': 0.0942, 'valid_mcc': 0.5938}
0m46s {'sensitivity': 0.4338, 'specificity': 0.9963, 'accuracy': 0.9671, 'precision': 0.8635, 'mcc': 0.5987, 'micro_auroc': 0.9233, 'train_bce': 0.0522, 'valid_bce': 0.1243, 'valid_mcc': 0.584}
0m46s {'sensitivity': 0.4689, 'specificity': 0.9933, 'accuracy': 0.9661, 'precision': 0.7925, 'mcc': 0.5943, 'micro_auroc': 0.92, 'train_bce': 0.0324, 'va

KeyboardInterrupt: 