In [2]:
%load_ext Cython

In [57]:
%%cython
from libc.math cimport abs, sqrt, ceil, floor, log, log2, acos, cos
from numpy cimport ndarray
import numpy as np
from numpy import ones, zeros, int32, float32, uint8, fromstring
from numpy import sort, empty, array, arange, concatenate, searchsorted
from numpy import minimum, maximum,divide, std, mean
from numpy import min as nmin
from numpy import max as nmax
import h5py as h

def pairRegionsIntersection(ndarray[int, ndim=2] pairs,
                            ndarray[int, ndim=2] regions,
                            exclude=False, allow_partial=False,
                            bool_out=False):
    '''
    Given an list of pairs of the form [a,b] where a<b and regions of the form [c,d] we want to return the indices
    (within the pair list) of pairs which overlap with the regions. This will be used to exclude certain regions of the
    genome if needed.
     
    Arguments:
    
    - pairs: (N,2) shape array where each row details a pair.
    - regions: (M,2) shape array where each row details a region.
    - exclude: Boolean. If true, then return the indices of pairs which don't overlap with any of the regions.
    - allow_partial: Boolean. If true then include pair indices which only partially overlap with a given region.
                     Here we consider the 'partial' to apply to the pair. That is, if allow_partial == False then
                     we must satisfy either:
                         - the entire pair is included in a region if exclude == False
                         - The entire pair must not be included in a region if exclude == True
                     If allow_partial == True then we must satisfy:
                         - At least part of the pair touches at least one region if exclude == False
                         - At least part of the pair touches no regions if exclude == True
                      
    Returns:
    
    - indices: The indices of the pairs in the pair array which satisfy the overlapping conditions.
    '''
    cdef int i, j, k, a, b
    cdef int exc = int(exclude)
    cdef int partial = int(allow_partial)
    cdef int bool_o = int(bool_out)
    cdef int ni = 0
    cdef int np = len(pairs)
    cdef int nr = len(regions)
    cdef ndarray[int, ndim=1] indices = zeros(np, int32)
    cdef ndarray[int, ndim=1] indices_out = empty(np, int32)
    cdef ndarray[int, ndim=1] order = array(regions[:,0].argsort(), int32)
    cdef ndarray[int, ndim=2] regs = empty((nr+1,2), int32)
    cdef minpoint = minimum(nmin(pairs), nmin(regions))
    cdef maxpoint = maximum(nmax(pairs), nmax(regions))

    if exc:
        for i in range(nr):
            regs[i,1] = regions[order[i],0]
        regs[0,0] = minpoint - 1
        regs[-1,1] = maxpoint + 1
    else:
        regs = regions
    
    
    for i in range(np):
        for k in range(nr):
            j = order[k]
            if (regs[j,1] > pairs[i,0]) and (pairs[i,1] > regs[j,0]):
                #Pair is at least partially overlapping with the region
                if partial:
                    indices[i] = 1
                    if bool_o:
                        return 1
                    continue
                elif (regs[j,1] >= pairs[i,1]) and (pairs[i,0] >= regs[j,0]):
                    #Pair is entirely containing within the region
                    indices[i] = 1
                    if bool_o:
                        return 1
                    continue
    
    if bool_o:
        return 0
    
    for i in range(np):
        if indices[i] == 1:
            indices_out[ni] = i
            ni +=1
    
    return indices_out[:ni]

