In [None]:
import gzip
import numpy as np
import logging

from collections import deque


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


# zcat EE87920.hg38.frag.gz | awk -F'\t' '{print $3 - $2}' | sort -nr | head -n 1 -> 262
MATRIX_ROWS = int(262 * 1.01)  # add 50% threshold
MATRIX_COLUMNS = 2000
MATRIX_COLUMNS_HALF = MATRIX_COLUMNS // 2


TEST_DATA = "../../data/sorted/breast_sorted.hg38.frag.gz"
DHS_DATA = "../../data/sorted/Lymphoid_DHS_sorted.bed"


def read_dhs_to_memory():
    # saving DHS midpoints in a queue ds (in memory)
    sites = deque()
    with open(DHS_DATA, 'rt') as f:
        # keeping track of last_midpoint to decide whether the next DHS is inside the window or not, 
        #as well as curr_chr, because if we change chr then we need to reset last_midpoint
        last_midpoint, curr_chr = float('-inf'), None
        
        # line by line iteration
        for i, line in enumerate(f):
            chr, start, end = line.split('\t')
            
            # reset variables
            if chr != curr_chr:
                last_midpoint, curr_chr = float('-inf'), chr
            
            # parse string -> int
            start, end = int(start), int(end)
            midpoint = (end + start) // 2
            
            # if there is not enough diff between midpoint (current) and last_midpoint -> overlapping -> continue
            if midpoint - last_midpoint <= MATRIX_COLUMNS:
                # logger.info('skip - overlapping')
                continue
            
            # save midpoint (current)
            sites.append((midpoint, chr))
            # set last_midpoint to midpoint (current)
            last_midpoint = midpoint
    return sites, len(sites)
    

def get_curr_dhs() -> tuple:
    if not DHS_sites:
        return None, None, None
    
    curr_dhs_midpoint, chr = DHS_sites.popleft()
    return (
        curr_dhs_midpoint - MATRIX_COLUMNS_HALF, 
        curr_dhs_midpoint + MATRIX_COLUMNS_HALF,
        chr
    )

def parse_fragment(line: str) -> tuple:
    parsed_fragment = line.strip().split('\t')
    chr, start, end = parsed_fragment[0:3]
    return chr, int(start), int(end)

DHS_sites, initial_DHS_length = read_dhs_to_memory()

result = np.zeros((MATRIX_ROWS, MATRIX_COLUMNS))
curr_dhs_start, curr_dhs_end, curr_chr = get_curr_dhs()
with gzip.open(TEST_DATA, 'rt') as f:
    for line in f:
        chr, start, end = parse_fragment(line)
        fragment_midpoint, fragment_length = (start + end) // 2, end - start
        
        # if the fragment is too long skip and log it for now
        if fragment_length >= MATRIX_ROWS:
            logger.warning(f'Skipped fragment due to too high length:\nstart:{start}\nend:{end}')
            continue
        
        # move dhs until to the fragments' chromosome is reached
        while curr_dhs_end and chr != curr_chr:
            curr_dhs_start, curr_dhs_end, curr_chr = get_curr_dhs()
            if curr_dhs_end is None:
                logger.warning('No more DHS sites')
                break
        
        # move dhs until we have overlapping fragments
        while curr_dhs_end and chr == curr_chr and fragment_midpoint > curr_dhs_end:
            curr_dhs_start, curr_dhs_end, curr_chr = get_curr_dhs()
            if curr_dhs_end is None:
                logger.warning('No more DHS sites')
                break
                
        # break if no more dhs sites
        if curr_dhs_end is None:
            logger.warning('No more DHS sites')
            break
            
        # move fragments that are not overlapping and in the previous chromosome from the dhs point of view
        if chr != curr_chr:
            continue
            
        rel_midpoint = fragment_midpoint - curr_dhs_start

        if rel_midpoint >= 0 and rel_midpoint < MATRIX_COLUMNS:
            result[fragment_length, rel_midpoint] += 1
        
        
