Skip to content

Commit

Permalink
[query] Support zstd compression in BGEN files (hail-is#12576)
Browse files Browse the repository at this point in the history
* [query] Support zstd compression in BGEN files

CHANGELOG: `hl.import_bgen` and `hl.export_bgen` now support compression with Zstd.

* address comments
  • Loading branch information
tpoterba authored and danking committed Jan 30, 2023
1 parent cf1af97 commit 2d428d6
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 46 deletions.
2 changes: 2 additions & 0 deletions hail/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ dependencies {
bundled group: 'org.freemarker', name: 'freemarker', version: '2.3.31'

bundled 'com.kohlschutter.junixsocket:junixsocket-core:2.6.1'

bundled 'com.github.luben:zstd-jni:1.4.8-1'
}

task(checkSettings) doLast {
Expand Down
11 changes: 7 additions & 4 deletions hail/python/hail/ir/matrix_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,24 @@ def __eq__(self, other):


class MatrixBGENWriter(MatrixWriter):
@typecheck_method(path=str, export_type=ExportType.checker)
def __init__(self, path, export_type):
@typecheck_method(path=str, export_type=ExportType.checker, compression_codec=str)
def __init__(self, path, export_type, compression_codec):
self.path = path
self.export_type = export_type
self.compression_codec = compression_codec

def render(self):
writer = {'name': 'MatrixBGENWriter',
'path': self.path,
'exportType': self.export_type}
'exportType': self.export_type,
'compressionCodec': self.compression_codec}
return escape_str(json.dumps(writer))

def __eq__(self, other):
return isinstance(other, MatrixBGENWriter) and \
other.path == self.path and \
other.export_type == self.export_type
other.export_type == self.export_type and \
other.compression_codec == self.compression_codec


class MatrixPLINKWriter(MatrixWriter):
Expand Down
10 changes: 7 additions & 3 deletions hail/python/hail/methods/impex.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None,
gp=nullable(expr_array(expr_float64)),
varid=nullable(expr_str),
rsid=nullable(expr_str),
parallel=nullable(ir.ExportType.checker))
def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None):
parallel=nullable(ir.ExportType.checker),
compression_codec=enumeration('zlib', 'zstd'))
def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None, compression_codec='zlib'):
"""Export MatrixTable as :class:`.MatrixTable` as BGEN 1.2 file with 8
bits of per probability. Also writes SAMPLE file.
Expand Down Expand Up @@ -206,6 +207,8 @@ def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None):
per partition), each with its own header. If
``'separate_header'``, write a file for each partition,
without header, and a header file for the combined dataset.
compresssion_codec : str, optional
Compression codec. One of 'zlib', 'zstd'.
"""
require_row_key_variant(mt, 'export_bgen')
require_col_key_str(mt, 'export_bgen')
Expand Down Expand Up @@ -244,7 +247,8 @@ def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None):

Env.backend().execute(ir.MatrixWrite(mt._mir, ir.MatrixBGENWriter(
output,
parallel)))
parallel,
compression_codec)))


@typecheck(dataset=MatrixTable,
Expand Down
14 changes: 14 additions & 0 deletions hail/python/test/hail/methods/test_impex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,20 @@ def test_export_bgen(self):
sample_file=tmp + '.sample')
assert bgen._same(bgen2)

@fails_service_backend()
@fails_local_backend()
def test_export_bgen_zstd(self):
bgen = hl.import_bgen(resource('example.8bits.bgen'),
entry_fields=['GP'],
sample_file=resource('example.sample'))
tmp = new_temp_file("zstd")
hl.export_bgen(bgen, tmp, compression_codec='zstd')
hl.index_bgen(tmp + '.bgen')
bgen2 = hl.import_bgen(tmp + '.bgen',
entry_fields=['GP'],
sample_file=tmp + '.sample')
assert bgen._same(bgen2)

