Skip to content

Commit

Permalink
Merge pull request #1294: Read a subset of metadata columns
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Feb 8, 2024
2 parents 9b31ad8 + b56f699 commit 8678ae9
Show file tree
Hide file tree
Showing 15 changed files with 546 additions and 117 deletions.
9 changes: 9 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@

## __NEXT__

### Features

* filter: Added a new option `--query-columns` that allows specifying what columns are used in `--query` along with the expected data types. If unspecified, automatic detection of columns and types is attempted. [#1294][] (@victorlin)
* `augur.io.read_metadata`: A new optional `columns` argument allows specifying a subset of columns to load. The default behavior still loads all columns, so this is not a breaking change. [#1294][] (@victorlin)

### Bug Fixes

* filter: The order of rows in `--output-metadata` and `--output-strains` now reflects the order in the original `--metadata`. [#1294][] (@victorlin)
* filter, frequencies, refine: Performance improvements to reading the input metadata file. [#1294][] (@victorlin)
* For filter, this comes with increased writing times for `--output-metadata` and `--output-strains`. However, net I/O speed still decreased during testing of this change.
* filter: Updated the help text of `--include` and `--include-where` to explicitly state that this can add strains that are missing an entry from `--sequences`. [#1389][] (@victorlin)
* filter: Fixed the summary messages to properly reflect force-inclusion of strains that are missing an entry from `--sequences`. [#1389][] (@victorlin)
* filter: Updated wording of summary messages. [#1389][] (@victorlin)

[#1294]: https://github.com/nextstrain/augur/pull/1294
[#1389]: https://github.com/nextstrain/augur/pull/1389

## 24.1.0 (30 January 2024)
Expand Down
6 changes: 6 additions & 0 deletions augur/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Filter and subsample a sequence set.
"""
from augur.dates import numeric_date_type, SUPPORTED_DATE_HELP_TEXT
from augur.filter.io import ACCEPTED_TYPES, column_type_pair
from augur.io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, METADATA_DATE_COLUMN
from augur.types import EmptyOutputReportingMethod
from . import constants
Expand All @@ -28,6 +29,11 @@ def register_arguments(parser):
Uses Pandas Dataframe querying, see https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#indexing-query for syntax.
(e.g., --query "country == 'Colombia'" or --query "(country == 'USA' & (division == 'Washington'))")"""
)
metadata_filter_group.add_argument('--query-columns', type=column_type_pair, nargs="+", help=f"""
Use alongside --query to specify columns and data types in the format 'column:type', where type is one of ({','.join(ACCEPTED_TYPES)}).
Automatic type inference will be attempted on all unspecified columns used in the query.
Example: region:str coverage:float.
""")
metadata_filter_group.add_argument('--min-date', type=numeric_date_type, help=f"minimal cutoff for date, the cutoff date is inclusive; may be specified as: {SUPPORTED_DATE_HELP_TEXT}")
metadata_filter_group.add_argument('--max-date', type=numeric_date_type, help=f"maximal cutoff for date, the cutoff date is inclusive; may be specified as: {SUPPORTED_DATE_HELP_TEXT}")
metadata_filter_group.add_argument('--exclude-ambiguous-dates-by', choices=['any', 'day', 'month', 'year'],
Expand Down
81 changes: 19 additions & 62 deletions augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
DELIMITER as SEQUENCE_INDEX_DELIMITER,
)
from augur.io.file import open_file
from augur.io.metadata import InvalidDelimiter, read_metadata
from augur.io.metadata import InvalidDelimiter, Metadata, read_metadata
from augur.io.sequences import read_sequences, write_sequences
from augur.io.print import print_err
from augur.io.vcf import is_vcf as filename_is_vcf, write_vcf
from augur.types import EmptyOutputReportingMethod
from . import include_exclude_rules
from .io import cleanup_outputs, read_priority_scores
from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs
from .include_exclude_rules import apply_filters, construct_filters
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, create_queues_by_group, get_groups_for_subsampling

Expand Down Expand Up @@ -133,16 +133,6 @@ def run(args):
random_generator = np.random.default_rng(args.subsample_seed)
priorities = defaultdict(random_generator.random)

# Setup metadata output. We track whether any records have been written to
# disk yet through the following variables, to control whether we write the
# metadata's header and open a new file for writing.
metadata_header = True
metadata_mode = "w"

# Setup strain output.
if args.output_strains:
output_strains = open(args.output_strains, "w")

# Setup logging.
output_log_writer = None
if args.output_log:
Expand All @@ -168,19 +158,23 @@ def run(args):
filter_counts = defaultdict(int)

try:
metadata_reader = read_metadata(
args.metadata,
delimiters=args.metadata_delimiters,
id_columns=args.metadata_id_columns,
chunk_size=args.metadata_chunk_size,
dtype="string",
)
metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns)
except InvalidDelimiter:
raise AugurError(
f"Could not determine the delimiter of {args.metadata!r}. "
f"Valid delimiters are: {args.metadata_delimiters!r}. "
"This can be changed with --metadata-delimiters."
)
useful_metadata_columns = get_useful_metadata_columns(args, metadata_object.id_column, metadata_object.columns)

metadata_reader = read_metadata(
args.metadata,
delimiters=[metadata_object.delimiter],
columns=useful_metadata_columns,
id_columns=[metadata_object.id_column],
chunk_size=args.metadata_chunk_size,
dtype="string",
)
for metadata in metadata_reader:
duplicate_strains = (
set(metadata.index[metadata.index.duplicated()]) |
Expand Down Expand Up @@ -263,30 +257,6 @@ def run(args):
priorities[strain],
)

# Always write out strains that are force-included. Additionally, if
# we are not grouping, write out metadata and strains that passed
# filters so far.
force_included_strains_to_write = distinct_force_included_strains
if not group_by:
force_included_strains_to_write = force_included_strains_to_write | seq_keep

if args.output_metadata:
# TODO: wrap logic to write metadata into its own function
metadata.loc[list(force_included_strains_to_write)].to_csv(
args.output_metadata,
sep="\t",
header=metadata_header,
mode=metadata_mode,
)
metadata_header = False
metadata_mode = "a"

if args.output_strains:
# TODO: Output strains will no longer be ordered. This is a
# small breaking change.
for strain in force_included_strains_to_write:
output_strains.write(f"{strain}\n")

# In the worst case, we need to calculate sequences per group from the
# requested maximum number of sequences and the number of sequences per
# group. Then, we need to make a second pass through the metadata to find
Expand Down Expand Up @@ -323,6 +293,7 @@ def run(args):
metadata_reader = read_metadata(
args.metadata,
delimiters=args.metadata_delimiters,
columns=useful_metadata_columns,
id_columns=args.metadata_id_columns,
chunk_size=args.metadata_chunk_size,
dtype="string",
Expand Down Expand Up @@ -367,23 +338,6 @@ def run(args):
# Construct a data frame of records to simplify metadata output.
records.append(record)

if args.output_strains:
# TODO: Output strains will no longer be ordered. This is a
# small breaking change.
output_strains.write(f"{record.name}\n")

# Write records to metadata output, if requested.
if args.output_metadata and len(records) > 0:
records = pd.DataFrame(records)
records.to_csv(
args.output_metadata,
sep="\t",
header=metadata_header,
mode=metadata_mode,
)
metadata_header = False
metadata_mode = "a"

# Count and optionally log strains that were not included due to
# subsampling.
strains_filtered_by_subsampling = valid_strains - subsampled_strains
Expand Down Expand Up @@ -442,14 +396,17 @@ def run(args):
# Update the set of available sequence strains.
sequence_strains = observed_sequence_strains

if args.output_metadata or args.output_strains:
write_metadata_based_outputs(args.metadata, args.metadata_delimiters,
args.metadata_id_columns, args.output_metadata,
args.output_strains, valid_strains)

# Calculate the number of strains that don't exist in either metadata or
# sequences.
num_excluded_by_lack_of_metadata = 0
if sequence_strains:
num_excluded_by_lack_of_metadata = len(sequence_strains - metadata_strains)

if args.output_strains:
output_strains.close()

# Calculate the number of strains passed and filtered.
total_strains_passed = len(valid_strains)
Expand Down
101 changes: 78 additions & 23 deletions augur/filter/include_exclude_rules.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import ast
import json
import operator
import re
import numpy as np
import pandas as pd
from typing import Any, Callable, Dict, List, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

from augur.dates import is_date_ambiguous, get_numerical_dates
from augur.errors import AugurError
Expand Down Expand Up @@ -78,7 +79,7 @@ def filter_by_exclude(metadata, exclude_file) -> FilterFunctionReturn:
return set(metadata.index.values) - excluded_strains


def _parse_filter_query(query):
def parse_filter_query(query):
"""Parse an augur filter-style query and return the corresponding column,
operator, and value for the query.
Expand All @@ -98,9 +99,9 @@ def _parse_filter_query(query):
Examples
--------
>>> _parse_filter_query("property=value")
>>> parse_filter_query("property=value")
('property', <built-in function eq>, 'value')
>>> _parse_filter_query("property!=value")
>>> parse_filter_query("property!=value")
('property', <built-in function ne>, 'value')
"""
Expand Down Expand Up @@ -143,7 +144,7 @@ def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn:
['strain1', 'strain2']
"""
column, op, value = _parse_filter_query(exclude_where)
column, op, value = parse_filter_query(exclude_where)
if column in metadata.columns:
# Apply a test operator (equality or inequality) to values from the
# column in the given query. This produces an array of boolean values we
Expand All @@ -164,7 +165,7 @@ def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn:
return filtered


def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn:
def filter_by_query(metadata: pd.DataFrame, query: str, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn:
"""Filter metadata in the given pandas DataFrame with a query string and return
the strain names that pass the filter.
Expand All @@ -174,6 +175,8 @@ def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn:
Metadata indexed by strain name
query : str
Query string for the dataframe.
column_types : str
Dict mapping of data type
Examples
--------
Expand All @@ -187,22 +190,42 @@ def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn:
# Create a copy to prevent modification of the original DataFrame.
metadata_copy = metadata.copy()

# Support numeric comparisons in query strings.
#
# The built-in data type inference when loading the DataFrame does not
if column_types is None:
column_types = {}

# Set columns for type conversion.
variables = extract_variables(query)
if variables is not None:
columns = variables.intersection(metadata_copy.columns)
else:
# Column extraction failed. Apply type conversion to all columns.
columns = metadata_copy.columns

# If a type is not explicitly provided, try converting the column to numeric.
# This should cover most use cases, since one common problem is that the
# built-in data type inference when loading the DataFrame does not
# support nullable numeric columns, so numeric comparisons won't work on
# those columns. pd.to_numeric does proper conversion on those columns, and
# will not make any changes to columns with other values.
#
# TODO: Parse the query string and apply conversion only to columns used for
# numeric comparison. Pandas does not expose the API used to parse the query
# string internally, so this is non-trivial and requires a bit of
# reverse-engineering. Commit 2ead5b3e3306dc1100b49eb774287496018122d9 got
# halfway there but had issues so it was reverted.
#
# TODO: Try boolean conversion?
for column in metadata_copy.columns:
metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='ignore')
# those columns. pd.to_numeric does proper conversion on those columns,
# and will not make any changes to columns with other values.
for column in columns:
column_types.setdefault(column, 'numeric')

# Convert data types before applying the query.
for column, dtype in column_types.items():
if dtype == 'numeric':
metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='ignore')
elif dtype == 'int':
try:
metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise', downcast='integer')
except ValueError as e:
raise AugurError(f"Failed to convert value in column {column!r} to int. {e}")
elif dtype == 'float':
try:
metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise', downcast='float')
except ValueError as e:
raise AugurError(f"Failed to convert value in column {column!r} to float. {e}")
elif dtype == 'str':
metadata_copy[column] = metadata_copy[column].astype('str', errors='ignore')

try:
return set(metadata_copy.query(query).index.values)
Expand Down Expand Up @@ -492,7 +515,7 @@ def force_include_where(metadata, include_where) -> FilterFunctionReturn:
set()
"""
column, op, value = _parse_filter_query(include_where)
column, op, value = parse_filter_query(include_where)

if column in metadata.columns:
# Apply a test operator (equality or inequality) to values from the
Expand Down Expand Up @@ -578,9 +601,13 @@ def construct_filters(args, sequence_index) -> Tuple[List[FilterOption], List[Fi

# Exclude strains by metadata, using pandas querying.
if args.query:
kwargs = {"query": args.query}
if args.query_columns:
kwargs["column_types"] = {column: dtype for column, dtype in args.query_columns}

exclude_by.append((
filter_by_query,
{"query": args.query}
kwargs
))

# Filter by ambiguous dates.
Expand Down Expand Up @@ -820,3 +847,31 @@ def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs):
kwarg_list.append((key, value))

return json.dumps(kwarg_list)


def extract_variables(pandas_query: str):
"""Try extracting all variable names used in a pandas query string.
If successful, return the variable names as a set. Otherwise, nothing is returned.
Examples
--------
>>> extract_variables("var1 == 'value'")
{'var1'}
>>> sorted(extract_variables("var1 == 'value' & var2 == 10"))
['var1', 'var2']
>>> extract_variables("var1.str.startswith('prefix')")
{'var1'}
>>> extract_variables("this query is invalid")
"""
# Since Pandas' query grammar should be a subset of Python's, which uses the
# ast stdlib under the hood, we can try to parse queries with that as well.
# Errors may arise from invalid query syntax or any Pandas syntax not
# covered by Python (unlikely, but I'm not sure). In those cases, don't
# return anything.
try:
return set(node.id
for node in ast.walk(ast.parse(pandas_query))
if isinstance(node, ast.Name))
except:
return None

0 comments on commit 8678ae9

Please sign in to comment.