<div style="margin-bottom: 1em;">
  <a href="https://colab.research.google.com/github/google-deepmind/jax_privacy/blob/main/examples/dp_sgd_keras_gemma3_lora_finetuning_samsum.ipynb" target="_blank">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" />
  </a>
  <a href="https://github.com/google-deepmind/jax_privacy/blob/main/examples/dp_sgd_keras_gemma3_lora_finetuning_samsum.ipynb" target="_blank" style="margin-left: 10px;">
    <img src="https://img.shields.io/badge/GitHub-view--source-black?logo=github" />
  </a>
</div>

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tutorial of DP-SGD LoRA fine-tuning Gemma3 in Keras on SAMSum dataset

**Copyright 2025 DeepMind Technologies Limited.**

Welcome to Jax Privacy for Keras! In this tutorial you will learn how to LoRA fine-tune [Gemma3 LLM](https://www.kaggle.com/models/keras/gemma3) in a differentially private (DP) way using [DP-SGD algorithm](https://medium.com/pytorch/differential-privacy-series-part-1-dp-sgd-algorithm-explained-12512c3959a3). We will fine-tune the model on the [SAMSum dataset](https://huggingface.co/datasets/Samsung/samsum).

To perform the full fine-tuning and reproduce the results we recommend using A100 GPU, ideally multiple (e.g. 8 or 16), to speed up the training process. You can obtain them in [Google Colab](https://colab.research.google.com/) or in [Google Cloud Vertex AI](https://cloud.google.com/vertex-ai). For Gemma3 4b, the model we fine-tune in this tutorial, you need at least [16GB of memory in GPU](https://ai.google.dev/gemma/docs/core#sizes). If you don't have it you can use a smaller Gemma 1B model or enable mixed precision (see below).

The following links might be helpful as complementary material:

* [Gemma3 Overview](https://ai.google.dev/gemma/docs/core#sizes): Good introduction explaining what Gemma3 is.
* [Fine-tune Gemma in Keras using LoRA](https://ai.google.dev/gemma/docs/core/lora_tuning): Gemma3 fine-tuning without DP and on a different dataset, our notebook is very similar to this one.
* [KerasHub: Get started with Gemma 3](https://www.kaggle.com/code/abheesht75/kerashub-get-started-with-gemma-3): KerasHub tutorial how to make predictions with Gemma3 model (including images).
* [Distributed tuning with Gemma using Keras](https://ai.google.dev/gemma/docs/core/distributed_tuning): Gemma3 fine-tuning with model distribution, useful if you want to fine-tune Gemma3 12B or 27B versions. In our example we do only data distribution, model is not distributed.

## Install and import dependencies

In [None]:
%%capture

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"
!pip uninstall -y -q keras-hub
!pip install -q -U keras-hub
!pip install rouge-score
!pip install tqdm
!pip install ipywidgets

!pip install dp_accounting jaxtyping drjax
!pip install jax_privacy==1.0.0

In [None]:
import os

os.environ["KERAS_HOME"] = os.getcwd() # Ensure that Keras uses home directory, which has enough space
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00" # Avoid memory fragmentation on JAX backend.

import keras
import keras_hub
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm

# Jax Privacy deps
from jax_privacy.keras import keras_api

## Login to Kaggle

It is necessary to download the Gemma3 model. You might also have to give some consents.

See more information [here](https://ai.google.dev/gemma/docs/core/distributed_tuning#kaggle_credentials).

In [None]:
import kagglehub

kagglehub.login()

# If you are using Colab, you can alternatively set KAGGLE_USERNAME and KAGGLE_KEY
# values in user data, and then uncomment and run the following code:
#
# from colabtools import userdata
#
# os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
# os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
#
# You use userdata to keep the Kaggle API key safe. Alternatively, you can
# hardcode the values but it is not recommended due to security risks of
# leaking the API key.


# If you're not using Colab, set the env vars as appropriate for your system.
# For example, to set the env vars on Linux you can run in terminal:
# ```
# export KAGGLE_USERNAME="your_username"
# export KAGGLE_KEY="your_key"
# ```

## Hyper-parameters setup

We are going to fine-tune Gemma3 model with 4 billion parameters with full 32-bit weights. Such a Gemma3 model is 16GB and can fit into A100 GPU with 40GB. [Here](https://ai.google.dev/gemma/docs/core#sizes) you can find other model options and their memory requirements. For example, you can fit 16 bit Gemma3 12B into 40 GB A100 GPU.

The maximum sequence length in tokens will be 512. It is enough for the SAMSum dataset, 95% of examples have length smaller than that. Making the length too large will make inference slower.

The batch size should be multiple to the number of GPUs you are going to use. This is because we are going to do data parallelism (distribution): batch examples will be evenly distributed over all available GPUs and they will be processed in parallel. We don't do model parallelism (i.e. split the model weights on multiple GPUs), however if you are going to use a larger model (e.g. 27B) you will likely need to distribute it.

Gradient accumulation is an important step for DP training: DP works better the larger the effective batch size is. Effective batch size is the total number of examples the model processes before making a step in optimizer, it is equal to `gradient_accumulation_steps * batch_size`. We can't make the physical batch size that large because it won't fit into the memory. You can make test batch size bigger (usually twice larger) to speed up the evaluation.

You can set `USE_MIXED_PRECISION` to True if your model does not fit into GPU memory. This will be the case, for example, for A100 40 GB GPU and Gemma3 12B model. Note that the results in this notebook were obtained without mixed precision and the given hyper parameters might not work good with the mixed precision training, so you might need to adjust them.

By switching `USE_DP` on or off you can play around and compare DP and non-DP fine-tuning.

Epsilon, delta and clipping norm are the main DP parameters. You can read about their meaning in DP literature, e.g. [here](https://medium.com/pytorch/differential-privacy-series-part-1-dp-sgd-algorithm-explained-12512c3959a3).

Set `TEST_RUN` to True if you don't have a high-performance GPU right now but still want to run the whole notebook to check that it works for you and play with it a little bit. However, you won't be able to fully fine-tune the model.

In [None]:
GEMMA3_MODEL_TYPE = "gemma3_instruct_4b_text"
SEQUENCE_LENGTH = 512
TEST_DS_SEQUENCE_LENGTH = 512
EPOCHS = 3
BATCH_SIZE = 16 # Should be multiple of the number of GPUs you have.
GRADIANT_ACCUMULATION_STEPS = 64 # i.e. effective batch size is 16 * 64 = 1024
TEST_DS_BATCH_SIZE = 16
LORA_RANK = 32
LEARNING_RATE = 0.003
SEED = 0

# Use bfloat16 (i.e. 16-bit float) weights or not. Not all GPUs support bfloat16 (e.g. V100 does not support it, A100 does).
USE_MIXED_PRECISION = False

USE_DP = True
# DP-SGD parameters.
EPSILON = 4.0
DELTA = 2e-5  # chosen as a value smaller than 1/n^1.1 ~ 2.6e-5 where n = 14732 is number of examples in the training set.
CLIPPING_NORM = 0.001

# TEST_RUN executes on small data and small model, useful just to check that
# the code runs successfully in your environment.
TEST_RUN = False
if TEST_RUN:
  GEMMA3_MODEL_TYPE = "gemma3_instruct_1b"
  SEQUENCE_LENGTH = 256
  TEST_DS_SEQUENCE_LENGTH = 256
  MAX_TRAIN_SIZE = 3000
  LORA_RANK = 4
  USE_MIXED_PRECISION = False # 1b model is small and likely to fit into any GPU.

## Training data

### Download the train and validation datasets

Each example in the SAMSum dataset is a triple: example id, dialogue and its summary.

In [None]:
%%capture

manual_dir = '/root/tensorflow_datasets/downloads/manual'
os.makedirs(manual_dir, exist_ok=True)

!wget -O corpus.7z https://arxiv.org/src/1911.12237v2/anc/corpus.7z
!sudo apt-get install -y p7zip-full
!7z x corpus.7z -ocorpus_extracted -y
!mv corpus_extracted/* /root/tensorflow_datasets/downloads/manual

In [None]:
SOURCE_TRAIN_DS, SOURCE_VALIDATION_DS = tfds.load('samsum', split=['train', 'validation'])

if TEST_RUN:
  SOURCE_TRAIN_DS = SOURCE_TRAIN_DS.take(MAX_TRAIN_SIZE)
  SOURCE_VALIDATION_DS = SOURCE_VALIDATION_DS.take(MAX_TRAIN_SIZE)

Let's take a look at an entry in validation dataset.

In [None]:
SOURCE_EXAMPLE_DS = SOURCE_VALIDATION_DS.take(1).batch(1, drop_remainder=True)
SOURCE_EXAMPLE = SOURCE_EXAMPLE_DS.as_numpy_iterator().next()
for key, val in SOURCE_EXAMPLE.items():
  decoded_val = val[0].decode('utf-8')
  print(f'{key}:\n"{decoded_val}"\n')

### Pre-process the data to the expected format

Gemma3 expects the training examples in the format: `{"prompts": list[str], "responses": list[str]}`, where `prompts[i]` and `response[i]` are the i-th prompt and its expected response.

We construct the prompt for each example in the following way:
```
Summarize the following dialogue:
{dialogue}
Summary:
```


We prepend a prefix and a suffix to the dialogue to make the prompt more self-explanatory.

The response is added without any prefixes and suffixes, just the summary without any additional text.

Such a format of input to Gemma3 model means that we will train the model in the following way:

given a prompt in the format above it has to generate the following text:
```
Summarize the following dialogue:
{dialogue}
Summary:
{summary}
```

I.e. expressing it in Python it has to generate `prompt + summary`.

In [None]:
def source_to_gemma3_format(dialogue_dict):
  return {
      "prompts": tf.strings.join(["Summarize the following dialogue:\n", dialogue_dict["dialogue"], "\nSummary:\n"]),
      "responses": dialogue_dict["summary"]
  }

In [None]:
TRAIN_DS = SOURCE_TRAIN_DS.map(source_to_gemma3_format)
VALIDATION_DS = SOURCE_VALIDATION_DS.map(source_to_gemma3_format)

Let's take a look at the input to our model.

In [None]:
EXAMPLE_DS = VALIDATION_DS.take(1).batch(1, drop_remainder=True)
EXAMPLE = EXAMPLE_DS.as_numpy_iterator().next()
for key, val in EXAMPLE.items():
  decoded_val = val[0].decode('utf-8')
  print(f'{key}:\n"{decoded_val}"\n')

### Determine training size

We need to determine the training set size because it directly impacts the total number of optimization steps. The number of optimization steps is precisely determined by the interplay of the training set size, the number of epochs, the batch size, and the gradient accumulation steps.

Knowing the exact number of optimization steps beforehand is essential for calibrated noise generation in DP-SGD. During each optimization step, noise is added to ensure a specific protection level (defined by epsilon and delta). To accurately calibrate this noise for the desired privacy guarantees, we must know precisely how many times the noise will be generated throughout the training process.

In [None]:
# Train size is important for DP-SGD.
TRAIN_SIZE = int(TRAIN_DS.cardinality().numpy())
print(f'Train size: {TRAIN_SIZE}')
VALIDATION_SIZE = int(VALIDATION_DS.cardinality().numpy())
print(f'Validation size: {VALIDATION_SIZE}')

TRAIN_DS = TRAIN_DS.shuffle(buffer_size=2048).batch(BATCH_SIZE, drop_remainder=True)
VALIDATION_DS = VALIDATION_DS.batch(BATCH_SIZE, drop_remainder=True)

### Setup data parallelism

Parallelize training on all available devices, by splitting data by batch dimension.

In [None]:
DATA_PARALLEL = keras.distribution.DataParallel()
# You can see over how many GPUs the data will be distributed.
print(DATA_PARALLEL)
keras.distribution.set_distribution(DATA_PARALLEL)

## Gemma3 model setup

### Load the model

If `USE_MIXED_PRECISION` is true then model will be loaded with 16-bit weights.

It is important that Gemma3 preprocessor is `Gemma3CausalLMPreprocessor` because it does masking and padding properly (assuming the input is in the `{prompts, responses}` format).

In [None]:
MODEL_WEIGHTS_DTYPE = None # use default dtype
if USE_MIXED_PRECISION:
  print("Using mixed precision")
  keras.mixed_precision.set_global_policy('mixed_bfloat16')
  MODEL_WEIGHTS_DTYPE = "bfloat16"

gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(GEMMA3_MODEL_TYPE,
                                                       dtype=MODEL_WEIGHTS_DTYPE)

assert isinstance(gemma_lm.preprocessor, keras_hub.models.Gemma3CausalLMPreprocessor)
gemma_lm.preprocessor.sequence_length = SEQUENCE_LENGTH
gemma_lm.summary()

### Validation example inference before fine-tuning

Let's see what the model outputs before we fine-tune it.

In [None]:
def make_validation_example_inference():
  return gemma_lm.generate(EXAMPLE['prompts'])[0]

def show_validation_example_inference():
  print(make_validation_example_inference())
  print(f"\nCorrect summary:\n{EXAMPLE['responses'][0].decode('utf-8')}")

In [None]:
show_validation_example_inference()

Not too bad but we can see the artifacts which fine-tuning can help us get rid of: for example the model always outputs `<end_of_turn>`.

### Enable LoRA

Notice how the number of trainable params significantly decreased. It happens so because the LoRA rank defines the size of trainable matrices and our rank is quite small. This is how LoRA makes fine-tuning of LLMs a reality! You can learn more about LoRA [here](https://ai.google.dev/gemma/docs/core/lora_tuning#configure_lora_tuning).

In [None]:
gemma_lm.backbone.enable_lora(rank=LORA_RANK)
gemma_lm.summary()

### Enable DP

The most interesting part of this notebook!

If you want to do DP-SGD training you have to create `DPKerasConfig` providing to JAX Privacy essential parameters for DP training. There is nothing surprising in the config values, we've already defined everything in the preceeding cells. You can notice, for example, how we use `TRAIN_SIZE` we determined earlier to calculate the number of train steps. Also note that `gradient_accumulation_steps` is provided as a separate parameter. JAX Privacy takes it into account to calculate the real number of optimization steps during the training.

JAX Privacy will throw an exception if we exceed the specified number of train steps. If we do fewer train step then we won't consume all the (eps, delta)-DP budget and add more noise than necessary. Therefore always make these params tight and rerun training if you realize that you don't need that many training steps.

Once params config is created, we have to call `make_private` providing the model and params. It will return an updated model whose further training for the pre-defined number of training steps will be differentially-private, you don't have to do anything more.

During these calls `noise multiplier` will be calculated: this is standard deviation of the Gaussian noise that will be added to the accumulated gradient at each optimization step. The value of it will be printed to STDOUT.

In [None]:
if USE_DP:
  params = keras_api.DPKerasConfig(
        epsilon=EPSILON,
        delta=DELTA,
        clipping_norm=CLIPPING_NORM,
        batch_size=BATCH_SIZE,
        train_steps=EPOCHS * (TRAIN_SIZE // BATCH_SIZE),
        train_size=TRAIN_SIZE,
        gradient_accumulation_steps=GRADIANT_ACCUMULATION_STEPS,
        seed=SEED,
  )
  gemma_lm = keras_api.make_private(gemma_lm, params)
  print(
      "DP training:"
      f" {CLIPPING_NORM=} {EPOCHS=} {BATCH_SIZE=}"
  )
else:
  print("Non-DP training")

### Prepare the model for training

Create optimizer, providing learning rate and accumulation steps. Then compile the model for training.

In [None]:
optimizer = keras.optimizers.Adam(
    learning_rate=LEARNING_RATE,
    gradient_accumulation_steps=GRADIANT_ACCUMULATION_STEPS,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

## Let's DP fine-tune Gemma3!

In [None]:
gemma_lm.fit(x=TRAIN_DS,
             epochs=EPOCHS,
             validation_data=VALIDATION_DS)

**IMPORTANT**: You can't call `fit` anymore because you've already performed the maximum allowed number of training steps. If you try to do more model optimization (training) steps, an exception will be thrown because otherwise you will exceed your declared (eps, delta)-DP budget.

### Validation example inference after fine-tuning

Let's see what the model outputs now, after we fine-tuned it.

In [None]:
show_validation_example_inference()

Now it performs better. For example, there is no `<end_of_turn>` in the output.

## Calculate performance metrics on the test dataset

Let's calculate F1 score of ROUGE metrics on the test dataset that our model hasn't seen yet. We will use these metrics to compare different setups between each other.

It is interesting to see performance of:

* the non-fine-tuned model (baseline)
* non-DP fine-tuned model
* DP fine-tuned model

### Prepare the model for testing

You might want to increase the maximum sequence length for testing to make all test examples fit this length (2048 is more than enough). We use 512, not all examples fit into this length but majority (~95%) of them do, so it is fine.

Note that the length is in Gemma3 tokens and not in number of English words. So, to evaluate how many of examples will be truncated you need to use `gemma_lm.preprocessor.tokenier.tokenize(str)`.

In [None]:
gemma_lm.preprocessor.sequence_length = TEST_DS_SEQUENCE_LENGTH

### Load and pre-process the test dataset

Following the same process we did it for the train and validation datasets.

In [None]:
SOURCE_TEST_DS =  tfds.load("samsum", split="test")
TEST_DS = SOURCE_TEST_DS.map(source_to_gemma3_format)
TEST_SIZE = int(TEST_DS.cardinality().numpy())
print(f'Test size: {TEST_SIZE}')
TEST_DS = TEST_DS.batch(TEST_DS_BATCH_SIZE)

### List the evaluation metrics

We will calculate ROUGE_1, ROUGE_2 and ROUGE_L metrics. See [this guide](https://medium.com/nlplanet/two-minutes-nlp-learn-the-rouge-metric-by-examples-f179cc285499) if you don't know what these metrics mean.

In [None]:
METRIC_FNS = {
  'rouge_1': keras_hub.metrics.RougeN(order=1),
  'rouge_2': keras_hub.metrics.RougeN(order=2),
  'rouge_l': keras_hub.metrics.RougeL(),
}

### Evaluation code

For each test example, we should feed only the prompt without the reponse to the model. The model will output the text that will contain the prompt and the generated summary, therefore we have to remove the prompt from the text. After that we have to take the expected response (summary) and supply both the generated summary and the expected summary to the ROUGE metrics library that will compare two strings and will calculate the ROUGE metrics for us.

In [None]:
def calculate_common_prefix(str1, str2):
    i = 0
    while i < len(str1) and i < len(str2) and str1[i] == str2[i]:
        i += 1
    return i

def strip_prompts_from_outputs(prompts: list[str], generated_outputs: list[str]) -> list[str]:
    stripped_outputs = []

    for prompt, full_output in zip(prompts, generated_outputs):
        # Find the first position of the prompt in the output and strip it.
        common_prefix = calculate_common_prefix(prompt, full_output)
        stripped_outputs.append(full_output[common_prefix:])

    return stripped_outputs


def eval_batch(batch):
  prompts = [p.decode("utf-8") for p in batch["prompts"].numpy()]
  # Important: do not feed responses to the model, supply only prompts.
  output_batch = gemma_lm.generate(prompts)
  output_text = strip_prompts_from_outputs(prompts, output_batch)
  target_text = [s.decode('utf-8') for s in batch['responses'].numpy()]

  for _, metric_fn in METRIC_FNS.items():
    metric_fn.update_state(target_text, output_text)

### Let's calculate the metrics!

For each ROUGE metric we can calculate precision and recall. To evaluate both of them in one number we take F1 score.

In [None]:
for batch in tqdm.tqdm(TEST_DS):
  eval_batch(batch)

RESULT = { f'{k}': m.result()['f1_score'] for k, m in METRIC_FNS.items() }
print(RESULT)

## Results

For Gemma3 4b tuned for instructions (`gemma3_instruct_4b_text`) you can expect the following F1 scores on test dataset:

| Experiment | ROUGE_1 | ROUGE_2 | ROUGE_L |
|---|---|---|---|
| Baseline (no fine-tuning) | 0.341 | 0.127 | 0.263 |
| Non-DP fine-tuning (i.e. `USE_DP=False`) | 0.512 | 0.273 | 0.433 |
| DP fine-tuning (i.e. `USE_DP=True`) | 0.487 | 0.251 | 0.412 |

The rest of the hyper-parameters were the same as in the setup cell in the beggining of the notebook, i.e.:

```
Variable                      Type        Data/Info
---------------------------------------------------
BATCH_SIZE                    int         16
CLIPPING_NORM                 float       0.001
DELTA                         float       2e-05
EPOCHS                        int         3
EPSILON                       float       4.0
GEMMA3_MODEL_TYPE             str         gemma3_instruct_4b_text
GRADIANT_ACCUMULATION_STEPS   int         64
LEARNING_RATE                 float       0.003
LORA_RANK                     int         32
SEED                          int         0
SEQUENCE_LENGTH               int         512
TEST_DS_BATCH_SIZE            int         16
TEST_DS_SEQUENCE_LENGTH       int         512
TEST_RUN                      bool        False
USE_MIXED_PRECISION           bool        False
```

These results were obtained on 16 A100 40GB GPUs.