In [None]:
import pandas as pd
import json
import re
from rapidfuzz import process, fuzz
import numpy as np
from typing import Tuple, List, Optional
import itertools
from tqdm.auto import tqdm
import ast
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
import glob
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import requests
import tempfile
import sys
import zipfile

sys.path.append('./..')
import sec_certs.helpers as helpers

tqdm.pandas()
plt.style.use('seaborn')
pd.set_option("max_colwidth", 100)
pd.set_option("max_rows", 100)

## Functions for CVE and CPE preprocessing

In [None]:
def download_cve_data(output_dir: str, start_year=2002, end_year=2021):
    output_dir = Path(output_dir)
    if not output_dir.exists:
        output_dir.mkdir()

    base_url = 'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-'
    urls = [base_url + str(x) + '.json.zip' for x in range(start_year, end_year + 1)]

    print(f'Identified {len(urls)} CVE files to fetch from nist.gov. Downloading them into {output_dir}', flush=True)
    with tempfile.TemporaryDirectory() as tmp_dir:
        outpaths = [Path(tmp_dir) / Path(x).name.rstrip('.zip') for x in urls]
        responses = list(zip(*helpers.download_parallel(list(zip(urls, outpaths)), num_threads=8)))[1]

        for o, u, r in zip(outpaths, urls, responses):
            if r == 200:
                with zipfile.ZipFile(o, 'r') as zip_handle:
                    zip_handle.extractall(output_dir)
            else:
                print(f'Failed to download from {u}, got status code {r}')

def parse_all_cves(cve_dir: str, output_path: str) -> None:
    def get_relevant_info_from_file(input_path: Path) -> List[Dict]:
        with input_path.open('r') as handle:
            data = json.load(handle)
        cve_data = []
        for cve in data['CVE_Items']:
            cve_data.append(get_relevant_info_from_cve(cve))
        return cve_data
    
    def get_relevant_info_from_cve(cve: Dict) -> Dict:
        cve_id = cve['cve']['CVE_data_meta']['ID']
        impact = get_impact_from_cve(cve)
        affected_cpes = get_affected_cpes_from_cve(cve)
        return {'cve_id': cve_id, 'impact': impact, 'vulnerable_cpes': affected_cpes}

    def get_impact_from_cve(cve: Dict) -> Dict:
        result = {'base_score': None, 'severity': None, 'exploitabilityScore': None, 'impactScore': None}
        if not cve['impact']:
            pass
        elif 'baseMetricV3' in cve['impact']:
            result['base_score'] = cve['impact']['baseMetricV3']['cvssV3']['baseScore']
            result['severity'] = cve['impact']['baseMetricV3']['cvssV3']['baseSeverity']
            result['exploitabilityScore'] = cve['impact']['baseMetricV3']['exploitabilityScore']
            result['impactScore'] = cve['impact']['baseMetricV3']['impactScore']
        elif 'baseMetricV2' in cve['impact']:
            result['base_score'] = cve['impact']['baseMetricV2']['cvssV2']['baseScore']
            result['severity'] = cve['impact']['baseMetricV2']['severity']
            result['exploitabilityScore'] = cve['impact']['baseMetricV2']['exploitabilityScore']
            result['impactScore'] = cve['impact']['baseMetricV2']['impactScore']
        return result
    
    def get_affected_cpes_from_cve(cve: Dict) -> List[str]:
        affected_cpes = []
        for node in cve['configurations']['nodes']:
            affected_cpes.extend(get_affected_cpes_from_node(node))
        return affected_cpes

    def get_affected_cpes_from_node(node: Dict) -> List[str]:
        cpe_uris = []
        if 'children' in node:
            for child in node['children']:
                cpe_uris += get_affected_cpes_from_node(child)
        if 'cpe_match' in node:
            lst = node['cpe_match']
            for x in lst:
                if x['vulnerable']:
                    cpe_uris.append(x['cpe23Uri'])
        return cpe_uris


    json_files = glob.glob(cve_dir + '/*.json')
    print(f'Identified {len(json_files)} CVE files. Extracting relevant data and merging them into {output_path}', flush=True)
    
    all_cve_data = []
    for filepath in tqdm(json_files):
        all_cve_data.extend(get_relevant_info_from_file(Path(filepath)))

    with open(output_path, 'w') as handle:
        json.dump(all_cve_data, handle, indent=4)


def get_cpe_uri_to_title_dict(input_xml_filepath: str, output_filepath: str):
    print(f'Extracting dictionary cpe_uri:cpe_title from {input_xml_filepath} to {output_filepath}')
    root = ET.parse(input_xml_filepath).getroot()
    dct = {}
    for cpe_item in root.findall('{http://cpe.mitre.org/dictionary/2.0}cpe-item'):
        title = cpe_item.find('{http://cpe.mitre.org/dictionary/2.0}title').text
        cpe_uri = cpe_item.find('{http://scap.nist.gov/schema/cpe-extension/2.3}cpe23-item').attrib['name']
        dct[cpe_uri] = title
    with open(output_filepath, 'w') as handle:
        json.dump(dct, handle, indent=4)

