Skip to content

Commit 0629ebc

Browse files
committed
Add dtensor API for GPT2
1 parent 902e5bf commit 0629ebc

File tree

4 files changed

+141
-0
lines changed

4 files changed

+141
-0
lines changed

keras_nlp/models/gpt2/gpt2_backbone.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import tensorflow as tf
2020
from tensorflow import keras
21+
from tensorflow.experimental import dtensor
22+
from tensorflow.experimental.dtensor import Layout
2123

2224
from keras_nlp.api_export import keras_nlp_export
2325
from keras_nlp.layers.position_embedding import PositionEmbedding
@@ -194,3 +196,71 @@ def token_embedding(self):
194196
@classproperty
195197
def presets(cls):
196198
return copy.deepcopy(backbone_presets)
199+
200+
@classmethod
201+
def create_layout_map(cls, mesh):
202+
"""Create a DTensor layout map for a GPT2Backbone.
203+
204+
Given a DTensor mesh describing a list of devices, this method returns a
205+
DTensor layout map for creating a `keras_nlp.models.GPT2Backbone`
206+
instance. This mapping describes how to distribute all model weights
207+
across multiple devices. For an overview of DTensor concepts, see
208+
[this guide](https://www.tensorflow.org/guide/dtensor_overview).
209+
210+
Args:
211+
mesh: A 2D `tf.experimental.dtensor.Mesh` describing the arrangement
212+
of devices for running distributed computation. The
213+
first dimension in the mesh is expected to be for data parallel
214+
distribution, and the second for model parallel distribution.
215+
216+
Returns:
217+
A `tf.keras.dtensor.experimental.LayoutMap` which contains the
218+
proper layout to weights mapping for the model parallel setting.
219+
220+
Examples:
221+
```python
222+
keras.backend.experimental.enable_tf_random_generator()
223+
keras.utils.set_random_seed(1337)
224+
225+
# Update both dimensions below for a multi-device setting.
226+
mesh = dtensor.create_mesh([("batch", 1), ("model", 1)])
227+
layout_map = keras_nlp.models.GPT2Backbone.create_layout_map(mesh)
228+
229+
with layout_map.scope():
230+
model = keras_nlp.models.GPT2Backbone.from_preset("gpt2_base_en")
231+
```
232+
"""
233+
# We assert the mesh is 2D, and assume the first mesh dim is for data
234+
# parallel and the second dim is for model parallel.
235+
mesh_shape = mesh.shape()
236+
if len(mesh_shape) != 2:
237+
raise ValueError(
238+
f"Expect to create layout based on 2D mesh, received {mesh}"
239+
)
240+
_, model_dim = mesh.dim_names
241+
unshard_dim = dtensor.UNSHARDED
242+
243+
layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh)
244+
# Embedding sharding
245+
layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh)
246+
247+
# Transformer block sharding
248+
layout_map[r".*_(query|key|value)_dense.kernel"] = Layout(
249+
[unshard_dim, unshard_dim, model_dim], mesh
250+
)
251+
layout_map[r".*_(query|key|value)_dense.bias"] = Layout(
252+
[model_dim, unshard_dim], mesh
253+
)
254+
layout_map[r".*_feedforward_intermediate_dense.kernel"] = Layout(
255+
[unshard_dim, model_dim], mesh
256+
)
257+
layout_map[r".*_feedforward_intermediate_dense.bias"] = Layout(
258+
[model_dim], mesh
259+
)
260+
layout_map[r".*_feedforward_output_dense.kernel"] = Layout(
261+
[model_dim, unshard_dim], mesh
262+
)
263+
layout_map[r".*_feedforward_output_dense.bias"] = Layout(
264+
[unshard_dim], mesh
265+
)
266+
return layout_map

keras_nlp/models/gpt2/gpt2_backbone_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525

2626
class GPT2Test(tf.test.TestCase, parameterized.TestCase):
2727
def setUp(self):
28+
# For DTensor.
29+
keras.backend.experimental.enable_tf_random_generator()
30+
keras.utils.set_random_seed(1337)
31+
2832
self.backbone = GPT2Backbone(
2933
vocabulary_size=10,
3034
num_layers=2,
@@ -91,6 +95,23 @@ def test_saved_model(self, save_format, filename):
9195
restored_output = restored_model(self.input_batch)
9296
self.assertAllClose(model_output, restored_output)
9397

98+
def test_create_layout_map(self):
99+
mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)])
100+
with GPT2Backbone.create_layout_map(mesh).scope():
101+
GPT2Backbone(
102+
vocabulary_size=10,
103+
num_layers=2,
104+
num_heads=2,
105+
hidden_dim=2,
106+
intermediate_dim=4,
107+
max_sequence_length=5,
108+
)
109+
# Using DTensor enables the mlir bridge as a side effect. Eventually
110+
# this will be default, but for now we have compile errors with the
111+
# bridge elsewhere and must disable. See
112+
# https://github.com/keras-team/keras-nlp/issues/1001
113+
tf.config.experimental.disable_mlir_bridge()
114+
94115

95116
@pytest.mark.tpu
96117
@pytest.mark.usefixtures("tpu_test_class")

keras_nlp/models/gpt2/gpt2_causal_lm.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,39 @@ def preprocess(x):
456456
if outputs.dtype == tf.string:
457457
return tensor_to_string_list(outputs)
458458
return outputs.numpy()
459+
460+
@classmethod
461+
def create_layout_map(cls, mesh):
462+
"""Create a DTensor layout map for an GPT2CausalLM.
463+
464+
Given a DTensor mesh describing a list of devices, this method returns a
465+
DTensor layout map for creating a `keras_nlp.models.GPT2CausalLM`
466+
instance. This mapping describes how to distribute all model weights
467+
across multiple devices. For an overview of DTensor concepts, see
468+
[this guide](https://www.tensorflow.org/guide/dtensor_overview).
469+
470+
Args:
471+
mesh: A 2D `tf.experimental.dtensor.Mesh` describing the arrangement
472+
of devices for running distributed computation. The
473+
first dimension in the mesh is expected to be for data parallel
474+
distribution, and the second for model parallel distribution.
475+
476+
Returns:
477+
A `tf.keras.dtensor.experimental.LayoutMap` which contains the
478+
proper layout to weights mapping for the model parallel setting.
479+
480+
Examples:
481+
```python
482+
keras.backend.experimental.enable_tf_random_generator()
483+
keras.utils.set_random_seed(1337)
484+
485+
# Update both dimensions below for a multi-device setting.
486+
mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)])
487+
layout_map = keras_nlp.models.GPT2CausalLM.create_layout_map(mesh)
488+
489+
with layout_map.scope():
490+
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
491+
```
492+
"""
493+
# As this task has no new variables, we just re-use the backbone method.
494+
return cls.backbone_cls.create_layout_map(mesh)

keras_nlp/models/gpt2/gpt2_causal_lm_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030

3131
class GPT2CausalLMTest(tf.test.TestCase, parameterized.TestCase):
3232
def setUp(self):
33+
# For DTensor.
34+
keras.backend.experimental.enable_tf_random_generator()
35+
keras.utils.set_random_seed(1337)
36+
3337
self.vocab = {
3438
"!": 0,
3539
"air": 1,
@@ -147,3 +151,13 @@ def test_saved_model(self, save_format, filename):
147151
keras.utils.set_random_seed(42)
148152
restored_output = restored_model.predict(self.raw_batch)
149153
self.assertAllClose(model_output, restored_output)
154+
155+
def test_create_layout_map(self):
156+
mesh = tf.experimental.dtensor.create_mesh([("batch", 1), ("model", 1)])
157+
with GPT2CausalLM.create_layout_map(mesh).scope():
158+
GPT2CausalLM(backbone=self.backbone)
159+
# Using DTensor enables the mlir bridge as a side effect. Eventually
160+
# this will be default, but for now we have compile errors with the
161+
# bridge elsewhere and must disable. See
162+
# https://github.com/keras-team/keras-nlp/issues/1001
163+
tf.config.experimental.disable_mlir_bridge()

0 commit comments

Comments
 (0)