In [1]:
###########################################
## Jax Demo
## Goal - fool around with Jax AutoDiff (function composition/transformation) library
##
## Author: Chris Meaney
## Date: Feb 2021
###########################################

In [2]:
## Import JAX dependencies
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, hessian
from jax import random

## Plotting
import matplotlib.pyplot as plt
import seaborn as sns

## Pandas data wrangling, summarization, etc.
import pandas as pd

## Old numpy - for comparing speed/flexibility JAX approach vs. existing NumPy/Scipy capabilities
import numpy as onp

In [3]:
## Set seed
key = random.PRNGKey(912834)



In [4]:
key

DeviceArray([     0, 912834], dtype=uint32)

In [115]:
##
## Note: JAX arrays are immutable by nature, need special syntax to alter constructed arrays
##
x = jnp.array([1,2,3,4])
x.at[2].set(10)

DeviceArray([ 1,  2, 10,  4], dtype=int32)

In [5]:
######################################
##
## Data Manipulation Tools
##
######################################

In [6]:
## Create data from "array like" cpecify vectors (or matrices) of given shape
x = jnp.array([1,2,3]).reshape(3,1)
y = jnp.array([8,9,0]).reshape(3,1)

In [7]:
## Create data from arange
jnp.arange(10)

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [8]:
## Create data from linspace
jnp.linspace(start=0, stop=10, num=11, endpoint=True)

DeviceArray([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], dtype=float32)

In [9]:
## Repeat some value
jnp.repeat(a=jnp.array([0]), repeats=10)

DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)

In [10]:
## Repeat some value
jnp.repeat(a=jnp.array([0,1,2]), repeats=3)

