In [8]:
import os
import pickle
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from protacloader import PROTACSet, collater
from prepare_data import GraphData


In [9]:

def valid(model, test_loader, device):
    with torch.no_grad():
        model.eval()
        pred = []
        true = []
        for data_sample in test_loader:
            name = data_sample['name']
            y = data_sample['label'].tolist()
            true += y
            outputs = model(data_sample['ligase_ligand'].to(device),
                            data_sample['ligase_pocket'].to(device),
                            data_sample['target_ligand'].to(device),
                            data_sample['target_pocket'].to(device),
                            data_sample['smiles'].to(device),
                            data_sample['smiles_length'],)
            pred_y = torch.max(outputs,1)[1].cpu().tolist()
            pred += pred_y
            print(name, y, pred_y)
        print(accuracy_score(true, pred))

In [10]:
with open('name.pkl','rb') as f:
    name_list = pickle.load(f)

ligase_ligand = GraphData("ligase_ligand", root='data')
ligase_pocket = GraphData("ligase_pocket", root='data')
target_ligand = GraphData("target_ligand", root='data')
target_pocket = GraphData("target_pocket", root='data')
with open(os.path.join(target_pocket.processed_dir, "smiles.pkl"),"rb") as f:
    smiles = pickle.load(f)
label = torch.load(os.path.join(target_pocket.processed_dir, "label.pt"))

In [11]:
test_set = PROTACSet(
    name_list,
    ligase_ligand, 
    ligase_pocket,
    target_ligand, 
    target_pocket, 
    smiles, 
    label,
)

testloader = DataLoader(test_set, batch_size=1, collate_fn=collater,drop_last=True)

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.load('model/test.pt',map_location=torch.device('cpu'),weights_only=False)

valid(model, testloader, device)

['SIAIS208034'] [1] [1]
['SIAIS208040'] [0] [1]
['SIAIS208037'] [1] [1]
['SIAIS208038'] [1] [1]
['SIAIS208017'] [1] [1]
['SIAIS208033'] [1] [1]
['SIAIS208045'] [0] [1]
['SIAIS208036'] [1] [1]
['SIAIS208039'] [0] [1]
['SIAIS208020'] [1] [1]
['SIAIS208041'] [1] [1]
['SIAIS208032'] [0] [1]
['SIAIS208031'] [0] [1]
['SIAIS208019'] [1] [1]
['SIAIS208035'] [1] [1]
['SIAIS208018'] [1] [1]
0.6875


In [13]:
test_set[-3]

{'name': 'SIAIS208019',
 'ligase_ligand': Data(x=[32], edge_index=[2, 68], edge_attr=[68]),
 'ligase_pocket': Data(x=[156], edge_index=[2, 316], edge_attr=[316]),
 'target_ligand': Data(x=[28], edge_index=[2, 60], edge_attr=[60]),
 'target_pocket': Data(x=[194], edge_index=[2, 368], edge_attr=[368]),
 'smiles': [1, 1, 4, 1, 1, 4, 1, 1, 4, 1, 1, 4, 1, 1, 1, 3, 4],
 'label': 1}