In [251]:
import numpy as np
import h5py
import keras
from collections import OrderedDict

celltype_to_corefiles = {   
    'A549': {
        'ism_scores_npy_file':"/users/eprakash/projects/benchmarking/newdata/A549/models/deepseabeluga/results/A549.deepseabeluga.ISM.scores.5Ksubsample.npy",
        'ism_scores_seqnames_file': "/users/eprakash/projects/benchmarking/newdata/A549/models/deepseabeluga/results/ISM_deepseabeluga_A549_positive.labels.txt",
        'backpropscores_h5file': '/users/eprakash/projects/benchmarking/newdata/A549/models/deepseabeluga/results/A549.deepseabeluga.scores.5Ksubsample.h5',
        'positive_sequences_file': "/users/eprakash/projects/benchmarking/newdata/A549/A549.summits.400bp.implanted.5Ksubsample.bed.gz",
        'seqnames_used_for_ig_h5file': "/users/eprakash/projects/benchmarking/newdata/A549/models/deepseabeluga/results/top1kposlabels_A549_1kb",
        'motifmatches_file': '/users/eprakash/projects/benchmarking/newdata/A549/A549.motif.matches.txt',
        'variedrefs_h5file': '/users/eprakash/projects/benchmarking/newdata/A549/models/deepseabeluga/results/A549.deepseabeluga.scores.variedrefs.5Ksubsample.h5'
    },
    'HepG2': {
        'ism_scores_npy_file':"/users/eprakash/projects/benchmarking/newdata/HepG2/models/deepseabeluga/results/HepG2.deepseabeluga.ISM.scores.5Ksubsample.npy",
        'ism_scores_seqnames_file': "/users/eprakash/projects/benchmarking/newdata/HepG2/models/deepseabeluga/results/ISM_deepseabeluga_HepG2_positive.labels.txt",
        'backpropscores_h5file': '/users/eprakash/projects/benchmarking/newdata/HepG2/models/deepseabeluga/results/HepG2.deepseabeluga.scores.5Ksubsample.h5',
        'positive_sequences_file': "/users/eprakash/projects/benchmarking/newdata/HepG2/HepG2.summits.400bp.implanted.5Ksubsample.bed.gz",
        'seqnames_used_for_ig_h5file': "/users/eprakash/projects/benchmarking/newdata/HepG2/models/deepseabeluga/results/top1kposlabels_HepG2_1kb",
        'motifmatches_file': '/users/eprakash/projects/benchmarking/newdata/HepG2/HepG2.motif.matches.txt',
    },
    'H1ESC': {
        'ism_scores_npy_file':"/users/eprakash/projects/benchmarking/newdata/H1ESC/models/deepseabeluga/results/H1ESC.deepseabeluga.ISM.scores.5Ksubsample.npy",
        'ism_scores_seqnames_file': "/users/eprakash/projects/benchmarking/newdata/H1ESC/models/deepseabeluga/results/ISM_deepseabeluga_H1ESC_positive.labels.txt",
        'backpropscores_h5file': '/users/eprakash/projects/benchmarking/newdata/H1ESC/models/deepseabeluga/results/H1ESC.deepseabeluga.scores.5Ksubsample.h5',
        'positive_sequences_file': "/users/eprakash/projects/benchmarking/newdata/H1ESC/H1ESC.summits.400bp.implanted.5Ksubsample.bed.gz",
        'seqnames_used_for_ig_h5file': "/users/eprakash/projects/benchmarking/newdata/H1ESC/models/deepseabeluga/results/top1kposlabels_H1ESC_1kb",
        'motifmatches_file': '/users/eprakash/projects/benchmarking/newdata/H1ESC/H1ESC.motif.matches.txt',
        'variedrefs_h5file': '/users/eprakash/projects/benchmarking/newdata/H1ESC/models/deepseabeluga/results/H1ESC.deepseabeluga.scores.variedrefs.5Ksubsample.h5'
    }
}


