Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional NVD API Improvements #937

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pontos/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,14 @@ async def _next_iterator(self) -> "NVDResults":
return self

def __repr__(self) -> str:
return f'<{self.__class__.__name__} url="{self._url}" total_results={self._total_results} start_index={self._start_index} current_index={self._current_index} results_per_page={self._results_per_page}>'
return (
f"<{self.__class__.__name__} "
f'url="{self._url}" '
f"total_results={self._total_results} "
f"start_index={self._start_index} "
f"current_index={self._current_index} "
f"results_per_page={self._results_per_page}>"
)


class NVDApi(ABC):
Expand Down
23 changes: 20 additions & 3 deletions pontos/nvd/cpe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from argparse import ArgumentParser, Namespace
from typing import Callable

import httpx

from pontos.nvd.cpe.api import CPEApi

__all__ = ("CPEApi",)
Expand All @@ -32,9 +34,13 @@ async def query_cpe(args: Namespace) -> None:

async def query_cpes(args: Namespace) -> None:
async with CPEApi(token=args.token) as api:
async for cpe in api.cpes(
keywords=args.keywords, cpe_match_string=args.cpe_match_string
):
response = api.cpes(
keywords=args.keywords,
cpe_match_string=args.cpe_match_string,
request_results=args.number,
start_index=args.start,
)
async for cpe in response:
print(cpe)


Expand All @@ -61,6 +67,15 @@ def cpes_main() -> None:
help="Search for CPEs containing the keyword in their titles and "
"references.",
)
parser.add_argument(
"--number", "-n", metavar="N", help="Request only N CPEs", type=int
)
parser.add_argument(
"--start",
"-s",
help="Index of the first CPE to request.",
type=int,
)

main(parser, query_cpes)

