Before starting:

*  Go to Runtime menu and change runtime type to T4 GPU
*  The first pip install cell will require restart of runtime after running.
*  You must have HF_TOKEN colab secret set to download models: Use the key icon in the left side toolbar.

In [None]:
# You will be asked to restart the session after this install.
%%capture
!pip install git+https://github.com/google-research/mseb

In [None]:
%%capture
!pip install scann
!pip install openai-whisper

In [None]:
import mseb
from mseb.datasets import parquet as parquet_datasets
try:
  # scann import throws because __init__ tries to enable unused tf ops.
  from mseb.encoders import hf_llm_encoder
except:
  # but the deps are actually loaded, will succeed second try.
  from mseb.encoders import hf_llm_encoder

from pprint import pprint
from IPython.display import Audio
import json

from absl import flags
FLAGS = flags.FLAGS
FLAGS(["colab",
       "--dataset_basepath=https://storage.googleapis.com/mseb_asru_tutorial",
       "--task_cache_basepath=/tmp/cache"])
assert FLAGS.dataset_basepath

In [None]:
# Set huggingface read token.
import os
from google.colab import userdata
os.environ['HF_TOKEN'] = userdata.get("HF_TOKEN")

In [None]:
# Check for GPU
import torch
assert torch.cuda.device_count() > 0
import gc

# Tasks

MSEB has six classes of task:
*  classification
*  clustering
*  reasoning
*  reranking
*  retrieval
*  segmentation

For demonstration purposes, we have sampled small subsets of the evaluation datasets to use in colab and made them available on gcs. We'll override the datasets of a few tasks here to point to the pre-sampled demo data.

We will look at two instances of tasks for this demo:
*  **Intent Classification** (SpeechMassive)
*  **Passage Retrieval** (SimpleVoiceQuestions)

In [None]:
# Intent classification (SpeechMassive).
from mseb.tasks.classifications.intent import speech_massive

# Override with the demo data.
class SpeechMassiveFrFrIntentClassification(speech_massive.SpeechMassiveFrFrIntentClassification):
  def _get_dataset(self):
    return parquet_datasets.ParquetDataset(
        dataset_name="speech_massive",
        task_name="SpeechMassiveFrFrIntentClassification",
        filename="SpeechMassiveDataset_task_language_fr-FR.parquet",
        sample_n=2)

pprint(SpeechMassiveFrFrIntentClassification.metadata)

In [None]:
# Inspect sounds used by the task which will be passed through the encoder.
task = SpeechMassiveFrFrIntentClassification()
sound = next(task.sounds())
pprint(sound)
Audio(sound.waveform, rate=sound.context.sample_rate)

In [None]:
# Evaluation examples.
example = next(task.examples('passage_retrieval_in_lang'))
pprint(example)

In [None]:
# @title Passage retrieval task (SimpleVoiceQuestions).
#
# We've hacked up a colab version of the task that avoids using scann for
# index building and uses a (very) small cached dataset.
#
# You probably don't need to get into this implementation.
from mseb.tasks.retrievals.passage_in_lang import svq as passage_retrieval_svq
from mseb.datasets.parquet import ParquetDataset
from mseb.evaluators import retrieval_evaluator
from mseb import types as mseb_types

class SVQEnUsPassageInLangRetrieval(passage_retrieval_svq.SVQEnUsPassageInLangRetrieval):

  def __init__(self, transcript_key="whisper_transcript",
               context_key="top100_passages_gemini_embedding_whisper_transcript"):
    super().__init__()
    self._transcript_key = transcript_key
    self._context_key = context_key
    self._ds = None
    self._get_dataset()

  def _get_dataset(self):
    if not self._ds:
      self._ds = ParquetDataset(
          dataset_name="svq",
          task_name="passage_retrieval_in_lang",
          filename="SimpleVoiceQuestionsDataset_passage_retrieval_in_lang_locale_en_us.parquet",
          id_key="utt_id",
          sample_n=2)
    return self._ds

  def sounds(self):
    svq_dataset = self._get_dataset()
    for example in svq_dataset.get_task_data(
        "passage_retrieval_in_lang",
        dtype={
            'locale': str,
            'utt_id': str,
            self._transcript_key: str,
        },
    ).itertuples():
      if example.locale == self.locale:
        sound = svq_dataset.get_sound(example._asdict())
        sound.context.text = getattr(example, self._transcript_key)
        if self._context_key:
          sound = mseb_types.SoundWithTitleAndContext(
            waveform=sound.waveform,
            context_text=getattr(example, self._context_key),
            context=sound.context,
          )
        yield sound

  def documents(self):
    ds = self._get_dataset().get_task_data("passage_retrieval_in_lang")
    top_passages = ds[self._context_key]
    for x in top_passages:
      for x in [json.loads(line) for line in x.splitlines() if line.strip()]:
        yield mseb_types.Text(
            text=x["text"],
            context=mseb_types.TextContextParams(id=x["id"]))

  def setup(
      self, runner=None
  ):
    self._evaluator = retrieval_evaluator.RetrievalEvaluator(
        searcher=None, id_by_index_id={}
    )

