In [1]:
import re
import glob
import hashlib
from tqdm import tqdm, trange

from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad

In [41]:
EMPTY = np.array([False] * (312 * 64), dtype=bool)
state = None

def add(v1, v2):
    return [x ^^ y for x, y in zip(v1, v2)]

def _lshift_vec(vec, k):
    return [EMPTY] * k + list(vec)[:-k]

def _rshift_vec(vec, k):
    return list(vec)[k:] + [EMPTY] * k

def _mask_mask(vec, mask):
    res = []
    for i in range(64):
        k = (mask >> i) & 1
        if k:
            res.append(vec[i])
        else:
            res.append(EMPTY)
    return res

def _mask_const(term, const):
    res = []
    for i in range(64):
        if (const >> i) & 1:
            res.append(term)
        else:
            res.append(EMPTY)
    return res

def _process_vec(y):
    y = add(y, _mask_mask(_rshift_vec(y, 29), 0x5555555555555555))
    y = add(y, _mask_mask(_lshift_vec(y, 17), 0x71d67fffeda60000))
    y = add(y, _mask_mask(_lshift_vec(y, 37), 0xfff7eee000000000))
    y = add(y, _rshift_vec(y, 43))
    return y

In [3]:
const = 0xb5026f5aa96619e9

In [19]:
def twist():
    global state
    # print("[+] twist()", flush=True)
    for i in range(312):
        # print("[*]", state[i], state[(i + 1) % 312], flush=True)
        # y = _mask_mask(state[(i + 1) % 312], lower) + _mask_mask(state[i], upper)
        state[i] = state[(i + 1) % 312][1:31] + state[i][31:] + [EMPTY]
        state[i] = add(state[i], state[(i + 156) % 312])
        state[i] = add(state[i], _mask_const(state[(i + 1) % 312][0], const))

In [5]:
print("Loading data", flush=True)
with open("res", "r") as fin:
    bits = list(map(int, fin.readline().strip()))
print("bits:", bits[:100], flush=True)

Loading data
bits: [1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1]


In [13]:
bstart = False
dt = {}
for i, c in enumerate(bits):
    if not bstart:
        if c == 0:
            dt[i] = 0
    else:
        if c == 1:
            dt[i] = 1
    bstart = bool(int(c))

In [None]:
import copy
import numpy as np
from tqdm import trange
from time import process_time

print("Initialising variables", flush=True)
state = []
for i in range(312):
    tmp = []
    for j in range(64):
        _tmp = copy.deepcopy(EMPTY)
        _tmp[i * 64 + j] = True
        tmp.append(_tmp)
    state.append(tmp)

coef_mat = Matrix(GF(2), len(dt) * 2, 312 * 64)
rhs = vector(GF(2), len(dt) * 2)

print(f'{len(dt) = }')

print("Calculating coefficients", flush=True)
idx = 0

cnt = 0
start_time = process_time()

for i in range(len(bits)):
    if idx == 0:
        twist()

    if i in dt or i == 0:
        y = state[idx]
        y = _process_vec(y)
        msb1, msb2 = y[63], y[62]
        
        if i == 0:
            print([j for j in range(312 * 64) if msb1[j]])
            print([j for j in range(312 * 64) if msb2[j]])
            idx = (idx + 1) % 312
            continue
    
        # extract coefficients from msb1
        rhs[cnt * 2] = dt[i]
        for j in range(312 * 64):
            if msb1[j]:
                coef_mat[cnt * 2, j] = 1
        
        # extract coefficients from msb2
        rhs[cnt * 2 + 1] = dt[i]
        for j in range(312 * 64):
            if msb2[j]:
                coef_mat[cnt * 2 + 1, j] = 1
        
        cnt += 1
        if cnt % 100 == 0:
            print("[*] cnt =", cnt, "/", len(dt), ", rank =", coef_mat.rank(), ", took", process_time() - start_time, "seconds")

    idx = (idx + 1) % 312

print("Done!")

Initialising variables
len(dt) = 11962
Calculating coefficients
[56, 64, 74, 91, 9993, 10010, 10039, 10047]
[46, 63, 64, 90, 10009, 10029, 10046]
[*] cnt = 100 / 11962 , rank = 200 , took 2.858980999999858 seconds
[*] cnt = 200 / 11962 , rank = 400 , took 5.621541999999863 seconds
[*] cnt = 300 / 11962 , rank = 600 , took 7.548411000000215 seconds
[*] cnt = 400 / 11962 , rank = 800 , took 10.392158999999992 seconds
[*] cnt = 500 / 11962 , rank = 1000 , took 12.304619000000002 seconds
[*] cnt = 600 / 11962 , rank = 1200 , took 15.064030000000002 seconds
[*] cnt = 700 / 11962 , rank = 1400 , took 17.884378000000197 seconds
[*] cnt = 800 / 11962 , rank = 1600 , took 19.707887000000028 seconds
[*] cnt = 900 / 11962 , rank = 1800 , took 22.400462999999945 seconds
[*] cnt = 1000 / 11962 , rank = 2000 , took 24.253646000000117 seconds
[*] cnt = 1100 / 11962 , rank = 2200 , took 26.97199899999987 seconds
[*] cnt = 1200 / 11962 , rank = 2400 , took 29.871228999999857 seconds
[*] cnt = 1300 / 11

In [None]:
# print("TESTING")
# correct_recover = ''.join(map(str, rhs[64:1000]))
# our_recover = ''.join(map(str, (coef_mat * correct_coef)[64:1000]))
# print("Correct:", correct_recover)
# print(" We got:", our_recover)
# assert correct_recover == our_recover

In [None]:
print("Started solving equations...", flush=True)

res = coef_mat.solve_right(rhs)
print("solution:", res, flush=True)

ker = coef_mat.right_kernel()
print("kernel:", ker, flush=True)

In [None]:
print(coef_mat.dimensions())

In [None]:
basis = ker.basis()
for vec in basis:
    s = ''.join(map(str, vec))
    print(s[:20], s[-20:], s.count('1'))

In [None]:
required = set()
for vec in res_mat:
    for idx in range(312 * 64):
        if vec[idx] == 1:
            required.add(idx)

In [None]:
for vec in res_mat:
    print([idx for idx in range(312 * 64) if vec[idx] == 1])

In [None]:
basis = list(basis)
basis = [vec for vec in basis if len(set(idx for idx in range(len(vec)) if vec[idx]) & required) > 0]
for vec in basis:
    s = ''.join(map(str, vec))
    print(s[:20], s[-20:], s.count('1'))

In [None]:
def AES_decrypt(key, iv, ct):
    key = hashlib.md5(key).digest()
    iv = hashlib.md5(iv).digest()
    cipher = AES.new(key=key, iv=iv, mode=AES.MODE_CBC)
    msg = cipher.decrypt(ct)
    try:
        return unpad(msg, 16).decode()
    except ValueError:
        return

for _null in span(basis):
    _sol = res + _null
    _recovered = res_mat * _sol
    key = ''.join(map(str, _recovered))[::-1]
    key = int(key, 2).to_bytes(len(key) // 8, 'big')
    print(AES_decrypt(key, iv, ct))