# Dilithium Samplers Reference Implementation

In [1]:
import os
import binascii

### Parameters

In [2]:
DILITHIUM_Q = 8380417 # 2**23 - 2**13 + 1
DILITHIUM_N = 256
DILITHIUM_LOGN = 8
DILITHIUM_ROOT_OF_UNITY = 1753

## Category 5 parameters:
DILITHIUM_ETA = 2
DILITHIUM_ETA_BOUND = 15
DILITHIUM_TAU = 60
DILITHIUM_GAMMA1 = 2**19

### SHAKE

In [3]:
from hashlib import shake_128, shake_256

class Shake():
    def __init__(self, algorithm, block_length):
        self.algorithm    = algorithm
        self.block_length = block_length
        self.read_blocks  = 0
        self.read_data    = b""
        
    def absorb(self, input_bytes):
        self.read_blocks  = 0
        self.read_data    = b""
        self.xof = self.algorithm(input_bytes)
        
    def digest(self, input_bytes, length):
        return self.algorithm(input_bytes).digest(length)
        
    def get_n_blocks(self, n):
        byte_count = self.block_length * (self.read_blocks + n)
        xof_data   = self.xof.digest(byte_count)
        self.read_blocks += n
        self.read_data = xof_data[-self.block_length*n:]
    
    def read(self, n):
        if n > len(self.read_data):
            self.get_n_blocks(5*n)
        send = self.read_data[:n]
        self.read_data = self.read_data[n:]
        return send

Shake128 = Shake(shake_128, 168)
Shake256 = Shake(shake_256, 136)

### RejBounded ref model

In [4]:
def rejbounded(seed, i):
    global DILITHIUM_Q
    global DILITHIUM_N
    global DILITHIUM_ETA
    global DILITHIUM_ETA_BOUND
 
    def rejection_sample(xof):
        while True:
            js = []

            # Consider two values for each byte (top and bottom four bits)
            j  = xof.read(1)
            j  = int.from_bytes(j, "little")
            j0 = j & 0x0F
            j1 = j >> 4
            
            # rejection sample
            if j0 < DILITHIUM_ETA_BOUND:
                if DILITHIUM_ETA == 2: j0 %= 5
                js.append((DILITHIUM_ETA - j0) % DILITHIUM_Q)
                
            if j1 < DILITHIUM_ETA_BOUND:
                if DILITHIUM_ETA == 2: j1 %= 5
                js.append((DILITHIUM_ETA - j1) % DILITHIUM_Q)
            
            if js:
                return js
                
    # Initialise the XOF
    seed = seed + int.to_bytes(i, 2, "little")
    Shake256.absorb(seed)

    coeffs = []
    while len(coeffs) < DILITHIUM_N:
        js = rejection_sample(Shake256)
        coeffs += js

    # Remove the last byte if we ended up overfilling
    if len(coeffs) > DILITHIUM_N:
        coeffs = coeffs[:DILITHIUM_N]
    
    return coeffs

### Rejection_q ref model

In [5]:
def rejection_q(seed, i, j):
    global DILITHIUM_Q
    global DILITHIUM_N

    def rejection_sample(xof):
        while True:                
            j_bytes = xof.read(3)
            j = int.from_bytes(j_bytes, "little")
            j &= 0x7FFFFF
            if j < DILITHIUM_Q:
                return j

    # Initialise the XOF
    seed = seed + bytes([j, i])
    Shake128.absorb(seed)
    coeffs = [rejection_sample(Shake128) for _ in range(DILITHIUM_N)]
    return coeffs

### SampleInBall Ref Model

In [6]:
def sample_in_ball(seed):
    global DILITHIUM_Q
    global DILITHIUM_N
    global DILITHIUM_TAU

    def rejection_sample(i, xof):
        while True:
            j = xof.read(1)
            j = int.from_bytes(j, "little")
            if j <= i: 
                return j
    
    # Initialise the XOF
    Shake256.absorb(seed)
    
    # Set the first 8 bytes for the sign, and leave the rest for
    # sampling.
    sign_bytes = Shake256.read(8)
    sign_int = int.from_bytes(sign_bytes, "little")
    
    # Set the list of coeffs to be 0
    coeffs = [0 for _ in range(DILITHIUM_N)]
    
    # Now set tau values of coeffs to be ±1
    for i in range(DILITHIUM_N - DILITHIUM_TAU, DILITHIUM_N):
        j = rejection_sample(i, Shake256)
        coeffs[i] = coeffs[j]
        coeffs[j] = (1 - 2*(sign_int & 1)) % DILITHIUM_Q
        sign_int >>= 1
        
    return coeffs

