# References

This notebook contains analysis of certificate references in Common Criteria certificates.

In [None]:
import warnings
from pathlib import Path
from typing import Iterable

import matplotlib.pyplot as plt
import networkx as nx
import networkx.algorithms.community as nx_comm
import numpy as np
import pandas as pd
import seaborn as sns
from notebooks.fixed_sankey_plot import sankey
from tqdm import tqdm

from sec_certs.dataset.cc import CCDataset

# Surpress user warnings
warnings.filterwarnings("ignore", category=UserWarning)

%matplotlib inline

# matplotlib.use("pgf")
sns.set_theme(style="white")
plt.rcParams["axes.linewidth"] = 0.5
plt.rcParams["legend.fontsize"] = 6.5
plt.rcParams["xtick.labelsize"] = 8
plt.rcParams["ytick.labelsize"] = 8
plt.rcParams["ytick.left"] = True
plt.rcParams["ytick.major.size"] = 5
plt.rcParams["ytick.major.width"] = 0.5
plt.rcParams["ytick.major.pad"] = 0
plt.rcParams["xtick.bottom"] = True
plt.rcParams["xtick.major.size"] = 5
plt.rcParams["xtick.major.width"] = 0.5
plt.rcParams["xtick.major.pad"] = 0
# plt.rcParams["pgf.texsystem"] = "pdflatex"
# plt.rcParams["font.family"] = "serif"
# plt.rcParams["text.usetex"] = True
# plt.rcParams["pgf.rcfonts"] = False
plt.rcParams["axes.titlesize"] = 8
plt.rcParams["legend.handletextpad"] = 0.3
plt.rcParams["lines.markersize"] = 4
plt.rcParams["savefig.pad_inches"] = 0.01
sns.set_palette("deep")

# plt.style.use("seaborn-whitegrid")
# sns.set_palette("deep")
# sns.set_context("notebook")  # Set to "paper" for use in paper :)

# plt.rcParams['figure.figsize'] = (10, 6)

RESULTS_DIR = Path("./results/references")
RESULTS_DIR.mkdir(exist_ok=True, parents=True)
SMARTCARD_CATEGORY = "ICs, Smart Cards and Smart Card-Related Devices and Systems"


## Common processing functions

In [None]:
def len_if_exists(x) -> int:
    return len(x) if pd.notnull(x) else 0


def compute_reference_numbers(df__: pd.DataFrame) -> pd.DataFrame:
    """
    Creates new columns with number of references for each certificate.
    """
    return df__.copy().assign(
        n_refs=lambda df_: df_.refs.map(len_if_exists),
        n_trans_refs=lambda df_: df_.trans_refs.map(len_if_exists),
        n_in_refs=lambda df_: df_.in_refs.map(len_if_exists),
        n_in_trans_refs=lambda df_: df_.in_trans_refs.map(len_if_exists),
    )


def preprocess_cc_df(cc_df: pd.DataFrame) -> pd.DataFrame:
    """
    Pre-processing run on the CC dataset for the sake of this notebook.
    """
    return (
        cc_df.loc[cc_df.cert_id.notnull()]
        .copy()
        .rename(
            columns={
                "directly_referencing": "refs",
                "indirectly_referencing": "trans_refs",
                "directly_referenced_by": "in_refs",
                "indirectly_referenced_by": "in_trans_refs",
            }
        )
        .assign(
            longer_than_5_years=lambda df_: df_.not_valid_after - df_.not_valid_before > pd.Timedelta(days=5 * 365),
            not_valid_after=lambda df_: df_.not_valid_after.where(
                ~df_.longer_than_5_years, df_.not_valid_before + pd.Timedelta(days=5 * 365)
            ),
        )
        .drop_duplicates(subset=["cert_id"], keep="first")  # TODO: Investigate high number of duplicates and resolve
    )


def compute_references(cc_df: pd.DataFrame, graph: nx.DiGraph, label: str | Iterable[str]) -> pd.DataFrame:
    """
    Limits the columns with references to a given label.
    """
    label = label if isinstance(label, Iterable) else [label]
    sub_edges = [(u, v) for u, v, d in graph.edges(data=True) if d.get("reference_label") in label]
    subgraph = graph.edge_subgraph(sub_edges)

    return cc_df.assign(
        refs=lambda df_: df_.apply(
            lambda row: set(subgraph.successors(row.cert_id)) if row.cert_id in subgraph else np.nan,
            axis=1,
        ),
        trans_refs=lambda df_: df_.apply(
            lambda row: set(nx.descendants(subgraph, row.cert_id)) if row.cert_id in subgraph else np.nan, axis=1
        ),
        in_refs=lambda df_: df_.apply(
            lambda row: set(subgraph.predecessors(row.cert_id)) if row.cert_id in subgraph else np.nan,
            axis=1,
        ),
        in_trans_refs=lambda df_: df_.apply(
            lambda row: set(nx.ancestors(subgraph, row.cert_id)) if row.cert_id in subgraph else np.nan, axis=1
        ),
    )


