From 422edf28e7b611e3488565204836cfdaaa90b664 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 11 Apr 2023 15:26:40 -0700 Subject: [PATCH 1/3] Standalone functions for generate pre/post processing This decomposes generate in the way we discussed last week, with the goal of leaving the top-level functionality untouched, but allowing a more a granular way to access the preprocessing, postprocessing, and inner dense generation function. Colab [HERE](https://colab.research.google.com/gist/mattdangerw/bb1ef01c1b67255def4a6ad9429de2df/split-preprocessing-demo.ipynb) Other than moving things around in the refactor, there is one major change we need to do here, which is the inner, compiled generate function must also return a padding mask of token ids that were updated. Without this padding mask, the postprocessor would not know where to truncate output before detokenization. To accommodate this I made `generate_function` inputs and outputs a dict with keys "token_ids" and "padding_mask". I actually find this fairly intuitive, with this change `generate_function` has the same inputs and outputs as directly calling the model! ```python generate_function = causal_lm.make_generate_function() generate_function({ "token_ids": [[1, 2, 3, 4, 0, 0, 0, 0]], "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]], }) >>> { "token_ids": [[1, 2, 3, 4, 5, 6, 7, 8]], "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1]], } generate_function({ "token_ids": [[1, 2, 3, 4, 0, 0, 0, 0]], "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]], }, end_token_id=6) >>> { "token_ids": [[1, 2, 3, 4, 5, 6, 0, 0]], "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], } ``` --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 181 +++++++++++------- .../gpt2/gpt2_causal_lm_preprocessor.py | 89 +++++---- .../gpt2/gpt2_causal_lm_preprocessor_test.py | 19 +- keras_nlp/models/gpt2/gpt2_causal_lm_test.py | 15 +- keras_nlp/models/gpt2/gpt2_preprocessor.py | 17 +- .../models/gpt2/gpt2_preprocessor_test.py | 6 +- 6 files changed, 195 insertions(+), 132 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 8889246ab2..b738a969ae 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -29,7 +29,6 @@ from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tf_utils import tensor_to_string_list -from keras_nlp.utils.tf_utils import truncate_at_token @keras_nlp_export("keras_nlp.models.GPT2CausalLM") @@ -49,7 +48,7 @@ class GPT2CausalLM(Task): default, `"top_k"` sampling will be used. This model can optionally be configured with a `preprocessor` layer, in - which case it will automatically apply preprocessing to raw inputs during + which case it will automatically apply preprocessing to string inputs during `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default when creating the model with `from_preset()`. @@ -306,28 +305,23 @@ def make_generate_function(self): def generate_step( self, - token_ids, - padding_mask, + inputs, end_token_id=None, ): """A compilable generation function for a single batch of inputs. This function represents the inner, XLA-compilable, generation function - for a single batch of inputs. It takes in a dense `tf.Tensor` of token - ids, and return a dense `tf.Tensor` of token ids, and includes no - preprocessing. This function is wrapped by the `generate()` method. + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. Args: - token_ids: A dense int Tensor, with shape - `(batch_size, max_length)`. The user provided token ids - padded to `max_length`. - padding_mask: A dense boolean Tensor, with the same shape as - `token_ids`. Positions that are True in the `padding_mask` - are assumed to be user input and never updated. + inputs: A dictionary with two batched tensor keys `"token_ids"` + and `"padding_mask"`. end_token_id: The id of the end token to stop on. If all sequences have produced a new `end_token_id`, generation will stop. """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] # Create and seed cache with a single forward pass. hidden_states, cache = self._build_cache(token_ids) # Compute the lengths of all user inputted tokens ids. @@ -352,7 +346,7 @@ def next(prompt, cache, index): cache, ) - return self._sampler( + token_ids = self._sampler( next=next, prompt=token_ids, cache=cache, @@ -362,6 +356,78 @@ def next(prompt, cache, index): hidden_states=hidden_states, ) + # Compute an output padding mask with the token ids we updated. + if end_token_id is not None: + # Build a mask of `end_token_id` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = (token_ids == end_token_id) & (~padding_mask) + end_locations = tf.cast(end_locations, tf.int32) + # Use cumsum to get ones in all locations after end_locations. + overflow = tf.math.cumsum(end_locations, exclusive=True) + # Our padding mask is the inverse of these overflow locations. + padding_mask = ~tf.cast(overflow, tf.bool) + else: + # Without early stopping, all locations will have been updated. + padding_mask = tf.ones_like(token_ids, dtype=tf.bool) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def _normalize_generate_inputs( + self, + inputs, + ): + """Normalize user input to the generate function. + + This function coverts all inputs to tensors, adds a batch dimension if + necessary, and returns a iterable "dataset like" object (either an + actually dataset or a list with a single element). + """ + input_is_scalar = False + + if isinstance(inputs, tf.data.Dataset): + return inputs, input_is_scalar + + if isinstance(inputs, str) or isinstance(inputs, list): + inputs = tf.convert_to_tensor(inputs) + + if isinstance(inputs, tf.Tensor) and inputs.shape.rank == 0: + input_is_scalar = True + inputs = inputs[tf.newaxis] + + # We avoid coverting to a dataset purely for speed, for a single batch + # of input, creating a dataset would add significant overhead. + return [inputs], input_is_scalar + + def _normalize_generate_outputs( + self, + outputs, + input_is_scalar, + ): + """Normalize user output from the generate function. + + This function converts all output to numpy (for integer output), or + python strings (for string output). If a batch dimension was added to + the input, it is removed from the output (so generate can be string in, + string out). + """ + + def normalize(x): + x = tf.concat(x, axis=0) + x = tf.squeeze(x, 0) if input_is_scalar else x + is_string = x.dtype == tf.string + # Convert outputs to a friendly pythonic type. For numerical outputs + # that is numpy, for string outputs that is `list` and `str`. + return tensor_to_string_list(x) if is_string else x.numpy() + + if isinstance(outputs[0], dict): + return { + "token_ids": normalize([x["token_ids"] for x in outputs]), + "padding_mask": normalize([x["padding_mask"] for x in outputs]), + } + return normalize([x for x in outputs]) + def generate( self, inputs, @@ -397,65 +463,43 @@ def generate( A string or string list if `preprocessor` is set, and a integer tensor of token IDs if `preprocessor is None`. """ - input_is_scalar = False - + # Setup our three main passes. + # 1. Optionally preprocessing strings to dense integer tensors. + # 2. Generate new tokens via a compiled function on dense tensors. + # 3. Optionally postprocess dense integer tensors back to string. + generate_function = self.make_generate_function() + end_token_id = None if self.preprocessor is not None: + end_token_id = self.preprocessor.tokenizer.end_token_id - def preprocess(x): - return self.preprocessor( - x, - sequence_length=max_length, - return_labels=False, - # We do not append an end token by default during - # generation, as generating directly in the same sequence is - # the most common workflow. If an end token directly after - # a prompt is desired, it can be added to the prompt string. - add_end_token=False, - ) - - if not isinstance(inputs, tf.data.Dataset): - inputs = tf.convert_to_tensor(inputs) - input_is_scalar = inputs.shape.rank == 0 - inputs = inputs[tf.newaxis] if input_is_scalar else inputs - # Wrap a list to avoid the overhead of converting to dataset. - inputs = [preprocess(inputs)] - else: + def preprocess(x): + return self.preprocessor.generate_preprocess( + x, sequence_length=max_length + ) + + def generate(x): + return generate_function(x, end_token_id=end_token_id) + + def postprocess(x): + return self.preprocessor.generate_postprocess(x) + + # Normalize inputs, apply our three passes, and normalize outputs. + inputs, input_is_scalar = self._normalize_generate_inputs(inputs) + + if self.preprocessor is not None: + if isinstance(inputs, tf.data.Dataset): inputs = inputs.map(preprocess, tf.data.AUTOTUNE) inputs = inputs.prefetch(tf.data.AUTOTUNE) - else: - if not isinstance(inputs, tf.data.Dataset): - # Wrap a list to avoid the overhead of converting to dataset. - inputs = [inputs] + else: + # Fast path for non-dataset, single-batch input. + inputs = [preprocess(x) for x in inputs] - generate_function = self.make_generate_function() - outputs = [] - for batch in inputs: - token_ids, padding_mask = batch["token_ids"], batch["padding_mask"] - # If `preprocessor` is attached, we can stop after end_token_id. - end_token_id = None - if self.preprocessor is not None: - end_token_id = self.preprocessor.tokenizer.end_token_id - # Run the compiled generate function. - output = generate_function(token_ids, padding_mask, end_token_id) - - if self.preprocessor is not None: - # Truncate to ragged by removing tokens after the first - # generated `end_token_id`. - output = truncate_at_token(output, end_token_id, padding_mask) - # Strip start token if added. - if self.preprocessor.add_start_token: - output = output[:, 1:] - # Detokenize. - output = self.preprocessor.tokenizer.detokenize(output) - outputs.append(output) - - outputs = tf.concat(outputs, axis=0) - outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs - # Convert outputs to a friendly pythonic type. For numerical outputs - # that is numpy, for string outputs that is `list` and `str`. - if outputs.dtype == tf.string: - return tensor_to_string_list(outputs) - return outputs.numpy() + outputs = [generate(x) for x in inputs] + + if self.preprocessor is not None: + outputs = [postprocess(x) for x in outputs] + + return self._normalize_generate_outputs(outputs, input_is_scalar) @classmethod def create_layout_map(cls, mesh): @@ -492,3 +536,4 @@ def create_layout_map(cls, mesh): """ # As this task has no new variables, we just re-use the backbone method. return cls.backbone_cls.create_layout_map(mesh) + return self._normalize_generate_outputs(outputs, input_is_scalar) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 0bb3e07572..362e1af0bb 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -14,10 +14,14 @@ """GPT2 Causal LM preprocessor layer.""" +import tensorflow as tf from absl import logging from keras_nlp.api_export import keras_nlp_export from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight @@ -25,12 +29,17 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): """GPT2 Causal LM preprocessor. - This preprocessing layer is primarily meant to be used with + This preprocessing layer is meant for use with `keras_nlp.models.GPT2CausalLM`. By default, it will take in batches of strings, and return outputs in a `(x, y, sample_weight)` format, where the - `y` label is the next token id in the `x` sequence. For use with generation, - pass `return_labels=False`, in which case the output will simply be the - encoded string features. + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.GPT2CausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). Args: tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance. @@ -47,12 +56,6 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): generates label weights. sequence_length: Pass to override the configured `sequence_length` of the layer. - add_start_token: Pass to override the configured value of - `add_start_token` on the layer. - add_end_token: Pass to override the configured value of - `add_end_token` on the layer. - return_labels: If `True`, the output `"token_ids"` will be offset by one - and returned as labels. If `False` only features will be returned. Examples: ```python @@ -95,9 +98,6 @@ def call( y=None, sample_weight=None, sequence_length=None, - add_start_token=None, - add_end_token=None, - return_labels=True, ): if y is not None or sample_weight is not None: logging.warning( @@ -106,25 +106,48 @@ def call( "or `sample_weight`. Your `y` and `sample_weight` will be " "ignored." ) - if return_labels: - # Tokenize with one extra token to account for the truncation below. - sequence_length = (sequence_length or self.sequence_length) + 1 - x = super().call( + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( x, - sequence_length=sequence_length, - add_start_token=add_start_token, - add_end_token=add_end_token, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False ) - if return_labels: - token_ids, padding_mask = x["token_ids"], x["padding_mask"] - # The last token does not have a next token, so we truncate it out. - x = { - "token_ids": token_ids[..., :-1], - "padding_mask": padding_mask[..., :-1], - } - # Target `y` will be the next token. - y = token_ids[..., 1:] - sample_weight = padding_mask[..., 1:] - return pack_x_y_sample_weight(x, y, sample_weight) - else: - return x + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py index 03a41c84c1..5d5912d9a8 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py @@ -107,16 +107,19 @@ def test_dataset(self): self.assertAllEqual(y, [[1, 3, 4, 2, 5, 6, 0, 0]] * 4) self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4) - def test_call_overrides(self): + def test_generate_preprocess(self): input_data = "airplane at airport" - x, _, _ = self.preprocessor(input_data, add_start_token=False) - self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 6, 0, 0]) - x, _, _ = self.preprocessor(input_data, add_end_token=False) + x = self.preprocessor.generate_preprocess(input_data) self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0]) - x, _, _ = self.preprocessor(input_data, sequence_length=4) - self.assertAllEqual(x["token_ids"], [6, 1, 3, 4]) - x = self.preprocessor(input_data, return_labels=False) - self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 6, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": tf.constant([6, 1, 3, 4, 2, 5, 0, 0]), + "padding_mask": tf.cast([1, 1, 1, 1, 1, 1, 0, 0], dtype="bool"), + } + x = self.preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") def test_serialization(self): config = keras.utils.serialize_keras_object(self.preprocessor) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py index eb5b017f3c..38a0302c4b 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py @@ -98,8 +98,8 @@ def test_fit_no_xla(self): def test_generate(self): # String input. - prompt = " airplane" - output = self.causal_lm.generate(" airplane") + prompt = " airplane at airport" + output = self.causal_lm.generate(" airplane at airport") self.assertTrue(prompt in output) # String tensor input. self.assertIsInstance(self.causal_lm.generate(self.raw_batch)[0], str) @@ -107,8 +107,15 @@ def test_generate(self): self.assertIsInstance(self.causal_lm.generate(self.raw_dataset)[0], str) # Int tensor input. self.causal_lm.preprocessor = None - self.assertDTypeEqual( - self.causal_lm.generate(self.preprocessed_batch), tf.int32 + outputs = self.causal_lm.generate(self.preprocessed_batch) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + self.preprocessed_batch["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + self.preprocessed_batch["padding_mask"][:, :5], ) def test_generate_compilation(self): diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index f309c090fe..30eb0bbd5e 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -68,10 +68,6 @@ class GPT2Preprocessor(Preprocessor): sample_weight: Any label weight data. Will be passed through unaltered. sequence_length: Pass to override the configured `sequence_length` of the layer. - add_start_token: Pass to override the configured value of - `add_start_token` on the layer. - add_end_token: Pass to override the configured value of - `add_end_token` on the layer. Examples: @@ -154,8 +150,6 @@ def call( y=None, sample_weight=None, sequence_length=None, - add_start_token=None, - add_end_token=None, ): x = convert_inputs_to_list_of_tensor_segments(x) if len(x) != 1: @@ -165,17 +159,12 @@ def call( "for a multi-segment classification task, please refer to " "classification models like BERT or RoBERTa." ) - if sequence_length is None: - sequence_length = self.sequence_length - if add_start_token is None: - add_start_token = self.add_start_token - if add_end_token is None: - add_end_token = self.add_end_token + sequence_length = sequence_length or self.sequence_length token_ids, padding_mask = self.packer( self.tokenizer(x[0]), sequence_length=sequence_length, - add_start_value=add_start_token, - add_end_value=add_end_token, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, ) x = { "token_ids": token_ids, diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py index cf59654535..6f8f7d8629 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py @@ -97,12 +97,8 @@ def test_tokenize_labeled_dataset(self): self.assertAllEqual(x["token_ids"], [[6, 1, 3, 4, 2, 5, 6, 0]] * 4) self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) - def test_call_overrides(self): + def test_sequence_length_override(self): input_data = "airplane at airport" - x = self.preprocessor(input_data, add_start_token=False) - self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 6, 0, 0]) - x = self.preprocessor(input_data, add_end_token=False) - self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0]) x = self.preprocessor(input_data, sequence_length=4) self.assertAllEqual(x["token_ids"], [6, 1, 3, 6]) From 79d5babcb444b63aa974162fe3f5837e134478c1 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 3 May 2023 15:16:19 +0000 Subject: [PATCH 2/3] More docstring updates --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 6 +++--- .../models/gpt2/gpt2_causal_lm_preprocessor.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index b738a969ae..64c783a478 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -315,8 +315,8 @@ def generate_step( model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. Args: - inputs: A dictionary with two batched tensor keys `"token_ids"` - and `"padding_mask"`. + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. end_token_id: The id of the end token to stop on. If all sequences have produced a new `end_token_id`, generation will stop. @@ -382,7 +382,7 @@ def _normalize_generate_inputs( This function coverts all inputs to tensors, adds a batch dimension if necessary, and returns a iterable "dataset like" object (either an - actually dataset or a list with a single element). + actual `tf.data.Dataset` or a list with a single batch element). """ input_is_scalar = False diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 362e1af0bb..b9f72e2f5a 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -131,6 +131,17 @@ def generate_preprocess( x, sequence_length=None, ): + """Covert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ x = convert_inputs_to_list_of_tensor_segments(x)[0] x = self.tokenizer(x) token_ids, padding_mask = self.packer( @@ -145,6 +156,12 @@ def generate_postprocess( self, x, ): + """Covert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the interger sequence + back to a string. + """ token_ids, padding_mask = x["token_ids"], x["padding_mask"] # Strip any special tokens during detokenization (e.g. the start and # end markers). In the future we could make this configurable. From 23ccb365c980a20610804083f94f352191e55e29 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 3 May 2023 15:43:29 +0000 Subject: [PATCH 3/3] Fix merge conflict --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 64c783a478..d60dceb8f4 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -536,4 +536,3 @@ def create_layout_map(cls, mesh): """ # As this task has no new variables, we just re-use the backbone method. return cls.backbone_cls.create_layout_map(mesh) - return self._normalize_generate_outputs(outputs, input_is_scalar)