# HPC in Python with TensorFlow

In [1]:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

Depending on the CPU used and the TensorFlow binaries, what we will see (not in the Jupyter Notebook) are a bunch of messages, including the following:
`tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA`

What could they mean?

AVX stands for Advanced Vector Extensions and are instruction that the CPU can perform. They are specializations on vector operations (remember? SIMD, CPU inststruction set, etc.)

Why do they appear?

The code that we are using was not compiled with this flag on. This means, TensorFlow assumes that the CPU does not support this instructions and instead uses non-optimized ones. The reason is that this allows the binary (=compiled code) to also be run on a CPU that does not support then. While we use only some speed.
(yes, technically TensorFlow can be faster when compiled natively on your computer, but then it takes time and effort)

In [2]:
from jax import config
config.update("jax_enable_x64", True)

In [3]:
import numpy as np
import numba
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import torch

Let's start with a simple comparison of Numpy, an AOT compiled library, versus pure Python

In [4]:
size1 = 100000
list1 = [np.random.uniform() for _ in range(size1)]
list2 = [np.random.uniform() for _ in range(size1)]
list_zeros = [0] * size1

ar1 = np.array(list1)
ar2 = np.random.uniform(size=size1)  # way more efficient!
ar_zeros = np.zeros_like(ar1) # quite useful function the *_like -> like the object
# we could also create the below, in general better:
# ar_empty = np.empty(shape=size1, dtype=np.float64)

In [5]:
%%timeit
for i in range(size1):
    list_zeros[i] = list1[i] + list2[i]

4.45 ms ± 96.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
%%timeit
ar1 + ar2

58.5 μs ± 961 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


( _playground_ : we can also try assignements here or simliar)

### Fast and slow

Before we go deeper into the topic, we can draw two conclusions:
- slow Python is not "slow": it is still blazingly fast on an absolute scale, e.g if you need to loop over a few hundred points, it's still nothing. But it can add up!
- Numpy is a factor of 300 faster for this case (and: better reabable!)

=> there is _no reason_ to ever add (numerical) arrays with a for loop in Python (except for numba jit)

As mentioned, TensorFlow is basically Numpy. Let's check that out

In [7]:
rnd1_np = np.random.uniform(size=10, low=0, high=10)
rnd1_np  # adding a return value on the last line without assigning it prints the value

array([2.17546514, 8.92779882, 4.33083793, 4.96530809, 6.80797735,
       8.06737767, 3.36438422, 0.35439775, 4.55640887, 2.92539731])

In [8]:
rnd1 = tnp.random.uniform(size=(10,),
                         low=0,
                         high=10)
rnd2 = tf.random.uniform(shape=(10,),  # notice the "shape" argument: it's more picky than Numpy
                         minval=0,
                         maxval=10,
                         dtype=tf.float64)

In [9]:
rnd1

<tf.Tensor: shape=(10,), dtype=float64, numpy=
array([8.94494362, 8.76141472, 1.52177648, 0.18231205, 3.20450439,
       4.4144011 , 8.12311613, 2.13679637, 4.84699734, 8.51972251])>

This is in fact a "numpy array wrapped" and can explicitly be converted to an array

In [10]:
type(rnd1.numpy()), type(np.asarray(rnd1))

(numpy.ndarray, numpy.ndarray)

Other operations act as we would expect it

In [11]:
rnd1 + 10

<tf.Tensor: shape=(10,), dtype=float64, numpy=
array([18.94494362, 18.76141472, 11.52177648, 10.18231205, 13.20450439,
       14.4144011 , 18.12311613, 12.13679637, 14.84699734, 18.51972251])>

... and it converts itself (often) to Numpy when needed.

In [12]:
np.sqrt(rnd1)

array([2.99080986, 2.9599687 , 1.23360305, 0.42698015, 1.79011295,
       2.10104762, 2.85010809, 1.4617785 , 2.20158973, 2.91885637])

We can slice it...

In [13]:
rnd1[1:3]

<tf.Tensor: shape=(2,), dtype=float64, numpy=array([8.76141472, 1.52177648])>

...expand it....

In [14]:
rnd1[None, :, None]

<tf.Tensor: shape=(1, 10, 1), dtype=float64, numpy=
array([[[8.94494362],
        [8.76141472],
        [1.52177648],
        [0.18231205],
        [3.20450439],
        [4.4144011 ],
        [8.12311613],
        [2.13679637],
        [4.84699734],
        [8.51972251]]])>

...and broadcast with the known (maybe slightly stricter) rules

In [15]:
matrix1 = rnd1[None, :] * rnd1[:, None]
matrix1

