From d8a3227df5a3825d5342703c21a1ad29c931cd4d Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 2 May 2023 19:37:52 +0000 Subject: [PATCH] Add dtensor API for GPT2 --- keras_nlp/models/gpt2/gpt2_backbone.py | 70 ++++++++++++++++++++ keras_nlp/models/gpt2/gpt2_backbone_test.py | 21 ++++++ keras_nlp/models/gpt2/gpt2_causal_lm.py | 36 ++++++++++ keras_nlp/models/gpt2/gpt2_causal_lm_test.py | 14 ++++ 4 files changed, 141 insertions(+) diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index 6eae4b7439..9a611049eb 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -18,6 +18,8 @@ import tensorflow as tf from tensorflow import keras +from tensorflow.experimental import dtensor +from tensorflow.experimental.dtensor import Layout from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.position_embedding import PositionEmbedding @@ -194,3 +196,71 @@ def token_embedding(self): @classproperty def presets(cls): return copy.deepcopy(backbone_presets) + + @classmethod + def create_layout_map(cls, mesh): + """Create a DTensor layout map for a GPT2Backbone. + + Given a DTensor mesh describing a list of devices, this method returns a + DTensor layout map for creating a `keras_nlp.models.GPT2Backbone` + instance. This mapping describes how to distribute all model weights + across multiple devices. For an overview of DTensor concepts, see + [this guide](https://www.tensorflow.org/guide/dtensor_overview). + + Args: + mesh: A 2D `tf.experimental.dtensor.Mesh` describing the arrangement + of devices for running distributed computation. The + first dimension in the mesh is expected to be for data parallel + distribution, and the second for model parallel distribution. + + Returns: + A `tf.keras.dtensor.experimental.LayoutMap` which contains the + proper layout to weights mapping for the model parallel setting. + + Examples: + ```python + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) + + # Update both dimensions below for a multi-device setting. + mesh = dtensor.create_mesh([("batch", 1), ("model", 1)]) + layout_map = keras_nlp.models.GPT2Backbone.create_layout_map(mesh) + + with layout_map.scope(): + model = keras_nlp.models.GPT2Backbone.from_preset("gpt2_base_en") + ``` + """ + # We assert the mesh is 2D, and assume the first mesh dim is for data + # parallel and the second dim is for model parallel. + mesh_shape = mesh.shape() + if len(mesh_shape) != 2: + raise ValueError( + f"Expect to create layout based on 2D mesh, received {mesh}" + ) + _, model_dim = mesh.dim_names + unshard_dim = dtensor.UNSHARDED + + layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh) + # Embedding sharding + layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh) + + # Transformer block sharding + layout_map[r".*_(query|key|value)_dense.kernel"] = Layout( + [unshard_dim, unshard_dim, model_dim], mesh + ) + layout_map[r".*_(query|key|value)_dense.bias"] = Layout( + [model_dim, unshard_dim], mesh + ) + layout_map[r".*_feedforward_intermediate_dense.kernel"] = Layout( + [unshard_dim, model_dim], mesh + ) + layout_map[r".*_feedforward_intermediate_dense.bias"] = Layout( + [model_dim], mesh + ) + layout_map[r".*_feedforward_output_dense.kernel"] = Layout( + [model_dim, unshard_dim], mesh + ) + layout_map[r".*_feedforward_output_dense.bias"] = Layout( + [unshard_dim], mesh + ) + return layout_map diff --git a/keras_nlp/models/gpt2/gpt2_backbone_test.py b/keras_nlp/models/gpt2/gpt2_backbone_test.py index ce1232e08c..02cb6dcfdd 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone_test.py +++ b/keras_nlp/models/gpt2/gpt2_backbone_test.py @@ -25,6 +25,10 @@ class GPT2Test(tf.test.TestCase, parameterized.TestCase): def setUp(self): + # For DTensor. + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) + self.backbone = GPT2Backbone( vocabulary_size=10, num_layers=2, @@ -91,6 +95,23 @@ def test_saved_model(self, save_format, filename): restored_output = restored_model(self.input_batch) self.assertAllClose(model_output, restored_output) + def test_create_layout_map(self): + mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)]) + with GPT2Backbone.create_layout_map(mesh).scope(): + GPT2Backbone( + vocabulary_size=10, + num_layers=2, + num_heads=2, + hidden_dim=2, + intermediate_dim=4, + max_sequence_length=5, + ) + # Using DTensor enables the mlir bridge as a side effect. Eventually + # this will be default, but for now we have compile errors with the + # bridge elsewhere and must disable. See + # https://github.com/keras-team/keras-nlp/issues/1001 + tf.config.experimental.disable_mlir_bridge() + @pytest.mark.tpu @pytest.mark.usefixtures("tpu_test_class") diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 1468de24d7..8889246ab2 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -456,3 +456,39 @@ def preprocess(x): if outputs.dtype == tf.string: return tensor_to_string_list(outputs) return outputs.numpy() + + @classmethod + def create_layout_map(cls, mesh): + """Create a DTensor layout map for an GPT2CausalLM. + + Given a DTensor mesh describing a list of devices, this method returns a + DTensor layout map for creating a `keras_nlp.models.GPT2CausalLM` + instance. This mapping describes how to distribute all model weights + across multiple devices. For an overview of DTensor concepts, see + [this guide](https://www.tensorflow.org/guide/dtensor_overview). + + Args: + mesh: A 2D `tf.experimental.dtensor.Mesh` describing the arrangement + of devices for running distributed computation. The + first dimension in the mesh is expected to be for data parallel + distribution, and the second for model parallel distribution. + + Returns: + A `tf.keras.dtensor.experimental.LayoutMap` which contains the + proper layout to weights mapping for the model parallel setting. + + Examples: + ```python + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) + + # Update both dimensions below for a multi-device setting. + mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)]) + layout_map = keras_nlp.models.GPT2CausalLM.create_layout_map(mesh) + + with layout_map.scope(): + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + ``` + """ + # As this task has no new variables, we just re-use the backbone method. + return cls.backbone_cls.create_layout_map(mesh) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py index 5f210a0ecc..eb5b017f3c 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py @@ -30,6 +30,10 @@ class GPT2CausalLMTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): + # For DTensor. + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) + self.vocab = { "!": 0, "air": 1, @@ -147,3 +151,13 @@ def test_saved_model(self, save_format, filename): keras.utils.set_random_seed(42) restored_output = restored_model.predict(self.raw_batch) self.assertAllClose(model_output, restored_output) + + def test_create_layout_map(self): + mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)]) + with GPT2CausalLM.create_layout_map(mesh).scope(): + GPT2CausalLM(backbone=self.backbone) + # Using DTensor enables the mlir bridge as a side effect. Eventually + # this will be default, but for now we have compile errors with the + # bridge elsewhere and must disable. See + # https://github.com/keras-team/keras-nlp/issues/1001 + tf.config.experimental.disable_mlir_bridge()