Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add type hints in version_scanner.py #1581

Merged
merged 5 commits into from
Mar 10, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
75 changes: 43 additions & 32 deletions cve_bin_tool/version_scanner.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: GPL-3.0-or-later
from __future__ import annotations

import json
import os
import subprocess
import sys
from logging import Logger
from re import MULTILINE, compile, search
from typing import List
from typing import Callable, Generator, Iterator, Tuple, Union

import defusedxml.ElementTree as ET

from cve_bin_tool.checkers import Checker
from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.egg_updater import IS_DEVELOP, update_egg
from cve_bin_tool.error_handler import ErrorMode
from cve_bin_tool.extractor import Extractor
from cve_bin_tool.extractor import Extractor, TempDirExtractorContext
from cve_bin_tool.file import is_binary
from cve_bin_tool.log import LOGGER
from cve_bin_tool.strings import Strings
Expand All @@ -33,15 +36,16 @@ class VersionScanner:
""" "Scans files for CVEs using CVE checkers"""

CHECKER_ENTRYPOINT = "cve_bin_tool.checker"
infogen = Generator[Tuple[Union[ProductInfo, None], str], None, None]
rhythmrx9 marked this conversation as resolved.
Show resolved Hide resolved
rhythmrx9 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
should_extract=False,
exclude_folders=[],
checkers=None,
logger=None,
error_mode=ErrorMode.TruncTrace,
score=0,
should_extract: bool = False,
exclude_folders: list[str] = [],
checkers: dict[str, type[Checker]] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it type[Checker] and not just Checker? 🤔

Copy link
Contributor Author

@rhythmrx9 rhythmrx9 Feb 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because its the Checker subclass itself (uninstantiated) of type CheckerMetaClass.

for (dummy_checker_name, checker) in self.checkers.items():
checker = checker()

