Before starting:

*  Go to Runtime menu and change runtime type to T4 GPU
*  The first pip install cell will require restart of runtime after running.
*  You must have HF_TOKEN colab secret set to download models: Use the key icon in the left side toolbar.

In [None]:
# You will be asked to restart the session after this install.
%%capture
!pip install git+https://github.com/google-research/mseb

In [None]:
%%capture
!pip install scann
!pip install whisper

In [None]:
import mseb
from mseb.datasets import parquet as parquet_datasets
try:
  # scann import throws because __init__ tries to enable unused tf ops.
  from mseb.encoders import hf_llm_encoder
except:
  # but the deps are actually loaded, will succeed second try.
  from mseb.encoders import hf_llm_encoder

from pprint import pprint
from IPython.display import Audio
import json

from absl import flags
FLAGS = flags.FLAGS
FLAGS(["colab",
       "--dataset_basepath=https://storage.googleapis.com/mseb_asru_tutorial",
       "--task_cache_basepath=/tmp/cache"])
assert FLAGS.dataset_basepath

In [None]:
# Set huggingface read token.
import os
from google.colab import userdata
os.environ['HF_TOKEN'] = userdata.get("HF_TOKEN")

In [None]:
# Check for GPU
import torch
assert torch.cuda.device_count() > 0

# Tasks

MSEB has six classes of task:
*  classification
*  clustering
*  reasoning
*  reranking
*  retrieval
*  segmentation

For demonstration purposes, we have sampled small subsets of the evaluation datasets to use in colab and made them available on gcs. We'll override the datasets of a few tasks here to point to the pre-sampled demo data.

We will look at two instances of tasks for this demo:
*  **Intent Classification** (SpeechMassive)
*  **Passage Retrieval** (SimpleVoiceQuestions)

In [None]:
# Intent classification (SpeechMassive).
from mseb.tasks.classifications.intent import speech_massive

# Override with the demo data.
class SpeechMassiveFrFrIntentClassification(speech_massive.SpeechMassiveFrFrIntentClassification):
  def _get_dataset(self):
    return parquet_datasets.ParquetDataset(
        dataset_name="speech_massive",
        task_name="SpeechMassiveFrFrIntentClassification",
        filename="SpeechMassiveDataset_task_language_fr-FR.parquet",
        sample_n=2)

pprint(SpeechMassiveFrFrIntentClassification.metadata)

In [None]:
# Inspect sounds used by the task which will be passed through the encoder.
task = SpeechMassiveFrFrIntentClassification()
sound = next(task.sounds())
pprint(sound)
Audio(sound.waveform, rate=sound.context.sample_rate)

In [None]:
# Evaluation examples.
example = next(iter(task.examples('passage_retrieval_in_lang')))
pprint(example)

In [None]:
# Passage retrieval (SimpleVoiceQuestions).
from mseb.tasks.retrievals.passage_in_lang import svq as passage_retrieval_svq
from mseb.datasets.parquet import ParquetDataset
from mseb import types as mseb_types

class SVQEnUsPassageInLangRetrieval(passage_retrieval_svq.SVQEnUsPassageInLangRetrieval):

  def __init__(self):
    super().__init__()
    self._ds = None
    self._get_dataset()

  def _get_dataset(self):
    if not self._ds:
      self._ds = ParquetDataset(
          dataset_name="svq",
          task_name="passage_retrieval_in_lang",
          filename="SimpleVoiceQuestionsDataset_passage_retrieval_in_lang_locale_en_us.parquet",
          sample_n=2)
    return self._ds

  # Get the corpus from pre-computed top passage sets.
  def documents(self):
    ds = self._get_dataset().get_task_data("passage_retrieval_in_lang")
    top_passages = ds["top100_passages_gemini_embedding_whisper_transcript"]
    for x in top_passages:
      for x in [json.loads(line) for line in x.splitlines() if line.strip()]:
        yield mseb_types.Text(
            text=x["text"],
            context=mseb_types.TextContextParams(title="no title", id=x["id"]))

pprint(SVQEnUsPassageInLangRetrieval.metadata)

In [None]:
# Example sound.
task = SVQEnUsPassageInLangRetrieval()
sound = next(task.sounds())
pprint(sound)
Audio(sound.waveform, rate=sound.context.sample_rate)

In [None]:
# The examples evaluated by the class.
example = next(task.examples('passage_retrieval_in_lang'))
pprint(example)

In [None]:
# Passage corpus.
corpus = task.documents()
pprint(next(corpus))

# Encoders

In [None]:
# Encoder imports and definitions.
from mseb.encoders import raw_encoder
from mseb.encoders import gecko_encoder
from mseb.encoders import prompt_registry

spectrogram_encoder = raw_encoder.RawEncoder(
  frame_length=25,
  frame_step=10,
  transform_fn=raw_encoder.spectrogram_transform,
  pooling="mean",
)

gemma_intent_classification = hf_llm_encoder.HFLLMWithTitleAndContextEncoder(
  model_path="google/gemma-3n-E2B-it",
  prompt=prompt_registry.get_prompt_metadata("intent_classification").load(),
)

In [None]:
# Prompt is part of the encoding process for a particular use case.
prompt = prompt_registry.get_prompt_metadata("intent_classification").load()
pprint(prompt.GetPromptTemplate())

In [None]:
# Run encoder: gemma3n prompted to produce the class labels of this task.
gemma_intent_classification.setup()
encoded = gemma_intent_classification.encode([sound])
pprint(encoded)

# Running task/encoder benchmarks.

The mechanism for running the encoder across as task is help in the Runner.
For this colab, the runner will be the DirectRunner that simply iterates over the data in python and executes the encoder locally. A BeamRunner is supplied
for running in distributed settings, and the Runner interface can be implemented
by an end-user to customize how to distribute work in their local environement.

In [None]:
# Run a task. This is how the run_task script runs a benchmark.

from mseb.runner import DirectRunner
from mseb.leaderboard import run_benchmark

def run_task(encoder_name, encoder, task):
  runner = DirectRunner(encoder=encoder, batch_size=1)

  # Setup runs global task pre-processing, for instance, running encoder over
  # corpous for retrieval and building index. Not all tasks will have a setup
  # step.
  task.setup(runner=runner)

  # Run benchmark uses the runner to encode all the task sounds and then calls
  # the task evaluator.
  return run_benchmark(encoder_name=encoder, runner=runner, task=task)

result = run_task(
    encoder_name="gemma_intent_classification",
    encoder=gemma_intent_classification,
    task=SpeechMassiveFrFrIntentClassification())
pprint(result)
