diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index 870077021b..4d25b9c040 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -18,6 +18,8 @@ import tensorflow as tf from tensorflow import keras +from tensorflow.experimental import dtensor +from tensorflow.experimental.dtensor import Layout from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.token_and_position_embedding import ( @@ -172,3 +174,71 @@ def token_embedding(self): @classproperty def presets(cls): return copy.deepcopy(backbone_presets) + + @classmethod + def create_layout_map(cls, mesh): + """Create a DTensor layout map for an OPTBackbone. + + Given a DTensor mesh describing a list of devices, this method returns a + DTensor layout map for creating a `keras_nlp.models.OPTBackbone` + instance. This mapping describes how to distribute all model weights + across multiple devices. For an overview of DTensor concepts, see + [this guide](https://www.tensorflow.org/guide/dtensor_overview). + + Args: + mesh: A 2D `tf.experimental.dtensor.Mesh` describing the arrangement + of devices for running distributed computation. The + first dimension in the mesh is expected to be for data parallel + distribution, and the second for model parallel distribution. + + Returns: + A `tf.keras.dtensor.experimental.LayoutMap` which contains the + proper layout to weights mapping for the model parallel setting. + + Examples: + ```python + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) + + # Update both dimensions below for a multi-device setting. + mesh = dtensor.create_mesh([("batch", 1), ("model", 1)]) + layout_map = keras_nlp.models.OPTBackbone.create_layout_map(mesh) + + with layout_map.scope(): + model = keras_nlp.models.OPTBackbone.from_preset("opt_125m_en") + ``` + """ + # We assert the mesh is 2D, and assume the first mesh dim is for data + # parallel and the second dim is for model parallel. + mesh_shape = mesh.shape() + if len(mesh_shape) != 2: + raise ValueError( + f"Expect to create layout based on 2D mesh, received {mesh}" + ) + _, model_dim = mesh.dim_names + unshard_dim = dtensor.UNSHARDED + + layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh) + # Embedding sharding + layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh) + + # Transformer block sharding + layout_map[r".*_(query|key|value)_dense.kernel"] = Layout( + [unshard_dim, unshard_dim, model_dim], mesh + ) + layout_map[r".*_(query|key|value)_dense.bias"] = Layout( + [model_dim, unshard_dim], mesh + ) + layout_map[r".*_feedforward_intermediate_dense.kernel"] = Layout( + [unshard_dim, model_dim], mesh + ) + layout_map[r".*_feedforward_intermediate_dense.bias"] = Layout( + [model_dim], mesh + ) + layout_map[r".*_feedforward_output_dense.kernel"] = Layout( + [model_dim, unshard_dim], mesh + ) + layout_map[r".*_feedforward_output_dense.bias"] = Layout( + [unshard_dim], mesh + ) + return layout_map diff --git a/keras_nlp/models/opt/opt_backbone_test.py b/keras_nlp/models/opt/opt_backbone_test.py index 85e80d3350..a26e8d6a0a 100644 --- a/keras_nlp/models/opt/opt_backbone_test.py +++ b/keras_nlp/models/opt/opt_backbone_test.py @@ -25,6 +25,10 @@ class OPTTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): + # For DTensor. + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) + self.backbone = OPTBackbone( vocabulary_size=10, num_layers=2, @@ -91,6 +95,23 @@ def test_saved_model(self, save_format, filename): restored_output = restored_model(self.input_batch) self.assertAllClose(model_output, restored_output) + def test_create_layout_map(self): + mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)]) + with OPTBackbone.create_layout_map(mesh).scope(): + OPTBackbone( + vocabulary_size=10, + num_layers=2, + num_heads=2, + hidden_dim=2, + intermediate_dim=4, + max_sequence_length=5, + ) + # Using DTensor enables the mlir bridge as a side effect. Eventually + # this will be default, but for now we have compile errors with the + # bridge elsewhere and must disable. See + # https://github.com/keras-team/keras-nlp/issues/1001 + tf.config.experimental.disable_mlir_bridge() + @pytest.mark.tpu @pytest.mark.usefixtures("tpu_test_class") diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index ce3f4e2f1d..607297474b 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -450,3 +450,39 @@ def preprocess(x): if outputs.dtype == tf.string: return tensor_to_string_list(outputs) return outputs.numpy() + + @classmethod + def create_layout_map(cls, mesh): + """Create a DTensor layout map for an OPTCausalLM. + + Given a DTensor mesh describing a list of devices, this method returns a + DTensor layout map for creating a `keras_nlp.models.OPTCausalLM` + instance. This mapping describes how to distribute all model weights + across multiple devices. For an overview of DTensor concepts, see + [this guide](https://www.tensorflow.org/guide/dtensor_overview). + + Args: + mesh: A 2D `tf.experimental.dtensor.Mesh` describing the arrangement + of devices for running distributed computation. The + first dimension in the mesh is expected to be for data parallel + distribution, and the second for model parallel distribution. + + Returns: + A `tf.keras.dtensor.experimental.LayoutMap` which contains the + proper layout to weights mapping for the model parallel setting. + + Examples: + ```python + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) + + # Update both dimensions below for a multi-device setting. + mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)]) + layout_map = keras_nlp.models.OPTCausalLM.create_layout_map(mesh) + + with layout_map.scope(): + opt_lm = keras_nlp.models.OPTCausalLM.from_preset("opt_125m_en") + ``` + """ + # As this task has no new variables, we just re-use the backbone method. + return cls.backbone_cls.create_layout_map(mesh) diff --git a/keras_nlp/models/opt/opt_causal_lm_test.py b/keras_nlp/models/opt/opt_causal_lm_test.py index 47a5bf7a02..b995128f29 100644 --- a/keras_nlp/models/opt/opt_causal_lm_test.py +++ b/keras_nlp/models/opt/opt_causal_lm_test.py @@ -30,6 +30,10 @@ class OPTCausalLMTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): + # For DTensor. + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) + self.vocab = { "": 0, "": 1, @@ -153,3 +157,13 @@ def test_saved_model(self, save_format, filename): keras.utils.set_random_seed(42) restored_output = restored_model.predict(self.raw_batch) self.assertAllClose(model_output, restored_output) + + def test_create_layout_map(self): + mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)]) + with OPTCausalLM.create_layout_map(mesh).scope(): + OPTCausalLM(backbone=self.backbone) + # Using DTensor enables the mlir bridge as a side effect. Eventually + # this will be default, but for now we have compile errors with the + # bridge elsewhere and must disable. See + # https://github.com/keras-team/keras-nlp/issues/1001 + tf.config.experimental.disable_mlir_bridge()