In [252]:
import h5py
import numpy as np
import gzip
from collections import namedtuple, defaultdict
import sys
from sklearn.metrics import roc_auc_score, average_precision_score

MotifMatch = namedtuple("MotifMatch",
                        ["motifname", "seqname", "start", "end",
                         "strand", "hitstrength", "matchstring"])


def load_motif_matches(motif_match_file, seqnames_to_include):
    seqname_to_motifmatches = {}
    #returns a dictionary that maps seqname to a list of MotifMatch objects
    for row in open(motif_match_file):
        (motifname, seqname, homerstart, homerend,
         strand, hitstrength, matchstring) = row.rstrip().split("\t")
        if (seqname in seqnames_to_include):
            motifmatch = MotifMatch(motifname=motifname.split("-")[1],
                                    seqname=seqname,
                                    start=int(homerstart)-1, #1-indexed-inclusive to 0-indexed-inclusive
                                    end=int(homerend), #1-indexed-inclusive to 0-indexed-exclusive
                                    strand=strand,
                                    hitstrength=float(hitstrength),
                                    matchstring=matchstring)
            if (seqname not in seqname_to_motifmatches):
                seqname_to_motifmatches[seqname] = []
            seqname_to_motifmatches[seqname].append(motifmatch)
    return seqname_to_motifmatches


def onehot_encode(seqs):
    ltr = {'A': [1,0,0,0], 'C': [0,1,0,0], 'G': [0,0,1,0], 'T': [0,0,0,1], 'N': [0,0,0,0]}
    return np.array([[ltr[x] for x in seq.upper()] for seq in seqs])


def get_indices_of_subset(superset_seqnames, subset_seqnames):
    seqname_to_idx = dict([(x[1], x[0]) for x in enumerate(superset_seqnames)])
    idx_ordering = [seqname_to_idx[x] for x in subset_seqnames]
    return idx_ordering


def load_posseqs(corefiles, pos_idx_ordering):
    all_posseqs = [x.decode("utf-8").rstrip().split("\t")[1]
                   for x in gzip.open(corefiles['positive_sequences_file'])]
    posseqs = [all_posseqs[idx] for idx in pos_idx_ordering]
    return onehot_encode(posseqs), posseqs


def load_ism_scores(method_to_scores, corefiles, onehot, ism_idx_ordering):
    ism_scores_npy = corefiles['ism_scores_npy_file']
    ism_scores = np.load(ism_scores_npy)[ism_idx_ordering]
    #do some sanity checking
    #i.e. make sure that the ism scores are one-hot masked according to the
    # corresponding one-hot encoded sequence
    assert np.max(np.abs(np.sum(ism_scores*onehot,axis=-1)-np.sum(ism_scores,axis=-1)))==0.0
    method_to_scores['ism'] = np.sum(ism_scores,axis=-1)


def load_ig_scores(method_to_scores, corefiles, ig_idx_ordering):
    method_to_scores['ig10_multiref10'] = np.array(h5py.File(
        corefiles['backpropscores_h5file'], "r")
                 ['integrated_gradients10_multiref_10'][:]
                 [ig_idx_ordering])


def load_nonig_scores(method_to_scores, corefiles, nonig_idx_ordering):
    h5pyfile = h5py.File(corefiles['backpropscores_h5file'], "r")
    method_to_scores['gradtimesinp'] = h5pyfile['grad_times_inp'][:][nonig_idx_ordering]
    method_to_scores['deeplift-RS_multiref10'] = h5pyfile['rescale_all_layers_multiref_10'][:][nonig_idx_ordering]
    method_to_scores['deeplift-RC_multiref10'] = h5pyfile['rescale_conv_revealcancel_fc_multiref_10'][:][nonig_idx_ordering]