def multi_pairRegionsIntersection(ndarray[int, ndim=3] pairs,
                                  ndarray[int, ndim=3] regions,
                                  exclude=False, allow_partial=False, indices = False):
    '''
    Given an list of pairs of the form [a,b] where a<b and regions of the form [c,d] we want to work pairwise to
    return the region indices of the regions which overlap with each pair.
     
    Arguments:
    
    - pairs: (N,2) shape array where each row details a pair.
    - regions: (M,2) shape array where each row details a region.
    - exclude: Boolean. If true, then return the indices of pairs which don't overlap with any of the regions.
    - allow_partial: Boolean. If true then include pair indices which only partially overlap with a given region.
                     Here we consider the 'partial' to apply to the pair. That is, if allow_partial == False then
                     we must satisfy either:
                         - the entire pair is included in a region if exclude == False
                         - The entire pair must not be included in a region if exclude == True
                     If allow_partial == True then we must satisfy:
                         - At least part of the pair touches at least one region if exclude == False
                         - At least part of the pair touches no regions if exclude == True
                      
    Returns:
    
    - out: A maximum of an (N*M,2) shape array where each column details a pair group index and a regions group
           index where than pair group and that region group overlap. This is essentially gonna be in COO
           format
    '''
    cdef int i, j, k, np, nr,nl
    cdef int overlap
    cdef int ind = int(indices)
    cdef int maxnp = len(pairs[0,:,0]) 
    cdef int maxnr = len(regions[0,:,0])
    cdef int minpairs = nmin(pairs)
    cdef int minregions = nmin(regions)
    cdef int npairgroups = len(pairs)
    cdef int nreggroups = len(regions)
    cdef ndarray[int, ndim=2] out = zeros((npairgroups*nreggroups,2), int32)
    
    for i in range(npairgroups):
        print("#################################")
        print("pair {}".format(i))
        print("Unedited pairs:")
        print(pairs[i,:,:])
        np = maxnp
        for k in range(maxnp):
            if pairs[i,k,0] == minpairs:
                np = k
                break
        for j in range(nreggroups):
            print("#################################")
            print("region {}".format(j))
            print("Unedited regions:")
            print(regions[j,:,:])
            nr = maxnr
            for k in range(maxnr):
                if regions[j,k,0] == minregions:
                    nr = k
                    break
            print("******************************")
            print("Edited pairs:")
            print(pairs[i,:np,:])
            print("Edited regions:")
            print(regions[j,:nr,:])
            overlap = pairRegionsIntersection(pairs[i,:np,:],
                                              regions[j,:nr,:],
                                              allow_partial = allow_partial,
                                              exclude = exclude,
                                              bool_out = True)
            print("Output: {}".format(overlap))
            if overlap == 1:
                out[nl,0] = i
                out[nl,1] = j
                nl +=1
    
    return out[:nl,:]
                



In [60]:
import numpy as np
pairs = np.array([[[10,50],
                   [51,52],
                   [55,65],
                   [70,100]],[[5,20],
                              [21,22],
                              [-1,-1],
                              [-1,-1]],[[110,130],
                                        [-1,-1],
                                        [-1,-1],
                                        [-1,-1]]]).astype('int32')
regions = np.array([[[0,5],
                     [7,9],
                     [-1,-1],
                     [-1,-1],
                     [-1,-1]],[[2,7],
                               [10,22],
                               [7,17],
                               [21,100],
                               [105,151]],[[10,20],
                                           [21,22],
                                           [34,47],
                                           [51,100],
                                           [101,130]],[[10,20],
                                                       [30,40],
                                                       [50,60],
                                                       [-1,-1],
                                                       [-1,-1]]]).astype('int32')

In [62]:
multi_pairRegionsIntersection(pairs, regions)

#################################
pair 0
Unedited pairs:
[[ 10  50]
 [ 51  52]
 [ 55  65]
 [ 70 100]]
#################################
region 0
Unedited regions:
[[ 0  5]
 [ 7  9]
 [-1 -1]
 [-1 -1]
 [-1 -1]]
******************************
Edited pairs:
[[ 10  50]
 [ 51  52]
 [ 55  65]
 [ 70 100]]
Edited regions:
[[0 5]
 [7 9]]
Output: 0
#################################
region 1
Unedited regions:
[[  2   7]
 [ 10  22]
 [  7  17]
 [ 21 100]
 [105 151]]
******************************
Edited pairs:
[[ 10  50]
 [ 51  52]
 [ 55  65]
 [ 70 100]]
Edited regions:
[[  2   7]
 [ 10  22]
 [  7  17]
 [ 21 100]
 [105 151]]
Output: 1
#################################
region 2
Unedited regions:
[[ 10  20]
 [ 21  22]
 [ 34  47]
 [ 51 100]
 [101 130]]
******************************
Edited pairs:
[[ 10  50]
 [ 51  52]
 [ 55  65]
 [ 70 100]]
Edited regions:
[[ 10  20]
 [ 21  22]
 [ 34  47]
 [ 51 100]
 [101 130]]
Output: 1
#################################
region 3
Unedited regions:
[[10 20]
 [30 40]
 [5

array([[0, 1],
       [0, 2],
       [0, 3],
       [1, 1],
       [1, 2],
       [2, 1],
       [2, 2]], dtype=int32)