# EK and RSK visualizations

## Utility code

### Disable logging

In [None]:
import logging, sys
logging.disable(sys.maxsize)

### Loading the metadata.json

In [None]:
import json

def load_metadata(metadata_path):
    try:
        metadata = {}
        with open(metadata_path, "r") as f:
            metadata = json.load(f)

        assert metadata
        entries = metadata["entries"].values()
        assert 0 < len(entries)
    except:
        print("report_create: retrieving metadata was unsuccessful")
        return {}

    # We now group entries by vendor
    grouped = {}
    for entry in entries:
        vendor = entry.get("vendor")
        if vendor is None:
            print(
                f"report_create: entry {entry} does not contain vendor")
            continue

        grouped.setdefault(vendor, [])
        grouped[vendor].append(entry)
    return grouped

metadata = load_metadata('./metadata.json')

### Loading the EKs and RSKs from TPM_PCR measurements

In [None]:
# Load TPM PCR metadata
with open('tpm-pcr-metadata.json') as f:
    tpm_pcr_metadata = json.load(f)

### Loading the TPM cryptographic properties profiles

In [None]:
from algtestprocess.modules.data.tpm.manager import TPMProfileManager
from algtestprocess.modules.data.tpm.enums import CryptoPropResultCategory as cat

cryptoprops = []
for tpms in metadata.values():
    for entry in tpms:
        for measurement_path in entry['measurement paths']:
            man = TPMProfileManager(measurement_path)
            cpps = man.cryptoprops
            if cpps is None:
                continue
            cryptoprops.append(cpps)

### Merging of profiles for same firmwares

In [None]:
def firmware_name(profile):
    return f"{profile.manufacturer} {profile.firmware_version}"

def group_by_vendor(cpps):
    grouped = {}
    for cpp in cpps:
        manufacturer = cpp.manufacturer
        grouped.setdefault(manufacturer, [])
        grouped[manufacturer].append(cpp)
    return grouped

def aggregate_by_firmware(cpps):
    aggregated = {}
    for cpp in cryptoprops:
        if cpp is None:
            print(f"Cannot merge {cpp.firmware_name}, as it is None")
            continue

        fwn = firmware_name(cpp)
        if fwn not in aggregated:
            aggregated.setdefault(fwn, cpp)
        else:
            aggregated[fwn] = aggregated[fwn] + cpp
    return [cpp for cpp in aggregated.values()]
    
# All tpms in one list
tpms = cryptoprops
# All tpms grouped by vendor
tpms_grouped_by_vendor = group_by_vendor(tpms)
# All tpms aggregated by the tpm name
tpms_aggregated = aggregate_by_firmware(cryptoprops)
# All tpms aggregated by the tpm name and grouped by vendor
tpms_aggregated_grouped_by_vendor = group_by_vendor(tpms_aggregated)

### Profile sorting utilities

In [None]:
import re
from math import inf

def tpm_sorted(profiles, device_name):
    """
    Sorts the profiles according to manufacturer id alphabetically, then
    firmware version numerically

    Assumes device name is in the form of rgx
    """
    RGX = r"(\s*.+)+\s\s*\d+(\.\d+)*(\s[\[]\d+[\]])?"
    try:
        assert all([re.match(RGX, device_name(p)) is not None for p in profiles])
    except AssertionError:
        print("These device names do not match format")
        print([name for p in profiles if not re.match(RGX, (name := device_name(p)))])

    def key_f(profile):
        manufacturer = version = idx = inf
        numbers = [inf] * 4
        l, r = device_name(profile).rsplit(maxsplit=1)

        if re.match(r"[\[]\d+[\]]", r):
            idx = int(r.replace("[", "").replace("]", ""))
            manufacturer, firmware = l.rsplit(maxsplit=1)
        else:
            manufacturer, firmware = l, r

        numbers = [int(x) for x in filter(None, firmware.split("."))]

        return [manufacturer] + numbers + [idx]

    return sorted(profiles, key=key_f)

### Firmware name list to identifier

In [None]:
def firmwarelist2id(firmware_versions):
    # First we group the same major version TPMs
    firmware_versions = sorted(firmware_versions, key=lambda x: [int(y) for y in x.split('.')])
    versions = {}
    for fv in firmware_versions:
        major = fv.split('.')[0]
        if versions.get(major) is None:
            versions.setdefault(major, [])
        
        versions[major].append(fv)
        

    # Then we start building the result
    result = ""
    for major, entries in versions.items():
        if len(result.split('\n')[0]) > 12:
            result += "\n"
        
        if result != "":
            result += " "
        if len(entries) == 1:
            result += entries[0]
        else:
            fst = entries[0].split('.')[1]
            lst = entries[-1].split('.')[1]

            if fst == lst:
                result += f"{major}.{fst}.X"
            else:
                result += f"{major}.{fst}-{lst}.X"

    # STM specific edit which should not be here
    if '1.258-769.X' in result:
        result = result.replace('1.258-769.X', '') + ' 1.258-769.X'

    return result

