In [16]:
import gzip
import numpy as np
import logging

from collections import deque


logger = logging.getLogger(__name__)


# zcat EE87920.hg38.frag.gz | awk -F'\t' '{print $3 - $2}' | sort -nr | head -n 1 -> 262
MATRIX_ROWS = int(262 * 1.5)  # add 50% threshold
MATRIX_COLUMNS = 2000
MATRIX_COLUMNS_HALF = MATRIX_COLUMNS // 2


TEST_DATA = "../../raw_data/test/EE87920_10lines.hg38.frag.gz"
TEST_DHS_DATA = "../../raw_data/test/DHS_test.gz"


def read_dhs_to_memory():
    sites = deque()
    with gzip.open(TEST_DHS_DATA, 'rt') as f:
        for line in f:
            start, end = line.split('\t')
            sites.append((int(end)+int(start))//2)
    return sites, len(sites)
    

def get_curr_dhs() -> tuple:
    if not DHS_sites:
        return None, None, None, None
    
    curr_dhs = DHS_sites.popleft()
    return (
        curr_dhs, 
        curr_dhs - MATRIX_COLUMNS_HALF, 
        curr_dhs + MATRIX_COLUMNS_HALF, 
        initial_DHS_length - len(DHS_sites) - 1
    )

def parse_fragment(line: str) -> tuple:
    parsed_fragment = line.strip().split('\t')
    start, end = parsed_fragment[1:3]
    return int(start), int(end)

DHS_sites, initial_DHS_length = read_dhs_to_memory()

result = np.zeros((len(DHS_sites), MATRIX_ROWS, MATRIX_COLUMNS))
curr_dhs, curr_dhs_start, curr_dhs_end, curr_dhs_index = get_curr_dhs()
with gzip.open(TEST_DATA, 'rt') as f:
    for line in f:
        start, end = parse_fragment(line)
        fragment_length = end - start
        
        # if the fragment is too long skip and log it for now
        if fragment_length >= MATRIX_ROWS:
            logger.warning(f'Skipped fragment due to too high length:\nstart:{start}\nend:{end}')
            continue
        
        # move dhs until we have overlapping fragments
        while curr_dhs and end >= curr_dhs_end:
            curr_dhs, curr_dhs_start, curr_dhs_end, curr_dhs_index = get_curr_dhs()
            if curr_dhs is None:
                logger.warning('No more DHS sites')
                break
                
        if curr_dhs is None:
            logger.warning('No more DHS sites')
            break
            
        
        if start >= curr_dhs_start and end <= curr_dhs_end:
            rel_start = start - curr_dhs_start
            rel_end = end - curr_dhs_start
            
            # take care boundaries so we ain't updating nonexistent rows or columns
            rel_start = max(0, rel_start)
            rel_end = min(MATRIX_COLUMNS - 1, rel_end)
            
            if rel_start < rel_end:
                result[curr_dhs_index, fragment_length, rel_start:rel_end+1] += 1
        
        
# saving result
np.save('../../data/test/EE87920_10lines__DHS_test.npy', result)


In [9]:
import numpy as np

# opening numpy arrays
with open('../../data/test/EE87920_10lines__DHS_test.npy', 'rb') as f:
    a = np.load(f)
    
a.sum()

3168.0