Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions gcp_variant_transforms/options/variant_transform_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import argparse # pylint: disable=unused-import
import re

from apache_beam.io import filesystem
from apache_beam.io import filesystems
from apache_beam.io.gcp.internal.clients import bigquery
from apitools.base.py import exceptions
from oauth2client.client import GoogleCredentials
Expand Down Expand Up @@ -53,9 +51,16 @@ class VcfReadOptions(VariantTransformsOptions):

def add_arguments(self, parser):
"""Adds all options of this transform to parser."""
parser.add_argument('--input_pattern',
required=True,
help='Input pattern for VCF files to process.')
parser.add_argument(
'--input_pattern',
help=('Input pattern for VCF files to process. Either'
'this or --input_file flag has to be provided, exclusively.'))
parser.add_argument(
'--input_file',
help=('File that contains the list of VCF file names to input. Either '
'this or --input_pattern flag has to be provided, exclusively.'
'Note that using input_file rather than input_pattern is slower '
'for inputs that contain less than 50k files.'))
parser.add_argument(
'--allow_malformed_records',
type='bool', default=False, nargs='?', const=True,
Expand Down Expand Up @@ -113,16 +118,7 @@ def validate(self, parsed_args):
raise ValueError('Both --infer_headers and --representative_header_file '
'are passed! Please double check and choose at most one '
'of them.')
try:
# Gets at most one pattern match result of type `filesystems.MatchResult`.
first_match = filesystems.FileSystems.match(
[parsed_args.input_pattern], [1])[0]
if not first_match.metadata_list:
raise ValueError('Input pattern {} did not match any files.'.format(
parsed_args.input_pattern))
except filesystem.BeamIOError:
raise ValueError('Invalid or inaccessible input pattern {}.'.format(
parsed_args.input_pattern))
_validate_inputs(parsed_args)


class AvroWriteOptions(VariantTransformsOptions):
Expand Down Expand Up @@ -477,9 +473,16 @@ class PreprocessOptions(VariantTransformsOptions):

def add_arguments(self, parser):
# type: (argparse.ArgumentParser) -> None
parser.add_argument('--input_pattern',
required=True,
help='Input pattern for VCF files to process.')
parser.add_argument(
'--input_pattern',
help='Input pattern for VCF files to process. Either'
'this or --input_file flag has to be provided, exclusively.')
parser.add_argument(
'--input_file',
help=('File that contains the list of VCF file names to input. Either '
'this or --input_pattern flag has to be provided, exlusively. '
'Note that using input_file than input_pattern is slower for '
'inputs that contain less than 50k files.'))
parser.add_argument(
'--report_all_conflicts',
type='bool', default=False, nargs='?', const=True,
Expand All @@ -501,6 +504,10 @@ def add_arguments(self, parser):
'generated if unspecified. Otherwise, please provide a local '
'path if run locally, or a cloud path if run on Dataflow.'))

def validate(self, parsed_args):
_validate_inputs(parsed_args)


class PartitionOptions(VariantTransformsOptions):
"""Options for partitioning Variant records."""

Expand Down Expand Up @@ -583,3 +590,10 @@ def add_arguments(self, parser):
'be the same as the BigQuery table, but it requires all '
'extracted variants to have the same call name ordering (usually '
'true for tables from single VCF file import).'))


def _validate_inputs(parsed_args):
if ((parsed_args.input_pattern and parsed_args.input_file) or
(not parsed_args.input_pattern and not parsed_args.input_file)):
raise ValueError('Exactly one of input_pattern and input_file has to be '
'provided.')
55 changes: 50 additions & 5 deletions gcp_variant_transforms/options/variant_transform_options_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from apitools.base.py import exceptions

from gcp_variant_transforms.options import variant_transform_options
from gcp_variant_transforms.testing import temp_dir


def make_args(options, args):
Expand All @@ -47,20 +48,37 @@ def _make_args(self, args):
# type: (List[str]) -> argparse.Namespace
return make_args(self._options, args)

def test_failure_for_conflicting_flags(self):
def test_no_inputs(self):
args = self._make_args([])
self.assertRaises(ValueError, self._options.validate, args)

def test_failure_for_conflicting_flags_inputs(self):
args = self._make_args(['--input_pattern', '*',
'--input_file', 'asd'])
self.assertRaises(ValueError, self._options.validate, args)

def test_failure_for_conflicting_flags_headers(self):
args = self._make_args(['--input_pattern', '*',
'--infer_headers',
'--representative_header_file', 'gs://some_file'])
self.assertRaises(ValueError, self._options.validate, args)

def test_failure_for_conflicting_flags_no_errors(self):
def test_failure_for_conflicting_flags_no_errors_with_pattern_input(self):
args = self._make_args(['--input_pattern', '*',
'--representative_header_file', 'gs://some_file'])
self._options.validate(args)

def test_failure_for_invalid_input_pattern(self):
args = self._make_args(['--input_pattern', 'nonexistent_file.vcf'])
self.assertRaises(ValueError, self._options.validate, args)
def test_failure_for_conflicting_flags_no_errors_with_file_input(self):
lines = ['./gcp_variant_transforms/testing/data/vcf/valid-4.0.vcf\n',
'./gcp_variant_transforms/testing/data/vcf/valid-4.0.vcf\n',
'./gcp_variant_transforms/testing/data/vcf/valid-4.0.vcf\n']
with temp_dir.TempDir() as tempdir:
filename = tempdir.create_temp_file(lines=lines)
args = self._make_args([
'--input_file',
filename,
'--representative_header_file', 'gs://some_file'])
self._options.validate(args)


