In [1]:
import matplotlib.pyplot as plt
import numpy as np
from numba import jit

Pure Python:

In [2]:
p,q=0.1,0.2

In [3]:
def find_states(n):
    states = np.empty(n,dtype=int)
    states[0]=1
    #Do the drawing already:
    unif = np.random.uniform(0,1,size=n)
    for i in range(1,n):
        x0=states[i-1]
        if x0==0:
            states[i] = unif[i]<p
        else:
            states[i] = unif[i]>q
    return states

In [4]:
x=find_states(100000)
print(np.mean(x==0))
%timeit find_states(100000)

0.66738
10 loops, best of 3: 70.6 ms per loop


Numba:

In [5]:
find_states_numba = jit(find_states)

In [6]:
x=find_states_numba(100000)
print(np.mean(x==0))
%timeit find_states_numba(100000)

0.6637
1000 loops, best of 3: 1.34 ms per loop


In [7]:
%load_ext Cython

In [8]:
%%cython
import numpy as np
from numpy cimport int_t, float_t

def find_states_cython(int n):
    x_np = np.empty(n, dtype=int)
    unif_np = np.random.uniform(0, 1, size=n)
    cdef int_t [:] x = x_np
    cdef float_t [:] U = unif_np
    cdef float p = 0.1
    cdef float q = 0.2
    cdef int t
    x[0] = 1
    for t in range(1,n):
        x0 = x[t-1]
        if x0 == 0:
            x[t] = U[t] < p
        else:
            x[t] = U[t] > q
    return np.asarray(x)

In [9]:
x=find_states_cython(int(100000))
print(np.mean(x==0))
%timeit find_states_cython(100000)

0.66177
100 loops, best of 3: 2.36 ms per loop