def preprocess_refs_df(csv_path: str | Path, cc_df: pd.DataFrame) -> pd.DataFrame:
    return (
        pd.read_csv(csv_path)
        .pipe(lambda df_: df_.loc[df_.dgst.isin(cc_df.index)])
        .assign(cert_id=lambda df_: df_.dgst.map(cc_df.cert_id.to_dict()))
    )


def get_reference_graph_from_refs_df(refs_df: pd.DataFrame) -> nx.DiGraph:
    return nx.from_pandas_edgelist(
        refs_df,
        source="cert_id",
        target="reference",
        create_using=nx.DiGraph,
        edge_attr=["reference_label"],
    )


## Load data and compute reference graph

In [None]:
dset = CCDataset.from_json("/Users/adam/phd/projects/certificates/sec-certs/dataset/cc_final_run_may_23/dataset.json")
cc_df = preprocess_cc_df(dset.to_pandas())
refs_df = preprocess_refs_df("/Users/adam/Downloads/predictions.csv", cc_df)
unique_labels = refs_df.reference_label.unique().tolist()

# Load labeled reference graph as networkx directed graph
graph = nx.from_pandas_edgelist(
    refs_df,
    source="cert_id",
    target="canonical_reference_keyword",
    edge_attr="reference_label",
    create_using=nx.DiGraph,
)

cc_df = compute_reference_numbers(compute_references(cc_df, graph, unique_labels))


## Common processing functions

In [None]:
# Understand which columns I need and limit myself to those columns
# Every analytical cell should be isolated in a function that takes a single input: The dataframe of certificates to work on.
#     - The number of those references is computed in the function itself
#     - Each analytical method should have some tests at the end
#     - If some LaTeX output accompanies the computaiton, the function should return it as a string
#     - Those are stored in a dictionary that keeps expanding


### Count numbers of reference-rich certificates

In [None]:
def compute_basic_reference_graph_stats(df__: pd.DataFrame, graph: nx.DiGraph) -> dict[str, str]:
    df = df__.copy().assign(has_refs=lambda df_: df_.refs.notnull()).pipe(compute_reference_numbers)

    n_ref_smartcards = df.loc[(df.has_refs) & (df.category == SMARTCARD_CATEGORY)].shape[0]
    n_ref_others = df.loc[(df.has_refs) & (df.category != SMARTCARD_CATEGORY)].shape[0]

    print(
        f"Number of smartcard certificates that reference some other certificate: {n_ref_smartcards} ({100 * n_ref_smartcards / df.loc[df.category == SMARTCARD_CATEGORY].shape[0]:.2f}%)"
    )
    print(
        f"Number of non-smartcard certificates that reference some other certificate: {n_ref_others} ({100 * n_ref_others / df.loc[df.category != SMARTCARD_CATEGORY].shape[0]:.2f}%)"
    )
    print(
        f"Total number of referencing certificates: {n_ref_smartcards + n_ref_others} ({100 * (n_ref_smartcards + n_ref_others) / df.shape[0]:.2f}%)"
    )

    df_melted = df[["n_refs", "n_trans_refs", "n_in_refs", "n_in_trans_refs"]].melt()
    df_melted["incoming"] = df_melted.variable.map(lambda x: bool(x.endswith("by")))
    sns.catplot(data=df_melted, kind="boxen", x="variable", y="value", col="variable", sharex=False, sharey=False)
    plt.savefig(RESULTS_DIR / "boxen_plot_references.pdf", bbox_inches="tight")

    plt.show()

    return {}


compute_basic_reference_graph_stats(cc_df, graph)


## Evolution of certificate reach for top-10 certificates

