diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/BingImageSearch.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/BingImageSearch.scala
index 776b84524d..60e2aa00a3 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/BingImageSearch.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/BingImageSearch.scala
@@ -3,27 +3,25 @@
package com.microsoft.ml.spark.cognitive
-import java.net.URL
import com.microsoft.ml.spark.core.utils.AsyncUtils
import com.microsoft.ml.spark.logging.BasicLogging
import com.microsoft.ml.spark.stages.Lambda
import org.apache.commons.io.IOUtils
import org.apache.http.client.methods.{HttpGet, HttpRequestBase}
import org.apache.http.entity.AbstractHttpEntity
-import org.apache.spark.binary.ConfUtils
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.ServiceParam
import org.apache.spark.ml.util._
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.functions.{col, explode, udf}
+import org.apache.spark.sql.functions.{col, explode}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, Row}
+import spray.json.DefaultJsonProtocol._
+import java.net.URL
import scala.concurrent.duration.Duration
import scala.concurrent.{ExecutionContext, Future}
-import spray.json.DefaultJsonProtocol._
object BingImageSearch extends ComplexParamsReadable[BingImageSearch] with Serializable {
@@ -68,13 +66,15 @@ object BingImageSearch extends ComplexParamsReadable[BingImageSearch] with Seria
class BingImageSearch(override val uid: String)
extends CognitiveServicesBase(uid)
- with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicLogging {
+ with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicLogging with HasSetLinkedService {
logClass()
override protected lazy val pyInternalWrapper = true
def this() = this(Identifiable.randomUID("BingImageSearch"))
+ def urlPath: String = "/v7.0/images/search"
+
setDefault(url -> "https://api.bing.microsoft.com/v7.0/images/search")
override def prepareMethod(): HttpRequestBase = new HttpGet()
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/CognitiveServiceBase.scala
index 26f9efaa40..8cacdf8292 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/CognitiveServiceBase.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/CognitiveServiceBase.scala
@@ -241,6 +241,19 @@ trait HasSetLinkedService extends Wrappable with HasURL with HasSubscriptionKey
}
}
+trait HasSetLinkedServiceUsingLocation extends HasSetLinkedService with HasSetLocation {
+ override def setLinkedService(v: String): this.type = {
+ val classPath = "mssparkutils.cognitiveService"
+ val linkedServiceClass = ScalaClassLoader(getClass.getClassLoader).tryToLoadClass(classPath)
+ val locationMethod = linkedServiceClass.get.getMethod("getLocation", v.getClass)
+ val keyMethod = linkedServiceClass.get.getMethod("getKey", v.getClass)
+ val location = locationMethod.invoke(linkedServiceClass.get, v).toString
+ val key = keyMethod.invoke(linkedServiceClass.get, v).toString
+ setLocation(location)
+ setSubscriptionKey(key)
+ }
+}
+
trait HasSetLocation extends Wrappable with HasURL with HasUrlPath {
override def pyAdditionalMethods: String = super.pyAdditionalMethods + {
"""
@@ -277,6 +290,12 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
assert(badColumns.isEmpty,
s"Could not find dynamic columns: $badColumns in columns: ${schema.fieldNames.toSet}")
+ val missingRequiredParams = this.getRequiredParams.filter {
+ p => this.get(p).isEmpty && this.getDefault(p).isEmpty
+ }
+ assert(missingRequiredParams.isEmpty,
+ s"Missing required params: ${missingRequiredParams.map(s => s.name).mkString("(", ", ", ")")}")
+
val dynamicParamCols = getVectorParamMap.values.toList.map(col) match {
case Nil => Seq(lit(false).alias("placeholder"))
case l => l
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/ComputerVision.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/ComputerVision.scala
index 74f3a546ff..238242825a 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/ComputerVision.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/ComputerVision.scala
@@ -312,8 +312,8 @@ class RecognizeText(override val uid: String)
"printed text recognition is performed. If 'Handwritten' is specified," +
" handwriting recognition is performed",
{
- case Left(_) => true
- case Right(s) => Set("Printed", "Handwritten")(s)
+ case Left(s) => Set("Printed", "Handwritten")(s)
+ case Right(_) => true
}, isURLParam = true)
def getMode: String = getScalarParam(mode)
@@ -361,8 +361,8 @@ class ReadImage(override val uid: String)
" so only provide a language code if you would like to force the documented" +
" to be processed as that specific language.",
{
- case Left(_) => true
- case Right(s) => Set("en", "nl", "fr", "de", "it", "pt", "es")(s)
+ case Left(s) => Set("en", "nl", "fr", "de", "it", "pt", "es")(s)
+ case Right(_) => true
}, isURLParam = true)
def setLanguage(v: String): this.type = setScalarParam(language, v)
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/DocumentTranslator.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/DocumentTranslator.scala
index c3a23bd3e6..76ede8bb5a 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/DocumentTranslator.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/DocumentTranslator.scala
@@ -20,6 +20,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
import spray.json._
import java.net.URI
+import scala.reflect.internal.util.ScalaClassLoader
trait DocumentTranslatorAsyncReply extends BasicAsyncReply {
@@ -137,6 +138,17 @@ class DocumentTranslator(override val uid: String) extends CognitiveServicesBase
))).toJson.compactPrint, ContentType.APPLICATION_JSON))
}
+ override def setLinkedService(v: String): this.type = {
+ val classPath = "mssparkutils.cognitiveService"
+ val linkedServiceClass = ScalaClassLoader(getClass.getClassLoader).tryToLoadClass(classPath)
+ val nameMethod = linkedServiceClass.get.getMethod("getName", v.getClass)
+ val keyMethod = linkedServiceClass.get.getMethod("getKey", v.getClass)
+ val name = nameMethod.invoke(linkedServiceClass.get, v).toString
+ val key = keyMethod.invoke(linkedServiceClass.get, v).toString
+ setServiceName(name)
+ setSubscriptionKey(key)
+ }
+
override def setServiceName(v: String): this.type = {
super.setServiceName(v)
setUrl(s"https://$getServiceName.cognitiveservices.azure.com/" + urlPath)
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/FormRecognizer.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/FormRecognizer.scala
index cfc82209c2..4adf873660 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/FormRecognizer.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/FormRecognizer.scala
@@ -71,8 +71,8 @@ trait HasModelID extends HasServiceParams {
trait HasLocale extends HasServiceParams {
val locale = new ServiceParam[String](this, "locale", "Locale of the receipt. Supported" +
" locales: en-AU, en-CA, en-GB, en-IN, en-US.", {
- case Left(_) => true
- case Right(s) => Set("en-AU", "en-CA", "en-GB", "en-IN", "en-US")(s)
+ case Left(s) => Set("en-AU", "en-CA", "en-GB", "en-IN", "en-US")(s)
+ case Right(_) => true
}, isURLParam = true)
def setLocale(v: String): this.type = setScalarParam(locale, v)
@@ -258,7 +258,7 @@ object ListCustomModels extends ComplexParamsReadable[ListCustomModels]
class ListCustomModels(override val uid: String) extends CognitiveServicesBase(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser
- with HasSetLocation with BasicLogging {
+ with HasSetLocation with HasSetLinkedService with BasicLogging {
logClass()
def this() = this(Identifiable.randomUID("ListCustomModels"))
@@ -283,7 +283,7 @@ object GetCustomModel extends ComplexParamsReadable[GetCustomModel]
class GetCustomModel(override val uid: String) extends CognitiveServicesBase(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser
- with HasSetLocation with BasicLogging with HasModelID {
+ with HasSetLocation with HasSetLinkedService with BasicLogging with HasModelID {
logClass()
def this() = this(Identifiable.randomUID("GetCustomModel"))
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechToText.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechToText.scala
index b52c36583c..52706fba99 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechToText.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechToText.scala
@@ -21,7 +21,7 @@ object SpeechToText extends ComplexParamsReadable[SpeechToText] with Serializabl
class SpeechToText(override val uid: String) extends CognitiveServicesBase(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser with HasSetLocation with BasicLogging
- with HasSetLinkedService {
+ with HasSetLinkedServiceUsingLocation {
logClass()
def this() = this(Identifiable.randomUID("SpeechToText"))
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechToTextSDK.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechToTextSDK.scala
index 8e261f5bfe..af86ef1f9d 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechToTextSDK.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/SpeechToTextSDK.scala
@@ -77,7 +77,7 @@ private[ml] class BlockingQueueIterator[T](lbq: LinkedBlockingQueue[Option[T]],
abstract class SpeechSDKBase extends Transformer
with HasSetLocation with HasServiceParams
with HasOutputCol with HasURL with HasSubscriptionKey with ComplexParamsWritable with BasicLogging
- with HasSetLinkedService {
+ with HasSetLinkedServiceUsingLocation {
type ResponseType <: SharedSpeechFields
@@ -198,16 +198,6 @@ abstract class SpeechSDKBase extends Transformer
def urlPath: String = "/sts/v1.0/issuetoken"
- override def setLinkedService(v: String): this.type = {
- val classPath = "mssparkutils.cognitiveService"
- val linkedServiceClass = ScalaClassLoader(getClass.getClassLoader).tryToLoadClass(classPath)
- val locationMethod = linkedServiceClass.get.getMethod("getLocation", v.getClass)
- val keyMethod = linkedServiceClass.get.getMethod("getKey", v.getClass)
- val location = locationMethod.invoke(linkedServiceClass.get, v).toString
- val key = keyMethod.invoke(linkedServiceClass.get, v).toString
- setLocation(location)
- setSubscriptionKey(key)
- }
setDefault(language -> Left("en-us"))
setDefault(profanity -> Left("Masked"))
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalytics.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalytics.scala
index c85f7780f8..ddbda40e49 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalytics.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalytics.scala
@@ -95,6 +95,12 @@ abstract class TextAnalyticsBase(override val uid: String) extends CognitiveServ
override protected def getInternalTransformer(schema: StructType): PipelineModel = {
val dynamicParamColName = DatasetExtensions.findUnusedColumnName("dynamic", schema)
+ val missingRequiredParams = this.getRequiredParams.filter {
+ p => this.get(p).isEmpty && this.getDefault(p).isEmpty
+ }
+ assert(missingRequiredParams.isEmpty,
+ s"Missing required params: ${missingRequiredParams.map(s => s.name).mkString("(", ", ", ")")}")
+
def reshapeToArray(parameterName: String): Option[(Transformer, String, String)] = {
val reshapedColName = DatasetExtensions.findUnusedColumnName(parameterName, schema)
getVectorParamMap.get(parameterName).flatMap {
@@ -296,6 +302,18 @@ class NER(override val uid: String) extends TextAnalyticsBase(uid) with BasicLog
def urlPath: String = "/text/analytics/v3.0/entities/recognition/general"
}
+object PII extends ComplexParamsReadable[PII]
+
+class PII(override val uid: String) extends TextAnalyticsBase(uid) with BasicLogging {
+ logClass()
+
+ def this() = this(Identifiable.randomUID("PII"))
+
+ override def responseDataType: StructType = PIIResponseV3.schema
+
+ def urlPath: String = "/text/analytics/v3.1/entities/recognition/pii"
+}
+
object LanguageDetector extends ComplexParamsReadable[LanguageDetector]
class LanguageDetector(override val uid: String)
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSDK.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSDK.scala
index d84d0dc299..71acdaf669 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSDK.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSDK.scala
@@ -30,12 +30,33 @@ import scala.concurrent.duration.Duration
import scala.concurrent.{ExecutionContext, Future}
trait HasOptions extends HasServiceParams {
- val options = new Param[TextAnalyticsRequestOptionsV4](
- this, name = "options", "text analytics request options")
+ val modelVersion = new Param[String](
+ this, name = "modelVersion", "modelVersion option")
- def getOptions: Option[TextAnalyticsRequestOptionsV4] = get(options)
+ def getModelVersion: Option[String] = get(modelVersion)
- def setOptions(v: TextAnalyticsRequestOptionsV4): this.type = set(options, v)
+ def setModelVersion(v: String): this.type = set(modelVersion, v)
+
+ val includeStatistics = new Param[Boolean](
+ this, name = "includeStatistics", "includeStatistics option")
+
+ def getIncludeStatistics: Option[Boolean] = get(includeStatistics)
+
+ def setIncludeStatistics(v: Boolean): this.type = set(includeStatistics, v)
+
+ val disableServiceLogs = new Param[Boolean](
+ this, name = "disableServiceLogs", "disableServiceLogs option")
+
+ def getDisableServiceLogs: Option[Boolean] = get(disableServiceLogs)
+
+ def setDisableServiceLogs(v: Boolean): this.type = set(disableServiceLogs, v)
+
+ val includeOpinionMining = new Param[Boolean](
+ this, name = "includeOpinionMining", "includeOpinionMining option")
+
+ def getIncludeOpinionMining: Option[Boolean] = get(includeOpinionMining)
+
+ def setIncludeOpinionMining(v: Boolean): this.type = set(includeOpinionMining, v)
}
@@ -49,7 +70,6 @@ abstract class TextAnalyticsSDKBase[T]()
val responseBinding: SparkBindings[TAResponseV4[T]]
def invokeTextAnalytics(client: TextAnalyticsClient,
- options: Option[TextAnalyticsRequestOptions],
text: Seq[String],
lang: Seq[String]): TAResponseV4[T]
@@ -68,14 +88,9 @@ abstract class TextAnalyticsSDKBase[T]()
.map(ct => Duration.fromNanos((ct * math.pow(10, 9)).toLong)) //scalastyle:ignore magic.number
.getOrElse(Duration.Inf)
- val requestOptions = get(options) match {
- case Some(o) => Some(toSDK(o))
- case None => None
- }
-
val futures = rows.map { row =>
Future {
- val results = invokeTextAnalytics(client, requestOptions, getValue(row, text), getValue(row, language))
+ val results = invokeTextAnalytics(client, getValue(row, text), getValue(row, language))
Row.fromSeq(row.toSeq ++ Seq(toRow(results))) // Adding a new column
}(ExecutionContext.global)
}
@@ -153,14 +168,18 @@ class LanguageDetectionV4(override val uid: String)
override val responseBinding: SparkBindings[TAResponseV4[DetectedLanguageV4]] = DetectLanguageResponseV4
override def invokeTextAnalytics(client: TextAnalyticsClient,
- options: Option[TextAnalyticsRequestOptions],
input: Seq[String],
hints: Seq[String]): TAResponseV4[DetectedLanguageV4] = {
val documents = (input, hints, input.indices).zipped.map { (doc, hint, i) =>
new DetectLanguageInput(i.toString, doc, hint)
}.asJava
- val response = client.detectLanguageBatchWithResponse(documents, options.orNull, Context.NONE).getValue
+ val options = new TextAnalyticsRequestOptions()
+ .setModelVersion(getModelVersion.getOrElse("latest"))
+ .setIncludeStatistics(getIncludeStatistics.getOrElse(false))
+ .setServiceLogsDisabled(getDisableServiceLogs.getOrElse(false))
+
+ val response = client.detectLanguageBatchWithResponse(documents, options, Context.NONE).getValue
toResponse(response.asScala, response.getModelVersion)
}
}
@@ -176,14 +195,17 @@ class KeyphraseExtractionV4(override val uid: String)
override val responseBinding: SparkBindings[TAResponseV4[KeyphraseV4]] = KeyPhraseResponseV4
override def invokeTextAnalytics(client: TextAnalyticsClient,
- options: Option[TextAnalyticsRequestOptions],
input: Seq[String],
lang: Seq[String]): TAResponseV4[KeyphraseV4] = {
val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) =>
new TextDocumentInput(i.toString, doc).setLanguage(lang)
}.asJava
+ val options = new TextAnalyticsRequestOptions()
+ .setModelVersion(getModelVersion.getOrElse("latest"))
+ .setIncludeStatistics(getIncludeStatistics.getOrElse(false))
+ .setServiceLogsDisabled(getDisableServiceLogs.getOrElse(false))
- val response = client.extractKeyPhrasesBatchWithResponse(documents, options.orNull, Context.NONE).getValue
+ val response = client.extractKeyPhrasesBatchWithResponse(documents, options, Context.NONE).getValue
toResponse(response.asScala, response.getModelVersion)
}
}
@@ -194,19 +216,25 @@ class TextSentimentV4(override val uid: String)
extends TextAnalyticsSDKBase[SentimentScoredDocumentV4]() {
logClass()
- def this() = this(Identifiable.randomUID("KeyphraseExtractionV4"))
+ def this() = this(Identifiable.randomUID("TextSentimentV4"))
override val responseBinding: SparkBindings[TAResponseV4[SentimentScoredDocumentV4]] = SentimentResponseV4
override def invokeTextAnalytics(client: TextAnalyticsClient,
- options: Option[TextAnalyticsRequestOptions],
input: Seq[String],
lang: Seq[String]): TAResponseV4[SentimentScoredDocumentV4] = {
+
val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) =>
new TextDocumentInput(i.toString, doc).setLanguage(lang)
}.asJava
- val response = client.analyzeSentimentBatchWithResponse(documents, options.orNull, Context.NONE).getValue
+ val options = new AnalyzeSentimentOptions()
+ .setModelVersion(getModelVersion.getOrElse("latest"))
+ .setIncludeStatistics(getIncludeStatistics.getOrElse(false))
+ .setServiceLogsDisabled(getDisableServiceLogs.getOrElse(false))
+ .setIncludeOpinionMining(getIncludeOpinionMining.getOrElse(true))
+
+ val response = client.analyzeSentimentBatchWithResponse(documents, options, Context.NONE).getValue
toResponse(response.asScala, response.getModelVersion)
}
}
@@ -221,14 +249,18 @@ class PIIV4(override val uid: String) extends TextAnalyticsSDKBase[PIIEntityColl
override val responseBinding: SparkBindings[TAResponseV4[PIIEntityCollectionV4]] = PIIResponseV4
override def invokeTextAnalytics(client: TextAnalyticsClient,
- options: Option[TextAnalyticsRequestOptions],
input: Seq[String],
lang: Seq[String]): TAResponseV4[PIIEntityCollectionV4] = {
val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) =>
new TextDocumentInput(i.toString, doc).setLanguage(lang)
}.asJava
- val response = client.recognizePiiEntitiesBatchWithResponse(documents, null, Context.NONE).getValue
+ val options = new RecognizePiiEntitiesOptions()
+ .setModelVersion(getModelVersion.getOrElse("latest"))
+ .setIncludeStatistics(getIncludeStatistics.getOrElse(false))
+ .setServiceLogsDisabled(getDisableServiceLogs.getOrElse(false))
+
+ val response = client.recognizePiiEntitiesBatchWithResponse(documents, options, Context.NONE).getValue
toResponse(response.asScala, response.getModelVersion)
}
}
@@ -243,14 +275,17 @@ class HealthcareV4(override val uid: String) extends TextAnalyticsSDKBase[Health
override val responseBinding: SparkBindings[TAResponseV4[HealthEntitiesResultV4]] = HealthcareResponseV4
override def invokeTextAnalytics(client: TextAnalyticsClient,
- options: Option[TextAnalyticsRequestOptions],
input: Seq[String],
lang: Seq[String]): TAResponseV4[HealthEntitiesResultV4] = {
val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) =>
new TextDocumentInput(i.toString, doc).setLanguage(lang)
}.asJava
+ val options = new AnalyzeHealthcareEntitiesOptions()
+ .setModelVersion(getModelVersion.getOrElse("latest"))
+ .setIncludeStatistics(getIncludeStatistics.getOrElse(false))
+ .setServiceLogsDisabled(getDisableServiceLogs.getOrElse(false))
- val poller = client.beginAnalyzeHealthcareEntities(documents, null, Context.NONE)
+ val poller = client.beginAnalyzeHealthcareEntities(documents, options, Context.NONE)
poller.waitForCompletion()
val pagedResults = poller.getFinalResult.asScala
@@ -268,14 +303,18 @@ class EntityLinkingV4(override val uid: String) extends TextAnalyticsSDKBase[Lin
override val responseBinding: SparkBindings[TAResponseV4[LinkedEntityCollectionV4]] = LinkedEntityResponseV4
override def invokeTextAnalytics(client: TextAnalyticsClient,
- options: Option[TextAnalyticsRequestOptions],
input: Seq[String],
lang: Seq[String]): TAResponseV4[LinkedEntityCollectionV4] = {
val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) =>
new TextDocumentInput(i.toString, doc).setLanguage(lang)
}.asJava
- val response = client.recognizeLinkedEntitiesBatchWithResponse(documents, options.orNull, Context.NONE).getValue
+ val options = new RecognizeLinkedEntitiesOptions()
+ .setModelVersion(getModelVersion.getOrElse("latest"))
+ .setIncludeStatistics(getIncludeStatistics.getOrElse(false))
+ .setServiceLogsDisabled(getDisableServiceLogs.getOrElse(false))
+
+ val response = client.recognizeLinkedEntitiesBatchWithResponse(documents, options, Context.NONE).getValue
toResponse(response.asScala, response.getModelVersion)
}
}
@@ -290,14 +329,18 @@ class NERV4(override val uid: String) extends TextAnalyticsSDKBase[NERCollection
override val responseBinding: SparkBindings[TAResponseV4[NERCollectionV4]] = NERResponseV4
override def invokeTextAnalytics(client: TextAnalyticsClient,
- options: Option[TextAnalyticsRequestOptions],
input: Seq[String],
lang: Seq[String]): TAResponseV4[NERCollectionV4] = {
val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) =>
new TextDocumentInput(i.toString, doc).setLanguage(lang)
}.asJava
- val response = client.recognizeEntitiesBatchWithResponse(documents, options.orNull, Context.NONE).getValue
+ val options = new RecognizeEntitiesOptions()
+ .setModelVersion(getModelVersion.getOrElse("latest"))
+ .setIncludeStatistics(getIncludeStatistics.getOrElse(false))
+ .setServiceLogsDisabled(getDisableServiceLogs.getOrElse(false))
+
+ val response = client.recognizeEntitiesBatchWithResponse(documents, options, Context.NONE).getValue
toResponse(response.asScala, response.getModelVersion)
}
}
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSDKSchemasV4.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSDKSchemasV4.scala
index 9e17dc0efd..286af23123 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSDKSchemasV4.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSDKSchemasV4.scala
@@ -39,10 +39,6 @@ case class TAWarningV4(warningCode: String, message: String)
case class TextDocumentInputs(id: String, text: String)
-case class TextAnalyticsRequestOptionsV4(modelVersion: String,
- includeStatistics: Boolean,
- disableServiceLogs: Boolean)
-
case class KeyphraseV4(keyPhrases: Seq[String], warnings: Seq[TAWarningV4])
case class SentimentConfidenceScoreV4(negative: Double, neutral: Double, positive: Double)
@@ -55,11 +51,11 @@ case class SentimentScoredDocumentV4(sentiment: String,
case class SentimentSentenceV4(text: String,
sentiment: String,
confidenceScores: SentimentConfidenceScoreV4,
- opinion: Option[Seq[OpinionV4]],
+ opinions: Option[Seq[OpinionV4]],
offset: Int,
length: Int)
-case class OpinionV4(target: TargetV4, assessment: Seq[AssessmentV4])
+case class OpinionV4(target: TargetV4, assessments: Seq[AssessmentV4])
case class TargetV4(text: String,
sentiment: String,
@@ -336,13 +332,6 @@ object SDKConverters {
entity.getEntities.getWarnings.asScala.toSeq.map(fromSDK))
}
- def toSDK(textAnalyticsRequestOptionsV4: TextAnalyticsRequestOptionsV4): TextAnalyticsRequestOptions = {
- new TextAnalyticsRequestOptions()
- .setModelVersion(textAnalyticsRequestOptionsV4.modelVersion)
- .setIncludeStatistics(textAnalyticsRequestOptionsV4.includeStatistics)
- .setServiceLogsDisabled(textAnalyticsRequestOptionsV4.disableServiceLogs)
- }
-
def unpackResult[T <: TextAnalyticsResult, U](result: T)(implicit converter: T => U):
(Option[TAErrorV4], Option[DocumentStatistics], Option[U]) = {
if (result.isError) {
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSchemas.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSchemas.scala
index e37b41ee62..5c0bc3a464 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSchemas.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextAnalyticsSchemas.scala
@@ -105,6 +105,22 @@ case class NEREntityV3(text: String,
length: Integer,
confidenceScore: Double)
+// NER Pii Schemas
+
+object PIIResponseV3 extends SparkBindings[TAResponse[PIIDocV3]]
+
+case class PIIDocV3(id: String,
+ entities: Seq[PIIEntityV3],
+ warnings: Seq[TAWarning],
+ statistics: Option[DocumentStatistics])
+
+case class PIIEntityV3(text: String,
+ category: String,
+ subcategory: Option[String] = None,
+ offset: Integer,
+ length: Integer,
+ confidenceScore: Double)
+
// KeyPhrase Schemas
object KeyPhraseResponseV3 extends SparkBindings[TAResponse[KeyPhraseScoreV3]]
diff --git a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextTranslator.scala b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextTranslator.scala
index b6d2623b6f..2eb1f6adf6 100644
--- a/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextTranslator.scala
+++ b/cognitive/src/main/scala/com/microsoft/ml/spark/cognitive/TextTranslator.scala
@@ -3,14 +3,19 @@
package com.microsoft.ml.spark.cognitive
+import com.microsoft.ml.spark.core.schema.DatasetExtensions
+import com.microsoft.ml.spark.io.http.SimpleHTTPTransformer
import com.microsoft.ml.spark.logging.BasicLogging
-import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpRequestBase}
+import com.microsoft.ml.spark.stages.{DropColumns, Lambda, UDFTransformer}
+import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase}
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
-import org.apache.spark.ml.ComplexParamsReadable
+import org.apache.spark.injections.UDFUtils
+import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel, Transformer}
import org.apache.spark.ml.param.ServiceParam
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Row
-import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
+import org.apache.spark.sql.functions.{array, col, lit, struct}
+import org.apache.spark.sql.types.{ArrayType, DataType, StringType, StructType}
import spray.json.DefaultJsonProtocol._
import spray.json._
@@ -38,6 +43,8 @@ trait HasTextInput extends HasServiceParams {
def setText(v: Seq[String]): this.type = setScalarParam(text, v)
+ def setText(v: String): this.type = setScalarParam(text, Seq(v))
+
def getTextCol: String = getVectorParam(text)
def setTextCol(v: String): this.type = setVectorParam(text, v)
@@ -72,79 +79,124 @@ trait HasToLanguage extends HasServiceParams {
def getToLanguageCol: String = getVectorParam(toLanguage)
}
-trait TextAsOnlyEntity extends HasTextInput with HasCognitiveServiceInput {
-
- override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
- r =>
- Some(new StringEntity(
- getValueOpt(r, text)
- .map(x => x.map(y => Map("Text" -> y))).toJson.compactPrint, ContentType.APPLICATION_JSON))
- }
-}
-
-abstract class TextTranslatorBase(override val uid: String) extends CognitiveServicesBase(uid)
- with HasInternalJsonOutputParser with HasCognitiveServiceInput with HasSubscriptionRegion
- with HasSetLocation with HasSetLinkedService {
-
- protected val subscriptionRegionHeaderName = "Ocp-Apim-Subscription-Region"
-
- override protected def contentType: Row => String = { _ => "application/json; charset=UTF-8" }
+trait TextAsOnlyEntity extends HasTextInput with HasCognitiveServiceInput with HasSubscriptionRegion {
override protected def inputFunc(schema: StructType): Row => Option[HttpRequestBase] = {
- val rowToUrl = prepareUrl
- val rowToEntity = prepareEntity;
{ row: Row =>
if (shouldSkip(row)) {
None
+ } else if (getValue(row, text).forall(Option(_).isEmpty)) {
+ None
} else {
- val req = prepareMethod()
- req.setURI(new URI(rowToUrl(row)))
- getValueOpt(row, subscriptionKey).foreach(
- req.setHeader(subscriptionKeyHeaderName, _))
- getValueOpt(row, subscriptionRegion).foreach(
- req.setHeader(subscriptionRegionHeaderName, _)
- )
- req.setHeader("Content-Type", contentType(row))
-
- req match {
- case er: HttpEntityEnclosingRequestBase =>
- rowToEntity(row).foreach(er.setEntity)
- case _ =>
+ val urlParams: Array[ServiceParam[Any]] =
+ getUrlParams.asInstanceOf[Array[ServiceParam[Any]]]
+
+ val texts = getValue(row, text)
+
+ val base = getUrl + "?api-version=3.0"
+ val appended = if (!urlParams.isEmpty) {
+ "&" + URLEncodingUtils.format(urlParams.flatMap(p =>
+ getValueOpt(row, p).map {
+ val pName = p.name match {
+ case "fromLanguage" => "from"
+ case "toLanguage" => "to"
+ case s => s
+ }
+ v => pName -> p.toValueString(v)
+ }
+ ).toMap)
+ } else {
+ ""
}
- Some(req)
+
+ val post = new HttpPost(base + appended)
+ getValueOpt(row, subscriptionKey).foreach(post.setHeader("Ocp-Apim-Subscription-Key", _))
+ getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))
+ post.setHeader("Content-Type", "application/json; charset=UTF-8")
+
+ val json = texts.map(s => Map("Text" -> s)).toJson.compactPrint
+ post.setEntity(new StringEntity(json, "UTF-8"))
+ Some(post)
}
}
}
- override protected def prepareUrl: Row => String = {
- val urlParams: Array[ServiceParam[Any]] =
- getUrlParams.asInstanceOf[Array[ServiceParam[Any]]];
+ override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { _ => None }
+}
+
+abstract class TextTranslatorBase(override val uid: String) extends CognitiveServicesBase(uid)
+ with HasInternalJsonOutputParser with HasSubscriptionRegion
+ with HasSetLocation with HasSetLinkedServiceUsingLocation {
+
- // This semicolon is needed to avoid argument confusion
- def replaceName(s: String): String = {
- if (s == "fromLanguage") {
- "from"
- } else if (s == "toLanguage") {
- "to"
- } else {
- s
+ protected def reshapeColumns(schema: StructType, parameterNames: Seq[String])
+ : Seq[(Transformer, String, String)] = {
+
+ def reshapeToArray(parameterName: String): Option[(Transformer, String, String)] = {
+ val reshapedColName = DatasetExtensions.findUnusedColumnName(parameterName, schema)
+ getVectorParamMap.get(parameterName).flatMap {
+ case c if schema(c).dataType == StringType =>
+ Some((Lambda(_.withColumn(reshapedColName, array(col(getVectorParam(parameterName))))),
+ getVectorParam(parameterName),
+ reshapedColName))
+ case _ => None
}
}
- { row: Row =>
- val base = getUrl + "?api-version=3.0"
- val appended = if (!urlParams.isEmpty) {
- "&" + URLEncodingUtils.format(urlParams.flatMap(p =>
- getValueOpt(row, p).map {
- v => replaceName(p.name) -> p.toValueString(v)
- }
- ).toMap)
- } else {
- ""
+
+ parameterNames.flatMap(x => reshapeToArray(x))
+ }
+
+ // noinspection ScalaStyle
+ protected def customGetInternalTransformer(schema: StructType,
+ parameterNames: Seq[String]): PipelineModel = {
+ val dynamicParamColName = DatasetExtensions.findUnusedColumnName("dynamic", schema)
+
+ val missingRequiredParams = this.getRequiredParams.filter {
+ p => this.get(p).isEmpty && this.getDefault(p).isEmpty
+ }
+ assert(missingRequiredParams.isEmpty,
+ s"Missing required params: ${missingRequiredParams.map(s => s.name).mkString("(", ", ", ")")}")
+
+ val reshapeCols = reshapeColumns(schema, parameterNames)
+
+ val newColumnMapping = reshapeCols.map {
+ case (_, oldCol, newCol) => (oldCol, newCol)
+ }.toMap
+
+ val columnsToGroup = getVectorParamMap.values.size match {
+ case 0 => getVectorParamMap.values.toList.map(col) match {
+ case Nil => Seq(lit(false).alias("placeholder"))
+ case l => l
}
- base + appended
+ case _ => getVectorParamMap.map { case (_, oldCol) =>
+ val newCol = newColumnMapping.getOrElse(oldCol, oldCol)
+ col(newCol).alias(oldCol)
+ }.toSeq
}
+
+ val stages = reshapeCols.map(_._1).toArray ++ Array(
+ Lambda(_.withColumn(
+ dynamicParamColName,
+ struct(columnsToGroup: _*))),
+ new SimpleHTTPTransformer()
+ .setInputCol(dynamicParamColName)
+ .setOutputCol(getOutputCol)
+ .setInputParser(getInternalInputParser(schema))
+ .setOutputParser(getInternalOutputParser(schema))
+ .setHandler(getHandler)
+ .setConcurrency(getConcurrency)
+ .setConcurrentTimeout(get(concurrentTimeout))
+ .setErrorCol(getErrorCol),
+ new DropColumns().setCols(Array(
+ dynamicParamColName) ++ newColumnMapping.values.toArray.asInstanceOf[Array[String]])
+ )
+
+ NamespaceInjections.pipelineModel(stages)
}
+ override protected def getInternalTransformer(schema: StructType): PipelineModel =
+ customGetInternalTransformer(schema, Seq("text"))
+
override def setLocation(v: String): this.type = {
setSubscriptionRegion(v)
setUrl("https://api.cognitive.microsofttranslator.com/" + urlPath)
@@ -162,6 +214,51 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
def urlPath: String = "translate"
+ override protected def inputFunc(schema: StructType): Row => Option[HttpRequestBase] = {
+ { row: Row =>
+ if (shouldSkip(row)) {
+ None
+ } else if (getValue(row, text).forall(Option(_).isEmpty)) {
+ None
+ } else if (getValue(row, toLanguage).forall(Option(_).isEmpty)) {
+ None
+ } else {
+ val urlParams: Array[ServiceParam[Any]] =
+ getUrlParams.asInstanceOf[Array[ServiceParam[Any]]]
+
+ val texts = getValue(row, text)
+
+ val base = getUrl + "?api-version=3.0"
+ val appended = if (!urlParams.isEmpty) {
+ "&" + URLEncodingUtils.format(urlParams.flatMap(p =>
+ getValueOpt(row, p).map {
+ val pName = p.name match {
+ case "fromLanguage" => "from"
+ case "toLanguage" => "to"
+ case s => s
+ }
+ v => pName -> p.toValueString(v)
+ }
+ ).toMap)
+ } else {
+ ""
+ }
+
+ val post = new HttpPost(base + appended)
+ getValueOpt(row, subscriptionKey).foreach(post.setHeader("Ocp-Apim-Subscription-Key", _))
+ getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))
+ post.setHeader("Content-Type", "application/json; charset=UTF-8")
+
+ val json = texts.map(s => Map("Text" -> s)).toJson.compactPrint
+ post.setEntity(new StringEntity(json, "UTF-8"))
+ Some(post)
+ }
+ }
+ }
+
+ override protected def getInternalTransformer(schema: StructType): PipelineModel =
+ customGetInternalTransformer(schema, Seq("text", "toLanguage"))
+
val toLanguage = new ServiceParam[Seq[String]](this, "toLanguage", "Specifies the language of the output" +
" text. The target language must be one of the supported languages included in the translation scope." +
" For example, use to=de to translate to German. It's possible to translate to multiple languages simultaneously" +
@@ -171,6 +268,8 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
def setToLanguage(v: Seq[String]): this.type = setScalarParam(toLanguage, v)
+ def setToLanguage(v: String): this.type = setScalarParam(toLanguage, Seq(v))
+
def setToLanguageCol(v: String): this.type = setVectorParam(toLanguage, v)
val fromLanguage = new ServiceParam[String](this, "fromLanguage", "Specifies the language of the input" +
@@ -186,8 +285,8 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
val textType = new ServiceParam[String](this, "textType", "Defines whether the text being" +
" translated is plain text or HTML text. Any HTML needs to be a well-formed, complete element. Possible values" +
" are: plain (default) or html.", {
- case Left(_) => true
- case Right(s) => Set("plain", "html")(s)
+ case Left(s) => Set("plain", "html")(s)
+ case Right(_) => true
}, isURLParam = true)
def setTextType(v: String): this.type = setScalarParam(textType, v)
@@ -206,8 +305,8 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
val profanityAction = new ServiceParam[String](this, "profanityAction", "Specifies how" +
" profanities should be treated in translations. Possible values are: NoAction (default), Marked or Deleted. ",
{
- case Left(_) => true
- case Right(s) => Set("NoAction", "Marked", "Deleted")(s)
+ case Left(s) => Set("NoAction", "Marked", "Deleted")(s)
+ case Right(_) => true
}, isURLParam = true)
def setProfanityAction(v: String): this.type = setScalarParam(profanityAction, v)
@@ -216,8 +315,8 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
val profanityMarker = new ServiceParam[String](this, "profanityMarker", "Specifies how" +
" profanities should be marked in translations. Possible values are: Asterisk (default) or Tag.", {
- case Left(_) => true
- case Right(s) => Set("Asterisk", "Tag")(s)
+ case Left(s) => Set("Asterisk", "Tag")(s)
+ case Right(_) => true
}, isURLParam = true)
def setProfanityMarker(v: String): this.type = setScalarParam(profanityMarker, v)
@@ -378,6 +477,8 @@ trait HasTextAndTranslationInput extends HasServiceParams {
def setTextAndTranslation(v: Seq[(String, String)]): this.type = setScalarParam(textAndTranslation, v)
+ def setTextAndTranslation(v: (String, String)): this.type = setScalarParam(textAndTranslation, Seq(v))
+
def getTextAndTranslationCol: String = getVectorParam(textAndTranslation)
def setTextAndTranslationCol(v: String): this.type = setVectorParam(textAndTranslation, v)
@@ -387,20 +488,66 @@ trait HasTextAndTranslationInput extends HasServiceParams {
object DictionaryExamples extends ComplexParamsReadable[DictionaryExamples]
class DictionaryExamples(override val uid: String) extends TextTranslatorBase(uid)
- with HasTextAndTranslationInput with HasFromLanguage with HasToLanguage with BasicLogging {
+ with HasTextAndTranslationInput with HasFromLanguage with HasToLanguage
+ with HasCognitiveServiceInput with BasicLogging {
logClass()
def this() = this(Identifiable.randomUID("DictionaryExamples"))
def urlPath: String = "dictionary/examples"
- override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
- r =>
- Some(new StringEntity(
- getValue(r, textAndTranslation).asInstanceOf[Seq[Row]]
- .map(x => Map("Text" -> x.getString(0), "Translation" -> x.getString(1)))
- .toJson.compactPrint, ContentType.APPLICATION_JSON))
+ override protected def inputFunc(schema: StructType): Row => Option[HttpRequestBase] = {
+ { row: Row =>
+ if (shouldSkip(row)) {
+ None
+ } else {
+ val urlParams: Array[ServiceParam[Any]] =
+ getUrlParams.asInstanceOf[Array[ServiceParam[Any]]]
+
+ val textAndTranslations = getValue(row, textAndTranslation)
+ if (textAndTranslations.isEmpty)
+ None
+ else {
+
+ val base = getUrl + "?api-version=3.0"
+ val appended = if (!urlParams.isEmpty) {
+ "&" + URLEncodingUtils.format(urlParams.flatMap(p =>
+ getValueOpt(row, p).map {
+ val pName = p.name match {
+ case "fromLanguage" => "from"
+ case "toLanguage" => "to"
+ case s => s
+ }
+ v => pName -> p.toValueString(v)
+ }
+ ).toMap)
+ } else {
+ ""
+ }
+
+ val post = new HttpPost(base + appended)
+ getValueOpt(row, subscriptionKey).foreach(post.setHeader("Ocp-Apim-Subscription-Key", _))
+ getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))
+ post.setHeader("Content-Type", "application/json; charset=UTF-8")
+
+ val json = textAndTranslations.head.getClass.getTypeName match {
+ case "scala.Tuple2" => textAndTranslations.map(
+ t => Map("Text" -> t._1, "Translation" -> t._2)).toJson.compactPrint
+ case _ => textAndTranslations.asInstanceOf[Seq[Row]].map(
+ s => Map("Text" -> s.getString(0), "Translation" -> s.getString(1))).toJson.compactPrint
+ }
+
+ post.setEntity(new StringEntity(json, "UTF-8"))
+ Some(post)
+ }
+ }
+ }
}
+ override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { _ => None }
+
+ override protected def getInternalTransformer(schema: StructType): PipelineModel =
+ customGetInternalTransformer(schema, Seq("textAndTranslation"))
+
override def responseDataType: DataType = ArrayType(DictionaryExamplesResponse.schema)
}
diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/AnamolyDetectionSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/AnamolyDetectionSuite.scala
index 9da24d479a..d3d3b556f5 100644
--- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/AnamolyDetectionSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/AnamolyDetectionSuite.scala
@@ -9,7 +9,7 @@ import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.functions.{col, collect_list, lit, struct}
+import org.apache.spark.sql.functions.{col, collect_list, lit, sort_array, struct}
trait AnomalyKey {
lazy val anomalyKey = sys.env.getOrElse("ANOMALY_API_KEY", Secrets.AnomalyApiKey)
@@ -39,7 +39,7 @@ trait AnomalyDetectorSuiteBase extends TestBase with AnomalyKey {
.withColumn("group", lit(1))
.withColumn("inputs", struct(col("timestamp"), col("value")))
.groupBy(col("group"))
- .agg(collect_list(col("inputs")).alias("inputs"))
+ .agg(sort_array(collect_list(col("inputs"))).alias("inputs"))
lazy val df2: DataFrame = Seq(
("2000-01-24T08:46:00Z", 826.0),
@@ -61,7 +61,7 @@ trait AnomalyDetectorSuiteBase extends TestBase with AnomalyKey {
.withColumn("group", lit(1))
.withColumn("inputs", struct(col("timestamp"), col("value")))
.groupBy(col("group"))
- .agg(collect_list(col("inputs")).alias("inputs"))
+ .agg(sort_array(collect_list(col("inputs"))).alias("inputs"))
}
@@ -93,6 +93,20 @@ class DetectLastAnomalySuite extends TransformerFuzzing[DetectLastAnomaly] with
assert(result.isAnomaly)
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new DetectLastAnomaly()
+ .setSubscriptionKey(anomalyKey)
+ .setLocation("westus2")
+ .setOutputCol("anomalies")
+ .setErrorCol("errors")
+ .transform(df).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("granularity"))
+ assert(caught.getMessage.contains("series"))
+ }
+
override def testObjects(): Seq[TestObject[DetectLastAnomaly]] =
Seq(new TestObject(ad, df))
@@ -117,6 +131,19 @@ class DetectAnomaliesSuite extends TransformerFuzzing[DetectAnomalies] with Anom
assert(result.isAnomaly.count({b => b}) == 2)
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new DetectAnomalies()
+ .setSubscriptionKey(anomalyKey)
+ .setLocation("westus2")
+ .setOutputCol("anomalies")
+ .transform(df).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("granularity"))
+ assert(caught.getMessage.contains("series"))
+ }
+
override def testObjects(): Seq[TestObject[DetectAnomalies]] =
Seq(new TestObject(ad, df))
@@ -181,6 +208,19 @@ class SimpleDetectAnomaliesSuite extends TransformerFuzzing[SimpleDetectAnomalie
.show(truncate=false)
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new SimpleDetectAnomalies()
+ .setSubscriptionKey(anomalyKey)
+ .setLocation("westus2")
+ .setOutputCol("anomalies")
+ .setGroupbyCol("group")
+ .transform(sdf).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("granularity"))
+ }
+
//TODO Nulls, different cardinalities
override def testObjects(): Seq[TestObject[SimpleDetectAnomalies]] =
diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/ComputerVisionSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/ComputerVisionSuite.scala
index c0f3076565..acaf07db6c 100644
--- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/ComputerVisionSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/ComputerVisionSuite.scala
@@ -493,6 +493,11 @@ class DescribeImageSuite extends TransformerFuzzing[DescribeImage]
assert(tags("person") && tags("glasses"))
}
+ override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
+ super.assertDFEq(df1.select("descriptions.description.tags", "descriptions.description.captions.text"),
+ df2.select("descriptions.description.tags", "descriptions.description.captions.text"))(eq)
+ }
+
override def testObjects(): Seq[TestObject[DescribeImage]] =
Seq(new TestObject(t, df))
diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FaceSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FaceSuite.scala
index c4e28aeea6..35dd109531 100644
--- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FaceSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FaceSuite.scala
@@ -98,6 +98,18 @@ class FindSimilarFaceSuite extends TransformerFuzzing[FindSimilarFace] with Cogn
assert(numMatches === List(1, 2, 2))
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new FindSimilarFace()
+ .setSubscriptionKey(cognitiveKey)
+ .setLocation("eastus")
+ .setOutputCol("similar")
+ .transform(faceIdDF).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("faceId"))
+ }
+
override def testObjects(): Seq[TestObject[FindSimilarFace]] =
Seq(new TestObject(findSimilar, faceIdDF))
@@ -147,6 +159,18 @@ class GroupFacesSuite extends TransformerFuzzing[GroupFaces] with CognitiveKey {
assert(numMatches === List(2, 2, 2))
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new GroupFaces()
+ .setSubscriptionKey(cognitiveKey)
+ .setLocation("eastus")
+ .setOutputCol("grouping")
+ .transform(faceIdDF).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("faceIds"))
+ }
+
override def testObjects(): Seq[TestObject[GroupFaces]] =
Seq(new TestObject(group, faceIdDF))
@@ -229,6 +253,19 @@ class IdentifyFacesSuite extends TransformerFuzzing[IdentifyFaces] with Cognitiv
assert(matches === List(satyaId, bradId, bradId))
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new IdentifyFaces()
+ .setSubscriptionKey(cognitiveKey)
+ .setLocation("eastus")
+ .setPersonGroupId(pgId)
+ .setOutputCol("identified_faces")
+ .transform(df).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("faceIds"))
+ }
+
override def testObjects(): Seq[TestObject[IdentifyFaces]] = Seq(new TestObject(id, df))
override def reader: MLReadable[_] = IdentifyFaces
diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FormRecognizerSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FormRecognizerSuite.scala
index 4b97b43fee..6814ef2c37 100644
--- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FormRecognizerSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/FormRecognizerSuite.scala
@@ -535,6 +535,18 @@ class GetCustomModelSuite extends TransformerFuzzing[GetCustomModel]
""""SALESPERSON","SERVICE ADDRESS:","SHIP TO:","SHIPPED VIA","TERMS","TOTAL","UNIT PRICE"]}}""").stripMargin)
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new GetCustomModel()
+ .setSubscriptionKey(cognitiveKey).setLocation("eastus")
+ .setIncludeKeys(true)
+ .setOutputCol("model")
+ .transform(pathDf).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("modelId"))
+ }
+
override def testObjects(): Seq[TestObject[GetCustomModel]] =
Seq(new TestObject(getCustomModel, pathDf))
@@ -586,6 +598,17 @@ class AnalyzeCustomModelSuite extends TransformerFuzzing[AnalyzeCustomModel]
assert(results.head.getString(2) === "")
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new AnalyzeCustomModel()
+ .setSubscriptionKey(cognitiveKey).setLocation("eastus")
+ .setImageUrlCol("source").setOutputCol("form")
+ .transform(imageDf4).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("modelId"))
+ }
+
override def testObjects(): Seq[TestObject[AnalyzeCustomModel]] =
Seq(new TestObject(analyzeCustomModel, imageDf4))
diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/ImageSearchSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/ImageSearchSuite.scala
index d599b67a69..77829f15c5 100644
--- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/ImageSearchSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/ImageSearchSuite.scala
@@ -103,6 +103,20 @@ class ImageSearchSuite extends TransformerFuzzing[BingImageSearch]
assert(ddf.collect().head.getAs[Row]("images") != null)
}
+ test("Throw errors if required fields not set") {
+ val caught = intercept[AssertionError] {
+ new BingImageSearch()
+ .setSubscriptionKey(searchKey)
+ .setOffsetCol("offsets")
+ .setCount(10)
+ .setImageType("photo")
+ .setOutputCol("images")
+ .transform(requestParameters).collect()
+ }
+ assert(caught.getMessage.contains("Missing required params"))
+ assert(caught.getMessage.contains("q"))
+ }
+
override lazy val dfEq: Equality[DataFrame] = new Equality[DataFrame] {
def areEqual(a: DataFrame, b: Any): Boolean =
(a.schema === b.asInstanceOf[DataFrame].schema) &&
diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TextAnalyticsSDKSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TextAnalyticsSDKSuite.scala
index cb41bd5c68..39d625bbf3 100644
--- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TextAnalyticsSDKSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TextAnalyticsSDKSuite.scala
@@ -164,10 +164,9 @@ class LanguageDetectionSuiteV4 extends TestBase with DataFrameEquality with Text
test("Language Detection - Overriding request options and including statistics") {
val replies = getDetector
- .setOptions(TextAnalyticsRequestOptionsV4(
- modelVersion = "latest",
- includeStatistics = true,
- disableServiceLogs = false))
+ .setModelVersion("latest")
+ .setIncludeStatistics(true)
+ .setDisableServiceLogs(false)
.transform(df)
.select("output.statistics")
.collect()
@@ -180,10 +179,9 @@ class LanguageDetectionSuiteV4 extends TestBase with DataFrameEquality with Text
val caught =
intercept[SparkException] {
getDetector
- .setOptions(TextAnalyticsRequestOptionsV4(
- modelVersion = "oopsie doopsie",
- includeStatistics = false,
- disableServiceLogs = false))
+ .setModelVersion("invalid model")
+ .setIncludeStatistics(true)
+ .setDisableServiceLogs(false)
.transform(df)
.collect()
}
@@ -193,10 +191,9 @@ class LanguageDetectionSuiteV4 extends TestBase with DataFrameEquality with Text
test("Language Detection - Disable logs") {
val replies = getDetector
- .setOptions(TextAnalyticsRequestOptionsV4(
- modelVersion = "latest",
- includeStatistics = false,
- disableServiceLogs = true))
+ .setModelVersion("latest")
+ .setIncludeStatistics(false)
+ .setDisableServiceLogs(true)
.transform(df)
.select("output.result")
.collect()
@@ -244,6 +241,27 @@ class SentimentAnalysisSuiteV4 extends TestBase with DataFrameEquality with Text
.setTextCol("text")
.setOutputCol("output")
+ test("Sentiment Analysis - Include Opinion Mining") {
+ val replies = getDetector
+ .setIncludeOpinionMining(true)
+ .transform(batchedDF)
+ .select("output")
+ .collect()
+ assert(replies(0).schema(0).name == "output")
+ df.printSchema()
+ df.show()
+ val fromRow = SentimentResponseV4.makeFromRowConverter
+
+ val outResponse = fromRow(replies(0).getAs[GenericRowWithSchema]("output"))
+
+ val opinions = outResponse.result.head.get.sentences.head.opinions
+
+ assert(opinions != null)
+
+ assert(opinions.get.head.target.text == "rain")
+ assert(opinions.get.head.target.sentiment == "negative")
+ }
+
test("Sentiment Analysis - Output Assertion") {
val replies = getDetector.transform(batchedDF)
.select("output")
@@ -313,6 +331,17 @@ class SentimentAnalysisSuiteV4 extends TestBase with DataFrameEquality with Text
assert(codes(0).get(0).toString == "InvalidDocument")
}
+ test("Sentiment Analysis - Opinion Mining") {
+ val replies = getDetector.transform(invalidDocDf)
+ .select("output.error.errorMessage", "output.error.errorCode")
+ .collect()
+ val errors = replies.map(row => row.getList(0))
+ val codes = replies.map(row => row.getList(1))
+
+ assert(errors(0).get(0).toString == "Document text is empty.")
+ assert(codes(0).get(0).toString == "InvalidDocument")
+ }
+
test("Sentiment Analysis - Assert Confidence Score") {
val replies = getDetector.transform(batchedDF)
.select("output")
@@ -531,8 +560,6 @@ class HealthcareSuiteV4 extends TestBase with DataFrameEquality with TextKey {
("en", "6-drops of Vitamin B-12 every evening")
).toDF("lang", "text")
- val options: TextAnalyticsRequestOptionsV4 = new TextAnalyticsRequestOptionsV4("", true, false)
-
lazy val extractor: HealthcareV4 = new HealthcareV4()
.setSubscriptionKey(textKey)
.setLocation("eastus")
diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TextAnalyticsSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TextAnalyticsSuite.scala
index 23c84df185..c700910ee9 100644
--- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TextAnalyticsSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TextAnalyticsSuite.scala
@@ -310,7 +310,7 @@ class KeyPhraseExtractorV3Suite extends TransformerFuzzing[KeyPhraseExtractor] w
println(results)
assert(results(0).getSeq[String](0).toSet === Set("Hello world", "input text"))
- assert(results(2).getSeq[String](0).toSet === Set("mucho tráfico", "carretera", "ayer"))
+ assert(results(2).getSeq[String](0).toSet === Set("mucho tráfico", "día", "carretera", "ayer"))
}
override def testObjects(): Seq[TestObject[KeyPhraseExtractor]] =
@@ -396,3 +396,44 @@ class NERSuiteV3 extends TransformerFuzzing[NER] with TextKey {
override def reader: MLReadable[_] = NER
}
+
+class PIISuiteV3 extends TransformerFuzzing[PII] with TextKey {
+ import spark.implicits._
+
+ lazy val df: DataFrame = Seq(
+ ("1", "en", "My SSN is 859-98-0987"),
+ ("2", "en",
+ "Your ABA number - 111000025 - is the first 9 digits in the lower left hand corner of your personal check."),
+ ("3", "en", "Is 998.214.865-68 your Brazilian CPF number?")
+ ).toDF("id", "language", "text")
+
+ lazy val n: PII = new PII()
+ .setSubscriptionKey(textKey)
+ .setLocation("eastus")
+ .setLanguage("en")
+ .setOutputCol("response")
+
+ test("Basic Usage") {
+ val results = n.transform(df)
+ val matches = results.withColumn("match",
+ col("response")
+ .getItem(0)
+ .getItem("entities")
+ .getItem(0))
+ .select("match")
+
+ val testRow = matches.collect().head(0).asInstanceOf[GenericRowWithSchema]
+
+ assert(testRow.getAs[String]("text") === "859-98-0987")
+ assert(testRow.getAs[Int]("offset") === 10)
+ assert(testRow.getAs[Int]("length") === 11)
+ assert(testRow.getAs[Double]("confidenceScore") > 0.6)
+ assert(testRow.getAs[String]("category") === "USSocialSecurityNumber")
+
+ }
+
+ override def testObjects(): Seq[TestObject[PII]] =
+ Seq(new TestObject[PII](n, df))
+
+ override def reader: MLReadable[_] = PII
+}
diff --git a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TranslatorSuite.scala b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TranslatorSuite.scala
index 63f6c8fdae..29e780ac7f 100644
--- a/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TranslatorSuite.scala
+++ b/cognitive/src/test/scala/com/microsoft/ml/spark/cognitive/split1/TranslatorSuite.scala
@@ -10,7 +10,6 @@ import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, flatten}
-import org.scalactic.Equality
trait TranslatorKey {
lazy val translatorKey: String = sys.env.getOrElse("TRANSLATOR_KEY", Secrets.TranslatorKey)
@@ -24,7 +23,7 @@ trait TranslatorUtils extends TestBase {
lazy val textDf1: DataFrame = Seq(List("Hello, what is your name?")).toDF("text")
- lazy val textDf2: DataFrame = Seq(List("Hello, what is your name?", "Bye")).toDF("text")
+ lazy val textDf2: DataFrame = Seq(List("Hello, what is your name?", "Bye")).toDF("text")
lazy val textDf3: DataFrame = Seq(List("This is bullshit.")).toDF("text")
@@ -34,6 +33,12 @@ trait TranslatorUtils extends TestBase {
lazy val textDf5: DataFrame = Seq(List("The word