DeviceArray([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=int32)

In [11]:
## Repeat some value
## Note: axis={0,1} now stacks this in vector/matrix...
jnp.repeat(a=jnp.array([0,1,2]).reshape(3,1), repeats=3, axis=1)

DeviceArray([[0, 0, 0],
             [1, 1, 1],
             [2, 2, 2]], dtype=int32)

In [12]:
## jax.random sub-package (random integers)
random.randint(key, minval=0, maxval=10, shape=(10,))

DeviceArray([8, 1, 3, 6, 5, 8, 7, 5, 4, 5], dtype=int32)

In [13]:
## jax.random sub-package (random normal)
x = random.normal(key, shape=(10000,))
x = x*2 + 10
pd.Series(x).describe()

count    10000.000000
mean         9.985111
std          2.010155
min          1.957885
25%          8.645895
50%         10.003785
75%         11.340447
max         19.328236
dtype: float64

In [14]:
## Read/load array like data (possibly pickled numpy arrays) directly in jax.array structure
##
## jax.numpy.load

In [15]:
##
## Read data in from some friendly "pandas-like" tool; and coerce the dataFrame -> array -> jax.array 
##
## df = pd.read_csv()
## df_np = df.to_numpy
## df_jax = jnp.array(df_np)
##

In [16]:
## Note can also/save or export data with similar functionality
##
## jax.numpy.save
## jax.numpy.savez

In [17]:
######################################
##
## Data Manipulation Tools
##
######################################

In [18]:
## Concatenate the vectors
x = jnp.array([1,2,3]).reshape(3,1)
y = jnp.array([8,9,0]).reshape(3,1)

A = jnp.concatenate((x,y), axis=1)
A

DeviceArray([[1, 8],
             [2, 9],
             [3, 0]], dtype=int32)

In [19]:
## Stacking vectors (like cbind in R)
jnp.hstack((x,y))

DeviceArray([[1, 8],
             [2, 9],
             [3, 0]], dtype=int32)

In [20]:
## Stacking matrices (hstack is like R cbind)
A = random.uniform(key, shape=(6,)).reshape(2,3)
B = random.uniform(key, shape=(6,)).reshape(2,3)
jnp.hstack((A,B))

DeviceArray([[0.54742837, 0.1428895 , 0.4743966 , 0.54742837, 0.1428895 ,
              0.4743966 ],
             [0.67102444, 0.6460316 , 0.75969887, 0.67102444, 0.6460316 ,
              0.75969887]], dtype=float32)

In [21]:
## Stacking m(atrices (vstack is like R rbind)
jnp.vstack((A,B))

DeviceArray([[0.54742837, 0.1428895 , 0.4743966 ],
             [0.67102444, 0.6460316 , 0.75969887],
             [0.54742837, 0.1428895 , 0.4743966 ],
             [0.67102444, 0.6460316 , 0.75969887]], dtype=float32)

In [22]:
## Concatenation (axis=0 ==> rows)
jnp.concatenate((A,B), axis=0)

DeviceArray([[0.54742837, 0.1428895 , 0.4743966 ],
             [0.67102444, 0.6460316 , 0.75969887],
             [0.54742837, 0.1428895 , 0.4743966 ],
             [0.67102444, 0.6460316 , 0.75969887]], dtype=float32)

In [23]:
## Concatenation (axis=1 ==> cols)
jnp.concatenate((A,B), axis=1)

DeviceArray([[0.54742837, 0.1428895 , 0.4743966 , 0.54742837, 0.1428895 ,
              0.4743966 ],
             [0.67102444, 0.6460316 , 0.75969887, 0.67102444, 0.6460316 ,
              0.75969887]], dtype=float32)

In [24]:
## Unravel a matrix into a vector
## *** WARNING *** order='C' implements C-style (row-major) order of the unravel operations
jnp.ravel(A, order='C')

DeviceArray([0.54742837, 0.1428895 , 0.4743966 , 0.67102444, 0.6460316 ,
             0.75969887], dtype=float32)

In [25]:
## Unravel a matrix into a vector
## *** WARNING *** order='F' implements Fortran-style (col-major) order of the unravel operations
jnp.ravel(A, order='F')

DeviceArray([0.54742837, 0.67102444, 0.1428895 , 0.6460316 , 0.4743966 ,
             0.75969887], dtype=float32)

In [26]:
## Unravel a matrix into a vector
## Note order='K' and order='A' are "not implemented" errors

In [27]:
## Slicing/indexing a vector
x = jnp.arange(10)
x[2:7]

DeviceArray([2, 3, 4, 5, 6], dtype=int32)

In [28]:
## Slicing/indexing a vector
x = jnp.arange(10)
x[7:]

DeviceArray([7, 8, 9], dtype=int32)

In [29]:
## Slicing/indexing a vector
x = jnp.arange(10)
x[:7]

DeviceArray([0, 1, 2, 3, 4, 5, 6], dtype=int32)

In [30]:
## Sorting
x = jnp.array([1,4,2])
x.sort()

DeviceArray([1, 2, 4], dtype=int32)

In [31]:
########################################
## Inspection type functions
########################################

In [32]:
x = jnp.array([2,3,4,4]).reshape(2,2)

In [33]:
## Shape/dimension of object
x.shape

(2, 2)

In [34]:
## Size (number elements)
x.size

4

In [35]:
## "type" of object (float, int, etc.)
x.dtype

dtype('int32')

In [36]:
#######################################
##
## Mathematical/numerics functions in numpy/scipy 
##
#######################################

In [37]:
## Square root
jnp.sqrt(4)

DeviceArray(2., dtype=float32)

In [38]:
## Square X^2
jnp.square(9)

DeviceArray(81, dtype=int32)

In [39]:
## Sequence of integers
x = jnp.arange(10)
x

DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [40]:
## Sum of vector
x.sum()

DeviceArray(45, dtype=int32)

In [41]:
## Cumsum of vector
x.cumsum()

DeviceArray([ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45], dtype=int32)

In [42]:
## Cumprod of vector
(x+1).cumprod()

DeviceArray([      1,       2,       6,      24,     120,     720,
                5040,   40320,  362880, 3628800], dtype=int32)

In [43]:
## Discrete differences
xdiff = jnp.diff(a=x)
xdiff

DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)

In [44]:
## Absolute values
x = jnp.array([-2,-1,0,1,2])
jnp.abs(x)

DeviceArray([2, 1, 0, 1, 2], dtype=int32)

In [45]:
## Max 
jnp.max(x)

DeviceArray(2, dtype=int32)

In [46]:
## Min
jnp.min(x)

DeviceArray(-2, dtype=int32)

In [47]:
## Positional index where max occurs
jnp.argmax(x)

DeviceArray(4, dtype=int32)

In [48]:
## Positional index where min occurs
jnp.argmin(x)

DeviceArray(0, dtype=int32)

In [49]:
## Count zeroes
jnp.count_nonzero(x)