In [None]:
# TODO: Check that it actually works, the data on small subset was fairly weird
# TODO: Work only on sub-component references?
def compute_certs_top_reach(df__: pd.DataFrame) -> dict:
    def find_reach_over_time(df_: pd.DataFrame, cert_id: str, date_range: pd.DatetimeIndex) -> pd.Series:
        df = df_.copy().loc[lambda df_: df_.in_trans_refs.apply(lambda x: pd.notnull(x) and cert_id in x)]
        dct = {
            date: df.loc[(date >= df.not_valid_before) & (date <= df.not_valid_after)].shape[0] for date in date_range
        }
        return pd.Series(dct, name=cert_id)

    df = df__.copy()
    top_10_certs = df.sort_values(by="n_in_trans_refs", ascending=False).head(10)
    print(top_10_certs[["cert_id", "n_in_trans_refs"]])

    date_range = pd.date_range(df.not_valid_before.min(), df.not_valid_before.max())
    data = [find_reach_over_time(df, x, date_range) for x in tqdm(top_10_certs.cert_id.tolist())]
    df_reach_evolution_melted = (
        pd.concat(data, axis=1)
        .rename_axis("date")
        .reset_index()
        .melt(id_vars="date", var_name="certificate", value_name="reach")
    )

    g = sns.lineplot(data=df_reach_evolution_melted, x="date", y="reach", hue="certificate")
    g.set(title="Reach of top-10 certificates in time", xlabel="Time", ylabel="Certificate reach")
    plt.savefig(RESULTS_DIR / "lineplot_top_certificate_reach.pdf", bbox_inches="tight")
    plt.show()

    return {}


compute_certs_top_reach(cc_df)


## Average number of references & certificate reach over time

In [None]:
def compute_avg_references(df__: pd.DataFrame, variable: str, date_range: pd.DatetimeIndex) -> dict:
    df = df__.copy()
    return {
        date: df.loc[(date >= df.not_valid_before) & (date <= df.not_valid_after)][variable].mean()
        for date in tqdm(date_range)
    }


def compute_avg_references_over_time(df__: pd.DataFrame) -> dict:
    df = df__.copy()
    date_range = pd.date_range(df.not_valid_before.min(), df.not_valid_before.max())
    refs_smartcards = compute_avg_references(df.loc[df.category == SMARTCARD_CATEGORY], "n_refs", date_range)
    trans_refs_smartcards = compute_avg_references(
        df.loc[df.category == SMARTCARD_CATEGORY], "n_trans_refs", date_range
    )
    refs_others = compute_avg_references(df.loc[df.category != SMARTCARD_CATEGORY], "n_refs", date_range)
    trans_refs_others = compute_avg_references(df.loc[df.category != SMARTCARD_CATEGORY], "n_trans_refs", date_range)

    df_avg_num_refs_melted = (
        pd.concat(
            [
                pd.Series(refs_smartcards, name="smartcard references"),
                pd.Series(refs_others, name="other references"),
                pd.Series(trans_refs_smartcards, name="smartcard transitive references"),
                pd.Series(trans_refs_others, name="other transitive references"),
            ],
            axis=1,
        )
        .rename_axis("date")
        .reset_index()
        .melt(id_vars=["date"], var_name="category", value_name="n_references")
    )

    g = sns.lineplot(data=df_avg_num_refs_melted, x="date", y="n_references", hue="category")
    g.set(title="Average number of references in certificates", xlabel="Time", ylabel="Number of references")
    plt.savefig(RESULTS_DIR / "lineplot_avg_n_references.pdf", bbox_inches="tight")
    plt.show()

    return {}


def compute_avg_reach_over_time(df__: pd.DataFrame) -> dict:
    df = df__.copy()
    date_range = pd.date_range(df.not_valid_before.min(), df.not_valid_before.max())
    reach_smartcards = compute_avg_references(df.loc[df.category == SMARTCARD_CATEGORY], "n_in_trans_refs", date_range)
    reach_others = compute_avg_references(df.loc[df.category != SMARTCARD_CATEGORY], "n_in_trans_refs", date_range)

    df_avg_num_refs_melted = (
        pd.concat(
            [
                pd.Series(reach_smartcards, name="smartcard reach"),
                pd.Series(reach_others, name="other reach"),
            ],
            axis=1,
        )
        .rename_axis("date")
        .reset_index()
        .melt(id_vars=["date"], var_name="category", value_name="n_references")
    )

    g = sns.lineplot(data=df_avg_num_refs_melted, x="date", y="n_references", hue="category")
    g.set(
        title="Average certificate reach over time",
        xlabel="Time",
        ylabel="Number of (transitively) referencing certificates",
    )
    plt.savefig(RESULTS_DIR / "lineplot_avg_n_references.pdf", bbox_inches="tight")
    plt.show()

    return {}


compute_avg_references_over_time(cc_df)
compute_avg_reach_over_time(cc_df)