def load_variedrefs_scores(method_to_scores, corefiles, variedrefs_idx_ordering):
    #'rescale_all_layers_avg_gc_ref'
    h5pyfile = h5py.File(corefiles['variedrefs_h5file'], "r")
    method_to_scores['ig10_zeroref'] = h5pyfile['integrated_gradients10_all_zeros_ref'][:][variedrefs_idx_ordering]
    method_to_scores['ig10_gcref'] = h5pyfile['integrated_gradients10_avg_gc_ref'][:][variedrefs_idx_ordering]
    method_to_scores['deeplift-RS_zeroref'] = h5pyfile['rescale_all_layers_all_zeros_ref'][:][variedrefs_idx_ordering]
    method_to_scores['deeplift-RS_gcref'] = h5pyfile['rescale_all_layers_avg_gc_ref'][:][variedrefs_idx_ordering]


def get_sum_scores_in_window(scores, windowlen):
    assert len(scores.shape)==2
    cumsum_scores = np.pad(array=np.cumsum(scores, axis=-1),
                           pad_width=((0,0),(1,0)),
                           mode='constant',
                           constant_values=0)
    assert cumsum_scores.shape==(scores.shape[0], scores.shape[1]+1)
    to_return = cumsum_scores[:,windowlen:]-cumsum_scores[:,0:-windowlen]
    assert to_return.shape==(scores.shape[0], scores.shape[1]-(windowlen-1))
    return to_return


def get_scores_for_common_sequences(corefiles):
    corefiles = celltype_to_corefiles[celltype]
    
    #######
    #Load all the seqnames
    positives_seqnames = [x.decode("utf-8").rstrip().split("\t")[0]
                          for x in gzip.open(corefiles['positive_sequences_file'])]
    ig_seqnames = [x.decode("utf-8") for x in
                   h5py.File(corefiles['seqnames_used_for_ig_h5file'])['labels'][:]]
    nonig_backprop_seqnames = [
        x.decode("utf-8") for x in
        h5py.File(corefiles['backpropscores_h5file'])['labels'][:]]
    ism_seqnames = [
        x.rstrip() for x in open(corefiles['ism_scores_seqnames_file'])]
    all_seqnames = [ism_seqnames, nonig_backprop_seqnames, ig_seqnames]
    if ('variedrefs_h5file' in corefiles):
        variedrefs_seqnames = [x.decode("utf-8") for x in
                               h5py.File(corefiles['backpropscores_h5file'])['labels'][:]]
        all_seqnames.append(variedrefs_seqnames)
    
    ########
    #Figure out the common seqnames
    common_seqnames = set(all_seqnames[0])
    for seqnames in all_seqnames[1:]:
        common_seqnames = common_seqnames.intersection(set(seqnames))
    common_seqnames = sorted(list(common_seqnames))
    assert len(common_seqnames) == min(len(x) for x in all_seqnames),\
            (len(common_seqnames), [len(x) for x in all_seqnames])
    print("Number of common seqnames:",len(common_seqnames))
    
    ########
    #Figure out the mapping from sequence to indices for the common seqnames
    positives_idx_ordering = get_indices_of_subset(superset_seqnames=positives_seqnames,
                                                   subset_seqnames=common_seqnames)
    ism_idx_ordering = get_indices_of_subset(superset_seqnames=ism_seqnames,
                                             subset_seqnames=common_seqnames)
    ig_idx_ordering = get_indices_of_subset(superset_seqnames=ig_seqnames,
                                            subset_seqnames=common_seqnames)
    nonig_idx_ordering = get_indices_of_subset(superset_seqnames=nonig_backprop_seqnames,
                                               subset_seqnames=common_seqnames)
    if ('variedrefs_h5file' in corefiles):
        variedrefs_idx_ordering = get_indices_of_subset(superset_seqnames=variedrefs_seqnames,
                                                        subset_seqnames=common_seqnames)
    
    ########
    #Load the data using the idx ordering
    onehot_posseqs, posseqs = load_posseqs(corefiles=corefiles,
                                           pos_idx_ordering=positives_idx_ordering)
    method_to_scores = {}
    load_ism_scores(method_to_scores=method_to_scores,
                    corefiles=corefiles, onehot=onehot_posseqs,
                    ism_idx_ordering=ism_idx_ordering)
    load_ig_scores(method_to_scores=method_to_scores,
                   corefiles=corefiles,
                   ig_idx_ordering=ig_idx_ordering)
    load_nonig_scores(method_to_scores=method_to_scores,
                      corefiles=corefiles,
                      nonig_idx_ordering=nonig_idx_ordering)
    if ('variedrefs_h5file' in corefiles):
        load_variedrefs_scores(method_to_scores=method_to_scores,
                               corefiles=corefiles,
                               variedrefs_idx_ordering=variedrefs_idx_ordering)
    
    #strip away the 'dinuc_shuffled_motifs_implanted_' from the front
    seqnames = [x.replace("dinuc_shuffled_motifs_implanted_", "") for x in common_seqnames]
    
    return method_to_scores, onehot_posseqs, posseqs, seqnames


