Testing ipyrad functions for speed improvements with numba

In [1]:
import numpy as np
from collections import Counter
from numba import jit

In [2]:
data = [list("AAA-TTTT"),
        list("AAATTTTT"),
        list("AAA-TTTT"),
        list("AAA-TTTN"),
        list("AAA-TTTT"),
        list("AAT-TTCT"),
        list("AAT-TTCT"),
        list("A-T-TTNT"),
        list("AAA-TTCT"),
        list("AAT-TTCT") 
       ]
consens = "AAWNTTYT"
stack = [Counter(i) for i in np.array(data).T]

In [3]:
arr = np.array(data)[:,:]
coldepths = arr.shape[0]
print coldepths

ndepths = np.sum(arr=='N', axis=0)
print ndepths

idepths = np.sum(arr=='-', axis=0)
print idepths

10
[0 0 0 0 0 0 1 1]
[0 1 0 9 0 0 0 0]


In [4]:
np.array([np.sum(arr==x, axis=0) for x in list("CATG")]).T

array([[ 0, 10,  0,  0],
       [ 0,  9,  0,  0],
       [ 0,  6,  4,  0],
       [ 0,  0,  1,  0],
       [ 0,  0, 10,  0],
       [ 0,  0, 10,  0],
       [ 4,  0,  5,  0],
       [ 0,  0,  9,  0]])

In [5]:
import ipyrad as ip
import numpy as np
import gzip
import itertools

data1 = ip.load.load_assembly("test_rad/data1.assembly")
sample = data1.samples["1A_0"]
subsample = 500

  loading Assembly: data1 [test_rad/data1.assembly]
  New Assembly: data1


In [6]:
data = data1

clusters = gzip.open(sample.files.clusters)
pairdealer = itertools.izip(*[iter(clusters)]*2)
## array will be (nclusters, readlen, 4)
if "pair" in data.paramsdict["datatype"]:
    readlen = 2*data._hackersonly["max_fragment_length"]
else:
    readlen = data._hackersonly["max_fragment_length"]
dims = (int(sample.stats.clusters_hidepth), readlen, 4)
stacked = np.zeros(dims, dtype=np.int16)

In [7]:

def clustdealer(pairdealer, optim):
    """ return optim clusters given iterators, and whether it got all or not"""
    ccnt = 0
    chunk = []
    while ccnt < optim:
        ## try refreshing taker, else quit
        try:
            taker = itertools.takewhile(lambda x: x[0] != "//\n", pairdealer)
            oneclust = ["".join(taker.next())]
        except StopIteration:
            #LOGGER.debug('last chunk %s', chunk)
            return 1, chunk

        ## load one cluster
        while 1:
            try: 
                oneclust.append("".join(taker.next()))
            except StopIteration:
                break
        chunk.append("".join(oneclust))
        ccnt += 1
    return 0, chunk


In [8]:
def stackarray(data, sample):
    """ makes a list of lists of reads at each site """
    ## get clusters file
    clusters = gzip.open(sample.files.clusters)
    pairdealer = itertools.izip(*[iter(clusters)]*2)

    ## array will be (nclusters, readlen, 4)
    if "pair" in data.paramsdict["datatype"]:
        readlen = 2*data._hackersonly["max_fragment_length"]
    else:
        readlen = data._hackersonly["max_fragment_length"]
    dims = (int(sample.stats.clusters_hidepth), readlen, 4)
    stacked = np.zeros(dims, dtype=np.int16)

    ## don't use sequence edges / restriction overhangs
    cutlens = [None, None]
    for cidx, cut in enumerate(data.paramsdict["restriction_overhang"]):
        if cut:
            cutlens[cidx] = len(cut)
    try:
        cutlens[1] = -1*cutlens[1]
    except TypeError:
        pass
    #LOGGER.info(cutlens)

    ## fill stacked
    done = 0
    nclust = 0
    while not done:
        try:
            done, chunk = clustdealer(pairdealer, 1)
        except IndexError:
            raise IPyradError("clustfile formatting error in %s", chunk)
        if chunk:
            piece = chunk[0].strip().split("\n")
            names = piece[0::2]
            seqs = piece[1::2]
            ## pull replicate read info from seqs
            reps = [int(sname.split(";")[-2][5:]) for sname in names]
            sseqs = [list(seq) for seq in seqs]
            arrayed = np.concatenate(
                      [[seq]*rep for seq, rep in zip(sseqs, reps)])
            ## enforce minimum depth for estimates
            if arrayed.shape[0] >= data.paramsdict["mindepth_statistical"]:
                ## remove edge columns
                arrayed = arrayed[:, cutlens[0]:cutlens[1]]
                ## remove cols that are pair separator
                arrayed = arrayed[~np.any(arrayed == "n", axis=1)]
                ## convert - to N
                arrayed[arrayed == "-"] = "N"
                ## remove cols that are all Ns
                arrayed = arrayed[~np.any(arrayed == "n", axis=1)]                
                ## store in stacked dict
                catg = np.array(\
                    [np.sum(arrayed == i, axis=0) for i in list("CATG")], 
                    dtype='int16').T
                stacked[nclust, :catg.shape[0], :] = catg
                nclust += 1
    return stacked