@fails_service_backend()
@fails_local_backend()
def test_export_bgen_parallel(self):
Expand Down
36 changes: 24 additions & 12 deletions hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage}
import is.hail.expr.ir.streams.StreamProducer
import is.hail.expr.{JSONAnnotationImpex, Nat}
import is.hail.io._
import is.hail.io.bgen.BgenSettings
import is.hail.io.fs.FS
import is.hail.io.gen.{BgenWriter, ExportGen}
import is.hail.io.index.StagedIndexWriter
import is.hail.io.plink.{ExportPlink, BitPacker}
import is.hail.io.plink.{BitPacker, ExportPlink}
import is.hail.io.vcf.{ExportVCF, TabixVCF}
import is.hail.linalg.BlockMatrix
import is.hail.rvd.{IndexSpec, RVDPartitioner, RVDSpecMaker}
Expand All @@ -23,8 +24,8 @@ import is.hail.types.physical.stypes.primitives._
import is.hail.types.physical.{PBooleanRequired, PCanonicalBaseStruct, PCanonicalString, PCanonicalStruct, PInt64, PStruct, PType}
import is.hail.types.virtual._
import is.hail.types._
import is.hail.types.physical.stypes.concrete.{SJavaString, SJavaArrayString, SJavaArrayStringValue, SStackStruct}
import is.hail.types.physical.stypes.interfaces.{SIndexableValue, SBaseStructValue, SStringValue}
import is.hail.types.physical.stypes.concrete.{SJavaArrayString, SJavaArrayStringValue, SJavaString, SStackStruct}
import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SIndexableValue, SStringValue}
import is.hail.types.physical.stypes.primitives.{SBooleanValue, SInt64Value}
import is.hail.utils._
import is.hail.utils.richUtils.ByteTrackingOutputStream
Expand Down Expand Up @@ -951,7 +952,8 @@ final class GenSampleWriter extends SimplePartitionWriter {

case class MatrixBGENWriter(
path: String,
exportType: String
exportType: String,
compressionCodec: String
) extends MatrixWriter {
def apply(ctx: ExecuteContext, mv: MatrixValue): Unit = {
val tv = mv.toTableValue
Expand All @@ -972,8 +974,13 @@ case class MatrixBGENWriter(
else
path + ".bgen"

assert(compressionCodec == "zlib" || compressionCodec == "zstd")
val writeHeader = exportType == ExportType.PARALLEL_HEADER_IN_SHARD
val partWriter = BGENPartitionWriter(tm, entriesFieldName, writeHeader)
val compressionInt = compressionCodec match {
case "zlib" => BgenSettings.ZLIB_COMPRESSION
case "zstd" => BgenSettings.ZSTD_COMPRESSION
}
val partWriter = BGENPartitionWriter(tm, entriesFieldName, writeHeader, compressionInt)

ts.mapContexts { oldCtx =>
val d = digitsNeeded(ts.numPartitions)
Expand All @@ -996,13 +1003,13 @@ case class MatrixBGENWriter(
WritePartition(rows, ctx, partWriter)
}{ (results, globals) =>
val ctx = MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "results" -> results))
val commit = BGENExportFinalizer(tm, path, exportType)
val commit = BGENExportFinalizer(tm, path, exportType, compressionInt)
Begin(FastIndexedSeq(WriteMetadata(ctx, commit)))
}
}
}

case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHeader: Boolean) extends PartitionWriter {
case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHeader: Boolean, compression: Int) extends PartitionWriter {
val ctxType: Type = TStruct("cols" -> TArray(typ.colType), "numVariants" -> TInt64, "partFile" -> TString)
override def returnType: TStruct = TStruct("partFile" -> TString, "numVariants" -> TInt64, "dropped" -> TInt64)
def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = {
Expand All @@ -1027,7 +1034,7 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH
cb += (sampleIds(i) = s.loadString(cb))
}
val numVariants = ctx.loadField(cb, "numVariants").get(cb).asInt64.value
val header = Code.invokeScalaObject2[Array[String], Long, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants)
val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression)
cb += os.invoke[Array[Byte], Unit]("write", header)
}

Expand Down Expand Up @@ -1145,7 +1152,12 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH
// end emitGPData