# saving result
# np.save('../../data/test/EE87920__Lymphoid_DHS_sorted.npy', result)


In [None]:
import numpy as np

# opening numpy arrays
# with open('../../data/test/EE87920__Lymphoid_DHS_sorted.npy', 'rb') as f:
#     a = np.load(f)
    
# a, a.shape, a.sum()

In [None]:
import matplotlib.pyplot as plt

fragment_lengths = result.sum(axis=1)

# fig = plt.figure(figsize=(8, 4))
plt.plot(np.arange(len(fragment_lengths)), fragment_lengths)
plt.xlabel("Fragment lengths")
plt.ylabel("Count")
plt.title("Fragment lengths distribution")
# plt.show()
# fig.savefig('temp.png', dpi=300)


In [None]:
# length vs fragments' relative midpoint
def calculate_coverage(result: np.ndarray, max_position: int) -> np.ndarray:
    coverage = np.zeros(max_position)
    
    for fragment_length in range(result.shape[0]):
        for rel_midpoint in range(result.shape[1]):
            count = result[fragment_length, rel_midpoint]
            if count > 0:
                # calculate start and end positions from midpoint and length
                start_pos = rel_midpoint - fragment_length // 2
                end_pos = rel_midpoint + fragment_length // 2
                
                # make sure we stay in our boundaries
                start_pos = max(0, start_pos)
                end_pos = min(max_position, end_pos)
                
                # update coverage
                if start_pos < end_pos:
                    coverage[start_pos:end_pos] += count
    
    return coverage

coverage = calculate_coverage(result, MATRIX_COLUMNS)
coverage.shape

In [None]:
plt.plot(np.arange(len(coverage)), coverage)
plt.xlabel("Relative midpoint positions")
plt.ylabel("Coverage")
plt.title("Relative midpoint positions VS Coverage")

In [None]:
LWPS_WINDOW_SIZE = 120
LWPS_LOWER_THRESHOLD = 120
LWPS_UPPER_THRESHOLD = 180
NUM_POSITIONS = MATRIX_COLUMNS - 2 * LWPS_UPPER_THRESHOLD

def calculate_lwps(result: np.ndarray, window_size=LWPS_WINDOW_SIZE) -> np.ndarray:
    lwps = np.zeros(NUM_POSITIONS)
    
    # precompute fragment_data to avoid O(n^3)
    fragment_data = []
    for fragment_length in range(result.shape[0]):
        # filtering out fragments for 120-180 bp length range
        if LWPS_LOWER_THRESHOLD <= fragment_length <= LWPS_UPPER_THRESHOLD:
            continue
            
        for rel_midpoint in range(result.shape[1]):
            count = result[fragment_length, rel_midpoint]
            if count > 0:
                frag_start = rel_midpoint - fragment_length // 2
                frag_end = rel_midpoint + fragment_length // 2
                fragment_data.append({
                    'start': frag_start,
                    'end': frag_end,
                    'count': count,
                })
                
    # sliding window -> calculating lwps for each positions 180,181,...,1818, 1819 O(n^2)
    for pos in range(LWPS_UPPER_THRESHOLD, NUM_POSITIONS + LWPS_UPPER_THRESHOLD):
        # matrix indexing starts from 0
        pos_idx = pos - LWPS_UPPER_THRESHOLD
        
        if pos_idx % 100 == 0 or pos_idx == NUM_POSITIONS - 1:
            progress = pos_idx / NUM_POSITIONS * 100
            logger.info(f"Progress: {progress:.1f}%")
        
        # for position 0 -> window [-60, 60]
        window_start = pos - window_size // 2
        window_end = pos + window_size // 2
        
        # fragments which are outside of this [-60, 60], starts before -60 and ends after 60
        spanning_count = 0
        # fragments those either start or end in the window
        internal_endpoints = 0
        
        for frag in fragment_data:
            frag_start, frag_end, count = frag['start'], frag['end'], frag['count']
            
            # count spanning fragments
            if frag_start <= window_start and frag_end >= window_end:
                spanning_count += count
            
            # count internal endpoints
            if window_start <= frag_start <= window_end:  # starting in the window
                internal_endpoints += count
            if window_start <= frag_end <= window_end:    # ending in the window
                internal_endpoints += count
        
        lwps[pos_idx] = spanning_count - internal_endpoints
    
    return lwps