## Number of active vs. number of reference-rich certificates in time

In [None]:
def compute_number_of_active_vs_ref_rich_certs_over_time(df__: pd.DataFrame) -> dict:
    df = df__.copy()
    date_range = pd.date_range(df.not_valid_before.min(), df.not_valid_before.max())

    dct_active_others = {}
    dct_reference_rich_others = {}
    dct_active_smartcards = {}
    dct_reference_rich_smartcards = {}

    for date in tqdm(date_range):
        active_certs = df.loc[(date >= df.not_valid_before) & (date <= df.not_valid_after)]
        dct_active_others[date] = active_certs.loc[active_certs.category != SMARTCARD_CATEGORY].shape[0]
        dct_active_smartcards[date] = active_certs.loc[active_certs.category == SMARTCARD_CATEGORY].shape[0]
        dct_reference_rich_others[date] = active_certs.loc[
            (active_certs.category != SMARTCARD_CATEGORY) & (active_certs.n_refs > 0)
        ].shape[0]
        dct_reference_rich_smartcards[date] = active_certs.loc[
            (active_certs.category == SMARTCARD_CATEGORY) & (active_certs.n_refs > 0)
        ].shape[0]

    df_active_vs_ref_rich_melted = (
        pd.concat(
            [
                pd.Series(dct_active_others, name="active other categories"),
                pd.Series(dct_active_smartcards, name="active smartcards"),
                pd.Series(dct_reference_rich_others, name="ref. rich other categories"),
                pd.Series(dct_reference_rich_smartcards, name="ref. rich smartcards"),
            ],
            axis=1,
        )
        .rename_axis("date")
        .reset_index()
        .melt(id_vars=["date"], var_name="category", value_name="number of certificates")
    )

    g = sns.lineplot(data=df_active_vs_ref_rich_melted, x="date", y="number of certificates", hue="category")
    g.set(
        title="Number of active certificates vs. reference-rich certificates in time",
        xlabel="Time",
        ylabel="Number of certificates",
    )
    plt.savefig(RESULTS_DIR / "lienplot_n_active_certs_vs_n_references.pdf", bbox_inches="tight")
    plt.show()
    return {}


def compute_summary_active_vs_ref_rich_over_time(df__: pd.DataFrame) -> dict:
    df = df__.copy()
    date_range = pd.date_range(df.not_valid_before.min(), df.not_valid_before.max())

    dct_active_all = {}
    dct_reference_rich_all = {}
    dct_referenced_all = {}
    dct_isolated_all = {}

    for date in tqdm(date_range):
        active_certs = df.loc[(date >= df.not_valid_before) & (date <= df.not_valid_after)]
        dct_active_all[date] = active_certs.shape[0]
        dct_isolated_all[date] = active_certs.loc[(active_certs.n_refs == 0) & (active_certs.n_in_refs == 0)].shape[0]
        dct_reference_rich_all[date] = active_certs.loc[active_certs.n_refs > 0].shape[0]
        dct_referenced_all[date] = active_certs.loc[active_certs.n_in_refs > 0].shape[0]

    df_summary_references = (
        pd.concat(
            [
                pd.Series(dct_active_all, name="active certificates"),
                pd.Series(dct_reference_rich_all, name="ref. rich certificates"),
                pd.Series(dct_referenced_all, name="referenced certificates"),
                pd.Series(dct_isolated_all, name="isolated certificates"),
            ],
            axis=1,
        )
        .rename_axis("date")
        .reset_index()
    )

    df_summary_references_melted = df_summary_references.melt(
        id_vars=["date"], var_name="category", value_name="number of certificates"
    )
    g = sns.lineplot(
        data=df_summary_references_melted, x="date", y="number of certificates", hue="category", errorbar=None
    )
    g.set(
        title="Number of active certificates vs. reference-rich vs. referenced certificates in time",
        xlabel="Time",
        ylabel="Number of certificates",
    )
    plt.savefig(RESULTS_DIR / "lineplot_references_summary.pdf", bbox_inches="tight")
    plt.show()

    df_ratios = df_summary_references.copy()
    df_ratios["ref. rich certificates"] = df_ratios["ref. rich certificates"] / df_ratios["active certificates"]
    df_ratios["referenced certificates"] = df_ratios["referenced certificates"] / df_ratios["active certificates"]
    df_ratios["isolated certificates"] = df_ratios["isolated certificates"] / df_ratios["active certificates"]
    df_ratios = df_ratios.drop(columns=["active certificates"])
    df_ratios_melted = df_ratios.melt(id_vars=["date"], var_name="category", value_name="ratio of certificates")

    g = sns.lineplot(data=df_ratios_melted, x="date", y="ratio of certificates", hue="category", errorbar=None)
    g.set(
        title="ratio of reference-rich vs. referenced vs. isolated certificates in time",
        xlabel="Time",
        ylabel="Number of certificates",
    )
    plt.savefig(RESULTS_DIR / "lineplot_reference_ratio.pdf", bbox_inches="tight")
    plt.show()

    return {}


