# About this example

This example shows how you can run [OpenAI Whisper](https://github.com/openai/whisper) to perform speech-to-text with privacy regarding user data. 

By using [BlindAI](https://github.com/mithril-security/blindai), people can send data for the AI to analyze their data without having to fear privacy leaks.

[Whisper](https://openai.com/blog/whisper/) is a Transformers model, developed by OpenAI for speech-to-text. You can learn more about it in the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf).

This tutorial involves three steps:
- Prepare the Whisper model to have an ONNX file. This step is optional as we have a pre-exported model.
- Upload the model inside BlindAI.
- Query the model.

# Installing dependencies

First we will install the packages needed to run this sample.

## Installing Whisper

The commands below will install the Python packages needed to use Whisper models and evaluate the transcription results.

In [None]:
! pip install git+https://github.com/openai/whisper.git
! pip install jiwer

## Install BlindAI

Install the latest version of BlindAI.

In [None]:
!pip install blindai

# (Optional) Preparing the model

Here we will use OpenAI Whisper. The goal of this step is to get a Whisper model for speech-to-text inside an ONNX file, as BlindAI can only serve ONNX models.

Because our model only outputs information about the next most likely token to be predicted, and not a whole sentence, we need to create a kind of meta model that will make use of the model to generate a sequence of tokens.

This is why we will create a `MetaModel` that will leverage a `Whisper` model to output a sequence of tokens. Then we will export it.

We detail the process below but you can go directly to the [Deployment on BlindAI](#b---deployment-on-blindai) section. A pre-exported model will be downloaded, so no need to generate yourself the ONNX file for Whisper speech-to-text.

The first step is to get the model and tokenizers.

In [None]:
import os
import numpy as np

try:
    import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:
    pass

import torch
import pandas as pd
import whisper
import torchaudio

from tqdm.notebook import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class LibriSpeech(torch.utils.data.Dataset):
    """
    A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.
    It will drop the last few seconds of a very small portion of the utterances.
    """
    def __init__(self, split="test-clean", device=DEVICE):
        self.dataset = torchaudio.datasets.LIBRISPEECH(
            root=os.path.expanduser("~/.cache"),
            url=split,
            download=True,
        )
        self.device = device

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        audio, sample_rate, text, _, _, _ = self.dataset[item]
        assert sample_rate == 16000
        audio = whisper.pad_or_trim(audio.flatten()).to(self.device)
        mel = whisper.log_mel_spectrogram(audio)
        
        return (mel, text)

In [None]:
dataset = LibriSpeech("test-clean")
loader = torch.utils.data.DataLoader(dataset, batch_size=16)

In [None]:
model = whisper.load_model("tiny.en")
print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)

Here we define a PyTorch class, the `MetaModel` mentioned earlier, that would contain all the logic needed to do speech-to-text with token generations.

The code is inspired from the `DecodingTask` class from [Whisper](https://github.com/openai/whisper/blob/main/whisper/decoding.py). 

We do this because BlindAI only supports ONNX models. So we need to package our model inference logic inside an ONNX format before sending it to a secure enclave.

In [None]:
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical

from whisper.audio import CHUNK_LENGTH
from whisper.tokenizer import Tokenizer, get_tokenizer
from whisper.utils import compression_ratio

if TYPE_CHECKING:
    from whisper.model import Whisper

from whisper.decoding import *

import torch.nn as nn

class MetaModel(nn.Module):
    inference: Inference
    sequence_ranker: SequenceRanker
    decoder: TokenDecoder
    logit_filters: List[LogitFilter]

    def __init__(self, model: "Whisper", options: DecodingOptions):
        super(NNDecodingTask, self).__init__()
        self.model = model

        language = options.language or "en"
        tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
        self.tokenizer: Tokenizer = tokenizer
        self.options: DecodingOptions = self._verify_options(options)

        self.n_group: int = options.beam_size or options.best_of or 1
        self.n_ctx: int = model.dims.n_text_ctx
        self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2

        self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
        if self.options.without_timestamps:
            self.sot_sequence = tokenizer.sot_sequence_including_notimestamps

        self.initial_tokens: Tuple[int] = self._get_initial_tokens()
        self.sample_begin: int = len(self.initial_tokens)
        self.sot_index: int = self.initial_tokens.index(tokenizer.sot)

        # inference: implements the forward pass through the decoder, including kv caching
        self.inference = PyTorchInference(model, len(self.initial_tokens))

        # sequence ranker: implements how to rank a group of sampled sequences
        self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)

        # decoder: implements how to select the next tokens, given the autoregressive distribution
        if options.beam_size is not None:
            self.decoder = BeamSearchDecoder(
                options.beam_size, tokenizer.eot, self.inference, options.patience
            )
        else:
            self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)

        # logit filters: applies various rules to suppress or penalize certain tokens
        self.logit_filters = []
        if self.options.suppress_blank:
            self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
        if self.options.suppress_tokens:
            self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
        if not options.without_timestamps:
            precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
            max_initial_timestamp_index = None
            if options.max_initial_timestamp:
                max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
            self.logit_filters.append(
                ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
            )

    def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
        if options.beam_size is not None and options.best_of is not None:
            raise ValueError("beam_size and best_of can't be given together")
        if options.temperature == 0:
            if options.best_of is not None:
                raise ValueError("best_of with greedy sampling (T=0) is not compatible")
        if options.patience is not None and options.beam_size is None:
            raise ValueError("patience requires beam_size to be given")
        if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
            raise ValueError("length_penalty (alpha) should be a value between 0 and 1")

        return options

    def _get_initial_tokens(self) -> Tuple[int]:
        tokens = list(self.sot_sequence)
        prefix = self.options.prefix
        prompt = self.options.prompt

        if prefix:
            prefix_tokens = (
                self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
            )
            if self.sample_len is not None:
                max_prefix_len = self.n_ctx // 2 - self.sample_len
                prefix_tokens = prefix_tokens[-max_prefix_len:]
            tokens = tokens + prefix_tokens

        if prompt:
            prompt_tokens = (
                self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
            )
            tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens

        return tuple(tokens)

    def _get_suppress_tokens(self) -> Tuple[int]:
        suppress_tokens = self.options.suppress_tokens

        if isinstance(suppress_tokens, str):
            suppress_tokens = [int(t) for t in suppress_tokens.split(",")]

        if -1 in suppress_tokens:
            suppress_tokens = [t for t in suppress_tokens if t >= 0]
            suppress_tokens.extend(self.tokenizer.non_speech_tokens)
        elif suppress_tokens is None or len(suppress_tokens) == 0:
            suppress_tokens = []  # interpret empty string as an empty list
        else:
            assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"

        suppress_tokens.extend(
            [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
        )
        if self.tokenizer.no_speech is not None:
            # no-speech probability is collected separately
            suppress_tokens.append(self.tokenizer.no_speech)

        return tuple(sorted(set(suppress_tokens)))

    def _get_audio_features(self, mel: Tensor):
        if self.options.fp16:
            mel = mel.half()

        if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
            # encoded audio features are given; skip audio encoding
            audio_features = mel
        else:
            audio_features = self.model.encoder(mel)

        if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
            return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")

        return audio_features

    def _detect_language(self, audio_features: Tensor, tokens: Tensor):
        languages = [self.options.language] * audio_features.shape[0]
        lang_probs = None

        if self.options.language is None or self.options.task == "lang_id":
            lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
            languages = [max(probs, key=probs.get) for probs in lang_probs]
            if self.options.language is None:
                tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens

        return languages, lang_probs

    def _main_loop(self, audio_features: Tensor, tokens: Tensor):
        assert audio_features.shape[0] == tokens.shape[0]
        n_batch = tokens.shape[0]
        sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
        no_speech_probs = [np.nan] * n_batch

        try:
            for i in range(self.sample_len):
                logits = self.inference.logits(tokens, audio_features)

                if i == 0 and self.tokenizer.no_speech is not None:  # save no_speech_probs
                    probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                    no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

                # now we need to consider the logits at the last token only
                logits = logits[:, -1]

                # apply the logit filters, e.g. for suppressing or applying penalty to
                for logit_filter in self.logit_filters:
                    logit_filter.apply(logits, tokens)

                # expand the tokens tensor with the selected next tokens
                tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)

                if completed or tokens.shape[-1] > self.n_ctx:
                    break
        finally:
            self.inference.cleanup_caching()

        return tokens, sum_logprobs, no_speech_probs

    def forward(self, mel: Tensor):
        self.decoder.reset()
        tokenizer: Tokenizer = self.tokenizer
        n_audio: int = mel.shape[0]

        audio_features: Tensor = self._get_audio_features(mel)  # encoder forward pass
        tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)

        # detect language if requested, overwriting the language token
        languages, language_probs = self._detect_language(audio_features, tokens)
        if self.options.task == "lang_id":
            return [
                DecodingResult(audio_features=features, language=language, language_probs=probs)
                for features, language, probs in zip(audio_features, languages, language_probs)
            ]

        # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
        audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
        tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)

        # call the main sampling loop
        tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)

        # reshape the tensors to have (n_audio, n_group) as the first two dimensions
        audio_features = audio_features[:: self.n_group]
        no_speech_probs = no_speech_probs[:: self.n_group]
        assert audio_features.shape[0] == len(no_speech_probs) == n_audio

        tokens = tokens.reshape(n_audio, self.n_group, -1)
        sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)

        return tokens, sum_logprobs

