# Text to Speech in Snowflake
This notebook walks you through the following steps:
* Define a custom model for Facebook's TTS models from their [Massive Multilingual Speech project](https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/).
* Register the model in Snowflake's model registry
* Deploy the model as an inference service using Snowpark Container Services
* Test the deployed inference service
* View Service Logs

## Create Connection

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import warnings
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", UserWarning)

from snowflake.snowpark.context import get_active_session
session = get_active_session()

## Define Model
We'll be hosting multiple models on a single GPU, given that these models are fairly small.

In [None]:
from transformers import AutoTokenizer, VitsModel, pipeline
import torch

def load_pipeline(model_id):
    device = 0 if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float32 if torch.cuda.is_available() else torch.float32
    
    model = VitsModel.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True
    ).to(device)
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model_pipeline = pipeline(
        "text-to-speech",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch_dtype,
        device=device,
    )
    
    return model_pipeline

lang_codes = ['eng','deu','fra','nld','hin','kor','pol','por','rus','spa','swe']

pipelines = {}
for code in lang_codes:
    pipelines[code] = load_pipeline(f'facebook/mms-tts-{code}')

## Create Custom Model

In [None]:
from snowflake.ml.model import custom_model
import numpy as np
import scipy
import io
import logging
import base64

class TextToSpeechModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        warnings.simplefilter("ignore", FutureWarning)

        # Set up a logger
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.setLevel(logging.DEBUG)
        self.logger.handlers.clear()

        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)

    def run_pipeline(self, text, lang_code):
        lang_code = lang_code.lower()
        self.logger.debug(f"Starting text to speech processing language={lang_code}...")
        # retrieve pipeline given the lang_code
        pipeline_code = f'pipeline_{lang_code}'
        output = self.context[pipeline_code](text)
        waveform = output["audio"]
        sampling_rate = output["sampling_rate"]
        # Ensure valid shape
        if waveform.ndim == 2:
            # Convert to mono (average channels)
            waveform = waveform.mean(axis=0)
        # Ensure waveform is float32 in range [-1.0, 1.0]
        waveform = np.asarray(waveform, dtype=np.float32)
        waveform = np.clip(waveform, -1.0, 1.0)
        # Convert to int16 for WAV format
        waveform_int16 = (waveform * 32767).astype(np.int16)
        # Write waveform to buffer
        buffer = io.BytesIO()
        scipy.io.wavfile.write(buffer, rate=sampling_rate, data=waveform_int16)
        buffer.seek(0)
        audio_bytes = buffer.getvalue()
        # Encode audio into base64 (to make them JSON serializable)
        audio_base64_bytes = base64.b64encode(audio_bytes).decode('utf-8')
        self.logger.debug(f"Finished text to speech processing language={lang_code}...")
        return audio_base64_bytes

    @custom_model.inference_api
    def transform(self, text_df: pd.DataFrame) -> pd.DataFrame:
        audio_results = text_df.apply(lambda x: self.run_pipeline(x['TEXT_INPUT'], x['LANG_CODE']), axis=1)
        result = pd.DataFrame({'TEXT_TO_SPEECH_RESULT':audio_results})
        return result

# Set the model context that includes the model pipeline
mc = custom_model.ModelContext(
    pipeline_eng = pipelines['eng'], 
    pipeline_deu = pipelines['deu'], 
    pipeline_fra = pipelines['fra'],
    pipeline_nld = pipelines['nld'],
    pipeline_hin = pipelines['hin'],
    pipeline_kor = pipelines['kor'],
    pipeline_pol = pipelines['pol'],
    pipeline_por = pipelines['por'],
    pipeline_rus = pipelines['rus'],
    pipeline_spa = pipelines['spa'],
    pipeline_swe = pipelines['swe']
)
text_to_speech_model = TextToSpeechModel(context=mc)

## Test Model

In [None]:
text = [
    ['It is so awesome to have text to speech capabilities inside Snowflake!','eng'],
    ['Es ist so cool Text in Sprache in Snowflake umwandeln zu können!','deu'],
    ["C'est tellement génial d'avoir des fonctionnalités de synthèse vocale directement dans Snowflake!",'fra']
]

