In [1]:
import torch
import pandas as pd
import numpy as np
from torch_geometric.data import Data
from scipy.spatial import distance
from EmbedDataset import get_file

import biographs as bg
from Bio.PDB import PDBParser
from Bio.SeqUtils import seq1
parser = PDBParser()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
from LigandGNNV2 import LigandGNNV2
from EmbedDataset import LigandBinaryDataset

ds = LigandBinaryDataset('./data2')

model = LigandGNNV2(128, 37).to(device)
model.load_state_dict(torch.load('./models/CompModel.pt'))
model.eval()

LigandGNNV2(
  (node_encoder): Linear(1070, 128, bias=True)
  (layers): ModuleList(
    (0): DeepGCNLayer(block=res+)
    (1): DeepGCNLayer(block=res+)
    (2): DeepGCNLayer(block=res+)
    (3): DeepGCNLayer(block=res+)
    (4): DeepGCNLayer(block=res+)
    (5): DeepGCNLayer(block=res+)
    (6): DeepGCNLayer(block=res+)
    (7): DeepGCNLayer(block=res+)
    (8): DeepGCNLayer(block=res+)
    (9): DeepGCNLayer(block=res+)
    (10): DeepGCNLayer(block=res+)
    (11): DeepGCNLayer(block=res+)
    (12): DeepGCNLayer(block=res+)
    (13): DeepGCNLayer(block=res+)
    (14): DeepGCNLayer(block=res+)
    (15): DeepGCNLayer(block=res+)
    (16): DeepGCNLayer(block=res+)
    (17): DeepGCNLayer(block=res+)
    (18): DeepGCNLayer(block=res+)
    (19): DeepGCNLayer(block=res+)
    (20): DeepGCNLayer(block=res+)
    (21): DeepGCNLayer(block=res+)
    (22): DeepGCNLayer(block=res+)
    (23): DeepGCNLayer(block=res+)
    (24): DeepGCNLayer(block=res+)
    (25): DeepGCNLayer(block=res+)
    (26): DeepGC

In [3]:
test_df = pd.read_csv('./data/af2_dataset_testset_unlabeled.csv').drop('Unnamed: 0', axis=1)
df_test_grouped = pd.read_hdf('./data/data_test.h5')

In [4]:
def reference_embedding(row):
    protein_name = row['entry']
    grouped_row = df_test_grouped.loc[df_test_grouped['entry'] == protein_name]
    protein_embedding = grouped_row['embeddings'].to_numpy()[0]
    index = row['entry_index']
    return protein_embedding[index]

In [5]:
test_df['embeddings'] = test_df.apply(reference_embedding, axis =1)

In [6]:
bool_cols = [col for col in test_df.columns if test_df[col].dtype == bool]
test_df[bool_cols] = test_df[bool_cols].astype(int)
test_df

Unnamed: 0,annotation_sequence,feat_A,feat_C,feat_D,feat_E,feat_F,feat_G,feat_H,feat_I,feat_K,...,feat_DSSP_10,feat_DSSP_11,feat_DSSP_12,feat_DSSP_13,coord_X,coord_Y,coord_Z,entry,entry_index,embeddings
0,M,0,0,0,0,0,0,0,0,0,...,0,0.0,0,0.0,33.116001,37.023998,38.417000,QCR1_HUMAN,0,"[tensor(-0.3469), tensor(-0.0918), tensor(-0.0..."
1,A,1,0,0,0,0,0,0,0,0,...,2,-0.0,0,0.0,35.849998,34.841000,40.185001,QCR1_HUMAN,1,"[tensor(-0.1007), tensor(0.1798), tensor(0.297..."
2,A,1,0,0,0,0,0,0,0,0,...,0,0.0,2,-0.0,37.087002,31.719999,40.547001,QCR1_HUMAN,2,"[tensor(0.1528), tensor(0.2166), tensor(0.2359..."
3,S,0,0,0,0,0,0,0,0,0,...,0,0.0,-2,-0.0,38.095001,28.951000,42.321999,QCR1_HUMAN,3,"[tensor(0.1930), tensor(0.2558), tensor(-0.291..."
4,V,0,0,0,0,0,0,0,0,0,...,0,0.0,0,0.0,41.435001,27.417000,43.703999,QCR1_HUMAN,4,"[tensor(0.1481), tensor(0.0761), tensor(-0.263..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
107619,L,0,0,0,0,0,0,0,0,0,...,-3,-0.3,-3,-0.0,47.813999,7.569000,-27.368999,PDE7A_HUMAN,474,"[tensor(-0.0071), tensor(-0.1955), tensor(-0.1..."
107620,P,0,0,0,0,0,0,0,0,0,...,0,0.0,-3,-0.0,50.228001,8.068000,-30.333000,PDE7A_HUMAN,475,"[tensor(-0.3959), tensor(-0.2057), tensor(-0.0..."
107621,Q,0,0,0,0,0,0,0,0,0,...,0,0.0,0,0.0,51.507999,4.896000,-31.959999,PDE7A_HUMAN,476,"[tensor(-0.0042), tensor(-0.6211), tensor(-0.0..."
107622,E,0,0,0,1,0,0,0,0,0,...,0,0.0,0,0.0,54.845001,6.372000,-33.125000,PDE7A_HUMAN,477,"[tensor(-0.0075), tensor(-0.2074), tensor(-0.0..."


In [20]:
results = []

manual_list = ['CENPE_HUMAN']
drop_cols = ['annotation_sequence', 'annotation_atomrec', 'entry', 'embeddings']

for entry in test_df['entry'].unique():
    group = test_df[test_df['entry'] == entry]

    x = group.loc[group['entry'] == entry, group.columns] \
        .sort_values(by='entry_index') \
        .drop(drop_cols, axis=1).values

    edges = []

    if entry in manual_list:
        print('Manually looking for edges...')
        for i in range(len(group)):
            for j in range(i + 1, min(i + 50, len(group))):
                a = np.asarray(group.iloc[i][['coord_X', 'coord_Y', 'coord_Z']])
                b = np.asarray(group.iloc[j][['coord_X', 'coord_Y', 'coord_Z']])
                dist = distance.euclidean(a, b)
                if dist <= 6:
                    edges.append([i, j])
        edges = np.asarray(edges).T
    else:
        file = get_file(entry)
        structure = parser.get_structure(1, file)
        p1 = {chain.id:seq1(''.join(residue.resname for residue in chain)) for chain in structure.get_chains()}['A']
        p2 = "".join(test_df[test_df['entry'] == entry]['annotation_sequence'].values)

        molecule = bg.Pmolecule(file)
        network = molecule.network()

        for i in range(len(p2) + 1, len(p1) + 1):
            node_to_remove = 'A' + str(i)
            network.remove_node(node_to_remove)

        edges = np.asarray(list(network.edges)).T
        edges = [[int(s[1:]) for s in edges[0]], [int(s[1:]) for s in edges[1]]]

    x = torch.FloatTensor(x)
    x_acc = []

    for i in range(len(x)):
        x_acc.append(torch.cat([x[i], group['embeddings'].iloc[i]], 0).numpy())
    x = torch.FloatTensor(np.asarray(x_acc))

    edges = torch.tensor(edges, dtype=torch.long)
    if edges[0][0] == 1:
        edges = edges - 1

    graph = Data(x=x, edge_index=edges)

    model.eval()
    out = torch.sigmoid(model(graph.to(device))).detach().round().cpu().numpy()
    results.append(out)

Manually looking for edges...


In [56]:
flattened = np.asarray([item for sub_list in results for item in sub_list]).flatten()

In [60]:
submission = pd.DataFrame()
submission['y_Ligand'] = flattened
submission['y_Ligand'] = submission['y_Ligand'].astype(bool)
submission

Unnamed: 0,y_Ligand
0,False
1,False
2,False
3,False
4,False
...,...
107619,False
107620,False
107621,False
107622,False


In [61]:
submission.to_csv('submission.csv')