diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Triangulate.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Triangulate.scala index 39f384bd4..a5c873d83 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Triangulate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Triangulate.scala @@ -5,6 +5,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.geometry.linestring.MosaicLineString import com.databricks.labs.mosaic.core.geometry.multilinestring.MosaicMultiLineString import com.databricks.labs.mosaic.core.geometry.multipoint.MosaicMultiPoint +import com.databricks.labs.mosaic.core.geometry.point.MosaicPoint import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum.{MULTIPOINT, POINT} import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.geometry.base.VectorExpression @@ -49,10 +50,29 @@ case class ST_Triangulate ( .eval(input) .asInstanceOf[ArrayData] .toObjectArray(firstElementType) - .map(geometryAPI.geometry(_, firstElementType)) + .map({ + obj => + val g = geometryAPI.geometry(obj, firstElementType) + g.getGeometryType.toUpperCase(Locale.ROOT) match { + case "POINT" => g.asInstanceOf[MosaicPoint] + case _ => throw new UnsupportedOperationException("ST_Triangulate requires Point geometry as masspoints input") + } + }) val multiPointGeom = geometryAPI.fromSeq(pointsGeom, MULTIPOINT).asInstanceOf[MosaicMultiPoint] - val linesGeom = inputLinesArray.eval(input).asInstanceOf[ArrayData].toObjectArray(secondElementType).map(geometryAPI.geometry(_, secondElementType).asInstanceOf[MosaicLineString]) + val linesGeom = + inputLinesArray + .eval(input) + .asInstanceOf[ArrayData] + .toObjectArray(secondElementType) + .map({ + obj => + val g = geometryAPI.geometry(obj, secondElementType) + g.getGeometryType.toUpperCase(Locale.ROOT) match { + case "LINESTRING" => g.asInstanceOf[MosaicLineString] + case _ => throw new UnsupportedOperationException("ST_Triangulate requires LINESTRING geometry as breakline input") + } + }) val triangles = multiPointGeom.triangulate(linesGeom, inputTolerance.eval(input).asInstanceOf[Double]) diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_TriangulateBehaviours.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_TriangulateBehaviours.scala index 1a2108a8f..21d53c27f 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_TriangulateBehaviours.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_TriangulateBehaviours.scala @@ -34,4 +34,30 @@ trait ST_TriangulateBehaviours extends QueryTest { } + def conformingTriangulateBehavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + + val mc = mosaicContext + import mc.functions._ + val sc = spark + import sc.implicits._ + mc.register(spark) + + val pointsPath = "src/test/resources/binary/elevation/sd46_dtm_point.shp" + val linesPath = "src/test/resources/binary/elevation/sd46_dtm_breakline.shp" + val points = MosaicContext.read.option("asWKB", "true").format("multi_read_ogr").load(pointsPath) + val breaklines = MosaicContext.read.option("asWKB", "true").format("multi_read_ogr").load(linesPath) + val linesDf = breaklines + .where(st_geometrytype($"geom_0") === "LINESTRING") + .groupBy() + .agg(collect_list($"geom_0").as("breaklines")) + val result = points + .groupBy() + .agg(collect_list($"geom_0").as("masspoints")) + .crossJoin(linesDf) + .withColumn("mesh", st_triangulate($"masspoints", $"breaklines", lit(0.01))) + .drop($"masspoints", $"breaklines") + noException should be thrownBy result.collect() + + } + } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_TriangulateTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_TriangulateTest.scala index ec5f2e1e3..797a787cd 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_TriangulateTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_TriangulateTest.scala @@ -7,5 +7,6 @@ import org.apache.spark.sql.test.SharedSparkSession class ST_TriangulateTest extends QueryTest with SharedSparkSession with ST_TriangulateBehaviours { test("Testing ST_Triangulate (H3, JTS) to produce unconstrained triangulation") { simpleTriangulateBehavior(H3IndexSystem, JTS)} + test("Testing ST_Triangulate (H3, JTS) to produce conforming triangulation") { conformingTriangulateBehavior(H3IndexSystem, JTS)} }