In [1]:
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
import h5py
import simdna.synthetic as synthetic
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    roc_curve, precision_recall_curve)
import re
from collections import OrderedDict, defaultdict
def load_motif_matches(motif_match_file, doprint=False):
        """
        Loads a homer motif match file into an ordered dictionary with key as se
quence name
        and value as list of dictionaries each containing the keys - motif, sequ
ence,
        begin (0-indexed inclusive begin index of motif), end (0-indexed exclusi
ve end index),
        strand (+ or -), seqval. Each dictionary
        represents one motif match on that sequence
        """
        motif_matches = OrderedDict()
        fp = open(motif_match_file, "r")
        if doprint:
                print("#Loading " + motif_match_file + " ...")
        numlines = 0
        for line in fp:
                match = re.match("((\w|\-)+)\s+((\w|\:|\-)+)\s+(\d+)\s+(\d+)\s+(\+|\-)\s+.+\s+(\w+)$", line)
                if match:
                        numlines = numlines + 1
                        motif = match.group(1)
                        sequence = match.group(3)
                        begin = int(match.group(5))
                        end = int(match.group(6))
                        strand = match.group(7)
                        seqval = match.group(8)
                        entry = dict()
                        entry['motif'] = motif
                        entry['sequence'] = sequence
                        entry['begin'] = begin-1 # Homer motif match file is 1 indexed, convert to 0
                        entry['end'] = end # Homer motif match file is 1 indexed AND inclusive, convert to 0 and exclusive
                        entry['strand'] = strand
                        entry['seqval'] = seqval
                        if sequence not in motif_matches:
                                motif_matches[sequence] = list()
                        motif_matches[sequence].append(entry)
        fp.close()
        if doprint:
                print("#Loaded " + str(numlines) + " motif matches in " + str(len(motif_matches.keys())) + " sequences")
        return motif_matches

In [2]:
def rename(label):
    match=re.match('.*_(chr.*)$',label)
    if match:
        return match.group(1)
    else:
        return ""

In [3]:
def get_relevant_labels_in_order_of_scores(labels, motif_matches):
    relevant_labels_list=[]
    relevant_indices_list=[]
    sequence_index=0
    positive_labels=[]
    for label in motif_matches.keys():
        positive_labels.append(label)
    positive_labels_set = set(positive_labels)
    for sequence_label in labels:
        if sequence_label in positive_labels_set:
            if sequence_index==731:
                print(sequence_index)
            relevant_indices_list.append(sequence_index)
            relevant_labels_list.append(sequence_label)
        sequence_index=sequence_index+1
    print (len(relevant_indices_list))
    print (len(relevant_labels_list))
    return relevant_indices_list, relevant_labels_list

In [4]:
def get_relevant_scores(relevant_indices_list, scores, seq_len=400):
    relevant_scores=np.zeros((len(relevant_indices_list),seq_len))
    index=0
    for scores_index in relevant_indices_list:
        relevant_scores[index]=scores[scores_index]
        index=index+1
    return relevant_scores

In [5]:
def load_labels(labels_file):
    f = open(labels_file, "r")
    l = f.read().splitlines()
    print ("Read " + str(len(l)) + " labels from " + str(labels_file))
    f.close()
    return l

In [6]:
scores_file='/users/eprakash/projects/benchmarking/newdata/A549/models/deepseabeluga/results/A549.deepseabeluga.ISM.scores.5Ksubsample.npy'
initial_labels=load_labels("/users/eprakash/projects/benchmarking/newdata/A549/models/deepseabeluga/results/ISM_deepseabeluga_A549_positive.labels.txt")
size=len(initial_labels)
labels=np.empty(size,dtype=object)
#fpointer = open("/users/eprakash/projects/benchmarking/newdata/K562/K562_pos_labels_truncated.txt", "w")
for index in range(0, size):
    labels[index]=rename(initial_labels[index])
    #fpointer.write(str(labels[index]) + "\n")
#Sanity check
#fpointer.close()
print (labels.shape)
#print("Labels")
#print(labels)

# TAKE THE ABSOLUTE VALUE OF THE ISM SCORE!"
original_ism_scores=np.abs(np.load(scores_file))
print(original_ism_scores.shape)
original_ism_scores=np.sum(original_ism_scores, axis=2)
print(original_ism_scores.shape)