compute_number_of_active_vs_ref_rich_certs_over_time(cc_df)
compute_summary_active_vs_ref_rich_over_time(cc_df)


## Number of active certificates that reference some archived certificate in time

In [None]:
def compute_certs_referencing_archived_ones(df__: pd.DataFrame) -> dict:
    df = df__.copy()
    date_range = pd.date_range(df.not_valid_before.min(), df.not_valid_before.max())

    dct_direct_others = {}
    dct_direct_smartcards = {}
    dct_transitive_others = {}
    dct_transitive_smartcards = {}

    for date in tqdm(date_range):
        active_certs = df.loc[(date >= df.not_valid_before) & (date <= df.not_valid_after)].copy()
        active_certs_cert_ids = set(active_certs["cert_id"].tolist())
        active_certs["no_intersection"] = active_certs.refs.map(
            lambda x: False if pd.isnull(x) else not x.intersection(active_certs_cert_ids)
        )
        active_certs["no_transitive_intersection"] = active_certs.trans_refs.map(
            lambda x: False if pd.isnull(x) else not x.intersection(active_certs_cert_ids)
        )
        dct_direct_others[date] = active_certs.loc[
            (active_certs.no_intersection) & (active_certs.category != SMARTCARD_CATEGORY)
        ].shape[0]
        dct_transitive_others[date] = active_certs.loc[
            (active_certs.no_transitive_intersection) & (active_certs.category != SMARTCARD_CATEGORY)
        ].shape[0]
        dct_direct_smartcards[date] = active_certs.loc[
            (active_certs.no_intersection) & (active_certs.category == SMARTCARD_CATEGORY)
        ].shape[0]
        dct_transitive_smartcards[date] = active_certs.loc[
            (active_certs.no_transitive_intersection) & (active_certs.category == SMARTCARD_CATEGORY)
        ].shape[0]

    df_refs_to_archived_melted = (
        pd.concat(
            [
                pd.Series(dct_direct_others, name="direct reference others"),
                pd.Series(dct_transitive_others, name="transitive reference others"),
                pd.Series(dct_direct_smartcards, name="direct reference smartcards"),
                pd.Series(dct_transitive_smartcards, name="transitive reference smartcards"),
            ],
            axis=1,
        )
        .rename_axis("date")
        .reset_index()
        .melt(id_vars=["date"], var_name="reference type", value_name="number of certificates")
    )

    g = sns.lineplot(data=df_refs_to_archived_melted, x="date", y="number of certificates", hue="reference type")
    g.set(
        title="Number of active certificates that reference some archived certificate",
        xlabel="Time",
        ylabel="Number of certificates",
    )
    plt.savefig(RESULTS_DIR / "lienplot_active_certs_referencing_archived.pdf", bbox_inches="tight")
    plt.show()

    return {}


compute_certs_referencing_archived_ones(cc_df)


## Certificates referencing vulnerable certificates in time