pprint(SVQEnUsPassageInLangRetrieval.metadata)

In [None]:
# Example sound.
task = SVQEnUsPassageInLangRetrieval()
sound = next(task.sounds())
pprint(sound)
Audio(sound.waveform, rate=sound.context.sample_rate)

In [None]:
# The examples evaluated by the class.
example = next(task.examples('passage_retrieval_in_lang'))
pprint(example)

In [None]:
# Passage corpus.
corpus = task.documents()
pprint(next(corpus))

# Encoders

In [None]:
# A very simple Text -> TextEmbedding encoder implementation.
from mseb import encoder as encoder_lib
from mseb import types as mseb_types
import codecs
import numpy as np

class Rot13Encoder(encoder_lib.MultiModalEncoder):

  def _setup(self):
    self._codec = 'rot13'

  def _check_input_types(self, batch):
    if not all(isinstance(x, mseb_types.Text) for x in batch):
      raise ValueError('Batch must be all Text input.')

  def _encode(self, batch):
    return [
        mseb_types.TextEmbedding(
            embedding=np.array(codecs.encode(x.text)),
            spans=np.array([[0, len(x.text)]]),
            context=x.context)
        for x in batch
    ]

rot13_encoder = Rot13Encoder()
text = mseb_types.Text(
    text="hello world",
    context=mseb_types.TextContextParams(id='test000'))
pprint(rot13_encoder.encode([text]))

In [None]:
import whisper
from mseb.encoders import whisper_encoder

In [None]:
# An embedding could be Sound -> SoundEmbedding(text) (ASR)
whisper_asr_medium = whisper_encoder.SpeechToTextEncoder("medium")
whisper_asr_medium.setup()

task = SVQEnUsPassageInLangRetrieval()
sound = next(task.sounds())
pprint(sound)
encoded, = whisper_asr_medium.encode([sound])
pprint(encoded)
Audio(sound.waveform, rate=sound.context.sample_rate)
del whisper_asr_medium
gc.collect()

In [None]:
# An embedding could be Sound -> SoundEmbedding(fixed size vector) (ASR)
whisper_pooled_medium = whisper_encoder.PooledAudioEncoder("medium")
whisper_pooled_medium.setup()

task = SVQEnUsPassageInLangRetrieval()
sound = next(task.sounds())
pprint(sound)
encoded, = whisper_pooled_medium.encode([sound])
pprint(encoded)
Audio(sound.waveform, rate=sound.context.sample_rate)
del whisper_pooled_medium
gc.collect()

In [None]:
from mseb.encoders import gecko_encoder
from mseb.encoders import prompt_registry

gemma_intent_classification = hf_llm_encoder.HFLLMWithTitleAndContextEncoder(
  model_path="google/gemma-3n-E2B-it",
  prompt=prompt_registry.get_prompt_metadata("intent_classification").load(),
)

In [None]:
# Prompt is part of the encoding process for a particular use case.
prompt = prompt_registry.get_prompt_metadata("intent_classification").load()
pprint(prompt.GetPromptTemplate())

In [None]:
# Run encoder: gemma3n prompted to produce the class labels for intent classification.
gemma_intent_classification.setup()
encoded = gemma_intent_classification.encode([sound])
pprint(encoded)

# Running task/encoder benchmarks.

The mechanism for running the encoder across as task is help in the Runner.
For this colab, the runner will be the DirectRunner that simply iterates over the data in python and executes the encoder locally. A BeamRunner is supplied
for running in distributed settings, and the Runner interface can be implemented
by an end-user to customize how to distribute work in their local environement.

In [None]:
from mseb.runner import DirectRunner

# The runner allows bulk encoding, it produces a mapping from instance ids to results.
runner = DirectRunner(encoder=gemma_intent_classification)
task = SpeechMassiveFrFrIntentClassification()
results_cache = runner.run(task.sounds())
print("\nSound ids:", results_cache.keys())
key = list(results_cache.keys())[0]
print(f"Result[{key}]:")
pprint(results_cache[key])

