diff --git a/_delphi_utils_python/DEVELOP.md b/_delphi_utils_python/DEVELOP.md index 2407e29a8..20d41166c 100644 --- a/_delphi_utils_python/DEVELOP.md +++ b/_delphi_utils_python/DEVELOP.md @@ -54,3 +54,5 @@ When you are finished, the virtual environment can be deactivated and deactivate rm -r env ``` +## Releasing the module +If you have made enough changes that it warrants updating [the PyPi project](https://pypi.org/project/delphi-utils/), currently this is done as part of merging from `main` to `prod`. diff --git a/nchs_mortality/delphi_nchs_mortality/constants.py b/nchs_mortality/delphi_nchs_mortality/constants.py index 800444e58..2bdd78419 100644 --- a/nchs_mortality/delphi_nchs_mortality/constants.py +++ b/nchs_mortality/delphi_nchs_mortality/constants.py @@ -25,8 +25,3 @@ "prop" ] INCIDENCE_BASE = 100000 - -# this is necessary as a delimiter in the f-string expressions we use to -# construct detailed error reports -# (https://www.python.org/dev/peps/pep-0498/#escape-sequences) -NEWLINE = "\n" diff --git a/nchs_mortality/delphi_nchs_mortality/pull.py b/nchs_mortality/delphi_nchs_mortality/pull.py index 18bbfd59a..aef964168 100644 --- a/nchs_mortality/delphi_nchs_mortality/pull.py +++ b/nchs_mortality/delphi_nchs_mortality/pull.py @@ -5,11 +5,11 @@ import numpy as np import pandas as pd +from delphi_utils.geomap import GeoMapper from sodapy import Socrata -from delphi_utils.geomap import GeoMapper +from .constants import METRICS, RENAME -from .constants import METRICS, RENAME, NEWLINE def standardize_columns(df): """Rename columns to comply with a standard set. @@ -85,16 +85,15 @@ def pull_nchs_mortality_data(socrata_token: str, test_file: Optional[str] = None try: df = df.astype(type_dict) except KeyError as exc: - raise ValueError(f""" + raise ValueError( + f""" Expected column(s) missed, The dataset schema may have changed. Please investigate and amend the code. -Columns needed: -{NEWLINE.join(type_dict.keys())} - -Columns available: -{NEWLINE.join(df.columns)} -""") from exc +expected={''.join(type_dict.keys())} +received={''.join(df.columns)} +""" + ) from exc df = df[keep_columns + ["timestamp", "state"]].set_index("timestamp") diff --git a/nwss_wastewater/delphi_nwss/constants.py b/nwss_wastewater/delphi_nwss/constants.py index 5e8eb2aeb..648e44708 100644 --- a/nwss_wastewater/delphi_nwss/constants.py +++ b/nwss_wastewater/delphi_nwss/constants.py @@ -12,18 +12,33 @@ SIGNALS = ["pcr_conc_smoothed"] METRIC_SIGNALS = ["detect_prop_15d", "percentile", "ptc_15d"] -METRIC_DATES = ["date_start", "date_end"] -SAMPLE_SITE_NAMES = { - "wwtp_jurisdiction": "category", - "wwtp_id": int, - "reporting_jurisdiction": "category", - "sample_location": "category", - "county_names": "category", - "county_fips": "category", - "population_served": float, - "sampling_prior": bool, - "sample_location_specify": float, +PROVIDER_NORMS = { + "CDC_VERILY": ("flow-population", "microbial"), + "NWSS": ("flow-population", "microbial"), + "WWS": ("microbial",), } -SIG_DIGITS = 7 -NEWLINE = "\n" +SIG_DIGITS = 4 + +TYPE_DICT = {key: float for key in SIGNALS} +TYPE_DICT.update({"timestamp": "datetime64[ns]"}) +TYPE_DICT_METRIC = {key: float for key in METRIC_SIGNALS} +TYPE_DICT_METRIC.update({key: "datetime64[ns]" for key in ["date_start", "date_end"]}) +# Sample site names +TYPE_DICT_METRIC.update( + { + "wwtp_jurisdiction": "category", + "wwtp_id": int, + "reporting_jurisdiction": "category", + "sample_location": "category", + "county_names": "category", + "county_fips": "category", + "population_served": float, + "sampling_prior": bool, + "sample_location_specify": float, + } +) + +SOURCE_URL = "data.cdc.gov" +CONCENTRATION_TABLE_ID = "g653-rqe2" +METRIC_TABLE_ID = "2ew6-ywp6" diff --git a/nwss_wastewater/delphi_nwss/pull.py b/nwss_wastewater/delphi_nwss/pull.py index f4b781e12..d9fd9921c 100644 --- a/nwss_wastewater/delphi_nwss/pull.py +++ b/nwss_wastewater/delphi_nwss/pull.py @@ -6,12 +6,15 @@ from sodapy import Socrata from .constants import ( - SIGNALS, METRIC_SIGNALS, - METRIC_DATES, - SAMPLE_SITE_NAMES, + PROVIDER_NORMS, SIG_DIGITS, - NEWLINE, + SIGNALS, + TYPE_DICT, + TYPE_DICT_METRIC, + SOURCE_URL, + CONCENTRATION_TABLE_ID, + METRIC_TABLE_ID, ) @@ -34,47 +37,86 @@ def sig_digit_round(value, n_digits): return result -def construct_typedicts(): - """Create the type conversion dictionary for both dataframes.""" - # basic type conversion - type_dict = {key: float for key in SIGNALS} - type_dict["timestamp"] = "datetime64[ns]" - # metric type conversion - signals_dict_metric = {key: float for key in METRIC_SIGNALS} - metric_dates_dict = {key: "datetime64[ns]" for key in METRIC_DATES} - type_dict_metric = {**metric_dates_dict, **signals_dict_metric, **SAMPLE_SITE_NAMES} - return type_dict, type_dict_metric - - -def warn_string(df, type_dict): - """Format the warning string.""" - return f""" +def convert_df_type(df, type_dict, logger): + """Convert types and warn if there are unexpected columns.""" + try: + df = df.astype(type_dict) + except KeyError as exc: + raise KeyError( + f""" Expected column(s) missed, The dataset schema may have changed. Please investigate and amend the code. -Columns needed: -{NEWLINE.join(sorted(type_dict.keys()))} - -Columns available: -{NEWLINE.join(sorted(df.columns))} +expected={''.join(sorted(type_dict.keys()))} +received={''.join(sorted(df.columns))} """ + ) from exc + if new_columns := set(df.columns) - set(type_dict.keys()): + logger.info("New columns found in NWSS dataset.", new_columns=new_columns) + return df + +def reformat(df, df_metric): + """Combine df_metric and df. -def add_population(df, df_metric): - """Add the population column from df_metric to df, and rename some columns.""" + Move population and METRIC_SIGNAL columns from df_metric to df, and rename + date_start to timestamp. + """ # drop unused columns from df_metric - df_population = df_metric.loc[:, ["key_plot_id", "date_start", "population_served"]] + df_metric_core = df_metric.loc[ + :, ["key_plot_id", "date_end", "population_served", *METRIC_SIGNALS] + ] # get matching keys - df_population = df_population.rename(columns={"date_start": "timestamp"}) - df_population = df_population.set_index(["key_plot_id", "timestamp"]) + df_metric_core = df_metric_core.rename(columns={"date_end": "timestamp"}) + df_metric_core = df_metric_core.set_index(["key_plot_id", "timestamp"]) df = df.set_index(["key_plot_id", "timestamp"]) + df = df.sort_index() - df = df.join(df_population) + df = df.join(df_metric_core) df = df.reset_index() return df -def pull_nwss_data(socrata_token: str): +def add_identifier_columns(df): + """Parse `key_plot_id` to create several key columns. + + `key_plot_id` is of format "___wwtp_id". + We split by `_` and put each resulting item into its own column. + Add columns to get more detail than key_plot_id gives; specifically, state, and + `provider_normalization`, which gives the signal identifier + """ + df = df.copy() + # a pair of alphanumerics surrounded by _; for example, it matches "_al_", + # and not "_3a_" and returns just the two letters "al" + df["state"] = df.key_plot_id.str.extract(r"_(\w\w)_") + # anything followed by state as described just above. + # For example "CDC_VERILY_al" pulls out "CDC_VERILY" + df["provider"] = df.key_plot_id.str.extract(r"(.*)_[a-z]{2}_") + df["signal_name"] = df.provider + "_" + df.normalization + return df + + +def check_expected_signals(df): + """Make sure that there aren't any new signals that we need to add.""" + # compare with existing column name checker + # also add a note about handling errors + unique_provider_norms = ( + df[["provider", "normalization"]] + .drop_duplicates() + .sort_values(["provider", "normalization"]) + .reset_index(drop=True) + ) + for provider, normalization in zip( + unique_provider_norms["provider"], unique_provider_norms["normalization"] + ): + if not normalization in PROVIDER_NORMS[provider]: + raise ValueError( + f"There are new providers and/or norms." + f"The full new set is\n{unique_provider_norms}" + ) + + +def pull_nwss_data(token: str, logger): """Pull the latest NWSS Wastewater data, and conforms it into a dataset. The output dataset has: @@ -87,48 +129,56 @@ def pull_nwss_data(socrata_token: str): ---------- socrata_token: str My App Token for pulling the NWSS data (could be the same as the nchs data) - test_file: Optional[str] - When not null, name of file from which to read test data + logger: the structured logger Returns ------- pd.DataFrame Dataframe as described above. """ - # concentration key types - type_dict, type_dict_metric = construct_typedicts() - # Pull data from Socrata API - client = Socrata("data.cdc.gov", socrata_token) - results_concentration = client.get("g653-rqe2", limit=10**10) - results_metric = client.get("2ew6-ywp6", limit=10**10) + client = Socrata(SOURCE_URL, token) + results_concentration = client.get(CONCENTRATION_TABLE_ID, limit=10**10) + results_metric = client.get(METRIC_TABLE_ID, limit=10**10) df_metric = pd.DataFrame.from_records(results_metric) df_concentration = pd.DataFrame.from_records(results_concentration) df_concentration = df_concentration.rename(columns={"date": "timestamp"}) - try: - df_concentration = df_concentration.astype(type_dict) - except KeyError as exc: - raise ValueError(warn_string(df_concentration, type_dict)) from exc + # Schema checks. + df_concentration = convert_df_type(df_concentration, TYPE_DICT, logger) + df_metric = convert_df_type(df_metric, TYPE_DICT_METRIC, logger) - try: - df_metric = df_metric.astype(type_dict_metric) - except KeyError as exc: - raise ValueError(warn_string(df_metric, type_dict_metric)) from exc + # Drop sites without a normalization scheme. + df = df_concentration[~df_concentration["normalization"].isna()] - # pull 2 letter state labels out of the key_plot_id labels - df_concentration["state"] = df_concentration.key_plot_id.str.extract(r"_(\w\w)_") + # Pull 2 letter state labels out of the key_plot_id labels. + df = add_identifier_columns(df) + # move population and metric signals over to df + df = reformat(df, df_metric) # round out some of the numeric noise that comes from smoothing - df_concentration[SIGNALS[0]] = sig_digit_round( - df_concentration[SIGNALS[0]], SIG_DIGITS - ) - - df_concentration = add_population(df_concentration, df_metric) - # if there are population NA's, assume the previous value is accurate (most - # likely introduced by dates only present in one and not the other; even - # otherwise, best to assume some value rather than break the data) - df_concentration.population_served = df_concentration.population_served.ffill() - - keep_columns = ["timestamp", "state", "population_served"] - return df_concentration[SIGNALS + keep_columns] + for signal in [*SIGNALS, *METRIC_SIGNALS]: + df[signal] = sig_digit_round(df[signal], SIG_DIGITS) + + # For each location, fill missing population values with a previous + # population value. + # Missing population values seem to be introduced by dates present in only + # one of the two (concentration and metric) datastes. This `ffill` approach + # assumes that the population on a previous date is still accurate. However, + # population served by a given sewershed can and does change over time. The + # effect is presumably minimal since contiguous dates with missing + # population should be limited in length such that incorrect + # population values are quickly corrected. + df.population_served = df.population_served.groupby(by = ["key_plot_id"]).ffill() + check_expected_signals(df) + + keep_columns = [ + *SIGNALS, + *METRIC_SIGNALS, + "timestamp", + "state", + "population_served", + "normalization", + "provider", + ] + return df[keep_columns] diff --git a/nwss_wastewater/delphi_nwss/run.py b/nwss_wastewater/delphi_nwss/run.py index 378849ba5..d236bca6a 100644 --- a/nwss_wastewater/delphi_nwss/run.py +++ b/nwss_wastewater/delphi_nwss/run.py @@ -2,7 +2,7 @@ """Functions to call when running the function. This module should contain a function called `run_module`, that is executed -when the module is run with `python -m MODULE_NAME`. `run_module`'s lone argument should be a +when the module is run with `python -m delphi_nwss`. `run_module`'s lone argument should be a nested dictionary of parameters loaded from the params.json file. We expect the `params` to have the following structure: - "common": @@ -16,75 +16,31 @@ `delphi_utils.add_prefix()` - "test_file" (optional): str, name of file from which to read test data - "socrata_token": str, authentication for upstream data pull - - "archive" (optional): if provided, output will be archived with S3 - - "aws_credentials": Dict[str, str], AWS login credentials (see S3 documentation) - - "bucket_name: str, name of S3 bucket to read/write - - "cache_dir": str, directory of locally cached data """ + import time from datetime import datetime import numpy as np -import pandas as pd -from delphi_utils import S3ArchiveDiffer, get_structured_logger, create_export_csv +from delphi_utils import ( + GeoMapper, + get_structured_logger, + create_export_csv, +) from delphi_utils.nancodes import add_default_nancodes -from .constants import GEOS, SIGNALS +from .constants import GEOS, METRIC_SIGNALS, PROVIDER_NORMS, SIGNALS from .pull import pull_nwss_data -def sum_all_nan(x): - """Return a normal sum unless everything is NaN, then return that.""" - all_nan = np.isnan(x).all() - if all_nan: - return np.nan - return np.nansum(x) - - -def generate_weights(df, column_aggregating="pcr_conc_smoothed"): - """ - Weigh column_aggregating by population. - - generate the relevant population amounts, and create a weighted but - unnormalized column, derived from `column_aggregating` - """ - # set the weight of places with na's to zero - df[f"relevant_pop_{column_aggregating}"] = ( - df["population_served"] * df[column_aggregating].notna() - ) - # generate the weighted version - df[f"weighted_{column_aggregating}"] = ( - df[column_aggregating] * df[f"relevant_pop_{column_aggregating}"] - ) - return df - - -def weighted_state_sum(df: pd.DataFrame, geo: str, sensor: str): - """Sum sensor, weighted by population for non NA's, grouped by state.""" - agg_df = df.groupby(["timestamp", geo]).agg( - {f"relevant_pop_{sensor}": "sum", f"weighted_{sensor}": sum_all_nan} - ) - agg_df["val"] = agg_df[f"weighted_{sensor}"] / agg_df[f"relevant_pop_{sensor}"] - agg_df = agg_df.reset_index() - agg_df = agg_df.rename(columns={"state": "geo_id"}) - return agg_df - - -def weighted_nation_sum(df: pd.DataFrame, sensor: str): - """Sum sensor, weighted by population for non NA's.""" - agg_df = df.groupby("timestamp").agg( - {f"relevant_pop_{sensor}": "sum", f"weighted_{sensor}": sum_all_nan} - ) - agg_df["val"] = agg_df[f"weighted_{sensor}"] / agg_df[f"relevant_pop_{sensor}"] - agg_df = agg_df.reset_index() - agg_df["geo_id"] = "us" - return agg_df - - def add_needed_columns(df, col_names=None): """Short util to add expected columns not found in the dataset.""" if col_names is None: col_names = ["se", "sample_size"] + else: + assert "geo_value" not in col_names + assert "time_value" not in col_names + assert "value" not in col_names for col_name in col_names: df[col_name] = np.nan @@ -125,39 +81,49 @@ def run_module(params): ) export_dir = params["common"]["export_dir"] socrata_token = params["indicator"]["socrata_token"] - if "archive" in params: - daily_arch_diff = S3ArchiveDiffer( - params["archive"]["cache_dir"], - export_dir, - params["archive"]["bucket_name"], - "nchs_mortality", - params["archive"]["aws_credentials"], - ) - daily_arch_diff.update_cache() - run_stats = [] ## build the base version of the signal at the most detailed geo level you can get. ## compute stuff here or farm out to another function or file - df_pull = pull_nwss_data(socrata_token) - ## aggregate - for sensor in SIGNALS: - df = df_pull.copy() - # add weighed column - df = generate_weights(df, sensor) - - for geo in GEOS: - logger.info("Generating signal and exporting to CSV", metric=sensor) - if geo == "nation": - agg_df = weighted_nation_sum(df, sensor) - else: - agg_df = weighted_state_sum(df, geo, sensor) - # add se, sample_size, and na codes - agg_df = add_needed_columns(agg_df) - # actual export - dates = create_export_csv( - agg_df, geo_res=geo, export_dir=export_dir, sensor=sensor - ) - if len(dates) > 0: - run_stats.append((max(dates), len(dates))) + df_pull = pull_nwss_data(socrata_token, logger) + geomapper = GeoMapper() + # iterate over the providers and the normalizations that they specifically provide + for provider, normalizations in PROVIDER_NORMS.items(): + for normalization in normalizations: + # copy by only taking the relevant subsection + df_prov_norm = df_pull[ + (df_pull.provider == provider) + & (df_pull.normalization == normalization) + ] + df_prov_norm = df_prov_norm.drop(["provider", "normalization"], axis=1) + for sensor in [*SIGNALS, *METRIC_SIGNALS]: + full_sensor_name = sensor + "_" + provider + "_" + normalization + for geo in GEOS: + logger.info( + "Generating signal and exporting to CSV", + metric=full_sensor_name, + ) + if geo == "nation": + df_prov_norm["nation"] = "us" + agg_df = geomapper.aggregate_by_weighted_sum( + df_prov_norm, + geo, + sensor, + "timestamp", + "population_served", + ) + agg_df = agg_df.rename( + columns={geo: "geo_id", f"weighted_{sensor}": "val"} + ) + # add se, sample_size, and na codes + agg_df = add_needed_columns(agg_df) + # actual export + dates = create_export_csv( + agg_df, + geo_res=geo, + export_dir=export_dir, + sensor=full_sensor_name, + ) + if len(dates) > 0: + run_stats.append((max(dates), len(dates))) ## log this indicator run logging(start_time, run_stats, logger) diff --git a/nwss_wastewater/tests/test_pull.py b/nwss_wastewater/tests/test_pull.py index 8a2edbd23..273f6e311 100644 --- a/nwss_wastewater/tests/test_pull.py +++ b/nwss_wastewater/tests/test_pull.py @@ -1,20 +1,12 @@ -from datetime import datetime, date -import json -from unittest.mock import patch -import tempfile -import os -import time -from datetime import datetime - import pandas as pd import pandas.api.types as ptypes from delphi_nwss.pull import ( - construct_typedicts, + add_identifier_columns, sig_digit_round, - add_population, - warn_string, + reformat, ) +from delphi_nwss.constants import TYPE_DICT, TYPE_DICT_METRIC import numpy as np @@ -29,32 +21,10 @@ def test_sig_digit(): ).all() -def test_column_type_dicts(): - type_dict, type_dict_metric = construct_typedicts() - assert type_dict == {"pcr_conc_smoothed": float, "timestamp": "datetime64[ns]"} - assert type_dict_metric == { - "date_start": "datetime64[ns]", - "date_end": "datetime64[ns]", - "detect_prop_15d": float, - "percentile": float, - "ptc_15d": float, - "wwtp_jurisdiction": "category", - "wwtp_id": int, - "reporting_jurisdiction": "category", - "sample_location": "category", - "county_names": "category", - "county_fips": "category", - "population_served": float, - "sampling_prior": bool, - "sample_location_specify": float, - } - - def test_column_conversions_concentration(): - type_dict, type_dict_metric = construct_typedicts() df = pd.read_csv("test_data/conc_data.csv", index_col=0) df = df.rename(columns={"date": "timestamp"}) - converted = df.astype(type_dict) + converted = df.astype(TYPE_DICT) assert all( converted.columns == pd.Index(["key_plot_id", "timestamp", "pcr_conc_smoothed", "normalization"]) @@ -64,9 +34,8 @@ def test_column_conversions_concentration(): def test_column_conversions_metric(): - type_dict, type_dict_metric = construct_typedicts() df = pd.read_csv("test_data/metric_data.csv", index_col=0) - converted = df.astype(type_dict_metric) + converted = df.astype(TYPE_DICT_METRIC) assert all( converted.columns == pd.Index( @@ -112,16 +81,14 @@ def test_column_conversions_metric(): def test_formatting(): - type_dict, type_dict_metric = construct_typedicts() df_metric = pd.read_csv("test_data/metric_data.csv", index_col=0) - df_metric = df_metric.astype(type_dict_metric) + df_metric = df_metric.astype(TYPE_DICT_METRIC) - type_dict, type_dict_metric = construct_typedicts() df = pd.read_csv("test_data/conc_data.csv", index_col=0) df = df.rename(columns={"date": "timestamp"}) - df = df.astype(type_dict) + df = df.astype(TYPE_DICT) - df_formatted = add_population(df, df_metric) + df_formatted = reformat(df, df_metric) assert all( df_formatted.columns @@ -132,6 +99,28 @@ def test_formatting(): "pcr_conc_smoothed", "normalization", "population_served", + "detect_prop_15d", + "percentile", + "ptc_15d", ] ) ) + + +def test_identifier_colnames(): + test_df = pd.read_csv("test_data/conc_data.csv", index_col=0) + test_df = add_identifier_columns(test_df) + assert all(test_df.state.unique() == ["ak", "tn"]) + assert all(test_df.provider.unique() == ["CDC_BIOBOT", "WWS"]) + # the only cases where the signal name is wrong is when normalization isn't defined + assert all( + (test_df.signal_name == test_df.provider + "_" + test_df.normalization) + | (test_df.normalization.isna()) + ) + assert all( + ( + test_df.signal_name.unique() + == ["CDC_BIOBOT_flow-population", np.nan, "WWS_microbial"] + ) + | (pd.isna(test_df.signal_name.unique())) + ) diff --git a/nwss_wastewater/tests/test_run.py b/nwss_wastewater/tests/test_run.py index 218e1f8d0..dc5740140 100644 --- a/nwss_wastewater/tests/test_run.py +++ b/nwss_wastewater/tests/test_run.py @@ -1,127 +1,23 @@ -from datetime import datetime, date -import json -from unittest.mock import patch -import tempfile -import os -import time -from datetime import datetime - import numpy as np import pandas as pd from pandas.testing import assert_frame_equal -from delphi_utils import S3ArchiveDiffer, get_structured_logger, create_export_csv, Nans - -from delphi_nwss.constants import GEOS, SIGNALS -from delphi_nwss.run import ( - generate_weights, - sum_all_nan, - weighted_state_sum, - weighted_nation_sum, -) - - -def test_sum_all_nan(): - """Check that sum_all_nan returns NaN iff everything is a NaN""" - no_nans = np.array([3, 5]) - assert sum_all_nan(no_nans) == 8 - partial_nan = np.array([np.nan, 3, 5]) - assert np.isclose(sum_all_nan([np.nan, 3, 5]), 8) - oops_all_nans = np.array([np.nan, np.nan]) - assert np.isnan(oops_all_nans).all() - - -def test_weight_generation(): - dataFrame = pd.DataFrame( - { - "a": [1, 2, 3, 4, np.nan], - "b": [5, 6, 7, 8, 9], - "population_served": [10, 5, 8, 1, 3], - } - ) - weighted = generate_weights(dataFrame, column_aggregating="a") - weighted_by_hand = pd.DataFrame( - { - "a": [1, 2, 3, 4, np.nan], - "b": [5, 6, 7, 8, 9], - "population_served": [10, 5, 8, 1, 3], - "relevant_pop_a": [10, 5, 8, 1, 0], - "weighted_a": [10.0, 2 * 5.0, 3 * 8, 4.0 * 1, np.nan * 0], - } - ) - assert_frame_equal(weighted, weighted_by_hand) - # operations are in-place - assert_frame_equal(weighted, dataFrame) +from delphi_nwss.run import add_needed_columns -def test_weighted_state_sum(): - dataFrame = pd.DataFrame( - { - "state": [ - "al", - "al", - "ca", - "ca", - "nd", - ], - "timestamp": np.zeros(5), - "a": [1, 2, 3, 4, 12], - "b": [5, 6, 7, np.nan, np.nan], - "population_served": [10, 5, 8, 1, 3], - } - ) - weighted = generate_weights(dataFrame, column_aggregating="b") - agg = weighted_state_sum(weighted, "state", "b") - expected_agg = pd.DataFrame( - { - "timestamp": np.zeros(3), - "geo_id": ["al", "ca", "nd"], - "relevant_pop_b": [10 + 5, 8 + 0, 0], - "weighted_b": [5 * 10 + 6 * 5, 7 * 8 + 0, np.nan], - "val": [80 / 15, 56 / 8, np.nan], - } - ) - assert_frame_equal(agg, expected_agg) - - weighted = generate_weights(dataFrame, column_aggregating="a") - agg_a = weighted_state_sum(weighted, "state", "a") - expected_agg_a = pd.DataFrame( - { - "timestamp": np.zeros(3), - "geo_id": ["al", "ca", "nd"], - "relevant_pop_a": [10 + 5, 8 + 1, 3], - "weighted_a": [1 * 10 + 2 * 5, 3 * 8 + 1 * 4, 12 * 3], - "val": [20 / 15, 28 / 9, 36 / 3], - } - ) - assert_frame_equal(agg_a, expected_agg_a) - - -def test_weighted_nation_sum(): - dataFrame = pd.DataFrame( - { - "state": [ - "al", - "al", - "ca", - "ca", - "nd", - ], - "timestamp": np.hstack((np.zeros(3), np.ones(2))), - "a": [1, 2, 3, 4, 12], - "b": [5, 6, 7, np.nan, np.nan], - "population_served": [10, 5, 8, 1, 3], - } - ) - weighted = generate_weights(dataFrame, column_aggregating="a") - agg = weighted_nation_sum(weighted, "a") - expected_agg = pd.DataFrame( +def test_adding_cols(): + df = pd.DataFrame({"val": [0.0, np.nan], "timestamp": np.zeros(2)}) + modified = add_needed_columns(df) + modified + expected_df = pd.DataFrame( { - "timestamp": [0.0, 1], - "relevant_pop_a": [10 + 5 + 8, 1 + 3], - "weighted_a": [1 * 10 + 2 * 5 + 3 * 8, 1 * 4 + 3 * 12], - "val": [44 / 23, 40 / 4], - "geo_id": ["us", "us"], + "val": [0.0, np.nan], + "timestamp": np.zeros(2), + "se": [np.nan, np.nan], + "sample_size": [np.nan, np.nan], + "missing_val": [0, 5], + "missing_se": [1, 1], + "missing_sample_size": [1, 1], } ) - assert_frame_equal(agg, expected_agg) + assert_frame_equal(modified, expected_df)