# Why Φ<sub>ML</sub> has Precision Management

Having control over the floating-point (FP) precision is essential for many scientific applications.
For example, some linear systems of equations are only solvable with FP64, even if the desired tolerance lies within FP32 territory.
To accommodate these requirements, Φ<sub>ML</sub> provides custom precision management tools that differ from the common machine learning libraries.

In [2]:
%%capture
import numpy as np
import torch
import tensorflow as tf
import jax
from jax import numpy as jnp
from phiml import math

## Precision in ML libraries

First, let's look at the behavior of the backend libraries that Φ<sub>ML</sub> supports.

**Tensor creation**:
Consider creating a float tensor from primitive floats. Can you guess what the data type will be for `tensor(1.)` (or the analogous operations) in NumPy, PyTorch, TensorFlow and Jax?

In [3]:
print(f"NumPy: {np.array(1.).dtype}\nPyTorch: {torch.tensor(1.).dtype}\nTensorFlow: {tf.constant(1.).dtype}\nJax: {jnp.array(1.).dtype}")

NumPy: float64
PyTorch: torch.float32
TensorFlow: <dtype: 'float32'>
Jax: float32


IF you guessed `float64` for NumPy, `float32` for PyTorch and TensorFlow, and *depends on the configuration* for Jax, you are correct!
Yes, Jax disables FP64 by default! Let's repeat that with FP64 enabled.

In [4]:
jax.config.update("jax_enable_x64", True)
print(f"Jax: {jnp.array(1.).dtype}")

Jax: float64


Now, Jax behaves like NumPy! Or does it...?

**Combining different precisions**:
What do you think will happen in each of the base libraries if we sum a FP64 and FP32 tensor?
Let's try it!

In [5]:
(np.array(1., dtype=np.float32) + np.array(1., dtype=np.float64)).dtype  # NumPy

dtype('float64')

In [6]:
(torch.tensor(1., dtype=torch.float32) + torch.tensor(1., dtype=torch.float64)).dtype  # PyTorch

torch.float64

NumPy and PyTorch automatically upgrade to the highest precision.
However, unlike NumPy, PyTorch does not upgrade its `dtype` when adding a primitive `float`.

In [7]:
(np.array(1., dtype=np.float32) + 1.).dtype  # NumPy

dtype('float64')

In [8]:
(torch.tensor(1., dtype=torch.float32) + 1.).dtype

torch.float32

Let's look at TensorFlow and Jax next.

In [9]:
try:
    (tf.constant(1., dtype=tf.float32) + tf.constant(1., dtype=tf.float64)).dtype  # TensorFlow
except tf.errors.InvalidArgumentError as err:
    print(err)

cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a double tensor [Op:AddV2]


TensorFlow outright refuses to mix different precisions and requires manual casting.
This is not the case when passing a primitive `float` which is also FP64. Here, TensorFlow keeps the tensor `dtype`.

In [11]:
(tf.constant(1., dtype=tf.float32) + 1.).dtype  # TensorFlow

tf.float32

At first glance, Jax seems to upgrade the different precisions like NumPy.

In [12]:
(jnp.array(1., dtype=jnp.float32) + jnp.array(1., dtype=jnp.float64)).dtype  # Jax

dtype('float64')

 Let's modify the expression a bit.

In [13]:
t64 = jnp.array(1.)
print(t64.dtype)
(jnp.array(1., dtype=jnp.float32) + t64).dtype

float64


dtype('float32')

Here we also add a `float64` to a `float32` tensor but the result now is `float32`.
Jax remembers that we did not explicitly specify the type of the `t64` tensor and treats it differently.

Also, Jax does not upgrade the precision when adding a `float`.

In [15]:
(jnp.array(1., dtype=jnp.float32) + 1.).dtype  # Jax

dtype('float32')

**Converting integer tensors**:
Let's look at the behavior when combining a `float32` and an `int` tensor in the different libraries. Can you guess what the result type will be?

In [16]:
(np.array(1., dtype=np.float32) + np.array(1)).dtype  # NumPy

dtype('float64')

In [17]:
(torch.tensor(1., dtype=torch.float32) + torch.tensor(1)).dtype  # PyTorch

torch.float32

In [18]:
try:
    (tf.constant(1., dtype=tf.float32) + tf.constant(1)).dtype  # TensorFlow
except tf.errors.InvalidArgumentError as err:
    print(err)

cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a int32 tensor [Op:AddV2]


In [19]:
(jnp.array(1., dtype=jnp.float32) + jnp.array(1)).dtype  # Jax

dtype('float32')

We see that NumPy upgrades to 64 bit while PyTorch and Jax keep 32. Like before, TensorFlow refuses to combine different types.
When adding a primitive `int` instead, TensorFlow can perform the operation, however.