<tf.Tensor: shape=(10, 10), dtype=float64, numpy=
array([[8.00120164e+01, 7.83703608e+01, 1.36122048e+01, 1.63077097e+00,
        2.86641111e+01, 3.94865689e+01, 7.26608158e+01, 1.91135231e+01,
        4.33561180e+01, 7.62084375e+01],
       [7.83703608e+01, 7.67623879e+01, 1.33329148e+01, 1.59731144e+00,
        2.80759919e+01, 3.86763988e+01, 7.11699892e+01, 1.87213592e+01,
        4.24665539e+01, 7.46448222e+01],
       [1.36122048e+01, 1.33329148e+01, 2.31580365e+00, 2.77438182e-01,
        4.87653941e+00, 6.71773176e+00, 1.23615671e+01, 3.25172646e+00,
        7.37604655e+00, 1.29651133e+01],
       [1.63077097e+00, 1.59731144e+00, 2.77438182e-01, 3.32376818e-02,
        5.84219749e-01, 8.04798492e-01, 1.48094191e+00, 3.89563717e-01,
        8.83665998e-01, 1.55324803e+00],
       [2.86641111e+01, 2.80759919e+01, 4.87653941e+00, 5.84219749e-01,
        1.02688484e+01, 1.41459677e+01, 2.60305613e+01, 6.84737336e+00,
        1.55322243e+01, 2.73014882e+01],
       [3.94865689e+01, 3

## Exercise 1

*We will have exercises of this type throughout the notebook*

Can you do the same with jax? Start with the following arrays below.

Anything that surprises you?

In [16]:
jrnd1 = jnp.asarray(rnd1)
jrnd2 = jnp.asarray(rnd2)
type(ar2)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


numpy.ndarray

In [17]:
torchrnd1 = torch.tensor(rnd1)

TypeError: Scalar tensor has no `len()`

In [18]:
torchrnd1 = torch.tensor(jrnd1)

TypeError: len() of unsized object

In [19]:
torchrnd1 = torch.tensor(np.array(rnd1))

In [20]:
tnp.array(torchrnd1)

TypeError: Cannot interpret 'tensor([8.9449, 8.7614, 1.5218, 0.1823, 3.2045, 4.4144, 8.1231, 2.1368, 4.8470,
        8.5197], dtype=torch.float64)' as a data type

In [21]:
jnp.array(torchrnd1)

Array([8.94494362, 8.76141472, 1.52177648, 0.18231205, 3.20450439,
       4.4144011 , 8.12311613, 2.13679637, 4.84699734, 8.51972251],      dtype=float64)

In [22]:
type(np.asarray(torchrnd1))

numpy.ndarray

### Array API

These tensors implement (partially) the array API, a specification for communicating the data of an array-like object. `__array__` is one of the (fallback) methods, but there are more.

Do not call them in applications, instead, `np.asarray` will figure out the most efficient way to retrieve the underlying data storage.

In [23]:
torchrnd1.__array__(), jrnd1.__array__(), rnd1.__array__()

(array([8.94494362, 8.76141472, 1.52177648, 0.18231205, 3.20450439,
        4.4144011 , 8.12311613, 2.13679637, 4.84699734, 8.51972251]),
 array([8.94494362, 8.76141472, 1.52177648, 0.18231205, 3.20450439,
        4.4144011 , 8.12311613, 2.13679637, 4.84699734, 8.51972251]),
 array([8.94494362, 8.76141472, 1.52177648, 0.18231205, 3.20450439,
        4.4144011 , 8.12311613, 2.13679637, 4.84699734, 8.51972251]))

## Equivalent operations

Many operations that exist in Numpy also exist in JAX & Friends, sometimes with a different name.

The concept however is exactly the same: we have higher level objects such as Tensors (or arrays) and call operations on it with arguments. This is a "strong limitation" (theoretically) of what we can do, however, since we do math, there is only a limited set we need, and in practice this suffices for 98% of the cases.

Therefore we won't dive too deep into the possibilities of TensorFlow/JAX/torch/Numpy regarding operations but it is suggested to read the API docs of [TensorFlow](https://www.tensorflow.org/versions), [JAX](https://jax.readthedocs.io/en/latest/jax.numpy.html) or [torch](https://pytorch.org/docs/stable/torch.html), many are self-explanatory. It can be surprising that there is also some support for more exotic elements such as [RaggedTensors and operations](https://www.tensorflow.org/api_docs/python/tf/ragged?) and [SparseTensors and operations](https://www.tensorflow.org/api_docs/python/tf/sparse?) in TensorFlow or a (partial) [SciPy substitut](https://jax.readthedocs.io/en/latest/jax.scipy.html).

Mostly, the differences and the terminology will be introduced.

In [24]:
tf.sqrt(rnd1), tnp.sqrt(rnd1)

(<tf.Tensor: shape=(10,), dtype=float64, numpy=
 array([2.99080986, 2.9599687 , 1.23360305, 0.42698015, 1.79011295,
        2.10104762, 2.85010809, 1.4617785 , 2.20158973, 2.91885637])>,
 <tf.Tensor: shape=(10,), dtype=float64, numpy=
 array([2.99080986, 2.9599687 , 1.23360305, 0.42698015, 1.79011295,
        2.10104762, 2.85010809, 1.4617785 , 2.20158973, 2.91885637])>)

In [25]:
tf.reduce_sum(matrix1, axis=0), tnp.sum(matrix1, axis=0)  # with the axis argument to specify over which to reduce

(<tf.Tensor: shape=(10,), dtype=float64, numpy=
 array([453.11492741, 443.81809024,  77.08708604,   9.23519617,
        162.32732542, 223.61583442, 411.48444641, 108.2415244 ,
        245.5294233 , 431.57493327])>,
 <tf.Tensor: shape=(10,), dtype=float64, numpy=
 array([453.11492741, 443.81809024,  77.08708604,   9.23519617,
        162.32732542, 223.61583442, 411.48444641, 108.2415244 ,
        245.5294233 , 431.57493327])>)

### DTypes

TensorFlow is more picky on dtypes as Numpy and does not automatically cast dtypes. That's why we can often get a dtype error. Solution: make sure you add a `x = tf.cast(x, dtype=tf.float64)` (or whatever dtype we want) to cast it into the right dtype.

One noticable difference: TensorFlow and JAX use float32 as the default for all operations. Neural Networks function quite well with that (sometimes even with float16) but for (most) scientific use-cases, we want to use float64. So yes, [currently](https://github.com/tensorflow/tensorflow/issues/26033), we have to define this in (too) many places.

## What we can't do: assignements

The idea of JAX & friends evolves around building an abstract representation of the mathematical operations, sometimes referred to as graph$^{1)}$ inside a JITted functien. This has one profound implication, namely that we cannot make an _assignement_ to a Tensor, because it is a node in a graph. The logic just does not work (exception: `tf.Variable`). This does not mean that JAX & friends would not perform in-place operations _behind the scenes_ - they very well do if it is save to do so. Since JAX & friends know the whole graph with all dependencies, this can be figured out. See aloso in the [JAX docs about assignements](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#array-updates-x-at-idx-set-y)

Even in eager mode, without jit compilation, assignements could work (as for Numpy arrays), they are forbidden for consistency (one of the great plus points of TensorFlow).

_1) if you're familiar with TensorFlow 1, this statement would suprise you as pretty obvious; but in TensorFlow 2, JAX & friends, this is luckily more hidden.  

In [26]:
rnd1_np[5] = 42

In [27]:
try:
    rnd1[5] = 42
except TypeError as error:
    print(error)

'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment


### Speed comparison

Let's do the same calculation as with Numpy. The result should be comparable: both are AOT compiled libraries specialized on numerical, vectorized operations.

In [28]:
rnd1_big = tf.random.uniform(shape=(size1,),  # notice the "shape" argument: it's more picky than Numpy
                         minval=0,
                         maxval=10,
                         dtype=tf.float64)
rnd2_big = tf.random.uniform(shape=(size1,),
                         minval=0,
                         maxval=10,
                         dtype=tf.float64)

In [128]:
jrnd1_big = jnp.asarray(rnd1_big)

In [29]:
%%timeit
rnd1_big + rnd2_big

105 μs ± 3.84 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [30]:
%%timeit  # using numpy, same as before
ar1 + ar2

37.9 μs ± 197 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Looks like the same as Numpy. Let's compare with smaller arrays

In [31]:
rnd1_np = np.asarray(rnd1)
rnd2_np = np.asarray(rnd2)

In [32]:
rnd1_np

array([8.94494362, 8.76141472, 1.52177648, 0.18231205, 3.20450439,
       4.4144011 , 8.12311613, 2.13679637, 4.84699734, 8.51972251])

In [33]:
%%timeit
rnd1_np + rnd2_np

385 ns ± 6.85 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [34]:
%%timeit
rnd1 + rnd2

43 μs ± 311 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### TensorFlow is slow?

We see now a significant difference in the runtime! This is because TensorFlow has a larger overhead than Numpy. As seen before, this is not/barely noticable for larger arrays, however for very small calculations, this is visible.

There is more overhead because TensorFlow tries to be "smarter" about many things than Numpy and does not simply directly execute the computation.

The cost is a slowdown on very small operations but a better scaling and improved performance with larger arrays and more complicated calculations.

In [35]:
# relative speeds may differ, depending on the hardware used.
# size_big = 10  # numpy faster
size_big = 20000  # sameish
# size_big = 100000  # TF faster
# size_big = 1000000  # TF faster
# size_big = 10000000  # TF faster
# size_big = 100000000  # TF faster

In [36]:
%%timeit
tf.random.uniform(shape=(size_big,), dtype=tf.float64) + tf.random.uniform(shape=(size_big,), dtype=tf.float64)

519 μs ± 21.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [37]:
%%timeit
np.random.uniform(size=(size_big,)) + np.random.uniform(size=(size_big,))

290 μs ± 14 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Computing kernels

In general, TensorFlow is preciser in what input arguments are required compared to Numpy and JAX and does less automatic dtype casting and asks more explicit for shapes. For example, integers don't work in the logarithm. However, this error message illustrates very well the kernel dispatch system of TensorFlow, so lets do it!

In [38]:
try:
    tf.math.log(5)
except tf.errors.NotFoundError as error:
    print(error)

InvalidArgumentError: Value for attr 'T' of int32 is not in the list of allowed values: bfloat16, half, float, double, complex64, complex128
	; NodeDef: {{node Log}}; Op<name=Log; signature=x:T -> y:T; attr=T:type,allowed=[DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128]> [Op:Log] name: 

In [39]:
torch.log(5)  # torch is a bit... unfriendly sometimes

TypeError: log(): argument 'input' (position 1) must be Tensor, not int

What we see here: it searches the registered kernels and does not find any that supports this operation. We find different classifications:
- GPU: normal GPU kernel
- CPU: normal CPU kernel
- XLA: [Accelerated Linear Algebra](https://www.tensorflow.org/xla) is a high-level compiler that can fuse operations, which would result in single calls to a fused kernel. JAX JIT is built around XLA.

## just-in-time compilation

Let's see the JIT in action. Therefore, we use the example from the slides and start modifying it.

In [40]:
def add_log(x, y):
    print('running Python')
    tf.print("running compiled code")
    x_sq = tnp.log(x)
    y_sq = tnp.log(y)
    return x_sq + y_sq

As seen before, we can use it like Python. To make sure that we know when the actual Python is executed, we inserted a print and a `tf.print` or a `jax.debug.print`, the latter is a TensorFlow/JAX operation and therefore expected to be called everytime we compute something.

In [41]:
add_log(4., 5.)

running Python
running compiled code


<tf.Tensor: shape=(), dtype=float64, numpy=2.995732273553991>

In [42]:
add_log(42., 52.)

running Python
running compiled code


<tf.Tensor: shape=(), dtype=float64, numpy=7.688913336864796>

As we see, both the Python and TensorFlow operation execute. Now we can do the same with a decorator. Note that so far we entered pure Python numbers, not Tensors. Since we ran in eager mode, this did not matter so far.

In [43]:
@tf.function
def add_log_tf(x, y):
    print('running Python')
    tf.print("running TensorFlow")
    x_sq = tf.math.log(x)
    y_sq = tf.math.log(y)
    return x_sq + y_sq

In [44]:
add_log_tf(1., 2.)

running Python
running TensorFlow


<tf.Tensor: shape=(), dtype=float32, numpy=0.6931472>

In [45]:
add_log_tf(11., 21.)  # again with different numbers

running Python
running TensorFlow


<tf.Tensor: shape=(), dtype=float32, numpy=5.442418>

As we see, Python is still run: this happens because 11. is not equal to 1., TensorFlow does not convert those to Tensors. Lets use it in the right way, with Tensors

In [46]:
add_log_tf(tf.constant(1.), tf.constant(2.))  # first compilation

running Python
running TensorFlow


<tf.Tensor: shape=(), dtype=float32, numpy=0.6931472>

In [47]:
add_log_tf(tf.constant(11.), tf.constant(22.))

running TensorFlow


<tf.Tensor: shape=(), dtype=float32, numpy=5.488938>

Now only the TensorFlow operations get executed! Everything else became static. We can illustrate this more extremely here

In [48]:
@tf.function(autograph=False)
def add_rnd(x):
    print('running Python')
    tf.print("running TensorFlow")
    rnd_np = np.random.uniform()
    rnd_tf = tf.random.uniform(shape=())
    return x * rnd_np, x * rnd_tf

In [49]:
add_rnd(tf.constant(1.))

running Python
running TensorFlow


(<tf.Tensor: shape=(), dtype=float32, numpy=0.5509533>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5707332>)

The first time, the numpy code was executed as well, no difference so far. However, running it a second time, only the TensorFlow parts can change

In [50]:
add_rnd(tf.constant(1.))

running TensorFlow


(<tf.Tensor: shape=(), dtype=float32, numpy=0.5509533>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6005988>)

In [51]:
add_rnd(tf.constant(2.))

running TensorFlow


(<tf.Tensor: shape=(), dtype=float32, numpy=1.1019067>,
 <tf.Tensor: shape=(), dtype=float32, numpy=1.6360176>)

We see now clearly: TensorFlow executes the function but _only cares about the TensorFlow operations_ , everything else is regarded as static. This can be a large pitfall! If we would execute this function _without_ the decorator, we would get a different result, since Numpy is also sampling a new random variable every time.

## Using XLA

So far, the `tf.function` created a computational graph that allowed inputs to have various shapes without recompilation. Instead of this representation, another representation, `XLA` (accerelated Linear Algebra) is also available. It's a more strict subset of what the graph representation allows (i.e. not dynamic shapes, no `tf.print`!) but also more performant. Technically, it lowers to LLVM IR and performs optimizations at this level.

To enable it in TF, we can use the switch `jit_compile=True`

In [52]:
@tf.function(autograph=False, jit_compile=True)
def add_rnd(x):
    print('running Python')
    # tf.print("running TensorFlow")  # not available in XLA!
    rnd_np = np.random.uniform()
    rnd_tf = tf.random.uniform(shape=())
    return x * rnd_np, x * rnd_tf

In [53]:
add_rnd(tf.constant(2.))

running Python




(<tf.Tensor: shape=(), dtype=float32, numpy=0.452369>,
 <tf.Tensor: shape=(), dtype=float32, numpy=1.1837895>)

## JAX JIT

JAX is built around `XLA` and lowers everything to this representation. In comparison to TF, it's more specific when to specialize on an argument.

The most important one is to specify the static arguments (i.e. that are used to create a new specialization of the function), `static_argnums` and `static_argnames` to specify the position or name of the argument, respectively.

In [54]:
@jax.jit
def squaref(x):
    return x ** 2

In [55]:
from functools import partial

In [56]:
@partial(jax.jit, static_argnames=['subtract'])
def square_or_subtract(x, subtract):
    if subtract:
        return x - 2
    else:
        return x ** 2

In [57]:
squaref(4.)

Array(16., dtype=float64, weak_type=True)

In [58]:
square_or_subtract(4., True)

Array(2., dtype=float64, weak_type=True)

In [59]:
square_or_subtract(4., False)

Array(16., dtype=float64, weak_type=True)

To pass an argument 

### Large functions

That being said, we can build graphs that require thousands of lines of Python code to stick them together correctly. Function calls in function calls etc are all possible.

### Shapes

Tensors have a shape, similar to Numpy arrays. But Tensors have two kind of shapes, a static and a dynamic shape. The static shape is what can be inferred _before_ executing the computation while the dynamic shape is only inferred during the execution of the code. The latter typically arises with random variables and masking or cuts.

We can access the static shape with `Tensor.shape`

If the shape is known inside a graph, this will be the same. If the shape is unknown, the unknown axis will be None.

Note that unknown shapes are not supported in `XLA` (and therefore not at all in JAX).

In [126]:
@tf.function
def func_shape_tf(x):
    print(f"static shape: {x.shape}")  # static shape
    tf.print('dynamic shape ',tf.shape(x))  # dynamic shape
    x = x[x>3.5]
    print(f"static shape cuts applied: {x.shape}")  # static shape
    tf.print('dynamic shape cuts applied',tf.shape(x))  # dynamic shape

In [61]:
func_shape_tf(rnd1)

static shape: (10,)
static shape cuts applied: (None,)
dynamic shape  [10]
dynamic shape cuts applied [6]


We can access the axes by indexing

In [62]:
rnd3 = rnd1[None, :] * rnd1[:, None]
rnd3

<tf.Tensor: shape=(10, 10), dtype=float64, numpy=
array([[8.00120164e+01, 7.83703608e+01, 1.36122048e+01, 1.63077097e+00,
        2.86641111e+01, 3.94865689e+01, 7.26608158e+01, 1.91135231e+01,
        4.33561180e+01, 7.62084375e+01],
       [7.83703608e+01, 7.67623879e+01, 1.33329148e+01, 1.59731144e+00,
        2.80759919e+01, 3.86763988e+01, 7.11699892e+01, 1.87213592e+01,
        4.24665539e+01, 7.46448222e+01],
       [1.36122048e+01, 1.33329148e+01, 2.31580365e+00, 2.77438182e-01,
        4.87653941e+00, 6.71773176e+00, 1.23615671e+01, 3.25172646e+00,
        7.37604655e+00, 1.29651133e+01],
       [1.63077097e+00, 1.59731144e+00, 2.77438182e-01, 3.32376818e-02,
        5.84219749e-01, 8.04798492e-01, 1.48094191e+00, 3.89563717e-01,
        8.83665998e-01, 1.55324803e+00],
       [2.86641111e+01, 2.80759919e+01, 4.87653941e+00, 5.84219749e-01,
        1.02688484e+01, 1.41459677e+01, 2.60305613e+01, 6.84737336e+00,
        1.55322243e+01, 2.73014882e+01],
       [3.94865689e+01, 3

In [63]:
tf.shape(rnd3)

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([10, 10], dtype=int32)>

In [64]:
rnd3.shape[1]

10

## Variables

Stateful variable pose many problems for performance reasons, and make functions not idempotent! JAX therefore completely omits variables.

TensorFlow offers the possibility to have statefull objects inside a compiled graph (which e.g. is not possible with Numba). The most commonly used one is the `tf.Variable`. Technically, they are automatically captured on the function compilation and belong to it.

In [65]:
var1 = tf.Variable(1.)

In [66]:
@tf.function(autograph=False)
def scale_by_var(x):
    print('running Python')
    tf.print("running TensorFlow")
    return x * var1

In [67]:
scale_by_var(tf.constant(1.))

running Python
running TensorFlow


<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

In [68]:
scale_by_var(tf.constant(2.))

running TensorFlow


<tf.Tensor: shape=(), dtype=float32, numpy=2.0>

In [69]:
var1.assign(42.)
scale_by_var(tf.constant(1.))

running TensorFlow


<tf.Tensor: shape=(), dtype=float32, numpy=42.0>

As we see, the output changed. This is of course especially useful in the context of model fitting libraries, be it likelihoods or neural networks.

In [70]:
def add_rnd(x):
    print('running Python')
    tf.print("running TensorFlow")
    rnd_np = np.random.uniform()
    rnd_tf = tf.random.uniform(shape=())
    return x * rnd_np, x * rnd_tf

In [71]:
add_rnd(tf.constant(1.))

running Python
running TensorFlow


(<tf.Tensor: shape=(), dtype=float32, numpy=0.3851656>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.8145858>)

In [72]:
add_rnd(tf.constant(2.))

running Python
running TensorFlow


(<tf.Tensor: shape=(), dtype=float32, numpy=0.09448713>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.79835534>)

This means that we can use Numpy fully compatible in eager mode, but not when decorated.

In [73]:
def try_np_sqrt(x):
    return np.sqrt(x)

In [74]:
try_np_sqrt(tf.constant(5.))

2.236068

In [75]:
try_np_sqrt_tf = tf.function(try_np_sqrt, autograph=False)  # equivalent to decorator

In [76]:
try:
    try_np_sqrt_tf(tf.constant(5.))
except NotImplementedError as error:
    print(error)

Cannot convert a symbolic tf.Tensor (x:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.


As we see, Numpy complains in the graph mode, given that it cannot handle the Symbolic Tensor.

Having the `tf.function` decorator means that we can't use any Python dynamicity. What fails when decorated but works nicely if not:

In [77]:
def greater_python(x, y):
    if x > y:
        return True
    else:
        return False

In [78]:
greater_python(tf.constant(1.), tf.constant(2.))

False

This works again, and will fail with the graph decorator.

In [79]:
greater_python_tf = tf.function(greater_python, autograph=False)

In [80]:
try:
    greater_python_tf(tf.constant(1.), tf.constant(2.))
except Exception as error:
    print(error)

Using a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code for more information.


The error message hints at something: while this does not work now - Python does not yet now the value of the Tensors so it can't decide whether it will evaluate to True or False - there is the possibility of "autograph": it automatically converts (a subset) of Python to TensorFlow: while loops, for loops through Tensors and conditionals. However, this is usually less effective and more errorprone than using explicitly the `tf.*` functions. Lets try it!

In [81]:
greater_python_tf_autograph = tf.function(greater_python, autograph=True)

In [82]:
greater_python_tf_autograph(tf.constant(1.), tf.constant(2.))

<tf.Tensor: shape=(), dtype=bool, numpy=False>

This now works neatless! But we're never sure.

To do it explicitly, we can do that as well.

In [83]:
code = tf.autograph.to_code(greater_python)
print(code)

def tf__greater_python(x, y):
    with ag__.FunctionScope('greater_python', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal retval_, do_return
            do_return, retval_ = vars_

        def if_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = True
            except:
                do_return = False
                raise

        def else_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = False
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.ld(x) > ag__.ld(y), if_body, else_body, get_state, set_state, ('do_r

## Performance

In the end, this is what matters. And a comparison would be nice. Let's do that and see how Numpy and TensorFlow compare.

In [84]:
nevents = 10000000
data_tf = tf.random.uniform(shape=(nevents,), dtype=tf.float64)
data_np = np.random.uniform(size=(nevents,))

In [85]:
def calc_np(x):
    x_init = x
    i = 42.
    x = np.sqrt(np.abs(x_init * (i + 1.)))
    x = np.cos(x - 0.3)
    x = np.power(x, i + 1)
    x = np.sinh(x + 0.4)
    x = x ** 2
    x = x / np.mean(x)
    x = np.abs(x)
    logx = np.log(x)
    x = np.mean(logx)
    
    x1 = np.sqrt(np.abs(x_init * (i + 1.)))
    x1 = np.cos(x1 - 0.3)
    x1 = np.power(x1, i + 1)
    x1 = np.sinh(x1 + 0.4)
    x1 = x1 ** 2
    x1 = x1 / np.mean(x1)
    x1 = np.abs(x1)
    logx = np.log(x1)
    x1 = np.mean(logx)
    
    x2 = np.sqrt(np.abs(x_init * (i + 1.)))
    x2 = np.cos(x2 - 0.3)
    x2 = np.power(x2, i + 1)
    x2 = np.sinh(x2 + 0.4)
    x2 = x2 ** 2
    x2 = x2 / np.mean(x2)
    x2 = np.abs(x2)
    logx = np.log(x2)
    x2 = np.mean(logx)
    return x + x1 + x2

calc_np_numba = numba.njit(parallel=True)(calc_np)

In [86]:
def calc_tf(x):
    x_init = x
    i = 42.
    x = tf.sqrt(tf.abs(x_init * (tf.cast(i, dtype=tf.float64) + 1.)))
    x = tf.cos(x - 0.3)
    x = tf.pow(x, tf.cast(i + 1, tf.float64))
    x = tf.sinh(x + 0.4)
    x = x ** 2
    x = x / tf.reduce_mean(x)
    x = tf.abs(x)
    x = tf.reduce_mean(tf.math.log(x))
    
    x1 = tf.sqrt(tf.abs(x_init * (tf.cast(i, dtype=tf.float64) + 1.)))
    x1 = tf.cos(x1 - 0.3)
    x1 = tf.pow(x1, tf.cast(i + 1, tf.float64))
    x1 = tf.sinh(x1 + 0.4)
    x1 = x1 ** 2
    x1 = x1 / tf.reduce_mean(x1)
    x1 = tf.abs(x1)
    
    x2 = tf.sqrt(tf.abs(x_init * (tf.cast(i, dtype=tf.float64) + 1.)))
    x2 = tf.cos(x2 - 0.3)
    x2 = tf.pow(x2, tf.cast(i + 1, tf.float64))
    x2 = tf.sinh(x2 + 0.4)
    x2 = x2 ** 2
    x2 = x2 / tf.reduce_mean(x2)
    x2 = tf.abs(x2)
    
    return x + x1 + x2

calc_tf_func = tf.function(calc_tf, autograph=False)

In [87]:
%%timeit -n1 -r1  # compile time, just for curiosity
calc_tf_func(data_tf)

283 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [88]:
%%timeit -n1 -r1  # compile time, just for curiosity
calc_np_numba(data_np)

2.24 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [89]:
%timeit calc_np(data_np)  # not compiled

2.09 s ± 26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [90]:
%timeit calc_tf(data_tf)  # not compiled

1.53 s ± 137 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [91]:
%%timeit -n1 -r7
calc_np_numba(data_np)

410 ms ± 71.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [92]:
%%timeit -n1 -r7
calc_tf_func(data_tf)

260 ms ± 15.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


We can now play around with this numbers. Depending on the size (we can go up to 10 mio) and parallelizability of the problem, the numbers differ..

In general:
- Numpy is faster for small numbers
- TensorFlow is faster for larger arrays and well parallelizable computations. Due to the larger overhead in dispatching in eager mode, it is significantly slower for very small (1-10) sample sizes.

=> there is no free lunch

Note: this has not run on a GPU, which would automatically happen for TensorFlow.

In [93]:
def calc_tf2(x, n):
    sum_init = tf.zeros_like(x)
    for i in range(1, n + 1):
        x = tf.sqrt(tf.abs(x * (tf.cast(i, dtype=tf.float64) + 1.)))
        x = tf.cos(x - 0.3)
        x = tf.pow(x, tf.cast(i + 1, tf.float64))
        x = tf.sinh(x + 0.4)
        x = x ** 2
        x = x / tf.reduce_mean(x, axis=None)
        x = tf.abs(x)
        x = x - tf.reduce_mean(tf.math.log(x, name="Jonas_log"), name="Jonas_mean")  # name for ops, see later ;)
        sum_init += x
    return sum_init

calc_tf_func2 = tf.function(calc_tf2, autograph=False)

@numba.njit(parallel=True)  # njit is equal to jit(nopython=True), meaning "compile everything or raise error"
def calc_numba2(x, n):
    sum_init = np.zeros_like(x)
    for i in range(1, n + 1):
        x = np.sqrt(np.abs(x * (i + 1.)))
        x = np.cos(x - 0.3)
        x = np.power(x, i + 1)
        x = np.sinh(x + 0.4)
        x = x ** 2
        x = x / np.mean(x)
        x = np.abs(x)
        x = x - np.mean(np.log(x))
        sum_init += x
    return sum_init

In [94]:
%%timeit -n1 -r1  #compile
calc_numba2(rnd1_big.numpy(), 1)

788 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [95]:
calc_numba2(rnd1_big.numpy(), 1)

array([0.62229937, 0.53256975, 0.49222939, ..., 1.66407892, 0.64902505,
       0.46597241])

In [96]:
%%timeit -n1 -r1  #compile
calc_tf_func2(rnd1_big, 1)

21.5 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [97]:
calc_tf_func2(rnd1_big, 1)

<tf.Tensor: shape=(100000,), dtype=float64, numpy=
array([0.62229937, 0.53256975, 0.49222939, ..., 1.66407892, 0.64902505,
       0.46597241])>

In [98]:
%%timeit
calc_numba2(rnd1_big.numpy(), 1)

1.94 ms ± 76.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [99]:
%%timeit
calc_tf_func2(rnd1_big, 1)

1.45 ms ± 136 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [100]:
calc_tf_func2(rnd1_big, 10)

<tf.Tensor: shape=(100000,), dtype=float64, numpy=
array([16.28470874, 17.46429155, 21.23315363, ..., 16.17488668,
       15.97030245, 18.62912863])>

In [101]:
calc_numba2(rnd1_big.numpy(), 10)

array([16.28470877, 17.46429158, 21.23315366, ..., 16.1748867 ,
       15.97030248, 18.62912866])

In [102]:
%%timeit
calc_numba2(rnd1_big.numpy(), 10)

18.6 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [103]:
%%timeit
calc_tf_func2(rnd1_big, 10)

37 ms ± 906 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Exercise

Add JAX code to it, and run it once with and once without compilation.

- How does it compare?
- Rerun the TensorFlow code but use `jit_compile=True` and compare

## Control flow

While JAX & friends are independent of the Python control flow, it has its own functions for that, mainly:
- while_loop(): a while loop taking a body and condition function
- cond: if-like
- case and switch_case (TF) or switch (JAX): if/elif statements
- _where_ (which is vectorized inherently)

JAX control flow is [documented here](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators)

In [104]:
def true_fn():
    return tnp.array(1.)

def false_fn():
    return tnp.array(0.)

var1 = tnp.array(111.)
var2 = tnp.array(42.)
value = tf.cond(var1 > var2, true_fn=true_fn, false_fn=false_fn)

In [105]:
value

<tf.Tensor: shape=(), dtype=float64, numpy=1.0>

### While loops

We can create while loops in order to have some kind of repetitive task

In [106]:
def cond(x, y):
    return x > y

def body(x, y):
    return x / 2, y + 1

x, y = tf.while_loop(cond=cond,
                     body=body,
                     loop_vars=[100., 1.])

In [107]:
x, y

(<tf.Tensor: shape=(), dtype=float32, numpy=3.125>,
 <tf.Tensor: shape=(), dtype=float32, numpy=6.0>)

### map a function

We can also map a function on each element. While this is not very efficient, it allows for high flexibility.

A map is like a (parallel) for loop. More powerful (especially in JAX) is the vectorized function (like `np.vectorize`), for [JAX this is `jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap)

In [139]:
jax.lax.map(jnp.sin, jrnd1_big)

Array([0.32697651, 0.91132368, 0.77641278, ..., 0.98328917, 0.2543407 ,
       0.9074935 ], dtype=float64)

In [148]:
tf.map_fn(tf.math.sin, rnd1_big)  # This is basically a for-loop!

<tf.Tensor: shape=(100000,), dtype=float64, numpy=
array([0.32697651, 0.91132368, 0.77641278, ..., 0.98328917, 0.2543407 ,
       0.9074935 ])>

In [131]:
%%timeit -n1 -r1
tf.map_fn(tf.math.sin, rnd1_big)

29.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [137]:
%%timeit -n1 -r1
jax.lax.map(jnp.sin, jrnd1_big)  # jax always compiles, can use batch size!

32.5 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [146]:
%%timeit -n1 -r1
tf.vectorized_map(tnp.sin, rnd1_big)  # can greatly speedup things sometimes

16.8 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [143]:
vsin = jax.vmap(jnp.sin)

In [145]:
%%timeit -n1 -r1
vsin(jrnd1_big)

1.4 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [147]:
@tf.function
def do_map(func, tensor):
    return tf.map_fn(func, tensor)

do_map(tf.math.sin, rnd1_big)

<tf.Tensor: shape=(100000,), dtype=float64, numpy=
array([0.32697651, 0.91132368, 0.77641278, ..., 0.98328917, 0.2543407 ,
       0.9074935 ])>

In [111]:
@tf.function
def do_map_vec(func, tensor):
    return tf.vectorized_map(func, tensor)

do_map_vec(tf.math.sin, rnd1_big)

<tf.Tensor: shape=(100000,), dtype=float64, numpy=
array([0.32697651, 0.91132368, 0.77641278, ..., 0.98328917, 0.2543407 ,
       0.9074935 ])>

In [149]:
@partial(jax.jit, static_argnames=['func'])
def do_map_vec_jax(func, tensor):
    vec_func = jax.vmap(func)
    return vec_func(tensor)

do_map_vec_jax(jnp.sin, jrnd1_big)

Array([0.32697651, 0.91132368, 0.77641278, ..., 0.98328917, 0.2543407 ,
       0.9074935 ], dtype=float64)

In [112]:
%%timeit
do_map(tnp.sin, rnd1_big)

793 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [113]:
%%timeit
do_map_vec(tnp.sin, rnd1_big)

557 μs ± 14.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [150]:
%%timeit
do_map_vec_jax(jnp.sin, jrnd1_big)

510 μs ± 18.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [151]:
%%timeit
tnp.sin(rnd1_big)

214 μs ± 15.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [152]:
%%timeit
jnp.sin(jrnd1_big)

494 μs ± 26.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


As we can see, the generic mapping is surely not optimal. However, it works "always". `vectorized_map` on the other hand has a huge speedup and performs nearly as well as using the native function! However, while this works nicely for this case, it's applications are limited and depend heavily on the use-case; more complicated examples can easily result in a longer runtime and a huge memory consumption. Caution is therefore advised when using this function.

## Gradients

TensorFlow (and PyTorch) allows us to calculate the automatic gradients using a gnadient tape.

In [115]:
x = tnp.array(2.)
with tf.GradientTape() as tape:
    tape.watch(x)
    y = x ** 3
y

<tf.Tensor: shape=(), dtype=float64, numpy=8.0>

In [116]:
grad = tape.gradient(y, x)
grad

<tf.Tensor: shape=(), dtype=float64, numpy=12.0>

JAX has a slightly different approach: it creates a gradient functions.

In [117]:
grad = jax.grad(lambda x: x ** 3)

In [118]:
grad(jnp.array(2.))

Array(12., dtype=float64, weak_type=True)

## Exercise

Try to get higher derivatives. Have a look at [the JAX guide on derivatives](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#gradients)

 - can you get the second derivative?

This allows to do many things with gradients and e.g. solve differential equations.

## Behind the scenes: computational graph

We talked about the computational graph back and forth, but _where is it_ ?

The graph (in TensorFlow) can be retained from a function that was already traced.

In [119]:
concrete_func = calc_tf_func2.get_concrete_function(rnd1, 2)
concrete_func

<ConcreteFunction (x: TensorSpec(shape=(10,), dtype=tf.float64, name=None), n: Literal[2]) -> TensorSpec(shape=(10,), dtype=tf.float64, name=None) at 0x7CD7DB5C00D0>

In [120]:
graph = concrete_func.graph
graph

<tensorflow.python.framework.func_graph.FuncGraph at 0x7cd7b214ff40>

In [121]:
ops = graph.get_operations()
ops

[<tf.Operation 'x' type=Placeholder>,
 <tf.Operation 'zeros_like' type=Const>,
 <tf.Operation 'Cast/x' type=Const>,
 <tf.Operation 'Cast' type=Cast>,
 <tf.Operation 'add/y' type=Const>,
 <tf.Operation 'add' type=AddV2>,
 <tf.Operation 'mul' type=Mul>,
 <tf.Operation 'Abs' type=Abs>,
 <tf.Operation 'Sqrt' type=Sqrt>,
 <tf.Operation 'sub/y' type=Const>,
 <tf.Operation 'sub' type=Sub>,
 <tf.Operation 'Cos' type=Cos>,
 <tf.Operation 'Cast_1/x' type=Const>,
 <tf.Operation 'Cast_1' type=Cast>,
 <tf.Operation 'Pow' type=Pow>,
 <tf.Operation 'add_1/y' type=Const>,
 <tf.Operation 'add_1' type=AddV2>,
 <tf.Operation 'Sinh' type=Sinh>,
 <tf.Operation 'pow_1/y' type=Const>,
 <tf.Operation 'pow_1' type=Pow>,
 <tf.Operation 'Const' type=Const>,
 <tf.Operation 'Mean' type=Mean>,
 <tf.Operation 'truediv' type=RealDiv>,
 <tf.Operation 'Abs_1' type=Abs>,
 <tf.Operation 'Jonas_log' type=Log>,
 <tf.Operation 'Const_1' type=Const>,
 <tf.Operation 'Jonas_mean' type=Mean>,
 <tf.Operation 'sub_1' type=Sub>,
 

In [122]:
log_op = ops[-6]
log_op

<tf.Operation 'Jonas_log_1' type=Log>

In [123]:
log_op.outputs

[<tf.Tensor 'Jonas_log_1:0' shape=(10,) dtype=float64>]

In [124]:
op_inputs_mean = ops[-4].inputs
op_inputs_mean

(<tf.Tensor 'Jonas_log_1:0' shape=(10,) dtype=float64>,
 <tf.Tensor 'Const_3:0' shape=(1,) dtype=int32>)

In [125]:
log_op.outputs[0] is op_inputs_mean[0]

True

The output of the log operation is the input to the mean operation! We can just walk along the graph here. TensorFlow Graphs are no magic, they are simple object that store their input, their output, their operation. That's it!

### jaxpr: JAX Expressions

The graph "equivalent" in JAX is an expression, a function with inputs.

In [154]:
# helper
def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

In [155]:
def foo(x):
  return x + 1
    
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

foo
=====
invars: [Var(id=137265412096960):int64[]]
outvars: [Var(id=137265412102336):int64[]]
constvars: []
equation: [Var(id=137265412096960):int64[], 1] add [Var(id=137265412102336):int64[]] {}

jaxpr: { [34m[22m[1mlambda [39m[22m[22m; a[35m:i64[][39m. [34m[22m[1mlet[39m[22m[22m b[35m:i64[][39m = add a 1 [34m[22m[1min [39m[22m[22m(b,) }


In [156]:
def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))

bar
=====
invars: [Var(id=137265208912768):float64[5,10], Var(id=137265208912896):float64[5], Var(id=137265208914624):float64[10]]
outvars: [Var(id=137265208912192):float64[5], Var(id=137265208914624):float64[10]]
constvars: []
equation: [Var(id=137265208912768):float64[5,10], Var(id=137265208914624):float64[10]] dot_general [Var(id=137265208914816):float64[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float64')}
equation: [Var(id=137265208914816):float64[5], Var(id=137265208912896):float64[5]] add [Var(id=137265208916096):float64[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=137265208915392):float64[5]] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [Var(id=137265208916096):float64[5], Var(id=137265208915392):float64[5]] add [Var(id=137265208912192):float64[5]] {}

jaxpr: { [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[5,10][39m b[35m:f64[5][39m c[35m:f64[10][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f64