lwps = calculate_lwps(result)

In [None]:
plt.plot(np.arange(LWPS_UPPER_THRESHOLD, NUM_POSITIONS + LWPS_UPPER_THRESHOLD), lwps)
plt.xlabel("Relative midpoint positions")
plt.ylabel("L-WPS score")
plt.title("Relative midpoint positions VS L-WPS score")

In [None]:
filtered_fragments_lengths = fragment_lengths[LWPS_LOWER_THRESHOLD:LWPS_UPPER_THRESHOLD + 1]
plt.plot(np.arange(LWPS_LOWER_THRESHOLD, LWPS_UPPER_THRESHOLD + 1), filtered_fragments_lengths)
plt.xlabel("Fragment lengths")
plt.ylabel("Count")
plt.title("Fragment lengths distribution")

In [None]:
def calculate(matrix):
    X = 0.999
    ENDPOINT_WINDOW = 10
    WINDOW_SIZE = 20

    logger.info(f"Calculating FDI with x={X}, endpoint_window={ENDPOINT_WINDOW}, window_size={WINDOW_SIZE}")

    # convert matrix to reads format
    reads = matrix_to_reads(matrix)
    logger.info(f"Converted matrix to {len(reads)} reads")

    # calculate coverage array
    coverage = calculate_coverage(matrix)

    # calculate endpoint dispersion matrix
    dispersion_matrix = calculate_endpoint_dispersion(reads, matrix.shape[1], X, ENDPOINT_WINDOW)

    # calculate FDI in non-overlapping sliding windows
    fdi_results = calculate_windowed_fdi(coverage, dispersion_matrix, WINDOW_SIZE)

    return fdi_results

# # TODO: same logic as in LWPSStatistic
def matrix_to_reads(matrix):    
    reads = []
    
    for fragment_length in range(matrix.shape[0]):
        # filtering out fragments based on lengths, maybe it makes sense
        # if 120 <= fragment_length <= 180:
        #     continue
            
        for rel_midpoint in range(matrix.shape[1]):
            count = matrix[fragment_length, rel_midpoint]
            
            if count > 0:
                frag_start = rel_midpoint - fragment_length // 2
                frag_end = rel_midpoint + fragment_length // 2
                
                for _ in range(int(count)):
                    reads.append({
                        'start': frag_start,
                        'end': frag_end,
                        'count': count,
                    })
    return reads

# TODO: same logic as in visualize_matrix.py
def calculate_coverage(matrix: np.ndarray) -> np.ndarray:
    matrix_columns = matrix.shape[1]
    
    coverage = np.zeros(matrix_columns)
    
    for fragment_length in range(matrix.shape[0]):
        for rel_midpoint in range(matrix_columns):
            count = matrix[fragment_length, rel_midpoint]
            if count > 0:
                # calculate start and end positions from midpoint and length
                start_pos = rel_midpoint - fragment_length // 2
                end_pos = rel_midpoint + fragment_length // 2
                
                # make sure we stay in our boundaries
                start_pos = max(0, start_pos)
                end_pos = min(matrix_columns, end_pos)
                
                # update coverage
                if start_pos < end_pos:
                    coverage[start_pos:end_pos] += count
    
    return coverage


