In [1]:
from __future__ import print_function, division
%matplotlib inline

try:
    reload  # Python 2.7
except NameError:
    try:
        from importlib import reload  # Python 3.4+
    except ImportError:
        from imp import reload  # Python 3.0 - 3.3

In [2]:
import numpy as np
import modisco

onehot_data = np.load('extracted_onehot.npy')[:,:]
hypothetical_data = np.load('extracted_hypothetical_scores.npy')[:,:]
meannorm_hypothetical_data = (hypothetical_data
                              -np.mean(hypothetical_data, axis=-1)[:,:,None])
contrib_data = np.load('extracted_contrib_scores.npy')[:,:]
#perposimp_hypmeannorm = np.sum(meannorm_hypothetical_data*onehot_data,axis=-1)
perposimp = np.sum(contrib_data*onehot_data,axis=-1)

flanksize=20
contrib_scores_track = modisco.core.DataTrack(name="contrib_scores",
                                 fwd_tracks=contrib_data,
                                 rev_tracks=contrib_data[:,::-1, ::-1])
hypcontrib_scores_track = modisco.core.DataTrack(name="hypcontrib_scores",
                                    fwd_tracks=hypothetical_data,
                                    rev_tracks=hypothetical_data[:,::-1, ::-1])
meannorm_hypcontrib_scores_track = modisco.core.DataTrack(name="meannorm_hypcontrib_scores",
                                    fwd_tracks=meannorm_hypothetical_data,
                                    rev_tracks=meannorm_hypothetical_data[:,::-1, ::-1])
onehot_track = modisco.core.DataTrack(
                 name="onehot", fwd_tracks=onehot_data,
                 rev_tracks=onehot_data[:,::-1, ::-1])
track_set = modisco.core.DataTrackSet(data_tracks=[
                contrib_scores_track, hypcontrib_scores_track,
                meannorm_hypcontrib_scores_track, onehot_track
               ])

coords = []
for example_idx in range(len(onehot_data)):
    #figure out the 6bp window with highest imp
    padded_cumsum = np.array([0]+list(np.cumsum(perposimp[example_idx][50:50+41])))
    sliding_window_imp = padded_cumsum[7:] - padded_cumsum[:-7]
    start = np.argmax(sliding_window_imp)+50
    coords.append(modisco.core.Coordinate(
                    example_idx=example_idx, start=start-17,
                    end=start+24, is_revcomp=False))

all_seqlets = [x for x in track_set.create_seqlets(coords=coords, flanks=30)
               if np.sum(x["contrib_scores"].corefwd) > 0]




In [3]:
n_seqlets_in_subsample = 5000 #len(all_seqlets)#40000
seqlets_subsample = [all_seqlets[i] for i in np.random.RandomState(1).choice(
                     np.arange(len(all_seqlets)), size=n_seqlets_in_subsample,replace=False)]

## Initial test using just correlation

In [4]:
#The organization of featvec ends up being as follows:
# featvec has first dim of (len(seqlets)*(max_shift*2 + 1)*2); the *2 at end is for revcomp
# second dim of (len(score_track_names*4)*(2*max_shift + corelen))
# corelen is 21
def get_feature_vectors(seqlets, scoretracknames, max_shift):
    possible_shifts = list(range(-max_shift, max_shift+1))
    feature_vectors = []
    for seqlet in seqlets:
        for is_revcomp in [False, True]:
            #to minimize memory overhead, would be good to store the different views as slices
            # into a single array, rather than as completely separate arrays.
            tracks = [seqlet[scoretrackname].get_core_with_flank(left=max_shift, right=max_shift,
                                                                 is_revcomp=is_revcomp)
                      for scoretrackname in scoretracknames]
            concatenated_tracks = np.concatenate(tracks, axis=-1)
            assert len(concatenated_tracks.shape)==2
            corelen = concatenated_tracks.shape[0]-2*max_shift
            for shift in possible_shifts:
                slice_start = max_shift+shift
                feature_vectors.append(concatenated_tracks[slice_start:slice_start+corelen])
    return np.array(feature_vectors).reshape((len(feature_vectors),-1))

