Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add type annotations for pandas and networkx #896

Merged
merged 14 commits into from
Apr 3, 2024
Merged
53 changes: 41 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pytest-unordered = "^0.5.2"
viztracer = "^0.16.1"
ruff = ">=0.1.12,<0.4.0"
pytest-sugar = ">=0.9.7,<1.1.0"
pandas-stubs = "^2.2.1.240316"
types-networkx = "^3.2.1.20240331"


[tool.poetry.scripts]
Expand Down Expand Up @@ -130,7 +132,6 @@ pythonpath = [".", "src", "test"]


[tool.mypy]
ignore_missing_imports = true # TODO: deactivate this
show_column_numbers = true
strict = true
exclude = [
Expand All @@ -149,6 +150,11 @@ exclude = [
]


[[tool.mypy.overrides]]
module = ["jsonpath_ng.*", "viztracer.*"] # For these packages, no type stubs are available yet
ignore_missing_imports = true


[tool.ruff]
line-length = 120
target-version = "py311"
Expand Down
23 changes: 13 additions & 10 deletions src/dsp_tools/commands/excel2json/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from typing import Any
from typing import Optional
from typing import cast

import jsonpath_ng.ext
import jsonschema
Expand Down Expand Up @@ -83,7 +84,7 @@ def excel2properties(
return props, True


def _read_check_property_df(excelfile: str) -> pd.DataFrame | None:
def _read_check_property_df(excelfile: str) -> pd.DataFrame:
sheets_df_dict = read_and_clean_all_sheets(excelfile=excelfile)
if len(sheets_df_dict) != 1:
msg = MoreThanOneSheetProblem("properties.xlsx", list(sheets_df_dict.keys())).execute_error_protocol()
Expand Down Expand Up @@ -166,7 +167,7 @@ def _add_optional_columns(df: pd.DataFrame) -> pd.DataFrame:
in_df_cols = set(df.columns)
if not optional_col_set.issubset(in_df_cols):
additional_col = list(optional_col_set.difference(in_df_cols))
additional_df = pd.DataFrame(columns=additional_col, index=df.index, data=pd.NA)
additional_df = pd.DataFrame(columns=additional_col, index=df.index, data=None)
jnussbaum marked this conversation as resolved.
Show resolved Hide resolved
df = pd.concat(objs=[df, additional_df], axis=1)
return df

Expand All @@ -181,13 +182,13 @@ def _check_missing_values_in_row(df: pd.DataFrame) -> None | list[MissingValuesI
if missing_gui_attributes is not None:
missing_dict.update(missing_gui_attributes)
if missing_dict:
missing_dict = get_wrong_row_numbers(wrong_row_dict=missing_dict, true_remains=True)
return [MissingValuesInRowProblem(column=col, row_numbers=row_nums) for col, row_nums in missing_dict.items()]
missing_dict_int = get_wrong_row_numbers(wrong_row_dict=missing_dict, true_remains=True)
jnussbaum marked this conversation as resolved.
Show resolved Hide resolved
return [MissingValuesInRowProblem(col, row_nums) for col, row_nums in missing_dict_int.items()]
else:
return None


def _check_compliance_gui_attributes(df: pd.DataFrame) -> dict[str, pd.Series] | None:
def _check_compliance_gui_attributes(df: pd.DataFrame) -> dict[str, pd.Series[bool]] | None:
mandatory_attributes = ["Spinbox", "List"]
mandatory_check = col_must_or_not_empty_based_on_other_col(
df=df,
Expand All @@ -208,15 +209,17 @@ def _check_compliance_gui_attributes(df: pd.DataFrame) -> dict[str, pd.Series] |
case None, None:
return None
case pd.Series(), pd.Series():
final_series = pd.Series(np.logical_or(mandatory_check, no_attribute_check)) # type: ignore[arg-type]
mandatory_check = cast("pd.Series[bool]", mandatory_check)
no_attribute_check = cast("pd.Series[bool]", no_attribute_check)
final_series: pd.Series[bool] = pd.Series(np.logical_or(mandatory_check, no_attribute_check))
case pd.Series(), None:
final_series = mandatory_check
final_series = cast("pd.Series[bool]", mandatory_check)
case None, pd.Series:
final_series = no_attribute_check
final_series = cast("pd.Series[bool]", no_attribute_check)
jnussbaum marked this conversation as resolved.
Show resolved Hide resolved
return {"gui_attributes": final_series}


def _row2prop(df_row: pd.Series, row_num: int, excelfile: str) -> dict[str, Any]:
def _row2prop(df_row: pd.Series[Any], row_num: int, excelfile: str) -> dict[str, Any]:
_property = {x: df_row[x] for x in mandatory_properties} | {
"labels": get_labels(df_row=df_row),
"super": [s.strip() for s in df_row["super"].split(",")],
Expand All @@ -239,7 +242,7 @@ def _row2prop(df_row: pd.Series, row_num: int, excelfile: str) -> dict[str, Any]


def _get_gui_attribute(
df_row: pd.Series,
df_row: pd.Series[Any],
row_num: int,
) -> dict[str, int | str | float] | InvalidExcelContentProblem | None:
if pd.isnull(df_row["gui_attributes"]):
Expand Down
6 changes: 4 additions & 2 deletions src/dsp_tools/commands/excel2json/resources.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import importlib.resources
import json
import warnings
Expand Down Expand Up @@ -96,7 +98,7 @@ def _find_validation_problem(


def _row2resource(
class_info_row: pd.Series,
class_info_row: pd.Series[Any],
class_df_with_cardinalities: pd.DataFrame,
) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -270,7 +272,7 @@ def _validate_excel_file(classes_df: pd.DataFrame, df_dict: dict[str, pd.DataFra
f"Please use {[f'label_{lang}' for lang in languages]}"
)
problems: list[Problem] = []
if missing_super_rows := [int(index) + 2 for index, row in classes_df.iterrows() if not check_notna(row["super"])]:
if missing_super_rows := [int(str(i)) + 2 for i, row in classes_df.iterrows() if not check_notna(row["super"])]:
jnussbaum marked this conversation as resolved.
Show resolved Hide resolved
problems.append(MissingValuesInRowProblem(column="super", row_numbers=missing_super_rows))
if duplicate_check := check_column_for_duplicate(classes_df, "name"):
problems.append(duplicate_check)
Expand Down
19 changes: 11 additions & 8 deletions src/dsp_tools/commands/excel2json/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -117,7 +118,7 @@ def check_column_for_duplicate(df: pd.DataFrame, to_check_column: str) -> None |
return None


def check_required_values(df: pd.DataFrame, required_values_columns: list[str]) -> dict[str, pd.Series]:
def check_required_values(df: pd.DataFrame, required_values_columns: list[str]) -> dict[str, pd.Series[bool]]:
"""
If there are any empty cells in the column, it adds the column name and a boolean pd.Series to the dictionary.
If there are no empty cells, then it is not included in the dictionary.
Expand Down Expand Up @@ -155,7 +156,9 @@ def turn_bool_array_into_index_numbers(series: pd.Series[bool], true_remains: bo
return list(series[series].index)


def get_wrong_row_numbers(wrong_row_dict: dict[str, pd.Series], true_remains: bool = True) -> dict[str, list[int]]:
def get_wrong_row_numbers(
wrong_row_dict: dict[str, pd.Series[bool]], true_remains: bool = True
) -> dict[str, list[int]]:
"""
From the boolean pd.Series the index numbers of the True values are extracted.
The resulting list is the new value of the dictionary.
Expand All @@ -168,13 +171,13 @@ def get_wrong_row_numbers(wrong_row_dict: dict[str, pd.Series], true_remains: bo
Returns:
Dictionary with the column name as key and the row number as a list.
"""
wrong_row_dict = {
wrong_row_dict_int = {
jnussbaum marked this conversation as resolved.
Show resolved Hide resolved
k: turn_bool_array_into_index_numbers(series=v, true_remains=true_remains) for k, v in wrong_row_dict.items()
}
return {k: [x + 2 for x in v] for k, v in wrong_row_dict.items()}
return {k: [x + 2 for x in v] for k, v in wrong_row_dict_int.items()}


def get_labels(df_row: pd.Series) -> dict[str, str]:
def get_labels(df_row: pd.Series[Any]) -> dict[str, str]:
"""
This function takes a pd.Series which has "label_[language tag]" in the index.
If the value of the index is not pd.NA, the language tag and the value are added to a dictionary.
Expand All @@ -190,7 +193,7 @@ def get_labels(df_row: pd.Series) -> dict[str, str]:
return {lang: df_row[f"label_{lang}"] for lang in languages if df_row[f"label_{lang}"] is not pd.NA}


def get_comments(df_row: pd.Series) -> dict[str, str] | None:
def get_comments(df_row: pd.Series[Any]) -> dict[str, str] | None:
"""
This function takes a pd.Series which has "comment_[language tag]" in the index.
If the value of the index is not pd.NA, the language tag and the value are added to a dictionary.
Expand All @@ -207,7 +210,7 @@ def get_comments(df_row: pd.Series) -> dict[str, str] | None:
return comments or None


def find_one_full_cell_in_cols(df: pd.DataFrame, required_columns: list[str]) -> pd.Series | None:
def find_one_full_cell_in_cols(df: pd.DataFrame, required_columns: list[str]) -> pd.Series[bool] | None:
"""
This function takes a pd.DataFrame and a list of column names where at least one cell must have a value per row.
A pd.Series with boolean values is returned, True if any rows do not have a value in at least one column
Expand All @@ -234,7 +237,7 @@ def col_must_or_not_empty_based_on_other_col(
substring_colname: str,
check_empty_colname: str,
must_have_value: bool,
) -> pd.Series | None:
) -> pd.Series[bool] | None:
"""
It is presumed that the column "substring_colname" has no empty cells.
Based on the string content of the individual rows, which is specified in the "substring_list",
Expand Down
13 changes: 7 additions & 6 deletions src/dsp_tools/commands/excel2xml/excel2xml_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Optional
from typing import Union
Expand Down Expand Up @@ -166,7 +167,7 @@ def _convert_rows_to_xml(

def _append_bitstream_to_resource(
resource: etree._Element,
row: pd.Series,
row: pd.Series[Any],
row_number: int,
) -> etree._Element:
"""
Expand Down Expand Up @@ -210,7 +211,7 @@ def _append_bitstream_to_resource(

def _convert_resource_row_to_xml(
row_number: int,
row: pd.Series,
row: pd.Series[Any],
) -> etree._Element:
"""
Convert a resource-row to an XML resource element.
Expand Down Expand Up @@ -291,7 +292,7 @@ def _convert_resource_row_to_xml(


def _get_prop_function(
row: pd.Series,
row: pd.Series[Any],
resource_id: str,
) -> Callable[..., etree._Element]:
"""
Expand Down Expand Up @@ -328,7 +329,7 @@ def _get_prop_function(


def _convert_row_to_property_elements(
row: pd.Series,
row: pd.Series[Any],
max_num_of_props: int,
row_number: int,
resource_id: str,
Expand Down Expand Up @@ -399,7 +400,7 @@ def _convert_row_to_property_elements(

def _convert_property_row_to_xml(
row_number: int,
row: pd.Series,
row: pd.Series[Any],
max_num_of_props: int,
resource_id: str,
) -> etree._Element:
Expand Down Expand Up @@ -443,7 +444,7 @@ def _convert_property_row_to_xml(

def _create_property(
make_prop_function: Callable[..., etree._Element],
row: pd.Series,
row: pd.Series[Any],
property_elements: list[PropertyElement],
resource_id: str,
) -> etree._Element:
Expand Down
8 changes: 5 additions & 3 deletions src/dsp_tools/commands/project/create/project_validate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import importlib.resources
import json
from pathlib import Path
Expand Down Expand Up @@ -420,8 +422,8 @@ def _extract_cardinalities_from_project(
return cardinalities, dependencies


def _make_cardinality_dependency_graph(dependencies: dict[str, dict[str, list[str]]]) -> nx.MultiDiGraph:
graph = nx.MultiDiGraph()
def _make_cardinality_dependency_graph(dependencies: dict[str, dict[str, list[str]]]) -> nx.MultiDiGraph[Any]:
graph: nx.MultiDiGraph[Any] = nx.MultiDiGraph()
for start, cards in dependencies.items():
for edge, targets in cards.items():
for target in targets:
Expand All @@ -430,7 +432,7 @@ def _make_cardinality_dependency_graph(dependencies: dict[str, dict[str, list[st


def _find_circles_with_min_one_cardinality(
graph: nx.MultiDiGraph, cardinalities: dict[str, dict[str, str]], dependencies: dict[str, dict[str, list[str]]]
graph: nx.MultiDiGraph[Any], cardinalities: dict[str, dict[str, str]], dependencies: dict[str, dict[str, list[str]]]
) -> set[tuple[str, str]]:
errors: set[tuple[str, str]] = set()
circles = list(nx.algorithms.cycles.simple_cycles(graph))
Expand Down