def calculate_endpoint_dispersion(reads, matrix_columns, x, endpoint_window):
    # dispersion_matrix[:, 0] = endpoint counts, dispersion_matrix[:, 1] = dispersion values
    dispersion_matrix = np.zeros((matrix_columns, 2))

    # count endpoints at each position
    for read in reads:
        start, end, fragment_length = read['start'], read['end'], read['count']

        # adjust positions to valid range
        start = max(endpoint_window, min(start, matrix_columns - endpoint_window - 1))
        end = max(endpoint_window, min(end, matrix_columns - endpoint_window - 1))

        # count endpoints
        dispersion_matrix[start, 0] += 1
        dispersion_matrix[end, 0] += 1

    # calculate dispersion values
    for read in reads:
        start, end, fragment_length = read['start'], read['end'], read['count']

        # Adjust positions to valid range
        start = max(endpoint_window, min(start, matrix_columns - endpoint_window - 1))
        end = max(endpoint_window, min(end, matrix_columns - endpoint_window - 1))

        # calculate local density and dispersion for start endpoint
        local_density_start = np.sum(
            dispersion_matrix[start-endpoint_window:start+endpoint_window+1, 0]
        ) - 1
        dispersion_matrix[start, 1] += x ** local_density_start

        # calculate local density and dispersion for end endpoint  
        local_density_end = np.sum(
            dispersion_matrix[end-endpoint_window:end+endpoint_window+1, 0]
        ) - 1
        dispersion_matrix[end, 1] += x ** local_density_end

    return dispersion_matrix

def calculate_windowed_fdi(coverage, dispersion_matrix, window_size):
    matrix_columns = len(coverage)
    num_windows = matrix_columns // window_size

    positions, fdi_scores, coverage_stds, endpoint_dispersions = [], [], [], []
    for i in range(num_windows):
        start = i * window_size
        end = start + window_size

        if not i % 10 or i == num_windows - 1:
            progress = round(i / (num_windows - 1) * 100) if num_windows > 1 else 100
            logger.info(f"FDI calculation progress: {progress}%")

        # calculate coverage standard deviation in window
        window_coverage = coverage[start:end]
        coverage_std = np.std(window_coverage)

        # calculate endpoint dispersion in window
        window_endpoint_counts = np.sum(dispersion_matrix[start:end, 0])
        if window_endpoint_counts == 0:  # avoid 0 division
            window_endpoint_counts = 1

        # calculate average endpoint dispersion
        window_dispersion_sum = np.sum(dispersion_matrix[start:end, 1])
        avg_endpoint_dispersion = window_dispersion_sum / window_endpoint_counts

        # calculate FDI score
        fdi_score = coverage_std * avg_endpoint_dispersion

        positions.append((start, end))
        fdi_scores.append(fdi_score)
        coverage_stds.append(coverage_std)
        endpoint_dispersions.append(avg_endpoint_dispersion)

    return {
        'positions': positions,
        'fdi_scores': np.array(fdi_scores),
        'coverage_std': np.array(coverage_stds),
        'endpoint_dispersion': np.array(endpoint_dispersions)
    }

In [None]:
fdi = calculate(result)

In [None]:
fdi

In [None]:
window_centers = [(start + end) / 2 for start, end in fdi['positions']]

plt.figure(figsize=(12, 6))
plt.plot(window_centers, fdi['fdi_scores'], 'b-', marker='o', markersize=4, linewidth=1)

# Mark the DHS site
plt.axvline(x=1000, color='red', linestyle='--', linewidth=2, label='DHS site at 1000')

# Use log scale for y-axis since values are very small
plt.yscale('log')

plt.xlabel('Genomic Position')
plt.ylabel('FDI Score (log scale)')
plt.title('FDI Scores Across Sliding Windows (DHS site at position 1000)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
result.sum()

In [None]:
X = 0.99
ENDPOINT_WINDOW = 10
WINDOW_SIZE = 20

reads = []
for fragment_length in range(result.shape[0]):
    for rel_midpoint in range(result.shape[1]):
        count = result[fragment_length, rel_midpoint]
        
        if count > 0:
            start_pos = rel_midpoint - fragment_length // 2
            end_pos = start_pos + fragment_length

            for _ in range(int(count)):
                reads.append([start_pos, end_pos, fragment_length])

reads = np.array(reads)


In [None]:
reads[:5], reads.shape

In [None]:
coverage = calculate_coverage(result)

In [None]:
coverage[:5], coverage.shape

In [None]:
density_matrix = np.zeros((MATRIX_COLUMNS, 2))

# First pass: Count endpoints at each position
# for read in np.concatenate([reads[:5], reads[-5:]], axis=0):
for read in reads:
    start_pos, end_pos, fragment_length = read

    # Adjust positions to valid range
    adjusted_start = max(ENDPOINT_WINDOW, min(start_pos, MATRIX_COLUMNS - ENDPOINT_WINDOW - 1))
    adjusted_end = max(ENDPOINT_WINDOW, min(end_pos, MATRIX_COLUMNS - ENDPOINT_WINDOW - 1))
    
    # Count endpoints
    density_matrix[adjusted_start, 0] += 1
    density_matrix[adjusted_end, 0] += 1

# # Second pass: Calculate dispersion values
# for read in np.concatenate([reads[:5], reads[-5:]], axis=0):
for read in reads:
    start_pos, end_pos, fragment_length = read
#     print(start_pos, end_pos)z

    # Adjust positions to valid range
    adjusted_start = max(ENDPOINT_WINDOW, min(start_pos, MATRIX_COLUMNS - ENDPOINT_WINDOW - 1))
    adjusted_end = max(ENDPOINT_WINDOW, min(end_pos, MATRIX_COLUMNS - ENDPOINT_WINDOW - 1))
#     print(adjusted_start, adjusted_end)
    
    local_density_start = np.sum(
        density_matrix[adjusted_start - ENDPOINT_WINDOW: adjusted_start + ENDPOINT_WINDOW + 1, 0]
    ) - 1
    density_matrix[adjusted_start, 1] += X ** local_density_start
#     print(local_density_start)
    local_density_end = np.sum(
        density_matrix[adjusted_end - ENDPOINT_WINDOW: adjusted_end + ENDPOINT_WINDOW + 1, 0]
    ) - 1
    density_matrix[adjusted_end, 1] += X ** local_density_end

    
    

density_matrix.sum()

In [None]:
start_idx, end_idx = 100, len(density_matrix) - 100
plt.plot(np.arange(start_idx, end_idx), density_matrix[start_idx:end_idx, 1])

In [None]:
start_idx, end_idx = 100, len(density_matrix) - 100
plt.plot(np.arange(start_idx, end_idx), density_matrix[start_idx:end_idx, 0])

In [None]:
matrix_columns = len(coverage)
num_windows = matrix_columns // WINDOW_SIZE

positions = []
fdi_scores = []
coverage_stds = []
endpoint_dispersions = []

for i in range(num_windows):
    start = i * WINDOW_SIZE
    end = start + WINDOW_SIZE
    
    if not i % 10 or i == num_windows - 1:
        progress = round(i / (num_windows - 1) * 100) if num_windows > 1 else 100
        logger.info(f"FDI calculation progress: {progress}%")

    # calculate coverage standard deviation in window
    window_coverage = coverage[start:end]
    coverage_std = np.std(window_coverage)

    # calculate endpoint dispersion in window
    window_endpoint_counts = np.sum(density_matrix[start:end, 0])
    if window_endpoint_counts == 0:  # avoid 0 division
        window_endpoint_counts = 1

    # calculate average endpoint dispersion
    window_dispersion_sum = np.sum(density_matrix[start:end, 1])
    avg_endpoint_dispersion = window_dispersion_sum / window_endpoint_counts

    # calculate FDI score
    fdi_score = coverage_std * avg_endpoint_dispersion

    positions.append((start, end))
    fdi_scores.append(fdi_score)
    coverage_stds.append(coverage_std)
    endpoint_dispersions.append(avg_endpoint_dispersion)



In [None]:
fdi_scores

In [None]:
fdi_results = np.zeros(matrix_columns)
for i, (start, end) in enumerate(positions):
    print(i, start, end)