## Using DTensors with Keras

In [1]:
from datetime import datetime

now = datetime.now().strftime("%a, %d %b %Y %H:%M:%S")
print(f"Created date: Wed, 17 Apr 2024 04:35:29")
print(f"Modified date: {now}")

Created date: Wed, 17 Apr 2024 04:35:29
Modified date: Thu, 18 Apr 2024 04:31:09


### 1. Overview

What is DTensors? 

DTensor(`tf.experimental.dtensor`) as a part of TensorFlow version >= 2.9.0

### 2. Setup

In [2]:
# !pip install --quiet --upgrade tensorflow-datasets

In [3]:
import tensorflow as tf 
import tensorflow_datasets as tfds
from tensorflow.experimental import dtensor

print(f"tensorflow version: {tf.__version__}")
print(f"tensorflow_datasets version: {tfds.__version__}")

2024-04-18 04:31:09.978927: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-18 04:31:10.017703: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


tensorflow version: 2.16.1
tensorflow_datasets version: 4.9.4


``` python
(function) def list_physical_devices(device_type: Any | None = None) -> list[PhysicalDevice]

```

Return a list of physical devices visible to the host runtime.

Physical devices are hardware devices present on the host machine. By default all discovered CPU and GPU devices are considered visible.

This API allows querying the physical hardware resources prior to runtime initialization. Thus, giving an opportunity to call any additional configuration APIs. This is in contrast to `tf.config.list_logical_devices`, which triggers runtime initialization in order to list the configured devices.

The following example lists the number of visible GPUs on the host.
``` shell 
>>> physical_devices = tf.config.list_physical_devices('GPU')
>>> print("Num GPUs:", len(physical_devices))
Num GPUs: ...
``` 

However, the number of GPUs available to the runtime may change during runtime initialization due to marking certain devices as not visible or configuring multiple logical devices.

In [4]:
# list_physical_devices
tf.config.list_physical_devices('CPU')

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]

``` python 
(function) def set_logical_device_configuration(
    device: Any,
    logical_devices: Any
) -> None
```

Set the logical device configuration for a `tf.config.PhysicalDevice`.

A visible `tf.config.PhysicalDevice` will by default have a single `tf.config.LogicalDevice` associated with it once the runtime is initialized. Specifying a list of tf.config.LogicalDeviceConfiguration objects allows multiple devices to be created on the same `tf.config.PhysicalDevice`.

Logical device configurations can be modified by calling this function as long as the runtime is uninitialized. After the runtime is initialized calling this function raises a RuntimeError.

The following example splits the CPU into 2 logical devices:

```shell
>>> physical_devices = tf.config.list_physical_devices('CPU')
>>> assert len(physical_devices) == 1, "No CPUs found"
>>> # Specify 2 virtual CPUs. Note currently memory limit is not supported.
>>> try:
...   tf.config.set_logical_device_configuration(
...     physical_devices[0],
...     [tf.config.LogicalDeviceConfiguration(),
...      tf.config.LogicalDeviceConfiguration()])
...   logical_devices = tf.config.list_logical_devices('CPU')
...   assert len(logical_devices) == 2
...
...   tf.config.set_logical_device_configuration(
...     physical_devices[0],
...     [tf.config.LogicalDeviceConfiguration(),
...      tf.config.LogicalDeviceConfiguration(),
...      tf.config.LogicalDeviceConfiguration(),
...      tf.config.LogicalDeviceConfiguration()])
... except:
...   # Cannot modify logical devices once initialized.
...   pass
```

The following example splits the GPU into 2 logical devices with 100 MB each:

``` shell
>>> physical_devices = tf.config.list_physical_devices('GPU')
>>> try:
...   tf.config.set_logical_device_configuration(
...     physical_devices[0],
...     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
...      tf.config.LogicalDeviceConfiguration(memory_limit=100)])
...
...   logical_devices = tf.config.list_logical_devices('GPU')
...   assert len(logical_devices) == len(physical_devices) + 1
...
...   tf.config.set_logical_device_configuration(
...     physical_devices[0],
...     [tf.config.LogicalDeviceConfiguration(memory_limit=10),
...      tf.config.LogicalDeviceConfiguration(memory_limit=10)])
... except:
...   # Invalid device or cannot modify logical devices once initialized.
...   pass
```

In [5]:
def configure_virtual_cpus(num_cpus: int):
    physical_devices = tf.config.list_physical_devices('CPU')

    # split one physical device into `num_cpus` logical devices
    tf.config.set_logical_device_configuration(
        physical_devices[0],
        [tf.config.LogicalDeviceConfiguration()] * num_cpus)
    
# config virtual cpus
configure_virtual_cpus(4)

# get a list of logical devices
virtual_logical_devices = tf.config.list_logical_devices("CPU")

print(f"Num of virtual CPUs: {len(virtual_logical_devices)}")

devices = [f'CPU:{i}' for i in range(4)]


Num of virtual CPUs: 4


### 3. Deterministic pseudo-random number generators


`DTensor API` requires each of the running client to have the same random seeds to have deterministic behaviour for initializing the weights. 

In [14]:
import keras
print(f"Keras version: {keras.__version__}")
print(f"TensorFlow version: {tf.__version__}")

# This API is removed from the current versions
# tf.keras.backend.experimental.enable_tf_random_generator() # AttributeError: module 'keras.backend' has no attribute 'experimental'

# TensorFlow version: 2.16.1
tf.config.experimental.enable_op_determinism()
tf.keras.utils.set_random_seed(1337)

Keras version: 3.2.1
TensorFlow version: 2.16.1


### 4. Creating a Data Parallel Mesh

`Data Parallel` training is a commonly used parallel tranining scheme.

For example: 

`tf.distribute.MirrorStrategy`

With `DTensor`, a `Data Parallel` training loop uses a `Mesh` that consists of a single 'batch' dimension, where each device runs a replica of the model that receives a shard from the global batch.

In [16]:
mesh = dtensor.create_mesh([("batch", 4)], devices=devices)
mesh

Mesh.from_string(|batch=4|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1,/job:localhost/replica:0/task:0/device:CPU:2,/job:localhost/replica:0/task:0/device:CPU:3)

### References

1. https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism