18
18
19
19
import tensorflow as tf
20
20
from tensorflow import keras
21
+ from tensorflow .experimental import dtensor
22
+ from tensorflow .experimental .dtensor import Layout
21
23
22
24
from keras_nlp .api_export import keras_nlp_export
23
25
from keras_nlp .layers .position_embedding import PositionEmbedding
@@ -194,3 +196,71 @@ def token_embedding(self):
194
196
@classproperty
195
197
def presets (cls ):
196
198
return copy .deepcopy (backbone_presets )
199
+
200
+ @classmethod
201
+ def create_layout_map (cls , mesh ):
202
+ """Create a DTensor layout map for a GPT2Backbone.
203
+
204
+ Given a DTensor mesh describing a list of devices, this method returns a
205
+ DTensor layout map for creating a `keras_nlp.models.GPT2Backbone`
206
+ instance. This mapping describes how to distribute all model weights
207
+ across multiple devices. For an overview of DTensor concepts, see
208
+ [this guide](https://www.tensorflow.org/guide/dtensor_overview).
209
+
210
+ Args:
211
+ mesh: A 2D `tf.experimental.dtensor.Mesh` describing the arrangement
212
+ of devices for running distributed computation. The
213
+ first dimension in the mesh is expected to be for data parallel
214
+ distribution, and the second for model parallel distribution.
215
+
216
+ Returns:
217
+ A `tf.keras.dtensor.experimental.LayoutMap` which contains the
218
+ proper layout to weights mapping for the model parallel setting.
219
+
220
+ Examples:
221
+ ```python
222
+ keras.backend.experimental.enable_tf_random_generator()
223
+ keras.utils.set_random_seed(1337)
224
+
225
+ # Update both dimensions below for a multi-device setting.
226
+ mesh = dtensor.create_mesh([("batch", 1), ("model", 1)])
227
+ layout_map = keras_nlp.models.GPT2Backbone.create_layout_map(mesh)
228
+
229
+ with layout_map.scope():
230
+ model = keras_nlp.models.GPT2Backbone.from_preset("gpt2_base_en")
231
+ ```
232
+ """
233
+ # We assert the mesh is 2D, and assume the first mesh dim is for data
234
+ # parallel and the second dim is for model parallel.
235
+ mesh_shape = mesh .shape ()
236
+ if len (mesh_shape ) != 2 :
237
+ raise ValueError (
238
+ f"Expect to create layout based on 2D mesh, received { mesh } "
239
+ )
240
+ _ , model_dim = mesh .dim_names
241
+ unshard_dim = dtensor .UNSHARDED
242
+
243
+ layout_map = keras .dtensor .experimental .LayoutMap (mesh = mesh )
244
+ # Embedding sharding
245
+ layout_map [r".*embeddings" ] = Layout ([unshard_dim , model_dim ], mesh )
246
+
247
+ # Transformer block sharding
248
+ layout_map [r".*_(query|key|value)_dense.kernel" ] = Layout (
249
+ [unshard_dim , unshard_dim , model_dim ], mesh
250
+ )
251
+ layout_map [r".*_(query|key|value)_dense.bias" ] = Layout (
252
+ [model_dim , unshard_dim ], mesh
253
+ )
254
+ layout_map [r".*_feedforward_intermediate_dense.kernel" ] = Layout (
255
+ [unshard_dim , model_dim ], mesh
256
+ )
257
+ layout_map [r".*_feedforward_intermediate_dense.bias" ] = Layout (
258
+ [model_dim ], mesh
259
+ )
260
+ layout_map [r".*_feedforward_output_dense.kernel" ] = Layout (
261
+ [model_dim , unshard_dim ], mesh
262
+ )
263
+ layout_map [r".*_feedforward_output_dense.bias" ] = Layout (
264
+ [unshard_dim ], mesh
265
+ )
266
+ return layout_map
0 commit comments