Skip to content

Commit

Permalink
refactor: add type hints in version_scanner.py (#1581)
Browse files Browse the repository at this point in the history
* part of #1539
  • Loading branch information
rhythmrx9 committed Mar 10, 2022
1 parent ca08460 commit 69f489d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 42 deletions.
5 changes: 5 additions & 0 deletions cve_bin_tool/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class ProductInfo(NamedTuple):
version: str


class ScanInfo(NamedTuple):
product_info: ProductInfo
file_path: str


class VersionInfo(NamedTuple):
start_including: str
start_excluding: str
Expand Down
95 changes: 53 additions & 42 deletions cve_bin_tool/version_scanner.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: GPL-3.0-or-later
from __future__ import annotations

import json
import os
import subprocess
import sys
from logging import Logger
from re import MULTILINE, compile, search
from typing import List
from typing import Iterator

import defusedxml.ElementTree as ET

from cve_bin_tool.checkers import Checker
from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.egg_updater import IS_DEVELOP, update_egg
from cve_bin_tool.error_handler import ErrorMode
from cve_bin_tool.extractor import Extractor
from cve_bin_tool.extractor import Extractor, TempDirExtractorContext
from cve_bin_tool.file import is_binary
from cve_bin_tool.log import LOGGER
from cve_bin_tool.strings import Strings
from cve_bin_tool.util import DirWalk, ProductInfo, inpath
from cve_bin_tool.util import DirWalk, ProductInfo, ScanInfo, inpath

if sys.version_info >= (3, 8):
from importlib import metadata as importlib_metadata
Expand All @@ -36,12 +39,12 @@ class VersionScanner:

def __init__(
self,
should_extract=False,
exclude_folders=[],
checkers=None,
logger=None,
error_mode=ErrorMode.TruncTrace,
score=0,
should_extract: bool = False,
exclude_folders: list[str] = [],
checkers: dict[str, type[Checker]] | None = None,
logger: Logger | None = None,
error_mode: ErrorMode = ErrorMode.TruncTrace,
score: int = 0,
):
self.logger = logger or LOGGER.getChild(self.__class__.__name__)
# Update egg if installed in development mode
Expand All @@ -62,13 +65,13 @@ def __init__(
)
).walk
self.should_extract = should_extract
self.file_stack = []
self.file_stack: list[str] = []
self.error_mode = error_mode
self.cve_db = CVEDB()
# self.logger.info("Checkers loaded: %s" % (", ".join(self.checkers.keys())))

