## Using DTensors with Keras

In [None]:
from datetime import datetime

now = datetime.now().strftime("%a, %d %b %Y %H:%M:%S")
print(f"Created date: Wed, 17 Apr 2024 04:35:29")
print(f"Modified date: {now}")

### 1. Overview

What is DTensors? 

DTensor(`tf.experimental.dtensor`) as a part of TensorFlow version >= 2.9.0

### 2. Setup

In [None]:
!pip install --quiet --upgrade tensorflow-datasets

In [None]:
import tensorflow as tf 
import tensorflow_datasets as tfds
from tensorflow.experimental import dtensor

print(f"tensorflow version: {tf.__version__}")
print(f"tensorflow_datasets version: {tfds.__version__}")

``` python
(function) def list_physical_devices(device_type: Any | None = None) -> list[PhysicalDevice]

```

Return a list of physical devices visible to the host runtime.

Physical devices are hardware devices present on the host machine. By default all discovered CPU and GPU devices are considered visible.

This API allows querying the physical hardware resources prior to runtime initialization. Thus, giving an opportunity to call any additional configuration APIs. This is in contrast to `tf.config.list_logical_devices`, which triggers runtime initialization in order to list the configured devices.

The following example lists the number of visible GPUs on the host.
``` shell 
>>> physical_devices = tf.config.list_physical_devices('GPU')
>>> print("Num GPUs:", len(physical_devices))
Num GPUs: ...
``` 

However, the number of GPUs available to the runtime may change during runtime initialization due to marking certain devices as not visible or configuring multiple logical devices.

In [None]:
# list_physical_devices
tf.config.list_physical_devices('CPU')

``` python 
(function) def set_logical_device_configuration(
    device: Any,
    logical_devices: Any
) -> None
```

Set the logical device configuration for a `tf.config.PhysicalDevice`.

A visible `tf.config.PhysicalDevice` will by default have a single `tf.config.LogicalDevice` associated with it once the runtime is initialized. Specifying a list of tf.config.LogicalDeviceConfiguration objects allows multiple devices to be created on the same `tf.config.PhysicalDevice`.

Logical device configurations can be modified by calling this function as long as the runtime is uninitialized. After the runtime is initialized calling this function raises a RuntimeError.

The following example splits the CPU into 2 logical devices:

```shell
>>> physical_devices = tf.config.list_physical_devices('CPU')
>>> assert len(physical_devices) == 1, "No CPUs found"
>>> # Specify 2 virtual CPUs. Note currently memory limit is not supported.
>>> try:
...   tf.config.set_logical_device_configuration(
...     physical_devices[0],
...     [tf.config.LogicalDeviceConfiguration(),
...      tf.config.LogicalDeviceConfiguration()])
...   logical_devices = tf.config.list_logical_devices('CPU')
...   assert len(logical_devices) == 2
...
...   tf.config.set_logical_device_configuration(
...     physical_devices[0],
...     [tf.config.LogicalDeviceConfiguration(),
...      tf.config.LogicalDeviceConfiguration(),
...      tf.config.LogicalDeviceConfiguration(),
...      tf.config.LogicalDeviceConfiguration()])
... except:
...   # Cannot modify logical devices once initialized.
...   pass
```

The following example splits the GPU into 2 logical devices with 100 MB each:

``` shell
>>> physical_devices = tf.config.list_physical_devices('GPU')
>>> try:
...   tf.config.set_logical_device_configuration(
...     physical_devices[0],
...     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
...      tf.config.LogicalDeviceConfiguration(memory_limit=100)])
...
...   logical_devices = tf.config.list_logical_devices('GPU')
...   assert len(logical_devices) == len(physical_devices) + 1
...
...   tf.config.set_logical_device_configuration(
...     physical_devices[0],
...     [tf.config.LogicalDeviceConfiguration(memory_limit=10),
...      tf.config.LogicalDeviceConfiguration(memory_limit=10)])
... except:
...   # Invalid device or cannot modify logical devices once initialized.
...   pass
```

In [None]:
# check a number of physical CPUs
physical_devices = tf.config.list_physical_devices("CPU")
assert len(physical_devices) ==  1, "No CPUs found"

# specify 2 virtual CPUs 
try:
    tf.config.set_logical_device_configuration(
        physical_devices[0],
        [
            tf.config.LogicalDeviceConfiguration(),
            tf.config.LogicalDeviceConfiguration()
        ])
    
    logical_devices = tf.config.list_logical_devices("CPU")
    assert len(logical_devices) == 2

    tf.config.set_logical_device_configuration(
        physical_devices[0],
        [
            tf.config.LogicalDeviceConfiguration(),
            tf.config.LogicalDeviceConfiguration(),
            tf.config.LogicalDeviceConfiguration(),
            tf.config.LogicalDeviceConfiguration()
        ])

except Exception as error:
    # cannot modify logical devices once initialized
    print(f"Catch this error: {error}")

print(f"Number of logical CPUs: {len(logical_devices)}")

In [None]:
def configure_virtual_cpus(num_cpus: int):
    physical_devices = tf.config.list_physical_devices('CPU')

    # split one physical device into `num_cpus` logical devices
    tf.config.set_logical_device_configuration(
        physical_devices[0],
        [tf.config.LogicalDeviceConfiguration()] * num_cpus)
    
# config virtual cpus
configure_virtual_cpus(4)

# get a list of logical devices
virtual_logical_devices = tf.config.list_logical_devices("CPU")

print(f"Num of virtual CPUs: {len(virtual_logical_devices)}")
