In [1]:
import modisco
import h5py
import deepdish
import intervaltree
from collections import defaultdict, OrderedDict, Counter
from modisco.visualization import viz_sequence
import matplotlib

# for pdf text saving 
matplotlib.rcParams['pdf.fonttype'] = 42

from matplotlib import pyplot as plt
import numpy as np
import tqdm
import modisco
from modisco.visualization import viz_sequence
import tqdm
import pyBigWig
import pyfaidx

In [2]:
!pip freeze | grep modisco


modisco==0.5.16.0


In [3]:
import os


In [4]:
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"


In [5]:
def revcomp(x):
    # Assuming ACGT
    return x[::-1][:,::-1]

In [6]:
# will crop to this region around center (summit of peak) [this is what was done for this data when modisco was called]
MODISCO_CROP_WIDTH = 500

In [7]:
hg38 = pyfaidx.Fasta("/users/surag/genomes/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta")

In [8]:
def get_regions(regions_file, crop_width):
    with open(regions_file) as f:
        scored_regions = [x.strip().split('\t') for x in f]

    # importance scores are computed centered at summit (2nd col + 10th col)
    scored_regions = [(x[0], int(x[9])+int(x[1])-crop_width//2, int(x[9])+int(x[1])+crop_width//2) for x in scored_regions]
    
    return scored_regions

In [9]:
def fetch_data(modisco_hdf5, pattern_name, pattern_start, pattern_end,
               imp_scores, regions, insertions_bw_file, 
               pred_w_bias_bw_file, pred_wo_bias_bw_file, 
               modisco_crop_width, genome, rc_everything=False, expand_bw=100):
    """
    this function basically goes through the seqlets of a given pattern (pattern_name)
    it pulls out raw shap scores for each seqlet, its coordinates, whether it's RC-ed,
    its subcluster indices
    also pulls out raw insertions for each seqlet (out expand_bw on both sides), and same
    for predicted profiles w/ and w/o bias 
    
    returned seqlets all have the same orientation!
    pattern_start -> start of pattern within seqlet ([0,seqlen_len)]
    pattern_end-> start of pattern within seqlen ((pattern_start,seqlen_len)]
    """
    
    assert(len(regions)==imp_scores['shap']['seq'].shape[0])
    
    # region sliced out and used for modisco
    imp_width = imp_scores['shap']['seq'].shape[-1]
    imp_crop_start = imp_width//2 - modisco_crop_width//2
    imp_crop_end = imp_width//2 + modisco_crop_width//2
    
    pattern = modisco_hdf5['metacluster_idx_to_submetacluster_results']['metacluster_0']['seqlets_to_patterns_result']['patterns'][pattern_name]
    
    seqlet_len = pattern["sequence"]["fwd"].shape[0]
    
    subcluster_idxs = np.array(list(pattern["subclusters"]))
    
    insertions_bw = pyBigWig.open(insertions_bw_file)
    pred_w_bias_bw = pyBigWig.open(pred_w_bias_bw_file)
    pred_wo_bias_bw = pyBigWig.open(pred_wo_bias_bw_file)
    
    seqlet_coords = []
    seqlet_is_rc = []
    seqlet_shaps = []
    seqlet_one_hots = []
    seqlet_pred_wo_bias = []
    seqlet_pred_w_bias = []
    seqlet_insertions = []
    
    # fetch info from the bigwigs where the seqlets are present
    for i in tqdm.trange(len(pattern['seqlets_and_alnmts']['seqlets'])):
        x = pattern['seqlets_and_alnmts']['seqlets'][i]
        
        # index into regions/imp_scores, i.e. which entry in regions does each seqlet come from
        idx = int(x.decode('utf8').split(',')[0].split(':')[1])
        start = int(x.decode('utf8').split(',')[1].split(':')[1])
        rc = eval(x.decode('utf8').split(',')[3].split(':')[1])

        reg = regions[idx]

        cur_proj_shap_scores = imp_scores['projected_shap']['seq'][idx][:, imp_crop_start:imp_crop_end].transpose()
        cur_one_hot = imp_scores['raw']['seq'][idx][:, imp_crop_start:imp_crop_end].transpose()

        seqlet_is_rc.append(rc)
        
        if rc:
            seqlet_shaps.append(revcomp(cur_proj_shap_scores[start+(seqlet_len-pattern_end):start+(seqlet_len-pattern_start)]))
            seqlet_one_hots.append(revcomp(cur_one_hot[start+(seqlet_len-pattern_end):start+(seqlet_len-pattern_start)]))

            start_coord = reg[1]+start+(seqlet_len-pattern_end)
            end_coord = reg[1]+start+(seqlet_len-pattern_start)
            seqlet_coords.append([reg[0], start_coord, end_coord])
            
            start_coord_expanded = start_coord - expand_bw
            end_coord_expanded = end_coord + expand_bw
            
            seqlet_pred_wo_bias.append(np.nan_to_num(pred_wo_bias_bw.values(reg[0], start_coord_expanded, end_coord_expanded))[::-1])
            seqlet_pred_w_bias.append(np.nan_to_num(pred_w_bias_bw.values(reg[0], start_coord_expanded, end_coord_expanded))[::-1])
            seqlet_insertions.append(np.nan_to_num(insertions_bw.values(reg[0], start_coord_expanded, end_coord_expanded))[::-1])

        else:
            seqlet_shaps.append(cur_proj_shap_scores[start+pattern_start:start+pattern_end])
            seqlet_one_hots.append(cur_one_hot[start+pattern_start:start+pattern_end])
            
            start_coord = reg[1]+start+pattern_start
            end_coord = reg[1]+start+pattern_end
            seqlet_coords.append([reg[0], start_coord, end_coord])
            
            start_coord_expanded = start_coord - expand_bw
            end_coord_expanded = end_coord + expand_bw

            seqlet_pred_wo_bias.append(np.nan_to_num(pred_wo_bias_bw.values(reg[0], start_coord_expanded, end_coord_expanded)))
            seqlet_pred_w_bias.append(np.nan_to_num(pred_w_bias_bw.values(reg[0], start_coord_expanded, end_coord_expanded)))    
            seqlet_insertions.append(np.nan_to_num(insertions_bw.values(reg[0], start_coord_expanded, end_coord_expanded)))
    
    insertions_bw.close()
    pred_w_bias_bw.close()
    pred_wo_bias_bw.close()
    
    # RC-ing all instances
    seqlet_shaps = np.array(seqlet_shaps) #[:,::-1,::-1]
    seqlet_one_hots = np.array(seqlet_one_hots)  # [:,::-1,::-1]
    seqlet_pred_wo_bias = np.array(seqlet_pred_wo_bias)
    seqlet_pred_w_bias = np.array(seqlet_pred_w_bias)
    seqlet_insertions = np.array(seqlet_insertions)
    
    if rc_everything:
        seqlet_shaps = seqlet_shaps[:, ::-1, ::-1]
        seqlet_one_hots = seqlet_one_hots[:, ::-1, ::-1]
        seqlet_pred_wo_bias = seqlet_pred_wo_bias[:, ::-1]
        seqlet_pred_w_bias = seqlet_pred_w_bias[:, ::-1]
        seqlet_insertions = seqlet_insertions[:, ::-1]
    
    return seqlet_coords, seqlet_is_rc, subcluster_idxs, seqlet_shaps, seqlet_one_hots, seqlet_insertions, seqlet_pred_w_bias, seqlet_pred_wo_bias

In [10]:
HIGH_OSK_MODISCO_PATH = "/users/surag/oak/projects/scATAC-reprog/bpnet/models/20210820_chrombpnet_lite/chrombpnet/cluster_idx11/modisco/modisco_results_allChroms_counts.hdf5"
HIGH_OSK_HDF5_PATH = "/users/surag/oak/projects/scATAC-reprog/bpnet/models/20210820_chrombpnet_lite/chrombpnet/cluster_idx11/interpret/counts_scores.h5"
HIGH_OSK_REG_PATH = "/users/surag/oak/projects/scATAC-reprog/bpnet/models/20210820_chrombpnet_lite/chrombpnet/cluster_idx11/interpret/interpreted_regions.bed"

In [None]:
high_OSK_modisco = h5py.File(HIGH_OSK_MODISCO_PATH, 'r')