## Preprocessing realization and path specification

Filepaths for rest of this notebook are specified here. Also, the realized three functions will:
    
1. Download all CVE datafiles
2. Extract relevant CVE information from all files and merge it into single file
3. Create a dictionary of `cpe_uri: cpe title`, will come handly later

In [None]:
CVE_FOLDER_PATH = '/Users/adam/phd/projects/certificates/cpe_matching/new/cves'
CVE_MERGED_FILEPATH = '/Users/adam/phd/projects/certificates/cpe_matching/new/cve_data.json'

CPE_DICTIONARY_PATH = '/Users/adam/phd/projects/certificates/cpe_matching/new/cpe_dictionary.json'
CPE_XML_PATH = '/Users/adam/phd/projects/certificates/cpe_matching/official-cpe-dictionary_v2.3.xml'

PETR_ONE_TO_ONE_MATCH_JSON = '/Users/adam/Downloads/certs_to_cpe_single_match.json'
CERTIFICATE_DATASET_CSV = '/Users/adam/phd/projects/certificates/cpe_matching/new/cc_full_dataset.csv'

download_cve_data(CVE_FOLDER_PATH)
parse_all_cves(CVE_FOLDER_PATH, CVE_MERGED_FILEPATH)
get_cpe_uri_to_title_dict(CPE_XML_PATH, CPE_DICTIONARY_PATH)

## Main functions

### CPE dictionary building

In [None]:
def get_cpe_vendor(cpe_record):
    vendor = cpe_record.split(':')[3]
    return ' '.join(vendor.split('_'))

def get_cpe_product(cpe_record):
    return ' '.join(cpe_record.split(':')[4].split('_'))

def get_cpe_version(cpe_record):
    return cpe_record.split(':')[5]

with open(PETR_ONE_TO_ONE_MATCH_JSON, 'r') as handle:
    petrs_matches = json.load(handle)
petrs_matches = {x.split('.pdf')[0]:y for x,y in petrs_matches.items()}

with open(CPE_DICTIONARY_PATH, 'r') as handle:
    cpe_data = json.load(handle)

cpe_triplets = [(get_cpe_vendor(x), get_cpe_product(x), get_cpe_version(x)) for x in cpe_data.keys()]
cpe_uri_to_triplet = {x: (get_cpe_vendor(x), get_cpe_product(x), get_cpe_version(x)) for x in cpe_data.keys()}
cpe_triplet_to_uri = {(get_cpe_vendor(x), get_cpe_product(x), get_cpe_version(x)): x for x in cpe_data.keys()}
cpe_vendor_dict = {x: [] for x in [x[0] for x in cpe_triplets]}
cpe_vendor_to_version_dict = {x: [] for x in [x[0] for x in cpe_triplets]}
cpe_full_dict = {x: [] for x in [(x[0], x[2]) for x in cpe_triplets]}

for vendor, product, version in cpe_triplets:
    cpe_vendor_dict[vendor].append((vendor, product, version))
    cpe_vendor_to_version_dict[vendor].append(version)
    cpe_full_dict[(vendor, version)].append(product)
    
with open(CVE_MERGED_FILEPATH, 'r') as handle:
    cve_dataset = json.load(handle)
vuln_score_mapping = {x['cve_id']: x['impact']['base_score'] for x in cve_dataset}

def get_cve_ids_for_cpe_uri(cpe_uri):
    if not isinstance(cpe_uri, str):
        return None
    if not (ids := [cve['cve_id'] for cve in cve_dataset if cpe_uri in cve['vulnerable_cpes']]):
        return None
    else:
        return ids

## Actual functions for CPE<->Certificate matching

In [None]:
def parse_cert_version(crt_name):
    # TODO: E.g. Huawei with version V100R005C30SPC300 gets parsed as 300
    # TODO: Enhance version capabilities
    just_numbers = r'(\d{1,5})(\.\d{1,5})'

    without_version = just_numbers + '+'
    long_version = r'(\bversion)\s*' + just_numbers + '*'
    short_version = r'\bv\s*' + just_numbers + '*'
    regexps = [without_version, long_version, short_version]

    true_matches = [re.search(x, crt_name, re.IGNORECASE) for x in regexps]
    true_matches = [x for x in true_matches if x is not None]
    if true_matches:
        first_match = true_matches[0].group()
        return re.search(just_numbers + r'*', first_match, re.IGNORECASE).group()
    return '-'

def map_petrs_match(report_link):
    for x in petrs_matches.keys():
        if x.replace(' ', '%20') in report_link:
            base_string = 'hotfix:' + petrs_matches[x]
            return get_cpe_vendor(base_string), get_cpe_version(base_string), get_cpe_product(base_string)
    return None

