# Motif Scan Importance Threshold

Apply an importance score based threshold for motifs-- top 1% by summed importance score across length of motif.

Do it for high OSK cell state.

In [15]:
import pyBigWig
import numpy as np
import tqdm
import pandas as pd
from collections import defaultdict

In [1]:
from modisco.visualization import viz_sequence


In [3]:
import numpy as np

In [4]:
HIGH_OSK_REG_PATH = "/users/surag/oak/projects/scATAC-reprog/bpnet/models/20210820_chrombpnet_lite/chrombpnet/cluster_idx11/interpret/interpreted_regions.bed" 

In [5]:
HIGH_OSK_IMP_BW = "/users/surag/oak/projects/scATAC-reprog/bpnet/models/20210820_chrombpnet_lite/chrombpnet/cluster_idx11/interpret/bigwig/counts.importance.bw"

In [6]:
high_osk_imp_bw = pyBigWig.open(HIGH_OSK_IMP_BW)

In [2]:
def get_regions(regions_file, crop_width): 
    with open(regions_file) as f: 
        scored_regions = [x.strip().split('\t') for x in f] 
 
    # importance scores are computed centered at summit (2nd col + 10th col) 
    scored_regions = [(x[0], int(x[9])+int(x[1])-crop_width//2, int(x[9])+int(x[1])+crop_width//2) for x in scored_regions] 
     
    return scored_regions 

In [7]:
regions = get_regions(HIGH_OSK_REG_PATH, 2000) # 2000 is output size of model 

In [9]:
regions[0]

('chr1', 9323, 11323)

In [27]:
# get perc percentile of importance values for windows of width
# negative values are replaced by 0 to ignore negative importance contribs
# regs is the regions with importance scores available
def get_perc(bw, regs, width, perc, N=10000): 
    vals = [] 
    for _ in tqdm.tqdm(range(N)): 
        i = np.random.randint(0, len(regs)) 
        j = np.random.randint(regs[i][1], regs[i][2]-width)
        vals.append(np.sum(np.maximum(np.nan_to_num(bw.values(regs[i][0], j, j+width)),0)))
         
    return np.quantile(vals, perc) 


In [12]:
SCAN_SCHEMA = ["chr", "start", "end", "strand", "score", "seq"]

In [20]:
MOTIFS = ["OCTSOX", "KLF", "SOX", "AP1"]

In [66]:
raw_scans = dict()

for motif in MOTIFS:
        raw_scans[motif] = pd.read_csv("./motif_scans/scans/raw/{}.bed".format(motif), sep='\t', names=SCAN_SCHEMA)

In [67]:
widths = set([raw_scans[x].iloc[0]['end']-raw_scans[x].iloc[0]['start'] for x in MOTIFS])
widths

{9, 15}

In [30]:
imp_threshes = dict()

for x in widths:
    imp_threshes[x] = get_perc(high_osk_imp_bw, regions, x, 0.99, 100000)

100%|██████████| 100000/100000 [00:54<00:00, 1824.85it/s]
100%|██████████| 100000/100000 [00:36<00:00, 2728.51it/s]


In [65]:
imp_threshes

{9: 0.24456437518819985, 15: 0.3719796596281215}

In [78]:
imp_thresholded_scans = dict()

for motif in MOTIFS:
    w = raw_scans[motif].iloc[0]['end'] - raw_scans[motif].iloc[0]['start']

    imp_vals = []
    for _,x in tqdm.tqdm(raw_scans[motif].iterrows(), total=raw_scans[motif].shape[0]):
        imp_vals.append(np.sum(np.maximum(np.nan_to_num(high_osk_imp_bw.values(x['chr'], x['start'], x['end'])),0)))
        
    raw_scans[motif]['imp'] = imp_vals

    imp_thresholded_scans[motif] = raw_scans[motif][raw_scans[motif]['imp']>imp_threshes[w]]

100%|██████████| 436180/436180 [02:22<00:00, 3068.22it/s]
100%|██████████| 572850/572850 [03:06<00:00, 3065.58it/s]
100%|██████████| 301965/301965 [01:38<00:00, 3054.21it/s]
100%|██████████| 61725/61725 [00:20<00:00, 3007.67it/s]


In [81]:
for m in MOTIFS:
    print(m, imp_thresholded_scans[m].shape[0], imp_thresholded_scans[m].shape[0]/raw_scans[m].shape[0])

OCTSOX 69617 0.159606125911321
KLF 126318 0.22050798638387012
SOX 18844 0.06240458331263557
AP1 29840 0.4834345889023896


Write importance thresholded hits.

In [82]:
for motif in MOTIFS:
    imp_thresholded_scans[motif].to_csv("./motif_scans/scans/importance_thresholded/{}.tsv".format(motif),
                                                        sep ='\t', index=False)

In [84]:
imp_thresholded_scans["AP1"]

Unnamed: 0,chr,start,end,strand,score,seq,imp
1,chr17,33487649,33487658,-,6.974161,TTTGAGTCAC,0.368245
3,chr17,46474674,46474683,+,8.923917,GTGAGTCATG,0.577262
4,chr17,70057186,70057195,+,8.479893,ATGACTCAAC,0.323845
6,chr17,19642881,19642890,-,7.221008,GATTAGTCAC,0.311039
11,chr13,32056511,32056520,+,6.278644,ATGACACATT,0.274024
...,...,...,...,...,...,...,...
61715,chrY,15405628,15405637,+,8.495907,GTGAGTCAGG,0.456366
61716,chrY,7858109,7858118,-,9.486049,AATGAGTCAC,0.890239
61719,chrY,12833476,12833485,-,6.792998,GCTTAGTCAC,0.356936
61721,chrY,13218725,13218734,-,8.660520,GGTGACTCAC,0.534028


In [87]:
old_os_scan = pd.read_csv("../../analysis/20210520_spacing_in_data/scanning/tfmodisco_cluster_idx4_gc_neg_peak_set_8_10_11_13/high_OSK.not.fibr.1000.OCTSOX.bed".format(motif), sep='\t', names=SCAN_SCHEMA)

In [89]:
get_perc(high_osk_imp_bw, regions, 14, .99, N=10000)

100%|██████████| 10000/10000 [00:02<00:00, 4532.30it/s]


0.3464563995762729

In [91]:
imp_vals = []
for _,x in tqdm.tqdm(old_os_scan.iterrows(), total=old_os_scan.shape[0]):
    imp_vals.append(np.sum(np.maximum(np.nan_to_num(high_osk_imp_bw.values(x['chr'], x['start'], x['end'])),0)))
        
old_os_scan['imp'] = imp_vals

imp_thresholded_old_scan = old_os_scan[old_os_scan['imp']>.346]

100%|██████████| 398249/398249 [02:10<00:00, 3059.88it/s]


In [92]:
imp_thresholded_old_scan.shape

(48077, 7)

In [93]:
imp_thresholded_old_scan.to_csv("./tmp/OCTSOX.old.scan.agg.imp.thresholded.tsv", 
                                                        sep ='\t', index=False)