In [1]:
from collections import OrderedDict
import gzip
import numpy as np
def load_sequences_from_bedfile(seqfile):
    seqs = []
    seqname_to_motifs = OrderedDict()
    fp = gzip.open(seqfile, "rb")
    print("#Loading " + seqfile + " ...")
    for line in fp:
        line=line.decode('utf8').split()
        seqs.append(line[1])
        if (len(line) > 2):
            seqname_to_motifs[line[0]] = line[2]
        else:
            seqname_to_motifs[line[0]] = ""
    fp.close()
    print("#Loaded " + str(len(seqs)) + " sequences from " + seqfile)
    return np.array(seqs), seqname_to_motifs

In [3]:
data_filename_test_positive = "/users/eprakash/git/interpret-benchmark/data/dnase_positives/common_scripts/H1/sequences/test_sim_positives.txt.gz"
data_filename_test_negative = "/users/eprakash/git/interpret-benchmark/data/dnase_positives/common_scripts/H1/sequences/test_sim_negatives.txt.gz"

pos_seqs, pos_seqname_to_motifs = load_sequences_from_bedfile(data_filename_test_positive)
neg_seqs, neg_seqname_to_motifs = load_sequences_from_bedfile(data_filename_test_negative)
neg_seqs = neg_seqs[:len(pos_seqs)]
seqs = np.concatenate((pos_seqs, neg_seqs), axis = 0)

#There were technically more negatives than positives in both the initial training and test sets.
#However, the momma dragonn TwoStream data loader makes sure to only use as many negative sequences
#as positive sequences if negatives_to_positives_ratio=1 (which it was). Thus, only the first len(pos_seqs)
#neg_seqs were used in training/testing.

#Loading /users/eprakash/git/interpret-benchmark/data/dnase_positives/common_scripts/H1/sequences/test_sim_positives.txt.gz ...
#Loaded 17312 sequences from /users/eprakash/git/interpret-benchmark/data/dnase_positives/common_scripts/H1/sequences/test_sim_positives.txt.gz
#Loading /users/eprakash/git/interpret-benchmark/data/dnase_positives/common_scripts/H1/sequences/test_sim_negatives.txt.gz ...
#Loaded 182795 sequences from /users/eprakash/git/interpret-benchmark/data/dnase_positives/common_scripts/H1/sequences/test_sim_negatives.txt.gz


In [4]:
def one_hot_encode_along_channel_axis(sequence):
    to_return = np.zeros((len(sequence),4), dtype=np.int8)
    seq_to_one_hot_fill_in_array(zeros_array=to_return,
                                 sequence=sequence, one_hot_axis=1)
    return to_return

def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):
    assert one_hot_axis==0 or one_hot_axis==1
    if (one_hot_axis==0):
        assert zeros_array.shape[1] == len(sequence)
    elif (one_hot_axis==1): 
        assert zeros_array.shape[0] == len(sequence)
    #will mutate zeros_array
    for (i,char) in enumerate(sequence):
        if (char=="A" or char=="a"):
            char_idx = 0
        elif (char=="C" or char=="c"):
            char_idx = 1
        elif (char=="G" or char=="g"):
            char_idx = 2
        elif (char=="T" or char=="t"):
            char_idx = 3
        elif (char=="N" or char=="n"):
            continue #leave that pos as all 0's
        else:
            raise RuntimeError("Unsupported character: "+str(char))
        if (one_hot_axis==0):
            zeros_array[char_idx,i] = 1
        elif (one_hot_axis==1):
            zeros_array[i,char_idx] = 1
onehot_data = np.array([one_hot_encode_along_channel_axis(seq) for seq in seqs])

In [10]:
import keras
from keras.models import model_from_json
import tensorflow as tf

model_weights = "/users/eprakash/git/interpret-benchmark/data/dnase_positives/momma_dragonn_config/no_preinitialization/dense/H1/temp/model_files/record_1_model_BHYtD_modelWeights.h5"
model_json = "/users/eprakash/git/interpret-benchmark/data/dnase_positives/momma_dragonn_config/no_preinitialization/dense/H1/temp/model_files/record_1_model_BHYtD_modelJson.json"

model = model_from_json(open(model_json).read())
model.load_weights(model_weights)

In [11]:
preds = model.predict(onehot_data)

In [14]:
from sklearn.metrics import confusion_matrix
labels = np.concatenate((np.full(pos_seqs.shape, 1), np.full(neg_seqs.shape, 0)), axis = 0)
print(confusion_matrix(labels, preds>0.5))

[[10677  6635]
 [ 8977  8335]]
