Skip to content

Commit

Permalink
Fixes projection over multiple columns of a nested object
Browse files Browse the repository at this point in the history
  • Loading branch information
mjakubowski84 committed Nov 12, 2023
1 parent 5df32c3 commit a021d7a
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 18 deletions.
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ lazy val core = (projectMatrix in file("core"))
.jvmPlatform(
scalaVersions = Seq(twoTwelve, twoThirteen),
settings = Def.settings(
libraryDependencies ++= sparkDeps :+ ("com.chuusai" %% "shapeless" % shapelessVersion)
libraryDependencies ++= sparkDeps :+ ("com.chuusai" %% "shapeless" % shapelessVersion),
excludeDependencies ++= Seq(
ExclusionRule("org.apache.logging.log4j", "log4j-slf4j2-impl")
)
)
)
.jvmPlatform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ class ProjectionItSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll

val expectedNestedRecord = RowParquetRecord(
"b" -> ListParquetRecord(
RowParquetRecord("x" -> 1.value),
RowParquetRecord("x" -> 2.value),
RowParquetRecord("x" -> 3.value)
)
RowParquetRecord("x" -> 1.value, "y" -> "a".value),
RowParquetRecord("x" -> 2.value, "y" -> "b".value),
RowParquetRecord("x" -> 3.value, "y" -> "c".value)
),
"c" -> true.value
)

try records.toList should be(
Expand All @@ -165,4 +166,19 @@ class ProjectionItSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll
finally records.close()
}

it should "allow to project complex fields for different subfields" in {
val records = ParquetReader
.projectedGeneric(
Col("nested.b").as[List[FullElem]],
Col("nested.c").as[Boolean]
)
.read(complexFilePath)
.as[FullNested]

val expectedRecord = FullNested(b = List(FullElem(1, "a"), FullElem(2, "b"), FullElem(3, "c")), c = true)

try records.toList should be(List(expectedRecord))
finally records.close()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,18 @@ abstract private class RowParquetRecordConverter(schema: GroupType)
private class RootRowParquetRecordConverter(schema: GroupType, columnProjections: Seq[ColumnProjection])
extends RowParquetRecordConverter(schema) {

private lazy val emptyProjectionRow = RowParquetRecord.emptyWithSchema(columnProjections.map(cp => cp.alias.getOrElse(cp.columnPath.elements.last)))

override def end(): Unit =
columnProjections.foreach { case ColumnProjection(columnPath, ordinal, aliasOpt) =>
record.get(columnPath) match {
case Some(value) if columnPath.elements.length > 1 =>
record = record.updated(ordinal, aliasOpt.getOrElse(columnPath.elements.last), value)
case Some(_) =>
aliasOpt.foreach { alias =>
record = record.rename(ordinal, alias)
if (columnProjections.nonEmpty) {
record = columnProjections.foldLeft(emptyProjectionRow) {
case (newRecord, ColumnProjection(columnPath, _, aliasOpt)) =>
record.get(columnPath) match {
case Some(value) =>
newRecord.add(aliasOpt.getOrElse(columnPath.elements.last), value)
case None =>
throw new IllegalArgumentException(s"""Invalid column projection: "$columnPath".""")
}
case None =>
throw new IllegalArgumentException(s"""Invalid column projection: "$columnPath".""")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.github.mjakubowski84.parquet4s

import com.github.mjakubowski84.parquet4s.SchemaDef.Meta
import org.apache.parquet.schema.{Types, *}
import org.apache.parquet.schema.*
import org.apache.parquet.schema.LogicalTypeAnnotation.{
DateLogicalTypeAnnotation,
DecimalLogicalTypeAnnotation,
Expand All @@ -14,6 +14,7 @@ import org.apache.parquet.schema.Type.Repetition

import scala.annotation.nowarn
import scala.reflect.ClassTag
import scala.jdk.CollectionConverters.*

object Message {

Expand All @@ -22,8 +23,9 @@ object Message {
def apply(name: Option[String], fields: Type*): MessageType =
Types.buildMessage().addFields(fields*).named(name.getOrElse(DefaultName))

/** Merges the fields before creating a schema. Merge is done by unifying types of columns that are define in a
* projection more than once. Type of the first mentioned column is chosen for each duplicate.
/** Merges the fields before creating a schema. Merge is done by unifying types of columns that are defined in a
* projection more than once. The first mentioned column of primitive type is chosen in the case of duplicates.
* Union of of member fields is executed in the case of complex types.
* @param fields
* fields to be merged and then used for defining the schema
* @return
Expand All @@ -37,6 +39,10 @@ object Message {
.foldLeft(Map.empty[String, Type], Vector.empty[Type]) { case ((register, merged), tpe) =>
val fieldName = tpe.getName
register.get(fieldName) match {
case Some(group: GroupType) if !tpe.isPrimitive() =>
val newMemberFields = mergeFields(group.getFields().asScala.toSeq ++ tpe.asGroupType().getFields().asScala)
val mergedGroup = group.withNewFields(newMemberFields.asJava)
register.updated(fieldName, mergedGroup) -> (merged.filterNot(_ == group) :+ mergedGroup)
case Some(firstSeen) =>
register -> (merged :+ firstSeen)
case None =>
Expand Down
2 changes: 1 addition & 1 deletion project/DependecyVersions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ object DependecyVersions {
val slf4jVersion = "2.0.9"
val logbackVersion = "1.3.11" // stick to 1.3.x for JDK-8 compatibility
val akkaVersion = "2.6.21" // non-licensed version
val fs2Version = "3.9.2"
val fs2Version = "3.9.3"
val catsEffectVersion = "3.5.2"
val scalaCollectionCompatVersion = "2.11.0"
val scalatestVersion = "3.2.17"
Expand Down

0 comments on commit a021d7a

Please sign in to comment.