##### Copyright 2025 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Audio Finetune with Hugging Face Transformers

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_3n]Audio_Finetune_with_HF.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

Gemma 3n is a generative AI model optimized for use in everyday devices, such as phones, laptops, and tablets. This tutorial demonstrates how to fine-tune the Gemma 3n model with audio inputs for text generation. Using the ["Cat Meow Classification" dataset](https://www.kaggle.com/datasets/andrewmvd/cat-meow-classification/), we'll train the model to identify a cat's breed, sex, and the context of its meow. The Transformers Python library provides a API for accessing pre-trained generative AI models, including Gemma. For more information, see the [Transformers](https://huggingface.co/docs/transformers/en/index) documentation.

## Setup

Before starting this tutorial, complete the following steps:

* Get access to Gemma by logging into [Hugging Face](https://huggingface.co/google/gemma-3n-E4b-it) and selecting **Acknowledge license** for a Gemma model.
* Select a Colab runtime with sufficient resources to run
  the Gemma model size you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).
* Generate a Hugging Face [Access Token](https://huggingface.co/docs/hub/en/security-tokens#how-to-manage-user-access-token) and use it to login from Colab.

This notebook will run on an NVIDIA L4 or A100 GPU.

In [None]:
# Login into Hugging Face Hub
from huggingface_hub import notebook_login
notebook_login()

### Install Python packages

Install the Hugging Face libraries required for running the Gemma model and making requests.

In [None]:
# Install a transformers version that supports Gemma 3n (>= 4.53)
!pip install -U "transformers>=4.53.0" timm trl peft

## Define formatting helper functions

Create a helper to generate a text.

In [None]:
import time

GEMMA_PATH = "google/gemma-3n-E4B-it" #@param ["google/gemma-3n-E2B-it", "google/gemma-3n-E4B-it"]

tick_start = 0

def tick():
    global tick_start
    tick_start = time.time()

def tock():
    print(f"TOTAL TIME ELAPSED: {time.time() - tick_start:.2f}s")

def text_gen(model, processor, message, max_tokens=256):
  tick()

  input_ids = processor.apply_chat_template(
    message,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
  )
  input_len = input_ids["input_ids"].shape[-1]
  input_ids = input_ids.to(model.device, dtype=model.dtype)
  outputs = model.generate(
    **input_ids,
    max_new_tokens=max_tokens,
    disable_compile=True
  )
  text = processor.batch_decode(
    outputs[:, input_len:],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=True
  )
  print('-'*80)
  print(text[0])
  print('-'*80)

  tock()


## Load Model

In [None]:
from transformers import AutoModelForImageTextToText, AutoProcessor

processor = AutoProcessor.from_pretrained(GEMMA_PATH)
model = AutoModelForImageTextToText.from_pretrained(GEMMA_PATH, torch_dtype="auto", device_map="auto")

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.20M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/769 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.15k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/171k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.66G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Device: cuda:0
DType: torch.bfloat16


## Load Dataset

Here's [the dataset](https://www.kaggle.com/datasets/andrewmvd/cat-meow-classification/) for Cat Meow Classification. While this dataset is typically for classification, we'll use it to train a language model as a fun learning exercise.

The file in this dataset follow the naming convention shown below.

```
Naming convention for files -> C_NNNNN_BB_SS_OOOOO_RXX, where:

C = emission context (values: B = brushing; F = waiting for food; I: isolation in an unfamiliar environment);
NNNNN = cat’s unique ID;
BB = breed (values: MC = Maine Coon; EU: European Shorthair);
SS = sex (values: FI = female, intact; FN: female, neutered; MI: male, intact; MN: male, neutered);
OOOOO = cat owner’s unique ID;
R = recording session (values: 1, 2 or 3)
XX = vocalization counter (values: 01..99)
```

Using this format, you will create the training dataset.

In [None]:
import kagglehub
download_path = kagglehub.dataset_download("andrewmvd/cat-meow-classification")
dataset_path = download_path + "/dataset/dataset"
extra_path = download_path + "/extras/sequences"
print("Path to dataset files:", dataset_path)

train = []

import os
for file in sorted(os.listdir(dataset_path)):
  context = "unknown"
  if file[:1] == "B":
    context = "brushing"
  elif file[:1] == "F":
    context = "waiting for food"
  elif file[:1] == "I":
    context = "isolation in an unfamiliar environment"

  breed = "unknown"
  if file[8:10] == "MC":
    breed = "Maine Coon"
  elif file[8:10] == "EU":
    breed = "European Shorthair"

  sex = "unknown"
  if file[11:13] == "FI":
    sex = "female, intact"
  elif file[11:13] == "FN":
    sex = "female, neutered"
  elif file[11:13] == "MI":
    sex = "male, intact"
  elif file[11:13] == "MN":
    sex = "male, neutered"

  message = [
      {"role": "user", "content": [
        {"type": "audio", "audio": os.path.join(dataset_path, file)},
        {"type": "text", "text": "Describe this audio."},
      ]},
      {"role": "assistant", "content": [
        {"type": "text", "text": f"context: {context}\nbreed: {breed}\nsex: {sex}"}
      ]}
  ]

  train.append(message)

import random
random.shuffle(train)

print(len(train))
print(train[0])

Path to dataset files: /kaggle/input/cat-meow-classification/dataset/dataset
440
[{'role': 'user', 'content': [{'type': 'audio', 'audio': '/kaggle/input/cat-meow-classification/dataset/dataset/B_NUL01_MC_MI_SIM01_301.wav'}, {'type': 'text', 'text': 'Describe this audio.'}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'context: brushing\nbreed: Maine Coon\nsex: male, intact'}]}]


## Before Finetune

The output below shows that the model currently cannot recognize the cat's meow and understand the context behind it.

In [None]:
from IPython.display import Audio, display

files = [
    train[0][0]["content"][0]["audio"],
    train[1][0]["content"][0]["audio"],
    train[2][0]["content"][0]["audio"],
]

for file in files:
  prompt = [{
    "role": "user",
    "content": [
      {"type": "audio", "audio": file},
      {"type": "text", "text": "Describe this audio."},
    ]
  }]
  print(prompt)
  display(Audio(file))

  text_gen(model, processor, prompt)


[{'role': 'user', 'content': [{'type': 'audio', 'audio': '/kaggle/input/cat-meow-classification/dataset/dataset/B_NUL01_MC_MI_SIM01_301.wav'}, {'type': 'text', 'text': 'Describe this audio.'}]}]


--------------------------------------------------------------------------------
The audio consists of a repeated sound, "ch". It's a single, clear pronunciation of the letter "ch" being spoken repeatedly, creating a somewhat repetitive and possibly rhythmic sound.
--------------------------------------------------------------------------------
TOTAL TIME ELAPSED: 23.21s
[{'role': 'user', 'content': [{'type': 'audio', 'audio': '/kaggle/input/cat-meow-classification/dataset/dataset/I_CAN01_EU_FN_GIA01_305.wav'}, {'type': 'text', 'text': 'Describe this audio.'}]}]


--------------------------------------------------------------------------------
The audio consists of a repetitive, drawn-out "ugh" sound, repeated many times in quick succession. It has a slightly exasperated or frustrated tone.
--------------------------------------------------------------------------------
TOTAL TIME ELAPSED: 5.28s
[{'role': 'user', 'content': [{'type': 'audio', 'audio': '/kaggle/input/cat-meow-classification/dataset/dataset/B_CAN01_EU_FN_GIA01_105.wav'}, {'type': 'text', 'text': 'Describe this audio.'}]}]


--------------------------------------------------------------------------------
The audio consists of a repeated sound of someone sighing or groaning. The sound is a short, breathy expulsion of air, often associated with tiredness, frustration, or disappointment. It's a single, consistent sound throughout the duration of the audio.
--------------------------------------------------------------------------------
TOTAL TIME ELAPSED: 8.38s


## LoRA Fine-tuning

> NOTE: To keep the training time under 10 minutes on an A100 GPU, we trained for only one epoch with a high learning rate. For better results, we recommend training for more epochs with a lower learning rate.

In [None]:
import datasets
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

import torch

# Check if GPU benefits from bfloat16
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

def collate_fn(examples):
    input_ids = processor.apply_chat_template(
        examples[0],
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    input_ids = input_ids.to(model.device, dtype=model.dtype)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = input_ids["input_ids"].clone()

    # Use Gemma3n specific token masking
    labels[labels == processor.tokenizer.pad_token_id] = -100
    if hasattr(processor.tokenizer, 'image_token_id'):
        labels[labels == processor.tokenizer.image_token_id] = -100
    if hasattr(processor.tokenizer, 'audio_token_id'):
        labels[labels == processor.tokenizer.audio_token_id] = -100
    if hasattr(processor.tokenizer, 'boi_token_id'):
        labels[labels == processor.tokenizer.boi_token_id] = -100
    if hasattr(processor.tokenizer, 'eoi_token_id'):
        labels[labels == processor.tokenizer.eoi_token_id] = -100

    input_ids["labels"] = labels
    return input_ids

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

training_args = SFTConfig(
    output_dir="gemma-3n-E2B-it-meow",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_checkpointing=False,    # Caching is incompatible with gradient checkpointing in Gemma3nTextDecoderLayer.
    logging_steps=len(train)/4,
    save_steps=len(train)/4,
    learning_rate=5e-04,
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    lr_scheduler_type="constant",
    report_to="none",
    dataloader_pin_memory=False,
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)

from transformers import TrainerCallback

class MyCallback(TrainerCallback):
    "A callback that evaluates the model at the end of eopch"
    def __init__(self, evaluate):
        self.evaluate = evaluate # evaluate function

    def on_log(self, args, state, control, **kwargs):
        # Evaluate the model using text generation
        print(f"Step {state.global_step} finished. Running evaluation:")
        self.evaluate()

# pick random extra files to evaluate
sound_list = os.listdir(extra_path)
random.shuffle(sound_list)
evaluate_sound = [
    os.path.join(extra_path, sound_list[0]),
    os.path.join(extra_path, sound_list[1]),
    os.path.join(extra_path, sound_list[2]),
]
def evaluate():
    for file in evaluate_sound:
        print(file)
        text_gen(model, processor, [{"role": "user",
            "content": [
                {"type": "audio", "audio": file},
                {"type": "text", "text": "Describe this audio."},
            ]
        }])

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train,
    processing_class=processor.tokenizer,
    peft_config=peft_config,
    callbacks=[MyCallback(evaluate)]
)
trainer.train()


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
110,0.5634
220,0.2454
330,0.2185
440,0.4189


Step 110 finished. Running evaluation:
/kaggle/input/cat-meow-classification/extras/sequences/B_CAN01_EU_FN_GIA01_1SEQ1.wav
--------------------------------------------------------------------------------
```
Describe the audio.
```
--------------------------------------------------------------------------------
TOTAL TIME ELAPSED: 2.81s
/kaggle/input/cat-meow-classification/extras/sequences/I_BLE01_EU_FN_DEL01_2SEQ2.wav
--------------------------------------------------------------------------------
```
Describe this audio.
```

**Isolation in an unfamiliar landscape**

This audio evokes a sense of isolation in an unfamiliar environment. The combination of a female vocalist, brushing, and isolation in a natural setting creates a unique and intimate experience.
--------------------------------------------------------------------------------
TOTAL TIME ELAPSED: 14.78s
/kaggle/input/cat-meow-classification/extras/sequences/I_NUL01_MC_MI_SIM01_2SEQ1.wav
-----------------------------------

TrainOutput(global_step=440, training_loss=0.3615472056648948, metrics={'train_runtime': 629.5369, 'train_samples_per_second': 0.699, 'train_steps_per_second': 0.699, 'total_flos': 3012695247064800.0, 'train_loss': 0.3615472056648948})

## After Finetune

The output below shows that the model try to catch the cat's meow and understand the context behind it.

In [None]:
from peft import PeftModel

from transformers import AutoModelForImageTextToText, AutoProcessor

# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
  "gemma-3n-E2B-it-meow/checkpoint-440",
  device_map="auto",
  torch_dtype="auto",
)
processor = AutoProcessor.from_pretrained(GEMMA_PATH)

