# Create validation dataset 197k

### import modules and load local files

In [25]:
import gzip
import importlib
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import pickle
import sys
import tensorflow as tf

In [3]:
datadir = "../../../../../data/FED"
#datadir = sys.argv[1]
outputdir = os.path.join(datadir, "hd5")
fasta_file = os.path.join(datadir, "hg38.fa")
human_sequences = os.path.join(datadir, "data_human_sequences.bed")
pyfaidx.Faidx(fasta_file)

Faidx("../../../../../data/FED/hg38.fa")

In [21]:
from collections import Mapping
import json
import functools


def _reduced_shape(shape, axis):
    if axis is None:
        return tf.TensorShape([])
    return tf.TensorShape([d for i, d in enumerate(shape) if i not in axis])

def organism_path(organism, prefix):
    return os.path.join(prefix, organism)

#
def get_dataset(organism, subset, prefix, num_threads=8):

    metadata = get_metadata(organism, prefix)

    dataset = tf.data.TFRecordDataset(tfrecord_files(organism, subset, prefix = prefix),
                                        compression_type='ZLIB',
                                        num_parallel_reads=num_threads)

    dataset = dataset.map(functools.partial(deserialize, metadata=metadata),
                            num_parallel_calls=num_threads)
    return dataset


def get_metadata(organism, prefix):
  # Keys:
  # num_targets, train_seqs, valid_seqs, test_seqs, seq_length,
  # pool_width, crop_bp, target_length
    path = os.path.join(organism_path(organism, prefix), 'statistics.json')
    with tf.io.gfile.GFile(path, 'r') as f:
        return json.load(f)


def tfrecord_files(organism, subset, prefix):
  # Sort the values by int(*).
  return sorted(tf.io.gfile.glob(os.path.join(
      organism_path(organism, prefix), 'tfrecords', f'{subset}-*.tfr'
  )), key=lambda x: int(x.split('-')[-1].split('.')[0]))


def deserialize(serialized_example, metadata):
    """Deserialize bytes stored in TFRecordFile."""
    feature_map = {
          'sequence': tf.io.FixedLenFeature([], tf.string),
          'target': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_example(serialized_example, feature_map)
    sequence = tf.io.decode_raw(example['sequence'], tf.bool)
    sequence = tf.reshape(sequence, (metadata['seq_length'], 4))
    sequence = tf.cast(sequence, tf.float32)

    target = tf.io.decode_raw(example['target'], tf.float16)
    target = tf.reshape(target,
                          (metadata['target_length'], metadata['num_targets']))
    target = tf.cast(target, tf.float32)

    return {'sequence': sequence,
              'target': target}



class FastaStringExtractor:

    def __init__(self, fasta_file):
        self.fasta = pyfaidx.Fasta(fasta_file)
        self._chromosome_sizes = {k: len(v) for k, v in self.fasta.items()}

    def extract(self, interval: Interval, **kwargs) -> str:
        # Truncate interval if it extends beyond the chromosome lengths.
        chromosome_length = self._chromosome_sizes[interval.chrom]
        trimmed_interval = Interval(interval.chrom,
                                    max(interval.start, 0),
                                    min(interval.end, chromosome_length),
                                    )
        # pyfaidx wants a 1-based interval
        sequence = str(self.fasta.get_seq(trimmed_interval.chrom,
                                          trimmed_interval.start + 1,
                                          trimmed_interval.stop).seq).upper()
        # Fill truncated values with N's.
        pad_upstream = 'N' * max(-interval.start, 0)
        pad_downstream = 'N' * max(interval.end - chromosome_length, 0)
        return pad_upstream + sequence + pad_downstream

    def close(self):
        return self.fasta.close()



In [22]:
fasta_extractor = FastaStringExtractor(fasta_file)

In [5]:
# Create dictionary for search (can be improved! quite slow)
#human_validation_dict = {}
#for interval in interval_list: 
#    sequence = one_hot_encode(fasta_extractor.extract(interval))
#    human_validation_dict[interval] = sequence
    
# -------- save
enformer_dict_file = os.path.join(outputdir,'00_enformer_dict_seqs.h5')
#with open(enformer_dict_file, 'wb') as config_dictionary_file:
#    pickle.dump(human_validation_dict, config_dictionary_file)
    
# -------- read 
with open(enformer_dict_file, 'rb') as config_dictionary_file:
    human_validation_dict = pickle.load(config_dictionary_file)
    

In [6]:
print("Number of sequences in dictionary")
print(len(human_validation_dict.keys()))

Number of sequences in dictionary
2213


In [7]:
def get_interval_from_sequence(sequence, human_validation_dict=human_validation_dict): 
    for interval, ref_sequence in human_validation_dict.items():
        if np.allclose(sequence,ref_sequence):
            return(interval)

In [48]:
# Create dataset with older sequences 
human_dataset = get_dataset('human', 'valid', "/home/luisasantus/Desktop/crg_cluster/data/FED/basenji/")

In [43]:
200%100

0

In [49]:
num_training_examples = 0
num_validation_examples = 0

for example in human_dataset:
    num_training_examples += 1
    if(num_training_examples % 100 == 0 ):
        print(num_training_examples)

100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200


In [34]:
num_training_examples

2214

In [47]:
## Create new dataset
dataset_197k = []
NEW_SEQUENCE_LENGTH = 196_608
max_steps = 2

for i, batch in tqdm(enumerate(human_dataset)):
    batch_197k = {}
    # 1 from the sequence 131k get the sequence 197k
    interval_test = get_interval_from_sequence(batch["sequence"])
    sequence_197k = one_hot_encode(fasta_extractor.extract(interval_test.resize(NEW_SEQUENCE_LENGTH)))
    batch_197k["sequence"] = tf.constant(sequence_197k[np.newaxis])
    
    # add same real targets
    batch_197k["target"] = batch["target"]
    dataset_197k.append(batch_197k)
    if max_steps is not None and i > max_steps:
        break

file = os.path.join(outputdir,'new_dataset_197k_valid.h5')

# Step 2
with open(file, 'wb') as config_dictionary_file:
    pickle.dump(dataset_197k, config_dictionary_file)


0it [00:00, ?it/s][A


NameError: name 'one_hot_encode' is not defined