### Finding improbable EKs

Since we sometimes have a sample of 1000 keys generated on-chip or more from one machine together with its EK/RSK, we flag a EK where the most significant byte of it has not appeared in most significant bytes of the 1000 keys generated on-chip.

In [None]:
improbable_ek = {}
for cpp in tpms:
    ek_msb = None
    if (result := cpp.results.get(cat.EK_RSA)) is not None:
        if (ek := result.data) is not None:
            ek_prefix, _ = ek[0]
            ek_msb = ek_prefix >> 8

    n_msb = None
    if (result := cpp.results.get(cat.RSA_2048)) is not None:
        if (rsa_df := result.data) is not None:
            rsa_df = rsa_df.dropna(subset=["p", "q", "n"])
            n_msb = [x >> x.bit_length() - 8 for x in list(map(lambda x: int(x, 16), rsa_df["n"]))]

    fwn = firmware_name(cpp)
    
    if ek_msb is not None and n_msb is not None:
        if ek_msb not in n_msb:
            improbable_ek.setdefault(fwn, [])
            improbable_ek[fwn].append(ek_msb)      

### Add ROCA vulnerable TPMs as improbable EKs


In [None]:
roca_blacklist = [
    r'IFX.*7\.40\..*',
    r'IFX.*7\.61\..*',
    r'IFX.*5\.61\..*',
]

In [None]:

for cpp in tpms:
    for entry in roca_blacklist:
        fwn = firmware_name(cpp)
        
        if re.match(entry, fwn):
            #print(fwn)
                
            ek_msb = None
            if (result := cpp.results.get(cat.EK_RSA)) is not None:
                if (ek := result.data) is not None:
                    ek_prefix, _ = ek[0]
                    ek_msb = ek_prefix >> 8

            if ek_msb is not None:
                print(f"Marking value {ek_msb} from {fwn} as improbable EK")
                improbable_ek.setdefault(fwn, [])
                improbable_ek[fwn].append(ek_msb)

for entry in roca_blacklist:
    for fwn, value in tpm_pcr_metadata.items():
        if re.match(entry, fwn):
            print(fwn)
            improbable_ek.setdefault(fwn, [])
            improbable_ek[fwn] += [int(x, 16) >> 8 for x in value['EK']]
improbable_ek
del improbable_ek['NTC 7.2.3.1']

### Bihistogram visualization

In [None]:
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pandas as pd
import numpy as np
import seaborn as sns
import json
from copy import copy
import re
import matplotlib.gridspec as gridspec

plt.rcParams.update({
    "text.usetex": True,
    "ytick.color" : "black",
    "xtick.color" : "black",
    "axes.labelcolor" : "black",
    "axes.edgecolor" : "black",
    "font.family" : "serif",
    "font.serif" : ["Computer Modern Serif"]
})