In [None]:
def compute_certs_referencing_vulnerable_over_time(df__: pd.DataFrame) -> dict:
    df = df__.copy()
    date_range = pd.date_range(df.not_valid_before.min(), df.not_valid_before.max())
    vulnerable_cert_ids = set(df.loc[df.related_cves.notnull()].cert_id.tolist())
    dct_direct_others = {}
    dct_transitive_others = {}
    dct_direct_smartcards = {}
    dct_transitive_smartcards = {}

    for date in tqdm(date_range):
        active_certs = df.loc[(date >= df.not_valid_before) & (date <= df.not_valid_after)].copy()
        active_certs["directly_references_vulnerable_cert"] = active_certs.refs.map(
            lambda x: False if pd.isnull(x) else bool(x.intersection(vulnerable_cert_ids))
        )
        active_certs["transitively_references_vulnerable_cert"] = active_certs.trans_refs.map(
            lambda x: False if pd.isnull(x) else bool(x.intersection(vulnerable_cert_ids))
        )
        dct_direct_others[date] = active_certs.loc[
            (active_certs.directly_references_vulnerable_cert) & (active_certs.category != SMARTCARD_CATEGORY)
        ].shape[0]
        dct_transitive_others[date] = active_certs.loc[
            (active_certs.transitively_references_vulnerable_cert) & (active_certs.category != SMARTCARD_CATEGORY)
        ].shape[0]
        dct_direct_smartcards[date] = active_certs.loc[
            (active_certs.directly_references_vulnerable_cert) & (active_certs.category == SMARTCARD_CATEGORY)
        ].shape[0]
        dct_transitive_smartcards[date] = active_certs.loc[
            (active_certs.transitively_references_vulnerable_cert) & (active_certs.category == SMARTCARD_CATEGORY)
        ].shape[0]

    df_references_vuln_melted = (
        pd.concat(
            [
                pd.Series(dct_direct_others, name="direct references others"),
                pd.Series(dct_transitive_others, name="transitive references others"),
                pd.Series(dct_direct_smartcards, name="direct references smartcards"),
                pd.Series(dct_transitive_smartcards, name="transitive references smartcards"),
            ],
            axis=1,
        )
        .rename_axis("date")
        .reset_index()
        .melt(id_vars=["date"], var_name="reference type", value_name="number of certificates")
    )

    g = sns.lineplot(data=df_references_vuln_melted, x="date", y="number of certificates", hue="reference type")
    g.set(
        title="Number of active certificates that reference some vulnerable certificate",
        xlabel="Time",
        ylabel="Number of certificates",
    )
    plt.savefig(RESULTS_DIR / "lienplot_active_certs_referencing_vulnerable.pdf", bbox_inches="tight")
    plt.show()
    return {}


compute_certs_referencing_vulnerable_over_time(cc_df)


### Plot direct references per category (count plot)

In [None]:
def plot_direct_refs_per_category(df__: pd.DataFrame) -> dict:
    df = df__.copy().assign(
        has_outgoing_direct_references=lambda df_: df_.n_refs > 0,
        has_incoming_direct_references=lambda df_: df_.n_in_refs > 0,
    )
    figure, axes = plt.subplots(1, 2)
    figure.set_size_inches(16, 10)
    figure.set_tight_layout(True)

    col_to_depict = ["has_outgoing_direct_references", "has_incoming_direct_references"]

    for index, col in enumerate(col_to_depict):
        countplot = sns.countplot(data=df, x="category", hue=col, ax=axes[index])
        countplot.set(
            xlabel="Category",
            ylabel="Outgoing direct references",
            title=f"Countplot of {' '.join(col.split('_'))}",
        )
        countplot.tick_params(axis="x", rotation=90)
        countplot.legend(title=" ".join(col.split("_")), bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)

    plt.show()

    return {}


plot_direct_refs_per_category(cc_df)


### Plot direct references per category (Sankey diagram)

In [None]:
def plot_sankey_refs_categories(df__: pd.DataFrame) -> dict:
    df = df__.copy()

    cert_id_to_category_mapping = dict(zip(df.cert_id, df.category))
    cert_id_to_category_mapping[np.NaN] = "No references"

    exploded = df.loc[:, ["category", "refs"]].explode("refs")
    exploded["ref_category"] = exploded.refs.map(lambda x: cert_id_to_category_mapping[x] if pd.notnull(x) else np.nan)
    exploded = exploded.loc[exploded.ref_category.notnull()]

    exploded_with_refs = exploded.loc[exploded.ref_category != "No references"]

    all_categories = set(exploded.category.unique()) | set(exploded.ref_category.unique())
    colors = list(sns.color_palette("hls", len(all_categories), as_cmap=False).as_hex())
    color_dict = dict(zip(all_categories, colors))

    figure, axes = plt.subplots(1, 1)
    figure.set_size_inches(24, 10)
    figure.set_tight_layout(True)

    sankey(
        exploded.category,
        exploded.ref_category,
        colorDict=color_dict,
        leftLabels=list(exploded.category.unique()),
        rightLabels=list(exploded.ref_category.unique()),
        fontsize=12,
        ax=axes,
    )

    plt.show()

    return {}


plot_sankey_refs_categories(cc_df)


### Plot direct references per scheme (count plot)

