Skip to content

Commit

Permalink
feat: [+] #105 improve raiseError with location
Browse files Browse the repository at this point in the history
partially resolves #105
  • Loading branch information
eruizalo committed Oct 1, 2022
1 parent c0b4ee0 commit 0cfcbf2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
package doric
package syntax

import doric.sem.Location
import org.apache.spark.sql.{functions => f}

private[syntax] trait StringColumns31 {

/**
* Throws an exception with the provided error message.
*
* @throws java.lang.RuntimeException with the error message
* @group String Type
* @see [[doric.syntax.StringColumns31.StringOperationsSyntax31.raiseError]]
*/
def raiseError(str: String)(implicit l: Location): NullColumn =
str.lit.raiseError

implicit class StringOperationsSyntax31(s: DoricColumn[String]) {

/**
Expand All @@ -20,6 +31,13 @@ private[syntax] trait StringColumns31 {
* @group String Type
* @see [[org.apache.spark.sql.functions.raise_error]]
*/
def raiseError: NullColumn = s.elem.map(f.raise_error).toDC
def raiseError(implicit l: Location): NullColumn =
concat(
s,
"\n at ".lit,
l.fileName.value.lit,
":".lit,
l.lineNumber.value.toString.lit
).elem.map(f.raise_error).toDC
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package doric
package syntax

import org.scalatest.EitherValues
import org.scalatest.{Assertion, EitherValues}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.sql.{functions => f}
import org.apache.spark.sql.types.NullType

Expand All @@ -15,7 +14,16 @@ class StringColumns31Spec
describe("raiseError doric function") {
import spark.implicits._

val df = List("this is an error").toDF("errorMsg")
lazy val errorMsg = "this is an error"
lazy val df = List(errorMsg).toDF("errorMsg")

def validateExceptions(
doricExc: RuntimeException,
sparkExc: RuntimeException
): Assertion = {
doricExc.getMessage should fullyMatch regex
s"""${sparkExc.getMessage}\n( )*at ${this.getClass.getSimpleName}.scala:(\\d)+"""
}

it("should work as spark raise_error function") {
import java.lang.{RuntimeException => exception}
Expand All @@ -30,7 +38,23 @@ class StringColumns31Spec
df.select(f.raise_error(f.col("errorMsg"))).collect()
}

doricErr.getMessage shouldBe sparkErr.getMessage
validateExceptions(doricErr, sparkErr)
}

it("should be available for strings") {
import java.lang.{RuntimeException => exception}

val doricErr = intercept[exception] {
val res = df.select(raiseError(errorMsg))

res.schema.head.dataType shouldBe NullType
res.collect()
}
val sparkErr = intercept[exception] {
df.select(f.raise_error(f.col("errorMsg"))).collect()
}

validateExceptions(doricErr, sparkErr)
}
}

Expand Down

0 comments on commit 0cfcbf2

Please sign in to comment.