file_list = os.listdir(extra_path)
random.shuffle(file_list)

for file in file_list[:5]:
  prompt = [{
    "role": "user",
    "content": [
      {"type": "audio", "audio": os.path.join(extra_path, file)},
      {"type": "text", "text": "Describe this audio."},
    ]
  }]
  print(prompt)
  text_gen(model, processor, prompt)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[{'role': 'user', 'content': [{'type': 'audio', 'audio': '/kaggle/input/cat-meow-classification/extras/sequences/I_NUL01_MC_MI_SIM01_2SEQ1.wav'}, {'type': 'text', 'text': 'Describe this audio.'}]}]
--------------------------------------------------------------------------------
context: brushing
breed: European Shorthair
sex: male, neutered
--------------------------------------------------------------------------------
TOTAL TIME ELAPSED: 4.40s
[{'role': 'user', 'content': [{'type': 'audio', 'audio': '/kaggle/input/cat-meow-classification/extras/sequences/F_IND01_EU_FN_ELI01_1SEQ1.wav'}, {'type': 'text', 'text': 'Describe this audio.'}]}]
--------------------------------------------------------------------------------
context: waiting for food
breed: Maine Coon
sex: male, neutered
--------------------------------------------------------------------------------
TOTAL TIME ELAPSED: 4.37s
[{'role': 'user', 'content': [{'type': 'audio', 'audio': '/kaggle/input/cat-meow-classification/extr

## Next steps

Build and explore more with Gemma models:

* [Fine-tune Gemma for text tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora)
* [Fine-tune Gemma for vision tasks using Hugging Face Transformers](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora)
* [Perform distributed fine-tuning and inference on Gemma models](https://ai.google.dev/gemma/docs/core/distributed_tuning)
* [Use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma)
* [Fine-tune Gemma using Keras and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb)