Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix issues with form recognizer parsing and form ontology learner #1506

Merged
merged 1 commit into from
May 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class FormOntologyLearner(override val uid: String) extends Estimator[FormOntolo
def this() = this(Identifiable.randomUID("FormOntologyLearner"))

private[ml] def extractOntology(fromRow: Row => AnalyzeResponse)(r: Row): StructType = {
val fieldResults = fromRow(r.getStruct(0)).analyzeResult.documentResults.get.head.fields
val fieldResults = fromRow(r.getStruct(0)).analyzeResult.documentResults
.getOrElse(throw new IllegalArgumentException("A row does not have a `analyzeResult.documentResults` field," +
" please filter these out before using the FormOntologyLearner"))
.head.fields
new StructType(fieldResults
.mapValues(_.toFieldResultRecursive.toSimplifiedDataType)
.map({ case (name, dt) => StructField(name, dt) }).toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import spray.json._

abstract class FormRecognizerBase(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicAsyncReply
with HasImageInput with HasSetLocation with HasSetLinkedService {
with HasImageInput with HasSetLocation with HasSetLinkedService with HasModelVersion {

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
Expand Down Expand Up @@ -306,21 +306,8 @@ class GetCustomModel(override val uid: String) extends CognitiveServicesBase(uid

def urlPath: String = "formrecognizer/v2.1/custom/models"

override protected def prepareUrl: Row => String = {
val urlParams: Array[ServiceParam[Any]] =
getUrlParams.asInstanceOf[Array[ServiceParam[Any]]];
// This semicolon is needed to avoid argument confusion
{ row: Row =>
val base = getUrl + s"/${getValue(row, modelId)}"
val appended = if (!urlParams.isEmpty) {
"?" + URLEncodingUtils.format(urlParams.flatMap(p =>
getValueOpt(row, p).map(v => p.name -> p.toValueString(v))
).toMap)
} else {
""
}
base + appended
}
override protected def prepareUrlRoot: Row => String = { row =>
getUrl + s"/${getValue(row, modelId)}"
}

override protected def prepareMethod(): HttpRequestBase = new HttpGet()
Expand All @@ -347,21 +334,8 @@ class AnalyzeCustomModel(override val uid: String) extends FormRecognizerBase(ui

def urlPath: String = "formrecognizer/v2.1/custom/models"

override protected def prepareUrl: Row => String = {
val urlParams: Array[ServiceParam[Any]] =
getUrlParams.asInstanceOf[Array[ServiceParam[Any]]];
// This semicolon is needed to avoid argument confusion
{ row: Row =>
val base = getUrl + s"/${getValue(row, modelId)}/analyze"
val appended = if (!urlParams.isEmpty) {
"?" + URLEncodingUtils.format(urlParams.flatMap(p =>
getValueOpt(row, p).map(v => p.name -> p.toValueString(v))
).toMap)
} else {
""
}
base + appended
}
override protected def prepareUrlRoot: Row => String = {row =>
getUrl + s"/${getValue(row, modelId)}/analyze"
}

override protected def responseDataType: DataType = AnalyzeResponse.schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,19 @@ trait HasPrebuiltModelID extends HasServiceParams {
def getPrebuiltModelIdCol: String = getVectorParam(prebuiltModelId)
}

object AnalyzeDocument extends ComplexParamsReadable[AnalyzeDocument] {
// Different versions might have different results so make sure tests pass before updating
val DefaultAPIVersion = "2022-01-30-preview"
}
object AnalyzeDocument extends ComplexParamsReadable[AnalyzeDocument]

class AnalyzeDocument(override val uid: String) extends CognitiveServicesBaseNoHandler(uid)
with HasCognitiveServiceInput with HasInternalJsonOutputParser with BasicAsyncReply
with HasPrebuiltModelID with HasPages with HasLocale
with HasPrebuiltModelID with HasPages with HasLocale with HasAPIVersion
with HasImageInput with HasSetLocation with BasicLogging {
logClass()

setDefault(apiVersion -> Left("2022-01-30-preview"))

def this() = this(Identifiable.randomUID("AnalyzeDocument"))

def urlPath: String = "formrecognizer/documentModels/"
def urlPath: String = "formrecognizer/documentModels"

val stringIndexType = new ServiceParam[String](this, "stringIndexType", "Method used to " +
"compute string offset and length.", {
Expand Down Expand Up @@ -75,21 +74,8 @@ class AnalyzeDocument(override val uid: String) extends CognitiveServicesBaseNoH
"Payload needs to contain image bytes or url. This code should not run"))
}

override protected def prepareUrl: Row => String = {
val urlParams: Array[ServiceParam[Any]] =
getUrlParams.asInstanceOf[Array[ServiceParam[Any]]];
// This semicolon is needed to avoid argument confusion
{ row: Row =>
val base = getUrl + s"${getValue(row, prebuiltModelId)}:analyze?api-version=${AnalyzeDocument.DefaultAPIVersion}"
val appended = if (!urlParams.isEmpty) {
"&" + URLEncodingUtils.format(urlParams.flatMap(p =>
getValueOpt(row, p).map(v => p.name -> p.toValueString(v))
).toMap)
} else {
""
}
base + appended
}
override protected def prepareUrlRoot: Row => String = { row =>
getUrl + s"/${getValue(row, prebuiltModelId)}:analyze"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ case class AnalyzeResultV3(apiVersion: String,

case class PageResultV3(pageNumber: Int,
angle: Double,
width: Int,
height: Int,
width: Double,
height: Double,
unit: String,
spans: Option[Seq[FormSpan]],
words: Option[Seq[FormWord]],
Expand All @@ -37,19 +37,19 @@ case class PageResultV3(pageNumber: Int,

case class FormSpan(offset: Int, length: Int)

case class FormWord(content: String, boundingBox: Option[Seq[Int]], confidence: Double, span: FormSpan)
case class FormWord(content: String, boundingBox: Option[Seq[Double]], confidence: Double, span: FormSpan)

case class FormSelectionMark(state: String, boundingBox: Option[Seq[Int]], confidence: Double, span: FormSpan)
case class FormSelectionMark(state: String, boundingBox: Option[Seq[Double]], confidence: Double, span: FormSpan)

case class FormLine(content: String, boundingBox: Option[Seq[Int]], spans: Option[Seq[FormSpan]])
case class FormLine(content: String, boundingBox: Option[Seq[Double]], spans: Option[Seq[FormSpan]])

case class TableResultV3(rowCount: Int,
columnCount: Int,
boundingRegions: Option[Seq[BoundingRegion]],
spans: Option[Seq[FormSpan]],
cells: Option[Seq[FormCell]])

case class BoundingRegion(pageNumber: Int, boundingBox: Option[Seq[Int]])
case class BoundingRegion(pageNumber: Int, boundingBox: Option[Seq[Double]])

case class FormCell(kind: String,
rowIndex: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.microsoft.azure.synapse.ml.cognitive.split1

import com.microsoft.azure.synapse.ml.cognitive._
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject}
import org.apache.spark.SparkException
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructType}
Expand All @@ -31,8 +32,30 @@ class FormOntologyLearnerSuite extends EstimatorFuzzing[FormOntologyLearner] wit
"https://mmlsparkdemo.blob.core.windows.net/ignite2021/forms/2009/Invoice12241.pdf"
).toDF("url")

lazy val tableUrlDF: DataFrame = Seq(
"https://mmlspark.blob.core.windows.net/datasets/FormRecognizer/tables1.pdf"
).toDF("url")

lazy val df: DataFrame = analyzeInvoices.transform(urlDF).cache()

test("Yields a reasonable error message when input rows dont contain documentResults") {
val analyzedDf = new AnalyzeLayout()
.setSubscriptionKey(cognitiveKey)
.setLocation("eastus")
.setImageUrlCol("url")
.setOutputCol("layout")
.setConcurrency(5)
.transform(tableUrlDF)

assertThrows[SparkException] {
new FormOntologyLearner()
.setInputCol("layout")
.setOutputCol("unified_ontology")
.fit(analyzedDf)
.transform(analyzedDf)
}
}

test("Basic Usage") {

val targetSchema = new StructType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ trait FormRecognizerUtils extends TestBase with CognitiveKey with Flaky {

lazy val bytesDF5: DataFrame = createTestDataframe(baseUrl, Seq("id1.jpg"), returnBytes = true)

lazy val imageDf6: DataFrame = createTestDataframe(baseUrl, Seq("tables1.pdf"), returnBytes = false)

lazy val pdfDf1: DataFrame = createTestDataframe(baseUrl, Seq("layout2.pdf"), returnBytes = false)

lazy val pdfDf2: DataFrame = createTestDataframe(baseUrl, Seq("invoice1.pdf", "invoice3.pdf"), returnBytes = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ class AnalyzeDocumentSuite extends TransformerFuzzing[AnalyzeDocument] with Form
super.assertDFEq(prep(df1), prep(df2))(eq)
}

test("basic usage with tables") {
val fromRow = AnalyzeDocumentResponse.makeFromRowConverter
analyzeDocument
.setPrebuiltModelId("prebuilt-layout")
.setImageUrlCol("source")
.transform(imageDf6)
.collect()
.map(r => fromRow(r.getAs[Row]("result")))
.foreach(r => assert(r.analyzeResult.pages.get.head.pageNumber >= 0))
}

def analyzeDocument: AnalyzeDocument = new AnalyzeDocument()
.setSubscriptionKey(cognitiveKey)
.setLocation("eastus")
Expand Down Expand Up @@ -128,7 +139,7 @@ class AnalyzeDocumentSuite extends TransformerFuzzing[AnalyzeDocument] with Form
resultAssert(result, "WA WASHINGTON\n20 1234567XX1101\nDRIVER LICENSE\nFEDERAL LIMITS APPLY\n" +
"4d LIC#WDLABCD456DG 9CLASS\nDONORS\n1 TALBOT\n2 LIAM R.\n3 DOB 01/06/1958\n",
"Address,CountryRegion,DateOfBirth,DateOfExpiration,DocumentNumber," +
"Endorsements,FirstName,LastName,Region,Restrictions,Sex")
"Endorsements,FirstName,LastName,Region,Restrictions,Sex")
}
}

Expand Down