# Pandas vs Numba vs JAX

In [11]:
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
from numba import njit

In [2]:
np.set_printoptions(edgeitems=30, linewidth=10000)
pd.options.display.max_rows = 200
pd.options.display.max_columns = 80
pd.options.display.max_colwidth = 100

In [3]:
n = 10000000
df = pd.DataFrame(dict(
    a = np.random.normal(1, 1, size=n),
    b = np.random.normal(0, 1, size=n),
    k = np.random.choice([0, 1, 2, 3, 4, 5], size=n),
))
df['k'] = df['k'].astype('category')
df.tail(2)

Unnamed: 0,a,b,k
9999998,1.915283,-0.08164,1
9999999,0.665971,0.401865,5


# Tests

## Vector Mean

In [6]:
%timeit _ = df['a'].mean()

35.6 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
%timeit _ = np.mean(df['a'].values)

3.92 ms ± 86.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
a = jnp.asarray(df['a'].values)
%timeit _ = jnp.mean(a)



1.99 ms ± 74.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [62]:
# @jax.jit
# def jax_mean(arr):
#     s = 0.
#     l = len(arr)
#     for i in range(l):
#         s += arr[i]
#     return s / l

# a = jnp.asarray(df['a'].values)
# %timeit _ = jax_mean(a)

In [15]:
@njit('f8(f8[:])')
def numba_mean(arr):
    s = 0.
    l = len(arr)
    for i in range(l):
        s += arr[i]
    return s / l

%timeit _ = numba_mean(df['a'].values)

9.81 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Vector Cumsum

In [18]:
%timeit _ = df['a'].cumsum()

66.2 ms ± 4.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [19]:
%timeit _ = np.cumsum(df['a'].values)

28.8 ms ± 748 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [21]:
a = jnp.asarray(df['a'].values)
%timeit _ = jnp.cumsum(a)

74.8 ms ± 740 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
@njit('f8[:](f8[:])')
def numba_cumsum(arr):
    l = len(arr)
    res = np.empty(l)
    res[0] = arr[0]
    for i in range(1, l):
        res[i] = arr[i] + res[i-1]
    return res

%timeit _ = numba_cumsum(df['a'].values)

32.8 ms ± 815 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Vector Dot

In [37]:
%timeit _ = df['a'] @ df['b']

5.15 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [38]:
%timeit _ = df['a'].values @ df['b'].values

5.48 ms ± 345 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [46]:
a = jnp.asarray(df['a'].values)
b = jnp.asarray(df['b'].values)
%timeit _ = jnp.dot(a, b)

11 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [44]:
@njit('f8(f8[:],f8[:])')
def numba_dot(arr1, arr2):
    s = 0.
    l = len(arr1)
    for i in range(l):
        s += arr1[i] * arr2[i]
    return s

%timeit _ = numba_dot(df['a'].values, df['b'].values)

11.6 ms ± 181 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Vector Exp

In [31]:
%timeit _ = np.exp(df['a'])

64.5 ms ± 902 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [26]:
%timeit _ = np.exp(df['a'].values)

65.5 ms ± 787 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [27]:
a = jnp.asarray(df['a'].values)
%timeit _ = jnp.exp(a)

6.58 ms ± 302 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Vector Log

In [34]:
%timeit _ = np.log(df['a']+5)

64.2 ms ± 1.92 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [35]:
a = jnp.asarray(df['a'].values)
%timeit _ = jnp.log(a+5)

10.7 ms ± 396 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Piecewise

In [84]:
def foo(x):
    if x < -1.:
        return x**3
    elif x < 0.1:
        return np.cos(x)
    elif x < 1.:
        return np.log(x)
    else:
        return x+10
    
%timeit _ = df['a'].transform(foo)

4.92 s ± 118 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [85]:
@np.vectorize
def foo(x):
    if x < -1.:
        return x**3
    elif x < 0.1:
        return np.cos(x)
    elif x < 1.:
        return np.log(x)
    else:
        return x+10
    
%timeit _ = foo(df['a'].values)

4.5 s ± 77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [86]:
def foo(x):
    r = np.where(
        x<-1., x**3,
        np.where(
            x<0.1, np.cos(x),
            np.where(
                x<1., np.log(x),
                x+10
            )
        )
    )
    return r
    
%timeit _ = foo(df['a'].values)

  x<1., np.log(x),


498 ms ± 13.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [87]:
def foo(x):
    r = jnp.where(
        x<-1., x**3,
        jnp.where(
            x<0.1, jnp.cos(x),
            jnp.where(
                x<1., jnp.log(x),
                x+10
            )
        )
    )
    return r

a = jnp.asarray(df['a'].values)
%timeit _ = foo(a)

66.2 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [88]:
@njit
def _numba_aux(x):
    if x < -1.:
        return x**3
    elif x < 0.1:
        return np.cos(x)
    elif x < 1.:
        return np.log(x)
    else:
        return x+10

@njit('f8[:](f8[:])')
def numba_foo(arr):
    l = len(arr)
    res = np.empty(l)
    for i in range(1, l):
        res[i] = _numba_aux(arr[i])
    return res

%timeit _ = numba_foo(df['a'].values)

121 ms ± 2.89 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


# Summary

In [89]:
summary = {
    'vector mean': {'pandas': 36., 'numpy': 4., 'jax numpy': 2., 'numba jit': 10.},
    'vector cumsum': {'pandas': 66., 'numpy': 29., 'jax numpy': 75., 'numba jit': 33.},
    'vector dot': {'pandas': 5., 'numpy': 5., 'jax numpy': 11., 'numba jit': 12.},
    'vector exp': {'numpy': 65., 'jax numpy': 7.},
    'vector log': {'numpy': 64., 'jax numpy': 11.},
    'conditional': {'pandas': 4900., 'numpy': 500., 'jax numpy': 66., 'numba jit': 121.},
}
pd.DataFrame(summary)

Unnamed: 0,vector mean,vector cumsum,vector dot,vector exp,vector log,conditional
pandas,36.0,66.0,5.0,,,4900.0
numpy,4.0,29.0,5.0,65.0,64.0,500.0
jax numpy,2.0,75.0,11.0,7.0,11.0,66.0
numba jit,10.0,33.0,12.0,,,121.0