val uncompLen = cb.memoize(uncompBuf.invoke[Int]("size"))
val compLen = cb.memoize(Code.invokeScalaObject2[ByteArrayBuilder, Array[Byte], Int](CompressionUtils.getClass, "compress", buf, uncompBuf.invoke[Array[Byte]]("result")))

val compMethod = compression match {
case 1 => "compressZlib"
case 2 => "compressZstd"
}
val compLen = cb.memoize(Code.invokeScalaObject2[ByteArrayBuilder, Array[Byte], Int](CompressionUtils.getClass, compMethod, buf, uncompBuf.invoke[Array[Byte]]("result")))

updateIntToBytesLE(cb, buf, cb.memoize(compLen + 4), gtDataBlockStart)
updateIntToBytesLE(cb, buf, uncompLen, cb.memoize(gtDataBlockStart + 4))
Expand All @@ -1154,7 +1166,7 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH
}
}

case class BGENExportFinalizer(typ: MatrixType, path: String, exportType: String) extends MetadataWriter {
case class BGENExportFinalizer(typ: MatrixType, path: String, exportType: String, compression: Int) extends MetadataWriter {
def annotationType: Type = TStruct("cols" -> TArray(typ.colType), "results" -> TArray(TStruct("partFile" -> TString, "numVariants" -> TInt64, "dropped" -> TInt64)))

def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]): Unit = {
Expand Down Expand Up @@ -1195,14 +1207,14 @@ case class BGENExportFinalizer(typ: MatrixType, path: String, exportType: String

if (exportType == ExportType.PARALLEL_SEPARATE_HEADER) {
val os = cb.memoize(cb.emb.create(const(path + ".bgen").concat("/header")))
val header = Code.invokeScalaObject2[Array[String], Long, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants)
val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression)
cb += os.invoke[Array[Byte], Unit]("write", header)
cb += os.invoke[Unit]("close")
}

if (exportType == ExportType.CONCATENATED) {
val os = cb.memoize(cb.emb.create(const(path + ".bgen")))
val header = Code.invokeScalaObject2[Array[String], Long, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants)
val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression)
cb += os.invoke[Array[Byte], Unit]("write", header)

annotations.loadField(cb, "results").get(cb).asIndexable.forEachDefined(cb) { (cb, i, res) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,10 @@ final class ByteArrayBuilder(initialCapacity: Int = 16) {
size_ = n
}

def setSizeUnchecked(n: Int) {
size_ = n
}

def apply(i: Int): Byte = {
require(i >= 0 && i < size)
b(i)
Expand Down
9 changes: 4 additions & 5 deletions hail/src/main/scala/is/hail/io/bgen/BgenRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import org.apache.spark.{OneToOneDependency, Partition, TaskContext}
import scala.language.reflectiveCalls

object BgenSettings {
val UNCOMPRESSED = 0x0
val ZLIB_COMPRESSION = 0x1
val ZSTD_COMPRESSION = 0x2

def indexKeyType(rg: Option[ReferenceGenome]): TStruct = TStruct(
"locus" -> rg.map(TLocus(_)).getOrElse(TLocus.representation),
"alleles" -> TArray(TString))
Expand Down Expand Up @@ -140,11 +144,6 @@ object BgenRDD {

ContextRDD(new BgenRDD(ctx.fsBc, f, indexBuilder, partitions, settings, keys))
}

private[bgen] def decompress(
input: Array[Byte],
uncompressedSize: Int
): Array[Byte] = is.hail.utils.decompress(input, uncompressedSize)
}

private class BgenRDD(
Expand Down
25 changes: 17 additions & 8 deletions hail/src/main/scala/is/hail/io/bgen/BgenRDDPartitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import org.apache.spark.Partition
trait BgenPartition extends Partition {
def path: String

def compressed: Boolean
def compression: Int // 0 uncompressed, 1 zlib, 2 zstd

def skipInvalidLoci: Boolean

Expand All @@ -38,7 +38,7 @@ private case class LoadBgenPartition(
path: String,
indexPath: String,
filterPartition: Partition,
compressed: Boolean,
compression: Int,
skipInvalidLoci: Boolean,
contigRecoding: Map[String, String],
partitionIndex: Int,
Expand Down Expand Up @@ -149,7 +149,7 @@ object BgenRDDPartitions extends Logging {
file.path,
file.indexPath,
filterPartition = null,
file.header.compressed,
file.header.compression,
file.skipInvalidLoci,
file.contigRecoding,
partitionIndex,
Expand Down Expand Up @@ -425,15 +425,24 @@ object CompileDecoder {
cb.define(LnoOp)
}

cb.ifx(cp.invoke[Boolean]("compressed"), {
val compression = cb.memoize(cp.invoke[Int]("compression"))
cb.ifx(compression ceq BgenSettings.UNCOMPRESSED, {
cb.assign(data, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize))
}, {
cb.assign(uncompressedSize, cbfis.invoke[Int]("readInt"))
cb.assign(input, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize - 4))
cb.assign(data, Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]](
BgenRDD.getClass, "decompress", input, uncompressedSize))
}, {
cb.assign(data, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize))
cb.ifx(compression ceq BgenSettings.ZLIB_COMPRESSION, {
cb.assign(data,
Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]](
CompressionUtils.getClass, "decompressZlib", input, uncompressedSize))
}, {
// zstd
cb.assign(data,Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]](
CompressionUtils.getClass, "decompressZstd", input, uncompressedSize))
})
})


