Skip to content

Spark_SQL_Macro_examples

Harish Butani edited this page Feb 19, 2021 · 2 revisions

Registering and using a macro

  • Register a macro by calling the registerMacro function. Its second argument mut be a call to the udm macro.
  • import org.apache.spark.sql.defineMacros._ so the registerMacro anf udm functions become implicitly available on your sparkSession
import org.apache.spark.sql.defineMacros._

spark.registerMacro("intUDM", spark.udm({(i : Int) =>
  val b = Array(5, 6)
  val j = b(0)
  val k = new java.sql.Date(System.currentTimeMillis()).getTime
  i + j + k + Math.abs(j)
  })
)

Once registered you can use the macro as a udf in spark-sql. The sql select m3(c_int) from sparktest.unit_test generates the following plan:

Project [((cast((c_int#2 + 5) as bigint) + 1613706648609) + cast(abs(5) as bigint)) AS ((CAST((c_int + 5) AS BIGINT) + 1613706648609) + CAST(abs(5) AS BIGINT))#3L]
+- SubqueryAlias spark_catalog.default.unit_test
   +- Relation[c_varchar2_40#0,c_number#1,c_int#2] parquet

Basic examples

Num. macro Catalyst expression
1. (i : Int) => i macroarg(0, IntegerType)
2. (i : java.lang.Integer) => i macroarg(0, IntegerType)
3. (i : Int) => i + 5 (macroarg(0, IntegerType) + 5)
4. {(i : Int) =>
val b = Array(5)
val j = 5
j
}
5
5. (i : Int) => org.apache.spark.SPARK_BRANCH.length + i (4 + macroarg(0, IntegerType))
6. {(i : Int) =>
val b = Array(5, 6)
val j = b(0)
i + j + Math.abs(j)}
((macroarg(0, IntegerType) + 5) + abs(5))
7. {(i : Int) =>
val b = Array(5)
val j = 5
j
}
5
  • We support all Types for which Spark can infer an ExpressionEncoder
  • See Array examples on expressions supported in Arrays
  • In example 5 org.apache.spark.SPARK_BRANCH.length is evaluated at macro defintion time.

Scala ADTs

  • We support ADT(case classes) construction and field access
  • At macro call-site the `case class must be in scope, for example:
package macrotest {
  object ExampleStructs {
    case class Point(x: Int, y: Int)
  }
}

import macrotest.ExampleStructs.Point
Num. macro Catalyst expression
1. {(p : Point) =>
Point(1, 2)
}
named_struct(x, 1, y, 2)
2. {(p : Point) =>
p.x + p.y
}
(
macroarg(0, StructField(x,IntegerType,false), StructField(y,IntegerType,false)).x
+
macroarg(0, StructField(x,IntegerType,false), StructField(y,IntegerType,false)).y
)
3. {(p : Point) =>
Point(p.x + p.y, p.y)
}
named_struct(
x,
(macroarg(0, StructField(x,IntegerType,false), StructField(y,IntegerType,false)).x
+ macroarg(0, StructField(x,IntegerType,false), StructField(y,IntegerType,false)).y),
y,
macroarg(0, StructField(x,IntegerType,false), StructField(y,IntegerType,false)).y
)

Tuple examples

  • We support Tuple construction and field access
Num. macro Catalyst expression
1. {(t : Tuple2[Int, Int]) =>
(t._2, t._1)
}
named_struct(
col1,
macroarg(0, StructField(_1,IntegerType,false), StructField(_2,IntegerType,false))._2,
col2,
macroarg(0, StructField(_1,IntegerType,false), StructField(_2,IntegerType,false))._1
)
2. {(t : Tuple2[Int, Int]) =>
t._2 -> t._1
}
named_struct(
col1,
macroarg(0, StructField(_1,IntegerType,false), StructField(_2,IntegerType,false))._2,
col2,
macroarg(0, StructField(_1,IntegerType,false), StructField(_2,IntegerType,false))._1
)
3. {(t : Tuple4[Float, Double, Int, Int]) =>
(t._4 + t._3, t._4)
}
named_struct(
col1,
(macroarg(0, StructField(_1,FloatType,false), StructField(_2,DoubleType,false), StructField(_3,IntegerType,false), StructField(_4,IntegerType,false))._4
+ macroarg(0, StructField(_1,FloatType,false), StructField(_2,DoubleType,false), StructField(_3,IntegerType,false), StructField(_4,IntegerType,false))._3),
col2,
macroarg(0, StructField(_1,FloatType,false), StructField(_2,DoubleType,false), StructField(_3,IntegerType,false), StructField(_4,IntegerType,false))._4
)

Collection examples

  • We support construction and entry access for Array and immutable Map
Num. macro Catalyst expression
1. {(i : Int) =>
val b = Array(5, i)
val j = b(0)
j + b(1)
}
(5 + macroarg(0, IntegerType))
2. {(i : Int) =>
val b = Map(0 -> i, 1 -> (i + 1))
val j = b(0)
j + b(1)
}
(macroarg(0, IntegerType) + (macroarg(0, IntegerType) + 1))

DateTime examples

  • Currently we support translation of functions in spark DateTimeUtils module function params are Date and Timestamp instead of Int and Long. org.apachespark.sql.sqlmacros.DateTimeUtils defines the supported functions with Date and Timestamp parameters.

Example

import java.sql.Date
import java.sql.Timestamp
import java.time.ZoneId
import java.time.Instant
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.sql.sqlmacros.DateTimeUtils._

{(dt : Date) =>
  val dtVal = dt
  val dtVal2 = new Date(System.currentTimeMillis())
  val tVal = new Timestamp(System.currentTimeMillis())
  val dVal3 = localDateToDays(java.time.LocalDate.of(2000, 1, 1))
  val t2 = instantToMicros(Instant.now())
  val t3 = stringToTimestamp("2000-01-01", ZoneId.systemDefault()).get
  val t4 = daysToMicros(dtVal, ZoneId.systemDefault())
  
  getDayInYear(dtVal) + getDayOfMonth(dtVal) + getDayOfWeek(dtVal2) +
  getHours(tVal, ZoneId.systemDefault) + getSeconds(t2, ZoneId.systemDefault) +
  getMinutes(t3, ZoneId.systemDefault()) +
  getDayInYear(dateAddMonths(dtVal, getMonth(dtVal2))) +
  getDayInYear(dVal3) +
  getHours(
    timestampAddInterval(t4, new CalendarInterval(1, 1, 1), ZoneId.systemDefault()),
    ZoneId.systemDefault) +
  getDayInYear(dateAddInterval(dtVal, new CalendarInterval(1, 1, 1L))) +
  monthsBetween(t2, t3, true, ZoneId.systemDefault()) +
  getDayOfMonth(getNextDateForDayOfWeek(dtVal2, "MO")) +
  getDayInYear(getLastDayOfMonth(dtVal2)) + getDayOfWeek(truncDate(dtVal, "week")) +
  getHours(toUTCTime(t3, ZoneId.systemDefault().toString), ZoneId.systemDefault())
}

is translated to:

((((((((((((((
  dayofyear(macroarg(0)) + 
  dayofmonth(macroarg(0))) + 
  dayofweek(CAST(timestamp_millis(1613603994973L) AS DATE))) + 
  hour(timestamp_millis(1613603995085L))) + 
  second(TIMESTAMP '2021-02-17 15:19:55.207')) + 
  minute(CAST('2000-01-01' AS TIMESTAMP))) + 
  dayofyear(add_months(macroarg(0), month(CAST(timestamp_millis(1613603994973L) AS DATE))))) + 
  dayofyear(make_date(2000, 1, 1))) + 
  hour(CAST(macroarg(0) AS TIMESTAMP) + INTERVAL '1 months 1 days 0.000001 seconds')) + 
  dayofyear(macroarg(0) + INTERVAL '1 months 1 days 0.000001 seconds')) + 
  months_between(TIMESTAMP '2021-02-17 15:19:55.207', CAST('2000-01-01' AS TIMESTAMP), true)) + 
  dayofmonth(next_day(CAST(timestamp_millis(1613603994973L) AS DATE), 'MO'))) + 
  dayofyear(last_day(CAST(timestamp_millis(1613603994973L) AS DATE)))) + 
  dayofweek(trunc(macroarg(0), 'week'))) + 
  hour(to_utc_timestamp(CAST('2000-01-01' AS TIMESTAMP), 'America/Los_Angeles')))

Predicates and Conditionals

We support translation for logical operators(AND, OR, NOT), for comparison operators (>, >=, <, <=, ==, !=), for string predicate functions(startsWith, endsWith, contains), the if statement and the case statement.

Support for case statements is limited:

  • case pattern must be cq"$pat => $expr2", so no if in case

  • the pattern must be a literal for constructor pattern like (a,b), Point(1,2) etc.

  • org.apache.spark.sql.sqlmacros.PredicateUtils provides a class that provides marker functions is_null, is_not_null, null_safe_eq(o : Any), in(a : Any*), def not_in(a : Any*) for Any value. By having import org.apache.spark.sql.sqlmacros.PredicateUtils._ in scope at the macro call-site you can write expressions like i.is_not_null, i.in(4, 5) (see Example 1 below).

    • note that these functions are there for the purpose of translation only. Since, if the macro cannot be translated, the scala function is registered with Spark; then at runtime invocation of the function will fail at these expressions.

Example 1

import org.apache.spark.sql.sqlmacros.PredicateUtils._
import macrotest.ExampleStructs.Point

{ (i: Int) =>
    val j = if (i > 7 && i < 20 && i.is_not_null) {
      i
    } else if (i == 6 || i.in(4, 5) ) {
      i + 1
    } else i + 2
    val k = i match {
      case 1 => i + 2
      case _ => i + 3
    }
    val l = (j, k) match {
      case (1, 2) => 1
      case (3, 4) => 2
      case _ => 3
    }
    val p = Point(k, l)
    val m = p match {
      case Point(1, 2) => 1
      case _ => 2
    }
    j + k + l + m
}

is translated to the expression tree:

((((
 IF((((macroarg(0) > 7) AND (macroarg(0) < 20)) AND (macroarg(0) IS NOT NULL)), 
   macroarg(0), 
   (IF(((macroarg(0) = 6) OR (macroarg(0) IN (4, 5))), 
     (macroarg(0) + 1), 
     (macroarg(0) + 2))))
 ) + 
 CASE 
   WHEN (macroarg(0) = 1) THEN (macroarg(0) + 2) 
   ELSE (macroarg(0) + 3) END) + 
 CASE 
   WHEN (named_struct(
          'col1', 
          (IF((((macroarg(0) > 7) AND (macroarg(0) < 20)) AND (macroarg(0) IS NOT NULL)), macroarg(0), 
            (IF(((macroarg(0) = 6) OR (macroarg(0) IN (4, 5))), 
            (macroarg(0) + 1), 
            (macroarg(0) + 2))))
          ), 
          'col2', 
          CASE WHEN (macroarg(0) = 1) THEN (macroarg(0) + 2) ELSE (macroarg(0) + 3) END
          ) = [1,2]) THEN 1 
   WHEN (named_struct(
            'col1', 
            (IF((((macroarg(0) > 7) AND (macroarg(0) < 20)) AND (macroarg(0) IS NOT NULL)), macroarg(0),
               (IF(((macroarg(0) = 6) OR (macroarg(0) IN (4, 5))),
                (macroarg(0) + 1),
                 (macroarg(0) + 2))))
            ),
            'col2',
             CASE WHEN (macroarg(0) = 1) THEN (macroarg(0) + 2) ELSE (macroarg(0) + 3) END
             ) = [3,4]) THEN 2 
   ELSE 3 END) +
 CASE WHEN (named_struct(
             'x',
              CASE WHEN (macroarg(0) = 1) THEN (macroarg(0) + 2) ELSE (macroarg(0) + 3) END,
             'y',
              CASE WHEN (named_struct('col1', 
                                      (IF((((macroarg(0) > 7) AND (macroarg(0) < 20))
                                              AND (macroarg(0) IS NOT NULL)),
                                              macroarg(0), 
                                              (IF(((macroarg(0) = 6) OR (macroarg(0) IN (4, 5))),
                                                 (macroarg(0) + 1), (macroarg(0) + 2))))
                                      ),
                                      'col2',
                                       CASE WHEN (macroarg(0) = 1) THEN (macroarg(0) + 2)
                                       ELSE (macroarg(0) + 3) END
                                       ) = [1,2]
                          ) THEN 1 
                    WHEN (named_struct('col1',
                                       (IF((((macroarg(0) > 7) AND (macroarg(0) < 20))
                                           AND (macroarg(0) IS NOT NULL)),
                                            macroarg(0), 
                                            (IF(((macroarg(0) = 6) OR (macroarg(0) IN (4, 5))),
                                               (macroarg(0) + 1), (macroarg(0) + 2))))
                                       ),
                                       'col2',
                                        CASE WHEN (macroarg(0) = 1) THEN (macroarg(0) + 2)
                                        ELSE (macroarg(0) + 3) END
                                       ) = [3,4]
                          ) THEN 2 
                    ELSE 3 END
              ) = [1,2]) THEN 1 
              ELSE 2 END
)

Example 2

{ (s: String) =>
    val i = if (s.endsWith("abc")) 1 else 0
    val j = if (s.contains("abc")) 1 else 0
    val k = if (s.is_not_null && s.not_in("abc")) 1 else 0
    i + j + k
}

is translated to the expression tree:

(((
  IF(endswith(macroarg(0), 'abc'), 1, 0)) + 
  (IF(contains(macroarg(0), 'abc'), 1, 0))) + 
  (IF(((macroarg(0) IS NOT NULL) AND (NOT (macroarg(0) IN ('abc')))), 1, 0))
)

Recursive use of macros

  • A macro definition may have invocations to already registered macros
  • The syntax to call a registered macro is registered_macros.<macro_name>(args...)
import org.apache.spark.sql.defineMacros._
import org.apache.spark.sql.sqlmacros.registered_macros

spark.registerMacro("m2", spark.udm({(i : Int) =>
  val b = Array(5, 6)
  val j = b(0)
  val k = new java.sql.Date(System.currentTimeMillis()).getTime
  i + j + k + Math.abs(j)
})
)

spark.registerMacro("m3", spark.udm({(i : Int) =>
  val l : Int = registered_macros.m2(i)
  i + l
})
)

Then the sql select m3(c_int) from sparktest.unit_test has the following plan:

Project [(cast(c_int#3 as bigint) + ((cast((c_int#3 + 5) as bigint) + 1613707983401) + cast(abs(5) as bigint))) AS (CAST(c_int AS BIGINT) + ((CAST((c_int + 5) AS BIGINT) + 1613707983401) + CAST(abs(5) AS BIGINT)))#4L]
+- SubqueryAlias spark_catalog.default.unit_test
   +- Relation[c_varchar2_40#1,c_number#2,c_int#3] parquet

Optimized catalyst expressions

  • We collapse sparkexpr.GetMapValue, sparkexpr.GetStructField and sparkexpr.GetArrayItem expressions.
  • We also simplify Unwrap <- Wrap expression sub-trees for Option values.
Num. macro Catalyst expression
1. {(p : Point) =>
val p1 = Point(p.x, p.y)
val a = Array(1)
val m = Map(1 -> 2)
p1.x + p1.y + a(0) + m(1)
}
(((macroarg(0, StructField(x,IntegerType,false), StructField(y,IntegerType,false)).x
+ macroarg(0, StructField(x,IntegerType,false), StructField(y,IntegerType,false)).y) + 1)
+ 2)