Expand All @@ -71,3 +86,5 @@ def main(parser: ArgumentParser, func: Callable) -> None:
asyncio.run(func(args))
except KeyboardInterrupt:
pass
except httpx.HTTPStatusError as e:
print(f"HTTP Error {e.response.status_code}: {e.response.text}")
4 changes: 3 additions & 1 deletion pontos/nvd/cpe/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def cpes(
keywords: Optional[Union[List[str], str]] = None,
match_criteria_id: Optional[str] = None,
request_results: Optional[int] = None,
start_index: int = 0,
) -> NVDResults[CPE]:
"""
Get all CPEs for the provided arguments
Expand All @@ -159,6 +160,8 @@ def cpes(
string identified by its UUID.
request_results: Number of CPEs to download. Set to None (default)
to download all available CPEs.
start_index: Index of the first CPE to be returned. Useful only for
paginated requests that should not start at the first page.

Returns:
A NVDResponse for CPEs
Expand Down Expand Up @@ -202,7 +205,6 @@ def cpes(
if match_criteria_id:
params["matchCriteriaId"] = match_criteria_id

start_index = 0
results_per_page = (
request_results
if request_results and request_results < MAX_CPES_PER_PAGE
Expand Down
15 changes: 15 additions & 0 deletions pontos/nvd/cve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from argparse import ArgumentParser, Namespace
from typing import Callable

import httpx

from pontos.nvd.cve.api import CVEApi

__all__ = ("CVEApi",)
Expand All @@ -32,6 +34,8 @@ async def query_cves(args: Namespace) -> None:
cvss_v2_vector=args.cvss_v2_vector,
cvss_v3_vector=args.cvss_v3_vector,
source_identifier=args.source_identifier,
request_results=args.number,
start_index=args.start,
):
print(cve)

Expand Down Expand Up @@ -66,6 +70,15 @@ def cves_main() -> None:
help="Get all CVE information with the source identifier. For example: "
"cve@mitre.org",
)
parser.add_argument(
"--number", "-n", metavar="N", help="Request only N CVEs", type=int
)
parser.add_argument(
"--start",
"-s",
help="Index of the first CVE to request.",
type=int,
)

main(parser, query_cves)

Expand All @@ -84,3 +97,5 @@ def main(parser: ArgumentParser, func: Callable) -> None:
asyncio.run(func(args))
except KeyboardInterrupt:
pass
except httpx.HTTPStatusError as e:
print(f"HTTP Error {e.response.status_code}: {e.response.text}")
4 changes: 3 additions & 1 deletion pontos/nvd/cve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def cves(
has_kev: Optional[bool] = None,
has_oval: Optional[bool] = None,
request_results: Optional[int] = None,
start_index: int = 0,
) -> NVDResults[CVE]:
"""
Get all CVEs for the provided arguments
Expand Down Expand Up @@ -171,6 +172,8 @@ def cves(
transitioned to the Center for Internet Security (CIS).
request_results: Number of CVEs to download. Set to None (default)
to download all available CVEs.
start_index: Index of the first CVE to be returned. Useful only for
paginated requests that should not start at the first page.

Returns:
A NVDResponse for CVEs
Expand Down Expand Up @@ -249,7 +252,6 @@ def cves(
if has_oval:
params["hasOval"] = ""

start_index = 0
results_per_page = (
request_results
if request_results and request_results < MAX_CVES_PER_PAGE
Expand Down
14 changes: 13 additions & 1 deletion pontos/nvd/cve_changes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
async def query_changes(args: Namespace) -> None:
async with CVEChangesApi(token=args.token) as api:
async for cve in api.changes(
cve_id=args.cve_id, event_name=args.event_name
cve_id=args.cve_id,
event_name=args.event_name,
request_results=args.number,
start_index=args.start,
):
print(cve)

Expand All @@ -25,6 +28,15 @@ def parse_args() -> Namespace:
parser.add_argument(
"--event-name", help="Get all CVE associated with a specific event name"
)
parser.add_argument(
"--number", "-n", metavar="N", help="Request only N CPEs", type=int
)
parser.add_argument(
"--start",
"-s",
help="Index of the first CPE to request.",
type=int,
)
return parser.parse_args()


Expand Down
5 changes: 4 additions & 1 deletion pontos/nvd/cve_changes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def changes(
cve_id: Optional[str] = None,
event_name: Optional[Union[EventName, str]] = None,
request_results: Optional[int] = None,
start_index: int = 0,
) -> NVDResults[CVEChange]:
"""
Get all CVEs for the provided arguments
Expand All @@ -98,6 +99,9 @@ def changes(
event_name: Return all CVE changes with this event name.
request_results: Number of CVEs changes to download. Set to None
(default) to download all available CPEs.
start_index: Index of the first CVE change to be returned. Useful
only for paginated requests that should not start at the first
page.

Returns:
A NVDResponse for CVE changes
Expand Down Expand Up @@ -142,7 +146,6 @@ def changes(
if event_name:
params["eventName"] = event_name

start_index: int = 0
results_per_page = (
request_results
if request_results and request_results < MAX_CVE_CHANGES_PER_PAGE
Expand Down
83 changes: 83 additions & 0 deletions tests/nvd/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ async def test_items(self):
result = await anext(it)
self.assertEqual(result.value, 6)

with self.assertRaises(StopAsyncIteration):
await anext(it)

async def test_aiter(self):
response_mock = MagicMock(spec=Response)
response_mock.json.side_effect = [
Expand Down Expand Up @@ -238,6 +241,9 @@ async def test_aiter(self):
result = await anext(it)
self.assertEqual(result.value, 6)

with self.assertRaises(StopAsyncIteration):
await anext(it)

async def test_len(self):
response_mock = MagicMock(spec=Response)
response_mock.json.return_value = {
Expand Down Expand Up @@ -294,6 +300,9 @@ async def test_chunks(self):
results = await anext(it)
self.assertEqual([result.value for result in results], [4, 5, 6])

with self.assertRaises(StopAsyncIteration):
await anext(it)

async def test_json(self):
response_mock = MagicMock(spec=Response)
response_mock.json.side_effect = [
Expand Down Expand Up @@ -466,3 +475,77 @@ async def test_response_error(self):
"resultsPerPage": 3,
}
)

async def test_request_results_limit(self):
response_mock = MagicMock(spec=Response)
response_mock.json.side_effect = [
{
"values": [1, 2, 3, 4],
"total_results": 5,
"results_per_page": 4,
},
{
"values": [5],
"total_results": 5,
"results_per_page": 1,
},
]
api_mock = AsyncMock(spec=NVDApi)
api_mock._get.return_value = response_mock

nvd_results: NVDResults[Result] = NVDResults(
api_mock,
{},
result_func,
request_results=5,
)

json: dict[str, Any] = await nvd_results.json() # type: ignore
self.assertEqual(json["values"], [1, 2, 3, 4])
self.assertEqual(json["total_results"], 5)
self.assertEqual(json["results_per_page"], 4)

api_mock._get.assert_called_once_with(params={"startIndex": 0})
api_mock.reset_mock()

json: dict[str, Any] = await nvd_results.json() # type: ignore
self.assertEqual(json["values"], [5])
self.assertEqual(json["total_results"], 5)
self.assertEqual(json["results_per_page"], 1)

api_mock._get.assert_called_once_with(
params={"startIndex": 4, "resultsPerPage": 1}
)

async def test_repr(self):
response_mock = MagicMock(spec=Response)
response_mock.json.side_effect = [
{
"values": [1, 2, 3, 4],
"total_results": 5,
"results_per_page": 4,
},
{
"values": [5],
"total_results": 5,
"results_per_page": 1,
},
]
response_mock.url = "https://some.url&startIndex=0"
api_mock = AsyncMock(spec=NVDApi)
api_mock._get.return_value = response_mock

nvd_results: NVDResults[Result] = NVDResults(
api_mock,
{},
result_func,
)

await nvd_results

self.assertEqual(
repr(nvd_results),
'<NVDResults url="https://some.url&startIndex=0" '
"total_results=5 start_index=0 current_index=4 "
"results_per_page=None>",
)