DeviceArray(4, dtype=int32)

In [50]:
## Trig functions
jnp.sin(10)

DeviceArray(-0.5440211, dtype=float32)

In [51]:
## Trig functions
jnp.cos(10)

DeviceArray(-0.8390715, dtype=float32)

In [52]:
## Exponential function
jnp.exp(1)

DeviceArray(2.7182817, dtype=float32)

In [53]:
## Logarithmic function
jnp.log(1)

DeviceArray(0., dtype=float32)

In [54]:
############################################
##
## Boolean type evaluations
##
############################################

In [55]:
## Logical comparison (equality)
jnp.array([2]) == jnp.array([2])

DeviceArray([ True], dtype=bool)

In [56]:
## Logical comparison (equality)
jnp.array([2]) != jnp.array([2])

DeviceArray([False], dtype=bool)

In [57]:
## Greater than
jnp.array([3]) >= jnp.array([1])

DeviceArray([ True], dtype=bool)

In [58]:
## Less than
jnp.array([1]) < jnp.array([0])

DeviceArray([False], dtype=bool)

In [59]:
## Broadcast boolean logic over vectors
x = jnp.arange(10)
x < 4

DeviceArray([ True,  True,  True,  True, False, False, False, False,
             False, False], dtype=bool)

In [60]:
## Check equality of vectors
x = jnp.array([1,2,3])
y = jnp.array([1,2,3])
x == y

DeviceArray([ True,  True,  True], dtype=bool)

In [61]:
## Equality elements row wise
jnp.equal(x,y)

DeviceArray([ True,  True,  True], dtype=bool)

In [62]:
## All.equal
jnp.alltrue(jnp.equal(x,y))

DeviceArray(True, dtype=bool)

In [63]:
## Boolean test (is in set)
a = jnp.array([1,2,3,4])
jnp.isin(element=a, test_elements=jnp.array([2]))

DeviceArray([False,  True, False, False], dtype=bool)

In [64]:
## Boolean test: (where condition met)
x = jnp.arange(10)
x[jnp.where(x>4)]

DeviceArray([5, 6, 7, 8, 9], dtype=int32)

In [65]:
## set difference
a = jnp.array([1,2,3,4])
b = jnp.array([2])
jnp.setdiff1d(a,b)

DeviceArray([1, 3, 4], dtype=int32)

In [66]:
## intersection
a = jnp.array([1,2,3,4])
b = jnp.array([2])
jnp.intersect1d(a,b)

DeviceArray([2], dtype=int32)

In [67]:
## Note: union1d function "not yet implemented" error

#jnp.union1d(a, jnp.array([10000]))

In [68]:
#################################################################
##
## Statistics functions
##
#################################################################

In [69]:
## Random normal vector
x = random.normal(key, shape=(1000,))
x = 2*x + 10
pd.Series(x).describe()

count    1000.000000
mean        9.998937
std         2.019339
min         3.788454
25%         8.573010
50%        10.021622
75%        11.303478
max        17.007282
dtype: float64

In [70]:
## Vector of ones
y = jnp.ones(1000)
pd.Series(y).value_counts()

1.0    1000
dtype: int64

In [71]:
## Add to vectors
z = x + y
pd.Series(z).describe()

count    1000.000000
mean       10.998937
std         2.019339
min         4.788454
25%         9.573010
50%        11.021622
75%        12.303478
max        18.007282
dtype: float64

In [72]:
## Multiply a scalar times a vector (broadcasting)
z = 10*x
pd.Series(z).describe()

count    1000.000000
mean       99.989357
std        20.193392
min        37.884541
25%        85.730106
50%       100.216217
75%       113.034779
max       170.072815
dtype: float64

In [73]:
## Mean
jnp.mean(z)

DeviceArray(99.98935, dtype=float32)

In [74]:
## Standard deviation
jnp.std(z)

DeviceArray(20.183294, dtype=float32)

In [75]:
## Variance
jnp.var(z)

DeviceArray(407.36536, dtype=float32)

In [76]:
## Percentiles
jnp.percentile(z, q=jnp.array([0,25,50,75,100]))

DeviceArray([ 37.88454 ,  85.7301  , 100.21622 , 113.034775, 170.07281 ],            dtype=float32)

In [77]:
## Covariance
key, subkey = random.split(key)
x = random.normal(key, shape=(1000,1))
y = random.normal(subkey, shape=(1000,1))
xy = jnp.concatenate((x,y), axis=1)
#xy.shape
jnp.cov(xy.T)

DeviceArray([[ 0.9426041 , -0.02633217],
             [-0.02633217,  0.95755184]], dtype=float32)

In [78]:
## Correlation matrix
jnp.corrcoef(xy.T)
#onp.corrcoef(xy.T)

DeviceArray([[ 1.        , -0.02771666],
             [-0.02771666,  1.        ]], dtype=float32)

In [79]:
##################################################################
##
## Vector/matrix operations, linear algebra
##
##################################################################

In [80]:
A = random.randint(key, minval=0, maxval=2, shape=(3,3))
A

DeviceArray([[0, 0, 0],
             [0, 0, 0],
             [1, 1, 1]], dtype=int32)

In [81]:
b = jnp.array([0,1,2]).reshape(3,1)
b

DeviceArray([[0],
             [1],
             [2]], dtype=int32)

In [82]:
## Matrix multiplication
jnp.matmul(A, A)

DeviceArray([[0, 0, 0],
             [0, 0, 0],
             [1, 1, 1]], dtype=int32)

In [83]:
## Vector matrix multiplication
jnp.matmul(A,b)

DeviceArray([[0],
             [0],
             [3]], dtype=int32)

In [84]:
## Dot product (two vectors)
jnp.dot(b.T,b)

DeviceArray([[5]], dtype=int32)

In [85]:
## Dot product matrix and Vector
B = jnp.array([[0,1,0],[0,0,2]])
#B
jnp.dot(B,b)

DeviceArray([[1],
             [4]], dtype=int32)

In [86]:
# Diagonal matrix
A = jnp.diag(v=jnp.array([1,2,3]), k=0)
A

DeviceArray([[1, 0, 0],
             [0, 2, 0],
             [0, 0, 3]], dtype=int32)

In [87]:
## Identity matrix
B = jnp.eye(3)
B

DeviceArray([[1., 0., 0.],
             [0., 1., 0.],
             [0., 0., 1.]], dtype=float32)

In [88]:
## Matrix multiply again
jnp.matmul(A,B)

DeviceArray([[1., 0., 0.],
             [0., 2., 0.],
             [0., 0., 3.]], dtype=float32)

In [89]:
## Cross/outer product
b = jnp.array([0,1,2]).reshape(3,1)
b

DeviceArray([[0],
             [1],
             [2]], dtype=int32)

In [90]:
## Cross/outer product
b = jnp.array([0,1,2]).reshape(3,1)
b

DeviceArray([[0],
             [1],
             [2]], dtype=int32)

In [91]:
## Dot/inner product between two vectors again
b.T.dot(b)

DeviceArray([[5]], dtype=int32)

In [92]:
## Outer product
jnp.outer(b,b.T)

DeviceArray([[0, 0, 0],
             [0, 1, 2],
             [0, 2, 4]], dtype=int32)

In [93]:
## Matrix multiplication again
X = jnp.arange(6).reshape(2,3)
Y = jnp.diag(jnp.array([0,1,2]))
jnp.matmul(X,Y)

DeviceArray([[ 0,  1,  4],
             [ 0,  4, 10]], dtype=int32)

In [94]:
## Matrix addition
Z = Y + jnp.eye(3) + jnp.zeros(9).reshape(3,3)
Z

DeviceArray([[1., 0., 0.],
             [0., 2., 0.],
             [0., 0., 3.]], dtype=float32)

In [95]:
## Inverse of matrix
jnp.linalg.inv(Z)

DeviceArray([[1.        , 0.        , 0.        ],
             [0.        , 0.5       , 0.        ],
             [0.        , 0.        , 0.33333334]], dtype=float32)

In [96]:
## Determinant of matrix
jnp.linalg.det(Z)

DeviceArray(6., dtype=float32)

In [97]:
## Trace
jnp.trace(Z)

DeviceArray(6., dtype=float32)

In [98]:
## Eigenvalues
jnp.linalg.eigvals(Z)

