Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _build_ip_prefix_allocation_query(
if prefix_length:
input_data["prefix_length"] = prefix_length
if member_type:
if member_type not in ("prefix", "address"):
if member_type not in {"prefix", "address"}:
raise ValueError("member_type possible values are 'prefix' or 'address'")
input_data["member_type"] = member_type
if prefix_type:
Expand Down Expand Up @@ -956,7 +956,7 @@ async def execute_graphql(
try:
resp = await self._post(url=url, payload=payload, headers=headers, timeout=timeout)

if raise_for_error in (None, True):
if raise_for_error in {None, True}:
resp.raise_for_status()

retry = False
Expand All @@ -970,7 +970,7 @@ async def execute_graphql(
self.log.error(f"Unable to connect to {self.address} .. ")
raise
except httpx.HTTPStatusError as exc:
if exc.response.status_code in [401, 403]:
if exc.response.status_code in {401, 403}:
response = decode_json(response=exc.response)
errors = response.get("errors", [])
messages = [error.get("message") for error in errors]
Expand Down Expand Up @@ -1208,7 +1208,7 @@ async def query_gql_query(
timeout=timeout or self.default_timeout,
)

if raise_for_error in (None, True):
if raise_for_error in {None, True}:
resp.raise_for_status()

return decode_json(response=resp)
Expand Down Expand Up @@ -1817,7 +1817,7 @@ def execute_graphql(
try:
resp = self._post(url=url, payload=payload, headers=headers, timeout=timeout)

if raise_for_error in (None, True):
if raise_for_error in {None, True}:
resp.raise_for_status()

retry = False
Expand All @@ -1831,7 +1831,7 @@ def execute_graphql(
self.log.error(f"Unable to connect to {self.address} .. ")
raise
except httpx.HTTPStatusError as exc:
if exc.response.status_code in [401, 403]:
if exc.response.status_code in {401, 403}:
response = decode_json(response=exc.response)
errors = response.get("errors", [])
messages = [error.get("message") for error in errors]
Expand Down Expand Up @@ -2446,7 +2446,7 @@ def query_gql_query(
timeout=timeout or self.default_timeout,
)

if raise_for_error in (None, True):
if raise_for_error in {None, True}:
resp.raise_for_status()

return decode_json(response=resp)
Expand Down
4 changes: 2 additions & 2 deletions infrahub_sdk/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ async def generate_query_data_node(

if (
rel_schema.cardinality == RelationshipCardinality.MANY # type: ignore[union-attr]
and rel_schema.kind not in [RelationshipKind.ATTRIBUTE, RelationshipKind.PARENT] # type: ignore[union-attr]
and rel_schema.kind not in {RelationshipKind.ATTRIBUTE, RelationshipKind.PARENT} # type: ignore[union-attr]
and not (include and rel_name in include)
):
continue
Expand Down Expand Up @@ -1364,7 +1364,7 @@ def generate_query_data_node(

if (
rel_schema.cardinality == RelationshipCardinality.MANY # type: ignore[union-attr]
and rel_schema.kind not in [RelationshipKind.ATTRIBUTE, RelationshipKind.PARENT] # type: ignore[union-attr]
and rel_schema.kind not in {RelationshipKind.ATTRIBUTE, RelationshipKind.PARENT} # type: ignore[union-attr]
and not (include and rel_name in include)
):
continue
Expand Down
8 changes: 4 additions & 4 deletions infrahub_sdk/object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def get(self, identifier: str, tracker: str | None = None) -> str:
self.client.log.error(f"Unable to connect to {self.client.address} .. ")
raise
except httpx.HTTPStatusError as exc:
if exc.response.status_code in [401, 403]:
if exc.response.status_code in {401, 403}:
response = exc.response.json()
errors = response.get("errors")
messages = [error.get("message") for error in errors]
Expand All @@ -54,7 +54,7 @@ async def upload(self, content: str, tracker: str | None = None) -> dict[str, st
self.client.log.error(f"Unable to connect to {self.client.address} .. ")
raise
except httpx.HTTPStatusError as exc:
if exc.response.status_code in [401, 403]:
if exc.response.status_code in {401, 403}:
response = exc.response.json()
errors = response.get("errors")
messages = [error.get("message") for error in errors]
Expand All @@ -81,7 +81,7 @@ def get(self, identifier: str, tracker: str | None = None) -> str:
self.client.log.error(f"Unable to connect to {self.client.address} .. ")
raise
except httpx.HTTPStatusError as exc:
if exc.response.status_code in [401, 403]:
if exc.response.status_code in {401, 403}:
response = exc.response.json()
errors = response.get("errors")
messages = [error.get("message") for error in errors]
Expand All @@ -102,7 +102,7 @@ def upload(self, content: str, tracker: str | None = None) -> dict[str, str]:
self.client.log.error(f"Unable to connect to {self.client.address} .. ")
raise
except httpx.HTTPStatusError as exc:
if exc.response.status_code in [401, 403]:
if exc.response.status_code in {401, 403}:
response = exc.response.json()
errors = response.get("errors")
messages = [error.get("message") for error in errors]
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/protocols_generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, schema: dict[str, MainSchemaTypesAll]) -> None:
if not e.startswith("__")
and not e.endswith("__")
and e
not in ("TYPE_CHECKING", "CoreNode", "Optional", "Protocol", "Union", "annotations", "runtime_checkable")
not in {"TYPE_CHECKING", "CoreNode", "Optional", "Protocol", "Union", "annotations", "runtime_checkable"}
]

self.sorted_generics = self._sort_and_filter_models(self.generics, filters=["CoreNode"] + self.base_protocols)
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/pytest_plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parse_user_provided_data(path: Path | None) -> Any:

if suffix and suffix == "json":
return ujson.loads(text)
if suffix in ("yml", "yaml"):
if suffix in {"yml", "yaml"}:
return yaml.safe_load(text)

return text
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/pytest_plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def pytest_sessionstart(session: Session) -> None:


def pytest_collect_file(parent: Collector | Item, file_path: Path) -> InfrahubYamlFile | None:
if file_path.suffix in [".yml", ".yaml"] and file_path.name.startswith("test_"):
if file_path.suffix in {".yml", ".yaml"} and file_path.name.startswith("test_"):
return InfrahubYamlFile.from_parent(parent, path=file_path)
return None

Expand Down
4 changes: 2 additions & 2 deletions infrahub_sdk/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ def _validate_load_schema_response(response: httpx.Response) -> SchemaLoadRespon
hash=status["hash"], previous_hash=status["previous_hash"], warnings=status.get("warnings") or []
)

if response.status_code in [
if response.status_code in {
httpx.codes.BAD_REQUEST,
httpx.codes.UNPROCESSABLE_ENTITY,
httpx.codes.UNAUTHORIZED,
httpx.codes.FORBIDDEN,
]:
}:
return SchemaLoadResponse(errors=response.json())

response.raise_for_status()
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/spec/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def is_valid(self) -> bool:

@property
def is_reference(self) -> bool:
return self.format in [RelationshipDataFormat.ONE_REF, RelationshipDataFormat.MANY_REF]
return self.format in {RelationshipDataFormat.ONE_REF, RelationshipDataFormat.MANY_REF}

def get_context(self, value: Any) -> dict:
"""Return a dict to insert to the context if the relationship is mandatory"""
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def str_to_bool(value: str) -> bool:
if isinstance(value, bool):
return value

if isinstance(value, int) and value in [0, 1]:
if isinstance(value, int) and value in {0, 1}:
return bool(value)

if not isinstance(value, str):
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ ignore = [
"PLR0913", # Too many arguments in function definition
"PLR0917", # Too many positional arguments
"PLR2004", # Magic value used in comparison
"PLR6201", # Use a `set` literal when testing for membership
"PLR6301", # Method could be a function, class method, or static method
"PLW0603", # Using the global statement to update `SETTINGS` is discouraged
"PLW1641", # Object does not implement `__hash__` method
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_infrahub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def test_get_generic(self, client: InfrahubClient, base_dataset) -> None:
async def test_get_generic_fragment(self, client: InfrahubClient, base_dataset) -> None:
nodes = await client.all(kind=TESTING_ANIMAL, fragment=True)
assert len(nodes)
assert nodes[0].typename in [TESTING_DOG, TESTING_CAT]
assert nodes[0].typename in {TESTING_DOG, TESTING_CAT}
assert nodes[0].breed.value is not None

async def test_get_related_nodes(self, client: InfrahubClient, base_dataset, person_ethan) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sdk/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
# type: ignore[attr-defined]

async_node_methods = [
method for method in dir(InfrahubNode) if not method.startswith("_") and method not in ("hfid", "hfid_str")
method for method in dir(InfrahubNode) if not method.startswith("_") and method not in {"hfid", "hfid_str"}
]
sync_node_methods = [
method for method in dir(InfrahubNodeSync) if not method.startswith("_") and method not in ("hfid", "hfid_str")
method for method in dir(InfrahubNodeSync) if not method.startswith("_") and method not in {"hfid", "hfid_str"}
]

client_types = ["standard", "sync"]
Expand Down