Skip to content

Commit

Permalink
Merge branch 'master' into brwals/spark35
Browse files Browse the repository at this point in the history
  • Loading branch information
BrendanWalsh committed Mar 6, 2024
2 parents 948d3fe + b390fd4 commit 0510590
Show file tree
Hide file tree
Showing 21 changed files with 394 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/acknowledge-new-issues.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Comment to acknowledge issue
uses: peter-evans/create-or-update-comment@v3
uses: peter-evans/create-or-update-comment@v4
with:
issue-number: ${{ github.event.issue.number }}
body: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/acknowledge-new-prs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Comment to acknowledge PRs
uses: peter-evans/create-or-update-comment@v3
uses: peter-evans/create-or-update-comment@v4
with:
issue-number: ${{ github.event.pull_request.number }}
body: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/clean-acr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: checkout repo content
uses: actions/checkout@v4 # checkout the repo
- name: setup python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Execute clean-acr.py
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
Expand All @@ -60,7 +60,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v2
uses: github/codeql-action/autobuild@v3

# ℹ️ Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
Expand All @@ -73,6 +73,6 @@ jobs:
# ./location_of_script_within_repo/buildscript.sh

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"
4 changes: 2 additions & 2 deletions .github/workflows/scorecards.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ jobs:
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # v3.1.3
uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1
with:
name: SARIF file
path: results.sarif
retention-days: 5

# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@807578363a7869ca324a79039e6db9c843e0e100 # v2.1.27
uses: github/codeql-action/upload-sarif@03e7845b7bfcd5e7fb63d1ae8c61b0e791134fab # v2.22.11
with:
sarif_file: results.sarif
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package com.microsoft.azure.synapse.ml.services
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, TokenLibrary}
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
import org.apache.http.NameValuePair
Expand Down Expand Up @@ -218,18 +220,6 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
| return self
|
|def _transform(self, dataset: DataFrame) -> DataFrame:
| if running_on_synapse_internal():
| try:
| from synapse.ml.fabric.token_utils import TokenUtils
| from synapse.ml.fabric.service_discovery import get_fabric_env_config
| fabric_env_config = get_fabric_env_config().fabric_env_config
| if self._java_obj.getInternalServiceType() != "openai":
| self._java_obj.setDefaultAADToken(TokenUtils().get_aad_token())
| else:
| self._java_obj.setDefaultCustomAuthHeader(TokenUtils().get_openai_auth_header())
| self.setDefaultInternalEndpoint(fabric_env_config.get_mlflow_workload_endpoint())
| except ModuleNotFoundError as e:
| pass
| return super()._transform(dataset)
|""".stripMargin
}
Expand Down Expand Up @@ -327,6 +317,16 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA

protected def contentType: Row => String = { _ => "application/json" }

protected def getCustomAuthHeader(row: Row): Option[String] = {
val providedCustomHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomHeader .isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default AAD Token On Fabric")
Option(TokenLibrary.getAuthHeader)
} else {
providedCustomHeader
}
}

protected def addHeaders(req: HttpRequestBase,
subscriptionKey: Option[String],
aadToken: Option[String],
Expand Down Expand Up @@ -364,7 +364,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getValueOpt(row, CustomAuthHeader))
getCustomAuthHeader(row))

req match {
case er: HttpEntityEnclosingRequestBase =>
Expand Down Expand Up @@ -501,7 +501,12 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform

setDefault(
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"))
errorCol -> (this.uid + "_error")
)

if(PlatformDetails.runningOnFabric()) {
setDefaultInternalEndpoint(FabricClient.MLWorkloadEndpointML)
}

protected def handlingFunc(client: CloseableHttpClient,
request: HTTPRequestData): HTTPResponseData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion, HasServiceParams}
import com.microsoft.azure.synapse.ml.fabric.OpenAITokenLibrary
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion,
HasCognitiveServiceInput, HasServiceParams}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.spark.sql.Row
import spray.json.DefaultJsonProtocol._
Expand Down Expand Up @@ -244,6 +247,19 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
}
}

trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput {
override protected def getCustomAuthHeader(row: Row): Option[String] = {
val providedCustomHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomHeader.isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default OpenAI Token On Fabric")
Option(OpenAITokenLibrary.getAuthHeader)
} else {
providedCustomHeader
}

}
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) {
setDefault(timeout -> 360.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasCognitiveServiceInput,
HasInternalJsonOutputParser}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json._
import spray.json.DefaultJsonProtocol._
import spray.json._

import scala.language.existentials

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasCognitiveServiceInput
with HasOpenAITextParams with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase,
HasCognitiveServiceInput, HasInternalJsonOutputParser}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.util._
Expand All @@ -20,7 +19,7 @@ import scala.language.existentials
object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion]

class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput
with HasOpenAITextParams with HasPromptInputs with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,25 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasCognitiveServiceInput, HasServiceParams}
import com.microsoft.azure.synapse.ml.core.contracts.HasInputCol
import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
import spray.json._
import org.apache.spark.ml.linalg.{Vector, Vectors}

import scala.language.existentials

object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]

class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAISharedParams with HasCognitiveServiceInput with SynapseMLLogging {
with HasOpenAISharedParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ private[ml] abstract class TextAnalyticsBaseNoBinding(uid: String)
} else {
import TAJSONFormat._
val post = new HttpPost(prepareUrl(row))
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
val json = TARequest(makeDocuments(row)).toJson.compactPrint
post.setEntity(new StringEntity(json, "UTF-8"))
Some(post)
Expand Down Expand Up @@ -648,7 +654,13 @@ class TextAnalyze(override val uid: String) extends TextAnalyticsBaseNoBinding(u
None
} else {
val post = new HttpPost(getUrl)
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
val tasks = TextAnalyzeTasks(
entityRecognitionTasks = getTaskHelper(getIncludeEntityRecognition, getEntityRecognitionParams),
entityLinkingTasks = getTaskHelper(getIncludeEntityLinking, getEntityLinkingParams),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ trait TextAsOnlyEntity extends HasTextInput with HasCognitiveServiceInput with H
}

val post = new HttpPost(base + appended)
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))

val json = texts.map(s => Map("Text" -> s)).toJson.compactPrint
Expand Down Expand Up @@ -248,7 +254,13 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
}

val post = new HttpPost(base + appended)
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))

val json = texts.map(s => Map("Text" -> s)).toJson.compactPrint
Expand Down Expand Up @@ -533,7 +545,13 @@ class DictionaryExamples(override val uid: String) extends TextTranslatorBase(ui
}

val post = new HttpPost(base + appended)
addHeaders(post, getValueOpt(row, subscriptionKey), getValueOpt(row, AADToken), contentType(row))
addHeaders(
post,
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row)
)
getValueOpt(row, subscriptionRegion).foreach(post.setHeader("Ocp-Apim-Subscription-Region", _))

val json = textAndTranslations.head.getClass.getTypeName match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ class SpeakerEmotionInferenceSuite extends TransformerFuzzing[SpeakerEmotionInfe
"<mstts:express-as role='female' style='calm'>\"This is an example of a sentence with unmatched quotes,\"" +
"</mstts:express-as> she said.\"</voice></speak>\n"))

lazy val df: DataFrame = testData.map(e => e._1).toSeq.toDF("text")
lazy val df: DataFrame = testData.keys.toSeq.toDF("text")

test("basic") {
val transformed = ssmlGenerator.transform(df)
transformed.show(truncate = false)
transformed.collect().map(row => {
val actual = testData.get(row.getString(0)).getOrElse("")
val actual = testData.getOrElse(row.getString(0), "")
val expected = row.getString(2)
assert(actual.equals(expected))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ class SpeechToTextSDKSuite extends TransformerFuzzing[SpeechToTextSDK] with Spee

test("SAS URL based access") {
val sasURL = "https://mmlspark.blob.core.windows.net/datasets/Speech/audio2.wav?sv=2019-12-12" +
"&st=2021-01-25T16%3A40%3A13Z&se=2024-01-26T16%3A40%3A00Z" +
"&sr=b&sp=r&sig=NpFm%2FJemAJOGIya1ykQ6f80YdvwpiAuJjnb2RVDtKro%3D"
"?sv=2021-10-04&st=2024-02-28T16%3A17%3A55Z&se=2026-03-30T15%3A33%3A00Z" +
"&sr=c&sp=rl&sig=5Oy6pEaF4hN3lj8uo6daLN%2F%2BiV9VD6XFNSy%2FZ8Upeeg%3D"

tryWithRetries(Array(100, 500)) { () => //For handling flaky build machines
val uriDf = Seq(Tuple1(sasURL))
Expand Down Expand Up @@ -429,8 +429,8 @@ class ConversationTranscriptionSuite extends TransformerFuzzing[ConversationTran

test("SAS URL based access") {
val sasURL = "https://mmlspark.blob.core.windows.net/datasets/Speech/audio2.wav" +
"?sv=2019-12-12&st=2021-01-25T16%3A40%3A13Z&se=2024-01-26T16%3A40%3A00Z&sr=b&sp=r" +
"&sig=NpFm%2FJemAJOGIya1ykQ6f80YdvwpiAuJjnb2RVDtKro%3D"
"?sv=2021-10-04&st=2024-02-28T16%3A17%3A55Z&se=2026-03-30T15%3A33%3A00Z" +
"&sr=c&sp=rl&sig=5Oy6pEaF4hN3lj8uo6daLN%2F%2BiV9VD6XFNSy%2FZ8Upeeg%3D"

tryWithRetries(Array(100, 500)) { () => //For handling flaky build machines
val uriDf = Seq(Tuple1(sasURL))
Expand Down
Loading

0 comments on commit 0510590

Please sign in to comment.