Skip to content

Commit

Permalink
Change threshold in filter_by_name to 0.0...1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
gutzbenj committed Apr 25, 2024
1 parent 5c48ca6 commit c8540ac
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -6,6 +6,7 @@ Development

- Adjust column specs for DWD Observation station listing
- Maintain order during deduplication
- Change threshold in `filter_by_name` to 0.0...1.0

0.81.0 (09.04.2024)
*******************
Expand Down
4 changes: 2 additions & 2 deletions tests/ui/cli/test_cli_warming_stripes.py
Expand Up @@ -60,9 +60,9 @@ def test_warming_stripes_start_year_ge_end_year():
@pytest.mark.remote
def test_warming_stripes_wrong_name_threshold():
runner = CliRunner()
result = runner.invoke(cli, "warming_stripes --station 1048 --name_threshold 101")
result = runner.invoke(cli, "warming_stripes --station 1048 --name_threshold 1.01")
assert result.exit_code == 1
assert "Error: name_threshold must be more than 0 and less than or equal to 100" in result.stdout
assert "Error: name_threshold must be between 0.0 and 1.0" in result.stdout


@pytest.mark.remote
Expand Down
6 changes: 2 additions & 4 deletions tests/ui/test_restapi.py
Expand Up @@ -570,13 +570,11 @@ def test_warming_stripes_wrong_name_threshold(client):
"/api/warming_stripes",
params={
"name": "Dresden-Klotzsche",
"name_threshold": 101,
"name_threshold": 1.01,
},
)
assert response.status_code == 400
assert response.json() == {
"detail": "Query argument 'name_threshold' must be more than 0 and less than or equal to 100"
}
assert response.json() == {"detail": "Query argument 'name_threshold' must be between 0.0 and 1.0"}


@pytest.mark.remote
Expand Down
12 changes: 6 additions & 6 deletions wetterdienst/core/timeseries/request.py
Expand Up @@ -668,30 +668,30 @@ def filter_by_station_id(self, station_id: str | tuple[str, ...] | list[str]) ->
stations_filter=StationsFilter.BY_STATION_ID,
)

def filter_by_name(self, name: str, rank: int = 1, threshold: int = 90) -> StationsResult:
def filter_by_name(self, name: str, rank: int = 1, threshold: float = 0.9) -> StationsResult:
"""
Method to filter stations_result for station name using string comparison.
:param name: name of looked up station
:param rank: number of stations requested
:param threshold: threshold for string match 0...100
:param threshold: threshold for string match 0.0...1.0
:return: df with matched station
"""
rank = int(rank)
if rank <= 0:
raise ValueError("'rank' has to be at least 1.")

threshold = int(threshold)
if threshold < 0:
raise ValueError("threshold must be ge 0")
threshold = float(threshold)
if threshold < 0 or threshold > 1:
raise ValueError("threshold must be between 0.0 and 1.0")

df = self.all().df

station_match = process.extract(
query=name,
choices=df[Columns.NAME.value],
scorer=fuzz.token_set_ratio,
score_cutoff=threshold,
score_cutoff=threshold * 100,
)

if station_match:
Expand Down
4 changes: 2 additions & 2 deletions wetterdienst/ui/cli.py
Expand Up @@ -1137,7 +1137,7 @@ def radar(
@cloup.option("--name", type=click.STRING)
@cloup.option("--start_year", type=click.INT)
@cloup.option("--end_year", type=click.INT)
@cloup.option("--name_threshold", type=click.INT, default=80)
@cloup.option("--name_threshold", type=click.FLOAT, default=0.90)
@cloup.option("--show_title", type=click.BOOL, default=True)
@cloup.option("--show_years", type=click.BOOL, default=True)
@cloup.option("--show_data_availability", type=click.BOOL, default=True)
Expand All @@ -1154,7 +1154,7 @@ def warming_stripes(
name: str,
start_year: int,
end_year: int,
name_threshold: int,
name_threshold: float,
show_title: bool,
show_years: bool,
show_data_availability: bool,
Expand Down
8 changes: 4 additions & 4 deletions wetterdienst/ui/core.py
Expand Up @@ -433,7 +433,7 @@ def _plot_warming_stripes(
name: str | None = None,
start_year: int | None = None,
end_year: int | None = None,
name_threshold: int = 80,
name_threshold: float = 0.9,
show_title: bool = True,
show_years: bool = True,
show_data_availability: bool = True,
Expand All @@ -446,8 +446,8 @@ def _plot_warming_stripes(
if start_year and end_year:
if start_year >= end_year:
raise ValueError("start_year must be less than end_year")
if name_threshold <= 0 or name_threshold > 100:
raise ValueError("name_threshold must be more than 0 and less than or equal to 100")
if name_threshold < 0 or name_threshold > 1:
raise ValueError("name_threshold must be between 0.0 and 1.0")
if dpi <= 0:
raise ValueError("dpi must be more than 0")

Expand Down Expand Up @@ -545,7 +545,7 @@ def _thread_safe_plot_warming_stripes(
name: str | None = None,
start_year: int | None = None,
end_year: int | None = None,
name_threshold: int = 80,
name_threshold: float = 0.9,
show_title: bool = True,
show_years: bool = True,
show_data_availability: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions wetterdienst/ui/restapi.py
Expand Up @@ -594,7 +594,7 @@ def warming_stripes(
name: Annotated[Optional[str], Query()] = None,
start_year: Annotated[Optional[int], Query()] = None,
end_year: Annotated[Optional[int], Query()] = None,
name_threshold: Annotated[Optional[int], Query()] = 80,
name_threshold: Annotated[Optional[float], Query()] = 0.9,
show_title: Annotated[bool, Query()] = True,
show_years: Annotated[bool, Query()] = True,
show_data_availability: Annotated[bool, Query()] = True,
Expand All @@ -620,10 +620,10 @@ def warming_stripes(
status_code=400,
detail="Query argument 'start_year' must be less than 'end_year'",
)
if name_threshold <= 0 or name_threshold > 100:
if name_threshold < 0 or name_threshold > 1:
raise HTTPException(
status_code=400,
detail="Query argument 'name_threshold' must be more than 0 and less than or equal to 100",
detail="Query argument 'name_threshold' must be between 0.0 and 1.0",
)
if fmt not in ["png", "jpg", "svg", "pdf"]:
raise HTTPException(
Expand Down

0 comments on commit c8540ac

Please sign in to comment.