@classmethod
def load_checkers(cls):
def load_checkers(cls) -> dict[str, type[Checker]]:
"""Loads CVE checkers"""
checkers = dict(
map(
Expand All @@ -79,12 +82,12 @@ def load_checkers(cls):
return checkers

@classmethod
def available_checkers(cls):
def available_checkers(cls) -> list[str]:
checkers = importlib_metadata.entry_points()[cls.CHECKER_ENTRYPOINT]
checker_list = [item.name for item in checkers]
return checker_list

def remove_skiplist(self, skips):
def remove_skiplist(self, skips: list[str]) -> None:
# Take out any checkers that are on the skip list
# (string of comma-delimited checker names)
skiplist = skips
Expand All @@ -95,20 +98,21 @@ def remove_skiplist(self, skips):
else:
self.logger.error(f"Checker {skipme} is not a valid checker name")

def print_checkers(self):
def print_checkers(self) -> None:
self.logger.info(f'Checkers: {", ".join(self.checkers.keys())}')

def number_of_checkers(self):
def number_of_checkers(self) -> int:
return len(self.checkers)

def is_executable(self, filename):
def is_executable(self, filename: str) -> tuple[bool, str | None]:
"""check if file is an ELF binary file"""

output = None
output: str | None = None
if inpath("file"):
# use system file if available (for performance reasons)
output = subprocess.check_output(["file", filename])
output = output.decode(sys.stdout.encoding)
output = subprocess.check_output(["file", filename]).decode(
sys.stdout.encoding
)

if "cannot open" in output:
self.logger.warning(f"Unopenable file {filename} cannot be scanned")
Expand All @@ -133,7 +137,7 @@ def is_executable(self, filename):

return True, output

def parse_strings(self, filename):
def parse_strings(self, filename: str) -> str:
"""parse binary file's strings"""

if inpath("strings"):
Expand All @@ -145,7 +149,7 @@ def parse_strings(self, filename):
lines = s.parse()
return lines

def scan_file(self, filename):
def scan_file(self, filename: str) -> Iterator[ScanInfo]:
"""Scans a file to see if it contains any of the target libraries,
and whether any of those contain CVEs"""

Expand Down Expand Up @@ -185,7 +189,9 @@ def scan_file(self, filename):

yield from self.run_checkers(filename, lines)

def find_java_vendor(self, product, version):
def find_java_vendor(
self, product: str, version: str
) -> tuple[ProductInfo, str] | tuple[None, None]:
"""Find vendor for Java product"""
vendor_package_pair = self.cve_db.get_vendor_product_pairs(product)
# If no match, try alternative product name.
Expand All @@ -205,7 +211,7 @@ def find_java_vendor(self, product, version):
return ProductInfo(vendor, product, version), file_path
return None, None

def run_java_checker(self, filename: str) -> None:
def run_java_checker(self, filename: str) -> Iterator[ScanInfo]:
"""Process maven pom.xml file and extract product and dependency details"""
tree = ET.parse(filename)
# Find root element
Expand All @@ -231,7 +237,7 @@ def run_java_checker(self, filename: str) -> None:
if product is not None and version is not None:
product_info, file_path = self.find_java_vendor(product, version)
if file_path is not None:
yield product_info, file_path
yield ScanInfo(product_info, file_path)

# Scan for any dependencies referenced in file
dependencies = root.find(schema + "dependencies")
Expand All @@ -249,16 +255,16 @@ def run_java_checker(self, filename: str) -> None:
product.text, version
)
if file_path is not None:
yield product_info, file_path
yield ScanInfo(product_info, file_path)

self.logger.debug(f"Done scanning file: {filename}")

def find_js_vendor(self, product: str, version: str) -> List[List[str]]:
def find_js_vendor(self, product: str, version: str) -> list[ScanInfo] | None:
"""Find vendor for Javascript product"""
if version == "*":
return None
vendor_package_pair = self.cve_db.get_vendor_product_pairs(product)
vendorlist: List[List[str]] = []
vendorlist: list[ScanInfo] = []
if vendor_package_pair != []:
# To handle multiple vendors, return all combinations of product/vendor mappings
for v in vendor_package_pair:
Expand All @@ -268,20 +274,21 @@ def find_js_vendor(self, product: str, version: str) -> List[List[str]]:
if "^" in version:
version = version[1:]
self.logger.debug(f"{file_path} {product} {version} by {vendor}")
vendorlist.append([ProductInfo(vendor, product, version), file_path])
vendorlist.append(
ScanInfo(ProductInfo(vendor, product, version), file_path)
)
return vendorlist if len(vendorlist) > 0 else None
return None

def run_js_checker(self, filename: str) -> None:
def run_js_checker(self, filename: str) -> Iterator[ScanInfo]:
"""Process package-lock.json file and extract product and dependency details"""
fh = open(filename)
data = json.load(fh)
product = data["name"]
version = data["version"]
vendor = self.find_js_vendor(product, version)
if vendor is not None:
for v in vendor:
yield v[0], v[1] # product_info, file_path
yield from vendor
# Now process dependencies
for i in data["dependencies"]:
# To handle @actions/<product>: lines, extract product name from line
Expand All @@ -299,20 +306,20 @@ def run_js_checker(self, filename: str) -> None:
version = data["dependencies"][i]
vendor = self.find_js_vendor(product, version)
if vendor is not None:
for v in vendor:
yield v[0], v[1] # product_info, file_path
yield from vendor
if "requires" in data["dependencies"][i]:
for r in data["dependencies"][i]["requires"]:
# To handle @actions/<product>: lines, extract product name from line
product = r.split("/")[1] if "/" in r else r
version = data["dependencies"][i]["requires"][r]
vendor = self.find_js_vendor(product, version)
if vendor is not None:
for v in vendor:
yield v[0], v[1] # product_info, file_path
yield from vendor
self.logger.debug(f"Done scanning file: {filename}")

def run_python_package_checkers(self, filename, lines):
def run_python_package_checkers(
self, filename: str, lines: str
) -> Iterator[ScanInfo]:
"""
This generator runs only for python packages.
There are no actual checkers.
Expand All @@ -331,15 +338,15 @@ def run_python_package_checkers(self, filename, lines):

self.logger.info(f"{file_path} is {product} {version}")

yield ProductInfo(vendor, product, version), file_path
yield ScanInfo(ProductInfo(vendor, product, version), file_path)

# There are packages with a METADATA file in them containing different data from what the tool expects
except AttributeError:
self.logger.debug(f"{filename} is an invalid METADATA/PKG-INFO")

self.logger.debug(f"Done scanning file: {filename}")

def run_checkers(self, filename, lines):
def run_checkers(self, filename: str, lines: str) -> Iterator[ScanInfo]:
# tko
for (dummy_checker_name, checker) in self.checkers.items():
checker = checker()
Expand Down Expand Up @@ -370,12 +377,14 @@ def run_checkers(self, filename, lines):
f'{file_path} {result["is_or_contains"]} {dummy_checker_name} {version}'
)
for vendor, product in checker.VENDOR_PRODUCT:
yield ProductInfo(vendor, product, version), file_path
yield ScanInfo(
ProductInfo(vendor, product, version), file_path
)

self.logger.debug(f"Done scanning file: {filename}")

@staticmethod
def clean_file_path(filepath):
def clean_file_path(filepath: str) -> str:
"""Returns a cleaner filepath by removing temp path from filepath"""

# we'll recieve a filepath similar to
Expand All @@ -387,7 +396,9 @@ def clean_file_path(filepath):
start_point = filepath.find("extracted") + 9
return filepath[start_point:]

def scan_and_or_extract_file(self, ectx, filepath):
def scan_and_or_extract_file(
self, ectx: TempDirExtractorContext, filepath: str
) -> Iterator[ScanInfo]:
"""Runs extraction if possible and desired otherwise scans."""
# Scan the file
yield from self.scan_file(filepath)
Expand All @@ -404,7 +415,7 @@ def scan_and_or_extract_file(self, ectx, filepath):
yield from self.scan_and_or_extract_file(ectx, filename)
self.file_stack.pop()

def recursive_scan(self, scan_path):
def recursive_scan(self, scan_path: str) -> Iterator[ScanInfo]:
with Extractor(logger=self.logger, error_mode=self.error_mode) as ectx:
if os.path.isdir(scan_path):
for filepath in self.walker([scan_path]):
Expand Down

0 comments on commit 69f489d

Please sign in to comment.