Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hail] memory-efficient scan #6345

Merged
merged 22 commits into from
Jun 27, 2019
Merged

Conversation

danking
Copy link
Contributor

@danking danking commented Jun 13, 2019

This adds SpillingCollectIterator which avoids holding more than 1000 aggregation results in memory at one time. We could do something that listens for GC events and spills data if there's high memory pressure. That seems a bit error prone and hard.

The number of results kept in memory is a flag on the HailContext. In C++ we can design a system that is aware of its memory usage and adjusts memory allocated to scans accordingly.

Implementation Notes

I had to add two new file operations to FS and HadoopFS because I need seekable file input streams. When we add non-hadoop FS's we'll need to address the interface issue.

When we overflow our in-memory buffer, we spill to a disk file. We use O(n_partitions / mem_limit) files. We stream through the files to scanLeft, to compute the globally valid scan state per partition. The stream writes its results to another file which must be on a cluster-visible file system (we use HailContext.getTemporaryFile). Finally, each partition reads that file and seeks to its scan state.

I somewhat better solution would be to eagerly scan as results come in. I leave that as future work.

Timings

Master 0.2.14-4da055db5a7b

In [1]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(10000, n_partitions=10000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                             
CPU times: user 1.45 s, sys: 333 ms, total: 1.78 s
Wall time: 24.6 s
In [3]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(1000000, n_partitions=1000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                             
CPU times: user 6.23 ms, sys: 1.96 ms, total: 8.19 ms
Wall time: 1.33 s

This branch

In [1]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(10000, n_partitions=10000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                                                                                                 
CPU times: user 1.36 s, sys: 297 ms, total: 1.66 s
Wall time: 27.3 s

In [2]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(1000000, n_partitions=1000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                                                                                                 
CPU times: user 4.55 ms, sys: 1.47 ms, total: 6.02 ms
Wall time: 1.38 s

@danking
Copy link
Contributor Author

danking commented Jun 13, 2019

cc: @akotlar same request regarding the FS stuff.

@tpoterba ok finally right

@akotlar
Copy link
Contributor

akotlar commented Jun 13, 2019

Looks about right to me.

scanAggsPerPartition.foreach { x =>
partitionIndices(i) = os.getPos
i += 1
val oos = new ObjectOutputStream(os)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we lift the oos up outside the loop, and flush once at the end?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, you need to flush inside to make the position correct. Could you add a comment to that effect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it was pretty weird, I should verify but I also had problems getting the right positions when re-using the same OOS

var i = 0
val scanAggsPerPartitionFile = hc.getTemporaryFile()
HailContext.get.sFS.writeFileNoCompression(scanAggsPerPartitionFile) { os =>
scanAggsPerPartition.foreach { x =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use zipWithIndex instead of the i += 1? seems a little clearer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also don't need numPartitions + 1 scan intermediates, right? (you don't have to do the last one)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. It adds like three lines of code to save not many bytes so I removed it. I can use zipWithIndex

if (codec != null)
codec.createOutputStream(os)
else
os
}

private def createNoCompresion(filename: String): FSDataOutputStream = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

size += a.length
if (size > sizeLimit) {
val file = hc.getTemporaryFile()
fs.writeFileNoCompression(file) { os =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about object output stream -- lift if possible add comment

}
}

class SpillingCollectIterator[T: ClassTag] private (rdd: RDD[T], sizeLimit: Int) extends Iterator[T] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs unit tests -- can use property-based testing to generate random arrays, partitioning, size limits, then compare array.iterator with SpillingCollectIterator(sc.parallelize(array, nPartitions), sizeLimit)

@danking
Copy link
Contributor Author

danking commented Jun 19, 2019

For posterity this is what goes wrong if I don't create a fresh OOS for each object:


Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1.apply(RDD.scala:1098)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.fold(RDD.scala:1092)
	at is.hail.rvd.RVD.count(RVD.scala:660)
	at is.hail.methods.ForceCountTable.execute(ForceCount.scala:11)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:771)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:88)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:59)
	at is.hail.expr.ir.InterpretNonCompilable$$anonfun$7.apply(InterpretNonCompilable.scala:19)
	at is.hail.expr.ir.InterpretNonCompilable$$anonfun$7.apply(InterpretNonCompilable.scala:19)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
	at is.hail.expr.ir.InterpretNonCompilable$.apply(InterpretNonCompilable.scala:19)
	at is.hail.expr.ir.CompileAndEvaluate$$anonfun$2.apply(CompileAndEvaluate.scala:36)
	at is.hail.expr.ir.CompileAndEvaluate$$anonfun$2.apply(CompileAndEvaluate.scala:36)
	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:20)
	at is.hail.expr.ir.CompileAndEvaluate$.apply(CompileAndEvaluate.scala:36)
	at is.hail.backend.Backend.execute(Backend.scala:61)
	at is.hail.backend.Backend.executeJSON(Backend.scala:67)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)

java.io.StreamCorruptedException: invalid stream header: 7571007E
	at java.io.ObjectInputStream.readStreamHeader(ObjectInputStream.java:866)
	at java.io.ObjectInputStream.<init>(ObjectInputStream.java:358)
	at is.hail.expr.ir.TableMapRows$$anonfun$41$$anonfun$42.apply(TableIR.scala:892)
	at is.hail.expr.ir.TableMapRows$$anonfun$41$$anonfun$42.apply(TableIR.scala:890)
	at is.hail.utils.package$.using(package.scala:597)
	at is.hail.io.fs.HadoopFS.readFileNoCompression(HadoopFS.scala:407)
	at is.hail.expr.ir.TableMapRows$$anonfun$41.apply(TableIR.scala:890)
	at is.hail.expr.ir.TableMapRows$$anonfun$41.apply(TableIR.scala:889)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitionsWithIndexAndValue$1$$anonfun$apply$34.apply(ContextRDD.scala:434)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitionsWithIndexAndValue$1$$anonfun$apply$34.apply(ContextRDD.scala:434)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28$$anonfun$apply$29.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28$$anonfun$apply$29.apply(ContextRDD.scala:405)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:435)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:441)
	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
	at is.hail.rvd.RVD$$anonfun$count$2.apply(RVD.scala:655)
	at is.hail.rvd.RVD$$anonfun$count$2.apply(RVD.scala:653)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$run$1$$anonfun$apply$8.apply(ContextRDD.scala:192)
	at is.hail.sparkextras.ContextRDD$$anonfun$run$1$$anonfun$apply$8.apply(ContextRDD.scala:192)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:435)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:441)
	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.foldLeft(TraversableOnce.scala:157)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.fold(TraversableOnce.scala:212)
	at scala.collection.AbstractIterator.fold(Iterator.scala:1334)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1$$anonfun$20.apply(RDD.scala:1096)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1$$anonfun$20.apply(RDD.scala:1096)
	at org.apache.spark.SparkContext$$anonfun$36.apply(SparkContext.scala:2157)
	at org.apache.spark.SparkContext$$anonfun$36.apply(SparkContext.scala:2157)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:403)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:409)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

@danking
Copy link
Contributor Author

danking commented Jun 19, 2019

@tpoterba all comments addressed

@danking danking merged commit 2907766 into hail-is:master Jun 27, 2019
@danking danking deleted the memmory-efficient-scan2 branch December 18, 2019 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants