In [1]:
import tensorflow as tf
import numpy as np
from readData_from_TFRec import widen_seq, parse_dataset, create_protein_batches
from utils import expand_dim, calc_pairwise_distances, to_distogram, load_npy_binary, output_to_distancemaps
from utils import pad_mask, pad_primary, pad_tertiary, masked_categorical_cross_entropy
import glob
import math
import tensorflow.keras.backend as K
from readData_from_TFRec import parse_tfexample

In [2]:
def generator_transform(primary, evolutionary, tertiary, tertiary_mask, stride, padding_value=-1, minimum_bin_val=2, 
                        maximum_bin_val=22, num_bins=64):
    
    # correcting the datatype to avoid errors
    stride = int(stride)
    padding_value = float(padding_value)
    minimum_bin_val = float(minimum_bin_val)
    maximum_bin_val = float(maximum_bin_val)
    num_bins = int(num_bins)
    
    # detecting the number of crops
    crops_per_seq = primary.shape[0] // stride
    if (primary.shape[0] % stride > 0) and (crops_per_seq > 0):
        crops_per_seq += 1
    total_crops = crops_per_seq * crops_per_seq
    
    # compute padding necessary for this protein
    # Find the number of padding elements
    num_padding = math.ceil(primary.shape[0]/stride)*stride - primary.shape[0] 
    # pad on left and bottom
    padding = tf.constant([[0, num_padding]])
    
    # primary transformation
    # compute the rank of tensor to apply padding
    primary_rank = tf.rank(primary).numpy()
    primary_padding = tf.repeat(padding, primary_rank, axis=0)
    primary = tf.pad(primary, primary_padding, constant_values=tf.cast(padding_value, primary.dtype))
    # widen primary sequence to convert it to 2D
    primary = widen_seq(primary)
    # cast it to float for the model
    primary = K.cast_to_floatx(primary)
    
    # tertiary trnsformation
    tertiary = calc_pairwise_distances(tertiary)
    # pad on left and bottom
    tertiary_rank = tf.rank(tertiary).numpy()
    tertiary_padding = tf.repeat(padding, tertiary_rank, axis=0)
    tertiary = tf.pad(tertiary, tertiary_padding, constant_values=tf.cast(padding_value, tertiary.dtype))
    
    # mask transformation
    mask_rank = tf.rank(tertiary_mask).numpy()
    mask_padding = tf.repeat(padding, mask_rank, axis=0)
    tertiary_mask = tf.pad(tertiary_mask, mask_padding, constant_values=0)
    
    # perform crop
    if total_crops > 0:
        batches = create_protein_batches(primary, tertiary, tertiary_mask, stride)
        # transform teritiary to distogram
        for i in range(len(batches)):
            dist_tertiary = to_distogram(batches[i][1], min_val=minimum_bin_val, max_val=maximum_bin_val, num_bins=num_bins)
            dist_tertiary = tf.convert_to_tensor(dist_tertiary, dtype=tertiary.dtype)
            dist_tertiary = K.cast_to_floatx(dist_tertiary)
            batches[i] = (batches[i][0], dist_tertiary, batches[i][2])
        return batches
    else:
        tertiary = to_distogram(tertiary, min_val=minimum_bin_val, max_val=maximum_bin_val, num_bins=num_bins)
        tertiary = tf.convert_to_tensor(tertiary, dtype=tertiary.dtype)
        tertiary = K.cast_to_floatx(tertiary)
        return ([(primary, tertiary, tertiary_mask)])

In [3]:
def parse_dataset_test(file_paths, parameters):
    """
    This function iterates over all input files
    and extract record information from each single file
    Use Yield for optimization purpose causes reading when needed
    """
#     print(type(parameters))
    if isinstance(parameters, np.ndarray):
        li = parameters.tolist()
        parameters = {item[0].decode('ascii'): float(item[1]) for item in li}
#     print(parameters)
    raw_dataset = tf.data.TFRecordDataset(file_paths)
    for data in raw_dataset:
        primary, evolutionary, tertiary, ter_mask = parse_tfexample(data)
        transformed_batch = [(primary, tertiary, ter_mask)]
#         transformed_batch = generator_transform(primary, evolutionary, tertiary, ter_mask, 
#                                                 stride=parameters["stride"], 
#                                                 padding_value=parameters["padding_value"], 
#                                                 minimum_bin_val=parameters["minimum_bin_val"], 
#                                                 maximum_bin_val=parameters["maximum_bin_val"], 
#                                                 num_bins=parameters["num_bins"])
        for subset in transformed_batch:
            yield subset # has values (primary, tertiary, tertiary mask)
#         if primary.shape[0]>64:
#             yield (primary, evolutionary, tertiary, ter_mask)
#         else:
#             pass

In [4]:
path = glob.glob("../proteinnet/data/casp7/training/100/*")
params = {
    "stride": 64,
    "padding_value": -1,
    "minimum_bin_val": 2,
    "maximum_bin_val": 22,
    "num_bins": 64
}
items = []
for i,j in params.items():
    items.append([i, j])
path


['../proteinnet/data/casp7/training/100/13',
 '../proteinnet/data/casp7/training/100/16',
 '../proteinnet/data/casp7/training/100/41',
 '../proteinnet/data/casp7/training/100/9',
 '../proteinnet/data/casp7/training/100/6',
 '../proteinnet/data/casp7/training/100/115',
 '../proteinnet/data/casp7/training/100/40',
 '../proteinnet/data/casp7/training/100/104',
 '../proteinnet/data/casp7/training/100/33',
 '../proteinnet/data/casp7/training/100/22',
 '../proteinnet/data/casp7/training/100/93',
 '../proteinnet/data/casp7/training/100/54',
 '../proteinnet/data/casp7/training/100/61',
 '../proteinnet/data/casp7/training/100/105',
 '../proteinnet/data/casp7/training/100/109',
 '../proteinnet/data/casp7/training/100/92',
 '../proteinnet/data/casp7/training/100/38',
 '../proteinnet/data/casp7/training/100/2',
 '../proteinnet/data/casp7/training/100/4',
 '../proteinnet/data/casp7/training/100/112',
 '../proteinnet/data/casp7/training/100/24',
 '../proteinnet/data/casp7/training/100/116',
 '../pro

In [7]:
ds_counter = tf.data.Dataset.from_generator(parse_dataset_test, args=[path, np.array(items)], output_types=(tf.float32, tf.float32, tf.float32),output_shapes= ((None, ),(None, None, ),(None, None, )))

In [10]:
count = 0
for _ in ds_counter:
    count +=1
count

34557