In [None]:
import itertools
import os
import time

from Bio import SeqIO
import h5py
import numpy as np

In [None]:
input_prokaryote_fasta_fp = '/home/jklynch/host/project/viral-learning/data/500_ArcFake_training_set.fasta'
input_phage_fasta_fp = '/home/jklynch/host/project/viral-learning/data/500_ArcPhage_training_set.fasta'

In [None]:
output_training_h5_fp = '/home/jklynch/host/project/viral-learning/data/500_phage_prok_cnn_training.h5'

In [None]:
def count_fasta_sequences(fasta_fp):
    seq_count = 0
    with open(fasta_fp, 'rt') as phage_file:
        for line in phage_file:
            if line.startswith('>'):
                seq_count += 1
            else:
                pass
    return seq_count    

In [None]:
def first_fasta_sequence_length(fasta_fp):
    for record in SeqIO.parse(fasta_fp, "fasta"):
        return len(record.seq)

In [None]:
def create_dataset(h5_file, dset_name, sample_count, sequence_length, im_height):
    return h5_file.create_dataset(
        name=dset_name,
        shape=(sample_count, im_height, sequence_length - im_height + 1, 4),
        maxshape=(None, im_height, sequence_length - im_height + 1, 4),
        dtype=np.float64,
        chunks=(1, im_height, sequence_length - im_height + 1, 4),
        compression='gzip',
        compression_opts=9)

In [None]:
def get_start_stop(N, M):
    """
    N is sequence length
    M is image height
    S is stride
    
    N=13
    M=4
    S=1
    1 2 3 4 5 6 7 8 9 10 11 12 13
    a c g t a c g t a c  g  t  a        start stop
                                        i     N-M+i
    a1 c2 g3 t4 a5 c6 g7  t8  a9  c10   0     9
    c2 g3 t4 a5 c6 g7 t8  a9  c10 g11   1     10
    g3 t4 a5 c6 g7 t8 a9  c10 g11 t12   2     11
    t4 a5 c6 g7 t8 a9 c10 g11 t12 a13   3     12
    """
    start_stop = np.zeros((M, 2), dtype=np.int32)
    start_stop[:, 0] = np.arange(M, dtype=np.int32)
    start_stop[:, 1] = (N-M) + np.arange(M, dtype=np.int32)
    print(start_stop)
    #start_stop = start_stop + np.arange(N-M+1, dtype=np.int32).reshape((N-M+1,1))
    return start_stop

In [None]:
def get_2D_sequence(seq, start_stop_indices):
    """
    seq=acgtacgtacgta
    start_stop_indices=((0,9),(1,10),(2,11),(3,12))
    
    seq_2d=(
        (acgtacgtac),
        (cgtacgtacg),
        (gtacgtacgt),
        (tacgtacgta)
    )
    """
    seq_2d = []
    for start, stop in start_stop_indices:  #get_start_stop(N=len(seq), M=M):
        seq_2d.append(tuple(seq[start:stop]))
    return tuple(seq_2d)

In [None]:
nucleotide_to_channels = {
    'A':[1.0, 0.0, 0.0, 0.0],
    'C':[0.0, 1.0, 0.0, 0.0],
    'G':[0.0, 0.0, 1.0, 0.0],
    'T':[0.0, 0.0, 0.0, 1.0]}
    #'U':[0.00, 0.00, 0.00, 0.00, 0.00],
    #'N':[0.20, 0.20, 0.20, 0.20, 0.20],
    #'R':[0.50, 0.00, 0.50, 0.00, 0.00],
    #'M':[0.50, 0.50, 0.00, 0.00, 0.00], # A or C
    #'S':[0.00, 0.50, 0.50, 0.00, 0.00], # C or G
    #'K':[0.00, 0.00, 0.333, 0.333, 0.333], # G, T, or U
    #'W':[0.333, 0.00, 0.00, 0.333, 0.333], # A, T, or U
    #'Y':[0.00, 0.333, 0.00, 0.333, 0.333]} # C, T, ur U

def translate_seq_to_training_input(seq, M, start_stop_indices, verbose=False):
    """
    M is image height
    
    """
    ##S = 1
    N = len(seq)
    ##M = 100
    training_data = np.zeros((M, N-M+1, 4))
    for start, partial_seq in enumerate(get_2D_sequence(seq, start_stop_indices=start_stop_indices)):
        #print(partial_seq)
        for n, nucleotide in enumerate(partial_seq):
            training_data[start, n, :] = nucleotide_to_channels[nucleotide]
        if verbose:
            print(partial_seq)
            print(training_data[start, :, :])

    return training_data