In [20]:
(tf.constant(1., dtype=tf.float32) + 1).dtype

tf.float32

### Observations

We have seen that there is no consistent type handling between the four libraries. In fact no two libraries behave the same.

* NumPy defaults to `float64` and upgrades when combining tensors and primitives, including `int`.
* PyTorch defaults to `float32` and upgrades only for float tensors, not primitives or integer tensors.
* Jax defaults to the precision specified by its configuration and uses involved upgrading rules that take into account whether the initial precision was set or inferred.
* TensorFlow defaults to `float32` but requires all tensors to have the same precision, except for Python primitives.

| Library    | `f32+f64` | `f32` + primitive `f64` | `f32+i32` |
|------------|-----------|-------------------------|-----------|
| NumPy      | `f64`     | `f64`                   | `f64`     |
| PyTorch    | `f64`     | `f32`                   | `f32`     |
| TensorFlow | Error     | `f32`                   | Error     |
| Jax        | Depends   | `f32`                   | `f32`     |

These inconsistencies indicate that there is not one obvious correct way to handle precision with the data type system these libraries employ, i.e. where the output `dtype` is determined solely by the input types.

## Precision Management in Φ<sub>ML</sub>

In Φ<sub>ML</sub> the operation / output precision is independent of the inputs. Instead, it can be set globally or by context.
The default precision is FP32.

**Tensor creation:** Let's create a tensor like above. Can you guess the resulting `dtype`?

In [48]:
math.tensor(1.).dtype

float32

Since we have not changed the precision, Φ<sub>ML</sub> creates an FP32 tensor.

**Combining different precisions**:
Can you guess what will happen if we add a `float32` and `float64` tensor?

In [49]:
(math.ones(dtype=(float, 32)) + math.ones(dtype=(float, 64))).dtype

float32

The precision is still set to `float32` so that's what we get.
Of course this also applies to adding Python primitives or `int` tensors.

In [50]:
(math.ones(dtype=(float, 32)) + math.ones(dtype=int)).dtype

float32

If we want to use FP64, we can either set the global precision or execute the code within a precision context.
The following line sets the global precision to 64 bit.

In [47]:
math.set_global_precision(64)

Executing the above cells now yields `float64` in all cases.
Likewise, the precision can be set to 16 bit. In that case we get `float16` even when adding a `float32` and `float64` tensor.

As you can see, this system is much simpler and more predictable than the alternatives.
It also makes writing code much easier. Upgrading a script that was written for FP32 to FP64 is as simple as setting the global precision, and executing parts of your code with a different precision is as simple as embedding it into a precision block (see example below).

## An Example of Mixing Precisions

Let's look at a simple application where we want to run operations with both FP32 and PF64, specifically iterate the map `35 (1-cos(x))^2`. The operation `1-cos` is much more sensitive to rounding errors than multiplication, so we wish to compute it using FP64.
The expected values after 5 iterations are: 0.2659 (FP64), 0.2663 (FP32), 0.2657 (mixed).

Here's the Φ<sub>ML</sub> code. We use a `precision` context to execute the inner part with FP64.

In [74]:
math.set_global_precision(32)  # reset precision to 32 bit
x = math.tensor(.5)
for i in range(5):
    with math.precision(64):
        x = 1 - math.cos(x)
    x = x ** 2 * 35
x

[94m0.265725[0m

Next, let's implement this using PyTorch. Here we need to manually convert `x` between FP32 and PF64.

In [73]:
x = torch.tensor(.5)
for i in range(5):
    x = x.double()
    x = 1 - torch.cos(x)
    x = x.float()
    x = x ** 2 * 35
x

tensor(0.2657)

These conversions seem relatively tame here, but imagine we had a bunch of variables to keep track of!
Making sure they all have the correct precision can be a time sink, especially when one variable with a too-high precision can upgrade all following intermediate results.
The danger of this going unnoticed is why TensorFlow and Jax have taken the extreme measures of banning operations with mixed inputs and disabling FP64 by default, respectively.

## Further Reading

[Data types in Φ<sub>ML</sub>](Data_Types.html)

[🌐 **Φ<sub>ML</sub>**](https://github.com/tum-pbs/PhiML)
&nbsp; • &nbsp; [📖 **Documentation**](https://tum-pbs.github.io/PhiML/)
&nbsp; • &nbsp; [🔗 **API**](https://tum-pbs.github.io/PhiML/phiml)
&nbsp; • &nbsp; [**▶ Videos**]()
&nbsp; • &nbsp; [<img src="images/colab_logo_small.png" height=4>](https://colab.research.google.com/github/tum-pbs/PhiML/blob/main/docs/Examples.ipynb) [**Examples**](https://tum-pbs.github.io/PhiML/Examples.html)