In [None]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from Bio.PDB import *
import multiprocessing as mp
import h5py
import logging
from math import floor
from tqdm.notebook import tqdm

In [None]:
# Biopython creates warnings for chains that are discontinuous. I recommend turning them off.
import warnings
from Bio import BiopythonWarning
warnings.simplefilter('ignore', BiopythonWarning)

In [None]:
logging.basicConfig(filename='data_generation.log',level=logging.ERROR)

In [None]:
#Global Variables
#----------------
# Requires absolute path
test_path = pathlib.Path("/home/collin/protein_gan/data/test/")
train_path = pathlib.Path("/home/collin/protein_gan/data/train/")

In [None]:
def split_matrix(matrix):
    if (len(matrix)>=RES):
        for n in range(1,int(floor(len(matrix)/RES))):
            # Creating RES x RES matrices by traversing the spine of input matrix
            matrix_chunk = matrix[RES*(n-1):RES*n, RES*(n-1):RES*n]
            yield matrix_chunk

In [None]:
def calc_dist_matrix(residues):
    """Returns a matrix of distances between residues of the same chain."""
    size = len(residues)
    answer = np.zeros((size, size), np.float)
    for row, residue_one in enumerate(residues):
        for col, residue_two in enumerate(residues):
            answer[row, col] = residue_one["CA"] - residue_two["CA"]
    return answer

In [None]:
def generate_maps(files):
    """
    Generate specified resolution a-carbon maps given a input directory
    """
    # Create A chain maps as matrices
    parser = PDBParser()
    io = PDBIO()
    # Get the initial structure of the protein
    try:    
        structure = parser.get_structure('X', files)
        for models in structure:
            residues = Selection.unfold_entities(models['A'], 'R')
            ca_residues = [residue for residue in residues if 'CA' in residue]
            distance_matrix = calc_dist_matrix(ca_residues)
            return list(split_matrix(distance_matrix))
    except ValueError as err:
        logging.error(f'ValuError file :{files}, Error is:{err}')
    except TypeError as err:
        logging.error(f'TypeError file :{files}, Error is:{err}')


In [None]:
def main(files, desc):
    """
    Clean the generated maps using all cores in the process
    """
    p = mp.Pool(maxtasksperchild=3)
    pdbs = [file for file in files.glob("*.pdb")]
    r = list(tqdm(p.imap(generate_maps, pdbs), total=len(pdbs), desc=desc))
    p.close()
    p.join()
    return r

In [None]:
def test(result):
    test_len = [len(x) for x in result]
    plt.hist(test_len)
    plt.show()
    plt.imshow(result[5], cmap='viridis')
    plt.colorbar()
    plt.show()

In [None]:
RES = 16
test_16 = main(test_path, "test_16")
test_16 = [item for sublist in test_16 if sublist for item in sublist]
with h5py.File('dataset.hdf5', 'a') as f:
    f.create_dataset('test_16', data=test_16, compression="gzip")

In [None]:
RES = 16
train_16 = main(train_path, "train_16")
train_16 = [item for sublist in train_16 if sublist for item in sublist]
with h5py.File('dataset.hdf5', 'a') as f:
    f.create_dataset('train_16', data=train_16, compression="gzip")

In [None]:
RES = 64
test_64 = main(test_path, "test_64")
test_64 = [item for sublist in test_64 if sublist for item in sublist]
with h5py.File('dataset.hdf5', 'a') as f:
    f.create_dataset('test_64', data=test_64, compression="gzip")

In [None]:
RES = 64
train_64 = main(train_path, "train_64")
train_64 = [item for sublist in train_64 if sublist for item in sublist]
with h5py.File('dataset.hdf5', 'a') as f:
    f.create_dataset('train_64', data=train_64, compression="gzip")

In [None]:
RES = 128
test_128 = main(test_path, "test_128")
test_128 = [item for sublist in test_128 if sublist for item in sublist]
with h5py.File('dataset.hdf5', 'a') as f:
    f.create_dataset('test_128', data=test_128, compression="gzip")

In [None]:
RES = 128
train_128 = main(train_path, "train_128")
train_128 = [item for sublist in train_128 if sublist for item in sublist]
with h5py.File('dataset.hdf5', 'a') as f:
    f.create_dataset('train_128', data=train_128, compression="gzip")