cb.assign(reader, Code.newInstance[ByteArrayReader, Array[Byte]](data))
cb.assign(nRow, reader.invoke[Int]("readInt"))
cb.ifx(nRow.cne(settings.nSamples), cb._fatal(
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/io/bgen/IndexBgen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import org.apache.spark.{Partition, TaskContext}

private case class IndexBgenPartition(
path: String,
compressed: Boolean,
compression: Int,
skipInvalidLoci: Boolean,
contigRecoding: Map[String, String],
startByteOffset: Long,
Expand Down Expand Up @@ -83,7 +83,7 @@ object IndexBgen {
val partitions: Array[Partition] = headers.zipWithIndex.map { case (f, i) =>
IndexBgenPartition(
f.path,
f.compressed,
f.compression,
skipInvalidLoci,
recoding,
f.dataStart,
Expand Down
10 changes: 4 additions & 6 deletions hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.json4s.{DefaultFormats, Formats, JObject, JValue}
import scala.io.Source

case class BgenHeader(
compressed: Boolean,
compression: Int, // 0 uncompressed, 1 zlib, 2 zstd
nSamples: Int,
nVariants: Int,
headerLength: Int,
Expand Down Expand Up @@ -113,18 +113,16 @@ object LoadBgen {
val flags = is.readInt()
val compressType = flags & 3

if (compressType != 0 && compressType != 1)
fatal(s"Hail only supports zlib compression.")

val isCompressed = compressType != 0
if (compressType != 0 && compressType != 1 && compressType != 2)
fatal(s"Hail only supports zlib or zstd compression.")

val version = (flags >>> 2) & 0xf
if (version != 2)
fatal(s"Hail supports BGEN version 1.2, got version 1.$version")

val hasIds = (flags >> 31 & 1) != 0
BgenHeader(
isCompressed,
compressType,
nSamples,
nVariants,
headerLength,
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/io/gen/ExportBGEN.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ object BgenWriter {
bb(pos + 3) = ((i >>> 24) & 0xff).toByte
}

def headerBlock(sampleIds: Array[String], nVariants: Long): Array[Byte] = {
def headerBlock(sampleIds: Array[String], nVariants: Long, compression: Int): Array[Byte] = {
val bb = new ByteArrayBuilder()
val nSamples = sampleIds.length
assert(nVariants < (1L << 32))

val magicNumbers = Array("b", "g", "e", "n").flatMap(_.getBytes)
val flags = 0x01 | (0x02 << 2) | (0x01 << 31)
val flags = compression | (0x02 << 2) | (0x01 << 31)
val headerLength = 20

intToBytesLE(bb, 0) // placeholder for offset
Expand Down

0 comments on commit 2d428d6

Please sign in to comment.