In [1]:
def read_file(path):
    sequences, targets, pdb_ids = [], [], []
    with open(path) as f:
        lines = f.readlines()
        num_samples = len(lines)
        for line in lines:
            sequence = line.split(' : ')[-1].strip()
            sequences.append(sequence)

            target = line.split(' : ')[-2].split(' ')
            target_indices = []
            for index in target:
                target_indices.append(int(index[1:]))
            target = []
            for index in range(len(sequence)):
                if index+1 in target_indices:
                    target.append(1)
                else:
                    target.append(0)
            targets.append(target)

            pdb_id = line.split(' : ')[0]
            pdb_ids.append(pdb_id)
    return num_samples, sequences, targets, pdb_ids


num_samples, sequences, targets, pdb_ids = read_file('lib/train.txt')

In [14]:
import pylcs
import os
from torchdrug import utils

def parse_pdb(file_path, chain_id):
    three_to_one_letter = {
        'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
        'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
        'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
        'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
        'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
    }

    atom_order = {'N': 0, 'CA': 1, 'C': 2, 'O': 3}

    positions = []
    sequence = ''
    prev_residue_number = None
    with open(file_path, 'r') as pdb_file:
        for line in pdb_file:
            if line.startswith('ATOM'):
                current_chain_id = line[21].strip()
                if current_chain_id == chain_id:
                    atom_name = line[12:16].strip()
                    residue_name = line[17:20].strip()
                    residue_number = int(line[22:26].strip())
                    if atom_name in atom_order:
                        if residue_number != prev_residue_number:
                            one_letter_code = three_to_one_letter.get(
                                residue_name, 'X')
                            sequence += one_letter_code
                            positions.append([None, None, None, None])
                            prev_residue_number = residue_number
                        x = float(line[30:38].strip())
                        y = float(line[38:46].strip())
                        z = float(line[46:54].strip())
                        positions[-1][atom_order[atom_name]] = [x, y, z]
    return positions, sequence

def get_parsed_data_from_pdb_id(orig_sequence, pdb_id):
    if os.path.exists("./pdb/%s.pdb" % pdb_id[:4]):
        file_path = "./pdb/%s.pdb" % pdb_id[:4]
    else:
        file_path = utils.download("https://files.rcsb.org/download/%s.pdb" % pdb_id[:4], "./pdb")
    chain_id = pdb_id[4]
    positions, sequence = parse_pdb(file_path, chain_id)
    lcs_indices = pylcs.lcs_sequence_idx(orig_sequence, sequence)
    assert(-1 not in lcs_indices)
    filtered_positions = [positions[i] for i in lcs_indices]
    filtered_sequence = ''.join(sequence[i] for i in lcs_indices)
    assert(filtered_sequence == orig_sequence)
    return filtered_positions


In [46]:
from ipywidgets import IntProgress
from IPython.display import display
import json

def load_all_in_file(filename):
    num_samples, sequences, targets, pdb_ids = read_file(filename)
    result = []
    f = IntProgress(min=0, max=num_samples)
    display(f)
    for sequence, pdb_id in zip(sequences, pdb_ids):
        coords = get_parsed_data_from_pdb_id(sequence, pdb_id)
        result.append({
            'name': pdb_id,
            'seq': sequence,
            'coords': coords,
        })
        f.value += 1
    
    with open(filename + '.json', 'w') as o:
        json.dump(result, o)
    
    return result

In [47]:
result = load_all_in_file('lib/train.txt')

IntProgress(value=0, max=388)

In [24]:
import gvp.data

# structures is a list or list-like as shown above
dataset = gvp.data.ProteinGraphDataset(result, device='cuda')

In [18]:
dataset[0]

Data(x=[566, 3], edge_index=[2, 16980], seq=[566], name='3EPSA', node_s=[566, 6], node_v=[566, 3, 3], edge_s=[16980, 32], edge_v=[16980, 1, 3], mask=[566])

In [49]:
import gvp.models

node_in_dim = (6,3)
edge_in_dim = (32,1)
node_h_dim = (100, 16)
edge_h_dim = (32, 1)

cpd_model = gvp.models.CPDModel(node_in_dim, node_h_dim, 
                        edge_in_dim, edge_h_dim).to('cuda')

In [50]:
protein = dataset[3]
h_V = (protein.node_s, protein.node_v)
h_E = (protein.edge_s, protein.edge_v) 
sample = cpd_model.sample(h_V, protein.edge_index, h_E, n_samples=1)

In [52]:
sample.eq(protein.seq).float().mean().cpu().numpy()

array(0.04158416, dtype=float32)