In [1]:
import numpy as np
import glob
import sys, os
import pickle as pkl
import argparse 
import torch
import esm
import gc

from torch_geometric.data import Data, Batch
from utils import protein_graph
from Bio.PDB.PDBParser import PDBParser
from tqdm import tqdm 
from joblib import Parallel, delayed

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
restype_1to3 = {
    'A': 'ALA',
    'R': 'ARG',
    'N': 'ASN',
    'D': 'ASP',
    'C': 'CYS',
    'Q': 'GLN',
    'E': 'GLU',
    'G': 'GLY',
    'H': 'HIS',
    'I': 'ILE',
    'L': 'LEU',
    'K': 'LYS',
    'M': 'MET',
    'F': 'PHE',
    'P': 'PRO',
    'S': 'SER',
    'T': 'THR',
    'W': 'TRP',
    'Y': 'TYR',
    'V': 'VAL',
}

In [4]:
def get_dist_seq(pdb_id):
    restype_3to1 = {v: k for k, v in restype_1to3.items()}

    pdb = "pdb/{}".format(pdb_id)
    parser = PDBParser()

    struct = parser.get_structure("x", pdb)
    model = struct[0]
    chain_id = list(model.child_dict.keys())[0]
    chain = model[chain_id]
    Ca_array = []
    sequence = ''
    seq_idx_list = list(chain.child_dict.keys())
    seq_len = seq_idx_list[-1][1] - seq_idx_list[0][1] + 1

    for idx in range(seq_idx_list[0][1], seq_idx_list[-1][1]+1):
        try:
            Ca_array.append(chain[(' ', idx, ' ')]['CA'].get_coord())
        except:
            Ca_array.append([np.nan, np.nan, np.nan])
        try:
            sequence += restype_3to1[chain[(' ', idx, ' ')].get_resname()]
        except:
            sequence += 'X'

    Ca_array = np.array(Ca_array)

    resi_num = Ca_array.shape[0]
    G = np.dot(Ca_array, Ca_array.T)
    H = np.tile(np.diag(G), (resi_num,1))
    dismap = (H + H.T - 2*G)**0.5
    return dismap, sequence

In [5]:
# def get_graph(distmap, sequence):
#     batch_converter = alphabet.get_batch_converter()
    
#     batch_labels, batch_strs, batch_tokens = batch_converter([('tmp', sequence)])
#     batch_tokens = batch_tokens.to(device)
#     with torch.no_grad():
#         results = esm_model(batch_tokens, repr_layers=[33], return_contacts=True)
#         token_representations = results["representations"][33][0].cpu().numpy().astype(np.float16)
#         esm_embed = token_representations[1:len(sequence)+1]

#     row, col = np.where(distmap <= 10)
#     edge = [row, col]
#     graph = protein_graph(sequence, edge, esm_embed)
#     return graph

In [6]:
pdbs = !ls pdb/
pdbs[:5]

['A0A024RBG1.pdb',
 'A0A087WT00.pdb',
 'A0A087WWM6.pdb',
 'A0A087WY85.pdb',
 'A0A087WZG4.pdb']

In [7]:
%%time
results = Parallel(n_jobs=-1)(delayed(get_dist_seq)(pdb_id) for pdb_id in (pdbs))

# 結果を分けてリストに格納
sequences = [result[1] for result in results]
dists = [result[0] for result in results]

CPU times: user 10.7 s, sys: 7.88 s, total: 18.5 s
Wall time: 1min 7s


In [8]:
test = [(i, j) for i, j in enumerate(sequences)]
test[:1]

[(0,
  'MMKFKPNQTRTYDREGFKKRAACLCFRSEQEDEVLLVSSSRYPDQWIVPGGGMEPEEEPGGAAVREVYEEAGVKGKLGRLLGIFEQNQDRKHRTYVYVLTVTEILEDWEDSVNIGRKREWFKVEDAIKVLQCHKPVHAEYLEKLKLGCSPANGNSTVPSLPDNNALFVTAAQTSGLPSSVR')]

In [9]:
def get_emb(data, ):
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    model = model.to(device)
    model.eval()
    
    batch_converter = alphabet.get_batch_converter()
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    batch_tokens = batch_tokens.to(device)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33])
        token_representations = results["representations"][33]

    del batch_tokens, batch_converter, results, model
    gc.collect()
    torch.cuda.empty_cache()

    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        seq_rep = token_representations[i, 1 : tokens_len - 1].detach().cpu().numpy()
        sequence_representations.append(seq_rep)

    del token_representations, batch_lens
    gc.collect()
    torch.cuda.empty_cache()

    return sequence_representations

In [10]:
def get_graph(distmap, sequence, emb):
    row, col = np.where(distmap <= 10)
    edge = [row, col]
    graph = protein_graph(sequence, edge, emb)
    return graph

In [11]:
import time

In [12]:
graphs = []
batch_size = 20

# 例として、test, dists, sequences, alphabet, model, device が定義されていると仮定
for i in tqdm(range(0, len(test), batch_size)):
    tmp_seq_num = test[i:i + batch_size]
    tmp_dists = dists[i:i + batch_size]
    tmp_seq = sequences[i:i + batch_size]

    embs = get_emb(tmp_seq_num)
    tmp = []
    for dist, seq, emb in zip(tmp_dists, tmp_seq, embs):
        tmp.append(get_graph(dist, seq, emb))
    
    graphs.extend(tmp)

  edge_index = torch.LongTensor(edge_index)
100%|██████████| 482/482 [1:38:37<00:00, 12.28s/it]


In [13]:
graphs

[Data(x=[181, 1280], edge_index=[2, 2901], native_x=[181]),
 Data(x=[195, 1280], edge_index=[2, 2651], native_x=[195]),
 Data(x=[600, 1280], edge_index=[2, 10542], native_x=[600]),
 Data(x=[148, 1280], edge_index=[2, 2606], native_x=[148]),
 Data(x=[1173, 1280], edge_index=[2, 13407], native_x=[1173]),
 Data(x=[139, 1280], edge_index=[2, 2355], native_x=[139]),
 Data(x=[272, 1280], edge_index=[2, 3936], native_x=[272]),
 Data(x=[156, 1280], edge_index=[2, 1366], native_x=[156]),
 Data(x=[164, 1280], edge_index=[2, 1856], native_x=[164]),
 Data(x=[341, 1280], edge_index=[2, 4331], native_x=[341]),
 Data(x=[748, 1280], edge_index=[2, 13056], native_x=[748]),
 Data(x=[155, 1280], edge_index=[2, 2007], native_x=[155]),
 Data(x=[179, 1280], edge_index=[2, 2959], native_x=[179]),
 Data(x=[794, 1280], edge_index=[2, 12020], native_x=[794]),
 Data(x=[797, 1280], edge_index=[2, 13471], native_x=[797]),
 Data(x=[147, 1280], edge_index=[2, 2255], native_x=[147]),
 Data(x=[30, 1280], edge_index=[2

In [22]:
for i, j in (zip([i.split('.')[0] for i in pdbs], graphs)):
    torch.save(j, 'cmap/{}.pt'.format(i))