featvec = get_feature_vectors(seqlets=seqlets_subsample,
                                  scoretracknames=["contrib_scores", "hypcontrib_scores"],
                                  max_shift=10)

In [None]:
#get a correlation-based similarity using pynndescent
from pynndescent import NNDescent
import time
n_neighbors = 20
start = time.time()
nnd = NNDescent(data=featvec,
                n_neighbors=n_neighbors,
                metric="correlation",
                metric_kwds={},
                random_state=1234,
                max_candidates=60, #value used in UMAP: https://github.com/lmcinnes/umap/blob/9f66cafdef9c666082b2da188c2ae9bff60bc763/umap/umap_.py#L297
                verbose=True,
                tree_init=True,
                n_jobs=4)
knn_indices, knn_dists = nnd.neighbor_graph
print("Took:",time.time()-start,"s")

In [95]:
from joblib import Parallel, delayed


def remap_knn_result_single_ex(knn_indices_for_views_of_ex,
                               knn_dists_for_views_of_ex,
                               views_per_ex,
                               topn_neighb_to_keep):
    flat_knn_indices_for_views_of_ex = knn_indices_for_views_of_ex.ravel()
    flat_knn_dists_for_views_of_ex = knn_dists_for_views_of_ex.ravel()
    
    dists_argsort = np.argsort(flat_knn_dists_for_views_of_ex)
    remapped_neighbor_indices = (flat_knn_indices_for_views_of_ex/views_per_ex).astype("int")
    
    topn_remapped_neighbors = []
    topn_remapped_neighbor_dists = []
    topn_remapped_neighbors_set = set()
    
    for dist_argidx in dists_argsort:
        dist = flat_knn_dists_for_views_of_ex[dist_argidx]
        remapped_neighbor = remapped_neighbor_indices[dist_argidx]
        if remapped_neighbor not in topn_remapped_neighbors_set:
            topn_remapped_neighbors_set.add(remapped_neighbor)
            topn_remapped_neighbors.append(remapped_neighbor)
            topn_remapped_neighbor_dists.append(dist)
        if (len(topn_remapped_neighbors_set)==topn_neighb_to_keep):
            break
    assert len(topn_remapped_neighbors)==topn_neighb_to_keep, len(topn_remapped_neighbors)
    
    return (np.array(topn_remapped_neighbors), np.array(topn_remapped_neighbor_dists))
    

#for all the different 'views' of each individual seqlet (shifts & revcomp),
# go back and figure out what the nearest neighbor pairs for individual
# seqlets would be
def remap_knn_result(knn_indices, knn_dists, orig_num_examples, topn_neighb_to_keep, n_jobs):
    views_per_ex = knn_indices.shape[0]/orig_num_examples
    assert int(views_per_ex)==views_per_ex
    views_per_ex = int(views_per_ex)
    
    #this collation can be done in parallel
    remapped_indices_and_dists = Parallel(n_jobs=n_jobs)(
        delayed(remap_knn_result_single_ex)(knn_indices[i*views_per_ex:(i+1)*views_per_ex],
                                            knn_dists[i*views_per_ex:(i+1)*views_per_ex],
                                            views_per_ex,
                                            topn_neighb_to_keep)
        for i in range(orig_num_examples))
    
    remapped_indices = np.stack([x[0] for x in remapped_indices_and_dists])
    remapped_dists = np.stack([x[1] for x in remapped_indices_and_dists])
    
    return remapped_indices, remapped_dists

remapped_knn_indices, remapped_knn_dists = remap_knn_result(
    knn_indices=knn_indices, knn_dists=knn_dists,
    orig_num_examples=len(seqlets_subsample),
    topn_neighb_to_keep=n_neighbors,
    n_jobs=1)


AssertionError: 7

In [67]:
#let's investigate concordance between remapped_indices and the contin-jacc approach

In [68]:
continjacc_affmat = np.load("affmat_recenter_41bpseqlets_0.75.npy")

