Skip to content

Commit

Permalink
Add one passed, fill-up test fails
Browse files Browse the repository at this point in the history
  • Loading branch information
dmittov committed Aug 23, 2020
1 parent 33cb8d2 commit 35fd7ba
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 51 deletions.
5 changes: 2 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ organization := "io.github.dmittov"

version := "0.2"

scalaVersion := "2.11.12"
scalaVersion := "2.12.10"

// EMR-5.30.0
// https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-release-5x.html
val sparkVersion = "2.4.5"

resolvers += "MavenRepository" at "https://mvnrepository.com/"
Expand All @@ -17,6 +15,7 @@ libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % "provided" withSources(),
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided" withSources(),
"org.apache.spark" %% "spark-catalyst" % sparkVersion % "provided" withSources(),
// FIXME: no spark-testing-base for 2.4.6 yet
"com.holdenkarau" %% "spark-testing-base" % "2.4.5_0.14.0" % "test",
"org.apache.spark" %% "spark-hive" % sparkVersion % "test"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,83 @@ import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}

import org.apache.commons.io.output.ByteArrayOutputStream
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types.{ArrayType, DataType}

import scala.collection.mutable
import scala.util.Random

case class WeightedItem[T](
body: T,
weight: Double
body: T,
weight: Double,
priority: Double
)

object WeightedItem {
implicit def orderingByPriority[A <: WeightedItem[_]]: Ordering[A] =
Ordering.by(item => item.priority)
}

/**
* A-ExpJ Reservoir implementation with merge option
* @param buffer Reservoir PriorityQueue buffer dequeue extracts Max, therefore priority it's = to -log(weightKey)
* of the original algorithm
* @param weightsToSkip
* @param passedWeight
* @param limit Reservoir size
* @tparam T
*/
case class Reservoir[T](
buffer: mutable.ArrayBuffer[WeightedItem[T]],
var weight: Double,
buffer: mutable.PriorityQueue[WeightedItem[T]],
weightsToSkip: Double,
passedWeight: Double,
limit: Int) {

def isFull: Boolean = buffer.length == limit

def add(item: WeightedItem[T]): Reservoir[T] = {
def add(item: T, weight: Double): Reservoir[T] = {
if (isFull) {
// FIXME: replace stub with implementation
replace(item, 0)
} else append(item)
if (weight >= weightsToSkip) {
replace(item, weight)
} else discard(weight)
} else append(item, weight)
}

def merge(other: Reservoir[T]): Reservoir[T] = {
// FIXME: replace stub with implementation
this.copy(weight = weight + other.weight)
copy(
passedWeight = passedWeight + other.passedWeight,
weightsToSkip = weightsToSkip + other.weightsToSkip
)
}

private[this] def append(item: WeightedItem[T]): Reservoir[T] = {
copy(buffer = buffer :+ item, weight = weight + item.weight)
private[this] def discard(weight: Double): Reservoir[T] = {
copy(passedWeight = passedWeight + weight, weightsToSkip = weightsToSkip - weight)
}

private[this] def replace(item: WeightedItem[T], position: Int): Reservoir[T] = {
this.buffer(position) = item
copy(weight = weight + item.weight)
private[this] def append(item: T, weight: Double): Reservoir[T] = {
val negPriority = math.log(Random.nextDouble()) / weight
buffer += WeightedItem(item, weight, -negPriority)
val weightsToSkip = if (buffer.length == limit) math.log(Random.nextDouble) / (-buffer.head.priority) else 0.0
copy(passedWeight = passedWeight + weight, weightsToSkip = weightsToSkip)
}

private[this] def replace(item: T, weight: Double): Reservoir[T] = {
val offset = math.exp(-buffer.head.priority * weight)
val negPriority = math.log((1 - offset) * Random.nextDouble + offset) / weight
buffer.dequeue()
buffer += WeightedItem(item, weight, -negPriority)
copy(
passedWeight = passedWeight + weight,
weightsToSkip = math.log(Random.nextDouble) / (-buffer.head.priority)
)
}
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Collects and returns a list of <size> non-unique elements " +
"out of origin list with equal probabilities.")
usage = "_FUNC_(expr) - Collects and returns a list of <size> non-unique elements " +
"out of origin list with equal probabilities.")
private case class CollectSample(
child: Expression,
limit: Int,
Expand All @@ -65,13 +98,13 @@ private case class CollectSample(
override def prettyName: String = "collect_sample"

override def createAggregationBuffer(): Reservoir[Any] =
Reservoir(mutable.ArrayBuffer.empty, 0.0, limit)
Reservoir(mutable.PriorityQueue.empty, 0.0, 0.0, limit)

override def update(reservoir: Reservoir[Any], input: InternalRow): Reservoir[Any] = {
val value = child.eval(input)
if (value != null) {
val item = WeightedItem(body = InternalRow.copyValue(value), weight = 1.0)
reservoir.add(item)
val item = InternalRow.copyValue(value)
reservoir.add(item, 1.0)
} else reservoir
}

Expand All @@ -87,7 +120,14 @@ private case class CollectSample(
else (rightReservoir, leftReservoir)
}
}
incompleteReservoir.buffer.foldLeft(reservoir) { (reservoir, item) => reservoir.add(item) }
val itemWeight = reservoir.passedWeight / reservoir.buffer.length
reservoir.buffer.foldLeft(incompleteReservoir) {
(reservoir, item) => reservoir.add(item.body, itemWeight)
}
val resultReservoir = incompleteReservoir.copy(
weightsToSkip = incompleteReservoir.weightsToSkip + reservoir.weightsToSkip
)
resultReservoir
}
}

