In [None]:
def read_file(path):
    sequences, targets, pdb_ids = [], [], []
    with open(path) as f:
        lines = f.readlines()
        num_samples = len(lines)
        for line in lines:
            sequence = line.split(' : ')[-1].strip()
            sequences.append(sequence)

            target = line.split(' : ')[-2].split(' ')
            target_indices = []
            for index in target:
                target_indices.append(int(index[1:]))
            target = []
            for index in range(len(sequence)):
                if index+1 in target_indices:
                    target.append(1)
                else:
                    target.append(0)
            targets.append(target)

            pdb_id = line.split(' : ')[0]
            pdb_ids.append(pdb_id)
    return num_samples, sequences, targets, pdb_ids


num_samples, sequences, targets, pdb_ids = read_file('lib/train.txt')


In [None]:
import torchdrug as td
from torchdrug import data, utils
from Bio.PDB import PDBParser
from Bio.PDB import Select, PDBIO
import pylcs
import torch
import os

class ChainSelect(Select):
    def __init__(self, chain):
        self.chain = chain

    def accept_chain(self, chain):
        if chain.get_id() == self.chain:
            return 1
        else:          
            return 0

def make_protein(sequence, pdb_id):
    if os.path.exists("./pdb/%s.pdb" % pdb_id[:4]):
        pdb_file = "./pdb/%s.pdb" % pdb_id[:4]
    else:
        pdb_file = utils.download("https://files.rcsb.org/download/%s.pdb" % pdb_id[:4], "./pdb")
    
    with open(pdb_file, 'r') as o:
        structure = PDBParser().get_structure('?', o)
        pdbio = PDBIO()
        pdbio.set_structure(structure)
        pdb_file = "./pdb/%s.pdb" % pdb_id
        pdbio.save(pdb_file, ChainSelect(pdb_id[4]))
    
    protein = data.Protein.from_pdb(pdb_file, atom_feature="position", bond_feature=None, residue_feature="symbol")

    lcs = pylcs.lcs_sequence_idx(sequence, protein.to_sequence().replace('.', ''))
    if -1 in lcs:
        print('warning: -1 in lcs. pdb_id: %s' % pdb_id)
    mask = torch.zeros(protein.num_residue, dtype=torch.bool, device=protein.device)
    mask[[i for i in lcs if i != -1]] = True
    protein = protein.subresidue(mask)
    return protein

In [None]:
def make_and_validate(sequence, pdb_id, replace=True):
    protein = make_protein(sequence, pdb_id)
    print('unstripped protein length: %s' % len(protein.to_sequence()))
    print('stripped protein length: %s' % len(protein.to_sequence().replace('.', '')))
    full_sequence = protein.to_sequence()
    lcs = pylcs.lcs_sequence_idx(sequence, full_sequence.replace('.', ''))
    segments = []
    l, r = 0, 0
    while l < len(lcs):
        r += 1
        while r < len(lcs) and lcs[r] == lcs[r-1] + 1:
            r += 1
        segments.append((lcs[l], lcs[r-1]))
        l = r
    return segments

make_and_validate(sequences[3], pdb_ids[3])


In [None]:
from collections import Counter
Counter([i.item() for i in protein.chain_id])

In [None]:
def append_to_log(message):
    with open('example.log', 'a') as f:
        f.write(message + '\n')

def protein_to_coordinates(protein):
    # find N, CA, C, O atoms
    coords = []
    mask = torch.zeros(protein.num_residue,
                       dtype=torch.bool, device=protein.device)
    
    for i in range(protein.num_residue):
        atom_ids = protein.residue2atom(i).sort()[0]
        atom_positions = {}
        for atom, position in zip(protein.atom_name[atom_ids].tolist(), protein.node_position[atom_ids].tolist()):
            atom_name = data.Protein.id2atom_name[atom]
            if atom_name in ['N', 'CA', 'C', 'O']:
                atom_positions[atom_name] = [round(i, 5) for i in position]
        try:
            coords.append([
                atom_positions['N'],
                atom_positions['CA'],
                atom_positions['C'],
                atom_positions['O'],
            ])
            mask[i] = True
        except:
            append_to_log('error: missing atom. atom at: %s' % i)
            mask[i] = False
    return coords

In [None]:
protein = make_protein(sequences[0], pdb_ids[0])
coords = protein_to_coordinates(protein)

In [None]:
from ipywidgets import IntProgress
from IPython.display import display
import json

def load_all_in_file(filename):
    num_samples, sequences, targets, pdb_ids = read_file(filename)
    result = []
    f = IntProgress(min=0, max=num_samples)
    display(f)
    for sequence, pdb_id in zip(sequences, pdb_ids):
        append_to_log('processing %s' % pdb_id)
        try:
            protein = make_protein(sequence, pdb_id)
        except Exception as e:
            append_to_log('error while creating protein: %s' % e)
        coords = protein_to_coordinates(protein)
        result.append({
            'name': pdb_id,
            'seq': sequence,
            'coords': coords,
        })
        f.value += 1
    
    with open(filename + '.json', 'w') as o:
        json.dump(result, o)
    
    return result
        

In [None]:
load_all_in_file('lib/train.txt')

In [None]:
with open('lib/train.txt.json', 'r') as f:
    train_data = json.load(f)

train_data[0]


In [None]:
import gvp.data

dataset = gvp.data.ProteinGraphDataset(train_data[1:3])

In [None]:
dataset

In [None]:
import gvp.models
node_in_dim = (6, 3)
edge_in_dim = (32, 1)
node_h_dim = [20, 20]
edge_h_dim = [20, 20]

cpd_model = gvp.models.CPDModel(node_in_dim, node_h_dim, 
                        edge_in_dim, edge_h_dim)

In [None]:
cpd_models