Skip to content

Commit

Permalink
feat: Add SpeakerEmotionInference transformer for generating SSML t… (#…
Browse files Browse the repository at this point in the history
…1691)

* feat: Add TextToSpeechSSMLGenerator transformer for generating SSML to augment TTS requestsfeat: Add TextToSpeechSSMLGenerator transformer for generating SSML to augment TTS requests.  Adding support for SSML to the TTS endpoint.

* Fixing style issues

* More style issues

* ok it finally passes scalastyle

Co-authored-by: Brendan Walsh <brwals@outlook.com>
  • Loading branch information
BrendanWalsh and brwals committed Nov 18, 2022
1 parent 0b96cc5 commit aeb2ff7
Show file tree
Hide file tree
Showing 15 changed files with 359 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ object GenerateThumbnails extends ComplexParamsReadable[GenerateThumbnails] with
class GenerateThumbnails(override val uid: String)
extends CognitiveServicesBase(uid) with HasImageInput
with HasWidth with HasHeight with HasSmartCropping
with HasInternalJsonOutputParser with HasCognitiveServiceInput with HasSetLocation with BasicLogging
with HasCognitiveServiceInput with HasSetLocation with BasicLogging
with HasSetLinkedService {
logClass()

Expand All @@ -454,8 +454,6 @@ class GenerateThumbnails(override val uid: String)
new CustomOutputParser().setUDF({ r: HTTPResponseData => r.entity.map(_.content).orNull })
}

override def responseDataType: DataType = BinaryType

def urlPath: String = "/vision/v2.0/generateThumbnail"
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.cognitive

import com.microsoft.azure.synapse.ml.logging.BasicLogging
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.stages.Lambda
import org.apache.http.client.methods.HttpRequestBase
import org.apache.http.entity.{AbstractHttpEntity, StringEntity}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel, Transformer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{DataType, StringType, StructType}
import spray.json.DefaultJsonProtocol.StringJsonFormat

object SpeakerEmotionInference extends ComplexParamsReadable[SpeakerEmotionInference] with Serializable

