diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index d6fbf26e13..1c4686db25 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -91,7 +91,7 @@ class GPT2CausalLM(Task): "I don't listen to music while coding.", "But I watch youtube while coding!", ] - ds = tf.data.Dataset.from_tensor_slices(features) + ds = tf.data.Dataset.from_tensor_slices(features).batch(2) # Create a `GPT2CausalLM` and fit your data. gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( @@ -100,7 +100,7 @@ class GPT2CausalLM(Task): gpt2_lm.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), ) - gpt2_lm.fit(ds, batch_size=2) + gpt2_lm.fit(ds) ``` Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on