Skip to content

Commit

Permalink
[CveXplore-234] fixes #274; create possibility to download a specific…
Browse files Browse the repository at this point in the history
… cve or cpe via ID or just simply limit the amount of returned results
  • Loading branch information
P-T-I committed Apr 12, 2024
1 parent cd966d8 commit 8565420
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CveXplore/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.24.dev12
0.3.24.dev13
11 changes: 9 additions & 2 deletions CveXplore/core/database_maintenance/main_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def update(self, update_source: str | list = None):
f"Update Total duration: {timedelta(seconds=time.time() - start_time)}"
)

def populate(self, populate_source: str | list = None):
def populate(
self, populate_source: str | list = None, limit: int = None, get_id: str = None
):
"""
Method used for updating the database
"""
Expand Down Expand Up @@ -173,7 +175,12 @@ def populate(self, populate_source: str | list = None):
x for x in self.sources if x["name"] == populate_source
][0]
up = update_this_source["updater"]()
up.populate()
# this method could be used for testing purposes and is able to limit the amount of entries
# that are going to be fetched from the cves and cpe NIST API endpoints, either by count or by id
if populate_source == "cpe" or populate_source == "cve":
up.populate(limit=limit, get_id=get_id)
else:
up.update()
except IndexError:
raise UpdateSourceNotFound(
f"Provided source: {populate_source} could not be found...."
Expand Down
31 changes: 21 additions & 10 deletions CveXplore/core/database_maintenance/sources_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def process_the_item(self, item: dict = None):

return cpe

def process_downloads(self, sites: list | None = None):
def process_downloads(
self, sites: list | None = None, limit: int = None, get_id: str = None
):
"""
Method to download and process files
"""
Expand All @@ -121,13 +123,20 @@ def process_downloads(self, sites: list | None = None):

if self.do_process:
if not self.is_update:
try:
total_results = self.api_handler.get_count(
self.api_handler.datasource.CPE
)
except ApiMaxRetryError:
# failed to get the count; set total_results to 0 and continue
total_results = 0
if limit is None and get_id is None:
try:
total_results = self.api_handler.get_count(
self.api_handler.datasource.CPE
)
except ApiMaxRetryError:
# failed to get the count; set total_results to 0 and continue
total_results = 0
else:
if get_id is None:
total_results = limit
else:
limit = 1
total_results = 1

self.logger.info(f"Preparing to download {total_results} CPE entries")

Expand All @@ -137,7 +146,9 @@ def process_downloads(self, sites: list | None = None):
position=0,
leave=True,
) as pbar:
for entry in self.api_handler.get_all_data(data_type="cpe"):
for entry in self.api_handler.get_all_data(
data_type="cpe", limit=limit, get_id=get_id
):
# do something here with the results...
for data_list in tqdm(
entry, desc=f"Processing batch", leave=False
Expand Down Expand Up @@ -261,7 +272,7 @@ def populate(self, **kwargs):

self.dropCollection(self.feed_type.lower())

self.process_downloads()
self.process_downloads(**kwargs)

self.database_indexer.create_indexes(collection=self.feed_type.lower())

Expand Down
31 changes: 22 additions & 9 deletions CveXplore/core/nvd_nist/nvd_nist_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,17 @@ def get_all_data(
data_type: str,
last_mod_start_date: datetime = None,
last_mod_end_date: datetime = None,
limit: int = None,
get_id: str = None,
):
resource = {}

if get_id is not None:
if self.datasource.CVE:
resource = {"cveId": get_id}
if self.datasource.CPE:
resource = {"cpeNameId": get_id}

if last_mod_start_date is not None and last_mod_end_date is not None:
self.logger.debug(f"Getting all updated {data_type}s....")
resource = self.check_date_range(
Expand All @@ -274,15 +282,20 @@ def get_all_data(
else:
self.logger.debug(f"Getting all {data_type}s...")

try:
data = self.get_count(
getattr(self.datasource, data_type.upper()),
last_mod_start_date=last_mod_start_date,
last_mod_end_date=last_mod_end_date,
)
except ApiMaxRetryError:
# failed to get the count; set data to 0 and continue
data = 0
if limit is None:
try:
data = self.get_count(
getattr(self.datasource, data_type.upper()),
last_mod_start_date=last_mod_start_date,
last_mod_end_date=last_mod_end_date,
)
except ApiMaxRetryError:
# failed to get the count; set data to 0 and continue
data = 0
else:
data = limit
if get_id is None:
resource = {"resultsPerPage": limit}

if isinstance(data, int):
for each_data in ApiData(
Expand Down

0 comments on commit 8565420

Please sign in to comment.