# Speech to Text in Snowflake
This notebook walks you through the following steps:
* Define a custom model for multiple whisper models.
* Register the model in Snowflake's model registry
* Deploy the model as an inference service using Snowpark Container Services
* Test the deployed inference service
* View Service Logs

## Install additional libraries

In [None]:
!pip install soundfile --quiet

## Create Connection

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import warnings
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", UserWarning)

from snowflake.snowpark.context import get_active_session
session = get_active_session()

## Define Model
We'll be hosting multiple whisper models on the same GPU.

In [None]:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import torch

def load_pipeline(model_id):
    device = 0 if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True
    ).to(device)
    
    processor = AutoProcessor.from_pretrained(model_id)

    model_pipeline = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=torch_dtype,
        device=device,
    )
    
    return model_pipeline
    
#model_pipeline = load_pipeline('openai/whisper-large-v3-turbo')

pipelines = {}
for code in ['tiny','base','small','medium','large-v3-turbo']:
    pipelines[code] = load_pipeline(f'openai/whisper-{code}')

## Create Custom Model

In [None]:
from snowflake.ml.model import custom_model
import soundfile as sf
from scipy.signal import resample
import numpy as np
import io
import logging

class SpeechToTextModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        warnings.simplefilter("ignore", FutureWarning)

        # Set up a logger
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.setLevel(logging.DEBUG)
        self.logger.handlers.clear()

        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)

    def run_pipeline(self, audio, model):
        self.logger.debug(f"Starting speech to text processing ...")
        # Test if hex-encoded bytes or real bytes
        try:
            audio = bytes.fromhex(audio.decode("ascii"))
        except:
            pass
        
        # Load with soundfile
        audio_data, sample_rate = sf.read(io.BytesIO(audio))

        # Convert stereo to mono if needed
        if len(audio_data.shape) > 1:
            audio_data = np.mean(audio_data, axis=1)

        # Resample to 16000 Hz if needed
        target_rate = 16000
        if sample_rate != target_rate:
            num_samples = int(len(audio_data) * target_rate / sample_rate)
            audio_data = resample(audio_data, num_samples)
            sample_rate = target_rate
            
        text = self.context[model]({"array": audio_data.astype(np.float32), "sampling_rate": sample_rate})['text']
        self.logger.debug(f"Finished speech to text processing ...")
        return text

    @custom_model.inference_api
    def transform(self, audio_df: pd.DataFrame) -> pd.DataFrame:
        transcriptions = audio_df.apply(lambda x: self.run_pipeline(x['AUDIO_INPUT'],x['MODEL']), axis=1)
        result = pd.DataFrame({'TRANSCRIPTION':transcriptions})
        return result

# Set the model context that includes the model pipeline
mc = custom_model.ModelContext(
    tiny = pipelines['tiny'], 
    base = pipelines['base'], 
    small = pipelines['small'],
    medium = pipelines['medium'],
    large_v3_turbo = pipelines['large-v3-turbo'],
)

# Set the model context that includes the model pipeline
#mc = custom_model.ModelContext(model_pipeline=model_pipeline)
speech_to_text_model = SpeechToTextModel(context=mc)

## Test Model

In [None]:
with open('harvard.wav', 'rb') as f:
    audio_input = f.read()

input_df = pd.DataFrame([[audio_input,'small']], columns=['AUDIO_INPUT','MODEL'])
output_df = speech_to_text_model.transform(input_df)

# Read results
for ix, row in pd.concat([input_df, output_df], axis=1).iterrows():
    with st.chat_message('ai'):
        st.write(row['TRANSCRIPTION'])
        st.audio(row['AUDIO_INPUT'])

## Register Model

In [None]:
CREATE SCHEMA IF NOT EXISTS MODEL_REGISTRY;

In [None]:
from snowflake.ml.registry import Registry
from snowflake.ml.model.model_signature import infer_signature

reg = Registry(session=session, database_name="AUDIO_INTERFACING_DEMO", schema_name="MODEL_REGISTRY")

model_signature = infer_signature(input_data=input_df, output_data=output_df)
print(model_signature)

In [None]:
model_ref = reg.log_model(
    model_name="SPEECH_TO_TEXT",
    version_name="MULTIPLE",    
    model=speech_to_text_model,
    pip_requirements=['torch','soundfile','transformers==4.55.3'],
    signatures={"transform": model_signature},
    options={"use_gpu": True, "cuda_version": "11.8"},
    comment="openai/whisper [tiny, base, small, medium, large-v3-turbo]"
)

## Create Inference Service

In [None]:
inference_service = model_ref.create_service(
    service_name="AUDIO_INTERFACING_DEMO.PUBLIC.SPEECH_TO_TEXT",
    service_compute_pool="AUDIO_INTERFACE_GPU_POOL",
    ingress_enabled=True,
    gpu_requests='1'
)

In [None]:
model_ref.list_services()

## Test Inference Service

In [None]:
output_df = model_ref.run(
    input_df,
    function_name="transform",
    service_name="AUDIO_INTERFACING_DEMO.PUBLIC.SPEECH_TO_TEXT"
)

# Read results
for ix, row in pd.concat([input_df, output_df], axis=1).iterrows():
    with st.chat_message('ai'):
        st.write(row['TRANSCRIPTION'])
        st.audio(row['AUDIO_INPUT'])

## View Logs

In [None]:
logs = session.call('SYSTEM$GET_SERVICE_LOGS', 'AUDIO_INTERFACING_DEMO.PUBLIC.SPEECH_TO_TEXT', '0', 'model-inference')
for line in logs.split('\n'):
    print(line)

## END