def get_motifmatches_and_nullwindows_mask(corefiles, seqnames):
    print("Reading in motif file")
    sys.stdout.flush()
    seqname_to_motifmatches = load_motif_matches(motif_match_file=corefiles['motifmatches_file'],
                                                 seqnames_to_include=seqnames)
    print("Read motif file")
    sys.stdout.flush()
    motifmatches_in_seqs = [seqname_to_motifmatches[x] for x in seqnames]
    
    #Get locations of each motif
    #Also get a mask for locations obscured by the motifs
    #covered_positions has a 1 if there is a motif at the
    # position and 0 otherwise
    covered_positions = []
    motifname_to_hitlocations = defaultdict(list)
    motifname_to_motiflen = {}
    for seqidx,(motifmatches, seq) in enumerate(zip(motifmatches_in_seqs, seqs)):
        covered_positions_entry = np.zeros(len(seq))
        for motifmatch in motifmatches:
            #sanity check
            assert motifmatch.matchstring == seq[motifmatch.start:motifmatch.end].upper(), (
                        motifmatch.matchstring, seq[motifmatch.start:motifmatch.end])  
            covered_positions_entry[motifmatch.start:motifmatch.end] = 1
            motifname_to_hitlocations[motifmatch.motifname].append(
                (seqidx, motifmatch.start))
            if motifmatch.motifname in motifname_to_motiflen:
                assert len(motifmatch.motifname)==motifname_to_motiflen[motifmatch.motifname]
            else:
                motifname_to_motiflen[motifmatch.motifname] = len(motifmatch.motifname)
        covered_positions.append(covered_positions_entry)
    covered_positions = np.array(covered_positions)
    assert len(covered_positions)==len(seqs)
    
    #get a mapping from motiflen to windows with all zeros, i.e. the negatives
    motiflens = sorted(set(len(y.motifname) for x in motifmatches_in_seqs for y in x))
    motiflen_to_nullwindowsmask = {}
    for motiflen in motiflens:
        coveredposition_windowsums = get_sum_scores_in_window(
            scores=covered_positions, windowlen=motiflen)
        nullwindowsmask = (coveredposition_windowsums==0.0)
        motiflen_to_nullwindowsmask[motiflen] = nullwindowsmask
        print("Number of null windows for length",motiflen, np.sum(nullwindowsmask))
    
    return (motifmatches_in_seqs, motifname_to_hitlocations,
            motifname_to_motiflen, motiflen_to_nullwindowsmask)