### Expand_Mask Ref Model

In [7]:
def expand_mask(seed, i, kappa):                            
    global DILITHIUM_Q
    global DILITHIUM_N
    global DILITHIUM_GAMMA1

    if DILITHIUM_GAMMA1 == (1 << 17):
        bit_count = 18
        total_bytes = 576 # (256 * 18) / 8
    else:
        bit_count = 20
        total_bytes = 640 # (256 * 20) / 8
    
    # Initialise the XOF
    seed = seed + int.to_bytes(kappa+i, 2, "little")
    xof_bytes = Shake256.digest(seed, total_bytes)
    r = int.from_bytes(xof_bytes, 'little')
    mask = (1 << bit_count) - 1
    coeffs = [((DILITHIUM_GAMMA1 - ((r >> bit_count*i) & mask)) % DILITHIUM_Q) for i in range(DILITHIUM_N)]
    
    return coeffs

### Test

In [8]:
seed_rnd = os.urandom(64).hex()
seed_fixed = "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
seed = binascii.unhexlify(seed_fixed)
i = 0 # i < k+l = 15
coeffs = rejbounded(seed, 0)

print("\n".join(" ".join(f"{x:6X}" for x in coeffs[i:i + 16]) for i in range(0, len(coeffs), 16)))

     0 7FDFFF 7FE000 7FE000      0 7FE000      0      2      0 7FE000      2 7FE000      2      1 7FE000      2
     1      1      1      0      0 7FE000      1      0      1      2      0      0      1 7FDFFF      2      0
     0      0 7FE000      1 7FDFFF      2 7FE000      1      0 7FDFFF      2 7FDFFF 7FE000 7FE000      0      1
     2 7FE000      2      0      1      0 7FDFFF 7FE000      1 7FDFFF      2 7FE000      2 7FDFFF      2 7FE000
     2 7FE000 7FDFFF      2 7FE000 7FDFFF      2 7FDFFF      2      0 7FDFFF      1      0      2      2      0
     2      1      1 7FDFFF 7FE000      2      1 7FDFFF      1      2      1      2      1 7FDFFF 7FE000      0
     1      1      1      1 7FDFFF      0 7FE000      2      1      0      1      2 7FE000      2 7FDFFF 7FE000
     2      2      0      2 7FE000      2 7FE000      0      0 7FDFFF      0 7FE000 7FDFFF 7FE000 7FE000      1
     1 7FDFFF      2      1      0      1      2      2      2 7FE000 7FE000      2 7FDFFF      0      0

In [9]:
seed_rnd = os.urandom(32).hex()
seed_fixed = "0000000000000000000000000000000000000000000000000000000000000000"
seed = binascii.unhexlify(seed_fixed)
i = 0 # i < k = 8
j = 0 # j < l = 7
coeffs = rejection_q(seed, i, j)

print("\n".join(" ".join(f"{x:6X}" for x in coeffs[i:i + 16]) for i in range(0, len(coeffs), 16)))

59DF49 3C9B80 4A0154 6ACCAB 75199A 2D48ED 19D957 729102  9B601 462A78 6D5EC7 10AAA9 2E342E 4D5FC7 556419 39C8B8
5F75BC 59F799 6A6A16  A1A97 47C0BC 49EF15 529E3D 46698B 11E056 37606B 55A075 32AA95 7A9541 652F69 45A182 2C86C6
5209F4 3102C5 75DC85  25F1F 245B55  939A4  2D95B 3D4747 46C617 1247F3 77F90D 4A5F3A 505508  99F99 1B517C 12CD3A
116B7C 4BE2D4 1DBDEA 450D85 48DEF0 420F06 7E2015 612643 250348 1FBF76 4C10A9 6713EA  7FB5D 73D309 138FF4 7038AC
64351F 5749BA 2D5CDC  4C1A1 597363 7E1C01 53629D 6870C8 52CDA2 72FA23 5FB6AE 4822AD 572105 38F1AF 5A8FA5 1B6112
690176   742C 3B437A 5947D9 1410BC 77125C 1707BF 7D4168 4CE295 180B07 14E13D 254BAA 2FDFAD 3077B1 50AE3A 244903
59814B 3B0A15 161B9F  85C13 3DFDA0 4BB71C 666398 712EBD 6C921A  34BC7 6AA929 2C0A2C 1F3480 7FB3DD 55E229 13AC54
179F34 7EF9D4 7F50C7 36AD97 6CD233  9FA8E 30D973 45591B 5753C7 1588A3 7860FD 5E957E 2C1D79  665EC 79D0BC 264E1D
543F36 18F808 755A6E 48193A 7EE4F2 50EE91 53F012 303B4A 54C826 101F3D 3F2968 287D64  E2188  CE625 1DED2C

