Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Support zstd compression in BGEN files #12576

Merged
merged 2 commits into from Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions hail/build.gradle
Expand Up @@ -233,6 +233,8 @@ dependencies {
bundled group: 'org.freemarker', name: 'freemarker', version: '2.3.31'

bundled 'com.kohlschutter.junixsocket:junixsocket-core:2.6.1'

bundled 'com.github.luben:zstd-jni:1.4.8-1'
}

task(checkSettings) doLast {
Expand Down
11 changes: 7 additions & 4 deletions hail/python/hail/ir/matrix_writer.py
Expand Up @@ -107,21 +107,24 @@ def __eq__(self, other):


class MatrixBGENWriter(MatrixWriter):
@typecheck_method(path=str, export_type=ExportType.checker)
def __init__(self, path, export_type):
@typecheck_method(path=str, export_type=ExportType.checker, compression_codec=str)
def __init__(self, path, export_type, compression_codec):
self.path = path
self.export_type = export_type
self.compression_codec = compression_codec

def render(self):
writer = {'name': 'MatrixBGENWriter',
'path': self.path,
'exportType': self.export_type}
'exportType': self.export_type,
'compressionCodec': self.compression_codec}
return escape_str(json.dumps(writer))

def __eq__(self, other):
return isinstance(other, MatrixBGENWriter) and \
other.path == self.path and \
other.export_type == self.export_type
other.export_type == self.export_type and \
other.compression_codec == self.compression_codec


class MatrixPLINKWriter(MatrixWriter):
Expand Down
10 changes: 7 additions & 3 deletions hail/python/hail/methods/impex.py
Expand Up @@ -176,8 +176,9 @@ def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None,
gp=nullable(expr_array(expr_float64)),
varid=nullable(expr_str),
rsid=nullable(expr_str),
parallel=nullable(ir.ExportType.checker))
def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None):
parallel=nullable(ir.ExportType.checker),
compression_codec=enumeration('zlib', 'zstd'))
def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None, compression_codec='zlib'):
"""Export MatrixTable as :class:`.MatrixTable` as BGEN 1.2 file with 8
bits of per probability. Also writes SAMPLE file.

Expand Down Expand Up @@ -206,6 +207,8 @@ def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None):
per partition), each with its own header. If
``'separate_header'``, write a file for each partition,
without header, and a header file for the combined dataset.
compresssion_codec : str, optional
Compression codec. One of 'zlib', 'zstd'.
"""
require_row_key_variant(mt, 'export_bgen')
require_col_key_str(mt, 'export_bgen')
Expand Down Expand Up @@ -244,7 +247,8 @@ def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None):

Env.backend().execute(ir.MatrixWrite(mt._mir, ir.MatrixBGENWriter(
output,
parallel)))
parallel,
compression_codec)))


@typecheck(dataset=MatrixTable,
Expand Down
14 changes: 14 additions & 0 deletions hail/python/test/hail/methods/test_impex.py
Expand Up @@ -1504,6 +1504,20 @@ def test_export_bgen(self):
sample_file=tmp + '.sample')
assert bgen._same(bgen2)

@fails_service_backend()
@fails_local_backend()
def test_export_bgen_zstd(self):
bgen = hl.import_bgen(resource('example.8bits.bgen'),
entry_fields=['GP'],
sample_file=resource('example.sample'))
tmp = new_temp_file("zstd")
hl.export_bgen(bgen, tmp, compression_codec='zstd')
hl.index_bgen(tmp + '.bgen')
bgen2 = hl.import_bgen(tmp + '.bgen',
entry_fields=['GP'],
sample_file=tmp + '.sample')
assert bgen._same(bgen2)