def compute_motif_scores(method_to_scores, motifname_to_hitlocations,
                         motifname_to_motiflen, motiflen_to_nullwindowsmask):
    motiflen_to_motifnames = defaultdict(list)
    for motifname in motifname_to_motiflen:
        motiflen_to_motifnames[motifname_to_motiflen[motifname]].append(motifname)
    motifname_to_method_to_hitscores = defaultdict(dict)
    motifname_to_method_to_auroc = defaultdict(dict)
    motifname_to_method_to_auprc = defaultdict(dict)
    motifname_to_numhits = {}
    motifname_to_baselineauprc = {}
    for motiflen in sorted(motiflen_to_motifnames.keys()):
        print("Doing motifs of length",motiflen)
        sys.stdout.flush()
        for method in sorted(method_to_scores.keys()):
            print("Method",method)
            sys.stdout.flush()
            scores = method_to_scores[method]
            cumsum_scores = get_sum_scores_in_window(
                             scores=scores, windowlen=motiflen)
            assert motiflen_to_nullwindowsmask[motiflen].shape==cumsum_scores.shape
            nullwindowscores = cumsum_scores[motiflen_to_nullwindowsmask[motiflen]]
            motiflen_to_method_to_nullwindowscores[motiflen] = nullwindowscores        
            for motifname in motiflen_to_motifnames[motiflen]:
                hitlocations = motifname_to_hitlocations[motifname]
                motifname_to_numhits[motifname] = len(hitlocations)
                baseline_auprc = len(hitlocations)/np.sum(motiflen_to_nullwindowsmask[motiflen])
                motifname_to_baselineauprc[motifname] = baseline_auprc
                hitscores = list(float(x) for x in cumsum_scores[tuple(zip(*hitlocations))])
                motifname_to_method_to_hitscores[motifname][method] = hitscores
                y_true = [1 for x in hitscores]+[0 for x in nullwindowscores]

                y_score = list(hitscores)+list(nullwindowscores)
                auroc = roc_auc_score(y_true=y_true, y_score=y_score)
                auprc = average_precision_score(y_true=y_true, y_score=y_score)
                motifname_to_method_to_auroc[motifname][method] = auroc
                motifname_to_method_to_auprc[motifname][method] = auprc
    return (motifname_to_method_to_hitscores,
            motifname_to_method_to_auroc, motifname_to_method_to_auprc,
            motifname_to_numhits, motifname_to_baselineauprc)


In [253]:
import json
celltypes = ['HepG2', 'H1ESC', 'A549', 'GM12878']

for celltype in celltypes:
    print("\n\nON",celltype)
    #get the scores for the different methods for those common sequences
    corefiles = celltype_to_corefiles[celltype]
    method_to_scores, onehot_seqs, seqs, seqnames =\
        get_scores_for_common_sequences(corefiles=corefiles)
    (motifmatches_in_seqs, motifname_to_hitlocations,
     motifname_to_motiflen, motiflen_to_nullwindowsmask) =\
        get_motifmatches_and_nullwindows_mask(corefiles=corefiles,
                                              seqnames=seqnames)
    (motifname_to_method_to_hitscores,
     motifname_to_method_to_auroc,
     motifname_to_method_to_auprc,
     motifname_to_numhits,
     motifname_to_baselineauprc) = compute_motif_scores(
        method_to_scores=method_to_scores,
        motifname_to_hitlocations=motifname_to_hitlocations,
        motifname_to_motiflen=motifname_to_motiflen,
        motiflen_to_nullwindowsmask=motiflen_to_nullwindowsmask)

    #save things to json
    open(celltype+"_motifscoring_results.json",'w').write(
         json.dumps(
            {'motifname_to_hitlocations': motifname_to_hitlocations,
             'motifname_to_method_to_hitscores': motifname_to_method_to_hitscores,
             'motifname_to_method_to_auroc': motifname_to_method_to_auroc,
             'motifname_to_method_to_auprc': motifname_to_method_to_auprc,
             'motifname_to_numhits': motifname_to_numhits,
             'motifname_to_baselineauprc': motifname_to_baselineauprc}))