def bihist(rsa_df, tpm_name, ek=None, rsk=None, fig=None, ek_hist=False):

    gs = gridspec.GridSpec(2, 1, height_ratios=(0.75, 0.25), wspace=0, hspace=0)


    
    top_hist_ax = fig.add_subplot(gs[0])
    top_hist_ax.set_xlim(128, 255)
    top_hist_ax.set_xticks([128, 255])
    top_hist_ax.locator_params(axis='y', nbins=3)
    top_hist_ax.tick_params(labelbottom=False)
    top_hist_ax.tick_params(axis='y', which='major', labelsize=7)
    top_hist_ax.tick_params(axis='y', which='minor', labelsize=7)
    top_hist_ax.set_facecolor('#ffffff')


    circles_ax = fig.add_subplot(gs[1])
    circles_ax.set_yticks([])
    circles_ax.sharex(top_hist_ax)
    circles_ax.set_ylim(0, 1)
    circles_ax.set_xlabel(
        tpm_name,
        fontsize=9,
        color="black",
        fontfamily="serif")
    circles_ax.tick_params(left='off', right='off')
    circles_ax.set_facecolor('#ffffff')

    circles_ax.tick_params(top=False,
               bottom=False,
               left=False,
               right=False,
               labelleft=True,
               labelbottom=True)

    
    n_msb = [x >> x.bit_length() - 8 for x in list(map(lambda x: int(x, 16), rsa_df["n"]))]

    bins = list(range(128, 257, 1))

    top_hist_ax.hist(n_msb, bins=bins,histtype="stepfilled",   color="#000000", ec="none",  density=True, alpha=0.6)

    if ek_hist and ek is not None:
        top_ek_hist_ax = top_hist_ax.twinx()
        sns.histplot(
            ek,
            alpha=0.2,
            bins=127//10,
            ax=top_ek_hist_ax,
            color="#1e88e5"
        )
        top_ek_hist_ax.axes.get_yaxis().set_visible(False)

        
    

    ##############################################################################################################
    # First we need to extract MSBs from each EK
    ek_msb = copy(ek)


    # Count the frequency of each value
    unique_values, value_counts = np.unique(ek_msb, return_counts=True)

    circles_ax.set_yticks([0.25, 0.75], ["EK", "SRK"], fontsize=7)


    
    # Create the histogram-like plot with stacked circles
    for value, count in zip(unique_values, value_counts):
        y_positions = [0.25] * count 

        if '5.61' in tpm_name:
            print(tpm_name, ek)
        
        color = '#1e88e5'
        alpha = 0.40
        if (value < min(n_msb) or value > max(n_msb)) or (value not in n_msb and len(n_msb) > 1000) or \
            any([re.match(pattern_entry, tpm_name)for pattern_entry in roca_blacklist]):
            print(f"Marking ek red for {tpm_name}.")
            color = '#d81b60'
            alpha = 0.65
        
        circles_ax.scatter([value] * len(y_positions), y_positions, s=75, marker='o', c=color, alpha=alpha)
        
    # Lastly print the improbable EKs as red
    if (improbable_eks := improbable_ek.get(tpm_name)) is not None:
        circles_ax.scatter(improbable_ek[tpm_name], len(improbable_eks)*[0.25], s=75, marker='o', c='#BA0700', alpha=0.5)


    #################################################################################################################
    # First we need to extract MSBs from each RSK
    rsk_msb = rsk
    # Count the frequency of each valueB
    unique_values, value_counts = np.unique(rsk_msb, return_counts=True)

    # Create the histogram-like plot with stacked circles
    for value, count in zip(unique_values, value_counts):
        y_positions = [0.75] * count 

        color = '#04ab8f'
        alpha = 0.40
        if (value < min(n_msb) or value > max(n_msb)) or (value not in n_msb and len(n_msb) > 1000):
            color = '#d81b60'
            alpha = 0.65
        
        circles_ax.scatter([value] * len(y_positions), y_positions, s=100, marker='o', c=color, alpha=alpha)
        #circles_ax.vlines(value, ymin=0.5, ymax=1, color="blue", lw=1, alpha=0.5)

    ###################################################################################################################
    RSA_KEY_COUNT = len(n_msb)
    EKs_COUNT  = len(ek)
    RSKs_COUNT = len(rsk)

    title = f"RSA Keys: {RSA_KEY_COUNT}"
    if EKs_COUNT > 0:
        title += f", EKs: {EKs_COUNT}"
    if RSKs_COUNT > 0:
        title += f", SRKs: {RSKs_COUNT}"

    top_hist_ax.set_title(title, fontsize=7)
    
    
EK_COUNT = 0

plutons = [
    ('AMD', '6.24.0.6', True, False, True),
    ('MSFT', '6.3.1.603', False, True, False)
]