In [None]:
def plot_refs_per_scheme(df__: pd.DataFrame) -> dict:
    df = df__.copy().assign(
        has_outgoing_direct_references=lambda df_: df_.n_refs > 0,
        has_incoming_direct_references=lambda df_: df_.n_in_refs > 0,
    )
    figure, axes = plt.subplots(1, 2)
    figure.set_size_inches(14, 4)
    figure.set_tight_layout(True)

    col_to_depict = ["has_outgoing_direct_references", "has_incoming_direct_references"]

    for index, col in enumerate(col_to_depict):
        countplot = sns.countplot(data=df, x="scheme", hue=col, ax=axes[index])
        countplot.set(
            xlabel="Category",
            ylabel="Outgoing direct references",
            title=f"Countplot of {' '.join(col.split('_'))}",
        )
        countplot.tick_params(axis="x", rotation=90)
        countplot.legend(title=" ".join(col.split("_")), bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)

    plt.show()

    return {}


plot_refs_per_scheme(cc_df)


### Number of certificates referencing archived certificates (count plot)

In [None]:
def countplot_certs_referencing_archived(df__: pd.DataFrame) -> dict:
    def references_archived_cert(references):
        if pd.isnull(references):
            return False

        return any([x in cert_ids] for x in references)

    df = df__.copy()

    cert_ids = set(df.loc[((df.cert_id.notnull()) & (df.status == "archived")), "cert_id"].tolist())
    df["references_archived_cert"] = df.in_refs.map(references_archived_cert)

    # TODO: We should limit on the number of certificates that referenced an archived certificate at some point where they were active as well.
    print(
        f"Number of certificates that reference some archived certificate: {df.loc[df.references_archived_cert].shape[0]}"
    )

    col_to_depict = ["category", "scheme"]

    figure, axes = plt.subplots(1, 2)
    figure.set_size_inches(14, 8)
    figure.set_tight_layout(True)

    for index, col in enumerate(col_to_depict):
        countplot = sns.countplot(data=df, x=col, hue="references_archived_cert", ax=axes[index])
        countplot.set(
            xlabel=col,
            ylabel="Outgoing direct references",
            title="Countplot of certificates that reference some archived certificate",
        )
        countplot.tick_params(axis="x", rotation=90)
        countplot.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)

    plt.show()

    return {}


countplot_certs_referencing_archived(cc_df)


### Count scheme references (Sankey diagram)

In [None]:
def plot_sankey_refs_schemes(df__: pd.DataFrame) -> dict:
    cert_id_to_scheme_mapping = dict(zip(df__.cert_id, df__.scheme))
    exploded = (
        df__.copy()
        .loc[lambda df_: df_.refs.notnull(), ["scheme", "refs"]]
        .explode("refs")
        .assign(ref_scheme=lambda df_: df_.refs.map(cert_id_to_scheme_mapping))
        .loc[lambda df_: df_.ref_scheme.notnull()]
    )

    all_schemes = set(exploded.scheme.unique()) | set(exploded.ref_scheme.unique())
    colors = list(sns.color_palette("hls", len(all_schemes), as_cmap=False).as_hex())
    color_dict = dict(zip(all_schemes, colors))

    figure, axes = plt.subplots(1, 1)
    figure.set_size_inches(4, 4)
    figure.set_tight_layout(True)

    sankey(
        exploded.scheme,
        exploded.ref_scheme,
        colorDict=color_dict,
        leftLabels=list(exploded.scheme.unique()),
        rightLabels=list(exploded.ref_scheme.unique()),
        fontsize=7,
        ax=axes,
    )

    figure.savefig(str(RESULTS_DIR / "scheme_references.pdf"), bbox_inches="tight")
    figure.savefig(str(RESULTS_DIR / "scheme_references.pgf"), bbox_inches="tight")
    plt.show()

    return {}


plot_sankey_refs_schemes(cc_df)


## Reference network visualization

In [None]:
# Print:
# - How many references in reports
# - How many references in targets
# - How many references in total
# -


### Combined references

### Certificate overview
Enter the certificate you are interested in below and see its reference graph component.

In [None]:
cert_id = "ANSSI-CC-2014/07"

for component in nx.weakly_connected_components(graph):
    if cert_id in component:
        break
else:
    raise ValueError(f"Certificate with id {cert_id} not found in graph.")

view = nx.subgraph_view(graph, lambda node: node in component)
print(f"Certificate with id {cert_id}:")
print(f" - is in a component with {len(view.nodes)} certificates and {len(view.edges)} references.")
print(f" - references {list(view[cert_id].keys())}")
print(f" - is referenced by {list(view.predecessors(cert_id))}")
for cert in dset:
    if cert.heuristics.cert_id == cert_id:
        break
else:
    raise ValueError(f"Certificate with id {cert_id} not found in dataset.")
print(f" - its page is at https://seccerts.org/cc/{cert.dgst}/")