logger: Logger | None = None,
error_mode: ErrorMode = ErrorMode.TruncTrace,
score: int = 0,
):
self.logger = logger or LOGGER.getChild(self.__class__.__name__)
# Update egg if installed in development mode
Expand All @@ -55,20 +59,20 @@ def __init__(
self.total_scanned_files = 0
self.exclude_folders = exclude_folders + [".git"]

self.walker = DirWalk(
self.walker: Callable[[list[str]], Iterator[str]] = DirWalk(
rhythmrx9 marked this conversation as resolved.
Show resolved Hide resolved
folder_exclude_pattern=";".join(
exclude if exclude.endswith("*") else exclude + "*"
for exclude in exclude_folders
)
).walk
self.should_extract = should_extract
self.file_stack = []
self.file_stack: list[str] = []
self.error_mode = error_mode
self.cve_db = CVEDB()
# self.logger.info("Checkers loaded: %s" % (", ".join(self.checkers.keys())))

@classmethod
def load_checkers(cls):
def load_checkers(cls) -> dict[str, type[Checker]]:
"""Loads CVE checkers"""
checkers = dict(
map(
Expand All @@ -79,12 +83,12 @@ def load_checkers(cls):
return checkers

@classmethod
def available_checkers(cls):
def available_checkers(cls) -> list[str]:
checkers = importlib_metadata.entry_points()[cls.CHECKER_ENTRYPOINT]
checker_list = [item.name for item in checkers]
return checker_list

def remove_skiplist(self, skips):
def remove_skiplist(self, skips: list[str]) -> None:
# Take out any checkers that are on the skip list
# (string of comma-delimited checker names)
skiplist = skips
Expand All @@ -95,20 +99,21 @@ def remove_skiplist(self, skips):
else:
self.logger.error(f"Checker {skipme} is not a valid checker name")

def print_checkers(self):
def print_checkers(self) -> None:
self.logger.info(f'Checkers: {", ".join(self.checkers.keys())}')

def number_of_checkers(self):
def number_of_checkers(self) -> int:
return len(self.checkers)

def is_executable(self, filename):
def is_executable(self, filename: str) -> tuple[bool, str | None]:
"""check if file is an ELF binary file"""

output = None
output: str | None = None
if inpath("file"):
# use system file if available (for performance reasons)
output = subprocess.check_output(["file", filename])
output = output.decode(sys.stdout.encoding)
output = subprocess.check_output(["file", filename]).decode(
sys.stdout.encoding
)

if "cannot open" in output:
self.logger.warning(f"Unopenable file {filename} cannot be scanned")
Expand All @@ -133,7 +138,7 @@ def is_executable(self, filename):

return True, output

def parse_strings(self, filename):
def parse_strings(self, filename: str) -> str:
"""parse binary file's strings"""

if inpath("strings"):
Expand All @@ -145,7 +150,7 @@ def parse_strings(self, filename):
lines = s.parse()
return lines

def scan_file(self, filename):
def scan_file(self, filename: str) -> infogen:
"""Scans a file to see if it contains any of the target libraries,
and whether any of those contain CVEs"""

Expand Down Expand Up @@ -185,7 +190,9 @@ def scan_file(self, filename):

yield from self.run_checkers(filename, lines)

def find_java_vendor(self, product, version):
def find_java_vendor(
self, product: str, version: str
) -> tuple[ProductInfo, str] | tuple[None, None]:
"""Find vendor for Java product"""
vendor_package_pair = self.cve_db.get_vendor_product_pairs(product)
# If no match, try alternative product name.
Expand All @@ -205,7 +212,7 @@ def find_java_vendor(self, product, version):
return ProductInfo(vendor, product, version), file_path
return None, None

def run_java_checker(self, filename: str) -> None:
def run_java_checker(self, filename: str) -> infogen:
"""Process maven pom.xml file and extract product and dependency details"""
tree = ET.parse(filename)
# Find root element
Expand Down Expand Up @@ -253,12 +260,14 @@ def run_java_checker(self, filename: str) -> None:

self.logger.debug(f"Done scanning file: {filename}")

def find_js_vendor(self, product: str, version: str) -> List[List[str]]:
def find_js_vendor(
self, product: str, version: str
) -> list[tuple[ProductInfo, str]] | None:
"""Find vendor for Javascript product"""
if version == "*":
return None
vendor_package_pair = self.cve_db.get_vendor_product_pairs(product)
vendorlist: List[List[str]] = []
vendorlist: list[tuple[ProductInfo, str]] = []
if vendor_package_pair != []:
# To handle multiple vendors, return all combinations of product/vendor mappings
for v in vendor_package_pair:
Expand All @@ -268,11 +277,11 @@ def find_js_vendor(self, product: str, version: str) -> List[List[str]]:
if "^" in version:
version = version[1:]
self.logger.debug(f"{file_path} {product} {version} by {vendor}")
vendorlist.append([ProductInfo(vendor, product, version), file_path])
vendorlist.append((ProductInfo(vendor, product, version), file_path))
return vendorlist if len(vendorlist) > 0 else None
return None

def run_js_checker(self, filename: str) -> None:
def run_js_checker(self, filename: str) -> infogen:
"""Process package-lock.json file and extract product and dependency details"""
fh = open(filename)
data = json.load(fh)
Expand Down Expand Up @@ -312,7 +321,7 @@ def run_js_checker(self, filename: str) -> None:
yield v[0], v[1] # product_info, file_path
self.logger.debug(f"Done scanning file: {filename}")

def run_python_package_checkers(self, filename, lines):
def run_python_package_checkers(self, filename: str, lines: str) -> infogen:
"""
This generator runs only for python packages.
There are no actual checkers.
Expand All @@ -339,7 +348,7 @@ def run_python_package_checkers(self, filename, lines):

self.logger.debug(f"Done scanning file: {filename}")

def run_checkers(self, filename, lines):
def run_checkers(self, filename: str, lines: str) -> infogen:
# tko
for (dummy_checker_name, checker) in self.checkers.items():
checker = checker()
Expand Down Expand Up @@ -375,7 +384,7 @@ def run_checkers(self, filename, lines):
self.logger.debug(f"Done scanning file: {filename}")

@staticmethod
def clean_file_path(filepath):
def clean_file_path(filepath: str) -> str:
"""Returns a cleaner filepath by removing temp path from filepath"""

# we'll recieve a filepath similar to
Expand All @@ -387,7 +396,9 @@ def clean_file_path(filepath):
start_point = filepath.find("extracted") + 9
return filepath[start_point:]

def scan_and_or_extract_file(self, ectx, filepath):
def scan_and_or_extract_file(
self, ectx: TempDirExtractorContext, filepath: str
) -> infogen:
"""Runs extraction if possible and desired otherwise scans."""
# Scan the file
yield from self.scan_file(filepath)
Expand All @@ -404,7 +415,7 @@ def scan_and_or_extract_file(self, ectx, filepath):
yield from self.scan_and_or_extract_file(ectx, filename)
self.file_stack.pop()

def recursive_scan(self, scan_path):
def recursive_scan(self, scan_path: str) -> infogen:
with Extractor(logger=self.logger, error_mode=self.error_mode) as ectx:
if os.path.isdir(scan_path):
for filepath in self.walker([scan_path]):
Expand Down