diff --git a/gcp_variant_transforms/bq_to_vcf.py b/gcp_variant_transforms/bq_to_vcf.py index a6c08e1ed..90508115a 100644 --- a/gcp_variant_transforms/bq_to_vcf.py +++ b/gcp_variant_transforms/bq_to_vcf.py @@ -60,7 +60,7 @@ from gcp_variant_transforms.beam_io import vcf_header_io from gcp_variant_transforms.beam_io import vcfio from gcp_variant_transforms.libs import bigquery_util -from gcp_variant_transforms.libs import bigquery_vcf_schema_converter +from gcp_variant_transforms.libs import schema_converter from gcp_variant_transforms.libs import genomic_region_parser from gcp_variant_transforms.libs import vcf_file_composer from gcp_variant_transforms.options import variant_transform_options @@ -140,7 +140,7 @@ def _write_vcf_meta_info(input_table, # type: (str, str, bool) -> None """Writes the meta information generated from BigQuery schema.""" header_fields = ( - bigquery_vcf_schema_converter.generate_header_fields_from_schema( + schema_converter.generate_header_fields_from_schema( _get_schema(input_table), allow_incompatible_schema)) write_header_fn = vcf_header_io.WriteVcfHeaderFn(representative_header_file) write_header_fn.process(header_fields, _VCF_VERSION_LINE) diff --git a/gcp_variant_transforms/libs/bigquery_util.py b/gcp_variant_transforms/libs/bigquery_util.py index 22e9e4486..5277f4ac9 100644 --- a/gcp_variant_transforms/libs/bigquery_util.py +++ b/gcp_variant_transforms/libs/bigquery_util.py @@ -16,7 +16,7 @@ import enum import re -from typing import Tuple # pylint: disable=unused-import +from typing import List, Tuple, Union # pylint: disable=unused-import from vcf import parser @@ -54,6 +54,17 @@ class TableFieldConstants(object): MODE_REPEATED = 'REPEATED' +class AvroConstants(object): + """Constants that are relevant to Avro schema.""" + TYPE = 'type' + NAME = 'name' + FIELDS = 'fields' + ARRAY = 'array' + ITEMS = 'items' + RECORD = 'record' + NULL = 'null' + + class _SupportedTableFieldType(enum.Enum): """The supported BigQuery field types. @@ -83,6 +94,16 @@ class _SupportedTableFieldType(enum.Enum): TableFieldConstants.TYPE_BOOLEAN: _VcfHeaderTypeConstants.FLAG } +# A map to convert from BigQuery types to their equivalent Avro types. +_BIG_QUERY_TYPE_TO_AVRO_TYPE_MAP = { + # This list is not exhaustive but covers all of the types we currently use. + TableFieldConstants.TYPE_INTEGER: 'long', + TableFieldConstants.TYPE_STRING: 'string', + TableFieldConstants.TYPE_FLOAT: 'double', + TableFieldConstants.TYPE_BOOLEAN: 'boolean', + TableFieldConstants.TYPE_RECORD: 'record' +} + # A map to convert from BigQuery types to Python types. _BIG_QUERY_TYPE_TO_PYTHON_TYPE_MAP = { TableFieldConstants.TYPE_INTEGER: int, @@ -156,3 +177,17 @@ def get_vcf_num_from_bigquery_schema(bigquery_mode, bigquery_type): def get_supported_bigquery_schema_types(): """Returns the supported BigQuery field types.""" return [item.value for item in _SupportedTableFieldType] + + +def get_avro_type_from_bigquery_type_mode(bigquery_type, bigquery_mode): + # type: (str, str) -> Union[str, List[str, str]] + if not bigquery_type in _BIG_QUERY_TYPE_TO_AVRO_TYPE_MAP: + raise ValueError('Unknown Avro equivalent for type {}'.format( + bigquery_type)) + avro_type = _BIG_QUERY_TYPE_TO_AVRO_TYPE_MAP[bigquery_type] + if bigquery_mode == TableFieldConstants.MODE_NULLABLE: + # A nullable type in the Avro schema is represented by a Union which is + # equivalent to an array in JSON format. + return [avro_type, AvroConstants.NULL] + else: + return avro_type diff --git a/gcp_variant_transforms/libs/bigquery_vcf_data_converter.py b/gcp_variant_transforms/libs/bigquery_vcf_data_converter.py index 4520aaccd..be3d8f0d3 100644 --- a/gcp_variant_transforms/libs/bigquery_vcf_data_converter.py +++ b/gcp_variant_transforms/libs/bigquery_vcf_data_converter.py @@ -129,7 +129,7 @@ def get_rows(self, def _get_call_record( self, - call, # type: VariantCall + call, # type: vcfio.VariantCall call_record_schema_descriptor, # type: bigquery_schema_descriptor.SchemaDescriptor allow_incompatible_records, # type: bool diff --git a/gcp_variant_transforms/libs/bigquery_vcf_schema_converter.py b/gcp_variant_transforms/libs/schema_converter.py similarity index 76% rename from gcp_variant_transforms/libs/bigquery_vcf_schema_converter.py rename to gcp_variant_transforms/libs/schema_converter.py index e6d77ca69..fd03e3c11 100644 --- a/gcp_variant_transforms/libs/bigquery_vcf_schema_converter.py +++ b/gcp_variant_transforms/libs/schema_converter.py @@ -12,24 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Handles the conversion between BigQuery schema and VCF header.""" +"""Handles the conversion between BigQuery/Avro schema and VCF header.""" from __future__ import absolute_import from collections import OrderedDict -from typing import Any, Dict, Union # pylint: disable=unused-import +import json +import logging +from typing import Dict, Union # pylint: disable=unused-import from apache_beam.io.gcp.internal.clients import bigquery +from apitools.base.protorpclite import messages # pylint: disable=unused-import from vcf import parser from gcp_variant_transforms.beam_io import vcfio from gcp_variant_transforms.beam_io import vcf_header_io -from gcp_variant_transforms.libs import bigquery_schema_descriptor # pylint: disable=unused-import from gcp_variant_transforms.libs import bigquery_util from gcp_variant_transforms.libs import processed_variant # pylint: disable=unused-import from gcp_variant_transforms.libs import bigquery_sanitizer -from gcp_variant_transforms.libs import vcf_field_conflict_resolver # pylint: disable=unused-import from gcp_variant_transforms.libs import vcf_reserved_fields from gcp_variant_transforms.libs.annotation import annotation_parser from gcp_variant_transforms.libs.variant_merge import variant_merge_strategy # pylint: disable=unused-import @@ -184,6 +185,144 @@ def generate_schema_from_header_fields( return schema +def _convert_repeated_field_to_avro_array(field, fields_list): + # type: (messages.MessageField) -> Dict + """Converts a repeated field to an Avro Array representation. + + For example the return value can be: {"type": "array", "items": "string"} + """ + array_dict = { + bigquery_util.AvroConstants.TYPE: bigquery_util.AvroConstants.ARRAY + } + if field.fields: + array_dict[bigquery_util.AvroConstants.ITEMS] = { + bigquery_util.AvroConstants.TYPE: bigquery_util.AvroConstants.RECORD, + bigquery_util.AvroConstants.NAME: field.name, + bigquery_util.AvroConstants.FIELDS: fields_list + } + else: + array_dict[bigquery_util.AvroConstants.ITEMS] = { + bigquery_util.AvroConstants.NAME: field.name, + bigquery_util.AvroConstants.TYPE: + bigquery_util.get_avro_type_from_bigquery_type_mode( + field.type, field.mode) + } + # All repeated fields are nullable. + return [bigquery_util.AvroConstants.NULL, array_dict] + + +def _convert_field_to_avro_dict(field): + # type: (messages.MessageField) -> Dict + field_dict = {} + fields_list = [] + if field.fields: + fields_list = [ + _convert_field_to_avro_dict(child_f) for child_f in field.fields] + if field.mode == bigquery_util.TableFieldConstants.MODE_REPEATED: + # TODO(bashir2): In this case both the name of the array and also individual + # records in the array is f.name. Make sure this is according to Avro + # spec then remove this TODO. + field_dict[bigquery_util.AvroConstants.NAME] = field.name + field_dict[bigquery_util.AvroConstants.TYPE] = ( + _convert_repeated_field_to_avro_array(field, fields_list)) + else: + field_dict[bigquery_util.AvroConstants.NAME] = field.name + field_dict[bigquery_util.AvroConstants.TYPE] = ( + bigquery_util.get_avro_type_from_bigquery_type_mode( + field.type, field.mode)) + if field.fields: + field_dict[bigquery_util.AvroConstants.FIELDS] = fields_list + return field_dict + + +def _convert_schema_to_avro_dict(schema): + # type: (bigquery.TableSchema) -> Dict + fields_dict = {} + # TODO(bashir2): Check if we need `namespace` and `name` at the top level. + fields_dict[bigquery_util.AvroConstants.NAME] = 'TBD' + fields_dict[ + bigquery_util.AvroConstants.TYPE] = bigquery_util.AvroConstants.RECORD + fields_dict[bigquery_util.AvroConstants.FIELDS] = [ + _convert_field_to_avro_dict(f) for f in schema.fields] + return fields_dict + + +def convert_table_schema_to_json_avro_schema(schema): + # type: (bigquery.TableSchema) -> str + """Returns the Avro equivalent of the given `schema` in json format. + + For writing to Avro files, the only piece that is different is the schema. In + other words the exact same `Dict` that represents a BigQuery row can be + written to an Avro file if the schema of that file is equivalent to the + BigQuery Table schema. This function generates that equivalent Avro schema. + + For details of Avro schema spec, see: + https://avro.apache.org/docs/1.8.2/spec.html + + For concrete examples relevant to our BigQuery schema, consider the following + three required fields: + + { + "fields": [ + { + "type": [ "string", "null"], + "name": "reference_name" + }, + { + "type": ["int", "null"], + "name": "start_position" + }, + { + "type": ["int", "null"], + "name": "end_position" + }, + ... + ], + "type": "record", + "name": "TBD" + } + + Note that the whole schema is represented as a `record` which has several + `fields`. In the above example, only the first three `fields` are shown. + A `NULLABLE` type in BigQuery schema is equivalent to a `type` array where + `null` is one of the members. + + `REPEATED` fields, specially `REPEATED` `RECORD` fields, are a little more + complex in Avro schema format. Here is one example for `alternate_bases`: + { + "type": [{ + "items": { + "type": "record", + "name": "alternate_bases", + "fields": [ + { + "type": ["string", "null"], + "name": "alt" + }, + { + "type": ["float", "null"], + "name": "AF" + } + ] + }, + "type": "array" + }, "null" ], + "name": "alternate_bases" + }, + + Args: + schema: This is the BigQuery table schema that is generated from input VCFs. + """ + if not isinstance(schema, bigquery.TableSchema): + raise ValueError( + 'Expected an instance of bigquery.TableSchema got {}'.format( + type(schema))) + schema_dict = _convert_schema_to_avro_dict(schema) + json_str = json.dumps(schema_dict) + logging.info('The Avro schema is: %s', json_str) + return json_str + + def generate_header_fields_from_schema(schema, allow_incompatible_schema=False): # type: (bigquery.TableSchema, bool) -> vcf_header_io.VcfHeader """Returns header fields converted from BigQuery schema. diff --git a/gcp_variant_transforms/libs/bigquery_vcf_schema_converter_test.py b/gcp_variant_transforms/libs/schema_converter_test.py similarity index 83% rename from gcp_variant_transforms/libs/bigquery_vcf_schema_converter_test.py rename to gcp_variant_transforms/libs/schema_converter_test.py index 353502a36..66d919253 100644 --- a/gcp_variant_transforms/libs/bigquery_vcf_schema_converter_test.py +++ b/gcp_variant_transforms/libs/schema_converter_test.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for `bigquery_vcf_schema_converter` module.""" +"""Tests for `schema_converter` module.""" from __future__ import absolute_import from collections import OrderedDict +from typing import List, Union # pylint: disable=unused-import import unittest +import avro from apache_beam.io.gcp.internal.clients import bigquery from vcf import parser @@ -26,7 +28,7 @@ from gcp_variant_transforms.beam_io import vcf_header_io from gcp_variant_transforms.libs import bigquery_util -from gcp_variant_transforms.libs import bigquery_vcf_schema_converter +from gcp_variant_transforms.libs import schema_converter from gcp_variant_transforms.libs import processed_variant from gcp_variant_transforms.libs.bigquery_util import ColumnKeyConstants from gcp_variant_transforms.libs.bigquery_util import TableFieldConstants @@ -50,6 +52,14 @@ def modify_bigquery_schema(self, schema, info_keys): class GenerateSchemaFromHeaderFieldsTest(unittest.TestCase): """Test cases for the ``generate_schema_from_header_fields`` function.""" + def _validate_schema(self, expected_fields, actual_schema): + """This can be overridden by child classes to do more validations. + + This is called at the end of each test to verify that `actual_schema` + has all the `expected_fields`. + """ + self.assertEqual(expected_fields, _get_fields_from_schema(actual_schema)) + def _generate_expected_fields(self, alt_fields=None, call_fields=None, info_fields=None): fields = [ColumnKeyConstants.REFERENCE_NAME, @@ -77,14 +87,11 @@ def _generate_expected_fields(self, alt_fields=None, call_fields=None, fields.extend(info_fields or []) return fields - def _assert_fields_equal(self, expected_fields, actual_schema): - self.assertEqual(expected_fields, _get_fields_from_schema(actual_schema)) - def test_no_header_fields(self): header_fields = vcf_header_io.VcfHeader() - self._assert_fields_equal( + self._validate_schema( self._generate_expected_fields(), - bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema_converter.generate_schema_from_header_fields( header_fields, processed_variant.ProcessedVariantFactory(header_fields))) @@ -101,21 +108,21 @@ def test_info_header_fields(self): Info('END', 1, 'Integer', 'Special END key', 'src', 'v'))]) header_fields = vcf_header_io.VcfHeader(infos=infos) - self._assert_fields_equal( + self._validate_schema( self._generate_expected_fields( alt_fields=['IA', 'IA2'], info_fields=['I1', 'I2', 'IU', 'IG', 'I0']), - bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema_converter.generate_schema_from_header_fields( header_fields, processed_variant.ProcessedVariantFactory(header_fields))) # Test with split_alternate_allele_info_fields=False. actual_schema = ( - bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema_converter.generate_schema_from_header_fields( header_fields, processed_variant.ProcessedVariantFactory( header_fields, split_alternate_allele_info_fields=False))) - self._assert_fields_equal( + self._validate_schema( self._generate_expected_fields( info_fields=['I1', 'I2', 'IA', 'IU', 'IG', 'I0', 'IA2']), actual_schema) @@ -154,12 +161,12 @@ def test_info_and_format_header_fields(self): ('GT', Format('GT', 2, 'Integer', 'Special GT key')), ('PS', Format('PS', 1, 'Integer', 'Special PS key'))]) header_fields = vcf_header_io.VcfHeader(infos=infos, formats=formats) - self._assert_fields_equal( + self._validate_schema( self._generate_expected_fields( alt_fields=['IA'], call_fields=['F1', 'F2', 'FU'], info_fields=['I1']), - bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema_converter.generate_schema_from_header_fields( header_fields, processed_variant.ProcessedVariantFactory(header_fields))) @@ -175,13 +182,13 @@ def test_bigquery_field_name_sanitize(self): ('a^b', Format('a^b', 1, 'String', 'desc')), ('OK_format_09', Format('OK_format_09', 1, 'String', 'desc'))]) header_fields = vcf_header_io.VcfHeader(infos=infos, formats=formats) - self._assert_fields_equal( + self._validate_schema( self._generate_expected_fields( alt_fields=['I_A'], call_fields=['a_b', 'OK_format_09'], info_fields=['field__', 'field__A', 'field_0a', 'A_B_C', 'OK_info_09']), - bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema_converter.generate_schema_from_header_fields( header_fields, processed_variant.ProcessedVariantFactory(header_fields))) @@ -191,17 +198,36 @@ def test_variant_merger_modify_schema(self): ('IA', Info('IA', field_counts['A'], 'Integer', 'desc', 'src', 'v'))]) formats = OrderedDict([('F1', Format('F1', 1, 'String', 'desc'))]) header_fields = vcf_header_io.VcfHeader(infos=infos, formats=formats) - self._assert_fields_equal( + self._validate_schema( self._generate_expected_fields( alt_fields=['IA'], call_fields=['F1'], info_fields=['I1', 'ADDED_BY_MERGER']), - bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema_converter.generate_schema_from_header_fields( header_fields, processed_variant.ProcessedVariantFactory(header_fields), variant_merger=_DummyVariantMergeStrategy())) +class ConvertTableSchemaToJsonAvroSchemaTest( + GenerateSchemaFromHeaderFieldsTest): + """Test cases for `convert_table_schema_to_json_avro_schema`. + + This basically works by extending GenerateSchemaFromHeaderFieldsTest such + that each BigQuery table schema that is generated by tests in that class, + are converted to Avro schema and verified in this class. + """ + + def _validate_schema(self, expected_fields, actual_schema): + super(ConvertTableSchemaToJsonAvroSchemaTest, self)._validate_schema( + expected_fields, actual_schema) + avro_schema = avro.schema.parse( + schema_converter.convert_table_schema_to_json_avro_schema( + actual_schema)) + self.assertEqual(expected_fields, + _get_fields_from_avro_type(avro_schema, '')) + + class GenerateHeaderFieldsFromSchemaTest(unittest.TestCase): """Test cases for the `generate_header_fields_from_schema` function.""" @@ -217,7 +243,7 @@ def test_add_info_fields_from_alternate_bases_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='bigquery desc')) infos_with_desc = OrderedDict() - bigquery_vcf_schema_converter._add_info_fields( + schema_converter._add_info_fields( alternate_bases_record_with_desc, infos_with_desc) expected_infos = OrderedDict([ ('AF', Info('AF', field_counts['A'], 'Float', 'bigquery desc', @@ -235,7 +261,7 @@ def test_add_info_fields_from_alternate_bases_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='')) infos_no_desc = OrderedDict() - bigquery_vcf_schema_converter._add_info_fields( + schema_converter._add_info_fields( alternate_bases_record_no_desc, infos_no_desc) expected_infos = OrderedDict([ ('AF', Info('AF', field_counts['A'], 'Float', @@ -256,11 +282,11 @@ def test_add_info_fields_from_alternate_bases_schema_compatibility(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='desc')) with self.assertRaises(ValueError): - bigquery_vcf_schema_converter._add_info_fields(schema_conflict_info, - OrderedDict()) + schema_converter._add_info_fields(schema_conflict_info, + OrderedDict()) infos_allow_incompatible_schema = OrderedDict() - bigquery_vcf_schema_converter._add_info_fields( + schema_converter._add_info_fields( schema_conflict_info, infos_allow_incompatible_schema, allow_incompatible_schema=True) @@ -280,7 +306,7 @@ def test_add_info_fields_from_alternate_bases_non_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='bigquery desc')) infos = OrderedDict() - bigquery_vcf_schema_converter._add_info_fields( + schema_converter._add_info_fields( alternate_bases_record, infos) expected_infos = OrderedDict([ ('non_reserved', Info('non_reserved', field_counts['A'], 'Float', @@ -294,7 +320,7 @@ def test_add_info_fields_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='bigquery desc') infos = OrderedDict() - bigquery_vcf_schema_converter._add_info_fields(field_with_desc, infos) + schema_converter._add_info_fields(field_with_desc, infos) expected_infos = OrderedDict([ ('AA', Info('AA', 1, 'String', 'bigquery desc', None, None))]) self.assertEqual(infos, expected_infos) @@ -305,7 +331,7 @@ def test_add_info_fields_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='') infos = OrderedDict() - bigquery_vcf_schema_converter._add_info_fields(field_without_desc, infos) + schema_converter._add_info_fields(field_without_desc, infos) expected_infos = OrderedDict([ ('AA', Info('AA', 1, 'String', 'Ancestral allele', None, None))]) self.assertEqual(infos, expected_infos) @@ -317,8 +343,8 @@ def test_add_info_fields_reserved_field_schema_compatibility(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='desc') with self.assertRaises(ValueError): - bigquery_vcf_schema_converter._add_info_fields(field_conflict_info_type, - OrderedDict()) + schema_converter._add_info_fields(field_conflict_info_type, + OrderedDict()) field_conflict_info_format = bigquery.TableFieldSchema( name='AA', @@ -326,11 +352,11 @@ def test_add_info_fields_reserved_field_schema_compatibility(self): mode=bigquery_util.TableFieldConstants.MODE_REPEATED, description='desc') with self.assertRaises(ValueError): - bigquery_vcf_schema_converter._add_info_fields(field_conflict_info_format, - OrderedDict()) + schema_converter._add_info_fields(field_conflict_info_format, + OrderedDict()) info_allow_incompatible_schema = OrderedDict() - bigquery_vcf_schema_converter._add_info_fields( + schema_converter._add_info_fields( field_conflict_info_format, info_allow_incompatible_schema, allow_incompatible_schema=True) @@ -345,7 +371,7 @@ def test_add_info_fields_non_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='') infos = OrderedDict() - bigquery_vcf_schema_converter._add_info_fields(non_reserved_field, infos) + schema_converter._add_info_fields(non_reserved_field, infos) expected_infos = OrderedDict([ ('non_reserved_info', Info('non_reserved_info', 1, 'String', '', None, None))]) @@ -363,8 +389,8 @@ def test_add_format_fields_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='bigquery desc')) formats = OrderedDict() - bigquery_vcf_schema_converter._add_format_fields(calls_record_with_desc, - formats) + schema_converter._add_format_fields(calls_record_with_desc, + formats) expected_formats = OrderedDict([ ('GQ', Format('GQ', 1, 'Integer', 'bigquery desc'))]) self.assertEqual(formats, expected_formats) @@ -380,8 +406,8 @@ def test_add_format_fields_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='')) formats = OrderedDict() - bigquery_vcf_schema_converter._add_format_fields(calls_record_without_desc, - formats) + schema_converter._add_format_fields(calls_record_without_desc, + formats) expected_formats = OrderedDict([ ('GQ', Format('GQ', 1, 'Integer', 'Conditional genotype quality'))]) self.assertEqual(formats, expected_formats) @@ -400,11 +426,11 @@ def test_add_format_fields_reserved_field_schema_compatibility(self): description='desc')) schema_conflict_format.fields.append(calls_record) with self.assertRaises(ValueError): - bigquery_vcf_schema_converter.generate_header_fields_from_schema( + schema_converter.generate_header_fields_from_schema( schema_conflict_format) formats_allow_incompatible_schema = OrderedDict() - bigquery_vcf_schema_converter._add_format_fields( + schema_converter._add_format_fields( calls_record, formats_allow_incompatible_schema, allow_incompatible_schema=True) @@ -424,7 +450,7 @@ def test_add_format_fields_non_reserved_field(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='bigquery desc')) formats = OrderedDict() - bigquery_vcf_schema_converter._add_format_fields(calls_record, formats) + schema_converter._add_format_fields(calls_record, formats) expected_formats = OrderedDict([ ('non_reserved_format', Format('non_reserved_format', 1, 'Integer', 'bigquery desc'))]) @@ -432,7 +458,7 @@ def test_add_format_fields_non_reserved_field(self): def test_generate_header_fields_from_schema(self): sample_schema = bigquery_schema_util.get_sample_table_schema() - header = bigquery_vcf_schema_converter.generate_header_fields_from_schema( + header = schema_converter.generate_header_fields_from_schema( sample_schema) infos = OrderedDict([ @@ -450,7 +476,7 @@ def test_generate_header_fields_from_schema(self): def test_generate_header_fields_from_schema_with_annotation(self): sample_schema = bigquery_schema_util.get_sample_table_schema( with_annotation_fields=True) - header = bigquery_vcf_schema_converter.generate_header_fields_from_schema( + header = schema_converter.generate_header_fields_from_schema( sample_schema) infos = OrderedDict([ @@ -474,7 +500,7 @@ def test_generate_header_fields_from_schema_date_type(self): type='Date', mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='Column required by BigQuery partitioning logic.')) - header = bigquery_vcf_schema_converter.generate_header_fields_from_schema( + header = schema_converter.generate_header_fields_from_schema( schema) expected_header = vcf_header_io.VcfHeader(infos=OrderedDict(), @@ -487,7 +513,7 @@ def test_generate_header_fields_from_schema_none_mode(self): name='field', type=bigquery_util.TableFieldConstants.TYPE_STRING, description='desc')) - header = bigquery_vcf_schema_converter.generate_header_fields_from_schema( + header = schema_converter.generate_header_fields_from_schema( schema_non_reserved_fields) infos = OrderedDict([ ('field', Info('field', 1, 'String', 'desc', None, None))]) @@ -500,7 +526,7 @@ def test_generate_header_fields_from_schema_none_mode(self): name='AA', type=bigquery_util.TableFieldConstants.TYPE_STRING, description='desc')) - header = bigquery_vcf_schema_converter.generate_header_fields_from_schema( + header = schema_converter.generate_header_fields_from_schema( schema_reserved_fields) infos = OrderedDict([ ('AA', Info('AA', 1, 'String', 'desc', None, None))]) @@ -516,10 +542,10 @@ def test_generate_header_fields_from_schema_schema_compatibility(self): mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='desc')) with self.assertRaises(ValueError): - bigquery_vcf_schema_converter.generate_header_fields_from_schema( + schema_converter.generate_header_fields_from_schema( schema_conflict) - header = bigquery_vcf_schema_converter.generate_header_fields_from_schema( + header = schema_converter.generate_header_fields_from_schema( schema_conflict, allow_incompatible_schema=True) infos = OrderedDict([ @@ -535,7 +561,7 @@ def test_generate_header_fields_from_schema_invalid_description(self): type=bigquery_util.TableFieldConstants.TYPE_STRING, mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='Desc\nThis is added intentionally.')) - header = bigquery_vcf_schema_converter.generate_header_fields_from_schema( + header = schema_converter.generate_header_fields_from_schema( schema) infos = OrderedDict([ @@ -560,21 +586,21 @@ def test_vcf_header_to_schema_to_vcf_header(self): ('FU', Format('FU', field_counts['.'], 'Float', 'desc'))]) original_header = vcf_header_io.VcfHeader(infos=infos, formats=formats) - schema = bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema = schema_converter.generate_schema_from_header_fields( original_header, processed_variant.ProcessedVariantFactory(original_header)) reconstructed_header = ( - bigquery_vcf_schema_converter.generate_header_fields_from_schema( + schema_converter.generate_header_fields_from_schema( schema)) self.assertEqual(original_header, reconstructed_header) def test_schema_to_vcf_header_to_schema(self): original_schema = bigquery_schema_util.get_sample_table_schema() - header = bigquery_vcf_schema_converter.generate_header_fields_from_schema( + header = schema_converter.generate_header_fields_from_schema( original_schema) reconstructed_schema = ( - bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema_converter.generate_schema_from_header_fields( header, processed_variant.ProcessedVariantFactory(header))) self.assertEqual(_get_fields_from_schema(reconstructed_schema), @@ -588,3 +614,30 @@ def _get_fields_from_schema(schema, prefix=''): if field.type == TableFieldConstants.TYPE_RECORD: fields.extend(_get_fields_from_schema(field, prefix=field.name + '.')) return fields + + +def _get_fields_from_avro_type(field_or_schema, prefix): + # type: (Union[avro.schema.Field, avro.schema.Schema], str) -> List[str] + fields = [] + if isinstance(field_or_schema, avro.schema.PrimitiveSchema): + return [] + t = field_or_schema.type + if isinstance(t, avro.schema.UnionSchema): + for s in t.schemas: + fields.extend(_get_fields_from_avro_type(s, prefix)) + if isinstance(field_or_schema, avro.schema.ArraySchema): + return _get_fields_from_avro_type(field_or_schema.items, prefix) + # We need to exclude the name for the case of a UnionSchema that has + # a RecordSchema as a type member. In this case, the name of the record + # appears twice in the Avro schema, once at the UnionSchema level and once + # at the child RecordSchema. + name = field_or_schema.name + if name and name not in fields and name != 'TBD': + fields.extend([prefix + field_or_schema.name]) + if field_or_schema.get_prop('fields'): + child_prefix = prefix + if name != 'TBD': + child_prefix = prefix + field_or_schema.name + '.' + for f in field_or_schema.fields: + fields.extend(_get_fields_from_avro_type(f, child_prefix)) + return fields diff --git a/gcp_variant_transforms/options/variant_transform_options.py b/gcp_variant_transforms/options/variant_transform_options.py index 1d306c125..16c26c288 100644 --- a/gcp_variant_transforms/options/variant_transform_options.py +++ b/gcp_variant_transforms/options/variant_transform_options.py @@ -125,13 +125,29 @@ def validate(self, parsed_args): parsed_args.input_pattern)) +class AvroWriteOptions(VariantTransformsOptions): + """Options for writing Variant records to Avro files.""" + + def add_arguments(self, parser): + # type: (argparse.ArgumentParser) -> None + parser.add_argument('--output_avro_path', + default='', + help='The output path to write Avro files under.') + + def validate(self, parsed_args): + # type: (argparse.Namespace) -> None + if not parsed_args.output_table and not parsed_args.output_avro_path: + raise ValueError('At least one of --output_table or --output_avro_path ' + 'options should be provided.') + + class BigQueryWriteOptions(VariantTransformsOptions): """Options for writing Variant records to BigQuery.""" def add_arguments(self, parser): # type: (argparse.ArgumentParser) -> None parser.add_argument('--output_table', - required=True, + default='', help='BigQuery table to store the results.') parser.add_argument( '--split_alternate_allele_info_fields', @@ -177,6 +193,9 @@ def add_arguments(self, parser): def validate(self, parsed_args, client=None): # type: (argparse.Namespace, bigquery.BigqueryV2) -> None + if not parsed_args.output_table and parsed_args.output_avro_path: + # Writing into BigQuery is not requested; no more BigQuery checks needed. + return output_table_re_match = re.match( r'^((?P.+):)(?P\w+)\.(?P[\w\$]+)$', parsed_args.output_table) diff --git a/gcp_variant_transforms/transforms/variant_to_avro.py b/gcp_variant_transforms/transforms/variant_to_avro.py new file mode 100644 index 000000000..410007530 --- /dev/null +++ b/gcp_variant_transforms/transforms/variant_to_avro.py @@ -0,0 +1,96 @@ +# Copyright 2018 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import apache_beam as beam +import avro + +from gcp_variant_transforms.beam_io import vcf_header_io # pylint: disable=unused-import +from gcp_variant_transforms.libs import bigquery_schema_descriptor +from gcp_variant_transforms.libs import schema_converter +from gcp_variant_transforms.libs import bigquery_vcf_data_converter +from gcp_variant_transforms.libs import processed_variant +from gcp_variant_transforms.libs import vcf_field_conflict_resolver +from gcp_variant_transforms.libs.variant_merge import variant_merge_strategy # pylint: disable=unused-import +from gcp_variant_transforms.transforms import variant_to_bigquery + + +# TODO(bashir2): Refactor common parts of VariantToAvroFiles and +# VariantToBigQuery into a class that is shared by both. It is mostly the +# schema generation that is different (and of course the sink). There is also +# some logic for updating schema etc. that is not needed for the Avro case. +@beam.typehints.with_input_types(processed_variant.ProcessedVariant) +class VariantToAvroFiles(beam.PTransform): + """Writes PCollection of `ProcessedVariant` records to Avro files.""" + + def __init__( + self, + output_path, # type: str + header_fields, # type: vcf_header_io.VcfHeader + proc_var_factory, # type: processed_variant.ProcessedVariantFactory + variant_merger=None, # type: variant_merge_strategy.VariantMergeStrategy + allow_incompatible_records=False, # type: bool + omit_empty_sample_calls=False, # type: bool + null_numeric_value_replacement=None # type: int + ): + # type: (...) -> None + """Initializes the transform. + + Args: + output_path: The path under which output Avro files are generated. + header_fields: Representative header fields for all variants. This is + needed for dynamically generating the schema. + proc_var_factory: The factory class that knows how to convert Variant + instances to ProcessedVariant. As a side effect it also knows how to + modify BigQuery schema based on the ProcessedVariants that it generates. + The latter functionality is what is needed here. + variant_merger: The strategy used for merging variants (if any). Some + strategies may change the schema, which is why this may be needed here. + allow_incompatible_records: If true, field values are casted to Bigquery ++ schema if there is a mismatch. + omit_empty_sample_calls: If true, samples that don't have a given call + will be omitted. + null_numeric_value_replacement: the value to use instead of null for + numeric (float/int/long) lists. For instance, [0, None, 1] will become + [0, `null_numeric_value_replacement`, 1]. If not set, the value will set + to bigquery_util._DEFAULT_NULL_NUMERIC_VALUE_REPLACEMENT. + """ + self._output_path = output_path + self._proc_var_factory = proc_var_factory + table_schema = ( + schema_converter.generate_schema_from_header_fields( + header_fields, proc_var_factory, variant_merger)) + self._avro_schema = avro.schema.parse( + schema_converter.convert_table_schema_to_json_avro_schema( + table_schema)) + self._bigquery_row_generator = ( + bigquery_vcf_data_converter.BigQueryRowGenerator( + bigquery_schema_descriptor.SchemaDescriptor(table_schema), + vcf_field_conflict_resolver.FieldConflictResolver( + resolve_always=allow_incompatible_records), + null_numeric_value_replacement)) + + self._allow_incompatible_records = allow_incompatible_records + self._omit_empty_sample_calls = omit_empty_sample_calls + + def expand(self, pcoll): + avro_records = pcoll | 'ConvertToAvroRecords' >> beam.ParDo( + variant_to_bigquery.ConvertVariantToRow( + self._bigquery_row_generator, + self._allow_incompatible_records, + self._omit_empty_sample_calls)) + return (avro_records + | 'WriteToAvroFiles' >> + beam.io.WriteToAvro(self._output_path, self._avro_schema)) diff --git a/gcp_variant_transforms/transforms/variant_to_bigquery.py b/gcp_variant_transforms/transforms/variant_to_bigquery.py index 13418f33d..a1ae5e3cc 100644 --- a/gcp_variant_transforms/transforms/variant_to_bigquery.py +++ b/gcp_variant_transforms/transforms/variant_to_bigquery.py @@ -27,9 +27,9 @@ from oauth2client.client import GoogleCredentials from gcp_variant_transforms.beam_io import vcf_header_io # pylint: disable=unused-import -from gcp_variant_transforms.libs import bigquery_schema_descriptor # pylint: disable=unused-import +from gcp_variant_transforms.libs import bigquery_schema_descriptor from gcp_variant_transforms.libs import bigquery_util -from gcp_variant_transforms.libs import bigquery_vcf_schema_converter +from gcp_variant_transforms.libs import schema_converter from gcp_variant_transforms.libs import bigquery_vcf_data_converter from gcp_variant_transforms.libs import processed_variant from gcp_variant_transforms.libs import vcf_field_conflict_resolver @@ -44,7 +44,7 @@ @beam.typehints.with_input_types(processed_variant.ProcessedVariant) -class _ConvertToBigQueryTableRow(beam.DoFn): +class ConvertVariantToRow(beam.DoFn): """Converts a ``Variant`` record to a BigQuery row.""" def __init__( @@ -54,7 +54,7 @@ def __init__( omit_empty_sample_calls=False # type: bool ): # type: (...) -> None - super(_ConvertToBigQueryTableRow, self).__init__() + super(ConvertVariantToRow, self).__init__() self._allow_incompatible_records = allow_incompatible_records self._omit_empty_sample_calls = omit_empty_sample_calls self._bigquery_row_generator = row_generator @@ -74,6 +74,8 @@ def __init__( header_fields, # type: vcf_header_io.VcfHeader variant_merger=None, # type: variant_merge_strategy.VariantMergeStrategy proc_var_factory=None, # type: processed_variant.ProcessedVariantFactory + # TODO(bashir2): proc_var_factory is a required argument and if `None` is + # supplied this will fail in schema generation. append=False, # type: bool update_schema_on_append=False, # type: bool allow_incompatible_records=False, # type: bool @@ -115,7 +117,7 @@ def __init__( self._proc_var_factory = proc_var_factory self._append = append self._schema = ( - bigquery_vcf_schema_converter.generate_schema_from_header_fields( + schema_converter.generate_schema_from_header_fields( self._header_fields, self._proc_var_factory, self._variant_merger)) # Resolver makes extra effort to resolve conflict when flag # allow_incompatible_records is set. @@ -134,7 +136,7 @@ def __init__( def expand(self, pcoll): bq_rows = pcoll | 'ConvertToBigQueryTableRow' >> beam.ParDo( - _ConvertToBigQueryTableRow( + ConvertVariantToRow( self._bigquery_row_generator, self._allow_incompatible_records, self._omit_empty_sample_calls)) diff --git a/gcp_variant_transforms/transforms/variant_to_bigquery_test.py b/gcp_variant_transforms/transforms/variant_to_bigquery_test.py index e76da37a9..17d106918 100644 --- a/gcp_variant_transforms/transforms/variant_to_bigquery_test.py +++ b/gcp_variant_transforms/transforms/variant_to_bigquery_test.py @@ -34,7 +34,7 @@ from gcp_variant_transforms.libs.bigquery_util import TableFieldConstants from gcp_variant_transforms.testing import vcf_header_util from gcp_variant_transforms.transforms import variant_to_bigquery -from gcp_variant_transforms.transforms.variant_to_bigquery import _ConvertToBigQueryTableRow as ConvertToBigQueryTableRow +from gcp_variant_transforms.transforms.variant_to_bigquery import ConvertVariantToRow class ConvertToBigQueryTableRowTest(unittest.TestCase): @@ -244,7 +244,7 @@ def test_convert_variant_to_bigquery_row(self): bigquery_rows = ( pipeline | Create([proc_var_1, proc_var_2, proc_var_3]) - | 'ConvertToRow' >> ParDo(ConvertToBigQueryTableRow( + | 'ConvertToRow' >> ParDo(ConvertVariantToRow( self._row_generator))) assert_that(bigquery_rows, equal_to([row_1, row_2, row_3])) pipeline.run() @@ -258,7 +258,7 @@ def test_convert_variant_to_bigquery_row_omit_empty_calls(self): bigquery_rows = ( pipeline | Create([proc_var]) - | 'ConvertToRow' >> ParDo(ConvertToBigQueryTableRow( + | 'ConvertToRow' >> ParDo(ConvertVariantToRow( self._row_generator, omit_empty_sample_calls=True))) assert_that(bigquery_rows, equal_to([row])) pipeline.run() @@ -273,7 +273,7 @@ def test_convert_variant_to_bigquery_row_allow_incompatible_recoreds(self): bigquery_rows = ( pipeline | Create([proc_var]) - | 'ConvertToRow' >> ParDo(ConvertToBigQueryTableRow( + | 'ConvertToRow' >> ParDo(ConvertVariantToRow( self._row_generator, allow_incompatible_records=True))) assert_that(bigquery_rows, equal_to([row])) pipeline.run() diff --git a/gcp_variant_transforms/vcf_to_bq.py b/gcp_variant_transforms/vcf_to_bq.py index e4943e03b..27e0345a2 100644 --- a/gcp_variant_transforms/vcf_to_bq.py +++ b/gcp_variant_transforms/vcf_to_bq.py @@ -60,10 +60,13 @@ from gcp_variant_transforms.transforms import merge_headers from gcp_variant_transforms.transforms import merge_variants from gcp_variant_transforms.transforms import partition_variants +from gcp_variant_transforms.transforms import variant_to_avro from gcp_variant_transforms.transforms import variant_to_bigquery + _COMMAND_LINE_OPTIONS = [ variant_transform_options.VcfReadOptions, + variant_transform_options.AvroWriteOptions, variant_transform_options.BigQueryWriteOptions, variant_transform_options.AnnotationOptions, variant_transform_options.FilterOptions, @@ -261,24 +264,43 @@ def run(argv=None): variants = [variants | 'FlattenPartitions' >> beam.Flatten()] num_partitions = 1 - for i in range(num_partitions): - table_suffix = '' - if partitioner and partitioner.get_partition_name(i): - table_suffix = '_' + partitioner.get_partition_name(i) - table_name = known_args.output_table + table_suffix - _ = (variants[i] | 'VariantToBigQuery' + table_suffix >> - variant_to_bigquery.VariantToBigQuery( - table_name, - header_fields, - variant_merger, - processed_variant_factory, - append=known_args.append, - update_schema_on_append=known_args.update_schema_on_append, - allow_incompatible_records=known_args.allow_incompatible_records, - omit_empty_sample_calls=known_args.omit_empty_sample_calls, - num_bigquery_write_shards=known_args.num_bigquery_write_shards, - null_numeric_value_replacement=( - known_args.null_numeric_value_replacement))) + if known_args.output_table: + for i in range(num_partitions): + table_suffix = '' + if partitioner and partitioner.get_partition_name(i): + table_suffix = '_' + partitioner.get_partition_name(i) + table_name = known_args.output_table + table_suffix + _ = (variants[i] | 'VariantToBigQuery' + table_suffix >> + variant_to_bigquery.VariantToBigQuery( + table_name, + header_fields, + variant_merger, + processed_variant_factory, + append=known_args.append, + update_schema_on_append=known_args.update_schema_on_append, + allow_incompatible_records=known_args.allow_incompatible_records, + omit_empty_sample_calls=known_args.omit_empty_sample_calls, + num_bigquery_write_shards=known_args.num_bigquery_write_shards, + null_numeric_value_replacement=( + known_args.null_numeric_value_replacement))) + + if known_args.output_avro_path: + # TODO(bashir2): Add an integration test that outputs to Avro files and + # also imports to BigQuery. Then import those Avro outputs using the bq + # tool and verify that the two tables are identical. + _ = ( + variants | 'FlattenToOnePCollection' >> beam.Flatten() + | 'VariantToAvro' >> + variant_to_avro.VariantToAvroFiles( + known_args.output_avro_path, + header_fields, + processed_variant_factory, + variant_merger=variant_merger, + allow_incompatible_records=known_args.allow_incompatible_records, + omit_empty_sample_calls=known_args.omit_empty_sample_calls, + null_numeric_value_replacement=( + known_args.null_numeric_value_replacement)) + ) result = pipeline.run() result.wait_until_finish()