Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hail] memory-efficient scan #6345

Merged
merged 22 commits into from Jun 27, 2019

Conversation

@danking
Copy link
Collaborator

commented Jun 13, 2019

This adds SpillingCollectIterator which avoids holding more than 1000 aggregation results in memory at one time. We could do something that listens for GC events and spills data if there's high memory pressure. That seems a bit error prone and hard.

The number of results kept in memory is a flag on the HailContext. In C++ we can design a system that is aware of its memory usage and adjusts memory allocated to scans accordingly.

Implementation Notes

I had to add two new file operations to FS and HadoopFS because I need seekable file input streams. When we add non-hadoop FS's we'll need to address the interface issue.

When we overflow our in-memory buffer, we spill to a disk file. We use O(n_partitions / mem_limit) files. We stream through the files to scanLeft, to compute the globally valid scan state per partition. The stream writes its results to another file which must be on a cluster-visible file system (we use HailContext.getTemporaryFile). Finally, each partition reads that file and seeks to its scan state.

I somewhat better solution would be to eagerly scan as results come in. I leave that as future work.

Timings

Master 0.2.14-4da055db5a7b

In [1]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(10000, n_partitions=10000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                             
CPU times: user 1.45 s, sys: 333 ms, total: 1.78 s
Wall time: 24.6 s
In [3]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(1000000, n_partitions=1000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                             
CPU times: user 6.23 ms, sys: 1.96 ms, total: 8.19 ms
Wall time: 1.33 s

This branch

In [1]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(10000, n_partitions=10000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                                                                                                 
CPU times: user 1.36 s, sys: 297 ms, total: 1.66 s
Wall time: 27.3 s

In [2]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(1000000, n_partitions=1000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                                                                                                 
CPU times: user 4.55 ms, sys: 1.47 ms, total: 6.02 ms
Wall time: 1.38 s
@danking

This comment has been minimized.

Copy link
Collaborator Author

commented Jun 13, 2019

cc: @akotlar same request regarding the FS stuff.

@tpoterba ok finally right

@akotlar

This comment has been minimized.

Copy link
Collaborator

commented Jun 13, 2019

Looks about right to me.

scanAggsPerPartition.foreach { x =>
partitionIndices(i) = os.getPos
i += 1
val oos = new ObjectOutputStream(os)

This comment has been minimized.

Copy link
@tpoterba

tpoterba Jun 18, 2019

Collaborator

can we lift the oos up outside the loop, and flush once at the end?

This comment has been minimized.

Copy link
@tpoterba

tpoterba Jun 18, 2019

Collaborator

oh, you need to flush inside to make the position correct. Could you add a comment to that effect?

This comment has been minimized.

Copy link
@danking

danking Jun 18, 2019

Author Collaborator

yeah it was pretty weird, I should verify but I also had problems getting the right positions when re-using the same OOS

var i = 0
val scanAggsPerPartitionFile = hc.getTemporaryFile()
HailContext.get.sFS.writeFileNoCompression(scanAggsPerPartitionFile) { os =>
scanAggsPerPartition.foreach { x =>

This comment has been minimized.

Copy link
@tpoterba

tpoterba Jun 18, 2019

Collaborator

use zipWithIndex instead of the i += 1? seems a little clearer.

This comment has been minimized.

Copy link
@tpoterba

tpoterba Jun 18, 2019

Collaborator

You also don't need numPartitions + 1 scan intermediates, right? (you don't have to do the last one)

This comment has been minimized.

Copy link
@danking

danking Jun 18, 2019

Author Collaborator

Correct. It adds like three lines of code to save not many bytes so I removed it. I can use zipWithIndex

if (codec != null)
codec.createOutputStream(os)
else
os
}

private def createNoCompresion(filename: String): FSDataOutputStream = {

This comment has been minimized.

Copy link
@tpoterba

tpoterba Jun 18, 2019

Collaborator

typo

size += a.length
if (size > sizeLimit) {
val file = hc.getTemporaryFile()
fs.writeFileNoCompression(file) { os =>

This comment has been minimized.

Copy link
@tpoterba

tpoterba Jun 18, 2019

Collaborator

same comment about object output stream -- lift if possible add comment

}
}

class SpillingCollectIterator[T: ClassTag] private (rdd: RDD[T], sizeLimit: Int) extends Iterator[T] {

This comment has been minimized.

Copy link
@tpoterba

tpoterba Jun 18, 2019

Collaborator

This needs unit tests -- can use property-based testing to generate random arrays, partitioning, size limits, then compare array.iterator with SpillingCollectIterator(sc.parallelize(array, nPartitions), sizeLimit)

@danking

This comment has been minimized.

Copy link
Collaborator Author

commented Jun 19, 2019

For posterity this is what goes wrong if I don't create a fresh OOS for each object:


Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1.apply(RDD.scala:1098)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.fold(RDD.scala:1092)
	at is.hail.rvd.RVD.count(RVD.scala:660)
	at is.hail.methods.ForceCountTable.execute(ForceCount.scala:11)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:771)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:88)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:59)
	at is.hail.expr.ir.InterpretNonCompilable$$anonfun$7.apply(InterpretNonCompilable.scala:19)
	at is.hail.expr.ir.InterpretNonCompilable$$anonfun$7.apply(InterpretNonCompilable.scala:19)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
	at is.hail.expr.ir.InterpretNonCompilable$.apply(InterpretNonCompilable.scala:19)
	at is.hail.expr.ir.CompileAndEvaluate$$anonfun$2.apply(CompileAndEvaluate.scala:36)
	at is.hail.expr.ir.CompileAndEvaluate$$anonfun$2.apply(CompileAndEvaluate.scala:36)
	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:20)
	at is.hail.expr.ir.CompileAndEvaluate$.apply(CompileAndEvaluate.scala:36)
	at is.hail.backend.Backend.execute(Backend.scala:61)
	at is.hail.backend.Backend.executeJSON(Backend.scala:67)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)

java.io.StreamCorruptedException: invalid stream header: 7571007E
	at java.io.ObjectInputStream.readStreamHeader(ObjectInputStream.java:866)
	at java.io.ObjectInputStream.<init>(ObjectInputStream.java:358)
	at is.hail.expr.ir.TableMapRows$$anonfun$41$$anonfun$42.apply(TableIR.scala:892)
	at is.hail.expr.ir.TableMapRows$$anonfun$41$$anonfun$42.apply(TableIR.scala:890)
	at is.hail.utils.package$.using(package.scala:597)
	at is.hail.io.fs.HadoopFS.readFileNoCompression(HadoopFS.scala:407)
	at is.hail.expr.ir.TableMapRows$$anonfun$41.apply(TableIR.scala:890)
	at is.hail.expr.ir.TableMapRows$$anonfun$41.apply(TableIR.scala:889)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitionsWithIndexAndValue$1$$anonfun$apply$34.apply(ContextRDD.scala:434)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitionsWithIndexAndValue$1$$anonfun$apply$34.apply(ContextRDD.scala:434)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28$$anonfun$apply$29.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28$$anonfun$apply$29.apply(ContextRDD.scala:405)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:435)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:441)
	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
	at is.hail.rvd.RVD$$anonfun$count$2.apply(RVD.scala:655)
	at is.hail.rvd.RVD$$anonfun$count$2.apply(RVD.scala:653)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$run$1$$anonfun$apply$8.apply(ContextRDD.scala:192)
	at is.hail.sparkextras.ContextRDD$$anonfun$run$1$$anonfun$apply$8.apply(ContextRDD.scala:192)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:435)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:441)
	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.foldLeft(TraversableOnce.scala:157)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.fold(TraversableOnce.scala:212)
	at scala.collection.AbstractIterator.fold(Iterator.scala:1334)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1$$anonfun$20.apply(RDD.scala:1096)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1$$anonfun$20.apply(RDD.scala:1096)
	at org.apache.spark.SparkContext$$anonfun$36.apply(SparkContext.scala:2157)
	at org.apache.spark.SparkContext$$anonfun$36.apply(SparkContext.scala:2157)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:403)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:409)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

@danking danking force-pushed the danking:memmory-efficient-scan2 branch from e5eeb9d to ee9fa94 Jun 19, 2019

@danking

This comment has been minimized.

Copy link
Collaborator Author

commented Jun 19, 2019

@tpoterba all comments addressed

danking added some commits Jun 25, 2019

@danking danking merged commit 2907766 into hail-is:master Jun 27, 2019

1 check passed

ci-test success
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.