Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions gcp_variant_transforms/beam_io/vcfio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from apache_beam.transforms import PTransform
from apache_beam.transforms.display import DisplayDataItem

from gcp_variant_transforms.libs import vcf_header_parser

__all__ = ['ReadFromVcf', 'ReadAllFromVcf', 'Variant', 'VariantCall',
'VariantInfo', 'MalformedVcfRecord']

Expand All @@ -46,7 +48,11 @@
# - 'A': one value per alternate allele.
# - 'G': one value for each possible genotype.
# - 'R': one value for each possible allele (including the reference).
VariantInfo = namedtuple('VariantInfo', ['data', 'field_count'])
# `annotation_names` is only filled for the annotation field and it is the
# list of annotation names extracted from the description part of the annotation
# field metadata in the VCF header.
VariantInfo = namedtuple('VariantInfo',
['data', 'field_count', 'annotation_names'])
# Stores data about failed VCF record reads. `line` is the text line that
# caused the failed read and `file_name` is the name of the file that the read
# failed in.
Expand Down Expand Up @@ -652,14 +658,16 @@ def __init__(self,
compression_type=CompressionTypes.AUTO,
buffer_size=DEFAULT_VCF_READ_BUFFER_SIZE,
validate=True,
allow_malformed_records=False):
allow_malformed_records=False,
annotation_field=None):
super(_VcfSource, self).__init__(file_pattern,
compression_type=compression_type,
validate=validate)

self._compression_type = compression_type
self._buffer_size = buffer_size
self._allow_malformed_records = allow_malformed_records
self._annotation_field = annotation_field

def read_records(self, file_name, range_tracker):
record_iterator = _VcfSource._VcfRecordIterator(
Expand All @@ -668,6 +676,7 @@ def read_records(self, file_name, range_tracker):
self._pattern,
self._compression_type,
self._allow_malformed_records,
annotation_field=self._annotation_field,
buffer_size=self._buffer_size,
skip_header_lines=0)

Expand All @@ -684,11 +693,13 @@ def __init__(self,
file_pattern,
compression_type,
allow_malformed_records,
annotation_field=None,
**kwargs):
self._header_lines = []
self._last_record = None
self._file_name = file_name
self._allow_malformed_records = allow_malformed_records
self._annotation_field = annotation_field

text_source = _TextSource(
file_pattern,
Expand Down Expand Up @@ -795,10 +806,19 @@ def _get_variant_info(self, record, infos):
for k, v in record.INFO.iteritems():
if k != END_INFO_KEY:
field_count = None
annotation_names = None
if k in infos:
field_count = self._get_field_count_as_string(infos[k].num)
info[k] = VariantInfo(data=v, field_count=field_count)

if k == self._annotation_field:
annotation_names = vcf_header_parser.extract_annotation_names(
infos[k].desc)
# TODO(bashir2): The reason we keep annotation_names with each variant
# is to do better merging, e.g., when some variants from two VCF files
# have different annotations. This merging logic needs to be
# implemented though.
info[k] = VariantInfo(data=v,
field_count=field_count,
annotation_names=annotation_names)
return info

def _get_field_count_as_string(self, field_count):
Expand Down Expand Up @@ -880,6 +900,7 @@ def __init__(
compression_type=CompressionTypes.AUTO,
validate=True,
allow_malformed_records=False,
annotation_field=None,
**kwargs):
"""Initialize the :class:`ReadFromVcf` transform.

Expand All @@ -892,23 +913,29 @@ def __init__(
underlying file_path's extension will be used to detect the compression.
validate (bool): flag to verify that the files exist during the pipeline
creation time.
annotation_field (str): If set, it is the field which will be treated as
annotation field, i.e., the description from header is split and copied
into the `VariantInfo.annotation_names` field of each variant.
"""
super(ReadFromVcf, self).__init__(**kwargs)
self._source = _VcfSource(
file_pattern,
compression_type,
validate=validate,
allow_malformed_records=allow_malformed_records)
allow_malformed_records=allow_malformed_records,
annotation_field=annotation_field)

def expand(self, pvalue):
return pvalue.pipeline | Read(self._source)


def _create_vcf_source(
file_pattern=None, compression_type=None, allow_malformed_records=None):
file_pattern=None, compression_type=None, allow_malformed_records=None,
annotation_field=None):
return _VcfSource(file_pattern=file_pattern,
compression_type=compression_type,
allow_malformed_records=allow_malformed_records)
allow_malformed_records=allow_malformed_records,
annotation_field=annotation_field)