seq_len=original_ism_scores.shape[1]
motif_matches=load_motif_matches('/users/eprakash/projects/benchmarking/newdata/A549/A549.motif.matches.txt', True)
initial_seq_ids_of_interest = motif_matches.keys()
#print("Initial seq ids of interest")
#print(initial_seq_ids_of_interest)
seq_ids_of_interest=initial_seq_ids_of_interest
#seq_ids_of_interest=[]
#for id in initial_seq_ids_of_interest:
#    seq_ids_of_interest.append(rename(id))
print(len(seq_ids_of_interest))

Read 5000 labels from /users/eprakash/projects/benchmarking/newdata/A549/models/deepseabeluga/results/ISM_deepseabeluga_A549_positive.labels.txt
(5000,)
(5000, 400, 4)
(5000, 400)
#Loading /users/eprakash/projects/benchmarking/newdata/A549/A549.motif.matches.txt ...
#Loaded 4599731 motif matches in 143217 sequences
143217


In [7]:
motif_keyz = set(motif_matches.keys())
print(len(motif_keyz))
labels_keyz = set(labels)
print(len(labels_keyz))
print((motif_keyz - labels_keyz))
print((labels_keyz - motif_keyz))

143217
5000
set(['chr16:4517753-4518153', 'chr2:218964041-218964441', 'chr2:85772302-85772702', 'chr8:60396882-60397282', 'chr8:67339504-67339904', 'chr15:92904195-92904595', 'chr3:171055997-171056397', 'chr12:111767048-111767448', 'chr20:32465526-32465926', 'chr1:214505073-214505473', 'chr5:116137816-116138216', 'chr19:2168810-2169210', 'chrX:40580357-40580757', 'chr4:107862661-107863061', 'chr16:11550745-11551145', 'chr15:40929062-40929462', 'chr12:55728294-55728694', 'chr7:115117339-115117739', 'chr22:39152902-39153302', 'chr1:32336115-32336515', 'chr18:25996374-25996774', 'chr7:134368711-134369111', 'chr6:165556462-165556862', 'chr3:81676728-81677128', 'chr1:173330371-173330771', 'chr6:27214940-27215340', 'chr10:86400643-86401043', 'chr1:230113669-230114069', 'chr4:96293771-96294171', 'chr12:60920816-60921216', 'chr7:51315637-51316037', 'chr8:39001234-39001634', 'chr1:21022543-21022943', 'chr12:119804318-119804718', 'chrX:48696813-48697213', 'chr20:382137-382537', 'chr1:202090961-2

In [8]:
seq_ids_of_interest_set = set(seq_ids_of_interest)
print(len(seq_ids_of_interest))
print(len(seq_ids_of_interest_set))
relevant_indices_list, relevant_labels_list=get_relevant_labels_in_order_of_scores(labels, motif_matches)
ism_scores=get_relevant_scores(relevant_indices_list, original_ism_scores, seq_len)
#reveal_cancel_scores=get_relevant_scores(relevant_indices_list, original_reveal_cancel_scores, seq_len)
#guided_backprop_scores=get_relevant_scores(relevant_indices_list, original_guided_backprop_scores, seq_len)
print (ism_scores.shape)

143217
143217
731
5000
5000
(5000, 400)


In [9]:
print(len(seq_ids_of_interest_set))
relevant_labels_set = set(relevant_labels_list)
print(len(relevant_labels_set))
print(seq_ids_of_interest_set - relevant_labels_set)
print(relevant_labels_set - seq_ids_of_interest_set)
seq_ids_of_interest_set = relevant_labels_set
removables = list()
for i in range(len(seq_ids_of_interest)):
    if seq_ids_of_interest[i] not in relevant_labels_set:
        removables.append(i)
removables.reverse()
for i in removables:
    del seq_ids_of_interest[i]
print(len(seq_ids_of_interest))

143217
5000
set(['chr16:4517753-4518153', 'chr2:218964041-218964441', 'chr2:85772302-85772702', 'chr8:60396882-60397282', 'chr8:67339504-67339904', 'chr15:92904195-92904595', 'chr3:171055997-171056397', 'chr12:111767048-111767448', 'chr20:32465526-32465926', 'chr1:214505073-214505473', 'chr5:116137816-116138216', 'chr19:2168810-2169210', 'chrX:40580357-40580757', 'chr4:107862661-107863061', 'chr16:11550745-11551145', 'chr15:40929062-40929462', 'chr12:55728294-55728694', 'chr7:115117339-115117739', 'chr22:39152902-39153302', 'chr1:32336115-32336515', 'chr18:25996374-25996774', 'chr7:134368711-134369111', 'chr6:165556462-165556862', 'chr3:81676728-81677128', 'chr1:173330371-173330771', 'chr6:27214940-27215340', 'chr10:86400643-86401043', 'chr1:230113669-230114069', 'chr4:96293771-96294171', 'chr12:60920816-60921216', 'chr7:51315637-51316037', 'chr8:39001234-39001634', 'chr1:21022543-21022943', 'chr12:119804318-119804718', 'chrX:48696813-48697213', 'chr20:382137-382537', 'chr1:202090961-2

In [10]:
import re
method_to_saved_scores = OrderedDict([
    ('ism', ism_scores)
])
method_to_seq_id_to_scores = {}
for method_name in method_to_saved_scores:
    scores = method_to_saved_scores[method_name]
    if (method_name=='integrated_gradients'):
        assert(len(ig_seq_ids_of_interest)==len(scores))
        seq_id_to_scores = dict(zip(ig_seq_ids_of_interest,scores))
    else:
        print(len(seq_ids_of_interest))
        print(len(scores))
        assert(len(seq_ids_of_interest)==len(scores))
        seq_id_to_scores = dict(zip(relevant_labels_list,scores))
    method_to_seq_id_to_scores[method_name] = seq_id_to_scores


#covered_positions marks a position as a 1 if it overlaps
# an embedded motif
seq_id_to_covered_positions = {}
motif_id_to_hit_locations = defaultdict(list)
motif_id_to_motif_length = {}
#For each motif, we need a mapping from the motif
# to the positives locations
for seq_id in motif_matches:
    original_seq_id=seq_id
    #seq_id=rename(seq_id)
    if (method_name=='integrated_gradients'):
        if (seq_id in ig_seq_ids_of_interest):
            embedded_positions = np.zeros(seq_len)
            for embedding in motif_matches[original_seq_id]:
                the_seq=embedding['sequence']
                motif_start_loc = embedding['begin'] 
                motif_end_loc = embedding['end']
                motif_len = motif_end_loc-motif_start_loc
                embedded_positions[motif_start_loc:motif_end_loc] = 1.0
                motif_name = re.match('(\d+)-(\w+)-(\d+)',embedding['motif']).group(2)
                if (motif_name in motif_id_to_motif_length):
                    assert (motif_id_to_motif_length[motif_name]==motif_len)
                else:
                    motif_id_to_motif_length[motif_name] = motif_len
                motif_id_to_hit_locations[motif_name].append(
                    (seq_id,motif_start_loc))
            seq_id_to_covered_positions[seq_id] = embedded_positions 
    else:
        if (seq_id in seq_ids_of_interest_set):
            embedded_positions = np.zeros(seq_len)
            for embedding in motif_matches[original_seq_id]:
                the_seq=embedding['sequence']
                motif_start_loc = embedding['begin'] 
                motif_end_loc = embedding['end']
                motif_len = motif_end_loc-motif_start_loc
                embedded_positions[motif_start_loc:motif_end_loc] = 1.0
                motif_name = re.match('(\d+)-(\w+)-(\d+)',embedding['motif']).group(2)
                if (motif_name in motif_id_to_motif_length):
                    assert (motif_id_to_motif_length[motif_name]==motif_len)
                else:
                    motif_id_to_motif_length[motif_name] = motif_len
                motif_id_to_hit_locations[motif_name].append(
                    (seq_id,motif_start_loc))
        #print(embedded_positions)
            seq_id_to_covered_positions[seq_id] = embedded_positions 

total_motif_bases = 0
total_bases = 0
for (seqname, arr) in seq_id_to_covered_positions.items():
    total_motif_bases = total_motif_bases + np.sum(arr)
    total_bases = total_bases +  len(arr)
print("Motif positions: " + str(total_motif_bases) + ", total positions: " + str(total_bases))

#find windows of a given length that do not overlap any motif
motif_len_to_negatives = defaultdict(list)                            
for motif_len in set(motif_id_to_motif_length.values()):
    for seq_id,covered_positions in seq_id_to_covered_positions.items():
        cumsum = np.array([0]+list(np.cumsum(covered_positions)))
        window_sums = cumsum[motif_len:]-cumsum[0:-motif_len]
        null_windows = [(seq_id,x) for x in
                        np.nonzero(window_sums==0)[0]]
        motif_len_to_negatives[motif_len].extend(null_windows)

5000
5000
Motif positions: 825682.0, total positions: 2000000


In [11]:
motif_id_to_pos_locs={}
for motif_id in motif_id_to_hit_locations:
    motif_len = motif_id_to_motif_length[motif_id]
    num_pos_locs = len(motif_id_to_hit_locations[motif_id])
    #num_neg_locs = len(motif_len_to_negatives[motif_len])
    motif_id_to_pos_locs.update({motif_id:float(num_pos_locs)})
print(len(motif_id_to_pos_locs))
top_5_motif_ids=sorted(motif_id_to_pos_locs, key=lambda x: motif_id_to_pos_locs[x])
#['TGAATGAATGAA', 'GGGGCGGGGC', 'AGAGGAAGTG', 'AAAAAAAAAAAA', 'AAAAAAAAAA', 'CTTCCTCT', 'GCCACTGCAC', 'AAAAAAAA', 'GGGTGGGG'] 
#top_5_motif_ids=sorted(motif_id_to_hit_locations, key=lambda x: len(motif_id_to_hit_locations[x]))[-5:]
#rescale_all_scores=np.random.rand(rescale_all_scores.shape[0], rescale_all_scores.shape[1])

75


In [12]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 15}

matplotlib.rc('font', **font)

method_name_to_plot_style = {
    'grad_times_input': {'color':'green'},
    'rescale_all': {'color':'orange'},
    'rescale_conv': {'color':'purple'},
    'reveal_cancel': {'color': 'red'},
    'guided_backprop': {'color': 'yellow'},
    'integrated_gradients': {'color': 'blue'}
}

method_name_nicename = {
    'grad_times_input': 'grad_times_input',
    'rescale_all': 'rescale_all',
    'rescale_conv': 'rescale_conv',
    'reveal_cancel': 'reveal_cancel',
    'guided_backprop': 'guided_backprop',
    'integrated_gradients': 'integrated_gradients'
}
motif_info={}
print("Number of motifs is " + str(len(motif_id_to_motif_length.keys())))
for motif_id in sorted(top_5_motif_ids):
    
    for methods_to_plot in [['ism']]:
    #[['grad_times_input', 'rescale_all', 'rescale_conv', 'reveal_cancel', 'guided_backprop', 'integrated_gradients']]:
        
        #print("On motif",motif_id)
        motif_len = motif_id_to_motif_length[motif_id]
        pos_locs = motif_id_to_hit_locations[motif_id]
        neg_locs = motif_len_to_negatives[motif_len]
        #print("num pos",len(pos_locs))
        #print("num neg",len(neg_locs))
        all_locs = list(pos_locs)+list(neg_locs)
        #print(len(motif_matches.items()))
        loc_labels = [1 for x in pos_locs]+[0 for x in neg_locs]

        #f,axarr = plt.subplots(1,2, figsize=(20,5))
        #plt.subplot(122)
        #plt.title(motif_id+' auROC Curves')
        #plt.plot([0, 1], [0, 1], color='black', linestyle='--')
        #plt.xlim([0, 1])
        #plt.ylim([0, 1])
        #plt.ylabel('True Positive Rate')
        #plt.xlabel('False Positive Rate')

        pos_frac = float(len(pos_locs))/(len(pos_locs)+len(neg_locs))
        #plt.subplot(121)
        #plt.title(motif_id+' auPRC Curves')
        #plt.plot([0, 1], [pos_frac, pos_frac], color='black', linestyle='--')
        #plt.xlim([0, 1])
        #plt.ylim([0, 1])
        #plt.ylabel('Precision')
        #plt.xlabel('Recall') 
    
        for method_name in methods_to_plot:
            #print(method_name)
            seq_id_to_windowsums = {}
            seq_id_to_scores=method_to_seq_id_to_scores[method_name]
            for seq_id,scores in seq_id_to_scores.items():
                #cumsum = np.array([0]+list(np.cumsum(np.sum(scores,axis=-1))))
                cumsum = np.array([0]+list(np.cumsum(scores)))
                windowsums = cumsum[motif_len:]-cumsum[:-motif_len]
                seq_id_to_windowsums[seq_id] = windowsums
            loc_scores = [seq_id_to_windowsums[seq_id][pos]
                          for (seq_id, pos) in all_locs]
            auroc = roc_auc_score(y_true=loc_labels,
                                  y_score=loc_scores)
            auprc = average_precision_score(y_true=loc_labels,
                                            y_score=loc_scores)
            #print("auROC: " + str(auroc))
            #print("auPRC: " + str(auprc))
            #print("Hits in pos set: " + str(motif_id_to_pos_locs[motif_id]))
            motif_info.update({motif_id:[auroc,auprc,motif_id_to_pos_locs[motif_id]]})
            """plt.subplot(122)
            fpr, tpr, rocthresholds = roc_curve(
                loc_labels, loc_scores, pos_label=1)
            plt.plot(fpr, tpr,
                     label = method_name_nicename[method_name]+(' AUC = %0.2f' % (100*auroc))+"%",
                     linewidth=2,
                     **method_name_to_plot_style[method_name])
            plt.legend(loc = 'lower right', fontsize='xx-small')

            plt.subplot(121)
            precision, recall, thresholds = precision_recall_curve(
                loc_labels, loc_scores, pos_label=1)
            plt.plot(precision, recall,
                     label = method_name_nicename[method_name]+(' AUC = %0.2f' % (100*auprc))+"%",
                     linewidth=2,
                     **method_name_to_plot_style[method_name])
            plt.legend(loc = 'top right', fontsize='xx-small')

        plt.show()"""

sorted_keys=(sorted(motif_info, key=lambda x: motif_info[x][1]))
for key in sorted_keys[::-1]:
    print(key+": "+str(motif_info[key]))

Number of motifs is 75
CCACYAGRKGGC: [0.9673349535676157, 0.5928244336691737, 1109.0]
CACTAGRGGG: [0.9423573897882155, 0.5124860616364718, 891.0]
ATGACTCA: [0.9604881808090586, 0.4193363364006677, 1682.0]
ATGACTCATC: [0.9620973581649963, 0.38476765414164704, 1631.0]
NNATGASTCATN: [0.9642795089437416, 0.3477527203700382, 1535.0]
TGASTCAB: [0.9160305318845535, 0.33308061558106095, 2566.0]
CCRCTAGGKG: [0.8860380826113686, 0.30047120895261625, 492.0]
CCCCTAGTGGCC: [0.8764853390833651, 0.24667310561093098, 463.0]
STTAVTCABH: [0.8370863368932051, 0.22212986934994597, 3061.0]
CYAGGGGGCGCT: [0.8261289006555468, 0.14216789899299512, 439.0]
CAGTCATK: [0.7196793998913074, 0.10222402032109734, 7031.0]
CTAGCGGC: [0.7113443775900095, 0.1005217699470613, 2093.0]
AGAGGGCGCT: [0.5714221638218304, 0.04576471731568443, 3508.0]
SCYYTARR: [0.5846211977737927, 0.022476745568608486, 5576.0]
TTAATGATTAAC: [0.8256820977357283, 0.022312412229617832, 243.0]
CGCGCBCT: [0.5008574939684869, 0.021344055191996662, 33

# Sanity Check

#### Extract base pair sequences

In [None]:
shuffled_motifs={'SAGATAAV', 'NVCTTTYACA', 'NTRTCTAGCK', 'NNATGASTCATN', 'DCTGTRAAARVN', 'STTCYDGGAA', 'WGAAAAGCAG', 'CGCGGCGC', 'CAGCCAAANV', 'RCCAATCG', 'NBNTTATCTG', 'GGCGSCGSCG', 'CMYCTRGTGG', 'SNAAACMGCH', 'GGGGGAGGGGSN', 'AGGAAACG', 'CAGCCAMA', 'GNTCTCGCGAGA', 'AAAAGCHGNCTN', 'GCCGCTAG', 'AACAGCAG', 'VAGRYGGCGSCN', 'GCCCCGCCCC', 'NBTGTYTAGCTG', 'ACCAGAAG', 'YGTTTCCHNN', 'HRAATGGAAT', 'CCTTCCCC', 'TGTCTAGC', 'GCTTTGCAAA', 'AAAGCCDC', 'AAACCAGA', 'AAAGSCNNNN', 'CCCCNCCC', 'YGTTTACA', 'RMTGACWG', 'RKCWGTAAAR', 'YTTCTGGTTT', 'GCTCCTCC', 'ATGACWDCRN', 'GSGGGGGSGGSG', 'ATGACTCA', 'GCCMYCTRGTGG', 'ACTTCCTB', 'NWGATAANVNNN', 'NNGCGCVKSCGC', 'AGATAGGRNNNN', 'RCTTCCTBYYNN', 'CCGCGCGGCG', 'CTTTCACA', 'NVMGCNVGCG', 'BYTGCTGTTW', 'GTGTGTGT', 'GATGAGTCAT', 'RCCAATSVSNNN', 'GTGACDTC', 'HGTSACTTVD', 'GCTCCWCCCG', 'CSGCGGCG', 'NHGCAGAAAARA', 'NTSTGGCTKBHD', 'CGCCCTCT', 'CTTTCMCAGAAG', 'NDGTCACRTGAC', 'TTTACARWCCYT', 'RGAGGAAGYG', 'NNDGHAGTCACT', 'GWGTGTGTGT', 'NNNARACAGCAR', 'AGCGCGCG', 'NCMVCTCCCYCN', 'NGBCCCGCGVGV', 'GWGTGTGTGTGT', 'DGCTGTTW', 'NGGGYGGRGCSR'}
original_motifs={'AAAGCSDC', 'HGGCCCCGCCCC', 'ACCAATCG', 'NNDACNGCNS', 'TTGGWGAACCTT', 'CCAATCAG', 'CCCCTCCCCCNC', 'ACTTCCTSYTBN', 'CAGCCAAA', 'AGGGTTTGTAAA', 'ASMMAAACAS', 'STTATCWG', 'GGCGCGGGCGCG', 'GAGTGTGTGTGT', 'YTGTCAGB', 'RTGACGTCAYCS', 'TAAAAGCAGGCT', 'CMYCTAGYGG', 'ACTACRNYTCCC', 'GGCGGCGGCG', 'HKCGCGCG', 'CGCAAGATTTAT', 'TGBAAACG', 'AGTGACCTCTAG', 'VTCATGTGAC', 'TGTGAAAG', 'GCCMCGCC', 'DGTGACGTCA', 'CGCCGCCG', 'CACYAGRGGG', 'TGTCTAGC', 'GCWGHMAAAAMV', 'BYTSTGGTTT', 'CGCGCNNGCG', 'ATTATACGCTAA', 'RMTCTCGCGAGA', 'RRCCAATCRG', 'CTAGCGGC', 'CWGATAAGANNN', 'TBTRKCTAGCTV', 'GGGHGGAGCC', 'CTCCGCCC', 'GCCCCGCCCC', 'VTTCYNGGAA', 'ATGACTCA', 'TCGAATGGAATC', 'CYCTGTMAAA', 'KGCKGTTT', 'RASMGGAAGT', 'CTGCGCATGCGC', 'CCGCGCGGCG', 'DACAGCWG', 'NGVTGASTCATC', 'GATGAGTCAT', 'AAGATGGCGGCS', 'RCCAATCAGMDB', 'TGGTCTAGCGGT', 'CTTATCTSNN', 'GRGGRAGT', 'CCCCCCCC', 'GCCMCCTAGTGG', 'CCMCKCCCMC', 'CGCATGCG', 'CGCCCTCT', 'TKAGCATGCT', 'WGATAAVS', 'KVKCGCGVGA', 'RYCAYRTGRYHN', 'NBCTTATCTS', 'GTGACTWC', 'GTRTCTAGCT', 'CTTCCGGT', 'SGWTYGTRAA', 'AGCGCCCCCT'}
print ("Intersects:", shuffled_motifs.intersection(original_motifs))

In [None]:
import gzip
import re
from collections import OrderedDict
def load_sequences_from_bedfile(seqfile):
    seqs = OrderedDict()
    fp = gzip.open(seqfile, "rb")
    print("#Loading " + seqfile + " ...")
    for line in fp:
        (label, sequence)=line.split()
        seqs[label]=sequence
    fp.close()
    print("#Loaded " + str(len(seqs.keys())) + " sequences from " + seqfile)
    return seqs

In [None]:
def rename(label):
    match=re.match('.*_(chr.*)$',label)
    if match:
        return match.group(1)
    else:
        return ""

In [None]:
data_filename_positive = "/users/eprakash/projects/benchmarking/newdata/SPI1_in_K562/SPI1.pos.summits.implanted.bed.gz"
data_filename_negative = "/users/eprakash/projects/benchmarking/newdata/SPI1_in_K562/SPI1.neg.summits.implanted.bed.gz"
labeled_sequences = load_sequences_from_bedfile(data_filename_positive)
neg_seqs = load_sequences_from_bedfile(data_filename_negative)
labeled_sequences.update(neg_seqs)
sequences={}
for (label,seq) in labeled_sequences.items():
    sequences.update({rename(label):seq})

#### Select random 5 sequences for motif

In [None]:
from random import *
motif='CTTCCTTYCT'
motif_seqs=[value[0] for value in motif_id_to_hit_locations[motif]]
random_indices=[randint(0,len(motif_seqs)) for i in range(0,5)]
print (random_indices)
selected_seqs=[motif_seqs[i] for i in random_indices]
print (selected_seqs)

#### Onehot encode

In [None]:
def one_hot_encode_along_channel_axis(sequence):
    to_return = np.zeros((len(sequence),4), dtype=np.int8)
    seq_to_one_hot_fill_in_array(zeros_array=to_return,
                                 sequence=sequence, one_hot_axis=1)
    return to_return

def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
    assert one_hot_axis==0 or one_hot_axis==1
    if (one_hot_axis==0):
        assert zeros_array.shape[1] == len(sequence)
    elif (one_hot_axis==1): 
        assert zeros_array.shape[0] == len(sequence)
    #will mutate zeros_array
    for (i,char) in enumerate(sequence):
        if (char=="A" or char=="a"):
            char_idx = 0
        elif (char=="C" or char=="c"):
            char_idx = 1
        elif (char=="G" or char=="g"):
            char_idx = 2
        elif (char=="T" or char=="t"):
            char_idx = 3
        elif (char=="N" or char=="n"):
            continue #leave that pos as all 0's
        else:
            raise RuntimeError("Unsupported character: "+str(char))
        if (one_hot_axis==0):
            zeros_array[char_idx,i] = 1
        elif (one_hot_axis==1):
            zeros_array[i,char_idx] = 1
#sequences = sequences[-5000:]            
onehot_data = np.array([one_hot_encode_along_channel_axis(sequences[select_seq]) for select_seq in selected_seqs])

In [None]:
print (onehot_data.shape)

#### Get relevant scores

In [None]:
selected_relevant_indices_list=[list(labels).index(seq) for seq in selected_seqs]
selected_rescale_conv_scores=get_relevant_scores(selected_relevant_indices_list, original_rescale_conv_scores, seq_len)
print (selected_rescale_conv_scores.shape)

In [None]:
def split_sequence_by_motif_name(motif_name, list_of_dicts):
    list_with_motif = list()
    list_without_motif = list()
    for one_motif_occurrence in list_of_dicts:
        full_motif = one_motif_occurrence['motif']
        essential_motif = re.match('(\d+)-(\w+)-(\d+)',full_motif).group(2)
        if essential_motif == motif_name:
            list_with_motif.append(one_motif_occurrence)
        else:
            list_without_motif.append(one_motif_occurrence)
    return (list_with_motif, list_without_motif)

#### Visualization

In [None]:
%matplotlib inline
from deeplift.visualization import viz_sequence
for i in range(0,len(selected_seqs)):
    scores=selected_grad_times_input_scores[i]
    seq_id=selected_seqs[i]
    print(seq_id)
    (list_with_motif, list_without_motif)=split_sequence_by_motif_name(motif, motif_matches[seq_id])
    viz_sequence.plot_weights(onehot_data[i]*scores[:, None], subticks_frequency=10, highlight={'blue':[(seq_info['begin'],seq_info['end']) for seq_info in list_with_motif]})
    viz_sequence.plot_weights(onehot_data[i]*scores[:, None], subticks_frequency=10, highlight={'red':[(seq_info['begin'],seq_info['end']) for seq_info in list_without_motif]})

In [None]:
pwm = np.array([[0.036,0.170,0.033,0.761],[0.001,0.635,0.125,0.239],[0.031,0.128,0.004,0.837],[0.001,0.676,0.166,0.157],
                [0.058,0.113,0.040,0.789],[0.007,0.676,0.207,0.110],[0.007,0.125,0.001,0.867],[0.001,0.660,0.206,0.133],
                [0.005,0.186,0.001,0.808],[0.001,0.607,0.241,0.151],[0.054,0.231,0.006,0.709],[0.025,0.540,0.217,0.218]])
print (pwm.shape)

In [None]:
%matplotlib inline
from deeplift.visualization import viz_sequence
viz_sequence.plot_weights(pwm, subticks_frequency=10)