# ML-DSA FIA Demo: Attacker
**See the figures below for a graphical representation of the [Loop Abort Attack](https://eprint.iacr.org/2016/449.pdf) and [Loop Abort Strikes Back Attack](https://tches.iacr.org/index.php/TCHES/article/view/11170/10609).**

**Relevant parameters in the last figure.**

<img src="img/loop_abort_attack.png" alt="Loop Abort Attack" width="800" />

<img src="img/loop_abort_attack_strikes_back.png" alt="Loop Abort Attack Strikes Back" width="800" />

<img src="img/parameters.png" alt="Parameters" width="800" />

## Imports

In [1]:
from attacker import Attacker

from helper import DILITHIUM, FAULTY_SIG_FILE_PATH, FAULTY_SIGS_FILE_PATH, info

import numpy as np

from pulp import LpAffineExpression, LpBinary, LpConstraint, LpConstraintLE, LpConstraintGE, LpInteger, LpMaximize, \
    LpProblem, LpSolverDefault, LpStatusOptimal, lpSum, LpVariable

from tqdm import tqdm

# Turn off messages from ILP solver
LpSolverDefault.msg = False

## Verify

In [2]:
attacker = Attacker()
attacker.load_pk()

### Verify a valid signature

In [3]:
msg, sig = attacker.load_sig()
attacker.verify(msg, sig)

[32mverify(msg="Message 1", sig="2bd16b43ea..."): True[0m


True

### Verify an invalid signature

In [4]:
msg = b'Message X'
attacker.verify(msg, sig)

[31mverify(msg="Message X", sig="2bd16b43ea..."): False[0m


False

## Step 1: Recover the secret key vector `s1`
<span style="color:green">**Steps 1-A and 1-B achieve the same goal. You can solve any or both of them to continue to Step 2.**</span>

### Step 1-A

#### Load a faulty signature that has all coefficients of all polynomials in `y` set to `0`

In [5]:
msg, sig = attacker.load_faulty_sig_y_all_coefficients_zero()
attacker.verify(msg, sig)

[32mverify(msg="Have fun solving this", sig="b080da3109..."): True[0m


True

#### Recover `s1` from a single faulty signature `z = y + c * s1` where `y = 0`
<span style="color:red">**Fix the code!**</red>

In [6]:
def recover_s1_from_single_faulty_signature(faulty_sig: bytes) -> None:
    c_tilde, z, _ = DILITHIUM._unpack_sig(faulty_sig)
    c = DILITHIUM._sample_in_ball(c_tilde)

    '''
    Notes:
        * A signature z = y + c * s_1
        * A single faulty signature has y = 0
        * Compute the inverse of a polynomial p using `p_inv = attacker.poly_inverse(p)`
        * Multiply a vector of polynomials `v` by a polynomial using `v.scale(p)`
    '''
    
    # Remove or comment
    raise Exception('Step 1-A: add your solution here')

    # Add your code here
    s1 = None
    
    return s1

#### Run `recover_s1_from_single_faulty_signature` and check the hash of the recovered `s1` against the hash of the expected `s1`

In [7]:
attacker.s1 = recover_s1_from_single_faulty_signature(sig)
attacker.check_hash_of_recovered_s1()

[32mCheck hash of recovered s1: True[0m


True

### Step 1-B

#### Load multiple faulty signatures that have some coefficients of the polynomials in `y` set to `0`

In [8]:
msgs, sigs = attacker.load_faulty_sigs_y_some_coefficients_zero()
result = attacker.verify_sigs(msgs, sigs, display_result=False)

Verifying 10 signatures: 100%|██████████| 10/10 [00:00<00:00, 119.72it/s]

[32mSuccessfully verified 10 signatures![0m


#### Recover `s1` from multiple signatures `z = y + c * s1` where each `y` has some polynomial coefficients set to `0`
<span style="color:red">**Fix the code!**</red>

In [9]:
def recover_s1_from_multiple_faulty_signatures(faulty_sigs: list[bytes]) -> None:
    s1 = []
    for poly_idx in range(DILITHIUM.l):
        info(f'Recovering polynomial {poly_idx}')
        solution = recover_single_poly(faulty_sigs, poly_idx)
        poly = [DILITHIUM.R(solution)]
        s1.append(poly)
    s1 = DILITHIUM.M(s1)
    return s1

def recover_single_poly(faulty_sigs: list[bytes], poly_idx: int = 0) -> list[int]:
        # The paper suggests setting K to 2 * beta + gamma_1, but it can be gamma_1
        # K = 2 * DILITHIUM.beta + DILITHIUM.gamma_1
        K = DILITHIUM.gamma_1
        eta = DILITHIUM.eta

        model = LpProblem(name='dilithium-ilp', sense=LpMaximize)
        # Create variables for the secrets
        num_secrets = DILITHIUM.n
        s = [LpVariable(f's_{i}', -eta, eta, LpInteger) for i in range(num_secrets)]

        eq_idx = 0
        x = []
        for sig in tqdm(faulty_sigs, desc='Constructing ILP from faulty sigs'):
            # Recover the c from the signature
            c_tilde, z, h = DILITHIUM._unpack_sig(sig)
            c = DILITHIUM._sample_in_ball(c_tilde)
            c = np.asarray(c.coeffs)[::-1]

            z_target = z[poly_idx][0].coeffs
            for j, z_val in enumerate(z_target):
                if abs(z_val) > DILITHIUM.beta:
                    continue
                # Add constraint for each faulty candidate coefficient
                row = np.hstack([c[-(j + 1):], -c[: -(j + 1)]])
                # Sign of lhs is flipped for convenience, but should not affect results
                lhs = LpAffineExpression(zip(s, row), -z_val)
                var = LpVariable(f'x_{eq_idx}', cat=LpBinary)
                rhs = K * (1 - var)

                '''
                Notes:
                    * The ILP constraints
                        * Constraint 1: z_m - C_m * s <= K * (1 - x_m)
                        * Constraint 2: z_m - C_m * s >= -K * (1 - x_m)
                    * Terms
                      * `lhs` is z_m - C_m * s
                      * `rhs` is K * (1 - x_m)
                    * Specify constraints
                        * For `expression <= 0` use `LpConstraint(expression, LpConstraintLE)`
                        * For `expression >= 0` use `LpConstraint(expression, LpConstraintGE)`
                '''

                # Remove or comment
                raise Exception('Step 1-B: add your solution here')

                # Add your code here
                c1 = None
                c2 = None
                
                model.addConstraint(c1, f'eqn {eq_idx} constraint 1')
                model.addConstraint(c2, f'eqn {eq_idx} constraint 2')

                # Add the var for this equation to overall list
                x.append(var)
                eq_idx += 1

        # Objective - maximize number of satisfied equations
        model += lpSum(x)

        s_sol = None
        print('Solving ILP ...', end=None)
        model.solve()
        print('Done')
        if model.status == LpStatusOptimal:
            s_sol = [int(var.value()) for var in s]

        return s_sol

#### Run `recover_s1_from_multiple_faulty_signatures` and check the hash of the recovered `s1` against the hash of the expected `s1`

In [10]:
attacker.s1 = recover_s1_from_multiple_faulty_signatures(sigs)
attacker.check_hash_of_recovered_s1()

[34mRecovering polynomial 0[0m


Constructing ILP from faulty sigs: 100%|██████████| 10/10 [00:00<00:00, 76.03it/s]

Solving ILP ...


Done
[34mRecovering polynomial 1[0m


Constructing ILP from faulty sigs: 100%|██████████| 10/10 [00:00<00:00, 56.47it/s]

Solving ILP ...


Done
[34mRecovering polynomial 2[0m


Constructing ILP from faulty sigs: 100%|██████████| 10/10 [00:00<00:00, 70.02it/s]

Solving ILP ...


Done
[34mRecovering polynomial 3[0m


Constructing ILP from faulty sigs: 100%|██████████| 10/10 [00:00<00:00, 72.52it/s]

Solving ILP ...


Done
[32mCheck hash of recovered s1: True[0m


True

## Step 2: Forge a signature
<span style="color:green">**Steps 2-A and 2-B achieve the same goal. You can solve any or both of them to continue to Step 3.**</span>

### Step 2-A: Forge a signature using the algorithm presented at CHES 2018
<span style="color:red">**Fix the code!**</red>

In [11]:
def forge_signature_ches2018(msg: bytes) -> bytes:
    pk = attacker.pk
    s1 = attacker.s1
    m = msg

    # unpack the public key
    rho, t1 = DILITHIUM._unpack_pk(pk)

    # Generate matrix A ∈ R^(kxl)
    A = DILITHIUM._expandA(rho, is_ntt=True)

    # Compute hash of the public key
    tr = DILITHIUM._h(pk, 32)

    # Set seeds and nonce (kappa)
    mu = DILITHIUM._h(tr + m, 64)
    kappa = 0
    rho_prime = DILITHIUM._h(mu, 64)

    # Precompute NTT representation
    s1_hat = s1.copy_to_ntt()
    s1 = s1.copy_to_ntt()

    # Compute u
    u = (A @ s1_hat).from_ntt() - t1.scale(1 << DILITHIUM.d)

    alpha = DILITHIUM.gamma_2 << 1
    while True:
        y = DILITHIUM._expandMask(rho_prime, kappa)
        y_hat = y.copy_to_ntt()

        kappa += DILITHIUM.l

        w = (A @ y_hat).from_ntt()

        # Extract out both the high and low bits
        w1, w0 = w.decompose(alpha)

        # Create challenge polynomial
        w1_bytes = w1.bit_pack_w(DILITHIUM.gamma_2)
        c_tilde = DILITHIUM._h(mu + w1_bytes, 32)
        c = DILITHIUM._sample_in_ball(c_tilde)

        # Store c in NTT form
        c_hat = c.copy_to_ntt()

        z = y + s1.scale(c_hat).from_ntt()

        if z.check_norm_bound(DILITHIUM.gamma_1 - DILITHIUM.beta):
            continue

        '''
        Notes:
            * Hint h = MakeHint(w_0 - c * s_2 + c * t_0, w_1)
            * u = A * s_1 - t_1 * 2 ** d = t_0 - s_2
            * x = w_0 - c * s_2 + c * t_0 = w_0 + c * (t_0 - s_2) = w_0 + c * u
            * Multiply a vector of polynomials `v` by a polynomial using `v.scale(p)`
        '''

        # Remove or comment
        raise Exception('Step 2-A: add your solution here')

        # Add your code here to compute `x` as w_0 + c * u. Polynomial `c` and vector `u` are defined/computed
        # in the code above this line.
        x = w0
        
        x.reduce_coefficents()
        h = DILITHIUM._make_hint(x, w1, alpha)

        sig_bytes = DILITHIUM._pack_sig(c_tilde, z, h)

        if not DILITHIUM.verify(pk, m, sig_bytes):
            continue

        return sig_bytes

#### Run `forge_signature_ches2018` and verify the generated signature

In [12]:
msg = b'This is fun'
forged_sig = forge_signature_ches2018(msg)
attacker.verify(msg, forged_sig)

[32mverify(msg="This is fun", sig="a7a4e83417..."): True[0m


True

### Step 2-B: Forge a signature using the algorithm presented at AsiaCCS 2019
<span style="color:red">**Fix the code!**</red>

In [13]:
def forge_signature_asiaccs2019(msg: bytes) -> bytes:
    pk = attacker.pk
    s1 = attacker.s1
    m = msg

    # unpack the public key
    rho, t1 = DILITHIUM._unpack_pk(pk)

    # Generate matrix A ∈ R^(kxl)
    A = DILITHIUM._expandA(rho, is_ntt=True)

    # Compute hash of the public key
    tr = DILITHIUM._h(pk, 32)

    # Set seeds and nonce (kappa)
    mu = DILITHIUM._h(tr + m, 64)
    kappa = 0
    rho_prime = DILITHIUM._h(mu, 64)

    # Precompute NTT representation
    s1_hat = s1.copy_to_ntt()

    t1_prime = t1.scale(1 << DILITHIUM.d)
    t1_prime = t1_prime.to_ntt()

    alpha = DILITHIUM.gamma_2 << 1
    while True:
        y = DILITHIUM._expandMask(rho_prime, kappa)
        y_hat = y.copy_to_ntt()

        kappa += DILITHIUM.l

        w = (A @ y_hat).from_ntt()

        # Extract out both the high and low bits
        w1, w0 = w.decompose(alpha)

        # Create challenge polynomial
        w1_bytes = w1.bit_pack_w(DILITHIUM.gamma_2)
        c_tilde = DILITHIUM._h(mu + w1_bytes, 32)
        c = DILITHIUM._sample_in_ball(c_tilde)

        # Store c in NTT form
        c.to_ntt()

        z = y + s1_hat.scale(c).from_ntt()

        if z.check_norm_bound(DILITHIUM.gamma_1 - DILITHIUM.beta):
            continue

        z = z.to_ntt()

        matrix = [[DILITHIUM.R([0 for _ in range(DILITHIUM.n)])] for _ in range(DILITHIUM.k)]
        h = DILITHIUM.M(matrix)

        '''
        Notes:
            * w_{1}^{'} = UseHint(h, w_{approx}^{'})
            * w_{approx}^{'} = A * z - c * t_1 * 2 ** d
            * Multiply a vector of polynomials `v` by a polynomial using `v.scale(p)`
        '''
        # Remove or comment
        raise Exception('Step 2-B: add your solution here')

        # Add your code here to compute `wa_prime` as A * z - c * t_1 * 2 ** d. Polynomial `c` and
        # vector `t1_prime` (equal to t_1 * 2 ** d) are defined/computed in the code above this line.
        wa_prime = (A @ z)
        
        wa_prime.from_ntt()
        w1_prime = DILITHIUM._use_hint(h, wa_prime, alpha)
        for i in range(0, DILITHIUM.k):
            for j in range(0, DILITHIUM.n):
                if w1_prime[i][0].coeffs[j] != w1[i][0].coeffs[j]:
                    h[i][0].coeffs[j] = 1

        w1_prime = DILITHIUM._use_hint(h, wa_prime, alpha)
        if w1_prime != w1 or DILITHIUM._sum_hint(h) > DILITHIUM.omega:
            continue

        z = z.from_ntt().from_montgomery()

        return DILITHIUM._pack_sig(c_tilde, z, h)

#### Run `forge_signature_asiaccs2019` and verify the generated signature

In [14]:
msg = b'This is fun'
forged_sig = forge_signature_asiaccs2019(msg)
attacker.verify(msg, forged_sig)

[32mverify(msg="This is fun", sig="a7a4e83417..."): True[0m


True