Now that we have a `MetaModel` available, we can export it. Because tracing is used behind the scenes for PyTorch export to ONNX, we need to provide an example of data used for prediction.

In [None]:
mels, text = next(iter(loader))
mel = mels[0].unsqueeze(0)

Now we just need to export this Torch model to ONNX before feeding it to BlindAI.

**Caution**: the code below might take a while, it took 20 minutes on a GCP n1-standard-4 VM.

You can uncomment the cell to run it yourself, but it's easier to just pull the pre-exported model we provide you.

In [None]:
# sample_len = 20

# options = whisper.DecodingOptions(language="en", without_timestamps=True, 
#                                   fp16 = False, sample_len=sample_len)

# metamodel = NNDecodingTask(model, options)

# model_name = f"whisper_tiny_en_{sample_len}_tokens.onnx"
# torch.onnx.export(metamodel, mel, model_name,
#                   export_params=True, opset_version=12, 
#                   operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

# Deployment on BlindAI

If you did not generate the ONNX model following the previous section, we have one available for download that you can pull using:

In [None]:
!wget --quiet --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1wqg1F0UkEdm3KB7n1BjfRLHnzKU2-G5S' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1wqg1F0UkEdm3KB7n1BjfRLHnzKU2-G5S" -O whisper_tiny_en_20_tokens.onnx && rm -rf /tmp/cookies.txt

