Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hail] memory-efficient scan #6345

Merged
merged 22 commits into from Jun 27, 2019
Merged

Conversation

@danking
Copy link
Collaborator

@danking danking commented Jun 13, 2019

This adds SpillingCollectIterator which avoids holding more than 1000 aggregation results in memory at one time. We could do something that listens for GC events and spills data if there's high memory pressure. That seems a bit error prone and hard.

The number of results kept in memory is a flag on the HailContext. In C++ we can design a system that is aware of its memory usage and adjusts memory allocated to scans accordingly.

Implementation Notes

I had to add two new file operations to FS and HadoopFS because I need seekable file input streams. When we add non-hadoop FS's we'll need to address the interface issue.

When we overflow our in-memory buffer, we spill to a disk file. We use O(n_partitions / mem_limit) files. We stream through the files to scanLeft, to compute the globally valid scan state per partition. The stream writes its results to another file which must be on a cluster-visible file system (we use HailContext.getTemporaryFile). Finally, each partition reads that file and seeks to its scan state.

I somewhat better solution would be to eagerly scan as results come in. I leave that as future work.

Timings

Master 0.2.14-4da055db5a7b

In [1]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(10000, n_partitions=10000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                             
CPU times: user 1.45 s, sys: 333 ms, total: 1.78 s
Wall time: 24.6 s
In [3]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(1000000, n_partitions=1000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                             
CPU times: user 6.23 ms, sys: 1.96 ms, total: 8.19 ms
Wall time: 1.33 s

This branch

In [1]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(10000, n_partitions=10000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                                                                                                 
CPU times: user 1.36 s, sys: 297 ms, total: 1.66 s
Wall time: 27.3 s

In [2]: %%time  
   ...:  
   ...: import hail as hl 
   ...: ht = hl.utils.range_table(1000000, n_partitions=1000) 
   ...: ht = ht.annotate(rank = hl.scan.count())._force_count()                                                                                                                                                                                                                 
CPU times: user 4.55 ms, sys: 1.47 ms, total: 6.02 ms
Wall time: 1.38 s
@danking
Copy link
Collaborator Author

@danking danking commented Jun 13, 2019

cc: @akotlar same request regarding the FS stuff.

@tpoterba ok finally right

Loading

@akotlar
Copy link
Contributor

@akotlar akotlar commented Jun 13, 2019

Looks about right to me.

Loading

scanAggsPerPartition.foreach { x =>
partitionIndices(i) = os.getPos
i += 1
val oos = new ObjectOutputStream(os)
Copy link
Collaborator

@tpoterba tpoterba Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we lift the oos up outside the loop, and flush once at the end?

Loading

Copy link
Collaborator

@tpoterba tpoterba Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, you need to flush inside to make the position correct. Could you add a comment to that effect?

Loading

Copy link
Collaborator Author

@danking danking Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it was pretty weird, I should verify but I also had problems getting the right positions when re-using the same OOS

Loading

var i = 0
val scanAggsPerPartitionFile = hc.getTemporaryFile()
HailContext.get.sFS.writeFileNoCompression(scanAggsPerPartitionFile) { os =>
scanAggsPerPartition.foreach { x =>
Copy link
Collaborator

@tpoterba tpoterba Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use zipWithIndex instead of the i += 1? seems a little clearer.

Loading

Copy link
Collaborator

@tpoterba tpoterba Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also don't need numPartitions + 1 scan intermediates, right? (you don't have to do the last one)

Loading

Copy link
Collaborator Author

@danking danking Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. It adds like three lines of code to save not many bytes so I removed it. I can use zipWithIndex

Loading

if (codec != null)
codec.createOutputStream(os)
else
os
}

private def createNoCompresion(filename: String): FSDataOutputStream = {
Copy link
Collaborator

@tpoterba tpoterba Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Loading

size += a.length
if (size > sizeLimit) {
val file = hc.getTemporaryFile()
fs.writeFileNoCompression(file) { os =>
Copy link
Collaborator

@tpoterba tpoterba Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about object output stream -- lift if possible add comment

Loading

}
}

class SpillingCollectIterator[T: ClassTag] private (rdd: RDD[T], sizeLimit: Int) extends Iterator[T] {
Copy link
Collaborator

@tpoterba tpoterba Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs unit tests -- can use property-based testing to generate random arrays, partitioning, size limits, then compare array.iterator with SpillingCollectIterator(sc.parallelize(array, nPartitions), sizeLimit)

Loading

@danking
Copy link
Collaborator Author

@danking danking commented Jun 19, 2019

For posterity this is what goes wrong if I don't create a fresh OOS for each object:


Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1.apply(RDD.scala:1098)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.fold(RDD.scala:1092)
	at is.hail.rvd.RVD.count(RVD.scala:660)
	at is.hail.methods.ForceCountTable.execute(ForceCount.scala:11)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:771)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:88)
	at is.hail.expr.ir.Interpret$.apply(Interpret.scala:59)
	at is.hail.expr.ir.InterpretNonCompilable$$anonfun$7.apply(InterpretNonCompilable.scala:19)
	at is.hail.expr.ir.InterpretNonCompilable$$anonfun$7.apply(InterpretNonCompilable.scala:19)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
	at is.hail.expr.ir.InterpretNonCompilable$.apply(InterpretNonCompilable.scala:19)
	at is.hail.expr.ir.CompileAndEvaluate$$anonfun$2.apply(CompileAndEvaluate.scala:36)
	at is.hail.expr.ir.CompileAndEvaluate$$anonfun$2.apply(CompileAndEvaluate.scala:36)
	at is.hail.utils.ExecutionTimer.time(ExecutionTimer.scala:20)
	at is.hail.expr.ir.CompileAndEvaluate$.apply(CompileAndEvaluate.scala:36)
	at is.hail.backend.Backend.execute(Backend.scala:61)
	at is.hail.backend.Backend.executeJSON(Backend.scala:67)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)

java.io.StreamCorruptedException: invalid stream header: 7571007E
	at java.io.ObjectInputStream.readStreamHeader(ObjectInputStream.java:866)
	at java.io.ObjectInputStream.<init>(ObjectInputStream.java:358)
	at is.hail.expr.ir.TableMapRows$$anonfun$41$$anonfun$42.apply(TableIR.scala:892)
	at is.hail.expr.ir.TableMapRows$$anonfun$41$$anonfun$42.apply(TableIR.scala:890)
	at is.hail.utils.package$.using(package.scala:597)
	at is.hail.io.fs.HadoopFS.readFileNoCompression(HadoopFS.scala:407)
	at is.hail.expr.ir.TableMapRows$$anonfun$41.apply(TableIR.scala:890)
	at is.hail.expr.ir.TableMapRows$$anonfun$41.apply(TableIR.scala:889)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitionsWithIndexAndValue$1$$anonfun$apply$34.apply(ContextRDD.scala:434)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitionsWithIndexAndValue$1$$anonfun$apply$34.apply(ContextRDD.scala:434)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28$$anonfun$apply$29.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28$$anonfun$apply$29.apply(ContextRDD.scala:405)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:435)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:441)
	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
	at is.hail.rvd.RVD$$anonfun$count$2.apply(RVD.scala:655)
	at is.hail.rvd.RVD$$anonfun$count$2.apply(RVD.scala:653)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$cmapPartitions$1$$anonfun$apply$28.apply(ContextRDD.scala:405)
	at is.hail.sparkextras.ContextRDD$$anonfun$run$1$$anonfun$apply$8.apply(ContextRDD.scala:192)
	at is.hail.sparkextras.ContextRDD$$anonfun$run$1$$anonfun$apply$8.apply(ContextRDD.scala:192)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:435)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:441)
	at scala.collection.Iterator$class.foreach(Iterator.scala:891)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.foldLeft(TraversableOnce.scala:157)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1334)
	at scala.collection.TraversableOnce$class.fold(TraversableOnce.scala:212)
	at scala.collection.AbstractIterator.fold(Iterator.scala:1334)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1$$anonfun$20.apply(RDD.scala:1096)
	at org.apache.spark.rdd.RDD$$anonfun$fold$1$$anonfun$20.apply(RDD.scala:1096)
	at org.apache.spark.SparkContext$$anonfun$36.apply(SparkContext.scala:2157)
	at org.apache.spark.SparkContext$$anonfun$36.apply(SparkContext.scala:2157)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:403)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:409)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Loading

@danking danking force-pushed the memmory-efficient-scan2 branch from e5eeb9d to ee9fa94 Jun 19, 2019
@danking
Copy link
Collaborator Author

@danking danking commented Jun 19, 2019

@tpoterba all comments addressed

Loading

@danking danking merged commit 2907766 into hail-is:master Jun 27, 2019
1 check passed
Loading
@danking danking deleted the memmory-efficient-scan2 branch Dec 18, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

3 participants