input_df = pd.DataFrame(text, columns=['TEXT_INPUT','LANG_CODE'])
output_df = text_to_speech_model.transform(input_df)

st.dataframe(output_df)

audio_bytes = output_df.iloc[2]['TEXT_TO_SPEECH_RESULT']
decoded_audio = base64.b64decode(audio_bytes)
st.audio(decoded_audio)

In [None]:
# Listen to results
for ix, row in pd.concat([input_df, output_df], axis=1).iterrows():
    audio_bytes = base64.b64decode(row['TEXT_TO_SPEECH_RESULT'])
    with st.chat_message('ai'):
        st.markdown(f"## Lang Code: {row['LANG_CODE']}")
        st.write(row['TEXT_INPUT'])
        st.audio(audio_bytes)

## Register Model

In [None]:
CREATE SCHEMA IF NOT EXISTS MODEL_REGISTRY;

In [None]:
from snowflake.ml.registry import Registry
from snowflake.ml.model.model_signature import infer_signature

reg = Registry(session=session, database_name="AUDIO_INTERFACING_DEMO", schema_name="MODEL_REGISTRY")

model_signature = infer_signature(input_data=input_df, output_data=output_df)
print(model_signature)

In [None]:
model_ref = reg.log_model(
    model_name="TEXT_TO_SPEECH",
    version_name="MULTILANGUAGE",    
    model=text_to_speech_model,
    #pip_requirements=['torch'],
    signatures={"transform": model_signature},
    options={"use_gpu": True, "cuda_version": "11.8"},
    comment="facebook/mms-tts-models ['eng','deu','fra','nld','hin','kor','pol','por','rus','spa','swe']"
)

## Create Inference Service

In [None]:
# mv is a snowflake.ml.model.ModelVersion object
inference_service = model_ref.create_service(
    service_name="AUDIO_INTERFACING_DEMO.PUBLIC.TEXT_TO_SPEECH",
    service_compute_pool="AUDIO_INTERFACE_GPU_POOL",
    ingress_enabled=True,
    gpu_requests='1'
)

In [None]:
model_ref.list_services()

## Test Inference Service

In [None]:
text = [
    ['It is so awesome to have text to speech capabilities inside Snowflake!','eng'],
    ['Es ist so cool Text in Sprache in Snowflake umwandeln zu können!','deu'],
    ["C'est tellement génial d'avoir des fonctionnalités de synthèse vocale directement dans Snowflake!",'fra']
]

input_df = pd.DataFrame(text, columns=['TEXT_INPUT','LANG_CODE'])

output_df = model_ref.run(
    input_df,
    function_name="transform",
    service_name="AUDIO_INTERFACING_DEMO.PUBLIC.TEXT_TO_SPEECH"
)

st.dataframe(output_df)

# Listen to results
for ix, row in pd.concat([input_df, output_df], axis=1).iterrows():
    audio_bytes = base64.b64decode(row['TEXT_TO_SPEECH_RESULT'])
    with st.chat_message('ai'):
        st.markdown(f"## Lang Code: {row['LANG_CODE']}")
        st.write(row['TEXT_INPUT'])
        st.audio(audio_bytes)

## Inference with SQL

In [None]:
SELECT 
    -- call the model
    AUDIO_INTERFACING_DEMO.PUBLIC.TEXT_TO_SPEECH!transform('Snowflake is awesome.','eng') AS MODEL_OUTPUT,
    -- retrieve the base64 string
    MODEL_OUTPUT['TEXT_TO_SPEECH_RESULT'] AS MODEL_OUTPUT_BASE64,
    -- Decode it to binary
    BASE64_DECODE_BINARY(MODEL_OUTPUT_BASE64) AS MODEL_OUTPUT_BINARY;

In [None]:
# Get the data from the former cell and play it
st.audio(SQL_INTERFACE2.to_pandas().iloc[0]['MODEL_OUTPUT_BINARY'])

## View Logs

In [None]:
logs = session.call('SYSTEM$GET_SERVICE_LOGS', 'AUDIO_INTERFACING_DEMO.PUBLIC.TEXT_TO_SPEECH', '0', 'model-inference')
for line in logs.split('\n'):
    print(line)

## END