class SpeakerEmotionInference(override val uid: String)
extends CognitiveServicesBase(uid)
with HasLocaleCol with HasVoiceNameCol with HasTextCol with HasSetLocation
with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicLogging {
logClass()

def this() = this(Identifiable.randomUID(classOf[SpeakerEmotionInference].getSimpleName))

setDefault(
locale -> Left("en-US"),
voiceName -> Left("en-US-JennyNeural"),
text -> Left(this.uid + "_text"))

def urlPath: String = "cognitiveservices/v1"

override protected def responseDataType: DataType = SpeakerEmotionInferenceResponse.schema

protected val additionalHeaders: Map[String, String] = Map[String, String](
("X-Microsoft-OutputFormat", "textanalytics-json"),
("Content-Type", "application/ssml+xml"))

override protected def inputFunc(schema: StructType): Row => Option[HttpRequestBase] = super.inputFunc(schema)
.andThen(r => r.map(r => {
additionalHeaders.foreach(header => r.setHeader(header._1, header._2))
r
}))

override def setLocation(v: String): this.type = {
val domain = getLocationDomain(v)
setUrl(s"https://$v.tts.speech.microsoft.$domain/cognitiveservices/v1")
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { row =>
val body: String =
s"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xmlns:mstts='https://www.w3.org/2001/mstts'" +
s" xml:lang='en-US'><voice name='Microsoft Server Speech Text to Speech Voice (en-US, JennyNeural)'>" +
s"<mstts:task name ='RoleStyle'/>${getValue(row, text)}</voice></speak>"
Some(new StringEntity(body))
}

private[cognitive] def formatSSML(content: String,
lang: String,
voice: String,
response: SpeakerEmotionInferenceResponse): String = {
// Create a sequence containing all of the non-speech text (text outside of quotes)
// Then zip that with the sequence of all speech text, wrapping speech text in an express-as tag
val speechBounds = response.Conversations.unzip(c => (c.Begin, c.End))
val nonSpeechEnds = speechBounds._1 ++ Seq(content.length)
val nonSpeechBegins = Seq(0) ++ speechBounds._2
val nonSpeechText = (nonSpeechBegins).zip(nonSpeechEnds).map(pair => content.substring(pair._1, pair._2))
val speechText = response.Conversations.map(c => {
s"<mstts:express-as role='${c.Role}' style='${c.Style}'>${c.Content}</mstts:express-as>"
}) ++ Seq("")
val innerText = nonSpeechText.zip(speechText).map(pair => pair._1 + pair._2).reduce(_ + _)

"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xmlns:mstts='https://www.w3.org/2001/mstts' " +
s"xml:lang='${lang}'><voice name='${voice}'>${innerText}</voice></speak>\n"
}

protected override def getInternalTransformer(schema: StructType): PipelineModel = {
val internalTransformer = super.getInternalTransformer(schema)
NamespaceInjections.pipelineModel(stages = Array[Transformer](
internalTransformer,
new Lambda().setTransform(ds => {
val converter = SpeakerEmotionInferenceResponse.makeFromRowConverter
val newSchema = schema.add(getErrorCol, SpeakerEmotionInferenceError.schema).add(getOutputCol, StringType)
ds.toDF().map(row => {
val ssml = formatSSML(
getValue(row, text),
getValue(row, locale),
getValue(row, voiceName),
converter(row.getAs[Row](row.fieldIndex(getOutputCol)))
)
new GenericRowWithSchema((row.toSeq.dropRight(1) ++ Seq(ssml)).toArray, newSchema): Row
})(RowEncoder({
newSchema
}))
})
))
}
}

trait HasLocaleCol extends HasServiceParams {
val locale = new ServiceParam[String](this,
"locale",
s"The locale of the input text",
isRequired = true)

def setLocale(v: String): this.type = setScalarParam(locale, v)

def setLocaleCol(v: String): this.type = setVectorParam(locale, v)
}

trait HasVoiceNameCol extends HasServiceParams {
val voiceName = new ServiceParam[String](this,
"voiceName",
s"The name of the voice used for synthesis",
isRequired = true)

def setVoiceName(v: String): this.type = setScalarParam(voiceName, v)

def setVoiceNameCol(v: String): this.type = setVectorParam(voiceName, v)
}

trait HasTextCol extends HasServiceParams {
val text = new ServiceParam[String](this,
"text",
s"The text to annotate with inferred emotion",
isRequired = true)

def setText(v: String): this.type = setScalarParam(text, v)

def setTextCol(v: String): this.type = setVectorParam(text, v)
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ object SpeechFormat extends DefaultJsonProtocol {
jsonFormat9(TranscriptionResponse.apply)
implicit val TranscriptionParticipantFormat: RootJsonFormat[TranscriptionParticipant] =
jsonFormat3(TranscriptionParticipant.apply)

}

object SpeechSynthesisError extends SparkBindings[SpeechSynthesisError] {
Expand All @@ -69,3 +68,19 @@ object SpeechSynthesisError extends SparkBindings[SpeechSynthesisError] {
}

case class SpeechSynthesisError(errorCode: String, errorDetails: String, errorReason: String)

object SpeakerEmotionInferenceError extends SparkBindings[SpeakerEmotionInferenceError]

case class SpeakerEmotionInferenceError(errorCode: String, errorDetails: String)

case class SSMLConversation(Begin: Int,
End: Int,
Content: String,
Role: String,
Style: String)

object SSMLConversation extends SparkBindings[SSMLConversation]

case class SpeakerEmotionInferenceResponse(IsValid: Boolean, Conversations: Seq[SSMLConversation])

object SpeakerEmotionInferenceResponse extends SparkBindings[SpeakerEmotionInferenceResponse]
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
import com.microsoft.azure.synapse.ml.io.http.{HasErrorCol, HasURL}
import com.microsoft.azure.synapse.ml.logging.BasicLogging
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.cognitiveservices.speech.{SpeechConfig, SpeechSynthesisCancellationDetails,
SpeechSynthesisOutputFormat, SpeechSynthesizer}
import com.microsoft.cognitiveservices.speech.{
SpeechConfig, SpeechSynthesisCancellationDetails,
SpeechSynthesisOutputFormat, SpeechSynthesisResult, SpeechSynthesizer
}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.{IOUtils => HUtils}
import org.apache.spark.ml.param.{Param, ParamMap}
Expand Down Expand Up @@ -62,7 +64,6 @@ class TextToSpeech(override val uid: String)
s"The locale of the input text",
isRequired = true)


def setLocale(v: String): this.type = setScalarParam(locale, v)

def setLocaleCol(v: String): this.type = setVectorParam(locale, v)
Expand Down Expand Up @@ -106,6 +107,20 @@ class TextToSpeech(override val uid: String)

def getOutputFileCol: String = $(outputFileCol)

val useSSML = new ServiceParam[Boolean](this,
"useSSML",
s"whether to interpret the provided text input as SSML (Speech Synthesis Markup Language). " +
"The default value is false.",
isRequired = false)

def setUseSSML(v: Boolean): this.type = setScalarParam(useSSML, v)

def setUseSSMLCol(v: String): this.type = setVectorParam(useSSML, v)

def speechGenerator(synth: SpeechSynthesizer, shouldUseSSML: Boolean, txt: String): SpeechSynthesisResult = {
if (shouldUseSSML) synth.SpeakSsml(txt) else synth.SpeakText(txt)
}

override def transform(dataset: Dataset[_]): DataFrame = {
val hconf = new SerializableConfiguration(dataset.sparkSession.sparkContext.hadoopConfiguration)
val toRow = SpeechSynthesisError.makeToRowConverter
Expand All @@ -116,8 +131,9 @@ class TextToSpeech(override val uid: String)
getValueOpt(row, outputFormat).foreach(format =>
config.setSpeechSynthesisOutputFormat(SpeechSynthesisOutputFormat.valueOf(format)))

val (errorOpt, data) = using(new SpeechSynthesizer(config, null)) { synth => //scalastyle:ignore null
val res = synth.SpeakText(getValue(row, text))
val (errorOpt, data) = using(new SpeechSynthesizer(config, null)) { synth => //scalastyle:ignore null
val res = speechGenerator(synth, getValueOpt(row, useSSML)
.getOrElse(false), getValueOpt(row, text).getOrElse(""))
val error = if (res.getReason.name() == "SynthesizingAudioCompleted") {
None
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

package com.microsoft.azure.synapse.ml.cognitive.split1

import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.cognitive._
import com.microsoft.azure.synapse.ml.core.spark.FluentAPI._
import com.microsoft.azure.synapse.ml.core.test.base.{Flaky, TestBase}
Expand All @@ -14,10 +13,6 @@ import org.apache.spark.sql.functions.{col, typedLit}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.scalactic.Equality

trait CognitiveKey {
lazy val cognitiveKey: String = sys.env.getOrElse("COGNITIVE_API_KEY", Secrets.CognitiveApiKey)
}

trait OCRUtils extends TestBase {

import spark.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package com.microsoft.azure.synapse.ml.cognitive.split2

import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.cognitive._
import com.microsoft.azure.synapse.ml.cognitive.split1.CognitiveKey
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.azure.synapse.ml.io.http.RESTHelpers._
Expand Down
Loading

0 comments on commit aeb2ff7

Please sign in to comment.