class ReadAllFromVcf(PTransform):
Expand All @@ -930,6 +957,7 @@ def __init__(
desired_bundle_size=DEFAULT_DESIRED_BUNDLE_SIZE,
compression_type=CompressionTypes.AUTO,
allow_malformed_records=False,
annotation_field=None,
**kwargs):
"""Initialize the :class:`ReadAllFromVcf` transform.

Expand All @@ -945,11 +973,15 @@ def __init__(
allow_malformed_records (bool): If true, malformed records from VCF files
will be returned as :class:`MalformedVcfRecord` instead of failing
the pipeline.
annotation_field (`str`): If set, that is the field which will be treated
as annotation field, i.e., the description from header is split and
copied into the `VariantInfo.annotation_names` field of each variant.
"""
super(ReadAllFromVcf, self).__init__(**kwargs)
source_from_file = partial(
_create_vcf_source, compression_type=compression_type,
allow_malformed_records=allow_malformed_records)
allow_malformed_records=allow_malformed_records,
annotation_field=annotation_field)
self._read_all_files = filebasedsource.ReadAllFiles(
True, # splittable
CompressionTypes.AUTO, desired_bundle_size,
Expand Down
48 changes: 32 additions & 16 deletions gcp_variant_transforms/beam_io/vcfio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def _get_sample_variant_1():
reference_name='20', start=1233, end=1234, reference_bases='C',
alternate_bases=['A', 'T'], names=['rs123', 'rs2'], quality=50,
filters=['PASS'],
info={'AF': vcfio.VariantInfo(data=[0.5, 0.1], field_count='A'),
'NS': vcfio.VariantInfo(data=1, field_count='1')})
info={'AF': vcfio.VariantInfo(data=[0.5, 0.1], field_count='A',
annotation_names=None),
'NS': vcfio.VariantInfo(data=1, field_count='1',
annotation_names=None)})
variant.calls.append(
vcfio.VariantCall(name='Sample1', genotype=[0, 0], info={'GQ': 48}))
variant.calls.append(
Expand All @@ -103,7 +105,8 @@ def _get_sample_variant_2():
reference_name='19', start=122, end=125, reference_bases='GTC',
alternate_bases=[], names=['rs1234'], quality=40,
filters=['q10', 's50'],
info={'NS': vcfio.VariantInfo(data=2, field_count='1')})
info={'NS': vcfio.VariantInfo(data=2, field_count='1',
annotation_names=None)})
variant.calls.append(
vcfio.VariantCall(name='Sample1', genotype=[1, 0],
phaseset=vcfio.DEFAULT_PHASESET_VALUE,
Expand All @@ -127,7 +130,8 @@ def _get_sample_variant_3():
variant = vcfio.Variant(
reference_name='19', start=11, end=12, reference_bases='C',
alternate_bases=['<SYMBOLIC>'], quality=49, filters=['q10'],
info={'AF': vcfio.VariantInfo(data=[0.5], field_count='A')})
info={'AF': vcfio.VariantInfo(data=[0.5], field_count='A',
annotation_names=None)})
variant.calls.append(
vcfio.VariantCall(name='Sample1', genotype=[0, 1],
phaseset='1',
Expand Down Expand Up @@ -392,7 +396,8 @@ def test_no_samples(self):
expected_variant = Variant(
reference_name='19', start=122, end=123, reference_bases='G',
alternate_bases=['A'], filters=['PASS'],
info={'AF': VariantInfo(data=[0.2], field_count='A')})
info={'AF': VariantInfo(data=[0.2], field_count='A',
annotation_names=None)})
read_data = self._create_temp_file_and_read_records(
_SAMPLE_HEADER_LINES[:-1] + [header_line, record_line])
self.assertEqual(1, len(read_data))
Expand Down Expand Up @@ -423,19 +428,27 @@ def test_info_numbers_and_types(self):
variant_1 = Variant(
reference_name='19', start=1, end=2, reference_bases='A',
alternate_bases=['T', 'C'],
info={'HA': VariantInfo(data=['a1', 'a2'], field_count='A'),
'HG': VariantInfo(data=[1, 2, 3], field_count='G'),
'HR': VariantInfo(data=['a', 'b', 'c'], field_count='R'),
'HF': VariantInfo(data=True, field_count='0'),
'HU': VariantInfo(data=[0.1], field_count=None)})
info={'HA': VariantInfo(data=['a1', 'a2'], field_count='A',
annotation_names=None),
'HG': VariantInfo(data=[1, 2, 3], field_count='G',
annotation_names=None),
'HR': VariantInfo(data=['a', 'b', 'c'], field_count='R',
annotation_names=None),
'HF': VariantInfo(data=True, field_count='0',
annotation_names=None),
'HU': VariantInfo(data=[0.1], field_count=None,
annotation_names=None)})
variant_1.calls.append(VariantCall(name='Sample1', genotype=[1, 0]))
variant_1.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))
variant_2 = Variant(
reference_name='19', start=123, end=124, reference_bases='A',
alternate_bases=['T'],
info={'HG': VariantInfo(data=[3, 4, 5], field_count='G'),
'HR': VariantInfo(data=['d', 'e'], field_count='R'),
'HU': VariantInfo(data=[1.1, 1.2], field_count=None)})
info={'HG': VariantInfo(data=[3, 4, 5], field_count='G',
annotation_names=None),
'HR': VariantInfo(data=['d', 'e'], field_count='R',
annotation_names=None),
'HU': VariantInfo(data=[1.1, 1.2], field_count=None,
annotation_names=None)})
variant_2.calls.append(VariantCall(name='Sample1', genotype=[0, 0]))
variant_2.calls.append(VariantCall(name='Sample2', genotype=[0, 1]))
read_data = self._create_temp_file_and_read_records(
Expand Down Expand Up @@ -755,9 +768,12 @@ def test_info_list(self):
def test_info_field_count(self):
coder = self._get_coder()
variant = Variant()
variant.info['NS'] = VariantInfo(data=3, field_count='1')
variant.info['AF'] = VariantInfo(data=[0.333, 0.667], field_count='A')
variant.info['DB'] = VariantInfo(data=True, field_count='0')
variant.info['NS'] = VariantInfo(data=3, field_count='1',
annotation_names=None)
variant.info['AF'] = VariantInfo(data=[0.333, 0.667], field_count='A',
annotation_names=None)
variant.info['DB'] = VariantInfo(data=True, field_count='0',
annotation_names=None)
expected = '. . . . . . . NS=3;AF=0.333,0.667;DB .\n'

self._assert_variant_lines_equal(coder.encode(variant), expected)
Expand Down
84 changes: 78 additions & 6 deletions gcp_variant_transforms/libs/bigquery_vcf_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from apache_beam.io.gcp.internal.clients import bigquery
from gcp_variant_transforms.beam_io import vcfio
from gcp_variant_transforms.libs import vcf_header_parser


__all__ = ['generate_schema_from_header_fields', 'get_rows_from_variant',
Expand Down Expand Up @@ -82,8 +83,15 @@ class _TableFieldConstants(object):
_JSON_CONCATENATION_OVERHEAD_BYTES = 5


# TODO(bashir2): Using type identifiers like ``HeaderFields`` does not seem
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the type identifiers have been on my mind for some time. I inherited it from Beam conventions, but have not actually set up automated documentation generation to verify that it finds the correct class.
I think the 'correct' way is something like :class:`HeaderFields`. In any case, I filed Issue #108 for this. You can remove the TODO from here as it's awkward to single out this particular field.

# to be picked up by tools, because they cannot resolve these type identifiers.
# We should either fix these or otherwise stop using the convention of using
# double ` when not recognized by tools.


def generate_schema_from_header_fields(header_fields, variant_merger=None,
split_alternate_allele_info_fields=True):
split_alternate_allele_info_fields=True,
annotation_field=None):
"""Returns a ``TableSchema`` for the BigQuery table storing variants.

Args:
Expand Down Expand Up @@ -141,6 +149,28 @@ def generate_schema_from_header_fields(header_fields, variant_merger=None,
type=_get_bigquery_type_from_vcf_type(field.type),
mode=_TableFieldConstants.MODE_NULLABLE,
description=_get_bigquery_sanitized_field(field.desc)))
if annotation_field:
annotation_names = []
for key, field in header_fields.infos.iteritems():
if key == annotation_field:
annotation_names = vcf_header_parser.extract_annotation_names(
field.desc)
break
if not annotation_names:
raise ValueError('Annotation field {} not found'.format(annotation_field))
annotation_record = bigquery.TableFieldSchema(
name=_get_bigquery_sanitized_field(annotation_field),
type=_TableFieldConstants.TYPE_RECORD,
mode=_TableFieldConstants.MODE_REPEATED,
description='List of annotations for this alternate.')
for annotation in annotation_names:
annotation_record.fields.append(bigquery.TableFieldSchema(
name=_get_bigquery_sanitized_field(annotation),
type=_TableFieldConstants.TYPE_STRING,
mode=_TableFieldConstants.MODE_NULLABLE,
# TODO(bashir2): Add descriptions of known annotations, e.g., from VEP
description=''))
alternate_bases_record.fields.append(annotation_record)
schema.fields.append(alternate_bases_record)

schema.fields.append(bigquery.TableFieldSchema(
Expand Down Expand Up @@ -202,7 +232,8 @@ def generate_schema_from_header_fields(header_fields, variant_merger=None,
# END info is already included by modifying the end_position.
if (key == vcfio.END_INFO_KEY or
(split_alternate_allele_info_fields and
field.num == vcf.parser.field_counts[_FIELD_COUNT_ALTERNATE_ALLELE])):
field.num == vcf.parser.field_counts[_FIELD_COUNT_ALTERNATE_ALLELE]) or
key == annotation_field):
continue
schema.fields.append(bigquery.TableFieldSchema(
name=_get_bigquery_sanitized_field_name(key),
Expand All @@ -217,7 +248,7 @@ def generate_schema_from_header_fields(header_fields, variant_merger=None,

# TODO: refactor this to use a class instead.
def get_rows_from_variant(variant, split_alternate_allele_info_fields=True,
omit_empty_sample_calls=False):
omit_empty_sample_calls=False, annotation_field=None):
"""Yields BigQuery rows according to the schema from the given variant.

There is a 10MB limit for each BigQuery row, which can be exceeded by having
Expand All @@ -232,6 +263,8 @@ def get_rows_from_variant(variant, split_alternate_allele_info_fields=True,
of the INFO fields.
omit_empty_sample_calls (bool): If true, samples that don't have a given
call will be omitted.
annotation_field (str): If provided, it is the name of the INFO field
that contains the annotation list.
Yields:
A dict representing a BigQuery row from the given variant. The row may have
a subset of the calls if it exceeds the maximum allowed BigQuery row size.
Expand All @@ -241,9 +274,11 @@ def get_rows_from_variant(variant, split_alternate_allele_info_fields=True,
# TODO: Add error checking here for cases where the schema defined
# by the headers does not match actual records.
base_row = _get_base_row_from_variant(
variant, split_alternate_allele_info_fields)
variant, split_alternate_allele_info_fields, annotation_field)
base_row_size_in_bytes = _get_json_object_size(base_row)
row_size_in_bytes = base_row_size_in_bytes
# TODO(bashir2): It seems that BigQueryWriter buffers 1000 rows and this
# can cause BigQuery API exceptions. We need to fix this!
row = copy.deepcopy(base_row) # Keep base_row intact.
for call in variant.calls:
call_record, empty = _get_call_record(call)
Expand Down Expand Up @@ -287,7 +322,34 @@ def _get_call_record(call):
return call_record, is_empty


def _get_base_row_from_variant(variant, split_alternate_allele_info_fields):
def _create_list_of_annotation_lists(alt, info):
"""Extracts list of annotations for an alternate.

Args:
alt (str): The alternate for which the annotation lists are extracted.
info (``VariantInfo``): The data for the annotation INFO field.
"""
annotation_record = []
for data in info.data:
annotation_list = vcf_header_parser.extract_annotation_list_with_alt(data)
if len(annotation_list) != len(info.annotation_names) + 1:
# TODO(bashir2): This and several other annotation related checks should
# be made "soft", i.e., handled gracefully. We will do this as part of the
# bigger issue to make schema error checking more robust.
raise ValueError('Number of annotations does not match header')
# TODO(bashir2): The alternate allele format is not necessarily as simple
# as being equal to an 'alt', so this needs to be fixed to handle all
# possible formats.
if annotation_list[0] == alt:
annotation_dict = {}
for i in range(len(info.annotation_names)):
annotation_dict[info.annotation_names[i]] = annotation_list[i + 1]
annotation_record.append(annotation_dict)
return annotation_record


def _get_base_row_from_variant(variant, split_alternate_allele_info_fields,
annotation_field):
"""A helper method for ``get_rows_from_variant`` to get row without calls."""
row = {
ColumnKeyConstants.REFERENCE_NAME: variant.reference_name,
Expand Down Expand Up @@ -315,12 +377,22 @@ def _get_base_row_from_variant(variant, split_alternate_allele_info_fields):
info_key, variant))
alt_record[_get_bigquery_sanitized_field_name(info_key)] = (
_get_bigquery_sanitized_field(info.data[alt_index]))
if annotation_field:
for info_key, info in variant.info.iteritems():
if info_key == annotation_field:
if not info.annotation_names:
raise ValueError(
'Annotation list not found for field {}'.format(info_key))
alt_record[_get_bigquery_sanitized_field(annotation_field)] = (
_create_list_of_annotation_lists(alt, info))
break
row[ColumnKeyConstants.ALTERNATE_BASES].append(alt_record)
# Add info.
for key, info in variant.info.iteritems():
if (info.data is not None and
(not split_alternate_allele_info_fields or
info.field_count != _FIELD_COUNT_ALTERNATE_ALLELE)):
info.field_count != _FIELD_COUNT_ALTERNATE_ALLELE) and
key != annotation_field):
row[_get_bigquery_sanitized_field_name(key)] = (
_get_bigquery_sanitized_field(info.data))
# Set calls to empty for now (will be filled later).
Expand Down
Loading