In [None]:
def get_images(fasta_fp, seq_length, im_height, im_limit):

    #max_samples, im_height, im_width, n_channels = dset.shape
    #seq_length = im_height + im_width - 1
    #print('max_samples     : {}'.format(max_samples))
    #print('image height    : {}'.format(im_height))
    #print('image width     : {}'.format(im_width))
    #print('channels        : {}'.format(n_channels))
    print('sequence length : {}'.format(seq_length))

    start_stop_indices = get_start_stop(seq_length, im_height)

    # i is the current output row index
    # r is the current input row index
    # they may not be equal
    i = 0
    t0 = time.time()
    for r, record in enumerate(itertools.islice(SeqIO.parse(fasta_fp, "fasta"), im_limit)):
        if len(record.seq) != seq_length:
            print('{} record.seq length: {} != {}'.format(r, len(record.seq), seq_length))
        else:
            # dset[i, :, :, :] = 
            try:
                t = translate_seq_to_training_input(
                    seq=str(record.seq),
                    start_stop_indices=start_stop_indices,
                    M=im_height)
                i += 1
                yield t
            except KeyError:
                print('found a sequence with ambigous base'.format())
        
        if (i + 1) % 100 == 0:
            print('finished 100 records in {:5.2f}s'.format(r, time.time()-t0))
            t0 = time.time()

    # return the number of images written to dset
    ##return i + 1


In [None]:
def write_phage_prok_cnn_training_file(input_phage_fp, input_prok_fp, output_h5_fp, im_height, im_limit=None):
    phage_seq_count = count_fasta_sequences(fasta_fp=input_phage_fp)
    print('{} sequences in file "{}"'.format(phage_seq_count, input_phage_fp))

    prok_seq_count = count_fasta_sequences(fasta_fp=input_prok_fp)
    print('{} sequences in file "{}"'.format(prok_seq_count, input_prok_fp))

    phage_seq_length = first_fasta_sequence_length(fasta_fp=input_phage_fp)
    prok_seq_length = first_fasta_sequence_length(fasta_fp=input_prok_fp)    
    
    if phage_seq_length == prok_seq_length:
        seq_length = phage_seq_length
        print('phage and prokaryote sequence length is {}'.format(seq_length))
        print('image height : {}'.format(im_height))
        print('image width  : {}'.format(seq_length - im_height + 1))
    else:
        raise Exception('phage and prokaryote sequence lengths are different')
    
    os.remove(output_h5_fp)
    with h5py.File(output_h5_fp, 'w') as h5_file:
        phage_dset = create_dataset(
            h5_file=h5_file,
            dset_name=os.path.basename(input_phage_fp),
            sample_count=phage_seq_count,
            sequence_length=seq_length,
            im_height=im_height)
        
        max_samples, im_height, im_width, n_channels = phage_dset.shape
        for i, seq_image in enumerate(get_images(fasta_fp=input_phage_fp, im_height=im_height, seq_length=seq_length, im_limit=im_limit)):
            phage_dset[i, :, :, :] = seq_image

        # resize the data set
        (s, m, n, c) = phage_dset.shape
        phage_dset.resize((i, m, n, c))
        
        prok_dset = create_dataset(
            h5_file=h5_file,
            dset_name=os.path.basename(input_prok_fp),
            sample_count=prok_seq_count,
            sequence_length=seq_length,
            im_height=im_height)
        
        max_samples, im_height, im_width, n_channels = prok_dset.shape
        for i, seq_image in enumerate(get_images(fasta_fp=input_prok_fp, im_height=im_height, seq_length=seq_length, im_limit=im_limit)):
            prok_dset[i, :, :, :] = seq_image
        
        # resize the data set
        (s, m, n, c) = prok_dset.shape
        prok_dset.resize((i, m, n, c))
        


In [None]:
write_phage_prok_cnn_training_file(
    input_phage_fp=input_phage_fasta_fp,
    input_prok_fp=input_prokaryote_fasta_fp,
    output_h5_fp=output_training_h5_fp,
    im_height=100,
    im_limit=200)