In [None]:
nx.draw(view, pos=nx.planar_layout(view), with_labels=True)


## Some graph metrics
See:
- <https://dataground.io/2021/09/29/simple-graph-metrics-networkx-for-beginners/>
- <https://theslaps.medium.com/centrality-metrics-via-networkx-python-e13e60ba2740>
- <https://www.geeksforgeeks.org/network-centrality-measures-in-a-graph-using-networkx-python/>


In [None]:
print(f"Density = {nx.density(graph)}")
print(f"Transitivity = {nx.transitivity(graph)}")

print("Degree centrality <Popularity> (top 20):")
degree_centrality_vals = list(nx.degree_centrality(graph).items())
degree_centrality_vals.sort(key=lambda pair: pair[1], reverse=True)
for pair in degree_centrality_vals[:20]:
    print(f"\t{pair[0]} = {pair[1]}")

print("Eigenvector centrality <Influence> (top 20):")
eigenvector_centrality_vals = list(nx.eigenvector_centrality(graph).items())
eigenvector_centrality_vals.sort(key=lambda pair: pair[1], reverse=True)
for pair in eigenvector_centrality_vals[:20]:
    print(f"\t{pair[0]} = {pair[1]}")

print("Closeness centrality <Centralness> (top 20):")
closeness_centrality_vals = list(nx.closeness_centrality(graph).items())
closeness_centrality_vals.sort(key=lambda pair: pair[1], reverse=True)
for pair in closeness_centrality_vals[:20]:
    print(f"\t{pair[0]} = {pair[1]}")

print("Betweenness centrality <Bridge> (top 20):")
betweenness_centrality_vals = list(nx.betweenness_centrality(graph).items())
betweenness_centrality_vals.sort(key=lambda pair: pair[1], reverse=True)
for pair in betweenness_centrality_vals[:20]:
    print(f"\t{pair[0]} = {pair[1]}")

component_lengths = list(filter(lambda comp_len: comp_len > 1, map(len, nx.weakly_connected_components(graph))))
component_lengths.sort(reverse=True)
# print(component_lengths)
print(f"Number of weakly connected subgraphs: {len(component_lengths)}")
print(f"Size of the largest weakly connected subgraphs: {component_lengths[:10]}")

big_boy = graph.subgraph(max(nx.weakly_connected_components(graph), key=len))
communities = list(nx_comm.greedy_modularity_communities(big_boy))
print(len(communities))

for com in communities:
    for i in sorted(com):
        print(f"\t{i}")


## LaTeX commands

In [None]:
# TODO: These are old commands that belonged to first paper. Replace them with code that can produce the similar numbers.
# print(f"\\newcommand{{\\numCcAllDirectReferencing}}{{{df.has_outgoing_direct_references.sum()}}}")
# print(f"\\newcommand{{\\numCcAllNotDirectReferencing}}{{{len(df) - df.has_outgoing_direct_references.sum()}}}")
# print(f"\\newcommand{{\\numCcWithIdDirectReferencing}}{{{df_id_rich.has_outgoing_direct_references.sum()}}}")
# print(f"\\newcommand{{\\numCcWithIdNotDirectReferencing}}{{{len(df_id_rich) - df_id_rich.has_outgoing_direct_references.sum()}}}")
# print(f"\\newcommand{{\\numCCActiveDirectReferencing}}{{{df_id_rich.loc[df_id_rich.status == 'active'].has_outgoing_direct_references.sum()}}}")

# print("")
# print(f"\\newcommand{{\\numCCDirectRefsSameCategory}}{{{(exploded_with_refs.category == exploded_with_refs.ref_category).sum()}}}")
# print(f"\\newcommand{{\\numCCDirectRefsOtherCategory}}{{{(exploded_with_refs.category != exploded_with_refs.ref_category).sum()}}}")
# print(f"\\newcommand{{\\numCCDirectRefs}}{{{len(exploded_with_refs)}}}")
# print(f"\\newcommand{{\\numCCDirectRefsFromSmartcards}}{{{(exploded_with_refs.category == 'ICs, Smart Cards and Smart Card-Related Devices and Systems').sum()}}}")

# print("")
# print(f"\\newcommand{{\\numCCUSReferencing}}{{{len(df_id_rich.loc[(df_id_rich.scheme == 'US') & (df_id_rich.directly_referencing.notnull())])}}}")
# print(f"\\newcommand{{\\numCCUS}}{{{len(df_id_rich.loc[(df_id_rich.scheme == 'US')])}}}")