DeviceArray([1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)

In [99]:
## Condition number
jnp.linalg.cond(Z)

DeviceArray(3., dtype=float32)

In [100]:
## Norm (Frobenius)
jnp.linalg.norm(Z, ord='fro')

DeviceArray(3.7416575, dtype=float32)

In [101]:
## Singular value decomposition of matrix
u, s, vt = jnp.linalg.svd(Z)

In [102]:
## Matrix rank
jnp.linalg.matrix_rank(Z)

DeviceArray(3, dtype=int32)

In [103]:
## Qr decomposition of matrix
q, r = jnp.linalg.qr(Z)
print([q,r])

[DeviceArray([[ 1.,  0.,  0.],
             [-0.,  1.,  0.],
             [-0., -0.,  1.]], dtype=float32), DeviceArray([[1., 0., 0.],
             [0., 2., 0.],
             [0., 0., 3.]], dtype=float32)]


In [104]:
## Cholesky decomposition of matrix
jnp.linalg.cholesky(Z)

DeviceArray([[1.       , 0.       , 0.       ],
             [0.       , 1.4142135, 0.       ],
             [0.       , 0.       , 1.7320508]], dtype=float32)

In [105]:
##
## Least squares example
##

## Dimensions of the problem
nrow = 10000
ncol = 3
beta = jnp.array([0,0,1,-1]).reshape(ncol+1,1)

## Create dataset/design-matrix
X = random.normal(key, shape=(nrow*ncol,)).reshape(nrow,ncol)
ones = jnp.repeat(1.0, nrow).reshape(nrow,1)
X = jnp.hstack((X, ones))
## Construct linear predictor (y-true)
y = jnp.matmul(X, beta)

## Using linear algebra to compute estimated parameters
XtX = jnp.matmul(X.T, X)
Xty = jnp.matmul(X.T, y)

## Learned vector of betas
beta_hat = jnp.matmul(jnp.linalg.inv(XtX), Xty)
print(beta)
print(beta_hat)

#jnp.isclose(beta, beta_hat)

[[ 0]
 [ 0]
 [ 1]
 [-1]]
[[-2.0122171e-08]
 [ 9.5216226e-09]
 [ 1.0000000e+00]
 [-1.0000001e+00]]


In [106]:
## Least squares solution
jnp.linalg.lstsq(a=X, b=y)

(DeviceArray([[ 1.5450608e-07],
              [ 4.4876984e-08],
              [ 1.0000001e+00],
              [-9.9999958e-01]], dtype=float32),
 DeviceArray([2.1880242e-09], dtype=float32),
 DeviceArray(4, dtype=int32),
 DeviceArray([100.82345,  99.9497 ,  99.60676,  98.52325], dtype=float32))

In [107]:
########################################################################
##
## Note: many of the same decompositions exist in jax.scipy.linalg
##
########################################################################

In [108]:
## E.g. SVD using SciPy
jax.scipy.linalg.svd(Z)

(DeviceArray([[0., 0., 1.],
              [0., 1., 0.],
              [1., 0., 0.]], dtype=float32),
 DeviceArray([3., 2., 1.], dtype=float32),
 DeviceArray([[0., 0., 1.],
              [0., 1., 0.],
              [1., 0., 0.]], dtype=float32))

In [109]:
#############################
## Scipy special functions
#############################

In [110]:
## gamma function
jax.scipy.special.gammaln(4)

DeviceArray(1.7917598, dtype=float32)

In [111]:
## Confirm the log gamma function...
jnp.log(jnp.prod(jnp.arange(3)+1))

DeviceArray(1.7917595, dtype=float32)

In [112]:
##
## Note: many other specifial functions exist
##
## digamma
## betainc
## logit 
## expit
## logsumexp
## efr
## erfinv
## etc. etc. etc.
##

In [113]:
from sinfo import sinfo
sinfo()

-----
jax         0.2.9
matplotlib  3.3.4
numpy       1.20.1
pandas      1.2.2
seaborn     0.11.1
sinfo       0.3.1
-----
IPython             7.20.0
jupyter_client      6.1.11
jupyter_core        4.7.1
jupyterlab          3.0.9
notebook            6.2.0
-----
Python 3.7.10 | packaged by conda-forge | (default, Feb 19 2021, 16:07:37) [GCC 9.3.0]
Linux-3.10.0-1127.el7.x86_64-x86_64-with-centos-7.8.2003-Core
79 logical CPU cores, x86_64
-----
Session information updated at 2021-02-26 15:13