def get_matching_vendors(vendor_name: str) -> Optional[List[str]]:
    result = set()
    if not isinstance(vendor_name, str):
        return None
    lower = vendor_name.lower()
    if ' / ' in vendor_name:
        chain = [get_matching_vendors(x) for x in vendor_name.split(' / ')]
        chain = [x for x in chain if x]
        return list(set(itertools.chain(*chain)))
    if lower in cpe_vendor_dict.keys():
        result.add(lower)
    if ' ' in lower and (y := lower.split(' ')[0]) in cpe_vendor_dict.keys():
        result.add(y)
    if ',' in lower and (y := lower.split(',')[0]) in cpe_vendor_dict.keys():
        result.add(y)
    if not result:
        return None
    return list(result)

def get_matching_versions(my_version: str, candidates: List[str]):
    just_numbers = r'(\d{1,5})(\.\d{1,5})'
    return list({x for x in candidates if ((my_version.startswith(x) and re.search(just_numbers, x)) or x.startswith(my_version))})

def get_best_match(cert_name: str, list_of_pairs: List[Tuple[str, str]]):
    # TODO: If equal matches, this kind of returns random match
    best_match = 0
    best_candidate = (None, None, None)
    if not list_of_pairs:
        return best_match, best_candidate

    for vendor, version in list_of_pairs:
        for candidate in cpe_full_dict[(vendor, version)]:
            if (potential := fuzz.partial_ratio(cert_name, candidate)) > best_match:
                best_match = potential
                best_candidate = vendor, candidate, version
    return best_match, best_candidate


def match_cpe(vendor_name: str, cert_name: str, version: str):
    matching_vendors = get_matching_vendors(vendor_name)
    matching_versions = []
    if not matching_vendors:
        return None, None

    all_candidates = []

    for v in matching_vendors:
        matching_versions.append(get_matching_versions(version, cpe_vendor_to_version_dict[v]))

    for vendor, versions in zip(matching_vendors, matching_versions):
        all_candidates.extend([vendor, v] for v in versions)

    best_match, best_candidate = get_best_match(cert_name, all_candidates)
    return best_match, best_candidate

In [None]:
sec_level_dict = {'EAL1': 0, 'EAL1+': 1, 'EAL2': 2, 'EAL2+': 3, 'EAL3': 4, 'EAL3+': 5, 'EAL4': 6, 'EAL4+': 7, 'EAL5': 8, 'EAL5+': 9, 'EAL6': 10, 'EAL6+': 11, 'EAL7': 12, 'EAL7+': 13}

df = pd.read_csv(CERTIFICATE_DATASET_CSV, sep=';')
df = df.set_index('dgst')

df.security_level = df.security_level.map(ast.literal_eval) # Since we have it in string representation, not needed when deserializing

df['max_security_level'] = df.security_level.map(lambda x: max([sec_level_dict.get(y, -1) for y in x]) if x else -1)

df['version'] = df['name'].map(parse_cert_version)

df.not_valid_before = df.not_valid_before.apply(pd.to_datetime)
df.not_valid_after = df.not_valid_after.apply(pd.to_datetime)

df['petr_match'] = df.report_link.map(map_petrs_match)
df['adam_match'] = df.apply(lambda x: match_cpe(x['manufacturer'], x['name'], x['version']), axis=1)

df['match_score'] = df.adam_match.apply(lambda x: x[0])
df['adam_match'] = df.adam_match.apply(lambda x: x[1])

df['has_long_cpe_match'] = df.adam_match.apply(lambda x: len(x[1]) > 5 if x and x[1] else False)
df['matched_cpe_uri'] = df.adam_match.map(cpe_triplet_to_uri)

# # Filter only to relevant pieces
df = df.loc[df.has_long_cpe_match == True]
df = df.loc[df.match_score > 80]

df['related_cves'] = df.matched_cpe_uri.progress_map(get_cve_ids_for_cpe_uri)
df['n_related_cves'] = df.related_cves.apply(lambda x: len(x) if x else 0)

# df_vuln = df.loc[df.n_related_cves > 0 ]
# vulnerable_certs_dict = df_vuln.to_dict(orient='index')

In [None]:
df.head()

In [None]:
df_cves = df.explode('related_cves')
df_cves = df_cves.reset_index()
df_cves['cve_score'] = df_cves.related_cves.map(vuln_score_mapping)
df_cves = df_cves.loc[df_cves.n_related_cves < 100]

In [None]:
fig, ax = plt.subplots()
ax = df_cves.plot.scatter('max_security_level', 'cve_score', c='n_related_cves', colormap='viridis',
                         s = 40,
                         title='CVE score vs. security level of affected certificate. Color = number of CVEs related to certificate',
                         xlabel='Security level, EAL1=0, EAL7+=13',
                         ylabel='CVE severity score 1-10',
                         figsize=(12,10), ax=ax)
fig = ax.get_figure()
fig.savefig('/Users/adam/Downloads/scatter_plot.png', dpi=300)