def cpps2fig(cpps, fig, tpm_pcr_metadata, skip_without_keys=False, ek_hist=False, alg=None):
    if cpps is None:
        return None

    result = None 
    if alg is not None:
        result = cpps.results.get(alg)
        if result is None:
            return None

    # To be able to produce a figure we need both rsa dataframe and rsa EK

    rsa_df = None
    if alg is None:
        # Special case when alg is None means we should merge the rsa dataframes
        dfs = []
        for alg in [cat.RSA_1024, cat.RSA_2048, cat.RSA_3072]:
            result = cpps.results.get(alg)
            if result is not None:
                dfs.append(result.data)
        assert dfs != []
        rsa_df = pd.concat(dfs)
    else:
        rsa_df = result.data
    rsa_df = rsa_df.dropna(subset=["p", "q", "n"])
    if len(rsa_df.index) < 1:
        return False

    # We retrieve our TPM name
    tpm_name = f"{cpps.manufacturer}"
    if isinstance(cpps.firmware_version, str):
        tpm_name += f" {cpps.firmware_version}"
    elif isinstance(cpps.firmware_version, list):
        tpm_name += f" {firmwarelist2id(cpps.firmware_version)}"
    # Temporary solution to label plutons
    for man, ver, first, last, add_man in plutons:
        if man in tpm_name and ver in tpm_name:
            tpm_name = tpm_name.replace(ver, ('\n' if first else '') +(f"{man} " if add_man else '')+ ver + ('\n' if last else '')+ " (Pluton)")

    # Now we need to get the EKs collected from tpm2-algtest
    # and add the ones found in PCR_TPM measurements
    ek = []

    # In both cases (agg/non-agg) results contain EK list(tuple(4B prefix, 4B suffix))
    if (ek_rsa_result := cpps.results.get(cat.EK_RSA)) is not None:
        if (ek_rsa_result_data := ek_rsa_result.data) is not None:
            # We take each 2B EK prefix which is comes from unique (prefix, suffix) pair
            ek =  [prefix for prefix, _ in list(set(ek_rsa_result_data))]

    # Aggregated result, we need to add EKs from all firmwares in TPM PCR measurements
    if isinstance(cpps.firmware_version, list):
        for fw in cpps.firmware_version:
            fwn = f"{cpps.manufacturer} {fw}"
            if (tpm_pcr_ek := tpm_pcr_metadata.get(fwn)) is not None:
                for prefix_string in tpm_pcr_ek['EK']:
                    if (prefix := int(prefix_string, 16)) not in ek:
                        ek.append(prefix)
    
    # If we have non-aggregated result
    elif (tpm_pcr_ek := tpm_pcr_metadata.get(tpm_name)) is not None:
        for prefix_string in tpm_pcr_ek['EK']:
            if (prefix := int(prefix_string, 16)) not in ek:
                ek.append(prefix)

    # We now just convert all the EKs to MSB
    ek = [x >> 8 for x in ek]
    # Add the count
    global EK_COUNT
    EK_COUNT += len(ek)

    rsk = []
    if (tpm_pcr_rsk := tpm_pcr_metadata.get(tpm_name)) is not None:
        rsk += [int(x, 16) >> 8 for x in tpm_pcr_rsk['RSK']]

    if isinstance(cpps.firmware_version, list):
        for fw in cpps.firmware_version:
            fwn = f"{cpps.manufacturer} {fw}"
            if (tpm_pcr_rsk := tpm_pcr_metadata.get(fwn)) is not None:
                rsk += [int(x, 16) >> 8 for x in tpm_pcr_rsk['RSK']]
            

    if rsk == [] and ek == [] and skip_without_keys:
        return False
    
    bihist(rsa_df, tpm_name, ek=ek, rsk=rsk, fig=fig, ek_hist=ek_hist)
    return True

def create_multiplot(tpms, nrows, ncols, figsize=(8.3, 11.7), subfig_count=None, skip_without_keys=False, ek_hist=[], alg=None):

    # Creating the figure with a constrained layout to avoid axes overlapping
    fig = plt.figure(layout='constrained', figsize=figsize, dpi=800)
    #fig.get_layout_engine().set(w_pad=0, h_pad=0, hspace=0, wspace=0)
    GridSpec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure= fig, hspace=0.1)
    subfig_count = ncols * nrows if subfig_count is None else subfig_count
    count = 0
    row = 0
    col = 0
    for i,cpps in enumerate(tpms):
        if count >= subfig_count:
            break
        
        subfig = fig.add_subfigure(GridSpec[row, col])
        subfig.set_facecolor("none")

        hist = False
        if ek_hist != [] and len(ek_hist) > i:
            hist = ek_hist[i]

        if cpps2fig(cpps, subfig, tpm_pcr_metadata, skip_without_keys=skip_without_keys, ek_hist=hist, alg=alg):
            if col + 1 >= ncols:
                row += 1
                col  = 0
            else:
                col += 1
            count += 1
        else:
            print(f"Plot for {cpps.device_name} failed")
    return fig

## Aggregated by firmware TPM visualization 

In [None]:
# Now plot only those with an EK
EK_COUNT = 0
#create_multiplot(tpms_aggregated, 4, 8, figsize=(22.14, 8.3), skip_without_keys=True)

In [None]:
print(f"Count of Endorsement keys used for visualization: {EK_COUNT}")

In [None]:
#create_multiplot(tpms_aggregated, 8, 8, figsize=(22.14, 16.6), skip_without_keys=False)

## Aggregated by vendor RSA algorithm implementation

In [None]:
def aggregate_by_implementation(cpps, groups):
    compiled_groups = [[re.compile(pattern) for pattern in group] for group in groups]

    aggre = [None] * len(groups)
    for cpp in cpps:
        fwn = firmware_name(cpp)
        for i, group in enumerate(compiled_groups):
            if any([pattern.match(fwn) is not None for pattern in group]):
                if aggre[i] is None:
                    aggre[i] = cpp
                else:
                    aggre[i] += cpp
    return aggre

