# Create validation dataset 197k

### import modules and load local files

In [218]:
import tensorflow as tf
import tensorflow_hub as hub
import joblib
import gzip
import kipoiseq
from kipoiseq import Interval
import pyfaidx
import statistics
import pandas as pd
import numpy as np

import os
from tqdm import tqdm
import importlib.util
import inspect
from typing import Any, Callable, Dict, Optional, Text, Union, Iterable
import attention_module
import sonnet as snt
import pickle


In [219]:
transform_path = 'gs://dm-enformer/models/enformer.finetuned.SAD.robustscaler-PCA500-robustscaler.transform.pkl'
model_path = 'https://tfhub.dev/deepmind/enformer/1'
datadir = "../../../../data/FED"
outputdir = os.path.join(datadir, "hd5")
fasta_file = os.path.join(datadir, "hg38.fa")
human_sequences = os.path.join(datadir, "data_human_sequences.bed")
pyfaidx.Faidx(fasta_file)

Faidx("../../../../data/FED/hg38.fa")

In [220]:
# import utils.py as module
spec_utils = importlib.util.spec_from_file_location("enformer", os.path.join(os.getcwd() ,"utils.py"))
utils = importlib.util.module_from_spec(spec_utils)
spec_utils.loader.exec_module(utils)
from utils import * 

In [221]:
model = Enformer(model_path)

In [222]:
# import enformer.py as module
spec = importlib.util.spec_from_file_location("enformer", os.path.join(os.getcwd() ,"enformer.py"))
enformer = importlib.util.module_from_spec(spec)
spec.loader.exec_module(enformer)
from enformer import * 

In [223]:
fasta_extractor = FastaStringExtractor(fasta_file)

In [224]:
df = pd.read_csv(human_sequences, memory_map=True, header=None, index_col=False, delimiter="\t")
# keep only validation intervals 
validation_intervals= df[df[3]=="valid"]
#validation_intervals = validation_intervals.head()
# create list with interval
interval_list = list()
validation_intervals.apply(lambda row : interval_list.append(kipoiseq.Interval(row[0],row[1], row[2])), axis = 1)

34021    None
34022    None
34023    None
34024    None
34025    None
         ... 
36229    None
36230    None
36231    None
36232    None
36233    None
Length: 2213, dtype: object

In [225]:
validation_intervals

Unnamed: 0,0,1,2,3
34021,chr6,165740202,165871274,valid
34022,chrX,55044496,55175568,valid
34023,chrX,84489673,84620745,valid
34024,chrX,26382093,26513165,valid
34025,chr7,2304644,2435716,valid
...,...,...,...,...
36229,chrX,16977595,17108667,valid
36230,chr20,45038994,45170066,valid
36231,chrX,24547069,24678141,valid
36232,chr2,235793611,235924683,valid


In [226]:
# Create dictionary for search (can be improved! quite slow)
#human_validation_dict = {}
#for interval in interval_list: 
#    sequence = one_hot_encode(fasta_extractor.extract(interval))
#    human_validation_dict[interval] = sequence
    
# -------- save
enformer_dict_file = os.path.join(outputdir,'00_enformer_dict_seqs.h5')
#with open(enformer_dict_file, 'wb') as config_dictionary_file:
#    pickle.dump(human_validation_dict, config_dictionary_file)
    
# -------- read 
with open(enformer_dict_file, 'rb') as config_dictionary_file:
    human_validation_dict = pickle.load(config_dictionary_file)
    

In [227]:
print("Number of sequences in dictionary")
print(len(human_validation_dict.keys()))

Number of sequences in dictionary
2213


In [228]:
def get_interval_from_sequence(sequence, human_validation_dict=human_validation_dict): 
    for interval, ref_sequence in human_validation_dict.items():
        if np.allclose(sequence,ref_sequence):
            return(interval)

In [229]:
# Create dataset with older sequences 
human_dataset = get_dataset('human', 'valid').batch(1).repeat()

In [230]:
## Create new dataset
dataset_197k = []
NEW_SEQUENCE_LENGTH = 196_608
max_steps = 2

for i, batch in tqdm(enumerate(human_dataset)):
    batch_197k = {}
    # 1 from the sequence 131k get the sequence 197k
    interval_test = get_interval_from_sequence(batch["sequence"])
    sequence_197k = one_hot_encode(fasta_extractor.extract(interval_test.resize(NEW_SEQUENCE_LENGTH)))
    batch_197k["sequence"] = tf.constant(sequence_197k[np.newaxis])
    
    # add same real targets
    batch_197k["target"] = batch["target"]
    dataset_197k.append(batch_197k)
    if max_steps is not None and i > max_steps:
        break

file = os.path.join(outputdir,'new_dataset_197k_valid.h5')

# Step 2
with open(file, 'wb') as config_dictionary_file:
    pickle.dump(dataset_197k, config_dictionary_file)

3it [00:09,  3.16s/it]


In [231]:
dataset_197k[1]["sequence"]

<tf.Tensor: shape=(1, 196608, 4), dtype=float32, numpy=
array([[[0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        ...,
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.]]], dtype=float32)>