In [None]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from Bio.PDB import *
import multiprocessing as mp
import h5py
import logging
from math import floor
from tqdm.notebook import tqdm

In [None]:
# Biopython creates warnings for chains that are discontinuous. I recommend turning them off.
import warnings
from Bio import BiopythonWarning
warnings.simplefilter('ignore', BiopythonWarning)

In [None]:
logging.basicConfig(filename='data_generation.log',level=logging.ERROR)

In [None]:
#Global Variables
#----------------
# Requires absolute path
test_path = pathlib.Path("/home/collin/protein_gan/data/test/")
train_path = pathlib.Path("/home/collin/protein_gan/data/train/")

In [None]:
def split_matrix(matrix):
    if (len(matrix)>=RES):
        for n in range(1,int(floor(len(matrix)/RES))):
            # Creating RES x RES matrices by traversing the spine of input matrix
            matrix_chunk = matrix[RES*(n-1):RES*n, RES*(n-1):RES*n]
            yield matrix_chunk

In [None]:
def calc_dist_matrix(residues):
    """Returns a matrix of distances between residues of the same chain."""
    size = len(residues)
    answer = np.zeros((size, size), np.float)
    for row, residue_one in enumerate(residues):
        for col, residue_two in enumerate(residues):
            answer[row, col] = residue_one["CA"] - residue_two["CA"]
    return answer

In [None]:
def generate_maps(files):
    """
    Generate specified resolution a-carbon maps given a input directory
    """
    # Create A chain maps as matrices
    parser = PDBParser()
    io = PDBIO()
    # Get the initial structure of the protein
    try:    
        structure = parser.get_structure('X', files)
        for models in structure:
            residues = Selection.unfold_entities(models['A'], 'R')
            ca_residues = [residue for residue in residues if 'CA' in residue]
            distance_matrix = calc_dist_matrix(ca_residues)
            return list(split_matrix(distance_matrix))
    except ValueError as err:
        logging.error(f'ValuError file :{files}, Error is:{err}')
    except TypeError as err:
        logging.error(f'TypeError file :{files}, Error is:{err}')


In [None]:
def main(files, desc):
    """
    Clean the generated maps using all cores in the process
    """
    p = mp.Pool()
    pdbs = [file for file in files.glob("*.pdb")]
    r = list(tqdm(p.imap(generate_maps, pdbs), total=len(pdbs), desc=desc))
    x = [item for sublist in r if sublist for item in sublist]    
    p.close()
    p.join()
    return x

In [None]:
def test_result(result):
    test_len = [len(x) for x in result]
    plt.hist(test_len)
    plt.show()
    plt.imshow(result[5], cmap='viridis')
    plt.colorbar()
    plt.show()

In [None]:
with h5py.File("dataset.hdf5", "w") as f:
    # 16x16
    RES = 16
    _set = 'test_16'
    f.create_dataset(_set, data=main(test_path, _set), compression="gzip")
    _set = 'train_16'
    f.create_dataset(_set, data=main(train_path, _set), compression="gzip")
    # 64x64
    RES = 64
    _set = 'test_64'
    f.create_dataset(_set, data=main(test_path, _set), compression="gzip")
    _set = 'train_64'
    f.create_dataset(_set, data=main(train_path, _set), compression="gzip")
    # 128x128
    RES = 128
    _set = 'test_128'
    f.create_dataset(_set, data=main(test_path, _set), compression="gzip")
    _set = 'train_128'
    f.create_dataset(_set, data=main(train_path, _set), compression="gzip")