Now we can upload the model to BlindAI Cloud. To upload of the model, make sure you have an API key.

You can get one on the [Mithril Cloud](https://cloud.mithrilsecurity.io/).

You might get an error if the name you want to use is already taken, as models are uniquely identified by their `model_id`. We will implement namespace soon to avoid that. Meanwhile, you will have to choose a unique ID. We provide an example below to upload your model with a unique name:

In [None]:
import blindai
import uuid

api_key = "YOUR_API_KEY" # Enter your API key here
model_id = "whisper-" + str(uuid.uuid4())

# Upload the ONNX file along with specs and model name
with blindai.Connection(api_key=api_key) as client:
    response = client.upload_model("whisper_tiny_en_20_tokens.onnx", model_id=model_id)

Your model should now be loaded inside a secure enclave managed by BlindAI Cloud! You will just need to send data now for it to be analyzed securely.

# Sending data for confidential prediction

Now it's time to check it's working live!

We will just prepare some input for the model inside the secure enclave of BlindAI to process it.

The ONNX model we uploaded previously is just able to output raw tokens, so we need to postprocess it to display text.

We will do it with a `PostProcessingTask` class inspired once again from the `DecodingTask` class from Whisper.

In [None]:
from whisper.decoding import *

class PostProcessingTask:
    sequence_ranker: SequenceRanker
    decoder: TokenDecoder

    def __init__(self, n_text_ctx, is_multilingual, options: DecodingOptions):

        language = options.language or "en"
        tokenizer = get_tokenizer(is_multilingual, language=language, task=options.task)
        self.tokenizer: Tokenizer = tokenizer
        self.options: DecodingOptions = self._verify_options(options)

        self.n_group: int = options.beam_size or options.best_of or 1
        self.n_ctx: int = n_text_ctx
        self.sample_len: int = options.sample_len or n_text_ctx // 2

        self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
        if self.options.without_timestamps:
            self.sot_sequence = tokenizer.sot_sequence_including_notimestamps

        self.initial_tokens: Tuple[int] = self._get_initial_tokens()
        self.sample_begin: int = len(self.initial_tokens)
        self.sot_index: int = self.initial_tokens.index(tokenizer.sot)

        # sequence ranker: implements how to rank a group of sampled sequences
        self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
        
        self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)

    def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
        if options.beam_size is not None and options.best_of is not None:
            raise ValueError("beam_size and best_of can't be given together")
        if options.temperature == 0:
            if options.best_of is not None:
                raise ValueError("best_of with greedy sampling (T=0) is not compatible")
        if options.patience is not None and options.beam_size is None:
            raise ValueError("patience requires beam_size to be given")
        if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
            raise ValueError("length_penalty (alpha) should be a value between 0 and 1")

        return options

    def _get_initial_tokens(self) -> Tuple[int]:
        tokens = list(self.sot_sequence)
        prefix = self.options.prefix
        prompt = self.options.prompt

        if prefix:
            prefix_tokens = (
                self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
            )
            if self.sample_len is not None:
                max_prefix_len = self.n_ctx // 2 - self.sample_len
                prefix_tokens = prefix_tokens[-max_prefix_len:]
            tokens = tokens + prefix_tokens

        if prompt:
            prompt_tokens = (
                self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
            )
            tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens

        return tuple(tokens)

    @torch.no_grad()
    def run(self, tokens: Tensor, sum_logprobs) -> List[DecodingResult]:
        tokenizer: Tokenizer = self.tokenizer
        # get the final candidates for each group, and slice between the first sampled token and EOT
        tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
        tokens: List[List[Tensor]] = [
            [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
        ]

        # select the top-ranked sample in each group
        selected = self.sequence_ranker.rank(tokens, sum_logprobs)
        tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
        texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]

        return texts

We will now 

In [None]:
# predict without timestamps for short-form transcription
sample_len = 20

options = whisper.DecodingOptions(language="en", without_timestamps=True, sample_len=sample_len)

n_text_ctx = model.dims.n_text_ctx or 448
is_multilingual = model.is_multilingual or False

postprocess = PostProcessingTask(n_text_ctx, is_multilingual, options)

In [None]:
import blindai

with blindai.Connection(api_key=api_key) as client:
  # Send data to the model
  prediction = client.predict(model_id, mels[3].unsqueeze(0))

In [None]:
prediction.inference_time

In [None]:
tokens, sum_logprobs = prediction.output
tokens = tokens.as_torch().unsqueeze(0).unsqueeze(0)
sum_logprobs = sum_logprobs.as_torch().unsqueeze(0)

texts = postprocess.run(tokens, sum_logprobs)
texts

Et voila! We have been able to apply a start of the art model for speech-to-text, without ever having to show the data in clear to the people operating the service!

If you have liked this example, do not hesitate to drop a star on our [GitHub](https://github.com/mithril-security/blindai) and chat with us on our [Discord](https://discord.gg/TxEHagpWd4)!