@fails_service_backend()
@fails_local_backend()
def test_export_bgen_parallel(self):
Expand Down
36 changes: 24 additions & 12 deletions hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala
Expand Up @@ -9,10 +9,11 @@ import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage}
import is.hail.expr.ir.streams.StreamProducer
import is.hail.expr.{JSONAnnotationImpex, Nat}
import is.hail.io._
import is.hail.io.bgen.BgenSettings
import is.hail.io.fs.FS
import is.hail.io.gen.{BgenWriter, ExportGen}
import is.hail.io.index.StagedIndexWriter
import is.hail.io.plink.{ExportPlink, BitPacker}
import is.hail.io.plink.{BitPacker, ExportPlink}
import is.hail.io.vcf.{ExportVCF, TabixVCF}
import is.hail.linalg.BlockMatrix
import is.hail.rvd.{IndexSpec, RVDPartitioner, RVDSpecMaker}
Expand All @@ -23,8 +24,8 @@ import is.hail.types.physical.stypes.primitives._
import is.hail.types.physical.{PBooleanRequired, PCanonicalBaseStruct, PCanonicalString, PCanonicalStruct, PInt64, PStruct, PType}
import is.hail.types.virtual._
import is.hail.types._
import is.hail.types.physical.stypes.concrete.{SJavaString, SJavaArrayString, SJavaArrayStringValue, SStackStruct}
import is.hail.types.physical.stypes.interfaces.{SIndexableValue, SBaseStructValue, SStringValue}
import is.hail.types.physical.stypes.concrete.{SJavaArrayString, SJavaArrayStringValue, SJavaString, SStackStruct}
import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SIndexableValue, SStringValue}
import is.hail.types.physical.stypes.primitives.{SBooleanValue, SInt64Value}
import is.hail.utils._
import is.hail.utils.richUtils.ByteTrackingOutputStream
Expand Down Expand Up @@ -951,7 +952,8 @@ final class GenSampleWriter extends SimplePartitionWriter {

case class MatrixBGENWriter(
path: String,
exportType: String
exportType: String,
compressionCodec: String
) extends MatrixWriter {
def apply(ctx: ExecuteContext, mv: MatrixValue): Unit = {
val tv = mv.toTableValue
Expand All @@ -972,8 +974,13 @@ case class MatrixBGENWriter(
else
path + ".bgen"

assert(compressionCodec == "zlib" || compressionCodec == "zstd")
val writeHeader = exportType == ExportType.PARALLEL_HEADER_IN_SHARD
val partWriter = BGENPartitionWriter(tm, entriesFieldName, writeHeader)
val compressionInt = compressionCodec match {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to make this some kind of enum, or if that's too much, named constants?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

case "zlib" => BgenSettings.ZLIB_COMPRESSION
case "zstd" => BgenSettings.ZSTD_COMPRESSION
}
val partWriter = BGENPartitionWriter(tm, entriesFieldName, writeHeader, compressionInt)

ts.mapContexts { oldCtx =>
val d = digitsNeeded(ts.numPartitions)
Expand All @@ -996,13 +1003,13 @@ case class MatrixBGENWriter(
WritePartition(rows, ctx, partWriter)
}{ (results, globals) =>
val ctx = MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "results" -> results))
val commit = BGENExportFinalizer(tm, path, exportType)
val commit = BGENExportFinalizer(tm, path, exportType, compressionInt)
Begin(FastIndexedSeq(WriteMetadata(ctx, commit)))
}
}
}

case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHeader: Boolean) extends PartitionWriter {
case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHeader: Boolean, compression: Int) extends PartitionWriter {
val ctxType: Type = TStruct("cols" -> TArray(typ.colType), "numVariants" -> TInt64, "partFile" -> TString)
override def returnType: TStruct = TStruct("partFile" -> TString, "numVariants" -> TInt64, "dropped" -> TInt64)
def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = {
Expand All @@ -1027,7 +1034,7 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH
cb += (sampleIds(i) = s.loadString(cb))
}
val numVariants = ctx.loadField(cb, "numVariants").get(cb).asInt64.value
val header = Code.invokeScalaObject2[Array[String], Long, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants)
val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression)
cb += os.invoke[Array[Byte], Unit]("write", header)
}

Expand Down Expand Up @@ -1145,7 +1152,12 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH
// end emitGPData

val uncompLen = cb.memoize(uncompBuf.invoke[Int]("size"))
val compLen = cb.memoize(Code.invokeScalaObject2[ByteArrayBuilder, Array[Byte], Int](CompressionUtils.getClass, "compress", buf, uncompBuf.invoke[Array[Byte]]("result")))

