Skip to content

Commit

Permalink
[compiler] WriteValue Stage Locally (#12798)
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Mar 20, 2023
1 parent 8e76d42 commit c1d1492
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 26 deletions.
2 changes: 0 additions & 2 deletions hail/python/test/hail/linalg/test_linalg.py
Expand Up @@ -1156,8 +1156,6 @@ def test_write_overwrite(self):
bm2.write(path, overwrite=True)
self._assert_eq(BlockMatrix.read(path), bm2)

@fails_service_backend()
@fails_local_backend()
def test_stage_locally(self):
nd = np.arange(0, 80, dtype=float).reshape(8, 10)
with hl.TemporaryDirectory(ensure_exists=False) as bm_uri:
Expand Down
11 changes: 6 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala
Expand Up @@ -48,14 +48,15 @@ case class BlockMatrixNativeWriter(
def loweredTyp: Type = TVoid

override def lower(ctx: ExecuteContext, s: BlockMatrixStage2, evalCtx: IRBuilder, eltR: TypeWithRequiredness): IR = {
if (stageLocally)
throw new LowererUnsupportedOperation(s"stageLocally not supported in BlockMatrixWrite lowering")
val etype = EBlockMatrixNDArray(EType.fromTypeAndAnalysis(s.typ.elementType, eltR), encodeRowMajor = forceRowMajor, required = true)
val spec = TypedCodecSpec(etype, TNDArray(s.typ.elementType, Nat(2)), BlockMatrix.bufferSpec)

val paths = s.collectBlocks(evalCtx, "block_matrix_native_writer") { (ctx, idx, block) =>
val filepath = strConcat(s"$path/parts/part-", idx, UUID4())
WriteValue(block, filepath, spec)
val paths = s.collectBlocks(evalCtx, "block_matrix_native_writer") { (_, idx, block) =>
val suffix = strConcat("parts/part-", idx, UUID4())
val filepath = strConcat(s"$path/", suffix)
WriteValue(block, filepath, spec,
if (stageLocally) Some(strConcat(s"${ctx.localTmpdir}/", suffix)) else None
)
}
RelationalWriter.scoped(path, overwrite, None)(WriteMetadata(paths, BlockMatrixNativeMetadataWriter(path, stageLocally, s.typ)))
}
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Children.scala
Expand Up @@ -247,7 +247,7 @@ object Children {
case WritePartition(stream, ctx, _) => Array(stream, ctx)
case WriteMetadata(writeAnnotations, _) => Array(writeAnnotations)
case ReadValue(path, _, _) => Array(path)
case WriteValue(value, path, spec) => Array(value, path)
case WriteValue(value, path, _, staged) => Array(value, path) ++ staged.toArray[IR]
case LiftMeOut(child) => Array(child)
}
}
9 changes: 6 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/Copy.scala
Expand Up @@ -400,9 +400,12 @@ object Copy {
case ReadValue(path, spec, requestedType) =>
assert(newChildren.length == 1)
ReadValue(newChildren(0).asInstanceOf[IR], spec, requestedType)
case WriteValue(value, path, spec) =>
assert(newChildren.length == 2)
WriteValue(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], spec)
case WriteValue(_, _, spec, _) =>
assert(newChildren.length == 2 || newChildren.length == 3)
val value = newChildren(0).asInstanceOf[IR]
val path = newChildren(1).asInstanceOf[IR]
val stage = if (newChildren.length == 3) Some(newChildren(2).asInstanceOf[IR]) else None
WriteValue(value, path, spec, stage)
case LiftMeOut(_) =>
LiftMeOut(newChildren(0).asInstanceOf[IR])
}
Expand Down
15 changes: 10 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
@@ -1,6 +1,5 @@
package is.hail.expr.ir

import is.hail.HailContext
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend.{BackendContext, ExecuteContext, HailTaskContext}
Expand All @@ -9,16 +8,16 @@ import is.hail.expr.ir.analyses.{ComputeMethodSplits, ControlFlowPreventsSplit,
import is.hail.expr.ir.lowering.TableStageDependency
import is.hail.expr.ir.ndarrays.EmitNDArray
import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils}
import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec}
import is.hail.io.fs.FS
import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec}
import is.hail.linalg.{BLAS, LAPACK, LinalgCodeUtils}
import is.hail.types.physical._
import is.hail.types.physical.stypes._
import is.hail.types.physical.stypes.concrete._
import is.hail.types.physical.stypes.interfaces._
import is.hail.types.physical.stypes.primitives._
import is.hail.types.virtual._
import is.hail.types.{RIterable, TypeWithRequiredness, VirtualTypeWithReq, tcoerce}
import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq, tcoerce}
import is.hail.utils._
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -2274,13 +2273,19 @@ class Emit[C](
decoded
}

