diff --git a/CveXplore/VERSION b/CveXplore/VERSION index 5f0596e5..1b7f80e7 100644 --- a/CveXplore/VERSION +++ b/CveXplore/VERSION @@ -1 +1 @@ -0.3.24.dev12 \ No newline at end of file +0.3.24.dev13 \ No newline at end of file diff --git a/CveXplore/core/database_maintenance/main_updater.py b/CveXplore/core/database_maintenance/main_updater.py index 6177cea0..f9645642 100644 --- a/CveXplore/core/database_maintenance/main_updater.py +++ b/CveXplore/core/database_maintenance/main_updater.py @@ -133,7 +133,9 @@ def update(self, update_source: str | list = None): f"Update Total duration: {timedelta(seconds=time.time() - start_time)}" ) - def populate(self, populate_source: str | list = None): + def populate( + self, populate_source: str | list = None, limit: int = None, get_id: str = None + ): """ Method used for updating the database """ @@ -173,7 +175,12 @@ def populate(self, populate_source: str | list = None): x for x in self.sources if x["name"] == populate_source ][0] up = update_this_source["updater"]() - up.populate() + # this method could be used for testing purposes and is able to limit the amount of entries + # that are going to be fetched from the cves and cpe NIST API endpoints, either by count or by id + if populate_source == "cpe" or populate_source == "cve": + up.populate(limit=limit, get_id=get_id) + else: + up.update() except IndexError: raise UpdateSourceNotFound( f"Provided source: {populate_source} could not be found...." diff --git a/CveXplore/core/database_maintenance/sources_process.py b/CveXplore/core/database_maintenance/sources_process.py index 5984516f..a5819444 100644 --- a/CveXplore/core/database_maintenance/sources_process.py +++ b/CveXplore/core/database_maintenance/sources_process.py @@ -104,7 +104,9 @@ def process_the_item(self, item: dict = None): return cpe - def process_downloads(self, sites: list | None = None): + def process_downloads( + self, sites: list | None = None, limit: int = None, get_id: str = None + ): """ Method to download and process files """ @@ -121,13 +123,20 @@ def process_downloads(self, sites: list | None = None): if self.do_process: if not self.is_update: - try: - total_results = self.api_handler.get_count( - self.api_handler.datasource.CPE - ) - except ApiMaxRetryError: - # failed to get the count; set total_results to 0 and continue - total_results = 0 + if limit is None and get_id is None: + try: + total_results = self.api_handler.get_count( + self.api_handler.datasource.CPE + ) + except ApiMaxRetryError: + # failed to get the count; set total_results to 0 and continue + total_results = 0 + else: + if get_id is None: + total_results = limit + else: + limit = 1 + total_results = 1 self.logger.info(f"Preparing to download {total_results} CPE entries") @@ -137,7 +146,9 @@ def process_downloads(self, sites: list | None = None): position=0, leave=True, ) as pbar: - for entry in self.api_handler.get_all_data(data_type="cpe"): + for entry in self.api_handler.get_all_data( + data_type="cpe", limit=limit, get_id=get_id + ): # do something here with the results... for data_list in tqdm( entry, desc=f"Processing batch", leave=False @@ -261,7 +272,7 @@ def populate(self, **kwargs): self.dropCollection(self.feed_type.lower()) - self.process_downloads() + self.process_downloads(**kwargs) self.database_indexer.create_indexes(collection=self.feed_type.lower()) diff --git a/CveXplore/core/nvd_nist/nvd_nist_api.py b/CveXplore/core/nvd_nist/nvd_nist_api.py index 0a05c6d2..5e235f2a 100644 --- a/CveXplore/core/nvd_nist/nvd_nist_api.py +++ b/CveXplore/core/nvd_nist/nvd_nist_api.py @@ -261,9 +261,17 @@ def get_all_data( data_type: str, last_mod_start_date: datetime = None, last_mod_end_date: datetime = None, + limit: int = None, + get_id: str = None, ): resource = {} + if get_id is not None: + if self.datasource.CVE: + resource = {"cveId": get_id} + if self.datasource.CPE: + resource = {"cpeNameId": get_id} + if last_mod_start_date is not None and last_mod_end_date is not None: self.logger.debug(f"Getting all updated {data_type}s....") resource = self.check_date_range( @@ -274,15 +282,20 @@ def get_all_data( else: self.logger.debug(f"Getting all {data_type}s...") - try: - data = self.get_count( - getattr(self.datasource, data_type.upper()), - last_mod_start_date=last_mod_start_date, - last_mod_end_date=last_mod_end_date, - ) - except ApiMaxRetryError: - # failed to get the count; set data to 0 and continue - data = 0 + if limit is None: + try: + data = self.get_count( + getattr(self.datasource, data_type.upper()), + last_mod_start_date=last_mod_start_date, + last_mod_end_date=last_mod_end_date, + ) + except ApiMaxRetryError: + # failed to get the count; set data to 0 and continue + data = 0 + else: + data = limit + if get_id is None: + resource = {"resultsPerPage": limit} if isinstance(data, int): for each_data in ApiData(