In [1]:
import torch
import numpy as np

# local files
from src.util.data_handling.data_loader import save_as_pickle, load_dataset
from src.util.data_handling.string_generator import str_seq_to_num_seq, ALPHABETS
from src.data.edit_distance import cross_distance_matrix_threads

# Main

In [2]:
def load_fasta(source_sequences):

    # load sequences
    with open(source_sequences, 'rb') as f:
        L = f.readlines()
        
    # store sequences in dictionary
    length = 0
    id_to_str_seq = {}
    for i in range(len(L) // 2):
        id_, l = L[2 * i].decode('UTF-8')[1:].strip(), L[2 * i + 1].decode('UTF-8').strip()        
        id_to_str_seq[id_] = l
        length = max(len(l), length)
        
    return id_to_str_seq, length

In [3]:
def split_ids(id_to_str_seq, split_to_size):
    
    ids = list(id_to_str_seq.keys())
    split_to_ids = {}
    cum_sum = 0
    
    for name, size in split_to_size.items():
        split_to_ids[name] = ids[cum_sum: cum_sum + size]
        cum_sum += size
        
    return split_to_ids

In [4]:
def get_sequences_distances(str_seqs, length, alphabet, n_thread):
    
    distances_matrix = cross_distance_matrix_threads(str_seqs, str_seqs, n_thread)
    sequences_matrix = [str_seq_to_num_seq(s, length=length, alphabet=alphabet) for s in str_seqs]
    
    distances_matrix = torch.tensor(distances_matrix).float()
    sequences_matrix = torch.tensor(sequences_matrix).long()
    
    return distances_matrix, sequences_matrix

In [5]:
def edit_distance_approximation_data(split_to_str_seqs, n_thread, alphabet, length):
    
    # initial values
    sequences = {}
    distances = {}
    
    # compute edit distance and labels
    for split, str_seqs in split_to_str_seqs.items():
        
        str_seqs = str_seqs[:10]
        distances_matrix, sequences_matrix = get_sequences_distances(str_seqs, length, alphabet, n_thread)
        print('Shapes: {} distances {} {} sequences {}\n'.format(split, distances_matrix.shape, split, sequences_matrix.shape))
        
        sequences[split] = sequences_matrix
        distances[split] = distances_matrix
        
    return sequences, distances

In [6]:
def closest_string_retrieval_data(split_to_str_seqs, n_thread, alphabet, length):

    # load data
    str_references = split_to_str_seqs['ref']
    str_queries = split_to_str_seqs['query']
    n_queries = len(split_to_str_seqs['query'])

    # convert string sequence to numerical sequence
    references = [str_seq_to_num_seq(s, length=length, alphabet=alphabet) for s in str_references]
    queries = [str_seq_to_num_seq(s, length=length, alphabet=alphabet) for s in str_queries]

    # compute distances and find reference with minimum distance
    distances = cross_distance_matrix_threads(str_references, str_queries, n_thread)
    minimum = np.min(distances, axis=0, keepdims=True)

    # queries are only valid if there is a unique answer (no exaequo)
    counts = np.sum((minimum+0.5 > distances).astype(float), axis=0)
    valid = counts == 1
    labels = np.argmin(distances, axis=0)[valid][:n_queries]

    # convert to torch
    references = torch.from_numpy(np.asarray(references)).long()
    queries = torch.from_numpy(np.asarray(queries)[valid][:n_queries]).long()
    labels = torch.from_numpy(labels).float()
    print('Shapes: References {} Queries {} Labels {}'.format(references.shape, queries.shape, labels.shape))

    return references, queries, labels

In [7]:
def main(split_to_size, source_sequences, alphabet, n_thread, outdir, compute_eda_data=True, compute_csr_data=True):

    # initial values
    filenames = ['{}/{}.pickle'.format(outdir, suffix) for suffix in ['auxillary_data', 'sequences_distances', 'closest_strings']]
    
    # load and split data
    id_to_str_seq, length = load_fasta(source_sequences)
    split_to_ids = split_ids(id_to_str_seq, split_to_size)
    save_as_pickle((id_to_str_seq, split_to_ids, alphabet, length), filenames[0])
    
    # seperate data by task: edit distance approximation (eda) and closest string retrival (csr)
    eda_split_to_str_seqs = {split: [id_to_str_seq[_id] for _id in split_to_ids[split]] for split in ['train', 'val', 'test']}
    csr_split_to_str_seqs = {split: [id_to_str_seq[_id] for _id in split_to_ids[split]] for split in ['ref', 'query']}

    # compute edit distance approximation (eda) data and closest string retrival (csr) data
    if compute_eda_data:
        sequences, distances = edit_distance_approximation_data(eda_split_to_str_seqs, n_thread, alphabet, length)
        save_as_pickle((sequences, distances), filenames[1])
    if compute_csr_data:
        references, queries, labels = closest_string_retrieval_data(csr_split_to_str_seqs, n_thread, alphabet, length)
        save_as_pickle((references, queries, labels), filenames[2])
    
    return filenames

In [8]:
split_to_size = {'train': 7000, 'val': 100, 'test': 150, 'ref': 50000, 'query': 500}
source_sequences = '../data/raw/greengenes/gg_13_5.fasta'
n_thread = 5
alphabet = ALPHABETS['DNA']
outdir = '../data/interim/greengenes'

filenames = main(split_to_size, source_sequences, alphabet, n_thread, outdir, compute_eda_data=False, compute_csr_data=False)