From 2cc019092cde25b9c9bb6d6b770bf57d8ad7eaea Mon Sep 17 00:00:00 2001 From: Adel Benlagra Date: Mon, 13 Nov 2023 17:16:22 -0500 Subject: [PATCH] converted multi_label_classification exemple --- .../tensorflow/nlp/multi_label_classification.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/keras_io/tensorflow/nlp/multi_label_classification.py b/examples/keras_io/tensorflow/nlp/multi_label_classification.py index d6f43314dce..c4ac6014639 100644 --- a/examples/keras_io/tensorflow/nlp/multi_label_classification.py +++ b/examples/keras_io/tensorflow/nlp/multi_label_classification.py @@ -2,7 +2,7 @@ Title: Large-scale multi-label text classification Author: [Sayak Paul](https://twitter.com/RisingSayak), [Soumik Rakshit](https://github.com/soumik12345) Date created: 2020/09/25 -Last modified: 2020/12/23 +Last modified: 2023/11/13 (conversion to Keras 3 by @ben-ad) Description: Implementing a large-scale multi-label text classification model. Accelerator: GPU """ @@ -29,8 +29,8 @@ ## Imports """ -from tensorflow.keras import layers -from tensorflow import keras +from keras import layers +import keras import tensorflow as tf from sklearn.model_selection import train_test_split @@ -65,7 +65,7 @@ duplication. Here we notice that our initial dataset has got about 13k duplicate entries. """ -total_duplicate_titles = sum(arxiv_data["titles"].duplicated()) +total_duplicate_titles = arxiv_data["titles"].duplicated().sum() print(f"There are {total_duplicate_titles} duplicate titles.") """ @@ -144,7 +144,7 @@ """ terms = tf.ragged.constant(train_df["terms"].values) -lookup = tf.keras.layers.StringLookup(output_mode="multi_hot") +lookup = layers.StringLookup(output_mode="multi_hot") lookup.adapt(terms) vocab = lookup.get_vocabulary() @@ -371,7 +371,10 @@ def plot_result(item): An important feature of the [preprocessing layers provided by Keras](https://keras.io/guides/preprocessing_layers/) -is that they can be included inside a `tf.keras.Model`. We will export an inference model +is that they can be included inside a `keras.Model`. Note however that the [TextVectorization +layer](https://keras.io/keras_core/api/layers/preprocessing_layers/text/text_vectorization/) + uses TensorFlow internally. It cannot be used as part of a compiled computation graph +of a model with any backend other than TensorFlow. We will export an inference model by including the `text_vectorization` layer on top of `shallow_mlp_model`. This will allow our inference model to directly operate on raw strings.