### INTC

In [None]:
groups = [
    [
        r'INTC 2\.0\..*',           # INTC 2.0.X
        r'INTC 10\.0\..*',          # INTC 10.0.X
        r'INTC 11\.[0-8]\..*',      # INTC 11.0-8.X
        r'INTC 30[2-3]\.12\.0\.0',  # INTC 302-303.12.0.0
    ],
    [
        r'INTC 40[1-2]\.1\..*',     # INTC 40[1-2].1.0.0
    ],
    [
        r'INTC 403\.1\.0\.0',       # INTC 403.1.0.0
    ],
    [
        r'INTC 500\..*',    # INTC 500.5-14.X
        r'INTC 600\..*'     # INTC 600.7-18.X
    ]
]

In [None]:
intc_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
f = create_multiplot(intc_aggre, 1, 4, figsize=(8.3, 2))

In [None]:
f.savefig("intc-ek-rsk.pdf", dpi=300, bbox_inches="tight")

### IFX

In [None]:
groups = [
    [
        r'IFX 5\.61\..*',
        r'IFX 7\.40\..*',
        r'IFX 7\.61\..*'
    ],
    [
        r'IFX 7\.6[2-3]\..*',
        r'IFX 5\.63\.13\.6400'

    ],
    [
        r'IFX 7\.8[3-5]\..*' 
    ],
]
ek_hist=[
    False, True, False
]

In [None]:
ifx_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
f = create_multiplot(ifx_aggre, 1, 4, figsize=(8.3, 2), ek_hist=ek_hist)

In [None]:
f.savefig("ifx-ek-rsk.pdf", dpi=300, bbox_inches="tight")

### NTC

In [None]:
groups = [
    [
        r'NTC 7.2.3.1'
    ]
]

In [None]:
ntc_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
f = create_multiplot(ntc_aggre, 1, 1, figsize=(2, 2))

### STM

In [None]:
groups = [
    [
        r'STM .*'
    ]
]

In [None]:
stm_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
create_multiplot(stm_aggre, 1, 1, figsize=(2, 2))

### AMD

In [None]:
groups = [
    [
        r'AMD .*'
    ]
]

In [None]:
amd_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
create_multiplot(amd_aggre, 1, 1, figsize=(2, 2))

## Prepare histogram background for NTC

In [None]:
groups = [
    [
        r'AMD .*'
    ],
    [
        r'NTC .*'
    ],
    [
        r'STM .*'
    ],
    [
        r'MSFT .*'
    ]
]
ek_hist_nsam = [
    False,
    True,
    False,
    False
]
nsam_aggre = aggregate_by_implementation(tpms_aggregated, groups)

In [None]:
f=create_multiplot(nsam_aggre, 1, 4, figsize=(8.3, 2), ek_hist=ek_hist_nsam)

In [None]:
f.savefig("amd-ntc-stm-msft-ek-rsk.pdf", dpi=300, bbox_inches="tight")

## IFX 5.63.13.6400 EK distribution


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

profile = None
for p in tpms_aggregated:
    if p.firmware_version == '5.63.13.6400':
        profile = p
assert profile is not None

ek = [prefix for prefix, _ in list(set(profile.results.get(cat.EK_RSA).data))]

for prefix_string in tpm_pcr_metadata['IFX 5.63.13.6400']['EK']:
    prefix = int(prefix_string, 16)
    if prefix not in ek:
        ek.append(prefix)

ek = [x >> 8 for x in ek]
fix, ax = plt.subplots()

print(len(ek))
sns.histplot(
    ek,
    bins=127//10,
    stat='density',
    ax=ax
).set(title='IFX 5.63.13.6400',xlabel='MSB values of modulus $n$')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

profile = ntc_aggre[0]

assert profile is not None

ek = [prefix for prefix, _ in list(set(profile.results.get(cat.EK_RSA).data))]

for key, entry in tpm_pcr_metadata.items():
    if re.match(r"^NTC .*", key):
        for prefix_string in entry['EK']:
            prefix = int(prefix_string, 16)
            if prefix not in ek:
                ek.append(prefix)

ek = [x >> 8 for x in ek]
fix, ax = plt.subplots()

print(len(ek))

sns.histplot(
    ek,
    bins=127//10,
    stat='density',
    ax=ax
).set(title='NTC 1.3.X 7.2.X',xlabel='MSB values of modulus $n$')