Skip to content

Commit

Permalink
Fix: typing in pontos/nvd (#673)
Browse files Browse the repository at this point in the history
* Fix: typing in pontos/nvd

* Fix: typing

---------

Co-authored-by: Tom <tom.ricciuti@greenbone.net>
  • Loading branch information
Tacire and Tom committed Jun 22, 2023
1 parent c293684 commit a1a0f73
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 20 deletions.
10 changes: 6 additions & 4 deletions pontos/nvd/api.py
Expand Up @@ -19,7 +19,7 @@
from abc import ABC
from datetime import datetime, timezone
from types import TracebackType
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Type, Union

from httpx import AsyncClient, Response, Timeout

Expand All @@ -30,7 +30,7 @@
DEFAULT_TIMEOUT_CONFIG = Timeout(DEFAULT_TIMEOUT) # three minutes

Headers = Dict[str, str]
Params = Dict[str, str]
Params = Dict[str, Union[str, int]]

__all__ = (
"convert_camel_case",
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
self._client = AsyncClient(http2=True, timeout=timeout)

if rate_limit:
self._rate_limit = 50 if token else 5
self._rate_limit: Optional[int] = 50 if token else 5
else:
self._rate_limit = None

Expand Down Expand Up @@ -170,4 +170,6 @@ async def __aexit__(
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
return await self._client.__aexit__(exc_type, exc_value, traceback)
return await self._client.__aexit__( # type: ignore
exc_type, exc_value, traceback
)
3 changes: 2 additions & 1 deletion pontos/nvd/cpe/__init__.py
Expand Up @@ -17,6 +17,7 @@

import asyncio
from argparse import ArgumentParser, Namespace
from typing import Callable

from pontos.nvd.cpe.api import CPEApi

Expand Down Expand Up @@ -64,7 +65,7 @@ def cpes_main() -> None:
main(parser, query_cpes)


def main(parser: ArgumentParser, func: callable) -> None:
def main(parser: ArgumentParser, func: Callable) -> None:
try:
args = parser.parse_args()
asyncio.run(func(args))
Expand Down
37 changes: 31 additions & 6 deletions pontos/nvd/cpe/api.py
Expand Up @@ -17,7 +17,17 @@


from datetime import datetime
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from types import TracebackType
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Optional,
Type,
Union,
)

from httpx import Timeout

Expand Down Expand Up @@ -153,7 +163,7 @@ async def cpes(
"""
total_results = None

params = {}
params: Dict[str, Union[str, int]] = {}
if last_modified_start_date:
params["lastModStartDate"] = format_date(last_modified_start_date)
if not last_modified_end_date:
Expand Down Expand Up @@ -191,11 +201,26 @@ async def cpes(
object_hook=convert_camel_case
)

results_per_page: int = data["results_per_page"]
total_results: int = data["total_results"]
products = data.get("products", [])
results_per_page: int = data["results_per_page"] # type: ignore
total_results: int = data["total_results"] # type: ignore
products: Iterable = data.get("products", []) # type: ignore

for product in products:
yield CPE.from_dict(product["cpe"])

start_index += results_per_page
if results_per_page is not None:
start_index += results_per_page

async def __aenter__(self) -> "CPEApi":
await super().__aenter__()
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
return await super().__aexit__( # type: ignore
exc_type, exc_value, traceback
)
3 changes: 2 additions & 1 deletion pontos/nvd/cve/__init__.py
Expand Up @@ -17,6 +17,7 @@

import asyncio
from argparse import ArgumentParser, Namespace
from typing import Callable

from pontos.nvd.cve.api import *

Expand Down Expand Up @@ -77,7 +78,7 @@ def cve_main() -> None:
main(parser, query_cve)


def main(parser: ArgumentParser, func: callable) -> None:
def main(parser: ArgumentParser, func: Callable) -> None:
try:
args = parser.parse_args()
asyncio.run(func(args))
Expand Down
44 changes: 36 additions & 8 deletions pontos/nvd/cve/api.py
Expand Up @@ -16,14 +16,25 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from datetime import datetime
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from types import TracebackType
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Optional,
Type,
Union,
)

from httpx import Timeout

from pontos.errors import PontosError
from pontos.nvd.api import (
DEFAULT_TIMEOUT_CONFIG,
NVDApi,
Params,
convert_camel_case,
format_date,
now,
Expand Down Expand Up @@ -162,9 +173,9 @@ async def cves(
async for cve in api.cves(keywords=["Mac OS X", "kernel"]):
print(cve.id)
"""
total_results = None
total_results: Optional[int] = None

params = {}
params: Params = {}
if last_modified_start_date:
params["lastModStartDate"] = format_date(last_modified_start_date)
if not last_modified_end_date:
Expand Down Expand Up @@ -219,7 +230,7 @@ async def cves(
if has_oval:
params["hasOval"] = ""

start_index = 0
start_index: int = 0
results_per_page = None

while total_results is None or start_index < total_results:
Expand All @@ -235,14 +246,17 @@ async def cves(
object_hook=convert_camel_case
)

results_per_page: int = data["results_per_page"]
total_results: int = data["total_results"]
vulnerabilities = data.get("vulnerabilities", [])
total_results = data["total_results"] # type: ignore
results_per_page: int = data["results_per_page"] # type: ignore
vulnerabilities: Iterable = data.get( # type: ignore
"vulnerabilities", []
)

for vulnerability in vulnerabilities:
yield CVE.from_dict(vulnerability["cve"])

start_index += results_per_page
if results_per_page is not None:
start_index += results_per_page

async def cve(self, cve_id: str) -> CVE:
"""
Expand Down Expand Up @@ -280,3 +294,17 @@ async def cve(self, cve_id: str) -> CVE:

vulnerability = vulnerabilities[0]
return CVE.from_dict(vulnerability["cve"])

async def __aenter__(self) -> "CVEApi":
await super().__aenter__()
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
return await super().__aexit__( # type: ignore
exc_type, exc_value, traceback
)

0 comments on commit a1a0f73

Please sign in to comment.