In [1]:
import itertools
from tqdm import tqdm
from binascii import hexlify, unhexlify
import Crypto.Random.random as random
from Crypto.Util.number import bytes_to_long, long_to_bytes

In [3]:
key = 0x31c7112c5238
sbox = []
pbox = []
keys = []
pc_key = [2, 13, 16, 37, 34, 32, 21, 29, 15, 25, 44, 42, 18, 35, 5, 38, 39, 12, 30, 11, 7, 20,
          17, 22, 14, 10, 26, 1, 33, 46, 45, 6, 40, 41, 43, 24, 9, 47, 4, 0, 19, 28, 27, 3, 31, 36, 8, 23]

In [157]:
def gen_box():
    global sbox, pbox
    sbox = []
    for i in range(8):
        sbox_i = []
        _ = list(range(16))
        for j in range(4):
            random.shuffle(_)
            sbox_i += _
        sbox.append(sbox_i)
    pbox = list(range(32))
    while True:
        random.shuffle(pbox)
        branch = 0
        for i in range(0, 32, 4):
            for j in range(i, i + 4):
                if not i <= pbox[j] < i + 4:
                    branch += 1
        if branch >= 24:
            break

def gen_key(key):
    global keys
    key_bin = bin(key)[2:].rjust(48, '0')
    for i in range(6):
        key_bin = ''.join([key_bin[pc_key[j]] for j in range(48)])
        sub_key = int(key_bin, 2)
        keys.append(sub_key)

In [53]:
def s(x, i):
    row = ((x & 0b100000) >> 4) + (x & 1)
    col = (x & 0b011110) >> 1
    return sbox[i][(row << 4) + col]

def p(x):
    x_bin = [int(_) for _ in bin(x)[2:].rjust(32, '0')]
    y_bin = [x_bin[pbox[i]] for i in range(32)]
    y = int(''.join([str(_) for _ in y_bin]), 2)
    return y

def e(x):
    x_bin = bin(x)[2:].rjust(32, '0')
    y_bin = ''
    idx = -1
    for i in range(8):
        for j in range(idx, idx + 6):
            y_bin += x_bin[j % 32]
        idx += 4
    return int(y_bin, 2)

