In [None]:
import json

def load_metadata(metadata_path):
    try:
        metadata = {}
        with open(metadata_path, "r") as f:
            metadata = json.load(f)

        assert metadata
        entries = metadata["entries"].values()
        assert 0 < len(entries)
    except:
        print("load_metadata: retrieving metadata was unsuccessful")
        return {}

    # We now group entries by vendor
    grouped = {}
    for entry in entries:
        vendor = entry.get("vendor")
        if vendor is None:
            print(
                f"report_create: entry {entry} does not contain vendor")
            continue

        grouped.setdefault(vendor, [])
        grouped[vendor].append(entry)
    return grouped

metadata = load_metadata('./algtest-pyprocess/metadata.json')

with open("./algtest-pyprocess/tpm-pcr-metadata.json", "r") as f:
    tpm_pcr_metadata = json.load(f)

In [None]:
import logging, sys
logging.disable(sys.maxsize)

In [None]:
from algtestprocess.modules.data.tpm.manager import TPMProfileManager
from algtestprocess.modules.data.tpm.enums import CryptoPropResultCategory as cat

def firmware_name(profile):
    return f"{profile.manufacturer} {profile.firmware_version}"

def group_by_vendor(cpps):
    grouped = {}
    for cpp in cpps:
        manufacturer = cpp.manufacturer
        grouped.setdefault(manufacturer, [])
        grouped[manufacturer].append(cpp)
    return grouped

def aggregate_by_firmware(cpps):
    aggregated = {}
    for cpp in cryptoprops:
        if cpp is None:
            print(f"Cannot merge {cpp.firmware_name}, as it is None")
            continue

        fwn = firmware_name(cpp)
        if fwn not in aggregated:
            aggregated.setdefault(fwn, cpp)
        else:
            aggregated[fwn] = aggregated[fwn] + cpp
    return [cpp for cpp in aggregated.values()]

def aggregate_by_implementation(cpps, groups):
    compiled_groups = [[re.compile(pattern) for pattern in group] for group in groups]

    aggre = [None] * len(groups)
    for cpp in cpps:
        fwn = firmware_name(cpp)
        for i, group in enumerate(compiled_groups):
            if any([pattern.match(fwn) is not None for pattern in group]):
                if aggre[i] is None:
                    aggre[i] = cpp
                else:
                    aggre[i] += cpp
    return aggre

