diff --git a/examples/nlp/img/multimodal_entailment/multimodal_entailment_14_0.png b/examples/nlp/img/multimodal_entailment/multimodal_entailment_14_0.png
index c8c0ae2d3b..57dd715cc5 100644
Binary files a/examples/nlp/img/multimodal_entailment/multimodal_entailment_14_0.png and b/examples/nlp/img/multimodal_entailment/multimodal_entailment_14_0.png differ
diff --git a/examples/nlp/img/multimodal_entailment/multimodal_entailment_14_2.png b/examples/nlp/img/multimodal_entailment/multimodal_entailment_14_2.png
index 4037a062f2..5ca596cf4a 100644
Binary files a/examples/nlp/img/multimodal_entailment/multimodal_entailment_14_2.png and b/examples/nlp/img/multimodal_entailment/multimodal_entailment_14_2.png differ
diff --git a/examples/nlp/ipynb/multimodal_entailment.ipynb b/examples/nlp/ipynb/multimodal_entailment.ipynb
index 06427b342a..124a3c81c7 100644
--- a/examples/nlp/ipynb/multimodal_entailment.ipynb
+++ b/examples/nlp/ipynb/multimodal_entailment.ipynb
@@ -10,7 +10,7 @@
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)
\n",
"**Date created:** 2021/08/08
\n",
- "**Last modified:** 2021/08/15
\n",
+ "**Last modified:** 2025/01/03
\n",
"**Description:** Training a multimodal model for predicting entailment."
]
},
@@ -83,12 +83,18 @@
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import numpy as np\n",
+ "import random\n",
+ "import math\n",
+ "from skimage.io import imread\n",
+ "from skimage.transform import resize\n",
+ "from PIL import Image\n",
"import os\n",
"\n",
- "import tensorflow as tf\n",
- "import tensorflow_hub as hub\n",
- "import tensorflow_text as text\n",
- "from tensorflow import keras"
+ "os.environ[\"KERAS_BACKEND\"] = \"jax\" # or tensorflow, or torch\n",
+ "\n",
+ "import keras\n",
+ "import keras_hub\n",
+ "from keras.utils import PyDataset"
]
},
{
@@ -164,7 +170,9 @@
"source": [
"df = pd.read_csv(\n",
" \"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv\"\n",
- ")\n",
+ ").iloc[\n",
+ " 0:1000\n",
+ "] # Resources conservation since these are examples and not SOTA\n",
"df.sample(10)"
]
},
@@ -265,10 +273,10 @@
" print(f\"Label: {label}\")\n",
"\n",
"\n",
- "random_idx = np.random.choice(len(df))\n",
+ "random_idx = random.choice(range(len(df)))\n",
"visualize(random_idx)\n",
"\n",
- "random_idx = np.random.choice(len(df))\n",
+ "random_idx = random.choice(range(len(df)))\n",
"visualize(random_idx)"
]
},
@@ -335,42 +343,24 @@
"source": [
"## Data input pipeline\n",
"\n",
- "TensorFlow Hub provides\n",
- "[variety of BERT family of models](https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub).\n",
+ "Keras Hub provides\n",
+ "[variety of BERT family of models](https://keras.io/keras_hub/presets/).\n",
"Each of those models comes with a\n",
"corresponding preprocessing layer. You can learn more about these models and their\n",
"preprocessing layers from\n",
- "[this resource](https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub).\n",
+ "[this resource](https://www.kaggle.com/models/keras/bert/keras/bert_base_en_uncased/2).\n",
"\n",
- "To keep the runtime of this example relatively short, we will use a smaller variant of\n",
+ "To keep the runtime of this example relatively short, we will use a base_unacased variant of\n",
"the original BERT model."
]
},
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "# Define TF Hub paths to the BERT encoder and its preprocessor\n",
- "bert_model_path = (\n",
- " \"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1\"\n",
- ")\n",
- "bert_preprocess_path = \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\""
- ]
- },
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
- "Our text preprocessing code mostly comes from\n",
- "[this tutorial](https://www.tensorflow.org/text/tutorials/bert_glue).\n",
- "You are highly encouraged to check out the tutorial to learn more about the input\n",
- "preprocessing."
+ "text preprocessing using KerasHub"
]
},
{
@@ -381,48 +371,10 @@
},
"outputs": [],
"source": [
- "\n",
- "def make_bert_preprocessing_model(sentence_features, seq_length=128):\n",
- " \"\"\"Returns Model mapping string features to BERT inputs.\n",
- "\n",
- " Args:\n",
- " sentence_features: A list with the names of string-valued features.\n",
- " seq_length: An integer that defines the sequence length of BERT inputs.\n",
- "\n",
- " Returns:\n",
- " A Keras Model that can be called on a list or dict of string Tensors\n",
- " (with the order or names, resp., given by sentence_features) and\n",
- " returns a dict of tensors for input to BERT.\n",
- " \"\"\"\n",
- "\n",
- " input_segments = [\n",
- " tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)\n",
- " for ft in sentence_features\n",
- " ]\n",
- "\n",
- " # Tokenize the text to word pieces.\n",
- " bert_preprocess = hub.load(bert_preprocess_path)\n",
- " tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name=\"tokenizer\")\n",
- " segments = [tokenizer(s) for s in input_segments]\n",
- "\n",
- " # Optional: Trim segments in a smart way to fit seq_length.\n",
- " # Simple cases (like this example) can skip this step and let\n",
- " # the next step apply a default truncation to approximately equal lengths.\n",
- " truncated_segments = segments\n",
- "\n",
- " # Pack inputs. The details (start/end token ids, dict of output tensors)\n",
- " # are model-dependent, so this gets loaded from the SavedModel.\n",
- " packer = hub.KerasLayer(\n",
- " bert_preprocess.bert_pack_inputs,\n",
- " arguments=dict(seq_length=seq_length),\n",
- " name=\"packer\",\n",
- " )\n",
- " model_inputs = packer(truncated_segments)\n",
- " return keras.Model(input_segments, model_inputs)\n",
- "\n",
- "\n",
- "bert_preprocess_model = make_bert_preprocessing_model([\"text_1\", \"text_2\"])\n",
- "keras.utils.plot_model(bert_preprocess_model, show_shapes=True, show_dtype=True)"
+ "text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(\n",
+ " \"bert_base_en_uncased\",\n",
+ " sequence_length=128,\n",
+ ")"
]
},
{
@@ -442,22 +394,22 @@
},
"outputs": [],
"source": [
- "idx = np.random.choice(len(train_df))\n",
+ "idx = random.choice(range(len(train_df)))\n",
"row = train_df.iloc[idx]\n",
"sample_text_1, sample_text_2 = row[\"text_1\"], row[\"text_2\"]\n",
"print(f\"Text 1: {sample_text_1}\")\n",
"print(f\"Text 2: {sample_text_2}\")\n",
"\n",
- "test_text = [np.array([sample_text_1]), np.array([sample_text_2])]\n",
- "text_preprocessed = bert_preprocess_model(test_text)\n",
+ "test_text = [sample_text_1, sample_text_2]\n",
+ "text_preprocessed = text_preprocessor(test_text)\n",
"\n",
"print(\"Keys : \", list(text_preprocessed.keys()))\n",
- "print(\"Shape Word Ids : \", text_preprocessed[\"input_word_ids\"].shape)\n",
- "print(\"Word Ids : \", text_preprocessed[\"input_word_ids\"][0, :16])\n",
- "print(\"Shape Mask : \", text_preprocessed[\"input_mask\"].shape)\n",
- "print(\"Input Mask : \", text_preprocessed[\"input_mask\"][0, :16])\n",
- "print(\"Shape Type Ids : \", text_preprocessed[\"input_type_ids\"].shape)\n",
- "print(\"Type Ids : \", text_preprocessed[\"input_type_ids\"][0, :16])\n",
+ "print(\"Shape Token Ids : \", text_preprocessed[\"token_ids\"].shape)\n",
+ "print(\"Token Ids : \", text_preprocessed[\"token_ids\"][0, :16])\n",
+ "print(\" Shape Padding Mask : \", text_preprocessed[\"padding_mask\"].shape)\n",
+ "print(\"Padding Mask : \", text_preprocessed[\"padding_mask\"][0, :16])\n",
+ "print(\"Shape Segment Ids : \", text_preprocessed[\"segment_ids\"].shape)\n",
+ "print(\"Segment Ids : \", text_preprocessed[\"segment_ids\"][0, :16])\n",
""
]
},
@@ -488,10 +440,11 @@
"\n",
"def dataframe_to_dataset(dataframe):\n",
" columns = [\"image_1_path\", \"image_2_path\", \"text_1\", \"text_2\", \"label_idx\"]\n",
- " dataframe = dataframe[columns].copy()\n",
- " labels = dataframe.pop(\"label_idx\")\n",
- " ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))\n",
- " ds = ds.shuffle(buffer_size=len(dataframe))\n",
+ " ds = UnifiedPyDataset(\n",
+ " dataframe,\n",
+ " batch_size=32,\n",
+ " workers=4,\n",
+ " )\n",
" return ds\n",
""
]
@@ -513,35 +466,16 @@
},
"outputs": [],
"source": [
- "resize = (128, 128)\n",
- "bert_input_features = [\"input_word_ids\", \"input_type_ids\", \"input_mask\"]\n",
- "\n",
- "\n",
- "def preprocess_image(image_path):\n",
- " extension = tf.strings.split(image_path)[-1]\n",
- "\n",
- " image = tf.io.read_file(image_path)\n",
- " if extension == b\"jpg\":\n",
- " image = tf.image.decode_jpeg(image, 3)\n",
- " else:\n",
- " image = tf.image.decode_png(image, 3)\n",
- " image = tf.image.resize(image, resize)\n",
- " return image\n",
+ "bert_input_features = [\"padding_mask\", \"segment_ids\", \"token_ids\"]\n",
"\n",
"\n",
"def preprocess_text(text_1, text_2):\n",
- " text_1 = tf.convert_to_tensor([text_1])\n",
- " text_2 = tf.convert_to_tensor([text_2])\n",
- " output = bert_preprocess_model([text_1, text_2])\n",
- " output = {feature: tf.squeeze(output[feature]) for feature in bert_input_features}\n",
+ " output = text_preprocessor([text_1, text_2])\n",
+ " output = {\n",
+ " feature: keras.ops.reshape(output[feature], [-1])\n",
+ " for feature in bert_input_features\n",
+ " }\n",
" return output\n",
- "\n",
- "\n",
- "def preprocess_text_and_image(sample):\n",
- " image_1 = preprocess_image(sample[\"image_1_path\"])\n",
- " image_2 = preprocess_image(sample[\"image_2_path\"])\n",
- " text = preprocess_text(sample[\"text_1\"], sample[\"text_2\"])\n",
- " return {\"image_1\": image_1, \"image_2\": image_2, \"text\": text}\n",
""
]
},
@@ -551,7 +485,7 @@
"colab_type": "text"
},
"source": [
- "### Create the final datasets"
+ "### Create the final datasets, method adapted from PyDataset doc string."
]
},
{
@@ -562,23 +496,168 @@
},
"outputs": [],
"source": [
- "batch_size = 32\n",
- "auto = tf.data.AUTOTUNE\n",
"\n",
+ "class UnifiedPyDataset(PyDataset):\n",
+ " \"\"\"A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch.\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " df,\n",
+ " batch_size=32,\n",
+ " workers=4,\n",
+ " use_multiprocessing=False,\n",
+ " max_queue_size=10,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " df: pandas DataFrame with data\n",
+ " batch_size: Batch size for dataset\n",
+ " workers: Number of workers to use for parallel loading (Keras)\n",
+ " use_multiprocessing: Whether to use multiprocessing\n",
+ " max_queue_size: Maximum size of the data queue for parallel loading\n",
+ " \"\"\"\n",
+ " super().__init__(**kwargs)\n",
+ " self.dataframe = df\n",
+ " columns = [\"image_1_path\", \"image_2_path\", \"text_1\", \"text_2\"]\n",
+ "\n",
+ " # image files\n",
+ " self.image_x_1 = self.dataframe[\"image_1_path\"]\n",
+ " self.image_x_2 = self.dataframe[\"image_1_path\"]\n",
+ " self.image_y = self.dataframe[\"label_idx\"]\n",
+ "\n",
+ " # text files\n",
+ " self.text_x_1 = self.dataframe[\"text_1\"]\n",
+ " self.text_x_2 = self.dataframe[\"text_2\"]\n",
+ " self.text_y = self.dataframe[\"label_idx\"]\n",
+ "\n",
+ " # general\n",
+ " self.batch_size = batch_size\n",
+ " self.workers = workers\n",
+ " self.use_multiprocessing = use_multiprocessing\n",
+ " self.max_queue_size = max_queue_size\n",
+ "\n",
+ " def __getitem__(self, index):\n",
+ " \"\"\"\n",
+ " Fetches a batch of data from the dataset at the given index.\n",
+ " \"\"\"\n",
+ "\n",
+ " # Return x, y for batch idx.\n",
+ " low = index * self.batch_size\n",
+ " # Cap upper bound at array length; the last batch may be smaller\n",
+ " # if the total number of items is not a multiple of batch size.\n",
+ "\n",
+ " high_image_1 = min(low + self.batch_size, len(self.image_x_1))\n",
+ " high_image_2 = min(low + self.batch_size, len(self.image_x_2))\n",
+ "\n",
+ " high_text_1 = min(low + self.batch_size, len(self.text_x_1))\n",
+ " high_text_2 = min(low + self.batch_size, len(self.text_x_1))\n",
+ "\n",
+ " # images files\n",
+ " batch_image_x_1 = self.image_x_1[low:high_image_1]\n",
+ " batch_image_y_1 = self.image_y[low:high_image_1]\n",
+ "\n",
+ " batch_image_x_2 = self.image_x_2[low:high_image_2]\n",
+ " batch_image_y_2 = self.image_y[low:high_image_2]\n",
+ "\n",
+ " # text files\n",
+ " batch_text_x_1 = self.text_x_1[low:high_text_1]\n",
+ " batch_text_y_1 = self.text_y[low:high_text_1]\n",
+ "\n",
+ " batch_text_x_2 = self.text_x_2[low:high_text_2]\n",
+ " batch_text_y_2 = self.text_y[low:high_text_2]\n",
+ "\n",
+ " # image number 1 inputs\n",
+ " image_1 = [\n",
+ " resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1\n",
+ " ]\n",
+ " image_1 = [\n",
+ " ( # exeperienced some shapes which were different from others.\n",
+ " np.array(Image.fromarray((img.astype(np.uint8))).convert(\"RGB\"))\n",
+ " if img.shape[2] == 4\n",
+ " else img\n",
+ " )\n",
+ " for img in image_1\n",
+ " ]\n",
+ " image_1 = np.array(image_1)\n",
+ "\n",
+ " # Both text inputs to the model, return a dict for inputs to BertBackbone\n",
+ " text = {\n",
+ " key: np.array(\n",
+ " [\n",
+ " d[key]\n",
+ " for d in [\n",
+ " preprocess_text(file_path1, file_path2)\n",
+ " for file_path1, file_path2 in zip(\n",
+ " batch_text_x_1, batch_text_x_2\n",
+ " )\n",
+ " ]\n",
+ " ]\n",
+ " )\n",
+ " for key in [\"padding_mask\", \"token_ids\", \"segment_ids\"]\n",
+ " }\n",
+ "\n",
+ " # Image number 2 model inputs\n",
+ " image_2 = [\n",
+ " resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2\n",
+ " ]\n",
+ " image_2 = [\n",
+ " ( # exeperienced some shapes which were different from others\n",
+ " np.array(Image.fromarray((img.astype(np.uint8))).convert(\"RGB\"))\n",
+ " if img.shape[2] == 4\n",
+ " else img\n",
+ " )\n",
+ " for img in image_2\n",
+ " ]\n",
+ " # Stack the list comprehension to an nd.array\n",
+ " image_2 = np.array(image_2)\n",
+ "\n",
+ " return (\n",
+ " {\n",
+ " \"image_1\": image_1,\n",
+ " \"image_2\": image_2,\n",
+ " \"padding_mask\": text[\"padding_mask\"],\n",
+ " \"segment_ids\": text[\"segment_ids\"],\n",
+ " \"token_ids\": text[\"token_ids\"],\n",
+ " },\n",
+ " # Target lables\n",
+ " np.array(batch_image_y_1),\n",
+ " )\n",
+ "\n",
+ " def __len__(self):\n",
+ " \"\"\"\n",
+ " Returns the number of batches in the dataset.\n",
+ " \"\"\"\n",
+ " return math.ceil(len(self.dataframe) / self.batch_size)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Create train, validation and test datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
"\n",
- "def prepare_dataset(dataframe, training=True):\n",
+ "def prepare_dataset(dataframe):\n",
" ds = dataframe_to_dataset(dataframe)\n",
- " if training:\n",
- " ds = ds.shuffle(len(train_df))\n",
- " ds = ds.map(lambda x, y: (preprocess_text_and_image(x), y)).cache()\n",
- " ds = ds.batch(batch_size).prefetch(auto)\n",
" return ds\n",
"\n",
"\n",
"train_ds = prepare_dataset(train_df)\n",
- "validation_ds = prepare_dataset(val_df, False)\n",
- "test_ds = prepare_dataset(test_df, False)\n",
- ""
+ "validation_ds = prepare_dataset(val_df)\n",
+ "test_ds = prepare_dataset(test_df)"
]
},
{
@@ -639,7 +718,7 @@
"):\n",
" projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)\n",
" for _ in range(num_projection_layers):\n",
- " x = tf.nn.gelu(projected_embeddings)\n",
+ " x = keras.ops.nn.gelu(projected_embeddings)\n",
" x = keras.layers.Dense(projection_dims)(x)\n",
" x = keras.layers.Dropout(dropout_rate)(x)\n",
" x = keras.layers.Add()([projected_embeddings, x])\n",
@@ -721,15 +800,18 @@
"def create_text_encoder(\n",
" num_projection_layers, projection_dims, dropout_rate, trainable=False\n",
"):\n",
- " # Load the pre-trained BERT model to be used as the base encoder.\n",
- " bert = hub.KerasLayer(bert_model_path, name=\"bert\",)\n",
+ " # Load the pre-trained BERT BackBone using KerasHub.\n",
+ " bert = keras_hub.models.BertBackbone.from_preset(\n",
+ " \"bert_base_en_uncased\", num_classes=3\n",
+ " )\n",
+ "\n",
" # Set the trainability of the base encoder.\n",
" bert.trainable = trainable\n",
"\n",
" # Receive the text as inputs.\n",
- " bert_input_features = [\"input_type_ids\", \"input_mask\", \"input_word_ids\"]\n",
+ " bert_input_features = [\"padding_mask\", \"segment_ids\", \"token_ids\"]\n",
" inputs = {\n",
- " feature: keras.Input(shape=(128,), dtype=tf.int32, name=feature)\n",
+ " feature: keras.Input(shape=(256,), dtype=\"int32\", name=feature)\n",
" for feature in bert_input_features\n",
" }\n",
"\n",
@@ -775,12 +857,12 @@
" image_2 = keras.Input(shape=(128, 128, 3), name=\"image_2\")\n",
"\n",
" # Receive the text as inputs.\n",
- " bert_input_features = [\"input_type_ids\", \"input_mask\", \"input_word_ids\"]\n",
+ " bert_input_features = [\"padding_mask\", \"segment_ids\", \"token_ids\"]\n",
" text_inputs = {\n",
- " feature: keras.Input(shape=(128,), dtype=tf.int32, name=feature)\n",
+ " feature: keras.Input(shape=(256,), dtype=\"int32\", name=feature)\n",
" for feature in bert_input_features\n",
" }\n",
- "\n",
+ " text_inputs = list(text_inputs.values())\n",
" # Create the encoders.\n",
" vision_encoder = create_vision_encoder(\n",
" num_projection_layers, projection_dims, dropout_rate, vision_trainable\n",
@@ -796,7 +878,7 @@
" # Concatenate the projections and pass through the classification layer.\n",
" concatenated = keras.layers.Concatenate()([vision_projections, text_projections])\n",
" outputs = keras.layers.Dense(3, activation=\"softmax\")(concatenated)\n",
- " return keras.Model([image_1, image_2, text_inputs], outputs)\n",
+ " return keras.Model([image_1, image_2, *text_inputs], outputs)\n",
"\n",
"\n",
"multimodal_model = create_multimodal_model()\n",
@@ -833,10 +915,10 @@
"outputs": [],
"source": [
"multimodal_model.compile(\n",
- " optimizer=\"adam\", loss=\"sparse_categorical_crossentropy\", metrics=\"accuracy\"\n",
+ " optimizer=\"adam\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"]\n",
")\n",
"\n",
- "history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=10)"
+ "history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)"
]
},
{
@@ -960,7 +1042,7 @@
"[Recognizing Multimodal Entailment](https://multimodal-entailment.github.io/)\n",
"tutorial provides a comprehensive overview.\n",
"\n",
- "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/multimodal-entailment) ",
+ "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/multimodal-entailment)\n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/multimodal_entailment)"
]
}
diff --git a/examples/nlp/md/multimodal_entailment.md b/examples/nlp/md/multimodal_entailment.md
index 05304531b2..6ff8bca22e 100644
--- a/examples/nlp/md/multimodal_entailment.md
+++ b/examples/nlp/md/multimodal_entailment.md
@@ -2,7 +2,7 @@
**Author:** [Sayak Paul](https://twitter.com/RisingSayak)
**Date created:** 2021/08/08
-**Last modified:** 2021/08/15
+**Last modified:** 2025/01/03
**Description:** Training a multimodal model for predicting entailment.
@@ -46,6 +46,14 @@ using the following command:
!pip install -q tensorflow_text
```
+
+