def F(x, k):
    x_in = bin(e(x) ^ k)[2:].rjust(48, '0')
    y_out = ''
    for i in range(0, 48, 6):
        x_in_i = int(x_in[i:i+6], 2)
        y_out += bin(s(x_in_i, i // 6))[2:].rjust(4, '0')
    y_out = int(y_out, 2)
    y = p(y_out)
    return y

def enc_block(x):
    x_bin = bin(x)[2:].rjust(64, '0')
    l, r = int(x_bin[:32], 2), int(x_bin[32:], 2)
    for i in range(6):
        l, r = r, l ^ F(r, keys[i])
    y = (l + (r << 32)) & ((1 << 64) - 1)
    return y

def enc(pt):
    pad_len = (8 - len(pt) % 8) % 8
    pt += b'\x00' * pad_len
    ct = b''
    for i in range(0, len(pt), 8):
        ct_block = long_to_bytes(
            enc_block(bytes_to_long(pt[i:i+8]))).rjust(8, b'\x00')
        ct += ct_block
    return ct

In [54]:
# Differential distribution
def gen_dif_dist():
    global dif_dist
    dif_dist = []
    keys = list(itertools.product(range(64), repeat=2))
    for i in range(8):
        dif_dist_i = dict()
        for key in keys:
            dif_dist_i[key] = 0
        for (x, x_ast) in keys:
            x_dif = (x ^ x_ast) & 0b111111
            y_dif = (s(x, i) ^ s(x_ast, i)) & 0b1111
            dif_dist_i[(x_dif, y_dif)] += 1
        dif_dist.append(dif_dist_i)

In [14]:
# 2-round iterative differential feature (sbox[i]~sbox[j])
def find_path(pre, i, j):
    global dif_dist
    sub_pro = 1
    max_pro = sub_pro
    cur_path, path = None, []
    for key in dif_dist[i].keys():
        if key[1] == 0 and ((key[0] & 0b110000) >> 4) == (pre & 0b000011):
            if i == j and (key[0] & 0b000011) != 0:
                continue
            value = dif_dist[i][key]
            if value % 64 == 0:
                continue
            if i < j:
                path, sub_pro = find_path(key[0], i + 1, j)
            if value * sub_pro > max_pro:
                max_pro = value * sub_pro
                cur_path = [key[0]] + path
    if not cur_path:
        return None, 0
    else:
        return cur_path, max_pro

In [15]:
def inv_e(x_in):
    x_in = bin(x_in)[2:].rjust(48, '0')
    x = ''
    for i in range(0, 48, 6):
        x += x_in[i+1:i+5]
    x = int(x, 2)
    return x

In [16]:
def input_dif(path, left, right):
    x_in = ''
    for i in range(2):
        x_in += bin(path[i])[2:].rjust(6, '0')
    x_in = '0' * left * 6 + x_in + '0' * (7 - right) * 6
    x = inv_e(int(x_in, 2))
    return (x << 32)

In [22]:
def filter_pair(pt_dif, left, right):
    filt = hex(pt_dif)[2:].rjust(16, '0')[:8]
    cts = []
    i_shift = 60 - right * 4
    j_shift = (i_shift - 3) % 32 + 32
    k_shift = (i_shift + 4 * (right - left + 1)) % 32 + 32
    for i in tqdm(range(2**(4*(right-left+1)))):
        for j in range(2**3):
            for k in range(2**3):
                pt = (i << i_shift) + (j << j_shift) + (k << k_shift)
                pt_ast = pt ^ pt_dif
                ct = enc_block(pt)
                ct_ast = enc_block(pt_ast)
                ct_dif = ct ^ ct_ast
                if hex(ct_dif)[2:].rjust(16, '0')[8:] == filt:
                    cts.append((ct, ct_ast))
    return cts

In [166]:
# Find satisfied dif-features
def gen_features(left, right):
    global sbox, pbox, dif_dist
    max_pro = 0
    while max_pro < 192: # find suitable sbox
        gen_box()
        gen_dif_dist()
        _, pro = find_path(0b000000, left, right)
        if pro > max_pro:
            print(pro)
            max_pro = pro
            path = _
            sbox_sat = sbox
            pbox_sat = pbox
            dif_dist_sat = dif_dist
    print((path, max_pro))
    sbox, pbox, dif_dist = sbox_sat, pbox_sat, dif_dist_sat
    pt_dif = input_dif(path, left, right)
    # print(hex(pt_dif)[2:].rjust(16, '0'))
    cts = filter_pair(pt_dif, left, right)
    return cts, (64**(right-left+1))/max_pro

In [181]:
def crack_part_key(cts, left, right, m):
    key_num = 2**(6*(right-left+1))
    sub_key = [0] * key_num
    if len(cts) > m:
        cts = cts[len(cts)//2-m//2:len(cts)//2+m//2]
    for (ct, ct_ast) in tqdm(cts):
        ctl = ct >> 32
        ctr = ct & ((1 << 32) - 1)
        ctl_ast = ct_ast >> 32
        ctr_ast = ct_ast & ((1 << 32) - 1)
        for i in range(key_num):
            pro_key = i << (42 - 6 * right)
            if (F(ctr, pro_key) ^ F(ctr_ast, pro_key) ^ ctl ^ ctl_ast) == 0:
                sub_key[i] += 1
    corr_num = max(sub_key)
    pro_part_key = []
    for i in range(key_num):
        if sub_key[i] == corr_num:
            pro_part_key.append(i)
    return pro_part_key

In [27]:
def crack_key():
    pro_key = []
    for left in range(0, 8, 2):
        cts, max_pro = gen_features(left, left + 1)
        c = 40
        m = c * int(max_pro)
        pro_part_key = crack_part_key(cts, left, left + 1, m)
        pro_key.append(pro_part_key)
    return pro_key

In [29]:
gen_key(key)

In [182]:
pro_key = crack_key()

72
84
100
112
  1%|          | 3/256 [00:00<00:09, 28.10it/s]192
([11, 60], 192)
100%|██████████| 256/256 [00:09<00:00, 27.57it/s]
100%|██████████| 32/32 [00:12<00:00,  2.63it/s]
48
60
64
80
100
112
120
176
  1%|          | 3/256 [00:00<00:08, 28.37it/s]192
([11, 60], 192)
100%|██████████| 256/256 [00:09<00:00, 28.12it/s]
100%|██████████| 16/16 [00:05<00:00,  2.81it/s]
24
80
112
140
  1%|          | 3/256 [00:00<00:08, 29.48it/s]216
([7, 60], 216)
100%|██████████| 256/256 [00:08<00:00, 28.56it/s]
100%|██████████| 92/92 [00:32<00:00,  2.80it/s]
64
80
96
120
128
144
168
  1%|          | 3/256 [00:00<00:08, 28.91it/s]224
([3, 60], 224)
100%|██████████| 256/256 [00:09<00:00, 28.40it/s]
100%|██████████| 38/38 [00:13<00:00,  2.81it/s]


In [183]:
print(pro_key)

[[150, 170, 598, 618], [156, 160, 604, 608], [665, 677, 857, 869], [342, 362, 406, 426]]


In [184]:
hex(keys[-1])[2:].rjust(12, '0')

'0aa25c2a51aa'

In [185]:
inv_pc_key = [pc_key.index(i) for i in range(48)]

In [170]:
# sub_key(bin_str) has 48-bits
def dec_block(y, sub_key):
    y_bin = bin(y)[2:].rjust(64, '0')
    l, r = int(y_bin[:32], 2), int(y_bin[32:], 2)
    for i in range(6):
        l, r = r, l ^ F(r, int(sub_key, 2))
        sub_key = ''.join([sub_key[inv_pc_key[j]] for j in range(48)])
    x = (l + (r << 32)) & ((1 << 64) - 1)
    return x

In [171]:
def dec(ct, sub_key):
    assert(len(ct) % 8 == 0)
    pt = b''
    for i in range(0, len(ct), 8):
        pt_block = long_to_bytes(dec_block(bytes_to_long(ct[i:i+8]), sub_key)).rjust(8, b'\x00')
        pt += pt_block
    return pt

In [186]:
sub_key = list(itertools.product(pro_key[0], pro_key[1], pro_key[2], pro_key[3]))

In [187]:
for i in range(len(sub_key)):
    sk = 0
    for j in range(4):
        sk += (sub_key[i][j] << (36 - 12 * j))
    sub_key[i] = bin(sk)[2:].rjust(48, '0')


In [188]:
flag = b'WMCTF{2_r0und_1T3r@t1v3_D1ffer3n7i4l_f34tur3!!!}'
ct = enc(flag)
for sk in tqdm(sub_key):
    pt = dec(ct, sk)
    if b'WMCTF' in pt:
        print(pt)
        break

40%|████      | 103/256 [00:00<00:00, 467.31it/s]b'WMCTF{2_r0und_1T3r@t1v3_D1ffer3n7i4l_f34tur3!!!}'