def compute_pqn_bytes(df):
    df = df.dropna(subset=["p", "n"])
    
    if len(df.index) < 1:
        raise ValueError("visualized dataframe must not be empty")

    # As the data doesn't contain q prime it needs to be computed
    n = list(map(lambda x: int(str(x), 16), list(df.n)))
    p = list(map(lambda x: int(str(x), 16), list(df.p)))
    q = [a // b for a, b in zip(n, p)]

    p_byte = [
        x
        >> (x.bit_length() - (8 if x.bit_length() % 8 == 0 else x.bit_length() % 8))
        for x in p
    ]
    q_byte = [
        x
        >> (x.bit_length() - (8 if x.bit_length() % 8 == 0 else x.bit_length() % 8))
        for x in q
    ]
    n_byte = [
        x
        >> (x.bit_length() - (8 if x.bit_length() % 8 == 0 else x.bit_length() % 8))
        for x in n
    ]

    return p_byte, q_byte, n_byte


cryptoprops = []
for tpms in metadata.values():
    for entry in tpms:
        for measurement_path in entry['measurement paths']:
            man = TPMProfileManager(measurement_path)
            cpps = man.cryptoprops
            if cpps is None:
                continue
            cryptoprops.append(cpps)

# All tpms in one list
tpms = cryptoprops
# All tpms grouped by vendor
tpms_grouped_by_vendor = group_by_vendor(tpms)
# All tpms aggregated by the tpm name
tpms_aggregated = aggregate_by_firmware(cryptoprops)
# All tpms aggregated by the tpm name and grouped by vendor
tpms_aggregated_grouped_by_vendor = group_by_vendor(tpms_aggregated)

In [None]:
import re

roca_blacklist = [
    r'IFX.*7\.40\.*',
    r'IFX.*7\.61\.*',
    r'IFX.*5\.61\.*',
]

def get_improbable_ek(tpms):
    improbable_ek = {}
    for cpp in tpms:
        ek_msb = None
        if (result := cpp.results.get(cat.EK_RSA)) is not None:
            if (ek := result.data) is not None:
                ek_prefix, _ = ek[0]
                ek_msb = ek_prefix >> 8
    
        n_msb = None
        if (result := cpp.results.get(cat.RSA_2048)) is not None:
            if (rsa_df := result.data) is not None:
                rsa_df = rsa_df.dropna(subset=["p", "n"])
                n_msb = [x >> x.bit_length() - 8 for x in list(map(lambda x: int(x, 16), rsa_df["n"]))]
    
        fwn = firmware_name(cpp)
        
        if ek_msb is not None and n_msb is not None:
            if ek_msb not in n_msb:
                improbable_ek.setdefault(fwn, [])
                improbable_ek[fwn].append(ek_msb)
    
    for cpp in tpms:
        for entry in roca_blacklist:
            fwn = firmware_name(cpp)
            
            if re.match(entry, fwn):
                ek_msb = None
                if (result := cpp.results.get(cat.EK_RSA)) is not None:
                    if (ek := result.data) is not None:
                        ek_prefix, _ = ek[0]
                        ek_msb = ek_prefix >> 8
    
                if ek_msb is not None:
                    improbable_ek.setdefault(fwn, [])
                    improbable_ek[fwn].append(ek_msb)
    
    for entry in roca_blacklist:
        for fwn, value in tpm_pcr_metadata.items():
            if re.match(entry, fwn):
                print(fwn)
                improbable_ek.setdefault(fwn, [])
                improbable_ek[fwn] += [int(x, 16) >> 8 for x in value['EK']]
    return improbable_ek

improbable_ek = get_improbable_ek(tpms)

In [None]:
import matplotlib.pyplot as plt
from algtestprocess.modules.visualization.heatmap import Heatmap
import matplotlib.lines as mlines
import pandas as pd
import matplotlib.gridspec as gridspec
import re
import numpy as np
import seaborn as sns
import json
from copy import copy
import matplotlib.gridspec as gridspec

plt.rcParams.update({
    "text.usetex": False,
    "ytick.color" : "black",
    "xtick.color" : "black",
    "axes.labelcolor" : "black",
    "axes.edgecolor" : "black",
    "font.family" : "serif",
    "font.serif" : ["Computer Modern Serif"]
})

def firmwarelist2id(firmware_versions):
    # First we group the same major version TPMs
    firmware_versions = sorted(firmware_versions, key=lambda x: [int(y) for y in x.split('.')])
    versions = {}
    for fv in firmware_versions:
        major = fv.split('.')[0]
        if versions.get(major) is None:
            versions.setdefault(major, [])
        
        versions[major].append(fv)
        

    # Then we start building the result
    result = ""
    for major, entries in versions.items():
        if len(result.split('\n')[0]) > 12:
            result += "\n"
        
        if result != "":
            result += " "
        if len(entries) == 1:
            result += entries[0]
        else:
            fst = entries[0].split('.')[1]
            lst = entries[-1].split('.')[1]

            if fst == lst:
                result += f"{major}.{fst}.X"
            else:
                result += f"{major}.{fst}-{lst}.X"

    # STM specific edit, to add one more space to the name
    if '1.258-769.X' in result:
        result = result.replace('1.258-769.X', '') + ' 1.258-769.X'

    return result

def bihist(rsa_df, tpm_name, ek=None, rsk=None, fig=None, ek_hist=False, improbable_eks=None):

    gs = gridspec.GridSpec(2, 1, height_ratios=(0.75, 0.25), wspace=0, hspace=0)


    
    top_hist_ax = fig.add_subplot(gs[0])
    top_hist_ax.set_xlim(128, 255)
    top_hist_ax.set_xticks([128, 255])
    top_hist_ax.locator_params(axis='y', nbins=3)
    top_hist_ax.tick_params(labelbottom=False)
    top_hist_ax.tick_params(axis='y', which='major', labelsize=7)
    top_hist_ax.tick_params(axis='y', which='minor', labelsize=7)
    top_hist_ax.set_facecolor('#ffffff')


    circles_ax = fig.add_subplot(gs[1])
    circles_ax.set_yticks([])
    circles_ax.sharex(top_hist_ax)
    circles_ax.set_ylim(0, 1)
    circles_ax.set_xlabel(
        tpm_name,
        fontsize=9,
        color="black",
        fontfamily="serif")
    circles_ax.tick_params(left='off', right='off')
    circles_ax.set_facecolor('#ffffff')

    circles_ax.tick_params(top=False,
               bottom=False,
               left=False,
               right=False,
               labelleft=True,
               labelbottom=True)

    
    n_msb = [x >> x.bit_length() - 8 for x in list(map(lambda x: int(x, 16), rsa_df["n"]))]

    bins = list(range(128, 257, 1))

    top_hist_ax.hist(n_msb, bins=bins,histtype="stepfilled",   color="#000000", ec="none",  density=True, alpha=0.6)

    if ek_hist and ek is not None:
        top_ek_hist_ax = top_hist_ax.twinx()
        sns.histplot(
            ek,
            alpha=0.2,
            bins=127//10,
            ax=top_ek_hist_ax,
            color="#1e88e5"
        )
        top_ek_hist_ax.axes.get_yaxis().set_visible(False)

    ##############################################################################################################
    # First we need to extract MSBs from each EK
    ek_msb = copy(ek)


    # Count the frequency of each value
    unique_values, value_counts = np.unique(ek_msb, return_counts=True)

    circles_ax.set_yticks([0.25, 0.75], ["EK", "SRK"], fontsize=7)
    
    # Create the histogram-like plot with stacked circles
    for value, count in zip(unique_values, value_counts):
        y_positions = [0.25] * count 
        
        color = '#1e88e5'
        alpha = 0.40
        if (value < min(n_msb) or value > max(n_msb)) or (value not in n_msb and len(n_msb) > 1000) or \
            any([re.match(pattern_entry, tpm_name)for pattern_entry in roca_blacklist]):
            color = '#d81b60'
            alpha = 0.65
        
        circles_ax.scatter([value] * len(y_positions), y_positions, s=75, marker='o', c=color, alpha=alpha)
        
    # Lastly print the improbable EKs as red
    if improbable_eks and (improbable_eks := improbable_ek.get(tpm_name)) is not None:
        circles_ax.scatter(improbable_ek[tpm_name], len(improbable_eks)*[0.25], s=75, marker='o', c='#BA0700', alpha=0.5)


    #################################################################################################################
    # First we need to extract MSBs from each RSK
    rsk_msb = rsk
    # Count the frequency of each valueB
    unique_values, value_counts = np.unique(rsk_msb, return_counts=True)

    # Create the histogram-like plot with stacked circles
    for value, count in zip(unique_values, value_counts):
        y_positions = [0.75] * count 

        color = '#04ab8f'
        alpha = 0.40
        if (value < min(n_msb) or value > max(n_msb)) or (value not in n_msb and len(n_msb) > 1000):
            color = '#d81b60'
            alpha = 0.65
        
        circles_ax.scatter([value] * len(y_positions), y_positions, s=100, marker='o', c=color, alpha=alpha)
        #circles_ax.vlines(value, ymin=0.5, ymax=1, color="blue", lw=1, alpha=0.5)

    ###################################################################################################################
    RSA_KEY_COUNT = len(n_msb)
    EKs_COUNT  = len(ek)
    RSKs_COUNT = len(rsk)

    title = f"RSA Keys: {RSA_KEY_COUNT}"
    if EKs_COUNT > 0:
        title += f", EKs: {EKs_COUNT}"
    if RSKs_COUNT > 0:
        title += f", SRKs: {RSKs_COUNT}"

    top_hist_ax.set_title(title, fontsize=7)

def cpps2fig_bihist(cpps, fig, tpm_pcr_metadata=None, skip_without_keys=False, ek_hist=False, alg=None):
    if cpps is None:
        return None

    result = None 
    if alg is not None:
        result = cpps.results.get(alg)
        if result is None:
            return None

    # To be able to produce a figure we need both rsa dataframe and rsa EK

    rsa_df = None
    if alg is None:
        # Special case when alg is None means we should merge the rsa dataframes
        dfs = []
        for alg in [cat.RSA_1024, cat.RSA_2048, cat.RSA_3072]:
            result = cpps.results.get(alg)
            if result is not None:
                dfs.append(result.data)
        assert dfs != []
        rsa_df = pd.concat(dfs)
    else:
        rsa_df = result.data
    rsa_df = rsa_df.dropna(subset=["p", "n"])
    if len(rsa_df.index) < 1:
        return False

    # We retrieve our TPM name
    tpm_name = f"{cpps.manufacturer}"
    if isinstance(cpps.firmware_version, str):
        tpm_name += f" {cpps.firmware_version}"
    elif isinstance(cpps.firmware_version, list):
        tpm_name += f" {firmwarelist2id(cpps.firmware_version)}"

    # Now we need to get the EKs collected from tpm2-algtest
    # and add the ones found in PCR_TPM measurements
    ek = []

    # In both cases (agg/non-agg) results contain EK list(tuple(4B prefix, 4B suffix))
    if (ek_rsa_result := cpps.results.get(cat.EK_RSA)) is not None:
        if (ek_rsa_result_data := ek_rsa_result.data) is not None:
            # We take each 2B EK prefix which is comes from unique (prefix, suffix) pair
            ek =  [prefix for prefix, _ in list(set(ek_rsa_result_data))]

    # Aggregated result, we need to add EKs from all firmwares in TPM PCR measurements
    if tpm_pcr_metadata and isinstance(cpps.firmware_version, list):
        for fw in cpps.firmware_version:
            fwn = f"{cpps.manufacturer} {fw}"
            if (tpm_pcr_ek := tpm_pcr_metadata.get(fwn)) is not None:
                for prefix_string in tpm_pcr_ek['EK']:
                    if (prefix := int(prefix_string, 16)) not in ek:
                        ek.append(prefix)
    
    # If we have non-aggregated result
    elif tpm_pcr_metadata and (tpm_pcr_ek := tpm_pcr_metadata.get(tpm_name)) is not None:
        for prefix_string in tpm_pcr_ek['EK']:
            if (prefix := int(prefix_string, 16)) not in ek:
                ek.append(prefix)

    # We now just convert all the EKs to MSB
    ek = [x >> 8 for x in ek]

    rsk = []
    if tpm_pcr_metadata and (tpm_pcr_rsk := tpm_pcr_metadata.get(tpm_name)) is not None:
        rsk += [int(x, 16) >> 8 for x in tpm_pcr_rsk['RSK']]

    if tpm_pcr_metadata and isinstance(cpps.firmware_version, list):
        for fw in cpps.firmware_version:
            fwn = f"{cpps.manufacturer} {fw}"
            if (tpm_pcr_rsk := tpm_pcr_metadata.get(fwn)) is not None:
                rsk += [int(x, 16) >> 8 for x in tpm_pcr_rsk['RSK']]
            

    if rsk == [] and ek == [] and skip_without_keys:
        return False
    
    bihist(rsa_df, tpm_name, ek=ek, rsk=rsk, fig=fig, ek_hist=ek_hist)
    return True

def cpps2fig_heatmap(cpps, fig, tpm_pcr_metadata=None, skip_without_keys=False, ek_hist=False, alg=None):
    if cpps is None:
        return None

    result = None 
    if alg is not None:
        result = cpps.results.get(alg)
        if result is None:
            return None

    tpm_name = f"{cpps.manufacturer}"
    if isinstance(cpps.firmware_version, list):
        tpm_name += f" {firmwarelist2id(cpps.firmware_version)}"
    else:
        tpm_name += f" {cpps.firmware_version}"

    df = None
    if alg is None:
        # Special case when alg is None means we should merge the rsa dataframes
        dfs = []
        for alg in [cat.RSA_1024, cat.RSA_2048]:
            result = cpps.results.get(alg)
            if result is not None:
                dfs.append(result.data)
        assert dfs != []
        df = pd.concat(dfs)
    else:
        df = result.data
            

    if len(df.dropna(subset=["p", "n"])) < 50:
        return None

    return Heatmap(df, tpm_name,  pqnf=compute_pqn_bytes, fig=fig, legend=False, ticks=False, parts=['title', 'heatmap', 'text'], part_height_ratios=[1, 0.1], text_font_size = 9, title_font_size=5, label_values=False).build()

def create_multiplot(tpms, nrows, ncols, figsize=(8.3, 11.7), figtype='heatmap', subfig_count=None, skip_without_keys=False, ek_hist=[], alg=None, tpm_pcr_metadata=None):
    # Creating the figure with a constrained layout to avoid axes overlapping
    fig = plt.figure(layout='constrained', figsize=figsize, dpi=800)
    #fig.get_layout_engine().set(w_pad=0, h_pad=0, hspace=0, wspace=0)
    GridSpec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure= fig)
    subfig_count = ncols * nrows
    count = 0
    row = 0
    col = 0
    for i, cpps in enumerate(tpms):
        if count >= subfig_count:
            break
        
        subfig = fig.add_subfigure(GridSpec[row, col])
        subfig.set_facecolor("none")

        cpps2fig = cpps2fig_heatmap if figtype == 'heatmap' else cpps2fig_bihist
        hist = None
        if ek_hist != [] and len(ek_hist) > i:
            hist = ek_hist[i]
        
        if cpps2fig(cpps, subfig, alg=alg, tpm_pcr_metadata=tpm_pcr_metadata, ek_hist=hist) is not None:
            if col + 1 >= ncols:
                row += 1
                col  = 0
            else:
                col += 1
            count += 1
        else:
            print(f"Plot for {cpps.device_name} failed")

    return fig