ON HepG2
Number of common seqnames: 1000
Reading in motif file
Read motif file
Number of null windows for length 8 150451
Number of null windows for length 10 133560
Number of null windows for length 12 118663
Doing motifs of length 8
Method deeplift-RC_multiref10
Method deeplift-RS_multiref10
Method gradtimesinp
Method ig10_multiref10
Method ism
Doing motifs of length 10
Method deeplift-RC_multiref10
Method deeplift-RS_multiref10
Method gradtimesinp
Method ig10_multiref10
Method ism
Doing motifs of length 12
Method deeplift-RC_multiref10
Method deeplift-RS_multiref10
Method gradtimesinp
Method ig10_multiref10
Method ism


ON H1ESC
Number of common seqnames: 1000
Reading in motif file
Read motif file
Number of null windows for length 8 129007
Number of null windows for length 10 113198
Number of null windows for length 12 99353
Doing motifs of length 8
Method deeplift-RC_multiref10
Method deeplift-RS_gcref
Method deeplift-RS_multiref10
Method deeplift-RS_zeroref
Method gradtimesinp
M

In [254]:
!gzip -f *.json

In [281]:
import gzip

for celltype in celltypes:
    dicts = json.loads(gzip.open(celltype+"_motifscoring_results.json.gz").read())
    motifname_to_method_to_auroc = dicts['motifname_to_method_to_auroc']
    motifname_to_method_to_auprc = dicts['motifname_to_method_to_auprc']
    motifname_to_numhits = dicts['motifname_to_numhits']
    motifname_to_baselineauprc = dicts['motifname_to_baselineauprc']
    
    motifnames = [x[0] for x in sorted(motifname_to_numhits.items(), key=lambda x: -x[1])]
    methods = sorted(motifname_to_method_to_auroc[motifnames[0]])
    auroc_outf = open(celltype+"_motifaurocs.tsv",'w')
    auprc_outf = open(celltype+"_motifauprcs.tsv",'w')
    auroc_outf.write("Motifname\tnumhits\t"+"\t".join(methods)+'\n')
    auprc_outf.write("Motifname\tnumhits\tbaselineauprc\t"+"\t".join(methods)+'\n')
    for motifname in motifnames:
        baselineauprc = motifname_to_baselineauprc[motifname]
        auroc_outf.write(motifname
                   +"\t"+str(motifname_to_numhits[motifname])
                   +"\t"+"\t".join(str(motifname_to_method_to_auroc[motifname][method])
                                       for method in methods)+"\n")
        auprc_outf.write(motifname
                   +"\t"+str(motifname_to_numhits[motifname])
                   +"\t"+str(baselineauprc)
                   +"\t"+"\t".join(str(motifname_to_method_to_auprc[motifname][method])
                                       for method in methods)+"\n")
    auroc_outf.close()
    auprc_outf.close()
    

In [283]:
!head -50 A549_motifaurocs.tsv

Motifname	numhits	deeplift-RC_multiref10	deeplift-RS_gcref	deeplift-RS_multiref10	deeplift-RS_zeroref	gradtimesinp	ig10_gcref	ig10_multiref10	ig10_zeroref	ism
TSGGAAMG	1775	0.7223725930133899	0.7098577461392888	0.7243529039077922	0.7099183299431487	0.6617983596131259	0.7006741837724241	0.6771241855431143	0.7012871666675561	0.7016636649019596
CAGTCATK	1649	0.8355876858472063	0.8046332755702159	0.8367236132898852	0.8056787620251955	0.7911343832263643	0.8167531282286272	0.824927417342391	0.8206458943219231	0.8214070887361671
AAAACGMG	1387	0.6960827287229601	0.7034973875909944	0.6977253688720721	0.7058338855188802	0.62364627914768	0.7014662304940966	0.6818079151211244	0.7039348084400191	0.7147709325422422
SCYYTARR	1287	0.7810942755144181	0.769323517961275	0.7828454537613081	0.7695276770363739	0.7239797566047006	0.7785331114921581	0.7747514998539694	0.7716690096051513	0.7752879603738971
VACWTTCC	1244	0.788982179387177	0.7570531033374485	0.7911435850301403	0.763792933598235	0.7316543360