val compMethod = compression match {
case 1 => "compressZlib"
case 2 => "compressZstd"
}
val compLen = cb.memoize(Code.invokeScalaObject2[ByteArrayBuilder, Array[Byte], Int](CompressionUtils.getClass, compMethod, buf, uncompBuf.invoke[Array[Byte]]("result")))

updateIntToBytesLE(cb, buf, cb.memoize(compLen + 4), gtDataBlockStart)
updateIntToBytesLE(cb, buf, uncompLen, cb.memoize(gtDataBlockStart + 4))
Expand All @@ -1154,7 +1166,7 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH
}
}

case class BGENExportFinalizer(typ: MatrixType, path: String, exportType: String) extends MetadataWriter {
case class BGENExportFinalizer(typ: MatrixType, path: String, exportType: String, compression: Int) extends MetadataWriter {
def annotationType: Type = TStruct("cols" -> TArray(typ.colType), "results" -> TArray(TStruct("partFile" -> TString, "numVariants" -> TInt64, "dropped" -> TInt64)))

def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]): Unit = {
Expand Down Expand Up @@ -1195,14 +1207,14 @@ case class BGENExportFinalizer(typ: MatrixType, path: String, exportType: String

if (exportType == ExportType.PARALLEL_SEPARATE_HEADER) {
val os = cb.memoize(cb.emb.create(const(path + ".bgen").concat("/header")))
val header = Code.invokeScalaObject2[Array[String], Long, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants)
val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression)
cb += os.invoke[Array[Byte], Unit]("write", header)
cb += os.invoke[Unit]("close")
}

if (exportType == ExportType.CONCATENATED) {
val os = cb.memoize(cb.emb.create(const(path + ".bgen")))
val header = Code.invokeScalaObject2[Array[String], Long, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants)
val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression)
cb += os.invoke[Array[Byte], Unit]("write", header)

annotations.loadField(cb, "results").get(cb).asIndexable.forEachDefined(cb) { (cb, i, res) =>
Expand Down
Expand Up @@ -684,6 +684,10 @@ final class ByteArrayBuilder(initialCapacity: Int = 16) {
size_ = n
}

def setSizeUnchecked(n: Int) {
size_ = n
}

def apply(i: Int): Byte = {
require(i >= 0 && i < size)
b(i)
Expand Down
9 changes: 4 additions & 5 deletions hail/src/main/scala/is/hail/io/bgen/BgenRDD.scala
Expand Up @@ -23,6 +23,10 @@ import org.apache.spark.{OneToOneDependency, Partition, TaskContext}
import scala.language.reflectiveCalls

object BgenSettings {
val UNCOMPRESSED = 0x0
val ZLIB_COMPRESSION = 0x1
val ZSTD_COMPRESSION = 0x2

def indexKeyType(rg: Option[ReferenceGenome]): TStruct = TStruct(
"locus" -> rg.map(TLocus(_)).getOrElse(TLocus.representation),
"alleles" -> TArray(TString))
Expand Down Expand Up @@ -140,11 +144,6 @@ object BgenRDD {

ContextRDD(new BgenRDD(ctx.fsBc, f, indexBuilder, partitions, settings, keys))
}

private[bgen] def decompress(
input: Array[Byte],
uncompressedSize: Int
): Array[Byte] = is.hail.utils.decompress(input, uncompressedSize)
}

private class BgenRDD(
Expand Down
25 changes: 17 additions & 8 deletions hail/src/main/scala/is/hail/io/bgen/BgenRDDPartitions.scala
Expand Up @@ -19,7 +19,7 @@ import org.apache.spark.Partition
trait BgenPartition extends Partition {
def path: String

def compressed: Boolean
def compression: Int // 0 uncompressed, 1 zlib, 2 zstd

def skipInvalidLoci: Boolean

Expand All @@ -38,7 +38,7 @@ private case class LoadBgenPartition(
path: String,
indexPath: String,
filterPartition: Partition,
compressed: Boolean,
compression: Int,
skipInvalidLoci: Boolean,
contigRecoding: Map[String, String],
partitionIndex: Int,
Expand Down Expand Up @@ -149,7 +149,7 @@ object BgenRDDPartitions extends Logging {
file.path,
file.indexPath,
filterPartition = null,
file.header.compressed,
file.header.compression,
file.skipInvalidLoci,
file.contigRecoding,
partitionIndex,
Expand Down Expand Up @@ -425,15 +425,24 @@ object CompileDecoder {
cb.define(LnoOp)
}

cb.ifx(cp.invoke[Boolean]("compressed"), {
val compression = cb.memoize(cp.invoke[Int]("compression"))
cb.ifx(compression ceq BgenSettings.UNCOMPRESSED, {
cb.assign(data, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize))
}, {
cb.assign(uncompressedSize, cbfis.invoke[Int]("readInt"))
cb.assign(input, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize - 4))
cb.assign(data, Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]](
BgenRDD.getClass, "decompress", input, uncompressedSize))
}, {
cb.assign(data, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize))
cb.ifx(compression ceq BgenSettings.ZLIB_COMPRESSION, {
cb.assign(data,
Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]](
CompressionUtils.getClass, "decompressZlib", input, uncompressedSize))
}, {
// zstd
cb.assign(data,Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]](
CompressionUtils.getClass, "decompressZstd", input, uncompressedSize))
})
})