In [None]:
groups = [
    [
        r'INTC 2\.0\..*',          
        r'INTC 10\.0\..*',          
        r'INTC 11\.[0-8]\..*',     
        r'INTC 30[2-3]\.12\.0\.0',  
    ],
    [
        r'INTC 40[1-2]\.1\..*',    
    ],
    [
        r'INTC 403\.1\.0\.0',       
    ],
    [
        r'INTC 500\..*',    
        r'INTC 600\..*'     
    ]
]
intc_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
groups = [
    [
        r'IFX 5\.61\..*',
        r'IFX 7\.40\..*',
        r'IFX 7\.61\..*'
    ],
    [
        r'IFX 7\.6[2-3]\..*',
        r'IFX 5\.63\.13\.6400'

    ],
    [
        r'IFX 7\.8[3-5]\..*' 
    ],
]
ifx_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
groups = [
    [
        r'AMD .*'
    ],
    [
        r'NTC .*'
    ],
    [
        r'STM .*'
    ],
    [
        r'MSFT .*'
    ]
]
ansm_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
f = create_multiplot(intc_aggre, 1, 4, figsize=(8.3, 2))

In [None]:
f.savefig("intc.pdf", dpi=300, bbox_inches="tight")

In [None]:
f=create_multiplot(ifx_aggre, 1, 4, figsize=(8.3, 2))

In [None]:
f.savefig("ifx.pdf", dpi=300, bbox_inches="tight")

In [None]:
f=create_multiplot(ansm_aggre, 1, 4, figsize=(8.3, 2))

In [None]:
f.savefig("amd-ntc-stm-msft.pdf", dpi=300, bbox_inches="tight")

In [None]:
f=create_multiplot(intc_aggre, 1, 4, figsize=(8.3, 2), figtype='bihist', tpm_pcr_metadata=tpm_pcr_metadata)

In [None]:
f.savefig("intc-ek-rsk.pdf", dpi=300, bbox_inches="tight")

In [None]:
f=create_multiplot(ifx_aggre, 1, 4, figsize=(8.3, 2), figtype='bihist', tpm_pcr_metadata=tpm_pcr_metadata)

In [None]:
f.savefig("ifx-ek-rsk.pdf", dpi=300, bbox_inches="tight")

In [None]:
f=create_multiplot(ansm_aggre, 1, 4, figsize=(8.3, 2), figtype='bihist', tpm_pcr_metadata=tpm_pcr_metadata)

In [None]:
f.savefig("amd-ntc-stm-msft-ek-rsk.pdf", dpi=300, bbox_inches="tight")