In [115]:
import json
import os
import numpy as np
import torch
import biotite.structure as struc

from tqdm import tqdm
from biotite.structure.io.pdb import PDBFile

In [3]:
with open('../data/gvp-data/chain_set_splits.json') as inFile:
    splits = json.loads(inFile.read())

In [4]:
len(splits['train']), len(splits['validation']), len(splits['test'])

(18024, 608, 1120)

In [119]:
oneletter2threeletter = {
    "A": "ALA",
    "R": "ARG",
    "N": "ASN",
    "D": "ASP",
    "C": "CYS",

    "Q": "GLN",
    "E": "GLU",
    "G": "GLY",
    "H": "HIS",
    "I": "ILE",

    "L": "LEU",
    "K": "LYS",
    "M": "MET",
    "F": "PHE",
    "P": "PRO",


    "S": "SER",
    "T": "THR",
    "W": "TRP",
    "Y": "TYR",
    "V": "VAL",
    
    "U": "SEC",
    "X": "SEC", # Just a hack to avoid backbone filtering.
    "O": "PYL",
    "B": "ASP", # Here too, should be "ASX".
}

In [120]:
with open('../data/gvp-data/chain_set.jsonl') as inFile:
    for i, l in tqdm(enumerate(inFile.readlines())):
        chainset = json.loads(l)
        
        name = chainset['name'].replace('.', '_')
        if os.path.exists(f'../data/gvp_pdb/{name}.pdb'):
            continue    
        
        coordinates = np.concatenate(
            [np.array(chainset['coords']['N'])[:, None, :],
            np.array(chainset['coords']['CA'])[:, None, :],
            np.array(chainset['coords']['C'])[:, None, :],
            np.array(chainset['coords']['O'])[:, None, :]],
            axis=1,
        )
        mask = np.isfinite(coordinates).all(axis=1).all(axis=1)
        coordinates = coordinates[mask].reshape(-1, 3)
        seq = [c for m, c in zip(mask, chainset['seq']) if m]
        
        num_residues = len(coordinates) // 4
        structure = struc.AtomArray(len(coordinates))
        structure.coord = coordinates
        structure.chain_id = ['A'] * len(coordinates)
        structure.atom_name = ['N', 'CA', 'C', 'O'] * (num_residues)
        res_names = []
        for c in seq:
            for _ in range(4):
                res_names.append(oneletter2threeletter[c])
        structure.res_name = res_names
        structure.res_id = np.repeat( range(1, num_residues + 1), 4 )

        pdb = PDBFile()
        pdb.set_structure(structure)
        pdb.write(f'../data/gvp_pdb/{name}.pdb')

21668it [11:57, 30.19it/s]


In [6]:
chainset.keys()

dict_keys(['seq', 'coords', 'num_chains', 'name', 'CATH'])

In [7]:
len(chainset['seq']) == len(chainset['coords']['N']) == len(chainset['coords']['C']) == len(chainset['coords']['CA']) == len(chainset['coords']['O'])

True

In [91]:
chainset['seq']

'XNEGDAAKGEKEFNKCKACHMIQAPDGTDIKGGKTGPNLYGVVGRKIASEEGFKYGEGILEVAEKNPDLTWTEANLIEYVTDPKPLVKKMTDDKGAKTKMTFKMGKNQADVVAFLAQDDPDAXXXXXXXXXXXXX'

In [94]:


num_residues = len(coordinates) // 4
structure = struc.AtomArray(len(coordinates))
structure.coord = coordinates
structure.chain_id = ['A'] * len(coordinates)
structure.atom_name = ['N', 'CA', 'C', 'O'] * (num_residues)
res_names = []
for c in seq:
    for _ in range(4):
        res_names.append(oneletter2threeletter[c])
structure.res_name = res_names
structure.res_id = np.repeat( range(1, num_residues + 1), 4 )

pdb = PDBFile()
pdb.set_structure(structure)
pdb.write('test.pdb')

In [103]:
t = torch.load('test.pt')