In [9]:

def tablestack(rstack):
    """ makes a count dict of each unique array element """
    ## goes by 10% at a time to minimize memory overhead. Is possible it skips
    ## the last chunk, but this shouldn't matter.
    table = Counter()
    for i in xrange(0, rstack.shape[0], rstack.shape[0]//10):
        tmp = Counter([j.tostring() for j in rstack[i:i+rstack.shape[0]//10]])
        table.update(tmp)
    return table



In [108]:

def frequencies(stacked):
    """ return frequency counts """
    totals = stacked.sum(axis=1)
    totals = totals.sum(axis=0)
    freqs = totals/np.float32(totals.sum())
    return freqs


In [109]:
stacked = stackarray(data, sample)
bfreqs = frequencies(stacked)
rstack = stacked.reshape(stacked.shape[0]*stacked.shape[1],
                             stacked.shape[2])
tstack = tablestack(rstack)

In [12]:
stacks = np.array([np.fromstring(i, dtype=np.int16) \
                       for i in tstack.iterkeys()])

In [13]:
dropme = np.zeros(4, dtype=np.int16).tostring()

In [14]:
ustacks = np.array([np.fromstring(i, dtype=np.int16) \
                        for i in tstack.iterkeys()])
counts = np.array(tstack.values())

In [19]:
import scipy.optimize
import scipy.stats
import numba

In [78]:
startp = np.array([0.01, 0.001], dtype=np.float16)

In [110]:
bfreqs

array([ 0.24844044,  0.24991964,  0.2499051 ,  0.25173481])

In [54]:
scipy.stats.binom.pmf(10, 100., 0.1)

0.13186534682448681

In [159]:
?numba.jit('f2', nopython=True)

In [163]:
ustacks

array([[23,  1,  0,  0],
       [ 0, 16,  1,  0],
       [ 0,  0,  9, 11],
       ..., 
       [24,  0,  1,  0],
       [ 0,  0, 16,  1],
       [ 0,  0, 28,  0]], dtype=int16)

In [164]:
@jit(['float32[:,:](float32, float32, int16[:,:])'])
def jlikelihood1(errors, bfreqs, ustacks):
    """probability homozygous"""
    ## make sure base_frequencies are in the right order
    #print uniqstackl.sum()-uniqstack, uniqstackl.sum(), 0.001
    totals = np.array([ustacks.sum(axis=1)]*4).T
    prob = scipy.stats.binom.pmf(totals-ustacks, totals, errors)
    return np.sum(bfreqs*prob, axis=1)


def likelihood1(errors, bfreqs, ustacks):
    """probability homozygous"""
    ## make sure base_frequencies are in the right order
    #print uniqstackl.sum()-uniqstack, uniqstackl.sum(), 0.001
    totals = np.array([ustacks.sum(axis=1)]*4).T
    prob = scipy.stats.binom.pmf(totals-ustacks, totals, errors)
    return np.sum(bfreqs*prob, axis=1)

In [253]:
@numba.jit(['float32[:](float32, float32[:], int16[:,:])'])
def jlikelihood2(errors, bfreqs, ustacks):
    """probability of heterozygous"""
    returns = np.zeros(len(ustacks), dtype=np.float32)
    for idx, ustack in enumerate(ustacks):
        spair = np.array(list(itertools.combinations(ustack, 2)))
        bpair = np.array(list(itertools.combinations(bfreqs, 2)))
        one = 2.*bpair.prod(axis=1)
        tot = ustack.sum()
        atwo = tot - spair[:,0] - spair[:,1]
        two = scipy.stats.binom.pmf(atwo, tot, (2.*errors)/3.)
        three = scipy.stats.binom.pmf(\
                    spair[:,0], spair.sum(axis=1), 0.5)
        four = 1.-np.sum(bfreqs**2)
        returns[idx] = np.sum(one*two*(three/four))
    return np.array(returns)


def likelihood2(errors, bfreqs, ustacks):
    """probability of heterozygous"""
    returns = np.zeros([len(ustacks)])
    for idx, ustack in enumerate(ustacks):
        spair = np.array(list(itertools.combinations(ustack, 2)))
        bpair = np.array(list(itertools.combinations(bfreqs, 2)))
        one = 2.*bpair.prod(axis=1)
        tot = ustack.sum()
        atwo = tot - spair[:,0] - spair[:,1]
        two = scipy.stats.binom.pmf(atwo, tot, (2.*errors)/3.)
        three = scipy.stats.binom.pmf(\
                    spair[:,0], spair.sum(axis=1), 0.5)
        four = 1.-np.sum(bfreqs**2)
        returns[idx] = np.sum(one*two*(three/four))
    return np.array(returns)

In [243]:
spair = np.array(list(itertools.combinations(ustacks[0], 2)))
print spair
atwo = tot - spair[:,0] - spair[:,1]
print atwo

[[23  1]
 [23  0]
 [23  0]
 [ 1  0]
 [ 1  0]
 [ 0  0]]
[ 0  1  1 23 23 24]


In [239]:
print ustacks[0]

[23  1  0  0]


In [242]:
scipy.stats.binom.pmf(atwo, tot, (2.*0.001)/3.)
scipy.stats.binom.pmf(spair[:,0], spair.sum(axis=1), 0.5)

array([  1.43051344e-06,   1.19209453e-07,   1.19209453e-07,
         5.00000000e-01,   5.00000000e-01,   1.00000000e+00])

In [197]:
tot = ustacks[0].sum()
tot
atwo = tot - np.array([i[0] for i in sp])
atwo

array([], dtype=float64)

In [213]:
tot - np.array(list(sp))

array([[23,  1],
       [23,  0],
       [23,  0],
       [ 1,  0],
       [ 1,  0],
       [ 0,  0]], dtype=int16)

In [248]:
%%timeit 
likelihood2(0.001, bfreqs, ustacks)

10 loops, best of 3: 118 ms per loop


In [256]:
%%timeit 
jlikelihood2(0.001, bfreqs, ustacks)

10 loops, best of 3: 115 ms per loop


In [170]:
%%timeit 
likelihood2(0.001, bfreqs, ustacks)

10 loops, best of 3: 124 ms per loop


In [165]:
%%timeit 
jlikelihood1(0.001, bfreqs, ustacks)

The slowest run took 4.88 times longer than the fastest. This could mean that an intermediate result is being cached 
1000 loops, best of 3: 665 µs per loop


In [166]:
%%timeit 
likelihood1(0.001, bfreqs, ustacks)

1000 loops, best of 3: 596 µs per loop


### HOly Cow
using jit is so much faster! Just need to rewrite code to fill empty arrays instead of appending to lists and we should be able to use jit just fine for steps 1,2,4,5,7.

In [46]:
%%timeit
fillarr(arr, list("abcdefgh"))

The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 2.99 µs per loop


In [47]:
%%timeit
jfillarr(arr, list("abcdefgh"))

The slowest run took 778.38 times longer than the fastest. This could mean that an intermediate result is being cached 
1000 loops, best of 3: 209 µs per loop


In [1]:
arr

NameError: name 'arr' is not defined

In [2]:
## pure python
def findbcode(cut, longbar, read1):
    search = read1[1][:longbar+len(cut)]
    countcuts = search.count(cut)
    if countcuts == 1:
        barcode = search.split(cut, 1)[0]
    elif countcuts == 2:
        barcode = search.rsplit(cut, 2)[0]
    else:
        barcode = ""
    return barcode

In [42]:
## jit version
@jit
def jfindbarcode(cut, longbar, read1):
    search = read1[1][:longbar+len(cut)]
    countcuts = search.count(cut)
    if countcuts == 1:
        barcode = search.split(cut, 1)[0]
    elif countcuts == 2:
        barcode = search.rsplit(cut, 2)[0]
    else:
        barcode = ""
    return barcode

In [47]:
cut = "TGCAG"
longbar = 6
read1 = ['fakeread','AAACCCTGCAGAAAAAAAAAAAAAAAAA']
nread1 = np.array(['fakeread','AAACCCTGCAGAAAAAAAAAAAAAAAAA'])


In [48]:
%%timeit 
findbcode(cut, longbar, read1)

The slowest run took 5.48 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 914 ns per loop


In [59]:
%%timeit
findbcode(cut, longbar, nread1)

The slowest run took 11.24 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 1.08 µs per loop


In [51]:
%%timeit 
jfindbarcode(cut, longbar, nread1)

10000 loops, best of 3: 53.5 µs per loop


In [37]:
read1[1][:longbar+6].count("TGCAG")

1

In [259]:
def frequencies(stacked):
    """ return frequency counts """
    totals = stacked.sum(axis=1)
    totals = totals.sum(axis=0)
    freqs = totals/np.float32(totals.sum())
    return freqs


@jit(['float32[:,:](float32, float32, int16[:,:])'])
def jlikelihood1(errors, bfreqs, ustacks):
    """Probability homozygous. All numpy and no loop so there was 
    no numba improvement to speed when tested. """
    ## make sure base_frequencies are in the right order
    #print uniqstackl.sum()-uniqstack, uniqstackl.sum(), 0.001
    totals = np.array([ustacks.sum(axis=1)]*4).T
    prob = scipy.stats.binom.pmf(totals-ustacks, totals, errors)
    return np.sum(bfreqs*prob, axis=1)

def likelihood1(errors, bfreqs, ustacks):
    """Probability homozygous. All numpy and no loop so there was 
    no numba improvement to speed when tested. """
    ## make sure base_frequencies are in the right order
    #print uniqstackl.sum()-uniqstack, uniqstackl.sum(), 0.001
    totals = np.array([ustacks.sum(axis=1)]*4).T
    prob = scipy.stats.binom.pmf(totals-ustacks, totals, errors)
    return np.sum(bfreqs*prob, axis=1)


@jit(['float32[:](float32, float32[:], int16[:,:])'])
def jlikelihood2(errors, bfreqs, ustacks):
    """probability of heterozygous. Very minimal speedup w/ numba."""
    returns = np.zeros(len(ustacks), dtype=np.float32)
    for idx, ustack in enumerate(ustacks):
        spair = np.array(list(itertools.combinations(ustack, 2)))
        bpair = np.array(list(itertools.combinations(bfreqs, 2)))
        one = 2.*bpair.prod(axis=1)
        tot = ustack.sum()
        atwo = tot - spair[:, 0] - spair[:, 1]
        two = scipy.stats.binom.pmf(atwo, tot, (2.*errors)/3.)
        three = scipy.stats.binom.pmf(\
                    spair[:, 0], spair.sum(axis=1), 0.5)
        four = 1.-np.sum(bfreqs**2)
        returns[idx] = np.sum(one*two*(three/four))
    return np.array(returns)


def likelihood2(errors, bfreqs, ustacks):
    """probability of heterozygous"""
    returns = np.zeros([len(ustacks)])
    for idx, ustack in enumerate(ustacks):
        spair = np.array(list(itertools.combinations(ustack, 2)))
        bpair = np.array(list(itertools.combinations(bfreqs, 2)))
        one = 2.*bpair.prod(axis=1)
        tot = ustack.sum()
        atwo = tot - spair[:, 0] - spair[:, 1]
        two = scipy.stats.binom.pmf(atwo, tot, (2.*errors)/3.)
        three = scipy.stats.binom.pmf(\
                    spair[:, 0], spair.sum(axis=1), 0.5)
        four = 1.-np.sum(bfreqs**2)
        returns[idx] = np.sum(one*two*(three/four))
    return np.array(returns)




def get_diploid_lik(pstart, bfreqs, ustacks, counts):
    """ Log likelihood score given values [H,E] """
    hetero, errors = pstart
    if (hetero <= 0.) or (errors <= 0.):
        score = np.exp(100)
    else:
        ## get likelihood for all sites
        lik1 = (1.-hetero)*likelihood1(errors, bfreqs, ustacks)
        lik2 = (hetero)*likelihood2(errors, bfreqs, ustacks)
        liks = lik1+lik2
        logliks = np.log(liks[liks > 0])*counts[liks > 0]
        score = -logliks.sum()
    return score


@jit
def j_diploid_lik(pstart, bfreqs, ustacks, counts):
    """ Log likelihood score given values [H,E]. """
    hetero, errors = pstart
    ## tell it to score terribly if scores are negative
    if (hetero <= 0.) or (errors <= 0.):
        score = np.exp(100)
    else:
        ## get likelihood for all sites
        lik1 = (1.-hetero)*jlikelihood1(errors, bfreqs, ustacks)
        lik2 = (hetero)*jlikelihood2(errors, bfreqs, ustacks)
        liks = lik1+lik2
        logliks = np.log(liks[liks > 0])*counts[liks > 0]
        score = -logliks.sum()
    return score


In [261]:
pstart = np.array([0.01, 0.001], dtype=np.float32)

In [266]:
%%timeit
func = get_diploid_lik
hetero, errors = scipy.optimize.fmin(func, pstart,
                                    (bfreqs, ustacks, counts), 
                                    disp=False,
                                    full_output=False)

1 loops, best of 3: 7.62 s per loop


In [267]:
%%timeit
func = j_diploid_lik
hetero, errors = scipy.optimize.fmin(func, pstart,
                                    (bfreqs, ustacks, counts), 
                                    disp=False,
                                    full_output=False)

1 loops, best of 3: 7.72 s per loop
