From 8ecb41463536d9792b269f229c30c71f1fdc869b Mon Sep 17 00:00:00 2001 From: Tural Neymanov Date: Tue, 15 Jan 2019 10:55:49 -0500 Subject: [PATCH 1/4] Parse out the reusable logic from infer_headers and variant_to_bigquery modules into libs directory. --- .../libs/annotation/annotation_parser.py | 17 + gcp_variant_transforms/libs/bigquery_util.py | 84 ++++ .../libs/bigquery_util_test.py | 327 ++++++++++++++++ .../libs/infer_headers_util.py | 360 ++++++++++++++++++ .../libs/infer_headers_util_test.py | 355 +++++++++++++++++ .../libs/processed_variant.py | 5 +- .../transforms/infer_headers.py | 347 +---------------- .../transforms/infer_headers_test.py | 136 +------ .../transforms/variant_to_bigquery.py | 88 +---- .../transforms/variant_to_bigquery_test.py | 323 ---------------- 10 files changed, 1158 insertions(+), 884 deletions(-) create mode 100644 gcp_variant_transforms/libs/infer_headers_util.py create mode 100644 gcp_variant_transforms/libs/infer_headers_util_test.py diff --git a/gcp_variant_transforms/libs/annotation/annotation_parser.py b/gcp_variant_transforms/libs/annotation/annotation_parser.py index f526539cb..b33eef130 100644 --- a/gcp_variant_transforms/libs/annotation/annotation_parser.py +++ b/gcp_variant_transforms/libs/annotation/annotation_parser.py @@ -43,6 +43,9 @@ _BREAKEND_ALT_RE = (re.compile( r'^(?P.*([\[\]]).*):(?P.*)([\[\]]).*$')) +# Filled with annotation field and name data, then used as a header ID. +_BASE_ANNOTATION_TYPE_KEY = '{}_{}_TYPE' + class AnnotationParserException(Exception): pass @@ -420,3 +423,17 @@ def reconstruct_annotation_description(annotation_names): returns 'Format: Allele|Consequence|IMPACT|SYMBOL|Gene'. """ return ' '.join(['Format:', '|'.join(annotation_names)]) + + +def get_inferred_annotation_type_header_key(annot_field, name): + # type: (str, str) -> str + """Creates ID values for annotation type info headers. + + Args: + annot_field: field name representing annotation field (e.g. 'CSQ'). + name: annotation data field names (e.g. 'IMPACT'). + + Returns: + Info ID value (e.g. CSQ_IMPACT_TYPE). + """ + return _BASE_ANNOTATION_TYPE_KEY.format(annot_field, name) diff --git a/gcp_variant_transforms/libs/bigquery_util.py b/gcp_variant_transforms/libs/bigquery_util.py index 5277f4ac9..c7d76db9a 100644 --- a/gcp_variant_transforms/libs/bigquery_util.py +++ b/gcp_variant_transforms/libs/bigquery_util.py @@ -14,10 +14,14 @@ """Constants and simple utility functions related to BigQuery.""" +import exceptions import enum import re from typing import List, Tuple, Union # pylint: disable=unused-import +from apache_beam.io.gcp.internal.clients import bigquery +from apitools.base.py import exceptions +from oauth2client.client import GoogleCredentials from vcf import parser from gcp_variant_transforms.beam_io import vcf_header_io @@ -191,3 +195,83 @@ def get_avro_type_from_bigquery_type_mode(bigquery_type, bigquery_mode): return [avro_type, AvroConstants.NULL] else: return avro_type + +def update_bigquery_schema_on_append(schema_fields, output_table): + # type: (bool) -> None + # if table does not exist, do not need to update the schema. + # TODO (yifangchen): Move the logic into validate(). + output_table_re_match = re.match( + r'^((?P.+):)(?P\w+)\.(?P[\w\$]+)$', + output_table) + credentials = GoogleCredentials.get_application_default().create_scoped( + ['https://www.googleapis.com/auth/bigquery']) + client = bigquery.BigqueryV2(credentials=credentials) + try: + project_id = output_table_re_match.group('project') + dataset_id = output_table_re_match.group('dataset') + table_id = output_table_re_match.group('table') + existing_table = client.tables.Get(bigquery.BigqueryTablesGetRequest( + projectId=project_id, + datasetId=dataset_id, + tableId=table_id)) + except exceptions.HttpError: + return + + new_schema = bigquery.TableSchema() + new_schema.fields = get_merged_field_schemas(existing_table.schema.fields, + schema_fields) + existing_table.schema = new_schema + try: + client.tables.Update(bigquery.BigqueryTablesUpdateRequest( + projectId=project_id, + datasetId=dataset_id, + table=existing_table, + tableId=table_id)) + except exceptions.HttpError as e: + raise RuntimeError('BigQuery schema update failed: %s' % str(e)) + + +def get_merged_field_schemas( + field_schemas_1, # type: List[bigquery.TableFieldSchema] + field_schemas_2 # type: List[bigquery.TableFieldSchema] + ): + # type: (...) -> List[bigquery.TableFieldSchema] + """Merges the `field_schemas_1` and `field_schemas_2`. + + Args: + field_schemas_1: A list of `TableFieldSchema`. + field_schemas_2: A list of `TableFieldSchema`. + Returns: + A new schema with new fields from `field_schemas_2` appended to + `field_schemas_1`. + Raises: + ValueError: If there are fields with the same name, but different modes or + different types. + """ + existing_fields = {} # type: Dict[str, bigquery.TableFieldSchema] + merged_field_schemas = [] # type: List[bigquery.TableFieldSchema] + for field_schema in field_schemas_1: + existing_fields.update({field_schema.name: field_schema}) + merged_field_schemas.append(field_schema) + + for field_schema in field_schemas_2: + if field_schema.name not in existing_fields.keys(): + merged_field_schemas.append(field_schema) + else: + existing_field_schema = existing_fields.get(field_schema.name) + if field_schema.mode != existing_field_schema.mode: + raise ValueError( + 'The mode of field {} is not compatible. The original mode is {}, ' + 'and the new mode is {}.'.format(field_schema.name, + existing_field_schema.mode, + field_schema.mode)) + if field_schema.type != existing_field_schema.type: + raise ValueError( + 'The type of field {} is not compatible. The original type is {}, ' + 'and the new type is {}.'.format(field_schema.name, + existing_field_schema.type, + field_schema.type)) + if field_schema.type == TableFieldConstants.TYPE_RECORD: + existing_field_schema.fields = get_merged_field_schemas( + existing_field_schema.fields, field_schema.fields) + return merged_field_schemas diff --git a/gcp_variant_transforms/libs/bigquery_util_test.py b/gcp_variant_transforms/libs/bigquery_util_test.py index 4d95ce69c..cc235c100 100644 --- a/gcp_variant_transforms/libs/bigquery_util_test.py +++ b/gcp_variant_transforms/libs/bigquery_util_test.py @@ -16,8 +16,11 @@ import unittest +from apache_beam.io.gcp.internal.clients import bigquery from gcp_variant_transforms.beam_io import vcf_header_io from gcp_variant_transforms.libs import bigquery_util +from gcp_variant_transforms.libs.bigquery_util import ColumnKeyConstants +from gcp_variant_transforms.libs.bigquery_util import TableFieldConstants class BigqueryUtilTest(unittest.TestCase): @@ -87,3 +90,327 @@ def test_get_vcf_num_from_bigquery_schema(self): bigquery_util.get_vcf_num_from_bigquery_schema( bigquery_mode=None, bigquery_type=bigquery_util.TableFieldConstants.TYPE_BOOLEAN)) + + + + def test_merge_field_schemas_no_same_id(self): + field_schemas_1 = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc'), + bigquery.TableFieldSchema( + name='IFR', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_REPEATED, + description='INFO foo desc') + ] + field_schemas_2 = [ + bigquery.TableFieldSchema( + name='AB', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc') + ] + merged_field_schemas = bigquery_util.get_merged_field_schemas( + field_schemas_1, field_schemas_2) + expected_merged_field_schemas = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc'), + bigquery.TableFieldSchema( + name='IFR', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_REPEATED, + description='INFO foo desc'), + bigquery.TableFieldSchema( + name='AB', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc') + ] + self.assertEqual(merged_field_schemas, expected_merged_field_schemas) + + def test_merge_field_schemas_same_id_no_conflicts(self): + field_schemas_1 = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc'), + bigquery.TableFieldSchema( + name='IFR', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_REPEATED, + description='INFO foo desc') + ] + field_schemas_2 = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc'), + bigquery.TableFieldSchema( + name='AB', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc') + ] + merged_field_schemas = bigquery_util.get_merged_field_schemas( + field_schemas_1, field_schemas_2) + expected_merged_field_schemas = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc'), + bigquery.TableFieldSchema( + name='IFR', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_REPEATED, + description='INFO foo desc'), + bigquery.TableFieldSchema( + name='AB', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc') + ] + self.assertEqual(merged_field_schemas, expected_merged_field_schemas) + + def test_merge_field_schemas_conflict_mode(self): + field_schemas_1 = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc') + ] + field_schemas_2 = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_REPEATED, + description='INFO foo desc') + ] + self.assertRaises(ValueError, bigquery_util.get_merged_field_schemas, + field_schemas_1, field_schemas_2) + + def test_merge_field_schemas_conflict_type(self): + field_schemas_1 = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc') + ] + field_schemas_2 = [ + bigquery.TableFieldSchema( + name='II', + type=TableFieldConstants.TYPE_FLOAT, + mode=TableFieldConstants.MODE_NULLABLE, + description='INFO foo desc') + ] + self.assertRaises(ValueError, bigquery_util.get_merged_field_schemas, + field_schemas_1, field_schemas_2) + + def test_merge_field_schemas_conflict_record_fields(self): + call_record_1 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + call_record_1.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_BOOLEAN, + mode=TableFieldConstants.MODE_NULLABLE, + description='FORMAT foo desc')) + field_schemas_1 = [call_record_1] + + call_record_2 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + call_record_2.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='FORMAT foo desc')) + field_schemas_2 = [call_record_2] + self.assertRaises(ValueError, bigquery_util.get_merged_field_schemas, + field_schemas_1, field_schemas_2) + + def test_merge_field_schemas_same_record(self): + call_record_1 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + call_record_1.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_BOOLEAN, + mode=TableFieldConstants.MODE_NULLABLE, + description='FORMAT foo desc')) + + field_schemas_1 = [call_record_1] + field_schemas_2 = [call_record_1] + + expected_merged_field_schemas = [call_record_1] + self.assertEqual( + bigquery_util.get_merged_field_schemas(field_schemas_1, + field_schemas_2), + expected_merged_field_schemas) + + def test_merge_field_schemas_merge_record_fields(self): + call_record_1 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + call_record_1.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_BOOLEAN, + mode=TableFieldConstants.MODE_NULLABLE, + description='FORMAT foo desc')) + + field_schemas_1 = [call_record_1] + + call_record_2 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + call_record_2.fields.append(bigquery.TableFieldSchema( + name='GQ', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='FORMAT foo desc')) + field_schemas_2 = [call_record_2] + + call_record_3 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + call_record_3.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_BOOLEAN, + mode=TableFieldConstants.MODE_NULLABLE, + description='FORMAT foo desc')) + call_record_3.fields.append(bigquery.TableFieldSchema( + name='GQ', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_NULLABLE, + description='FORMAT foo desc')) + + expected_merged_field_schemas = [call_record_3] + self.assertEqual( + bigquery_util.get_merged_field_schemas(field_schemas_1, + field_schemas_2), + expected_merged_field_schemas) + + def test_merge_field_schemas_conflict_inner_record_fields(self): + record_1 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + inner_record_1 = bigquery.TableFieldSchema( + name='inner record', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + inner_record_1.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='FORMAT foo desc')) + record_1.fields.append(inner_record_1) + field_schemas_1 = [record_1] + + record_2 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + inner_record_2 = bigquery.TableFieldSchema( + name='inner record', + type=TableFieldConstants.TYPE_INTEGER, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + inner_record_2.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='FORMAT foo desc')) + record_2.fields.append(inner_record_2) + field_schemas_2 = [record_2] + self.assertRaises(ValueError, bigquery_util.get_merged_field_schemas, + field_schemas_1, field_schemas_2) + + def test_merge_field_schemas_merge_inner_record_fields(self): + record_1 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + inner_record_1 = bigquery.TableFieldSchema( + name='inner record', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + inner_record_1.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='FORMAT foo desc')) + record_1.fields.append(inner_record_1) + field_schemas_1 = [record_1] + + record_2 = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + inner_record_2 = bigquery.TableFieldSchema( + name='inner record', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + inner_record_2.fields.append(bigquery.TableFieldSchema( + name='AB', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='FORMAT foo desc')) + record_2.fields.append(inner_record_2) + field_schemas_2 = [record_2] + + merged_record = bigquery.TableFieldSchema( + name=ColumnKeyConstants.CALLS, + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + merged_inner_record = bigquery.TableFieldSchema( + name='inner record', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='One record for each call.') + merged_inner_record.fields.append(bigquery.TableFieldSchema( + name='FB', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='FORMAT foo desc')) + merged_inner_record.fields.append(bigquery.TableFieldSchema( + name='AB', + type=TableFieldConstants.TYPE_RECORD, + mode=TableFieldConstants.MODE_REPEATED, + description='FORMAT foo desc')) + merged_record.fields.append(merged_inner_record) + expected_merged_field_schemas = [merged_record] + self.assertEqual( + bigquery_util.get_merged_field_schemas(field_schemas_1, + field_schemas_2), + expected_merged_field_schemas) diff --git a/gcp_variant_transforms/libs/infer_headers_util.py b/gcp_variant_transforms/libs/infer_headers_util.py new file mode 100644 index 000000000..b393456dd --- /dev/null +++ b/gcp_variant_transforms/libs/infer_headers_util.py @@ -0,0 +1,360 @@ +# Copyright 2019 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Helper module for Header Inference operations.""" + +from __future__ import absolute_import + +import logging +from typing import Any, Dict, List, Optional, Union # pylint: disable=unused-import + +from vcf.parser import _Format as Format +from vcf.parser import _Info as Info +from vcf.parser import field_counts + +from gcp_variant_transforms.beam_io import vcf_header_io +from gcp_variant_transforms.beam_io import vcfio # pylint: disable=unused-import +from gcp_variant_transforms.libs.annotation import annotation_parser +from gcp_variant_transforms.libs import vcf_field_conflict_resolver + +_FIELD_COUNT_ALTERNATE_ALLELE = 'A' + +# Alias for the header key/type constants to make referencing easier. +_HeaderKeyConstants = vcf_header_io.VcfParserHeaderKeyConstants +_HeaderTypeConstants = vcf_header_io.VcfHeaderFieldTypeConstants + +def _get_field_count(field_value): + # type: (Union[List, bool, int, str]) -> Optional[int] + """ + Args: + field_value: value for the field returned by PyVCF. E.g. [0.33, 0.66] is a + field value for Allele frequency (AF) field. + """ + if isinstance(field_value, list): + return field_counts['.'] + elif isinstance(field_value, bool): + return 0 + else: + return 1 + +def _get_field_type(field_value): + """ + Args: + field_value (list, bool, integer, or string): value for the field + returned by PyVCF. E.g. [0.33, 0.66] is a field value for Allele + frequency (AF) field. + """ + if isinstance(field_value, list): + return (_get_field_type(field_value[0]) if field_value else + vcf_header_io.VcfHeaderFieldTypeConstants.STRING) + + if isinstance(field_value, bool): + return vcf_header_io.VcfHeaderFieldTypeConstants.FLAG + elif isinstance(field_value, int): + return vcf_header_io.VcfHeaderFieldTypeConstants.INTEGER + elif isinstance(field_value, float): + return vcf_header_io.VcfHeaderFieldTypeConstants.FLOAT + elif _can_cast_to(field_value, int): + return vcf_header_io.VcfHeaderFieldTypeConstants.INTEGER + elif _can_cast_to(field_value, float): + return vcf_header_io.VcfHeaderFieldTypeConstants.FLOAT + else: + return vcf_header_io.VcfHeaderFieldTypeConstants.STRING + +def _can_cast_to(value, cast_type): + """Returns true if `value` can be casted to type `type`""" + try: + _ = cast_type(value) + return True + except ValueError: + return False + +def _get_corrected_type(defined_type, value): + # type: (str, Any) -> str + """Returns the corrected type according to `defined_type` and `value`. + + It handles one special case PyVCF cannot handle, i.e., the defined type is + `Integer`, but the provided value is float. In this case, correct the type + to be `Float`. + + Note that if `value` is a float instance but with an integer value + (e.g. 2.0), the type will stay the same as `defined_type`. + """ + if defined_type == _HeaderTypeConstants.INTEGER: + if isinstance(value, float) and not value.is_integer(): + return _HeaderTypeConstants.FLOAT + if isinstance(value, list): + for item in value: + corrected_type = _get_corrected_type(defined_type, item) + if corrected_type != defined_type: + return corrected_type + return defined_type + +def _infer_mismatched_info_field(field_key, # type: str + field_value, # type: Any + defined_header, # type: Dict + num_alternate_bases # type: int + ): + # type: (...) -> Optional[Info] + """Returns corrected info if there are mismatches. + + Two mismatches are handled: + - Defined num is `A`, but the provided values do not have the same + cardinality as the alternate bases. Correct the num to be `None`. + - Defined type is `Integer`, but the provided value is float. Correct the + type to be `Float`. + Args: + field_key: the info field key. + field_value: the value of the field key given in the variant. + defined_header: The definition of `field_key` in the header. + num_alternate_bases: number of the alternate bases. + Returns: + Corrected info definition if there are mismatches. + """ + corrected_num = defined_header.get(_HeaderKeyConstants.NUM) + if (corrected_num == field_counts[_FIELD_COUNT_ALTERNATE_ALLELE] and + len(field_value) != num_alternate_bases): + corrected_num = field_counts['.'] + + corrected_type = _get_corrected_type( + defined_header.get(_HeaderKeyConstants.TYPE), field_value) + + if (corrected_type != defined_header.get(_HeaderKeyConstants.TYPE) or + corrected_num != defined_header.get(_HeaderKeyConstants.NUM)): + return Info(field_key, + corrected_num, + corrected_type, + defined_header.get(_HeaderKeyConstants.DESC), + defined_header.get(_HeaderKeyConstants.SOURCE), + defined_header.get(_HeaderKeyConstants.VERSION)) + return None + +def _infer_mismatched_format_field(field_key, # type: str + field_value, # type: Any + defined_header # type: Dict + ): + # type: (...) -> Optional[Format] + """Returns corrected format if there are mismatches. + + One type of mismatches is handled: + - Defined type is `Integer`, but the provided value is float. Correct the + type to be `Float`. + Args: + field_key: the format field key. + field_value: the value of the field key given in the variant. + defined_header: The definition of `field_key` in the header. + Returns: + Corrected format definition if there are mismatches. + """ + corrected_type = _get_corrected_type( + defined_header.get(_HeaderKeyConstants.TYPE), field_value) + if corrected_type != defined_header.get(_HeaderKeyConstants.TYPE): + return Format(field_key, + defined_header.get(_HeaderKeyConstants.NUM), + corrected_type, + defined_header.get(_HeaderKeyConstants.DESC)) + return None + +def _infer_standard_info_fields(variant, infos, defined_headers): + # type: (vcfio.Variant, Dict[str, Info], vcf_header_io.VcfHeader) -> None + """Updates `infos` with inferred info fields. + + Two types of info fields are inferred: + - The info fields are undefined in the headers. + - The info fields' definitions provided by the header does not match the + field value. + Args: + variant: variant object + infos: dict of (info_key, `Info`) for any info field in + `variant` that is not defined in the header or the definition mismatches + the field values. + defined_headers: header fields defined in header section of VCF files. + """ + for info_field_key, info_field_value in variant.info.iteritems(): + if not defined_headers or info_field_key not in defined_headers.infos: + if info_field_key in infos: + raise ValueError( + 'Duplicate INFO field "{}" in variant "{}"'.format( + info_field_key, variant)) + logging.warning('Undefined INFO field "%s" in variant "%s"', + info_field_key, str(variant)) + infos[info_field_key] = Info(info_field_key, + _get_field_count(info_field_value), + _get_field_type(info_field_value), + '', # NO_DESCRIPTION + '', # UNKNOWN_SOURCE + '') # UNKNOWN_VERSION + else: + defined_header = defined_headers.infos.get(info_field_key) + corrected_info = _infer_mismatched_info_field( + info_field_key, info_field_value, + defined_header, len(variant.alternate_bases)) + if corrected_info: + logging.warning( + 'Incorrect INFO field "%s". Defined as "type=%s,num=%s", ' + 'got "%s", in variant "%s"', + info_field_key, defined_header.get(_HeaderKeyConstants.TYPE), + str(defined_header.get(_HeaderKeyConstants.NUM)), + str(info_field_value), str(variant)) + infos[info_field_key] = corrected_info + +def _infer_annotation_type_info_fields(variant, + infos, + defined_headers, + annotation_fields_to_infer + ): + # type: (vcfio.Variant, Dict[str, Info], vcf_header_io.VcfHeader) -> None + """Updates `infos` with inferred annotation type info fields. + + All annotation headers in each annotation field are converted to Info header + lines where the new ID corresponds to the given annotation field and header, + and the new TYPE corresponds to inferred type of the original header. Since + each variant potentially contains multiple values for each annotation + header, a small 'merge' of value types is performed before VcfHeader + creation for each variant. + Args: + variant: variant object + infos: dict of (info_key, `Info`) for any info field in + `variant` that is not defined in the header or the definition mismatches + the field values. + defined_headers: header fields defined in header section of VCF files. + annotation_fields_to_infer: list of info fields treated as annotation + fields (e.g. ['CSQ', 'CSQ_VT']). + """ + + def _check_annotation_lists_lengths(names, values): + lengths = set(len(v) for v in values) + lengths.add(len(names)) + if len(lengths) != 1: + error = ('Annotation lists have inconsistent lengths: {}.\nnames={}\n' + 'values={}').format(lengths, names, values) + raise ValueError(error) + + resolver = vcf_field_conflict_resolver.FieldConflictResolver( + resolve_always=True) + for field in annotation_fields_to_infer: + if field not in variant.info: + continue + annotation_names = annotation_parser.extract_annotation_names( + defined_headers.infos[field][_HeaderKeyConstants.DESC]) + # First element (ALT) is ignored, since its type is hard-coded as string + annotation_values = [annotation_parser.extract_annotation_list_with_alt( + annotation)[1:] for annotation in variant.info[field]] + _check_annotation_lists_lengths(annotation_names, annotation_values) + annotation_values = zip(*annotation_values) + for name, values in zip(annotation_names, annotation_values): + variant_merged_type = None + for v in values: + if not v: + continue + variant_merged_type = resolver.resolve_attribute_conflict( + _HeaderKeyConstants.TYPE, + variant_merged_type, + _get_field_type(v)) + if variant_merged_type == _HeaderTypeConstants.STRING: + break + key_id = annotation_parser.get_inferred_annotation_type_header_key( + field, name) + infos[key_id] = Info(key_id, + 1, # field count + variant_merged_type, + ('Inferred type field for annotation {}.'.format( + name)), + '', # UNKNOWN_SOURCE + '') # UNKNOWN_VERSION + +def infer_info_fields( + variant, + defined_headers, + infer_headers=False, # type: bool + annotation_fields_to_infer=None # type: Optional[List[str]] + ): + """Returns inferred info fields. + + Up to three types of info fields are inferred: + + if `infer_headers` is True: + - The info fields are undefined in the headers. + - The info fields' definitions provided by the header does not match the + field value. + if `infer_annotation_types` is True: + - Fields containing type information of corresponding annotation Info + fields. + + Args: + variant: variant object + defined_headers: header fields defined in header section of VCF files. + infer_headers: If true, header fields are inferred from variant data. + annotation_fields_to_infer: list of info fields treated as annotation + fields (e.g. ['CSQ', 'CSQ_VT']). + Returns: + infos: dict of (info_key, `Info`) for any info field in + `variant` that is not defined in the header or the definition mismatches + the field values. + """ + infos = {} + if infer_headers: + _infer_standard_info_fields(variant, infos, defined_headers) + if annotation_fields_to_infer: + _infer_annotation_type_info_fields( + variant, infos, defined_headers, annotation_fields_to_infer) + return infos + +def infer_format_fields(variant, defined_headers): + # type: (vcfio.Variant, vcf_header_io.VcfHeader) -> Dict[str, Format] + """Returns inferred format fields. + + Two types of format fields are inferred: + - The format fields are undefined in the headers. + - The format definition provided by the headers does not match the field + values. + Args: + variant: variant object + defined_headers: header fields defined in header section of VCF files. + Returns: + A dict of (format_key, `Format`) for any format key in + `variant` that is not defined in the header or the definition mismatches + the field values. + """ + formats = {} + for call in variant.calls: + for format_key, format_value in call.info.iteritems(): + if not defined_headers or format_key not in defined_headers.formats: + if format_key in formats: + raise ValueError( + 'Duplicate FORMAT field "{}" in variant "{}"'.format( + format_key, variant)) + logging.warning('Undefined FORMAT field "%s" in variant "%s"', + format_key, str(variant)) + formats[format_key] = Format(format_key, + _get_field_count(format_value), + _get_field_type(format_value), + '') # NO_DESCRIPTION + # No point in proceeding. All other calls have the same FORMAT. + break + for call in variant.calls: + for format_key, format_value in call.info.iteritems(): + if defined_headers and format_key in defined_headers.formats: + defined_header = defined_headers.formats.get(format_key) + corrected_format = _infer_mismatched_format_field( + format_key, format_value, defined_header) + if corrected_format: + logging.warning( + 'Incorrect FORMAT field "%s". Defined as "type=%s,num=%s", ' + 'got "%s" in variant "%s"', + format_key, defined_header.get(_HeaderKeyConstants.TYPE), + str(defined_header.get(_HeaderKeyConstants.NUM)), + str(format_value), str(variant)) + formats[format_key] = corrected_format + return formats diff --git a/gcp_variant_transforms/libs/infer_headers_util_test.py b/gcp_variant_transforms/libs/infer_headers_util_test.py new file mode 100644 index 000000000..dbcada040 --- /dev/null +++ b/gcp_variant_transforms/libs/infer_headers_util_test.py @@ -0,0 +1,355 @@ +# Copyright 2019 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for infer_headers_util module.""" + +from __future__ import absolute_import + +from collections import OrderedDict +import unittest + + +from vcf.parser import _Format as Format +from vcf.parser import _Info as Info +from vcf.parser import field_counts + +from gcp_variant_transforms.beam_io import vcf_header_io +from gcp_variant_transforms.beam_io import vcfio +from gcp_variant_transforms.libs import infer_headers_util + + +class InferHeaderUtilTest(unittest.TestCase): + """Test case for `InferHeaderFields` DoFn.""" + + def _get_sample_variant_1(self): + variant = vcfio.Variant( + reference_name='chr19', start=11, end=12, reference_bases='C', + alternate_bases=['A', 'TT'], names=['rs1', 'rs2'], quality=2, + filters=['PASS'], + info={'IS': 'some data', 'ISI': '1', 'ISF': '1.0', + 'IF': 1.0, 'IB': True, 'IA': [1, 2]}, + calls=[vcfio.VariantCall( + name='Sample1', genotype=[0, 1], phaseset='*', + info={'FI': 20, 'FU': [10.0, 20.0]})] + ) + return variant + + def _get_sample_variant_info_ia_cardinality_mismatch(self): + variant = vcfio.Variant( + reference_name='chr19', start=11, end=12, reference_bases='C', + alternate_bases=['A', 'TT'], names=['rs1', 'rs2'], quality=2, + filters=['PASS'], + info={'IS': 'some data', + 'ISI': '1', + 'ISF': '1.0', + 'IF': 1.0, + 'IB': True, + 'IA': [0.1]}, + calls=[vcfio.VariantCall( + name='Sample1', genotype=[0, 1], phaseset='*', + info={'FI': 20, 'FU': [10.0, 20.0]})] + ) + return variant + + def _get_sample_variant_info_ia_float_in_list(self): + variant = vcfio.Variant( + reference_name='chr19', start=11, end=12, reference_bases='C', + alternate_bases=['A', 'TT'], names=['rs1', 'rs2'], quality=2, + filters=['PASS'], + info={'IS': 'some data', + 'ISI': '1', + 'ISF': '1.0', + 'IF': 1.0, + 'IB': True, + 'IA': [1, 0.2]}, + calls=[vcfio.VariantCall( + name='Sample1', genotype=[0, 1], phaseset='*', + info={'FI': 20, 'FU': [10.0, 20.0]})] + ) + return variant + + def _get_sample_variant_info_ia_float_2_0_in_list(self): + variant = vcfio.Variant( + reference_name='chr19', start=11, end=12, reference_bases='C', + alternate_bases=['A', 'TT'], names=['rs1', 'rs2'], quality=2, + filters=['PASS'], + info={'IS': 'some data', + 'ISI': '1', + 'ISF': '1.0', + 'IF': 1.0, + 'IB': True, + 'IA': [1, 2.0]}, + calls=[vcfio.VariantCall( + name='Sample1', genotype=[0, 1], phaseset='*', + info={'FI': 20, 'FU': [10.0, 20.0]})] + ) + return variant + + def _get_sample_variant_format_fi_float_value(self): + variant = vcfio.Variant( + reference_name='chr19', start=11, end=12, reference_bases='C', + alternate_bases=['A', 'TT'], names=['rs1', 'rs2'], quality=2, + filters=['PASS'], + info={'IS': 'some data', + 'ISI': '1', + 'ISF': '1.0', + 'IF': 1.0, + 'IB': True, + 'IA': [0.1, 0.2]}, + calls=[vcfio.VariantCall( + name='Sample1', genotype=[0, 1], phaseset='*', + info={'FI': 20.1, 'FU': [10.0, 20.0]})] + ) + return variant + + def test_infer_mismatched_info_field_no_mismatches(self): + variant = self._get_sample_variant_info_ia_float_2_0_in_list() + infos = {'IS': Info('IS', 1, 'String', '', '', ''), + 'ISI': Info('ISI', 1, 'Integer', '', '', ''), + 'ISF': Info('ISF', 1, 'Float', '', '', ''), + 'IF': Info('IF', 1, 'Float', '', '', ''), + 'IB': Info('IB', 0, 'Flag', '', '', ''), + 'IA': Info('IA', 'A', 'Integer', '', '', '')} + corrected_info = infer_headers_util._infer_mismatched_info_field( + 'IA', variant.info.get('IA'), + vcf_header_io.VcfHeader(infos=infos).infos.get('IA'), + len(variant.alternate_bases)) + self.assertEqual(None, corrected_info) + + def test_infer_mismatched_info_field_correct_num(self): + variant = self._get_sample_variant_info_ia_cardinality_mismatch() + infos = {'IS': Info('IS', 1, 'String', '', '', ''), + 'ISI': Info('ISI', 1, 'Integer', '', '', ''), + 'ISF': Info('ISF', 1, 'Float', '', '', ''), + 'IF': Info('IF', 1, 'Float', '', '', ''), + 'IB': Info('IB', 0, 'Flag', '', '', ''), + 'IA': Info('IA', -1, 'Float', '', '', '')} + corrected_info = infer_headers_util._infer_mismatched_info_field( + 'IA', variant.info.get('IA'), + vcf_header_io.VcfHeader(infos=infos).infos.get('IA'), + len(variant.alternate_bases)) + expected = Info('IA', None, 'Float', '', '', '') + self.assertEqual(expected, corrected_info) + + def test_infer_mismatched_info_field_correct_type(self): + variant = self._get_sample_variant_info_ia_cardinality_mismatch() + infos = {'IS': Info('IS', 1, 'String', '', '', ''), + 'ISI': Info('ISI', 1, 'Integer', '', '', ''), + 'ISF': Info('ISF', 1, 'Float', '', '', ''), + 'IF': Info('IF', 1, 'Float', '', '', ''), + 'IB': Info('IB', 0, 'Flag', '', '', ''), + 'IA': Info('IA', None, 'Integer', '', '', '')} + corrected_info = infer_headers_util._infer_mismatched_info_field( + 'IA', variant.info.get('IA'), + vcf_header_io.VcfHeader(infos=infos).infos.get('IA'), + len(variant.alternate_bases) + ) + expected = Info('IA', None, 'Float', '', '', '') + self.assertEqual(expected, corrected_info) + + def test_infer_mismatched_info_field_correct_type_list(self): + variant = self._get_sample_variant_info_ia_float_in_list() + infos = {'IS': Info('IS', 1, 'String', '', '', ''), + 'ISI': Info('ISI', 1, 'Integer', '', '', ''), + 'ISF': Info('ISF', 1, 'Float', '', '', ''), + 'IF': Info('IF', 1, 'Float', '', '', ''), + 'IB': Info('IB', 0, 'Flag', '', '', ''), + 'IA': Info('IA', None, 'Integer', '', '', '')} + corrected_info = infer_headers_util._infer_mismatched_info_field( + 'IA', variant.info.get('IA'), + vcf_header_io.VcfHeader(infos=infos).infos.get('IA'), + len(variant.alternate_bases) + ) + expected = Info('IA', None, 'Float', '', '', '') + self.assertEqual(expected, corrected_info) + + def test_infer_info_fields_no_conflicts(self): + variant = self._get_sample_variant_1() + infos = {'IS': Info('IS', 1, 'String', '', '', ''), + 'ISI': Info('ISI', 1, 'Integer', '', '', ''), + 'ISF': Info('ISF', 1, 'Float', '', '', ''), + 'IF': Info('IF', 1, 'Float', '', '', ''), + 'IB': Info('IB', 0, 'Flag', '', '', ''), + 'IA': Info('IA', -1, 'Float', '', '', '')} + inferred_infos = infer_headers_util.infer_info_fields( + variant, vcf_header_io.VcfHeader(infos=infos), infer_headers=True) + self.assertEqual({}, inferred_infos) + + def test_infer_info_fields_combined_conflicts(self): + variant = self._get_sample_variant_info_ia_cardinality_mismatch() + infos = {'IS': Info('IS', 1, 'String', '', '', ''), + 'ISI': Info('ISI', 1, 'Integer', '', '', ''), + 'ISF': Info('ISF', 1, 'Float', '', '', ''), + 'IB': Info('IB', 0, 'Flag', '', '', ''), + 'IA': Info('IA', -1, 'Integer', '', '', '')} + inferred_infos = infer_headers_util.infer_info_fields( + variant, vcf_header_io.VcfHeader(infos=infos), infer_headers=True) + expected_infos = {'IF': Info('IF', 1, 'Float', '', '', ''), + 'IA': Info('IA', None, 'Float', '', '', '')} + self.assertEqual(expected_infos, inferred_infos) + + def test_infer_mismatched_format_field(self): + variant = self._get_sample_variant_format_fi_float_value() + formats = OrderedDict([ + ('FS', Format('FS', 1, 'String', 'desc')), + ('FI', Format('FI', 2, 'Integer', 'desc')), + ('FU', Format('FU', field_counts['.'], 'Float', 'desc')), + ('GT', Format('GT', 2, 'Integer', 'Special GT key')), + ('PS', Format('PS', 1, 'Integer', 'Special PS key'))]) + corrected_format = infer_headers_util._infer_mismatched_format_field( + 'FI', variant.calls[0].info.get('FI'), + vcf_header_io.VcfHeader(formats=formats).formats.get('FI')) + expected_formats = Format('FI', 2, 'Float', 'desc') + self.assertEqual(expected_formats, corrected_format) + + def test_infer_format_fields_no_conflicts(self): + variant = self._get_sample_variant_1() + formats = OrderedDict([ + ('FS', Format('FS', 1, 'String', 'desc')), + ('FI', Format('FI', 2, 'Integer', 'desc')), + ('FU', Format('FU', field_counts['.'], 'Float', 'desc')), + ('GT', Format('GT', 2, 'Integer', 'Special GT key')), + ('PS', Format('PS', 1, 'Integer', 'Special PS key'))]) + header = infer_headers_util.infer_format_fields( + variant, vcf_header_io.VcfHeader(formats=formats)) + self.assertEqual({}, header) + + def test_infer_format_fields_combined_conflicts(self): + variant = self._get_sample_variant_format_fi_float_value() + formats = OrderedDict([ + ('FS', Format('FS', 1, 'String', 'desc')), + ('FI', Format('FI', 2, 'Integer', 'desc')), + ('GT', Format('GT', 2, 'Integer', 'Special GT key')), + ('PS', Format('PS', 1, 'Integer', 'Special PS key'))]) + inferred_formats = infer_headers_util.infer_format_fields( + variant, vcf_header_io.VcfHeader(formats=formats)) + expected_formats = {'FI': Format('FI', 2, 'Float', 'desc'), + 'FU': Format('FU', field_counts['.'], 'Float', '')} + self.assertEqual(expected_formats, inferred_formats) + + def _get_annotation_infos(self): + return OrderedDict([ + ('CSQ', Info( + 'CSQ', + field_counts['.'], + 'String', + 'Annotations from VEP. Format: Allele|Gene|Position|Score', + 'src', + 'v')), + ('IS', Info('I1', 1, 'String', 'desc', 'src', 'v')), + ('ISI', Info('ISI', 1, 'Int', 'desc', 'src', 'v')), + ('ISF', Info('ISF', 1, 'Float', 'desc', 'src', 'v')), + ('IF', Info('IF', 1, 'Float', 'desc', 'src', 'v')), + ('IB', Info('I1', 1, 'Flag', 'desc', 'src', 'v')), + ('IA', Info('IA', field_counts['A'], 'Integer', 'desc', 'src', 'v'))]) + + def _get_inferred_info(self, field, annotation, info_type): + return Info( + id='{0}_{1}_TYPE'.format(field, annotation), + num=1, + type=info_type, + desc='Inferred type field for annotation {0}.'.format(annotation), + source='', + version='') + + def test_infer_annotation_empty_info(self): + anno_fields = ['CSQ'] + infos = self._get_annotation_infos() + variant = self._get_sample_variant_1() + variant.info['CSQ'] = [] + inferred_infos = infer_headers_util.infer_info_fields( + variant, vcf_header_io.VcfHeader(infos=infos), False, anno_fields) + self.assertEqual({}, inferred_infos) + + def test_infer_annotation_types_no_conflicts(self): + anno_fields = ['CSQ'] + infos = self._get_annotation_infos() + variant = self._get_sample_variant_1() + variant.info['CSQ'] = ['A|GENE1|100|1.2', 'TT|GENE1|101|1.3'] + inferred_infos = infer_headers_util.infer_info_fields( + variant, vcf_header_io.VcfHeader(infos=infos), True, anno_fields) + expected_infos = { + 'CSQ_Gene_TYPE': self._get_inferred_info('CSQ', 'Gene', 'String'), + 'CSQ_Position_TYPE': + self._get_inferred_info('CSQ', 'Position', 'Integer'), + 'CSQ_Score_TYPE': self._get_inferred_info('CSQ', 'Score', 'Float') + } + self.assertDictEqual(expected_infos, inferred_infos) + + def test_infer_annotation_types_with_type_conflicts(self): + anno_fields = ['CSQ'] + infos = self._get_annotation_infos() + variant = self._get_sample_variant_1() + variant.info['CSQ'] = ['A|1|100|1.2', + 'A|2|101|1.3', + 'A|1.2|start|0', + 'TT|1.3|end|7'] + inferred_infos = infer_headers_util.infer_info_fields( + variant, vcf_header_io.VcfHeader(infos=infos), False, anno_fields) + + expected_infos = { + 'CSQ_Gene_TYPE': self._get_inferred_info('CSQ', 'Gene', 'Float'), + 'CSQ_Position_TYPE': + self._get_inferred_info('CSQ', 'Position', 'String'), + 'CSQ_Score_TYPE': self._get_inferred_info('CSQ', 'Score', 'Float') + } + self.assertDictEqual(expected_infos, inferred_infos) + + def test_infer_annotation_types_with_missing(self): + anno_fields = ['CSQ'] + infos = self._get_annotation_infos() + variant = self._get_sample_variant_1() + variant.info['CSQ'] = ['A||100|', + 'A||101|1.3', + 'A|||1.4', + 'TT|||'] + inferred_infos = infer_headers_util.infer_info_fields( + variant, vcf_header_io.VcfHeader(infos=infos), False, anno_fields) + expected_infos = { + 'CSQ_Gene_TYPE': self._get_inferred_info('CSQ', 'Gene', None), + 'CSQ_Position_TYPE': + self._get_inferred_info('CSQ', 'Position', 'Integer'), + 'CSQ_Score_TYPE': self._get_inferred_info('CSQ', 'Score', 'Float') + } + self.assertDictEqual(expected_infos, inferred_infos) + + def test_infer_annotation_types_with_multiple_annotation_fields(self): + anno_fields = ['CSQ', 'CSQ_VT'] + infos = self._get_annotation_infos() + infos['CSQ_VT'] = Info( + 'CSQ_VT', + -1, + 'String', + 'Annotations from VEP. Format: Allele|Gene|Position|Score', + 'source', + 'v') + variant = self._get_sample_variant_1() + variant.info['CSQ_VT'] = ['A|1|100|1.2', + 'A|2|101|1.3'] + variant.info['CSQ'] = ['A|1|100|1.2', + 'A|2|101|1.3'] + inferred_infos = infer_headers_util.infer_info_fields( + variant, vcf_header_io.VcfHeader(infos=infos), False, anno_fields) + expected_infos = { + 'CSQ_Gene_TYPE': self._get_inferred_info('CSQ', 'Gene', 'Integer'), + 'CSQ_Position_TYPE': + self._get_inferred_info('CSQ', 'Position', 'Integer'), + 'CSQ_Score_TYPE': self._get_inferred_info('CSQ', 'Score', 'Float'), + 'CSQ_VT_Gene_TYPE': + self._get_inferred_info('CSQ_VT', 'Gene', 'Integer'), + 'CSQ_VT_Position_TYPE': + self._get_inferred_info('CSQ_VT', 'Position', 'Integer'), + 'CSQ_VT_Score_TYPE': self._get_inferred_info('CSQ_VT', 'Score', 'Float') + } + self.assertDictEqual(expected_infos, inferred_infos) diff --git a/gcp_variant_transforms/libs/processed_variant.py b/gcp_variant_transforms/libs/processed_variant.py index e9412a493..035eb7527 100644 --- a/gcp_variant_transforms/libs/processed_variant.py +++ b/gcp_variant_transforms/libs/processed_variant.py @@ -38,7 +38,6 @@ from gcp_variant_transforms.libs import bigquery_sanitizer from gcp_variant_transforms.libs.annotation import annotation_parser from gcp_variant_transforms.libs.annotation.vep import descriptions -from gcp_variant_transforms.transforms import infer_headers _FIELD_COUNT_ALTERNATE_ALLELE = 'A' @@ -354,7 +353,7 @@ def _gen_annotation_name_key_pairs(self, annot_field): annotation_names = annotation_parser.extract_annotation_names( self._header_fields.infos[annot_field][_HeaderKeyConstants.DESC]) for name in annotation_names: - type_key = infer_headers.get_inferred_annotation_type_header_key( + type_key = annotation_parser.get_inferred_annotation_type_header_key( annot_field, name) yield name, type_key @@ -476,7 +475,7 @@ def add_annotation_data(self, proc_var, annotation_field_name, data): for name, value in annotation_map.iteritems(): if name == annotation_parser.ANNOTATION_ALT: continue - type_key = infer_headers.get_inferred_annotation_type_header_key( + type_key = annotation_parser.get_inferred_annotation_type_header_key( annotation_field_name, name) vcf_type = self._vcf_type_from_annotation_header( annotation_field_name, type_key) diff --git a/gcp_variant_transforms/transforms/infer_headers.py b/gcp_variant_transforms/transforms/infer_headers.py index 06df0f916..c171d4cf7 100644 --- a/gcp_variant_transforms/transforms/infer_headers.py +++ b/gcp_variant_transforms/transforms/infer_headers.py @@ -1,4 +1,4 @@ -# Copyright 2018 Google Inc. All Rights Reserved. +# Copyright 2019 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,45 +16,23 @@ from __future__ import absolute_import -import logging -from typing import Any, Dict, Iterable, List, Optional, Union # pylint: disable=unused-import +from typing import Iterable, List, Optional # pylint: disable=unused-import import apache_beam as beam -from vcf.parser import _Format as Format -from vcf.parser import _Info as Info -from vcf.parser import field_counts from gcp_variant_transforms.beam_io import vcf_header_io from gcp_variant_transforms.beam_io import vcfio # pylint: disable=unused-import from gcp_variant_transforms.transforms import merge_headers -from gcp_variant_transforms.libs.annotation import annotation_parser -from gcp_variant_transforms.libs import vcf_field_conflict_resolver +from gcp_variant_transforms.libs import infer_headers_util _FIELD_COUNT_ALTERNATE_ALLELE = 'A' -# Filled with annotation field and name data, then used as a header ID. -_BASE_ANNOTATION_TYPE_KEY = '{}_{}_TYPE' - # Alias for the header key/type constants to make referencing easier. _HeaderKeyConstants = vcf_header_io.VcfParserHeaderKeyConstants _HeaderTypeConstants = vcf_header_io.VcfHeaderFieldTypeConstants -def get_inferred_annotation_type_header_key(annot_field, name): - # type: (str, str) -> str - """Creates ID values for annotation type info headers. - - Args: - annot_field: field name representing annotation field (e.g. 'CSQ'). - name: annotation data field names (e.g. 'IMPACT'). - - Returns: - Info ID value (e.g. CSQ_IMPACT_TYPE). - """ - return _BASE_ANNOTATION_TYPE_KEY.format(annot_field, name) - - class _InferHeaderFields(beam.DoFn): """Infers header fields from `Variant` records. @@ -85,317 +63,6 @@ def __init__( self._annotation_fields_to_infer = annotation_fields_to_infer self._infer_headers = infer_headers - def _get_field_count(self, field_value): - # type: (Union[List, bool, int, str]) -> Optional[int] - """ - Args: - field_value: value for the field returned by PyVCF. E.g. [0.33, 0.66] is a - field value for Allele frequency (AF) field. - """ - if isinstance(field_value, list): - return field_counts['.'] - elif isinstance(field_value, bool): - return 0 - else: - return 1 - - def _get_field_type(self, field_value): - """ - Args: - field_value (list, bool, integer, or string): value for the field - returned by PyVCF. E.g. [0.33, 0.66] is a field value for Allele - frequency (AF) field. - """ - if isinstance(field_value, list): - return (self._get_field_type(field_value[0]) if field_value else - vcf_header_io.VcfHeaderFieldTypeConstants.STRING) - - if isinstance(field_value, bool): - return vcf_header_io.VcfHeaderFieldTypeConstants.FLAG - elif isinstance(field_value, int): - return vcf_header_io.VcfHeaderFieldTypeConstants.INTEGER - elif isinstance(field_value, float): - return vcf_header_io.VcfHeaderFieldTypeConstants.FLOAT - elif self._can_cast_to(field_value, int): - return vcf_header_io.VcfHeaderFieldTypeConstants.INTEGER - elif self._can_cast_to(field_value, float): - return vcf_header_io.VcfHeaderFieldTypeConstants.FLOAT - else: - return vcf_header_io.VcfHeaderFieldTypeConstants.STRING - - def _can_cast_to(self, value, cast_type): - """Returns true if `value` can be casted to type `type`""" - try: - _ = cast_type(value) - return True - except ValueError: - return False - - def _get_corrected_type(self, defined_type, value): - # type: (str, Any) -> str - """Returns the corrected type according to `defined_type` and `value`. - - It handles one special case PyVCF cannot handle, i.e., the defined type is - `Integer`, but the provided value is float. In this case, correct the type - to be `Float`. - - Note that if `value` is a float instance but with an integer value - (e.g. 2.0), the type will stay the same as `defined_type`. - """ - if defined_type == _HeaderTypeConstants.INTEGER: - if isinstance(value, float) and not value.is_integer(): - return _HeaderTypeConstants.FLOAT - if isinstance(value, list): - for item in value: - corrected_type = self._get_corrected_type(defined_type, item) - if corrected_type != defined_type: - return corrected_type - return defined_type - - def _infer_mismatched_info_field(self, - field_key, # type: str - field_value, # type: Any - defined_header, # type: Dict - num_alternate_bases # type: int - ): - # type: (...) -> Optional[Info] - """Returns corrected info if there are mismatches. - - Two mismatches are handled: - - Defined num is `A`, but the provided values do not have the same - cardinality as the alternate bases. Correct the num to be `None`. - - Defined type is `Integer`, but the provided value is float. Correct the - type to be `Float`. - Args: - field_key: the info field key. - field_value: the value of the field key given in the variant. - defined_header: The definition of `field_key` in the header. - num_alternate_bases: number of the alternate bases. - Returns: - Corrected info definition if there are mismatches. - """ - corrected_num = defined_header.get(_HeaderKeyConstants.NUM) - if (corrected_num == field_counts[_FIELD_COUNT_ALTERNATE_ALLELE] and - len(field_value) != num_alternate_bases): - corrected_num = field_counts['.'] - - corrected_type = self._get_corrected_type( - defined_header.get(_HeaderKeyConstants.TYPE), field_value) - - if (corrected_type != defined_header.get(_HeaderKeyConstants.TYPE) or - corrected_num != defined_header.get(_HeaderKeyConstants.NUM)): - return Info(field_key, - corrected_num, - corrected_type, - defined_header.get(_HeaderKeyConstants.DESC), - defined_header.get(_HeaderKeyConstants.SOURCE), - defined_header.get(_HeaderKeyConstants.VERSION)) - return None - - def _infer_mismatched_format_field(self, - field_key, # type: str - field_value, # type: Any - defined_header # type: Dict - ): - # type: (...) -> Optional[Format] - """Returns corrected format if there are mismatches. - - One type of mismatches is handled: - - Defined type is `Integer`, but the provided value is float. Correct the - type to be `Float`. - Args: - field_key: the format field key. - field_value: the value of the field key given in the variant. - defined_header: The definition of `field_key` in the header. - Returns: - Corrected format definition if there are mismatches. - """ - corrected_type = self._get_corrected_type( - defined_header.get(_HeaderKeyConstants.TYPE), field_value) - if corrected_type != defined_header.get(_HeaderKeyConstants.TYPE): - return Format(field_key, - defined_header.get(_HeaderKeyConstants.NUM), - corrected_type, - defined_header.get(_HeaderKeyConstants.DESC)) - return None - - def _infer_standard_info_fields(self, variant, infos, defined_headers): - # type: (vcfio.Variant, Dict[str, Info], vcf_header_io.VcfHeader) -> None - """Updates `infos` with inferred info fields. - - Two types of info fields are inferred: - - The info fields are undefined in the headers. - - The info fields' definitions provided by the header does not match the - field value. - Args: - variant: variant object - infos: dict of (info_key, `Info`) for any info field in - `variant` that is not defined in the header or the definition mismatches - the field values. - defined_headers: header fields defined in header section of VCF files. - """ - for info_field_key, info_field_value in variant.info.iteritems(): - if not defined_headers or info_field_key not in defined_headers.infos: - if info_field_key in infos: - raise ValueError( - 'Duplicate INFO field "{}" in variant "{}"'.format( - info_field_key, variant)) - logging.warning('Undefined INFO field "%s" in variant "%s"', - info_field_key, str(variant)) - infos[info_field_key] = Info(info_field_key, - self._get_field_count(info_field_value), - self._get_field_type(info_field_value), - '', # NO_DESCRIPTION - '', # UNKNOWN_SOURCE - '') # UNKNOWN_VERSION - else: - defined_header = defined_headers.infos.get(info_field_key) - corrected_info = self._infer_mismatched_info_field( - info_field_key, info_field_value, - defined_header, len(variant.alternate_bases)) - if corrected_info: - logging.warning( - 'Incorrect INFO field "%s". Defined as "type=%s,num=%s", ' - 'got "%s", in variant "%s"', - info_field_key, defined_header.get(_HeaderKeyConstants.TYPE), - str(defined_header.get(_HeaderKeyConstants.NUM)), - str(info_field_value), str(variant)) - infos[info_field_key] = corrected_info - - def _infer_annotation_type_info_fields(self, variant, infos, defined_headers): - # type: (vcfio.Variant, Dict[str, Info], vcf_header_io.VcfHeader) -> None - """Updates `infos` with inferred annotation type info fields. - - All annotation headers in each annotation field are converted to Info header - lines where the new ID corresponds to the given annotation field and header, - and the new TYPE corresponds to inferred type of the original header. Since - each variant potentially contains multiple values for each annotation - header, a small 'merge' of value types is performed before VcfHeader - creation for each variant. - Args: - variant: variant object - infos: dict of (info_key, `Info`) for any info field in - `variant` that is not defined in the header or the definition mismatches - the field values. - defined_headers: header fields defined in header section of VCF files. - """ - - def _check_annotation_lists_lengths(names, values): - lengths = set(len(v) for v in values) - lengths.add(len(names)) - if len(lengths) != 1: - error = ('Annotation lists have inconsistent lengths: {}.\nnames={}\n' - 'values={}').format(lengths, names, values) - raise ValueError(error) - - resolver = vcf_field_conflict_resolver.FieldConflictResolver( - resolve_always=True) - for field in self._annotation_fields_to_infer: - if field not in variant.info: - continue - annotation_names = annotation_parser.extract_annotation_names( - defined_headers.infos[field][_HeaderKeyConstants.DESC]) - # First element (ALT) is ignored, since its type is hard-coded as string - annotation_values = [annotation_parser.extract_annotation_list_with_alt( - annotation)[1:] for annotation in variant.info[field]] - _check_annotation_lists_lengths(annotation_names, annotation_values) - annotation_values = zip(*annotation_values) - for name, values in zip(annotation_names, annotation_values): - variant_merged_type = None - for v in values: - if not v: - continue - variant_merged_type = resolver.resolve_attribute_conflict( - _HeaderKeyConstants.TYPE, - variant_merged_type, - self._get_field_type(v)) - if variant_merged_type == _HeaderTypeConstants.STRING: - break - key_id = get_inferred_annotation_type_header_key(field, name) - infos[key_id] = Info(key_id, - 1, # field count - variant_merged_type, - ('Inferred type field for annotation {}.'.format( - name)), - '', # UNKNOWN_SOURCE - '') # UNKNOWN_VERSION - - def _infer_info_fields(self, variant, defined_headers): - """Returns inferred info fields. - - Up to three types of info fields are inferred: - - if `infer_headers` is True: - - The info fields are undefined in the headers. - - The info fields' definitions provided by the header does not match the - field value. - if `infer_annotation_types` is True: - - Fields containing type information of corresponding annotation Info - fields. - - Args: - variant: variant object - defined_headers: header fields defined in header section of VCF files. - Returns: - infos: dict of (info_key, `Info`) for any info field in - `variant` that is not defined in the header or the definition mismatches - the field values. - """ - infos = {} - if self._infer_headers: - self._infer_standard_info_fields(variant, infos, defined_headers) - if self._annotation_fields_to_infer: - self._infer_annotation_type_info_fields(variant, infos, defined_headers) - return infos - - def _infer_format_fields(self, variant, defined_headers): - # type: (vcfio.Variant, vcf_header_io.VcfHeader) -> Dict[str, Format] - """Returns inferred format fields. - - Two types of format fields are inferred: - - The format fields are undefined in the headers. - - The format definition provided by the headers does not match the field - values. - Args: - variant: variant object - defined_headers: header fields defined in header section of VCF files. - Returns: - A dict of (format_key, `Format`) for any format key in - `variant` that is not defined in the header or the definition mismatches - the field values. - """ - formats = {} - for call in variant.calls: - for format_key, format_value in call.info.iteritems(): - if not defined_headers or format_key not in defined_headers.formats: - if format_key in formats: - raise ValueError( - 'Duplicate FORMAT field "{}" in variant "{}"'.format( - format_key, variant)) - logging.warning('Undefined FORMAT field "%s" in variant "%s"', - format_key, str(variant)) - formats[format_key] = Format(format_key, - self._get_field_count(format_value), - self._get_field_type(format_value), - '') # NO_DESCRIPTION - # No point in proceeding. All other calls have the same FORMAT. - break - for call in variant.calls: - for format_key, format_value in call.info.iteritems(): - if defined_headers and format_key in defined_headers.formats: - defined_header = defined_headers.formats.get(format_key) - corrected_format = self._infer_mismatched_format_field( - format_key, format_value, defined_header) - if corrected_format: - logging.warning( - 'Incorrect FORMAT field "%s". Defined as "type=%s,num=%s", ' - 'got "%s" in variant "%s"', - format_key, defined_header.get(_HeaderKeyConstants.TYPE), - str(defined_header.get(_HeaderKeyConstants.NUM)), - str(format_value), str(variant)) - formats[format_key] = corrected_format - return formats - def process(self, variant, # type: vcfio.Variant defined_headers # type: vcf_header_io.VcfHeader @@ -405,10 +72,14 @@ def process(self, Args: defined_headers: header fields defined in header section of VCF files. """ - infos = self._infer_info_fields(variant, defined_headers) + infos = infer_headers_util.infer_info_fields( + variant, + defined_headers, + self._infer_headers, + self._annotation_fields_to_infer) formats = {} if self._infer_headers: - formats = self._infer_format_fields(variant, defined_headers) + formats = infer_headers_util.infer_format_fields(variant, defined_headers) yield vcf_header_io.VcfHeader(infos=infos, formats=formats) diff --git a/gcp_variant_transforms/transforms/infer_headers_test.py b/gcp_variant_transforms/transforms/infer_headers_test.py index 217da4fc1..72c26ec81 100644 --- a/gcp_variant_transforms/transforms/infer_headers_test.py +++ b/gcp_variant_transforms/transforms/infer_headers_test.py @@ -1,4 +1,4 @@ -# Copyright 2018 Google Inc. All Rights Reserved. +# Copyright 2019 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ from gcp_variant_transforms.transforms import infer_headers + class InferHeaderFieldsTest(unittest.TestCase): """Test case for `InferHeaderFields` DoFn.""" @@ -255,139 +256,6 @@ def test_defined_fields_filtered_two_variants(self): asserts.header_fields_equal_ignore_order([expected])) p.run() - def test_infer_mismatched_info_field_no_mismatches(self): - variant = self._get_sample_variant_info_ia_float_2_0_in_list() - infos = {'IS': Info('IS', 1, 'String', '', '', ''), - 'ISI': Info('ISI', 1, 'Integer', '', '', ''), - 'ISF': Info('ISF', 1, 'Float', '', '', ''), - 'IF': Info('IF', 1, 'Float', '', '', ''), - 'IB': Info('IB', 0, 'Flag', '', '', ''), - 'IA': Info('IA', 'A', 'Integer', '', '', '')} - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - corrected_info = infer_header_fields._infer_mismatched_info_field( - 'IA', variant.info.get('IA'), - vcf_header_io.VcfHeader(infos=infos).infos.get('IA'), - len(variant.alternate_bases)) - self.assertEqual(None, corrected_info) - - def test_infer_mismatched_info_field_correct_num(self): - variant = self._get_sample_variant_info_ia_cardinality_mismatch() - infos = {'IS': Info('IS', 1, 'String', '', '', ''), - 'ISI': Info('ISI', 1, 'Integer', '', '', ''), - 'ISF': Info('ISF', 1, 'Float', '', '', ''), - 'IF': Info('IF', 1, 'Float', '', '', ''), - 'IB': Info('IB', 0, 'Flag', '', '', ''), - 'IA': Info('IA', -1, 'Float', '', '', '')} - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - corrected_info = infer_header_fields._infer_mismatched_info_field( - 'IA', variant.info.get('IA'), - vcf_header_io.VcfHeader(infos=infos).infos.get('IA'), - len(variant.alternate_bases)) - expected = Info('IA', None, 'Float', '', '', '') - self.assertEqual(expected, corrected_info) - - def test_infer_mismatched_info_field_correct_type(self): - variant = self._get_sample_variant_info_ia_cardinality_mismatch() - infos = {'IS': Info('IS', 1, 'String', '', '', ''), - 'ISI': Info('ISI', 1, 'Integer', '', '', ''), - 'ISF': Info('ISF', 1, 'Float', '', '', ''), - 'IF': Info('IF', 1, 'Float', '', '', ''), - 'IB': Info('IB', 0, 'Flag', '', '', ''), - 'IA': Info('IA', None, 'Integer', '', '', '')} - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - corrected_info = infer_header_fields._infer_mismatched_info_field( - 'IA', variant.info.get('IA'), - vcf_header_io.VcfHeader(infos=infos).infos.get('IA'), - len(variant.alternate_bases) - ) - expected = Info('IA', None, 'Float', '', '', '') - self.assertEqual(expected, corrected_info) - - def test_infer_mismatched_info_field_correct_type_list(self): - variant = self._get_sample_variant_info_ia_float_in_list() - infos = {'IS': Info('IS', 1, 'String', '', '', ''), - 'ISI': Info('ISI', 1, 'Integer', '', '', ''), - 'ISF': Info('ISF', 1, 'Float', '', '', ''), - 'IF': Info('IF', 1, 'Float', '', '', ''), - 'IB': Info('IB', 0, 'Flag', '', '', ''), - 'IA': Info('IA', None, 'Integer', '', '', '')} - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - corrected_info = infer_header_fields._infer_mismatched_info_field( - 'IA', variant.info.get('IA'), - vcf_header_io.VcfHeader(infos=infos).infos.get('IA'), - len(variant.alternate_bases) - ) - expected = Info('IA', None, 'Float', '', '', '') - self.assertEqual(expected, corrected_info) - - def test_infer_info_fields_no_conflicts(self): - variant = self._get_sample_variant_1() - infos = {'IS': Info('IS', 1, 'String', '', '', ''), - 'ISI': Info('ISI', 1, 'Integer', '', '', ''), - 'ISF': Info('ISF', 1, 'Float', '', '', ''), - 'IF': Info('IF', 1, 'Float', '', '', ''), - 'IB': Info('IB', 0, 'Flag', '', '', ''), - 'IA': Info('IA', -1, 'Float', '', '', '')} - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - inferred_infos = infer_header_fields._infer_info_fields( - variant, vcf_header_io.VcfHeader(infos=infos)) - self.assertEqual({}, inferred_infos) - - def test_infer_info_fields_combined_conflicts(self): - variant = self._get_sample_variant_info_ia_cardinality_mismatch() - infos = {'IS': Info('IS', 1, 'String', '', '', ''), - 'ISI': Info('ISI', 1, 'Integer', '', '', ''), - 'ISF': Info('ISF', 1, 'Float', '', '', ''), - 'IB': Info('IB', 0, 'Flag', '', '', ''), - 'IA': Info('IA', -1, 'Integer', '', '', '')} - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - inferred_infos = infer_header_fields._infer_info_fields( - variant, vcf_header_io.VcfHeader(infos=infos)) - expected_infos = {'IF': Info('IF', 1, 'Float', '', '', ''), - 'IA': Info('IA', None, 'Float', '', '', '')} - self.assertEqual(expected_infos, inferred_infos) - - def test_infer_mismatched_format_field(self): - variant = self._get_sample_variant_format_fi_float_value() - formats = OrderedDict([ - ('FS', Format('FS', 1, 'String', 'desc')), - ('FI', Format('FI', 2, 'Integer', 'desc')), - ('FU', Format('FU', field_counts['.'], 'Float', 'desc')), - ('GT', Format('GT', 2, 'Integer', 'Special GT key')), - ('PS', Format('PS', 1, 'Integer', 'Special PS key'))]) - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - corrected_format = infer_header_fields._infer_mismatched_format_field( - 'FI', variant.calls[0].info.get('FI'), - vcf_header_io.VcfHeader(formats=formats).formats.get('FI')) - expected_formats = Format('FI', 2, 'Float', 'desc') - self.assertEqual(expected_formats, corrected_format) - - def test_infer_format_fields_no_conflicts(self): - variant = self._get_sample_variant_1() - formats = OrderedDict([ - ('FS', Format('FS', 1, 'String', 'desc')), - ('FI', Format('FI', 2, 'Integer', 'desc')), - ('FU', Format('FU', field_counts['.'], 'Float', 'desc')), - ('GT', Format('GT', 2, 'Integer', 'Special GT key')), - ('PS', Format('PS', 1, 'Integer', 'Special PS key'))]) - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - header = infer_header_fields._infer_format_fields( - variant, vcf_header_io.VcfHeader(formats=formats)) - self.assertEqual({}, header) - - def test_infer_format_fields_combined_conflicts(self): - variant = self._get_sample_variant_format_fi_float_value() - formats = OrderedDict([ - ('FS', Format('FS', 1, 'String', 'desc')), - ('FI', Format('FI', 2, 'Integer', 'desc')), - ('GT', Format('GT', 2, 'Integer', 'Special GT key')), - ('PS', Format('PS', 1, 'Integer', 'Special PS key'))]) - infer_header_fields = infer_headers._InferHeaderFields(infer_headers=True) - inferred_formats = infer_header_fields._infer_format_fields( - variant, vcf_header_io.VcfHeader(formats=formats)) - expected_formats = {'FI': Format('FI', 2, 'Float', 'desc'), - 'FU': Format('FU', field_counts['.'], 'Float', '')} - self.assertEqual(expected_formats, inferred_formats) def test_pipeline(self): infos = {'IS': Info('IS', 1, 'String', '', '', ''), diff --git a/gcp_variant_transforms/transforms/variant_to_bigquery.py b/gcp_variant_transforms/transforms/variant_to_bigquery.py index a1ae5e3cc..2e80cbda8 100644 --- a/gcp_variant_transforms/transforms/variant_to_bigquery.py +++ b/gcp_variant_transforms/transforms/variant_to_bigquery.py @@ -16,15 +16,10 @@ from __future__ import absolute_import -import exceptions import random -import re from typing import Dict, List # pylint: disable=unused-import import apache_beam as beam -from apache_beam.io.gcp.internal.clients import bigquery -from apitools.base.py import exceptions -from oauth2client.client import GoogleCredentials from gcp_variant_transforms.beam_io import vcf_header_io # pylint: disable=unused-import from gcp_variant_transforms.libs import bigquery_schema_descriptor @@ -132,7 +127,8 @@ def __init__( self._omit_empty_sample_calls = omit_empty_sample_calls self._num_bigquery_write_shards = num_bigquery_write_shards if update_schema_on_append: - self._update_bigquery_schema_on_append() + bigquery_util.update_bigquery_schema_on_append(self._schema.fields, + self._output_table) def expand(self, pcoll): bq_rows = pcoll | 'ConvertToBigQueryTableRow' >> beam.ParDo( @@ -172,83 +168,3 @@ def expand(self, pcoll): beam.io.BigQueryDisposition.WRITE_APPEND if self._append else beam.io.BigQueryDisposition.WRITE_TRUNCATE)))) - - def _update_bigquery_schema_on_append(self): - # type: (bool) -> None - # if table does not exist, do not need to update the schema. - # TODO (yifangchen): Move the logic into validate(). - output_table_re_match = re.match( - r'^((?P.+):)(?P\w+)\.(?P
[\w\$]+)$', - self._output_table) - credentials = GoogleCredentials.get_application_default().create_scoped( - ['https://www.googleapis.com/auth/bigquery']) - client = bigquery.BigqueryV2(credentials=credentials) - try: - project_id = output_table_re_match.group('project') - dataset_id = output_table_re_match.group('dataset') - table_id = output_table_re_match.group('table') - existing_table = client.tables.Get(bigquery.BigqueryTablesGetRequest( - projectId=project_id, - datasetId=dataset_id, - tableId=table_id)) - except exceptions.HttpError: - return - - new_schema = bigquery.TableSchema() - new_schema.fields = _get_merged_field_schemas(existing_table.schema.fields, - self._schema.fields) - existing_table.schema = new_schema - try: - client.tables.Update(bigquery.BigqueryTablesUpdateRequest( - projectId=project_id, - datasetId=dataset_id, - table=existing_table, - tableId=table_id)) - except exceptions.HttpError as e: - raise RuntimeError('BigQuery schema update failed: %s' % str(e)) - - -def _get_merged_field_schemas( - field_schemas_1, # type: List[bigquery.TableFieldSchema] - field_schemas_2 # type: List[bigquery.TableFieldSchema] - ): - # type: (...) -> List[bigquery.TableFieldSchema] - """Merges the `field_schemas_1` and `field_schemas_2`. - - Args: - field_schemas_1: A list of `TableFieldSchema`. - field_schemas_2: A list of `TableFieldSchema`. - Returns: - A new schema with new fields from `field_schemas_2` appended to - `field_schemas_1`. - Raises: - ValueError: If there are fields with the same name, but different modes or - different types. - """ - existing_fields = {} # type: Dict[str, bigquery.TableFieldSchema] - merged_field_schemas = [] # type: List[bigquery.TableFieldSchema] - for field_schema in field_schemas_1: - existing_fields.update({field_schema.name: field_schema}) - merged_field_schemas.append(field_schema) - - for field_schema in field_schemas_2: - if field_schema.name not in existing_fields.keys(): - merged_field_schemas.append(field_schema) - else: - existing_field_schema = existing_fields.get(field_schema.name) - if field_schema.mode != existing_field_schema.mode: - raise ValueError( - 'The mode of field {} is not compatible. The original mode is {}, ' - 'and the new mode is {}.'.format(field_schema.name, - existing_field_schema.mode, - field_schema.mode)) - if field_schema.type != existing_field_schema.type: - raise ValueError( - 'The type of field {} is not compatible. The original type is {}, ' - 'and the new type is {}.'.format(field_schema.name, - existing_field_schema.type, - field_schema.type)) - if field_schema.type == bigquery_util.TableFieldConstants.TYPE_RECORD: - existing_field_schema.fields = _get_merged_field_schemas( - existing_field_schema.fields, field_schema.fields) - return merged_field_schemas diff --git a/gcp_variant_transforms/transforms/variant_to_bigquery_test.py b/gcp_variant_transforms/transforms/variant_to_bigquery_test.py index 17d106918..965154abb 100644 --- a/gcp_variant_transforms/transforms/variant_to_bigquery_test.py +++ b/gcp_variant_transforms/transforms/variant_to_bigquery_test.py @@ -33,7 +33,6 @@ from gcp_variant_transforms.libs.bigquery_util import ColumnKeyConstants from gcp_variant_transforms.libs.bigquery_util import TableFieldConstants from gcp_variant_transforms.testing import vcf_header_util -from gcp_variant_transforms.transforms import variant_to_bigquery from gcp_variant_transforms.transforms.variant_to_bigquery import ConvertVariantToRow @@ -277,325 +276,3 @@ def test_convert_variant_to_bigquery_row_allow_incompatible_recoreds(self): self._row_generator, allow_incompatible_records=True))) assert_that(bigquery_rows, equal_to([row])) pipeline.run() - - def test_merge_field_schemas_no_same_id(self): - field_schemas_1 = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc'), - bigquery.TableFieldSchema( - name='IFR', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_REPEATED, - description='INFO foo desc') - ] - field_schemas_2 = [ - bigquery.TableFieldSchema( - name='AB', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc') - ] - merged_field_schemas = variant_to_bigquery._get_merged_field_schemas( - field_schemas_1, field_schemas_2) - expected_merged_field_schemas = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc'), - bigquery.TableFieldSchema( - name='IFR', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_REPEATED, - description='INFO foo desc'), - bigquery.TableFieldSchema( - name='AB', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc') - ] - self.assertEqual(merged_field_schemas, expected_merged_field_schemas) - - def test_merge_field_schemas_same_id_no_conflicts(self): - field_schemas_1 = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc'), - bigquery.TableFieldSchema( - name='IFR', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_REPEATED, - description='INFO foo desc') - ] - field_schemas_2 = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc'), - bigquery.TableFieldSchema( - name='AB', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc') - ] - merged_field_schemas = variant_to_bigquery._get_merged_field_schemas( - field_schemas_1, field_schemas_2) - expected_merged_field_schemas = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc'), - bigquery.TableFieldSchema( - name='IFR', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_REPEATED, - description='INFO foo desc'), - bigquery.TableFieldSchema( - name='AB', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc') - ] - self.assertEqual(merged_field_schemas, expected_merged_field_schemas) - - def test_merge_field_schemas_conflict_mode(self): - field_schemas_1 = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc') - ] - field_schemas_2 = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_REPEATED, - description='INFO foo desc') - ] - self.assertRaises(ValueError, variant_to_bigquery._get_merged_field_schemas, - field_schemas_1, field_schemas_2) - - def test_merge_field_schemas_conflict_type(self): - field_schemas_1 = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc') - ] - field_schemas_2 = [ - bigquery.TableFieldSchema( - name='II', - type=TableFieldConstants.TYPE_FLOAT, - mode=TableFieldConstants.MODE_NULLABLE, - description='INFO foo desc') - ] - self.assertRaises(ValueError, variant_to_bigquery._get_merged_field_schemas, - field_schemas_1, field_schemas_2) - - def test_merge_field_schemas_conflict_record_fields(self): - call_record_1 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - call_record_1.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_BOOLEAN, - mode=TableFieldConstants.MODE_NULLABLE, - description='FORMAT foo desc')) - field_schemas_1 = [call_record_1] - - call_record_2 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - call_record_2.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='FORMAT foo desc')) - field_schemas_2 = [call_record_2] - self.assertRaises(ValueError, variant_to_bigquery._get_merged_field_schemas, - field_schemas_1, field_schemas_2) - - def test_merge_field_schemas_same_record(self): - call_record_1 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - call_record_1.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_BOOLEAN, - mode=TableFieldConstants.MODE_NULLABLE, - description='FORMAT foo desc')) - - field_schemas_1 = [call_record_1] - field_schemas_2 = [call_record_1] - - expected_merged_field_schemas = [call_record_1] - self.assertEqual( - variant_to_bigquery._get_merged_field_schemas(field_schemas_1, - field_schemas_2), - expected_merged_field_schemas) - - def test_merge_field_schemas_merge_record_fields(self): - call_record_1 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - call_record_1.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_BOOLEAN, - mode=TableFieldConstants.MODE_NULLABLE, - description='FORMAT foo desc')) - - field_schemas_1 = [call_record_1] - - call_record_2 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - call_record_2.fields.append(bigquery.TableFieldSchema( - name='GQ', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='FORMAT foo desc')) - field_schemas_2 = [call_record_2] - - call_record_3 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - call_record_3.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_BOOLEAN, - mode=TableFieldConstants.MODE_NULLABLE, - description='FORMAT foo desc')) - call_record_3.fields.append(bigquery.TableFieldSchema( - name='GQ', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_NULLABLE, - description='FORMAT foo desc')) - - expected_merged_field_schemas = [call_record_3] - self.assertEqual( - variant_to_bigquery._get_merged_field_schemas(field_schemas_1, - field_schemas_2), - expected_merged_field_schemas) - - def test_merge_field_schemas_conflict_inner_record_fields(self): - record_1 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - inner_record_1 = bigquery.TableFieldSchema( - name='inner record', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - inner_record_1.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='FORMAT foo desc')) - record_1.fields.append(inner_record_1) - field_schemas_1 = [record_1] - - record_2 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - inner_record_2 = bigquery.TableFieldSchema( - name='inner record', - type=TableFieldConstants.TYPE_INTEGER, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - inner_record_2.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='FORMAT foo desc')) - record_2.fields.append(inner_record_2) - field_schemas_2 = [record_2] - self.assertRaises(ValueError, variant_to_bigquery._get_merged_field_schemas, - field_schemas_1, field_schemas_2) - - def test_merge_field_schemas_merge_inner_record_fields(self): - record_1 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - inner_record_1 = bigquery.TableFieldSchema( - name='inner record', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - inner_record_1.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='FORMAT foo desc')) - record_1.fields.append(inner_record_1) - field_schemas_1 = [record_1] - - record_2 = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - inner_record_2 = bigquery.TableFieldSchema( - name='inner record', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - inner_record_2.fields.append(bigquery.TableFieldSchema( - name='AB', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='FORMAT foo desc')) - record_2.fields.append(inner_record_2) - field_schemas_2 = [record_2] - - merged_record = bigquery.TableFieldSchema( - name=ColumnKeyConstants.CALLS, - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - merged_inner_record = bigquery.TableFieldSchema( - name='inner record', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='One record for each call.') - merged_inner_record.fields.append(bigquery.TableFieldSchema( - name='FB', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='FORMAT foo desc')) - merged_inner_record.fields.append(bigquery.TableFieldSchema( - name='AB', - type=TableFieldConstants.TYPE_RECORD, - mode=TableFieldConstants.MODE_REPEATED, - description='FORMAT foo desc')) - merged_record.fields.append(merged_inner_record) - expected_merged_field_schemas = [merged_record] - self.assertEqual( - variant_to_bigquery._get_merged_field_schemas(field_schemas_1, - field_schemas_2), - expected_merged_field_schemas) From aa369be07cd740d3dd9fe4b19433ba667d453f58 Mon Sep 17 00:00:00 2001 From: Tural Neymanov Date: Mon, 28 Jan 2019 13:12:50 -0500 Subject: [PATCH 2/4] Applied the requested changes. --- .../libs/annotation/annotation_parser.py | 17 -- gcp_variant_transforms/libs/bigquery_util.py | 15 +- .../libs/bigquery_util_test.py | 25 +-- .../libs/infer_headers_util.py | 210 ++++++++++-------- .../libs/processed_variant.py | 5 +- 5 files changed, 140 insertions(+), 132 deletions(-) diff --git a/gcp_variant_transforms/libs/annotation/annotation_parser.py b/gcp_variant_transforms/libs/annotation/annotation_parser.py index b33eef130..f526539cb 100644 --- a/gcp_variant_transforms/libs/annotation/annotation_parser.py +++ b/gcp_variant_transforms/libs/annotation/annotation_parser.py @@ -43,9 +43,6 @@ _BREAKEND_ALT_RE = (re.compile( r'^(?P.*([\[\]]).*):(?P.*)([\[\]]).*$')) -# Filled with annotation field and name data, then used as a header ID. -_BASE_ANNOTATION_TYPE_KEY = '{}_{}_TYPE' - class AnnotationParserException(Exception): pass @@ -423,17 +420,3 @@ def reconstruct_annotation_description(annotation_names): returns 'Format: Allele|Consequence|IMPACT|SYMBOL|Gene'. """ return ' '.join(['Format:', '|'.join(annotation_names)]) - - -def get_inferred_annotation_type_header_key(annot_field, name): - # type: (str, str) -> str - """Creates ID values for annotation type info headers. - - Args: - annot_field: field name representing annotation field (e.g. 'CSQ'). - name: annotation data field names (e.g. 'IMPACT'). - - Returns: - Info ID value (e.g. CSQ_IMPACT_TYPE). - """ - return _BASE_ANNOTATION_TYPE_KEY.format(annot_field, name) diff --git a/gcp_variant_transforms/libs/bigquery_util.py b/gcp_variant_transforms/libs/bigquery_util.py index c7d76db9a..b83bef5d7 100644 --- a/gcp_variant_transforms/libs/bigquery_util.py +++ b/gcp_variant_transforms/libs/bigquery_util.py @@ -14,8 +14,8 @@ """Constants and simple utility functions related to BigQuery.""" -import exceptions import enum +import exceptions import re from typing import List, Tuple, Union # pylint: disable=unused-import @@ -197,9 +197,10 @@ def get_avro_type_from_bigquery_type_mode(bigquery_type, bigquery_mode): return avro_type def update_bigquery_schema_on_append(schema_fields, output_table): - # type: (bool) -> None + # type: (bool, str) -> None # if table does not exist, do not need to update the schema. # TODO (yifangchen): Move the logic into validate(). + """Update BQ schema by combining existing one with a new one, if possible.""" output_table_re_match = re.match( r'^((?P.+):)(?P\w+)\.(?P
[\w\$]+)$', output_table) @@ -218,8 +219,8 @@ def update_bigquery_schema_on_append(schema_fields, output_table): return new_schema = bigquery.TableSchema() - new_schema.fields = get_merged_field_schemas(existing_table.schema.fields, - schema_fields) + new_schema.fields = _get_merged_field_schemas(existing_table.schema.fields, + schema_fields) existing_table.schema = new_schema try: client.tables.Update(bigquery.BigqueryTablesUpdateRequest( @@ -231,7 +232,7 @@ def update_bigquery_schema_on_append(schema_fields, output_table): raise RuntimeError('BigQuery schema update failed: %s' % str(e)) -def get_merged_field_schemas( +def _get_merged_field_schemas( field_schemas_1, # type: List[bigquery.TableFieldSchema] field_schemas_2 # type: List[bigquery.TableFieldSchema] ): @@ -241,9 +242,11 @@ def get_merged_field_schemas( Args: field_schemas_1: A list of `TableFieldSchema`. field_schemas_2: A list of `TableFieldSchema`. + Returns: A new schema with new fields from `field_schemas_2` appended to `field_schemas_1`. + Raises: ValueError: If there are fields with the same name, but different modes or different types. @@ -272,6 +275,6 @@ def get_merged_field_schemas( existing_field_schema.type, field_schema.type)) if field_schema.type == TableFieldConstants.TYPE_RECORD: - existing_field_schema.fields = get_merged_field_schemas( + existing_field_schema.fields = _get_merged_field_schemas( existing_field_schema.fields, field_schema.fields) return merged_field_schemas diff --git a/gcp_variant_transforms/libs/bigquery_util_test.py b/gcp_variant_transforms/libs/bigquery_util_test.py index cc235c100..59d37f52b 100644 --- a/gcp_variant_transforms/libs/bigquery_util_test.py +++ b/gcp_variant_transforms/libs/bigquery_util_test.py @@ -92,7 +92,6 @@ def test_get_vcf_num_from_bigquery_schema(self): bigquery_type=bigquery_util.TableFieldConstants.TYPE_BOOLEAN)) - def test_merge_field_schemas_no_same_id(self): field_schemas_1 = [ bigquery.TableFieldSchema( @@ -113,7 +112,7 @@ def test_merge_field_schemas_no_same_id(self): mode=TableFieldConstants.MODE_NULLABLE, description='INFO foo desc') ] - merged_field_schemas = bigquery_util.get_merged_field_schemas( + merged_field_schemas = bigquery_util._get_merged_field_schemas( field_schemas_1, field_schemas_2) expected_merged_field_schemas = [ bigquery.TableFieldSchema( @@ -159,7 +158,7 @@ def test_merge_field_schemas_same_id_no_conflicts(self): mode=TableFieldConstants.MODE_NULLABLE, description='INFO foo desc') ] - merged_field_schemas = bigquery_util.get_merged_field_schemas( + merged_field_schemas = bigquery_util._get_merged_field_schemas( field_schemas_1, field_schemas_2) expected_merged_field_schemas = [ bigquery.TableFieldSchema( @@ -195,7 +194,7 @@ def test_merge_field_schemas_conflict_mode(self): mode=TableFieldConstants.MODE_REPEATED, description='INFO foo desc') ] - self.assertRaises(ValueError, bigquery_util.get_merged_field_schemas, + self.assertRaises(ValueError, bigquery_util._get_merged_field_schemas, field_schemas_1, field_schemas_2) def test_merge_field_schemas_conflict_type(self): @@ -213,7 +212,7 @@ def test_merge_field_schemas_conflict_type(self): mode=TableFieldConstants.MODE_NULLABLE, description='INFO foo desc') ] - self.assertRaises(ValueError, bigquery_util.get_merged_field_schemas, + self.assertRaises(ValueError, bigquery_util._get_merged_field_schemas, field_schemas_1, field_schemas_2) def test_merge_field_schemas_conflict_record_fields(self): @@ -240,7 +239,7 @@ def test_merge_field_schemas_conflict_record_fields(self): mode=TableFieldConstants.MODE_NULLABLE, description='FORMAT foo desc')) field_schemas_2 = [call_record_2] - self.assertRaises(ValueError, bigquery_util.get_merged_field_schemas, + self.assertRaises(ValueError, bigquery_util._get_merged_field_schemas, field_schemas_1, field_schemas_2) def test_merge_field_schemas_same_record(self): @@ -260,8 +259,8 @@ def test_merge_field_schemas_same_record(self): expected_merged_field_schemas = [call_record_1] self.assertEqual( - bigquery_util.get_merged_field_schemas(field_schemas_1, - field_schemas_2), + bigquery_util._get_merged_field_schemas(field_schemas_1, + field_schemas_2), expected_merged_field_schemas) def test_merge_field_schemas_merge_record_fields(self): @@ -308,8 +307,8 @@ def test_merge_field_schemas_merge_record_fields(self): expected_merged_field_schemas = [call_record_3] self.assertEqual( - bigquery_util.get_merged_field_schemas(field_schemas_1, - field_schemas_2), + bigquery_util._get_merged_field_schemas(field_schemas_1, + field_schemas_2), expected_merged_field_schemas) def test_merge_field_schemas_conflict_inner_record_fields(self): @@ -348,7 +347,7 @@ def test_merge_field_schemas_conflict_inner_record_fields(self): description='FORMAT foo desc')) record_2.fields.append(inner_record_2) field_schemas_2 = [record_2] - self.assertRaises(ValueError, bigquery_util.get_merged_field_schemas, + self.assertRaises(ValueError, bigquery_util._get_merged_field_schemas, field_schemas_1, field_schemas_2) def test_merge_field_schemas_merge_inner_record_fields(self): @@ -411,6 +410,6 @@ def test_merge_field_schemas_merge_inner_record_fields(self): merged_record.fields.append(merged_inner_record) expected_merged_field_schemas = [merged_record] self.assertEqual( - bigquery_util.get_merged_field_schemas(field_schemas_1, - field_schemas_2), + bigquery_util._get_merged_field_schemas(field_schemas_1, + field_schemas_2), expected_merged_field_schemas) diff --git a/gcp_variant_transforms/libs/infer_headers_util.py b/gcp_variant_transforms/libs/infer_headers_util.py index b393456dd..19cd4d4b2 100644 --- a/gcp_variant_transforms/libs/infer_headers_util.py +++ b/gcp_variant_transforms/libs/infer_headers_util.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A Helper module for Header Inference operations.""" +"""A Helper module for header inference operations.""" from __future__ import absolute_import @@ -25,8 +25,8 @@ from gcp_variant_transforms.beam_io import vcf_header_io from gcp_variant_transforms.beam_io import vcfio # pylint: disable=unused-import -from gcp_variant_transforms.libs.annotation import annotation_parser from gcp_variant_transforms.libs import vcf_field_conflict_resolver +from gcp_variant_transforms.libs.annotation import annotation_parser _FIELD_COUNT_ALTERNATE_ALLELE = 'A' @@ -34,6 +34,110 @@ _HeaderKeyConstants = vcf_header_io.VcfParserHeaderKeyConstants _HeaderTypeConstants = vcf_header_io.VcfHeaderFieldTypeConstants +# Filled with annotation field and name data, then used as a header ID. +_BASE_ANNOTATION_TYPE_KEY = '{}_{}_TYPE' + +def get_inferred_annotation_type_header_key(annot_field, name): + # type: (str, str) -> str + """Creates ID values for annotation type info headers. + + Args: + annot_field: field name representing annotation field (e.g. 'CSQ'). + name: annotation data field names (e.g. 'IMPACT'). + + Returns: + Info ID value (e.g. CSQ_IMPACT_TYPE). + """ + return _BASE_ANNOTATION_TYPE_KEY.format(annot_field, name) + +def infer_info_fields( + variant, + defined_headers, + infer_headers=False, # type: bool + annotation_fields_to_infer=None # type: Optional[List[str]] + ): + """Returns inferred info fields. + + Up to three types of info fields are inferred: + + if `infer_headers` is True: + - The info fields are undefined in the headers. + - The info fields' definitions provided by the header does not match the + field value. + if `infer_annotation_types` is True: + - Fields containing type information of corresponding annotation Info + fields. + + Args: + variant: variant object + defined_headers: header fields defined in header section of VCF files. + infer_headers: If true, header fields are inferred from variant data. + annotation_fields_to_infer: list of info fields treated as annotation + fields (e.g. ['CSQ', 'CSQ_VT']). + + Returns: + infos: dict of (info_key, `Info`) for any info field in + `variant` that is not defined in the header or the definition mismatches + the field values. + """ + infos = {} + if infer_headers: + _infer_non_annotation_info_fields(variant, infos, defined_headers) + if annotation_fields_to_infer: + _infer_annotation_type_info_fields( + variant, infos, defined_headers, annotation_fields_to_infer) + return infos + +def infer_format_fields(variant, defined_headers): + # type: (vcfio.Variant, vcf_header_io.VcfHeader) -> Dict[str, Format] + """Returns inferred format fields. + + Two types of format fields are inferred: + - The format fields are undefined in the headers. + - The format definition provided by the headers does not match the field + values. + + Args: + variant: variant object + defined_headers: header fields defined in header section of VCF files. + + Returns: + A dict of (format_key, `Format`) for any format key in + `variant` that is not defined in the header or the definition mismatches + the field values. + """ + formats = {} + for call in variant.calls: + for format_key, format_value in call.info.iteritems(): + if not defined_headers or format_key not in defined_headers.formats: + if format_key in formats: + raise ValueError( + 'Duplicate FORMAT field "{}" in variant "{}"'.format( + format_key, variant)) + logging.warning('Undefined FORMAT field "%s" in variant "%s"', + format_key, str(variant)) + formats[format_key] = Format(format_key, + _get_field_count(format_value), + _get_field_type(format_value), + '') # NO_DESCRIPTION + # No point in proceeding. All other calls have the same FORMAT. + break + for call in variant.calls: + for format_key, format_value in call.info.iteritems(): + if defined_headers and format_key in defined_headers.formats: + defined_header = defined_headers.formats.get(format_key) + corrected_format = _infer_mismatched_format_field( + format_key, format_value, defined_header) + if corrected_format: + logging.warning( + 'Incorrect FORMAT field "%s". Defined as "type=%s,num=%s", ' + 'got "%s" in variant "%s"', + format_key, defined_header.get(_HeaderKeyConstants.TYPE), + str(defined_header.get(_HeaderKeyConstants.NUM)), + str(format_value), str(variant)) + formats[format_key] = corrected_format + return formats + def _get_field_count(field_value): # type: (Union[List, bool, int, str]) -> Optional[int] """ @@ -114,11 +218,13 @@ def _infer_mismatched_info_field(field_key, # type: str cardinality as the alternate bases. Correct the num to be `None`. - Defined type is `Integer`, but the provided value is float. Correct the type to be `Float`. + Args: field_key: the info field key. field_value: the value of the field key given in the variant. defined_header: The definition of `field_key` in the header. num_alternate_bases: number of the alternate bases. + Returns: Corrected info definition if there are mismatches. """ @@ -140,20 +246,19 @@ def _infer_mismatched_info_field(field_key, # type: str defined_header.get(_HeaderKeyConstants.VERSION)) return None -def _infer_mismatched_format_field(field_key, # type: str - field_value, # type: Any - defined_header # type: Dict - ): - # type: (...) -> Optional[Format] +def _infer_mismatched_format_field(field_key, field_value, defined_header): + # type: (str, Any, Dict) -> Optional[Format] """Returns corrected format if there are mismatches. One type of mismatches is handled: - Defined type is `Integer`, but the provided value is float. Correct the type to be `Float`. + Args: field_key: the format field key. field_value: the value of the field key given in the variant. defined_header: The definition of `field_key` in the header. + Returns: Corrected format definition if there are mismatches. """ @@ -166,7 +271,7 @@ def _infer_mismatched_format_field(field_key, # type: str defined_header.get(_HeaderKeyConstants.DESC)) return None -def _infer_standard_info_fields(variant, infos, defined_headers): +def _infer_non_annotation_info_fields(variant, infos, defined_headers): # type: (vcfio.Variant, Dict[str, Info], vcf_header_io.VcfHeader) -> None """Updates `infos` with inferred info fields. @@ -174,6 +279,7 @@ def _infer_standard_info_fields(variant, infos, defined_headers): - The info fields are undefined in the headers. - The info fields' definitions provided by the header does not match the field value. + Args: variant: variant object infos: dict of (info_key, `Info`) for any info field in @@ -223,6 +329,7 @@ def _infer_annotation_type_info_fields(variant, each variant potentially contains multiple values for each annotation header, a small 'merge' of value types is performed before VcfHeader creation for each variant. + Args: variant: variant object infos: dict of (info_key, `Info`) for any info field in @@ -264,7 +371,7 @@ def _check_annotation_lists_lengths(names, values): _get_field_type(v)) if variant_merged_type == _HeaderTypeConstants.STRING: break - key_id = annotation_parser.get_inferred_annotation_type_header_key( + key_id = get_inferred_annotation_type_header_key( field, name) infos[key_id] = Info(key_id, 1, # field count @@ -273,88 +380,3 @@ def _check_annotation_lists_lengths(names, values): name)), '', # UNKNOWN_SOURCE '') # UNKNOWN_VERSION - -def infer_info_fields( - variant, - defined_headers, - infer_headers=False, # type: bool - annotation_fields_to_infer=None # type: Optional[List[str]] - ): - """Returns inferred info fields. - - Up to three types of info fields are inferred: - - if `infer_headers` is True: - - The info fields are undefined in the headers. - - The info fields' definitions provided by the header does not match the - field value. - if `infer_annotation_types` is True: - - Fields containing type information of corresponding annotation Info - fields. - - Args: - variant: variant object - defined_headers: header fields defined in header section of VCF files. - infer_headers: If true, header fields are inferred from variant data. - annotation_fields_to_infer: list of info fields treated as annotation - fields (e.g. ['CSQ', 'CSQ_VT']). - Returns: - infos: dict of (info_key, `Info`) for any info field in - `variant` that is not defined in the header or the definition mismatches - the field values. - """ - infos = {} - if infer_headers: - _infer_standard_info_fields(variant, infos, defined_headers) - if annotation_fields_to_infer: - _infer_annotation_type_info_fields( - variant, infos, defined_headers, annotation_fields_to_infer) - return infos - -def infer_format_fields(variant, defined_headers): - # type: (vcfio.Variant, vcf_header_io.VcfHeader) -> Dict[str, Format] - """Returns inferred format fields. - - Two types of format fields are inferred: - - The format fields are undefined in the headers. - - The format definition provided by the headers does not match the field - values. - Args: - variant: variant object - defined_headers: header fields defined in header section of VCF files. - Returns: - A dict of (format_key, `Format`) for any format key in - `variant` that is not defined in the header or the definition mismatches - the field values. - """ - formats = {} - for call in variant.calls: - for format_key, format_value in call.info.iteritems(): - if not defined_headers or format_key not in defined_headers.formats: - if format_key in formats: - raise ValueError( - 'Duplicate FORMAT field "{}" in variant "{}"'.format( - format_key, variant)) - logging.warning('Undefined FORMAT field "%s" in variant "%s"', - format_key, str(variant)) - formats[format_key] = Format(format_key, - _get_field_count(format_value), - _get_field_type(format_value), - '') # NO_DESCRIPTION - # No point in proceeding. All other calls have the same FORMAT. - break - for call in variant.calls: - for format_key, format_value in call.info.iteritems(): - if defined_headers and format_key in defined_headers.formats: - defined_header = defined_headers.formats.get(format_key) - corrected_format = _infer_mismatched_format_field( - format_key, format_value, defined_header) - if corrected_format: - logging.warning( - 'Incorrect FORMAT field "%s". Defined as "type=%s,num=%s", ' - 'got "%s" in variant "%s"', - format_key, defined_header.get(_HeaderKeyConstants.TYPE), - str(defined_header.get(_HeaderKeyConstants.NUM)), - str(format_value), str(variant)) - formats[format_key] = corrected_format - return formats diff --git a/gcp_variant_transforms/libs/processed_variant.py b/gcp_variant_transforms/libs/processed_variant.py index 035eb7527..196d322ed 100644 --- a/gcp_variant_transforms/libs/processed_variant.py +++ b/gcp_variant_transforms/libs/processed_variant.py @@ -36,6 +36,7 @@ from gcp_variant_transforms.libs import metrics_util from gcp_variant_transforms.libs import bigquery_util from gcp_variant_transforms.libs import bigquery_sanitizer +from gcp_variant_transforms.libs import infer_headers_util from gcp_variant_transforms.libs.annotation import annotation_parser from gcp_variant_transforms.libs.annotation.vep import descriptions @@ -353,7 +354,7 @@ def _gen_annotation_name_key_pairs(self, annot_field): annotation_names = annotation_parser.extract_annotation_names( self._header_fields.infos[annot_field][_HeaderKeyConstants.DESC]) for name in annotation_names: - type_key = annotation_parser.get_inferred_annotation_type_header_key( + type_key = infer_headers_util.get_inferred_annotation_type_header_key( annot_field, name) yield name, type_key @@ -475,7 +476,7 @@ def add_annotation_data(self, proc_var, annotation_field_name, data): for name, value in annotation_map.iteritems(): if name == annotation_parser.ANNOTATION_ALT: continue - type_key = annotation_parser.get_inferred_annotation_type_header_key( + type_key = infer_headers_util.get_inferred_annotation_type_header_key( annotation_field_name, name) vcf_type = self._vcf_type_from_annotation_header( annotation_field_name, type_key) From 3fed34165c33f2b3a5cace55ca48671d26e93e52 Mon Sep 17 00:00:00 2001 From: Tural Neymanov Date: Mon, 4 Feb 2019 11:22:36 -0500 Subject: [PATCH 3/4] Addressed 2nd iteration of comments. --- gcp_variant_transforms/libs/bigquery_util.py | 2 +- gcp_variant_transforms/libs/infer_headers_util.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gcp_variant_transforms/libs/bigquery_util.py b/gcp_variant_transforms/libs/bigquery_util.py index b83bef5d7..9eb5264e6 100644 --- a/gcp_variant_transforms/libs/bigquery_util.py +++ b/gcp_variant_transforms/libs/bigquery_util.py @@ -197,7 +197,7 @@ def get_avro_type_from_bigquery_type_mode(bigquery_type, bigquery_mode): return avro_type def update_bigquery_schema_on_append(schema_fields, output_table): - # type: (bool, str) -> None + # type: (List[bigquery.TableFieldSchema], str) -> None # if table does not exist, do not need to update the schema. # TODO (yifangchen): Move the logic into validate(). """Update BQ schema by combining existing one with a new one, if possible.""" diff --git a/gcp_variant_transforms/libs/infer_headers_util.py b/gcp_variant_transforms/libs/infer_headers_util.py index 19cd4d4b2..8e37f971a 100644 --- a/gcp_variant_transforms/libs/infer_headers_util.py +++ b/gcp_variant_transforms/libs/infer_headers_util.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A Helper module for header inference operations.""" +"""A helper module for header inference operations.""" from __future__ import absolute_import @@ -51,8 +51,8 @@ def get_inferred_annotation_type_header_key(annot_field, name): return _BASE_ANNOTATION_TYPE_KEY.format(annot_field, name) def infer_info_fields( - variant, - defined_headers, + variant, # type: vcfio.Variant + defined_headers, # type: vcf_header_io.VcfHeader infer_headers=False, # type: bool annotation_fields_to_infer=None # type: Optional[List[str]] ): From c6f611c84231b9590c4ff69ca11dbef478294d10 Mon Sep 17 00:00:00 2001 From: Tural Neymanov Date: Tue, 5 Feb 2019 16:53:12 -0500 Subject: [PATCH 4/4] Modified docstring for update_bigquery_schema_on_append, as per review comment. --- gcp_variant_transforms/libs/bigquery_util.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gcp_variant_transforms/libs/bigquery_util.py b/gcp_variant_transforms/libs/bigquery_util.py index 9eb5264e6..de70ac6e1 100644 --- a/gcp_variant_transforms/libs/bigquery_util.py +++ b/gcp_variant_transforms/libs/bigquery_util.py @@ -198,9 +198,11 @@ def get_avro_type_from_bigquery_type_mode(bigquery_type, bigquery_mode): def update_bigquery_schema_on_append(schema_fields, output_table): # type: (List[bigquery.TableFieldSchema], str) -> None - # if table does not exist, do not need to update the schema. - # TODO (yifangchen): Move the logic into validate(). - """Update BQ schema by combining existing one with a new one, if possible.""" + """Update BQ schema by combining existing one with a new one, if possible. + + If table does not exist, do not need to update the schema. + TODO (yifangchen): Move the logic into validate(). + """ output_table_re_match = re.match( r'^((?P.+):)(?P\w+)\.(?P
[\w\$]+)$', output_table)