class BigQueryWriteOptionsTest(unittest.TestCase):
Expand Down Expand Up @@ -151,3 +169,30 @@ def test_failure_for_invalid_vep_cache(self):
'--vep_image_uri', 'AN_IMAGE',
'--vep_cache_path', 'VEP_CACHE'])
self.assertRaises(ValueError, self._options.validate, args)


class PreprocessOptionsTest(unittest.TestCase):
"""Tests cases for the PreprocessOptions class."""

def setUp(self):
self._options = variant_transform_options.PreprocessOptions()

def _make_args(self, args):
# type: (List[str]) -> argparse.Namespace
return make_args(self._options, args)

def test_failure_for_conflicting_flags_inputs(self):
args = self._make_args(['--input_pattern', '*',
'--report_path', 'some_path',
'--input_file', 'asd'])
self.assertRaises(ValueError, self._options.validate, args)

def test_failure_for_conflicting_flags_no_errors(self):
args = self._make_args(['--input_pattern', '*',
'--report_path', 'some_path'])
self._options.validate(args)

def test_failure_for_conflicting_flags_no_errors_with_pattern_input(self):
args = self._make_args(['--input_pattern', '*',
'--report_path', 'some_path'])
self._options.validate(args)
63 changes: 54 additions & 9 deletions gcp_variant_transforms/pipeline_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import apache_beam as beam
from apache_beam import pvalue # pylint: disable=unused-import
from apache_beam.io import filesystem
from apache_beam.io import filesystems
from apache_beam.options import pipeline_options
from apache_beam.runners.direct import direct_runner
Expand Down Expand Up @@ -66,18 +67,61 @@ def parse_args(argv, command_line_options):
for transform_options in options:
transform_options.validate(known_args)
_raise_error_on_invalid_flags(pipeline_args)
known_args.all_patterns = _get_all_patterns(
known_args.input_pattern, known_args.input_file)
return known_args, pipeline_args


def get_pipeline_mode(input_pattern, optimize_for_large_inputs=False):
# type: (str, bool) -> int
def _get_all_patterns(input_pattern, input_file):
# type: (str, str) -> List[str]
patterns = [input_pattern] if input_pattern else _get_file_names(input_file)

# Validate inputs.
try:
# Gets at most 1 pattern match result of type `filesystems.MatchResult`.
matches = filesystems.FileSystems.match(patterns, [1] * len(patterns))
for match in matches:
if not match.metadata_list:
if input_file:
raise ValueError(
'Input pattern {} from {} did not match any files.'.format(
match.pattern, input_file))
else:
raise ValueError(
'Input pattern {} did not match any files.'.format(match.pattern))
except filesystem.BeamIOError:
if input_file:
raise ValueError(
'Some patterns in {} are invalid or inaccessible.'.format(
input_file))
else:
raise ValueError('Invalid or inaccessible input pattern {}.'.format(
input_pattern))
return patterns


def _get_file_names(input_file):
# type: (str) -> List[str]
"""Reads the input file and extracts list of patterns out of it."""
if not filesystems.FileSystems.exists(input_file):
raise ValueError('Input file {} doesn\'t exist'.format(input_file))
with filesystems.FileSystems.open(input_file) as f:
contents = map(str.strip, f.readlines())
if not contents:
raise ValueError('Input file {} is empty.'.format(input_file))
return contents


def get_pipeline_mode(all_patterns, optimize_for_large_inputs=False):
# type: (List[str], bool) -> int
"""Returns the mode the pipeline should operate in based on input size."""
if optimize_for_large_inputs:
if optimize_for_large_inputs or len(all_patterns) > 1:
return PipelineModes.LARGE

match_results = filesystems.FileSystems.match([input_pattern])
match_results = filesystems.FileSystems.match(all_patterns)
if not match_results:
raise ValueError('No files matched input_pattern: {}'.format(input_pattern))
raise ValueError(
'No files matched input_pattern: {}'.format(all_patterns[0]))

total_files = len(match_results[0].metadata_list)
if total_files > _LARGE_DATA_THRESHOLD:
Expand All @@ -87,15 +131,16 @@ def get_pipeline_mode(input_pattern, optimize_for_large_inputs=False):
return PipelineModes.SMALL


def read_headers(pipeline, pipeline_mode, input_pattern):
# type: (beam.Pipeline, int, str) -> pvalue.PCollection
def read_headers(pipeline, pipeline_mode, all_patterns):
# type: (beam.Pipeline, int, List[str]) -> pvalue.PCollection
"""Creates an initial PCollection by reading the VCF file headers."""
if pipeline_mode == PipelineModes.LARGE:
headers = (pipeline
| beam.Create([input_pattern])
| beam.Create(all_patterns)
| vcf_header_io.ReadAllVcfHeaders())
else:
headers = pipeline | vcf_header_io.ReadVcfHeaders(input_pattern)
headers = pipeline | vcf_header_io.ReadVcfHeaders(all_patterns[0])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? What if all_patterns is empty? Are you just counting on this not being called unless there is at least one pattern in the list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be - we actually had this discussion before hand with regards to having validation on multiple places or only at the beginning and decided that it probably isn't worth to copy an unreachable code all over the pipeline, so long as the validation in the beginning is done right. So yeah, I'm counting that all_patterns has been verified until this point and every pattern in all patterns has at least 1 match in the filesystem


return headers


Expand Down
Loading