Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: retry if NVD API Key is invalid #1574

Merged
merged 4 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ on:
env:
ACTIONS: 1
LONG_TESTS: 0
nvd_api_key: ${{ secrets.NVD_API_KEY }}

jobs:
docs:
Expand Down
4 changes: 4 additions & 0 deletions cve_bin_tool/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ def main(argv=None):
if not args["nvd_api_key"] and os.getenv("nvd_api_key"):
args["nvd_api_key"] = os.getenv("nvd_api_key")

# Also try the uppercase env variable, in case people prefer those
if not args["nvd_api_key"] and os.getenv("NVD_API_KEY"):
args["nvd_api_key"] = os.getenv("NVD_API_KEY")

# If you're not using an NVD key, let you know how to get one
if not args["nvd_api_key"] and not args["offline"]:
LOGGER.info("Not using an NVD API key. Your access may be rate limited by NVD.")
Expand Down
6 changes: 6 additions & 0 deletions cve_bin_tool/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class NVDServiceError(Exception):
"""


class NVDKeyError(Exception):
"""
Raised if the NVD API key is invalid.
"""


class SHAMismatch(Exception):
"""
Raised if the sha of a file in the cache was not what it should be.
Expand Down
28 changes: 27 additions & 1 deletion cve_bin_tool/nvd_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from rich.progress import Progress, track

from cve_bin_tool.async_utils import RateLimiter
from cve_bin_tool.error_handler import ErrorMode, NVDServiceError
from cve_bin_tool.error_handler import ErrorMode, NVDKeyError, NVDServiceError
from cve_bin_tool.log import LOGGER

FEED = "https://services.nvd.nist.gov/rest/json/cves/1.0"
Expand Down Expand Up @@ -102,6 +102,9 @@ async def get_nvd_params(
self.logger.debug("Fetching metadata from NVD...")
cve_count = await self.nvd_count_metadata(self.session)

if "apiKey" in self.params:
await self.validate_nvd_api()

if time_of_last_update:
# Fetch all the updated CVE entries from the modified date. Subtracting 2-minute offset for updating cve entries
self.params["modStartDate"] = self.convert_date_to_nvd_date(
Expand All @@ -125,6 +128,28 @@ async def get_nvd_params(
self.total_results = cve_count["Total"] - cve_count["Rejected"]
self.logger.info(f"Adding {self.total_results} CVE entries")

async def validate_nvd_api(self):
"""
Validate NVD API
"""
param_dict = self.params.copy()
param_dict["startIndex"] = 0
param_dict["resultsPerPage"] = 1
try:
self.logger.debug("Validating NVD API...")
async with await self.session.get(
self.feed, params=param_dict, raise_for_status=True
) as response:
data = await response.json()
if data.get("error", False):
self.logger.error(f"NVD API error: {data['error']}")
raise NVDKeyError(self.params["apiKey"])
except NVDKeyError:
# If the API key provided is invalid, delete from params
# list and try the request again.
self.logger.error("unset api key, retrying")
del self.params["apiKey"]

async def load_nvd_request(self, start_index):
"""Get single NVD request and update year_wise_data list which contains list of all CVEs"""

Expand All @@ -141,6 +166,7 @@ async def load_nvd_request(self, start_index):
) as response:
if response.status == 200:
fetched_data = await response.json()

if start_index == 0:
# Update total results in case there is discrepancy between NVD dashboard and API
self.total_results = fetched_data["totalResults"]
Expand Down