Skip to content

Commit

Permalink
Deepvariant supports csi format for vcf indexing. The csi format will
Browse files Browse the repository at this point in the history
be applied automatically if the referecnce genome is longer than 512M.
For reference genomes less than 512M, the VCF format will be tabix format
by default.

PiperOrigin-RevId: 261926352
  • Loading branch information
Genomics team in Google Brain authored and Copybara-Service committed Aug 6, 2019
1 parent 1b813bb commit 667d09e
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 6 deletions.
38 changes: 35 additions & 3 deletions deepvariant/postprocess_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,37 @@ def _get_stats_paths(input_vcf):
return template.format('per_record'), template.format('summary')


def _decide_to_use_csi(contigs):
"""Return True if CSI index is to be used over tabix index format.
If the length of any reference chromosomes exceeds 512M
(here we use 5e8 to keep a safety margin), we will choose csi
as the index format. Otherwise we use tbi as default.
Args:
contigs: list of contigs.
Returns:
A boolean variable indicating if the csi format is to be used or not.
"""
max_chrom_length = max([c.n_bases for c in contigs])
return max_chrom_length > 5e8


def build_index(vcf_file, csi=False):
"""A helper function for indexing VCF files.
Args:
vcf_file: string. Path to the VCF file to be indexed.
csi: bool. If true, index using the CSI format.
"""

if csi:
tabix.build_csi_index(vcf_file, min_shift=14)
else:
tabix.build_index(vcf_file)


def main(argv=()):
with errors.clean_commandline_error_exit():
if len(argv) > 1:
Expand Down Expand Up @@ -925,6 +956,7 @@ def main(argv=()):
sample_name = _extract_single_sample_name(record)
header = dv_vcf_constants.deepvariant_header(
contigs=contigs, sample_names=[sample_name])
use_csi = _decide_to_use_csi(contigs)
with tempfile.NamedTemporaryFile() as temp:
start_time = time.time()
postprocess_variants_lib.process_single_sites_tfrecords(
Expand All @@ -950,7 +982,7 @@ def main(argv=()):
output_vcf_path=FLAGS.outfile,
header=header)
if FLAGS.outfile.endswith('.gz'):
tabix.build_index(FLAGS.outfile)
build_index(FLAGS.outfile, use_csi)
logging.info('VCF creation took %s minutes',
(time.time() - start_time) / 60)
else:
Expand All @@ -969,9 +1001,9 @@ def main(argv=()):
variant_generator, nonvariant_generator, lessthanfn, fasta_reader,
vcf_writer, gvcf_writer)
if FLAGS.outfile.endswith('.gz'):
tabix.build_index(FLAGS.outfile)
build_index(FLAGS.outfile, use_csi)
if FLAGS.gvcf_outfile.endswith('.gz'):
tabix.build_index(FLAGS.gvcf_outfile)
build_index(FLAGS.gvcf_outfile, use_csi)
logging.info('Finished writing VCF and gVCF in %s minutes.',
(time.time() - start_time) / 60)
if FLAGS.create_vcf_stats:
Expand Down
16 changes: 16 additions & 0 deletions deepvariant/postprocess_variants_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import gzip
import io
import itertools
import os
import shutil
import sys


Expand Down Expand Up @@ -313,6 +315,20 @@ def _read_contents(path, decompress=False):
self.assertTrue(tf.gfile.Exists(FLAGS.outfile + '.tbi'))
self.assertTrue(tf.gfile.Exists(FLAGS.gvcf_outfile + '.tbi'))

@parameterized.parameters(False, True)
def test_build_index(self, use_csi):
vcf_file_gz = os.path.join(absltest.get_default_test_tmpdir(),
'call_test_id_%s.vcf.gz' % (use_csi))
shutil.copy(testdata.GOLDEN_POSTPROCESS_OUTPUT_COMPRESSED, vcf_file_gz)
postprocess_variants.build_index(vcf_file_gz, use_csi)

if use_csi:
self.assertFalse(tf.gfile.Exists(vcf_file_gz + '.tbi'))
self.assertTrue(tf.gfile.Exists(vcf_file_gz + '.csi'))
else:
self.assertFalse(tf.gfile.Exists(vcf_file_gz + '.csi'))
self.assertTrue(tf.gfile.Exists(vcf_file_gz + '.tbi'))