In [None]:
# Run a task. This is how the run_task script runs a benchmark.

from mseb.runner import DirectRunner
from mseb.leaderboard import run_benchmark

def run_task(encoder_name, encoder, task):
  runner = DirectRunner(encoder=encoder)

  # Setup runs global task pre-processing, for instance, running encoder over
  # corpous for retrieval and building index. Not all tasks will have a setup
  # step.
  task.setup(runner=runner)

  # Run benchmark uses the runner to encode all the task sounds and then calls
  # the task evaluator.
  return run_benchmark(encoder_name=encoder, runner=runner, task=task)

result = run_task(
    encoder_name="gemma_intent_classification",
    encoder=gemma_intent_classification,
    task=SpeechMassiveFrFrIntentClassification())
pprint(result)


In [None]:
# @title Cached retrieval encoder - for fast colab demo
from typing import Callable, Sequence, final
from mseb import encoder as encoder_lib
from mseb.encoders import converter as converter_lib
from mseb.encoders import prompt as prompt_lib
from mseb.encoders import retrieval_encoder
from mseb import types as mseb_types


retrieval_encoder.RetrievalEncoder._setup = lambda self: _

class CachedRetrievalEncoder(converter_lib.Converter):

  def __init__(self, for_rag: bool, top_k:int = 10):
    super().__init__()
    self._for_rag = for_rag
    self._top_k = top_k

  @final
  def _check_input_types(
      self,
      batch: Sequence[mseb_types.MultiModalObject] | Sequence[mseb_types.TextEmbedding],
  ):
    if not all(isinstance(x, mseb_types.SoundWithTitleAndContext) for x in batch):
      raise ValueError(
          'CachedRetrievalEncoder only supports a batch of all'
          ' SoundWithTitleAndContext inputs.'
      )

  @final
  def _encode(
      self, batch: Sequence[mseb_types.MultiModalObject]
  ) -> Sequence[mseb_types.TextPrediction]:
    outputs = []
    for sound in batch:
      assert isinstance(sound, mseb_types.SoundWithTitleAndContext)
      topk_retrieved_items = sound.context_text.split('\n')[:self._top_k]
      if self._for_rag:
        output = mseb_types.TextWithTitleAndContext(
          text='\n'.join(topk_retrieved_items),
          context=mseb_types.TextContextParams(
              id=sound.context.id,
              text=sound.context.text,
          ),
          context_text=sound.context_text,
      )
      else:
        output = mseb_types.TextPrediction(
                prediction='\n'.join(topk_retrieved_items),
                context=mseb_types.PredictionContextParams(
                    id=sound.context.id, debug_text=sound.context.debug_text
                ),
            )
      outputs.append(output)
    return outputs


def CachedRagHFLLMWithTitleAndContextTranscriptTruthEncoder(
    model_path: str,
    top_k: int = 10,
    normalizer: Callable[[str], str] | None = None,
    prompt: prompt_lib.Prompt = prompt_lib.RetrievalPrompt(),
) -> encoder_lib.CascadeEncoder:
  """Cascaded transcript truth and RAG HF LLM encoder."""
  return encoder_lib.CascadeEncoder(
      encoders=[
          CachedRetrievalEncoder(for_rag=True),
          hf_llm_encoder.HFLLMEncoder(
              model_path=model_path, normalizer=normalizer, prompt=prompt
          ),
          converter_lib.TextEmbeddingToTextPredictionConverter(),
      ]
  )

In [None]:
retrieval_gemini_embedding_encoder = CachedRetrievalEncoder(for_rag=False)

result = run_task(
    encoder_name="retrieval_gemini_embedding_encoder",
    encoder=retrieval_gemini_embedding_encoder,
    task=SVQEnUsPassageInLangRetrieval())
pprint(result)

## Span Reasoning

In [None]:
# @title Span reasoning (SimpleVoiceQuestions).
from typing import Sequence
from mseb.tasks.reasonings.span_in_lang import svq as span_reasoning_svq
from mseb.datasets.parquet import ParquetDataset
from mseb import types as mseb_types


def parse_as_list_of_str(s):
  assert s[0] == '[' and s[-1] == ']'
  s = s[1:-1].replace('\n', ' ')
  list_of_str = []
  while s:
    if s[0] == '"':
      ss = s[1:].split('"', )
      list_of_str.append(ss[0])
      s = '"'.join(ss[1:])
    elif s[0] == "'":
      ss = s[1:].split("'")
      list_of_str.append(ss[0])
      s = "'".join(ss[1:])
    else:
      s = s[1:]
  return list_of_str


