# Audio Classification with Google AI Edge LiteRT
In this notebook you will use the Google AI Edge LiteRT API to classify audio.


## Preparation
The first thing you will need to do is install the necessary dependencies for this sample.




In [None]:
%pip install -q gradio

The next step is to download the YAMNet Audio Classification model from [Kaggle models]((https://www.kaggle.com/models/google/yamnet/tfLite)).

YAMNet is a deep net that predicts 521 audio event classes from the AudioSet-YouTube corpus it was trained on. It employs the Mobilenet_v1 depthwise-separable convolution architecture.

In [None]:
import pathlib
import kagglehub

# Download latest version
path = kagglehub.model_download("google/yamnet/tfLite/classification-tflite")
print("Path to model files:", path)

MODEL_PATH = str(next(pathlib.Path(path).rglob('*.tflite')))

Optionally, you can upload your own model (.tflite). If you want to do so, uncomment and run the cell below.


In [None]:
# from google.colab import files
# uploaded = files.upload()

# for filename in uploaded:
#   content = uploaded[filename]
#   with open(filename, 'wb') as f:
#     f.write(content)

# MODEL_PATH = list(uploaded.keys())[0]

# print('Uploaded model:', MODEL_PATH)

### Install and import libraries

In [None]:
!pip install ai-edge-litert-nightly

In [None]:
import tensorflow as tf
import numpy as np
import zipfile
import scipy

from ai_edge_litert.interpreter import Interpreter
from IPython.display import Audio
from scipy.io import wavfile

Read the associated files from models
The TensorFlow Lite model with metadata and associated files is essentially a zip file that can be unpacked with common zip tools to get the associated files. For example, you can unzip **1.tflite** and extract the labels in the model as follows:


In [None]:
labels_file = zipfile.ZipFile(MODEL_PATH).open('yamnet_label_list.txt')
labels = [l.decode('utf-8').strip() for l in labels_file.readlines()]
print(len(labels))  # Should print 521

### Performing Audio Classification

Now that you have the necessary dependencies, it's time to start classifying some audio! While there are a variety of ways to retrieve audio clips, this example will download .wav files of someone whistling and a cat meowing.

Load the model with the `Interpreter`:

In [None]:
interpreter = Interpreter(model_path=MODEL_PATH)

Next you'll load the input and output details for the model and access `waveform_input_index` and `scores_output_index` using the `index` key

In [None]:
input_details = interpreter.get_input_details()
waveform_input_index = input_details[0]['index']
output_details = interpreter.get_output_details()
scores_output_index = output_details[0]['index']

In [None]:
input_shape = interpreter.get_input_details()[0]['shape']
input_shape

Add a method to verify and convert a loaded audio is on the proper sample_rate (16K), otherwise it would affect the model's results.

In [None]:
def ensure_sample_rate(original_sample_rate, waveform,
                       desired_sample_rate=16000):
  """Resample waveform if required."""
  if original_sample_rate != desired_sample_rate:
    desired_length = int(round(float(len(waveform)) /
                               original_sample_rate * desired_sample_rate))
    waveform = scipy.signal.resample(waveform, desired_length)
  return desired_sample_rate, waveform

Here you will download a wav file and listen to it. If you have a file already available, just upload it to colab and use it instead.

In [None]:
!curl -O https://storage.googleapis.com/audioset/speech_whistling2.wav
!curl -O https://storage.googleapis.com/audioset/miaow_16k.wav

In [None]:
# @title Choose an audio file
wav_file_name = "speech_whistling2.wav" # @param ["miaow_16k.wav", "speech_whistling2.wav"]
sample_rate, wav_data = wavfile.read(wav_file_name, 'rb')
sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)

# Show some basic information about the audio.
duration = len(wav_data)/sample_rate
print(f'Sample rate: {sample_rate} Hz')
print(f'Total duration: {duration:.2f}s')
print(f'Size of the input: {len(wav_data)}')

# Listening to the wav file.
Audio(wav_data, rate=sample_rate)

The wav_data needs to be normalized to values in [-1.0, 1.0] (as stated in the model's documentation).




In [None]:
waveform = wav_data / tf.int16.max

The model you've downloaded has a fixed input window (15600)

For a given audio file, you'll have to split it in windows of data of the expected size. The last window might need to be filled with zeros.

In [None]:
# Split the audio
INPUT_SIZE = 15600
splitted_audio_data = tf.signal.frame(waveform, INPUT_SIZE, INPUT_SIZE, pad_end=True, pad_value=0)

You'll loop over all the splitted audio and apply the model for each one of them. Lets also save the result every single time we run the model.

In [None]:
results = []
for i, data in enumerate(splitted_audio_data):
  wavform_data = data.numpy().astype('float32')
  # Run the model, check the output.
  interpreter.resize_tensor_input(waveform_input_index, [wavform_data.size], strict=True)
  interpreter.allocate_tensors()
  interpreter.set_tensor(waveform_input_index, wavform_data)
  interpreter.invoke()
  scores = interpreter.get_tensor(scores_output_index)
  results.append(scores)
  print(scores.shape)  # Should print (1, 521)
  top_class_index = scores.argmax()
  infered_class = labels[top_class_index]
  print(infered_class)

Now you'll average the results out to get the final prediction.

In [None]:
results_np = np.array(results)
mean_results = results_np.mean(axis=0)
result_index = mean_results.argmax()
print(f'The main sound is: {labels[result_index]}')

### (Optional) Real-time Audio Classification using Gradio

Here you'll rely on Gradio to perform real-time audio classification using the model by gathering all the steps needed to classify audio using the LiteRT API.

In [None]:
import gradio as gr


def classify(audio):
  results = []

  # Split the audio
  splitted_audio_data = tf.signal.frame(audio, INPUT_SIZE, INPUT_SIZE, pad_end=True, pad_value=0)

  for i, data in enumerate(splitted_audio_data):
    audio_data = data.numpy().astype('float32')
    # Run the model.
    interpreter.resize_tensor_input(waveform_input_index, [audio_data.size], strict=True)
    interpreter.allocate_tensors()
    interpreter.set_tensor(waveform_input_index, audio_data)
    interpreter.invoke()
    scores = interpreter.get_tensor(scores_output_index)
    results.append(scores)

  results_np = np.array(results)
  mean_results = results_np.mean(axis=0)
  result_index = mean_results.argmax()
  return f'The main sound is: {labels[result_index]}'


def inference(stream, new_chunk):
  sample_rate, data = new_chunk
  data = data.astype(np.float32)
  sample_rate, data = ensure_sample_rate(sample_rate, data)
  data /= np.max(np.abs(data))

  if stream is not None:
    stream = np.concatenate([stream, data])
  else:
    stream = data

  return stream, classify(data)

# Gradio parameters
title="YAMNet"
description="An audio event classifier trained on the AudioSet dataset to predict audio events from the AudioSet ontology."
with gr.Blocks(
      title=title,
      theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue)
  ) as demo:
    with gr.Row(equal_height=False):
        with gr.Column(scale=5, elem_id="audio_classification"):
            gr.Interface(
                inference,
                ["state", gr.Audio(sources=["microphone"], streaming=True)],
                ["state", "text"],
                live=True
            )

demo.queue().launch()