In [1]:
import sys
sys.path.append("..")

import numpy as np
from Bio import SeqIO
import os
import json
import tensorflow as tf

from tf_tcmc.tcmc.tcmc import TCMCProbability
import utilities.onehot_tuple_encoder as ote
from utilities import database_reader
from utilities import msa_converter

## Declare the Location of the Dataset and the wanted Splits and import them

In [2]:
folder = '../msa/'
basename = 'augustus_flies'
forest = ['../clades/flies.nwk']
num_leaves = database_reader.num_leaves(forest)

used_codons = True
entry_length = 3 if used_codons else 1
alphabet_size = 4 ** entry_length

train_split = database_reader.DatasetSplitSpecification(
    name = 'train', 
    wanted_models = [0, 1],
    interweave_models = [.62, .38],
    repeat_models = [False, True]
)

val_split = database_reader.DatasetSplitSpecification(
    name = 'val',
    wanted_models = [0, 1]
)

test_split = database_reader.DatasetSplitSpecification(
    name = 'test',
    wanted_models = [0, 1]
)


wanted_splits = [train_split, val_split, test_split]

# gather the dict of splits
datasets = database_reader.get_datasets(folder, basename, wanted_splits, num_leaves = num_leaves, alphabet_size = alphabet_size, seed = None, buffer_size = 1000)
print(datasets)

{'train': <ShuffleDataset shapes: ((), (), (), (None, None, None)), types: (tf.int64, tf.int64, tf.int64, tf.float64)>, 'val': {0: <ParallelMapDataset shapes: ((), (), (), (None, None, None)), types: (tf.int64, tf.int64, tf.int64, tf.float64)>, 1: <ParallelMapDataset shapes: ((), (), (), (None, None, None)), types: (tf.int64, tf.int64, tf.int64, tf.float64)>}, 'test': {0: <ParallelMapDataset shapes: ((), (), (), (None, None, None)), types: (tf.int64, tf.int64, tf.int64, tf.float64)>, 1: <ParallelMapDataset shapes: ((), (), (), (None, None, None)), types: (tf.int64, tf.int64, tf.int64, tf.float64)>}}


In [3]:
# for debug purposes print one entry
for t in datasets[train_split.name].take(1):
    model, clade_id, sequence_length, sequence_onehot = t

    print(f'model: {model}')
    print(f'clade_id: {clade_id}')
    print(f'sequence_length: {sequence_length}')
    print(f'sequence_onehot.shape: {sequence_onehot.shape}')
    
    # to debug the sequence first transform it to its
    # original shape
    S = tf.transpose(sequence_onehot, perm = [1, 0, 2])
    
    # decode the sequence and print some columns
    dec = ote.OnehotTupleEncoder.decode_tfrecord_entry(S.numpy(), tuple_length = entry_length)
    print(f'first (up to) 8 alignment columns of decoded reshaped sequence: \n{dec[:,:8]}')

model: 1
clade_id: 0
sequence_length: 129
sequence_onehot.shape: (129, 12, 64)
first (up to) 8 alignment columns of decoded reshaped sequence: 
[['---' '---' '---' '---' '---' '---' '---' '---']
 ['---' '---' '---' '---' '---' '---' '---' '---']
 ['---' '---' '---' '---' '---' '---' '---' '---']
 ['atg' 'tgt' 'tcc' 'aaa' 'cta' 'act' 'ctt' 'ttc']
 ['atg' 'tgt' 'tcc' 'aaa' 'ctc' 'acc' 'ctt' 'ttc']
 ['---' '---' '---' '---' '---' '---' '---' '---']
 ['---' '---' '---' '---' '---' '---' '---' '---']
 ['atg' 'tgt' 'tcc' 'aaa' 'ctc' 'gtt' 'ctt' 'ttc']
 ['---' '---' '---' '---' '---' '---' '---' '---']
 ['---' '---' '---' '---' '---' '---' '---' '---']
 ['---' '---' '---' '---' '---' '---' '---' '---']
 ['---' '---' '---' '---' '---' '---' '---' '---']]