In [10]:
seed_rnd = os.urandom(32).hex()
seed_fixed = "0000000000000000000000000000000000000000000000000000000000000000"
seed = binascii.unhexlify(seed_fixed)
coeffs = sample_in_ball(seed)

nonzero_count = sum(1 for elem in coeffs if elem != 0)
print("nonzero_count =", nonzero_count)


print("\n".join(" ".join(f"{x:6X}" for x in coeffs[i:i + 16]) for i in range(0, len(coeffs), 16)))

nonzero_count = 60
     0 7FE000      0      0 7FE000      1      0      0      0      0      1      1      0      0      0      0
     0      1      1      0      0      0      0      1      0 7FE000      0      0      0 7FE000      1      0
     0      0 7FE000      0      0      0 7FE000      0      0      0      0      0 7FE000      0      0      0
7FE000      0      0 7FE000      0      0      1 7FE000      0      0      0      1      0      0      0 7FE000
     0      0      0 7FE000      0      0      1      0      1 7FE000      0      0      0      0      1 7FE000
     1      0      0      0      0      0      1      1 7FE000      0      0      0      0      0      0      0
     0      0 7FE000      0      1      0      0      0      0      0      0      0      0      0      0      0
     0      1 7FE000      0      1      0      0      0      0      0      0      0      0      0      0      1
     0      1      0      0      0      0      0      1      1      0      0      0  

In [11]:
seed_rnd = os.urandom(64).hex()
seed_fixed = "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
seed = binascii.unhexlify(seed_fixed)
i = 0
kappa = 0
coeffs = expand_mask(seed, i, kappa)

print("\n".join(" ".join(f"{x:6X}" for x in coeffs[i:i + 16]) for i in range(0, len(coeffs), 16)))

 57CB4 786F89 7CD75E  6E0F8  62D3A   D5EE  20641 7EE83F 7C5FB6 7A956A 7A8E39  59DF1 7DE4C3  71FCB 7D4CF9 7D4519
7D641A 7CC5DC 7EC5C3 7C39A6 7C2873 7F164B 7C1E07 7F4F77 7F3901 79632C 7F0175  6F64F 7D3FE5 7DE524   5939  16A2A
 28F06 788A42 788084 7EF79D 7EC990 7E7718  38EAE 7B9C95 7D7C9E 7AC838  372FA 7F6C9F   C273  72F3B  75581 780C30
 650E9  4A4AE 7B34B6 77F1FB  7BB5E 7C60CE 7F5CDB 7FC2E7 7BC6AE 7C7CCD  26A0E  2B936  427E8 7D3A2F 7A4036 7F1CB3
79F0EB   1D8D 7C4710  37C3F  73096 7A3CE6 7A01ED  41F44   3010 7FAD97  787F4 7A63C3  2A799 7988BE  6DC23 7AEC0F
 62FC5 7E61ED 7D37C9  7F481 7D07EC 79B0AB 7CD023 7EEBC9  495B5 7E5330  41EA8 7D468D  1EC25 780E2C  6E435 7ED7A5
 5E5AA 7D7A9D 78C509 7FAB0C 7C50C2  710AE  5B50C  22481 7900A4  26B35  5595F  33F93   F384  1C93D  56273 793D6A
 4A094 789899 7BB2E5 7FDD59 7D886A 7E2633  58572 7D5D03  58CB5  2FBC3  701C5 7EC79E 7B6853   2A33  5EE57  52229
7AABA1   1075 7EFC52 7D713A  3CC4A 785B84   332F  3CD25  35410 7A6066    36A 78CB6B  339AF  3075F 79642D