@flagsaver.FlagSaver
def test_reading_sharded_input_with_empty_shards_does_not_crash(self):
valid_variants = tfrecord.read_tfrecords(
Expand Down
2 changes: 2 additions & 0 deletions deepvariant/testdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def init():
'golden.postprocess_single_site_input.tfrecord')
GOLDEN_POSTPROCESS_OUTPUT = deepvariant_testdata(
'golden.postprocess_single_site_output.vcf')
GOLDEN_POSTPROCESS_OUTPUT_COMPRESSED = deepvariant_testdata(
'golden.postprocess_single_site_output.vcf.gz')
GOLDEN_POSTPROCESS_GVCF_INPUT = deepvariant_testdata(
'golden.postprocess_gvcf_input.tfrecord')
GOLDEN_POSTPROCESS_GVCF_OUTPUT = deepvariant_testdata(
Expand Down
Binary file not shown.
1 change: 1 addition & 0 deletions third_party/nucleus/io/python/tabix_indexer.clif
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ from "third_party/nucleus/vendor/statusor_clif_converters.h" import *
from "third_party/nucleus/io/tabix_indexer.h":
namespace `nucleus`:
def `TbxIndexBuild` as tbx_index_build(path: str) -> Status
def `CSIIndexBuild` as csi_index_build(path: str, min_shift:int) -> Status
5 changes: 5 additions & 0 deletions third_party/nucleus/io/tabix.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@
def build_index(path):
"""Builds a tabix index for VCF at the specified path."""
tabix_indexer.tbx_index_build(path)


def build_csi_index(path, min_shift):
"""Builds a csi index for VCF at the specified path."""
tabix_indexer.csi_index_build(path, min_shift)
10 changes: 10 additions & 0 deletions third_party/nucleus/io/tabix_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,14 @@ tf::Status TbxIndexBuild(const string& path) {
return tf::Status::OK();
}

tf::Status CSIIndexBuild(string path, int min_shift) {
// Create a index file in CSI format by setting min_shift as a non-zero value.
int val = tbx_index_build_x(path, min_shift, &tbx_conf_vcf);
if (val < 0) {
LOG(WARNING) << "Return code: " << val << "\nFile path: " << path;
return tf::errors::Internal("Failure to write CSI index.");
}
return tf::Status::OK();
}

} // namespace nucleus
2 changes: 1 addition & 1 deletion third_party/nucleus/io/tabix_indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace nucleus {

// Builds a tabix index for bgzipped VCF at the specified path.
tensorflow::Status TbxIndexBuild(const string& path);

tensorflow::Status CSIIndexBuild(string path, int min_shift);
} // namespace nucleus

#endif // THIRD_PARTY_NUCLEUS_IO_TABIX_INDEXER_H_
24 changes: 24 additions & 0 deletions third_party/nucleus/io/tabix_indexer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,28 @@ TEST(TabixIndexerTest, IndexBuildsCorrectly) {
EXPECT_THAT(reader->Query(MakeRange("chr3", 14318, 14319)), IsOK());
}

TEST(CSIIndexerTest, IndexBuildsCorrectly) {
string output_filename = MakeTempFile("test_samples.vcf.gz");
string output_csi_index = output_filename + ".csi";

std::unique_ptr<nucleus::VcfReader> reader = std::move(
nucleus::VcfReader::FromFile(GetTestData(kVcfIndexSamplesFilename),
nucleus::genomics::v1::VcfReaderOptions())
.ValueOrDie());

nucleus::genomics::v1::VcfWriterOptions writer_options;
std::unique_ptr<VcfWriter> writer =
std::move(nucleus::VcfWriter::ToFile(output_filename, reader->Header(),
writer_options)
.ValueOrDie());

auto variants = nucleus::as_vector(reader->Iterate());
for (const auto& v : variants) {
TF_CHECK_OK(writer->Write(v));
}

EXPECT_THAT(CSIIndexBuild(output_filename, 14), IsOK());
EXPECT_THAT(tensorflow::Env::Default()->FileExists(output_csi_index), IsOK());
EXPECT_THAT(reader->Query(MakeRange("chr3", 14318, 14319)), IsOK());
}
} // namespace nucleus
25 changes: 23 additions & 2 deletions third_party/nucleus/io/tabix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def setUp(self):
self.output_file = test_utils.test_tmpfile('test_samples.vcf.gz')
shutil.copyfile(self.input_file, self.output_file)
self.tbx_index_file = self.output_file + '.tbi'
self.csi_index_file = self.output_file + '.csi'

def tearDown(self):
super(TabixTest, self).tearDown()
Expand All @@ -65,13 +66,23 @@ def tearDown(self):
os.remove(self.tbx_index_file)
except OSError:
pass
try:
os.remove(self.csi_index_file)
except OSError:
pass

def test_build_index(self):
def test_build_index_tbx(self):
self.assertFalse(gfile.Exists(self.tbx_index_file))
tabix.build_index(self.output_file)
self.assertTrue(gfile.Exists(self.tbx_index_file))

def test_vcf_query(self):
def test_build_index_csi(self):
min_shift = 14
self.assertFalse(gfile.Exists(self.csi_index_file))
tabix.build_csi_index(self.output_file, min_shift)
self.assertTrue(gfile.Exists(self.csi_index_file))

def test_vcf_query_tbx(self):
tabix.build_index(self.output_file)
self.input_reader = vcf.VcfReader(self.input_file)
self.output_reader = vcf.VcfReader(self.output_file)
Expand All @@ -81,6 +92,16 @@ def test_vcf_query(self):
list(self.input_reader.query(range1)),
list(self.output_reader.query(range1)))

def test_vcf_query_csi(self):
min_shift = 14
tabix.build_csi_index(self.output_file, min_shift)
self.input_reader = vcf.VcfReader(self.input_file)
self.output_reader = vcf.VcfReader(self.output_file)

range1 = ranges.parse_literal('chr3:100,000-500,000')
self.assertEqual(
list(self.input_reader.query(range1)),
list(self.output_reader.query(range1)))

if __name__ == '__main__':
absltest.main()

0 comments on commit 667d09e

Please sign in to comment.