In [78]:
def get_nn_ranking_for_affmat(affmat):
    nn_ranking = np.zeros_like(affmat).astype("int")
    for row_idx,row in enumerate(affmat):
        row_argsort = np.argsort(-row)
        nn_ranking[row_idx,row_argsort] = np.arange(len(row_argsort))
    return nn_ranking

continjacc_nn_ranking = get_nn_ranking_for_affmat(continjacc_affmat)

In [84]:
continjacc_nn_ranking_of_pnndesc_neighbs = np.array([
    continjacc_nn_ranking[row_idx, pynndesc_knn_indices]
    for (row_idx, pynndesc_knn_indices) in enumerate(remapped_knn_indices)
])

In [79]:
continjacc_nn_ranking

array([[   0, 2389, 3745, ..., 1680,   43, 2801],
       [2862,    0, 2150, ..., 3014,  806, 1856],
       [3849, 1564,    0, ..., 3956, 1733, 2922],
       ...,
       [ 668, 1343, 2492, ...,    0,  477,  853],
       [ 642,  329, 1698, ..., 1857,    0, 2757],
       [3842, 1169, 3523, ..., 3145, 3823,    0]])

In [80]:
remapped_knn_indices

array([array([   0, 1227, 4663, 2808, 1314, 2732,   75,  550,  982, 3994, 1804,
       1660, 2032,  400, 4837, 3554, 4264, 1523, 2388,  218]),
       array([   1,  500, 2546, 2248, 1973, 4636, 4169,  295,  217, 4620, 3718,
        353, 1186, 1114, 3056,  238, 2423, 1861, 2055, 4021]),
       array([   2, 1444, 1773, 1545, 4868, 2208, 1359,   37, 3379, 3372, 4125,
       1762, 2195, 4220, 3216, 3903, 4609, 4903, 1810,  902]),
       ...,
       array([4997, 1239, 2776, 4420,  511,  774, 2705, 1233, 3269, 4275, 4283,
       1275, 2934,  251, 2885, 1918, 1769, 2445, 2179, 2036]),
       array([4998, 2756, 1009, 2028, 4582, 4301, 2993,  520, 2356, 3925, 3653,
       2861, 2966,  659, 3022, 1971, 1028, 1899, 1986, 4307]),
       array([4999, 2672, 4652,  174, 1927, 1486, 3369, 1852, 1820, 3640, 1647,
        944, 3130, 4491, 4205, 4526, 1723, 3086, 2540, 3142])],
      dtype=object)

In [83]:
continjacc_nn_ranking[0, remapped_knn_indices[0]]

array([   0,  541, 4282, 3871, 1982,   16, 3966,  486, 4786, 1533, 3741,
       1840, 2233, 2694, 2894, 2573,  628, 3091, 1740, 2507])

In [88]:
from matplotlib import pyplot as plt

plt.hist(continjacc_nn_ranking_of_pnndesc_neighbs.ravel(), bins=20)
plt.show()

KeyboardInterrupt: 

Error in callback <function flush_figures at 0x11acaba70> (for post_execute):


KeyboardInterrupt: 

In [91]:
np.array(continjacc_nn_ranking_of_pnndesc_neighbs)

array([array([   0,  541, 4282, 3871, 1982,   16, 3966,  486, 4786, 1533, 3741,
       1840, 2233, 2694, 2894, 2573,  628, 3091, 1740, 2507]),
       array([  0, 261,   1,  22,   6,   7,  78,   4,  13,  11,  47,  10,   2,
       290,  17,  26,  39,  50,  68,  52]),
       array([   0,    1,    5,   12,    4,   15,   39,   94,   20,    6, 2309,
         43,    2, 2888, 4460,  143, 4071,  120, 4127,  177]),
       ...,
       array([  0,   1,   2,   4,   3,  37,   8,  13,  58,  20,  76,   6,  33,
        19,  11, 280,   5,  17,  18,  51]),
       array([   0,    1,    4,    2,    3,    6, 1983,   13,    5,    8,    7,
       1326,   11,   34, 2718, 1099,   28,   29,    9, 1276]),
       array([  0,  17,  53,  11,   7,   2, 151, 291,  14,  12,  22,   1, 152,
        94,   8, 486,  16,  56, 806, 148])], dtype=object)