case WriteValue(value, path, spec) =>
case WriteValue(value, path, spec, stagingFile) =>
emitI(path).flatMap(cb) { case pv: SStringValue =>
emitI(value).map(cb) { v =>
val ob = cb.memoize[OutputBuffer](spec.buildCodeOutputBuffer(mb.createUnbuffered(pv.asString.loadString(cb))))
val s = stagingFile.map(emitI(_).get(cb).asString)
val ob = cb.memoize[OutputBuffer](spec.buildCodeOutputBuffer(mb.createUnbuffered(
s.getOrElse(pv).loadString(cb))
))
spec.encodedType.buildEncoder(v.st, cb.emb.ecb)
.apply(cb, v, ob)
cb += ob.invoke[Unit]("close")
s.foreach { stage =>
cb += mb.getFS.invoke[String, String, Boolean, Unit]("copy", stage.loadString(cb), pv.loadString(cb), const(true))
}
pv
}
}
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/IR.scala
Expand Up @@ -911,7 +911,7 @@ final case class WritePartition(value: IR, writeCtx: IR, writer: PartitionWriter
final case class WriteMetadata(writeAnnotations: IR, writer: MetadataWriter) extends IR

final case class ReadValue(path: IR, spec: AbstractTypedCodecSpec, requestedType: Type) extends IR
final case class WriteValue(value: IR, path: IR, spec: AbstractTypedCodecSpec) extends IR
final case class WriteValue(value: IR, path: IR, spec: AbstractTypedCodecSpec, stagingFile: Option[IR] = None) extends IR

class PrimitiveIR(val self: IR) extends AnyVal {
def +(other: IR): IR = {
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/InferType.scala
Expand Up @@ -284,7 +284,7 @@ object InferType {
case WritePartition(value, writeCtx, writer) => writer.returnType
case _: WriteMetadata => TVoid
case ReadValue(_, _, typ) => typ
case WriteValue(value, path, spec) => TString
case _: WriteValue => TString
case LiftMeOut(child) => child.typ
}
}
Expand Down
8 changes: 4 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Expand Up @@ -1537,10 +1537,10 @@ object IRParser {
case "WriteValue" =>
import AbstractRVDSpec.formats
val spec = JsonMethods.parse(string_literal(it)).extract[AbstractTypedCodecSpec]
for {
value <- ir_value_expr(env)(it)
path <- ir_value_expr(env)(it)
} yield WriteValue(value, path, spec)
ir_value_children(env)(it).map {
case Array(value, path) => WriteValue(value, path, spec)
case Array(value, path, stagingFile) => WriteValue(value, path, spec, Some(stagingFile))
}
case "LiftMeOut" => ir_value_expr(env)(it).map(LiftMeOut)
case "ReadPartition" =>
val rowType = tcoerce[TStruct](type_expr(it))
Expand Down
4 changes: 1 addition & 3 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
@@ -1,6 +1,5 @@
package is.hail.expr.ir

import is.hail.HailContext
import is.hail.backend.ExecuteContext
import is.hail.expr.JSONAnnotationImpex
import is.hail.expr.ir.Pretty.prettyBooleanLiteral
Expand All @@ -12,7 +11,6 @@ import is.hail.utils.prettyPrint._
import is.hail.utils.richUtils.RichIterable
import is.hail.utils.{space => _, _}
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JString
import org.json4s.jackson.{JsonMethods, Serialization}

import scala.collection.mutable
Expand Down Expand Up @@ -432,7 +430,7 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
single(prettyStringLiteral(JsonMethods.compact(writer.toJValue), elide = elideLiterals))
case ReadValue(_, spec, reqType) =>
FastSeq(prettyStringLiteral(spec.toString), reqType.parsableString())
case WriteValue(_, _, spec) => single(prettyStringLiteral(spec.toString))
case WriteValue(_, _, spec, _) => single(prettyStringLiteral(spec.toString))
case MakeNDArray(_, _, _, errorId) => FastSeq(errorId.toString)

case _ => Iterable.empty
Expand Down
3 changes: 2 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala
Expand Up @@ -539,8 +539,9 @@ object TypeCheck {
case x@ReadValue(path, spec, requestedType) =>
assert(path.typ == TString)
assert(spec.encodedType.decodedPType(requestedType).virtualType == requestedType)
case x@WriteValue(value, path, spec) =>
case WriteValue(_, path, _, stagingFile) =>
assert(path.typ == TString)
assert(stagingFile.forall(_.typ == TString))
case LiftMeOut(_) =>
case Consume(_) =>
case TableMapRows(child, newRow) =>
Expand Down
1 change: 1 addition & 0 deletions hail/src/test/scala/is/hail/expr/ir/IRSuite.scala
Expand Up @@ -2828,6 +2828,7 @@ class IRSuite extends HailSuite {
RelationalWriter("path", overwrite = false, None)),
ReadValue(Str("foo"), TypedCodecSpec(PCanonicalStruct("foo" -> PInt32(), "bar" -> PCanonicalString()), BufferSpec.default), TStruct("foo" -> TInt32)),
WriteValue(I32(1), Str("foo"), TypedCodecSpec(PInt32(), BufferSpec.default)),
WriteValue(I32(1), Str("foo"), TypedCodecSpec(PInt32(), BufferSpec.default), Some(Str("/tmp/uid/part"))),
LiftMeOut(I32(1)),
RelationalLet("x", I32(0), I32(0)),
TailLoop("y", IndexedSeq("x" -> I32(0)), Recur("y", FastSeq(I32(4)), TInt32))
Expand Down

0 comments on commit c1d1492

Please sign in to comment.