Expand Down Expand Up @@ -115,5 +155,4 @@ private case class CollectSample(
override def nullable: Boolean = true

override def dataType: DataType = ArrayType(child.dataType)

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,60 @@ import com.holdenkarau.spark.testing.{DataFrameSuiteBase, RDDComparisons, Shared
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.{explode, struct, udf}
import org.scalatest.{FunSpec, Matchers}
import org.apache.spark.sql.functions._

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

class CollectSampleSpec extends FunSpec with Matchers with DataFrameSuiteBase
with RDDComparisons with SharedSparkContext {

import spark.implicits._

describe("Reservoir tests") {

it("add one test") {
val testsCount = 10000
val testResults = 1 to testsCount map (_ => {
val initReservoir = Reservoir[Int](
mutable.PriorityQueue[WeightedItem[Int]](), 0.0, 0.0, 1
).add(0, 100.0)
val updatedReservoir = initReservoir.add(1, 100.0)
if (updatedReservoir.buffer.map(_.body).head == 1) 1.0d else 0.0
})
val successRate = testResults.sum / testResults.length
val diff = math.abs(successRate - 0.5)
// 3-sigma chance to fail
assert(diff < 3.0 / 200)
}

it("fillup test") {
val testsCount = 10000
val testSize = 100
val restResults = 1 to testsCount map (_ => {
val order = Random.nextInt(testSize + 1)
val testReservoir = (0 to testSize).foldLeft[Reservoir[Int]](
Reservoir[Int](mutable.PriorityQueue[WeightedItem[Int]](), 0.0, 0.0, 2)
) {
(reservoir: Reservoir[Int], position: Int) => {
if (position == order) {
reservoir.add(1, testSize.toDouble)
} else {
reservoir.add(0, 1.0)
}
}
}
val survivedItems: Set[Int] = testReservoir.buffer.map(_.body).toSet
if (survivedItems contains 1) 1.0d else 0.0d

})
val successRate = restResults.sum / restResults.length
print(successRate)
assert(successRate > 0.45 && successRate < 0.55)
}
}

describe("collect_sample tests") {
it("check result cardinality") {
val df = sc.parallelize(DataStubs.Primarchs).toDF("loyal", "id", "name")
Expand All @@ -27,33 +75,35 @@ class CollectSampleSpec extends FunSpec with Matchers with DataFrameSuiteBase
compareRDD(limited.select($"loyal", size($"sample").as("sz")).rdd, correct) should be(None)
}

it("check simple datatype") {
val df = sc.parallelize(List(1, 2, 3, 4, 5, 6, 7, 8, 9, 0)).toDF("num")
val size = udf { x: Seq[Row] => x.size }
val limited = df.agg(CollectLimit.collect_sample($"num", 4).as("lst"))
val correct = sc.parallelize(List(Row(4)))
compareRDD(limited.select(size($"lst").as("cnt")).rdd, correct) should be(None)
}
it("check simple datatype") {
val df = sc.parallelize(List(1, 2, 3, 4, 5, 6, 7, 8, 9, 0)).toDF("num")
val size = udf { x: Seq[Row] => x.size }
val limited = df.agg(CollectLimit.collect_sample($"num", 4).as("lst"))
val correct = sc.parallelize(List(Row(4)))
compareRDD(limited.select(size($"lst").as("cnt")).rdd, correct) should be(None)
}

it("incomplete buckets") {
val df = sc.parallelize(DataStubs.Primarchs).toDF("loyal", "id", "name")
val limited = df.
groupBy($"id").
agg(
CollectLimit.collect_sample($"name", 2).as("names")
)
val result = limited.select($"id", explode($"names").as("name")).rdd
val correct = df.select($"id", $"name").rdd
compareRDD(result, correct) should be(None)
}
it("incomplete buckets") {
val df = sc.parallelize(DataStubs.Primarchs).toDF("loyal", "id", "name")
val limited = df.
groupBy($"id").
agg(
CollectLimit.collect_sample($"name", 2).as("names")
)
val result = limited.select($"id", explode($"names").as("name")).rdd
val correct = df.select($"id", $"name").rdd
compareRDD(result, correct) should be(None)
}

it("null test") {
val df = sc.parallelize(
List(Some(1), None, Some(5), None, Some(2), None, Some(3), None)
).toDF("num")
val limited = df.agg(CollectLimit.collect_sample($"num", 2).as("nums"))
val correct = sc.parallelize(List(Row(List(1, 5))))
compareRDD(limited.rdd, correct) should be(None)
}
it("null test") {
val df = sc.parallelize(
List(Some(1), None, Some(5), None, None, None)
).toDF("num")
val limited = df.
agg(CollectLimit.collect_sample($"num", 2).as("nums")).
select(sort_array($"nums").as("nums"))
val correct = sc.parallelize(List(Row(List(1, 5))))
compareRDD(limited.rdd, correct) should be(None)
}
}
}

0 comments on commit 35fd7ba

Please sign in to comment.