cb.assign(reader, Code.newInstance[ByteArrayReader, Array[Byte]](data))
cb.assign(nRow, reader.invoke[Int]("readInt"))
cb.ifx(nRow.cne(settings.nSamples), cb._fatal(
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/io/bgen/IndexBgen.scala
Expand Up @@ -16,7 +16,7 @@ import org.apache.spark.{Partition, TaskContext}

private case class IndexBgenPartition(
path: String,
compressed: Boolean,
compression: Int,
skipInvalidLoci: Boolean,
contigRecoding: Map[String, String],
startByteOffset: Long,
Expand Down Expand Up @@ -83,7 +83,7 @@ object IndexBgen {
val partitions: Array[Partition] = headers.zipWithIndex.map { case (f, i) =>
IndexBgenPartition(
f.path,
f.compressed,
f.compression,
skipInvalidLoci,
recoding,
f.dataStart,
Expand Down
10 changes: 4 additions & 6 deletions hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala
Expand Up @@ -23,7 +23,7 @@ import org.json4s.{DefaultFormats, Formats, JObject, JValue}
import scala.io.Source

case class BgenHeader(
compressed: Boolean,
compression: Int, // 0 uncompressed, 1 zlib, 2 zstd
nSamples: Int,
nVariants: Int,
headerLength: Int,
Expand Down Expand Up @@ -113,18 +113,16 @@ object LoadBgen {
val flags = is.readInt()
val compressType = flags & 3

if (compressType != 0 && compressType != 1)
fatal(s"Hail only supports zlib compression.")

val isCompressed = compressType != 0
if (compressType != 0 && compressType != 1 && compressType != 2)
fatal(s"Hail only supports zlib or zstd compression.")

val version = (flags >>> 2) & 0xf
if (version != 2)
fatal(s"Hail supports BGEN version 1.2, got version 1.$version")

val hasIds = (flags >> 31 & 1) != 0
BgenHeader(
isCompressed,
compressType,
nSamples,
nVariants,
headerLength,
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/io/gen/ExportBGEN.scala
Expand Up @@ -53,13 +53,13 @@ object BgenWriter {
bb(pos + 3) = ((i >>> 24) & 0xff).toByte
}

def headerBlock(sampleIds: Array[String], nVariants: Long): Array[Byte] = {
def headerBlock(sampleIds: Array[String], nVariants: Long, compression: Int): Array[Byte] = {
val bb = new ByteArrayBuilder()
val nSamples = sampleIds.length
assert(nVariants < (1L << 32))

val magicNumbers = Array("b", "g", "e", "n").flatMap(_.getBytes)
val flags = 0x01 | (0x02 << 2) | (0x01 << 31)
val flags = compression | (0x02 << 2) | (0x01 << 31)
val headerLength = 20

intToBytesLE(bb, 0) // placeholder for offset
Expand Down