In [2]:
# LM-Gearnet Pretrained: finetune it
from lib.pipeline import Pipeline
import torch

GPU = 0
pipeline = Pipeline(
    model='lm-gearnet',
    dataset='atpbind3d',
    gpus=[GPU],
    model_kwargs={
        'gpu': 0,
        'gearnet_hidden_dim_size': 512,
        'gearnet_hidden_dim_count': 4,
        'bert_freeze': False,
        'bert_freeze_layer_count': 29,
    }
)


state_dict = torch.load('ResidueType_lmg_4_512_0.57268.pth')
pipeline.model.gearnet.load_state_dict(state_dict)


get dataset atpbind3d
Split num:  [337, 41, 41]
train samples: 337, valid samples: 41, test samples: 41


<All keys matched successfully>

In [135]:
pipeline.train(num_epoch=1)
    

01:25:02   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:25:02   Epoch 4 begin
01:26:47   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:26:47   Epoch 4 end
01:26:47   duration: 1.27 hours
01:26:47   speed: 0.07 batch / sec
01:26:47   ETA: 0.00 secs
01:26:47   max GPU memory: 6690.6 MiB
01:26:47   ------------------------------
01:26:47   average binary cross entropy: 0.0236077


In [137]:
torch.save(pipeline.model.state_dict(), 'lm_gearnet_finetuned_demo.pth')

In [138]:
pipeline.model.load_state_dict(torch.load('lm_gearnet_finetuned_demo.pth'))

<All keys matched successfully>

In [136]:
pipeline.evaluate()

01:27:20   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
01:27:20   Evaluate on test
01:27:30   ------------------------------
01:27:30   mcc: 0.522607
01:27:30   micro_auroc: 0.917182


{'micro_auroc': 0.9171818494796753, 'mcc': 0.5226066587138549}

In [167]:
from torchdrug import data, utils
# https://www.rcsb.org/structure/4K6R
DEMO_PDB = '4K6RA'
DEMO_ID = 5
RESIDUE_OFFSET = 6

dataloader = data.DataLoader(
    [pipeline.train_set[DEMO_ID]], batch_size=1, shuffle=False)
batch = utils.cuda(next(iter(dataloader)), device=torch.device('cuda:{}'.format(GPU)))
pred, target = pipeline.task.predict_and_target(batch)

target_index = [i+1 for i, item in enumerate(target['label']) if item.item() == 1]
predict_index = [i+1 for i, item in enumerate(pred) if item.item() > -2]
print('binding site:', target_index)
print('graph: ', batch['graph'])
print('Sequence: ', batch['graph'].to_sequence()[0])
print('prediction: ', predict_index)

print('-------')
true_positive_index = [item for item in predict_index if item in target_index]
false_negative_index = [item for item in target_index if item not in predict_index]
false_positive_index = [item for item in predict_index if item not in target_index]
print('true positive: ', true_positive_index)
print('false negative: ', false_negative_index)
print('false positive: ', false_positive_index)

binding site: [9, 10, 11, 12, 13, 14, 52, 53, 78, 80, 81, 82, 83, 84, 87, 202]
graph:  PackedProtein(batch_size=1, num_atoms=[1400], num_bonds=[2798], num_residues=[350], device='cuda:0')
Sequence:  DTAVLVLAAGPGTRMRSDTPKVLHTLAGRSMLSHVLHAIAKLAPQRLIVVLGHDHQRIAPLVGELADTLGRTIDVALQDRPLGTGHAVLCGLSALPDDYAGNVVVTSGDTPLLDADTLADLIATHRAVSAAVTVLTTTLDDPFGYGRILRTQDHEVMAIVEQTDATPSQREIREVNAGVYAFDIAALRSALSRLSSNNAQQELYLTDVIAILRSDGQTVHASHVDDSALVAGVNNRVQLAELASELNRRVVAAHQLAGVTVVDPATTWIDVDVTIGRDTVIHPGTQLLGRTQIGGRCVVGPDTTLTDVAVGDGASVVRTHGSSSSIGDGAAVGPFTYLRPGTALGADGKL
prediction:  [10, 11, 13, 52, 53, 80, 81, 82, 83, 84, 87, 202]
-------
true positive:  [10, 11, 13, 52, 53, 80, 81, 82, 83, 84, 87, 202]
false negative:  [9, 12, 14, 78]
false positive:  []


In [168]:
from rdkit import Chem
import nglview
from random import random
mol = Chem.MolFromPDBFile(f'data/pdb/{DEMO_PDB}.pdb')
js_function = """
this.atomColor = function (atom) {
    if (%s.includes(atom.serial)) { // true positive
        return 0x00FF00
    } else if (%s.includes(atom.serial)) { // false negative
        return 0xFF0000
    } else if (%s.includes(atom.serial)) { // false positive
        return 0x0000FF
    } else {
        return 0x808080
    }
}
""" % ([i + RESIDUE_OFFSET - 1 for i in true_positive_index],
        [i + RESIDUE_OFFSET - 1 for i in false_negative_index],
        [i + RESIDUE_OFFSET - 1 for i in false_positive_index])

scheme_name_rnd = "awesome-" + str(random())
nglview.color.ColormakerRegistry.add_scheme_func(scheme_name_rnd, js_function)
view = nglview.show_rdkit(mol, default_representation=False)
view.center()
view.add_cartoon(color=scheme_name_rnd)
view


NGLWidget()