class SVQEnUsSpanInLangReasoning(span_reasoning_svq.SVQEnUsSpanInLangReasoning):

  def __init__(self, transcript_key="whisper_transcript",
               context_key="passage_text"):
    super().__init__()
    self._transcript_key = transcript_key
    self._context_key = context_key
    self._ds = None
    self._get_dataset()

  def _get_dataset(self):
    if not self._ds:
      self._ds = ParquetDataset(
          dataset_name="svq",
          task_name="span_reasoning_in_lang",
          filename="SimpleVoiceQuestionsDataset_span_reasoning_in_lang_locale_en_us.parquet",
          id_key="utt_id",
          # sample_n=2,
      )
    return self._ds

  def sounds(self):
    svq_dataset = self._get_dataset()
    for example in svq_dataset.get_task_data(
        'span_reasoning_in_lang',
        dtype={
            'locale': str,
            'utt_id': str,
            'page_title': str,
            self._context_key: str,
        },
    ).itertuples():
      if example.locale == self.locale:
        sound = svq_dataset.get_sound(example._asdict())
        sound.context.text = getattr(example, self._transcript_key)
        yield mseb_types.SoundWithTitleAndContext(
            waveform=sound.waveform,
            title_text=example.page_title,
            context_text=getattr(example, self._context_key),
            context=sound.context,
        )

  def examples(self, sub_task: str):
    svq_dataset = self._get_dataset()
    for example in svq_dataset.get_task_data(
        sub_task,
        dtype={
            'locale': str,
            'utt_id': str,
            'span': str,
            'spans': Sequence[str],
        },
    ).itertuples():
      if example.locale == self.locale:
        yield span_reasoning_svq.reasoning_evaluator.ReasoningSpans(
            sound_id=example.utt_id,
            reference_answer=example.span,
            texts=tuple(parse_as_list_of_str(example.spans)),
        )

  def span_lists(self):
    svq_dataset = self._get_dataset()
    for example in svq_dataset.get_task_data(
        'span_reasoning_in_lang', dtype={'locale': str, 'spans': Sequence[str]}
    ).itertuples():
      if example.locale == self.locale:
        yield [
            mseb_types.Text(
                text=span,
                context=mseb_types.TextContextParams(id=span),
            )
            for span in parse_as_list_of_str(example.spans)
        ]


pprint(SVQEnUsSpanInLangReasoning.metadata)

In [None]:
# The examples evaluated by the class.
task = SVQEnUsSpanInLangReasoning()
example = next(task.examples('span_reasoning_in_lang'))
pprint(example)

corpus = task.span_lists()
pprint(next(corpus))

In [None]:
import functools
from mseb import encoder as encoder_lib
from mseb.encoders import gemini_embedding_encoder
from mseb.runner import DirectRunner
from mseb.leaderboard import run_benchmark
from mseb.tasks import reasoning
import importlib

import jaxtyping
import numpy as np
from mseb.evaluators import reasoning_evaluator


!mkdir -p /tmp/cache/reasonings/svq_en_us_span_reasoning_in_lang
!wget https://storage.googleapis.com/mseb_asru_tutorial/svq_en_us_span_reasoning_in_lang_embeddings-00000-of-00001 -O /tmp/cache/reasonings/svq_en_us_span_reasoning_in_lang/embeddings-00000-of-00001
!wget https://storage.googleapis.com/mseb_asru_tutorial/SVQEnUsSpanInLangReasoningColab.gemini_embedding_with_title_and_context_transcript_truth_or_gemini_embedding_embeddings-00000-of-00001 -O /tmp/embeddings-00000-of-00001

flags.FLAGS(['--reasoning_no_answer_threshold=0.8'])


class DummyEncoder(encoder_lib.MultiModalEncoder):

  def _setup(self):
    return

  def _check_input_types(self, batch: Sequence[mseb_types.MultiModalObject]):
    return

  def _encode(
      self, batch: Sequence[mseb_types.MultiModalObject]
  ) -> Sequence[mseb_types.MultiModalObject]:
    return


def run_task(encoder_name, encoder, task):
  runner = DirectRunner(encoder=encoder, output_path='/tmp')
  task.setup()
  return run_benchmark(encoder_name=encoder, runner=runner, task=task)

result = run_task(
    encoder_name="dummy",
    encoder=DummyEncoder(),
    task=SVQEnUsSpanInLangReasoning())
pprint(result)