In [None]:
import pandas as pd
import random
import glob
import yaml

with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)
    
def merge_intervals(accessible_sites):
    if not accessible_sites:
        return []
    accessible_sites = sorted(accessible_sites)
    merged = [list(accessible_sites[0])]
    for s, e in accessible_sites[1:]:
        last = merged[-1]
        if s <= last[1]:
            last[1] = max(last[1], e)
        else:
            merged.append([s, e])
    return [(s, e) for s, e in merged]


def subtract_interval(all_sites, accessible_site):
    s, e = accessible_site
    out = []
    for a, b in all_sites:
        if b <= s or a >= e:
            out.append((a, b))
        else:
            if a < s:
                out.append((a, s))
            if e < b:
                out.append((e, b))
    return out


def sample_inaccessible_site(sites):
    total = sum(b - a for a, b in sites)
    if total <= 0:
        return None
    
    random.seed(42)
    r = random.randrange(total)
    acc = 0
    for a, b in sites:
        if acc + (b - a) > r:
            return a + (r - acc)
        acc += b - a
    return None

def generate_negative_dhs_df(
    df,
    window_size,
):
    chr_sizes = (
        df.groupby('chr')['end']
        .max()
        .to_dict()
    )
    neg_rows = []
    window_half = window_size // 2

    for chromosome, sites in df.groupby('chr'):
        if chromosome not in chr_sizes:
            continue

        chr_len = chr_sizes[chromosome]
        mids = sites['mid'].to_numpy()
        n_sites_needed = len(mids)

        accesible_sites = []
        for m in mids:
            s, e = max(0, m - window_half), min(chr_len, m + window_half)
            accesible_sites.append((s, e))
        accesible_sites = merge_intervals(accesible_sites)

        all_sites = [(0, chr_len)]
        for s, e in accesible_sites:
            all_sites = subtract_interval(all_sites, (s, e))

        for i in range(n_sites_needed):
            if i % 5000 == 0:
                print(f'{chromosome}: {i}/{n_sites_needed}')
                
            m = sample_inaccessible_site(all_sites)
            if m is None:
                break

            # save as a tiny DHS interval (BED-like)
            neg_rows.append((chromosome, m - 1, m + 1))

            # remove this window so negatives don't overlap each other
            used_inaccessible_site = (max(0, m - window_half), min(chr_len, m + window_half))
            all_sites = subtract_interval(all_sites, used_inaccessible_site)

    neg_df = pd.DataFrame(neg_rows, columns=['chr', 'start', 'end'])
    return neg_df

dhs_fnames = glob.glob(f"{config['input_dhs_dir']}*.bed")

for fname in dhs_fnames:
    dhs_df = pd.read_csv(fname, sep='\t', names=['chr', 'start', 'end'])
    dhs_df['mid'] = (dhs_df['start'] + dhs_df['end']) // 2
    
    new_path = fname.rsplit('.', 1)
    new_path[0] = f'{new_path[0]}_negative'
    
    output_file = '.'.join(new_path)
    negative_dhs_df = generate_negative_dhs(dhs_df, config['matrix_columns'])
    
    negative_dhs_df.to_csv(output_file, sep='\t', header=False, index=False)