# Cryptanalysis of TS-Hash

In [1]:
from sage.all import *
import itertools
import random

from binteger import Bin

from tshash import TSHash

## Extended form

In [2]:
for i in range(10):
    T = TSHash(n=128)
    print(T.p0_raw)
    msg = Bin.random(300)
    
    emsg = T.extend(msg)
    lp = T._last_pos
    assert T.compress(emsg) == msg
    assert T._last_pos == lp
    
    h1 = T.calc_raw(msg)
    h2 = T.fromF(T.calc(msg))
    h3 = T.fromF(T.calc_ext(emsg))
    h4 = T.fromF(T.calc_ext_fast(emsg))
    assert h1 == h2 == h3 == h4

10100001101110001000110000010000010110000110000101010011110011011100110011010110010011001100100010111000001110101001011101111101
10011011001010001100111110110000000000010111101010010010100000110111100010101001001100100110101110111111000110010000111010110101
10000111101101111011000100111111110111011010110000100000101001001000110111100000111001111101011001101000001110110011110000001111
11010010111001011110101111100110000110010101001010010101001010011100111110000110111100011110110010010100011111011010010110110000
10001100111111000011011011010111111101010111001101010001000100101011001100100011110111011100011111010101001110000100010011011000
11010010101001000010111111101010100110111100100100010110110011010111111111011101001000110001110010101110010010111000010010000001
11010100011000110001111011100000000110000110111001111100110100010010011110111010000001000111010000110101010100010111000001111100
1010101110101110101110000101011011110001101111001110101001000010101100110001101010100110111111100

## Algebraic expression

In [3]:
for _ in range(10):
    n = 64
    T = TSHash(n=128)
    print(T.p0_raw)

    emsg = [None] + [randrange(2) for _ in range(50)]
    for N in range(14):
        res = 0
        for I in Bin.iter(N):
            I = sorted(i+1 for i in I.support)
            if not I:
                res += T.s0 * T.g**N
            else:
                # the monomial  prod(mi)
                cur = prod(emsg[i] for i in I)

                # field constant
                cur *= T.h * T.g**(N - max(I))

                # starting control bit
                cur *= (T.alpha * T.s0 * T.g**(min(I)-1)).trace()

                # monomial steps control bits
                for ij1, ij in zip(I, I[1:]):
                    cur *= (T.alpha * T.h * T.g**(ij-ij1-1)).trace()
                
                res += cur
        assert res == T.calc_ext(emsg[1:1+N])

11000001001000001111001010100011110001110010010011100010011110110100100011101111011110011100000000010111110000000001100111110111
11100111000100110010010001110111001000101100010011100011000110011101001000110011110001100011000010001100010100001111100101110110
10010000010101000100100111100100100111100100001001000010011010001100011010000001110111101010111101000010110100100011010100000000
11000011001110110100111100001011010001100100100111001110001100101110001010100011011111110110110010111000100011110101101000011100
11110010011111101010011010010001100101110110111001100011100010000100011011000000011111101011010110111011110000110111001100111101
11010010111010011110101110010110110011100001110011111111000100101000001011100110010000110000100110010010100010001010010101000001
11001101001101001000011100011011110110001100101001001001010000111011100000110100001011100010000011011001101110110001000100110010
1111111110001011101111000001111110011001001111001110010100100010000001000011000000001010010101111

# ANF monomials

In [4]:
from cry.sbox2 import SBox2

nb = 10

print("expected", 1.5**nb)

nmonos = []

for _ in range(100):
    T = TSHash(n=32)
    
    y = []
    for x in Bin.iter(nb):
        x = Bin(x.tuple[::-1])
        x = x.list[::-1]
        xx = [x.pop()]
        for i in range(1, 1000):
            if not x:
                break
            xx += [0] * i
            xx += [x.pop()]
            
        hh = T.fromF(T.calc_ext(xx))
        y.append(hh)
    #print("".join(map(str, xx)))
    
    #print("degs", SBox2(y).degrees())
    monos = set()
    for a in SBox2(y, m=32).anfs():
        monos |= set(a.monomials())
    #print("monos", len(monos))
    #print()
    
    # for a in SBox2(y).anfs()[:25]:
    #     print(a)
    nmonos.append(len(monos))
    print("avg", "%.1f" % (sum(nmonos)/len(nmonos)), "median", sorted(nmonos)[len(nmonos)//2])

expected 57.6650390625


See https://github.com/sagemath/sage/issues/32709 for details.
  for a in SBox2(y, m=32).anfs():
See https://github.com/sagemath/sage/issues/32709 for details.
  return type(self)(mobius(self.tuple()), m=self.m)
See https://github.com/sagemath/sage/issues/32709 for details.
  return type(self)(tt, m=1)


avg 74.0 median 74
avg 115.0 median 156
avg 104.0 median 82
avg 98.0 median 82
avg 100.8 median 82
avg 90.8 median 82
avg 90.7 median 82
avg 113.5 median 90
avg 105.6 median 82
avg 102.0 median 82
avg 97.6 median 80
avg 97.7 median 82
avg 92.1 median 80
avg 88.1 median 80
avg 87.2 median 75
avg 88.1 median 80
avg 86.9 median 75
avg 84.4 median 75
avg 84.1 median 75
avg 82.2 median 75
avg 79.5 median 74
avg 76.6 median 74
avg 75.9 median 70
avg 77.0 median 74
avg 76.3 median 70
avg 75.4 median 70
avg 76.1 median 70
avg 74.6 median 70
avg 73.7 median 67
avg 73.2 median 67
avg 76.4 median 67
avg 75.4 median 67
avg 74.0 median 60
avg 73.6 median 63
avg 72.4 median 60
avg 75.6 median 63
avg 75.2 median 60
avg 74.0 median 60
avg 74.1 median 60
avg 73.3 median 60
avg 72.8 median 60
avg 72.1 median 60
avg 71.4 median 59
avg 71.0 median 59
avg 71.3 median 59
avg 71.4 median 60
avg 72.5 median 60
avg 72.2 median 60
avg 71.8 median 60
avg 73.7 median 60
avg 72.8 median 60
avg 73.2 median 60
avg 7

# High-probability differentials

## Last bit

In [5]:
for _ in range(100):
    T = TSHash(n=128)
    m1 = Bin.random(100)
    m2 = m1 ^ 1
    assert T.calc_raw(m1) == T.calc_raw(m2) ^ T.h_raw

## 2nd-to-last

In [6]:
probs = []
for _ in range(10):
    while True:
        T = TSHash(n=128)
        if (T.alpha * T.h).trace() == 0:
            break
            
    ngood = 0
    for t in range(1000):
        m1 = Bin.random(100)
        m2 = m1 ^ 2
        ngood += (T.calc_raw(m1) == T.calc_raw(m2) ^ T.fromF(T.h * T.g))
    prob = ngood/1000
    probs.append(prob)
    print(prob)
    assert 0.4 < prob < 0.6
print("Prob (%):", *[f"{prob*100:.1f}" for prob in probs])
print("Avg (%):", f"{sum(probs)/len(probs)*100:.1f}")

0.498
0.482
0.483
0.503
0.513
0.521
0.521
0.509
0.517
0.514
Prob (%): 49.8 48.2 48.3 50.3 51.3 52.1 52.1 50.9 51.7 51.4
Avg (%): 50.6


## 3rd-to-last

In [7]:
probs = []
for _ in range(10):
    while True:
        T = TSHash(n=128)
        if (T.alpha * T.h).trace() == (T.alpha * T.h * T.g).trace() == 0:
            break
    
    #print(T.h_raw)
    ngood = 0
    for t in range(1000):
        m1 = Bin.random(100)
        m2 = m1 ^ 4
        ngood += (T.calc_raw(m1) == T.calc_raw(m2) ^ T.fromF(T.h * T.g**2))
    prob = ngood/1000
    probs.append(prob)
    print(prob)
print("Prob (%):", *[f"{prob*100:.1f}" for prob in probs])
print("Avg (%):", f"{sum(probs)/len(probs)*100:.1f}")

0.244
0.269
0.255
0.28
0.233
0.285
0.271
0.235
0.266
0.249
Prob (%): 24.4 26.9 25.5 28.0 23.3 28.5 27.1 23.5 26.6 24.9
Avg (%): 25.9


# Generalized birthday

## Target parameters

In [8]:
n = 80
T = TSHash(n=n, seed=100)

a = b"\xaa" * (n//8)
a = (Bin(a).int | 2**(n-1))  # the MSB has to be equal to 1 for the prev. state to be controlled
target = T.toF(a)

assert T.fromF((target-T.h)/T.g) & 1 == 1
        
print("target", T.fromF(target).hex)
#c0*g^(e1+e2+e3) + h*(g^(e2+e3) + g^e3 + 1) = target

target aaaaaaaaaaaaaaaaaaaa


## Attack parameters

In [9]:
levels = 4  # depth
nchunks = 2**levels
bits = (n + levels - 1) // levels
chunk_size = int(2**(n / levels) * 2)

print("Attack parameters:")
print("  number of monomials:", 2**levels, "= # merges = # final candidates to check")
print("  chunk size:", chunk_size, "= cost of 1 merge operation =~ message length")
print("  bits:", bits, "filter per merge")
print()

emax = chunk_size
tar = (target - T.s0 * T.g**emax ) / T.h - 1
int_tar = T.fromF(tar)

print("chunk size", chunk_size, math.log(chunk_size, 2))

Attack parameters:
  number of monomials: 16 = # merges = # final candidates to check
  chunk size: 2097152 = cost of 1 merge operation =~ message length
  bits: 20 filter per merge

chunk size 2097152 21.0


## Functions

In [10]:
def message_fit(T, emax, gpoly, add_emax=True):
    """Reconstruct message given list of iterations where P1 was added"""
    gpoly = sorted(gpoly, reverse=True)
    iters = [emax-e for e in gpoly]
    if add_emax:
        iters.append(emax)

    v = T.s0_raw
    itr = 0
    ii = 0
    msg = []
    for ei in iters:
        while itr < ei:
            while v & 1 == 0:
                v >>= 1
                itr += 1
            v >>= 1
            v ^= T.p0_raw
            itr += 1
            msg.append(0)
            
        if itr != ei:
            print("failed:", itr, ">", ei)
            return
    
        msg[-1] = 1
        v ^= T.p0_raw ^ T.p1_raw
    return Bin(msg)

In [11]:
EFAIL = None
def message_fit_check(T, emax, gpoly, debug=0):
    """Qucikly check possibility of a given list of iterations where MASK1 was added,
    without reconstructing the message"""
    global EFAIL
    
    gpoly = sorted(gpoly, reverse=True)
    iters = [emax-e for e in gpoly]
    iters.append(emax)
    if debug:
        print(iters)
    v = T.s0
    itr = 0
    ii = 0
    prev = 0
    for eii, ei in enumerate(iters):
        eadd = ei - prev
        prev = ei
        v = v*T.g**(eadd-1)
        if T.tr(v) != 1:
            EFAIL = eii
            if debug:
                print("failed at:", eii, "itr=", ei, "val", fromF(v))
            return False
        v = v * T.g + T.h        
    return True  

In [12]:
def merge(chunk1, chunk2, mask, target=0, limit=float("inf"), debug=0):
    """Merge procedure for the Wagner's method."""
    ab = []
    tab = {}
    for bb, eb in chunk1:
        key = bb & mask
        if key not in tab:
            tab[key] = []
        tab[key].append((bb, eb))
    
    if debug:
        cnt = Counter()
        for bb, eb in chunk1:
            key = bb & mask
            cnt[key] += 1
        print("distrib chunk1", cnt.most_common(10))
        cnt = Counter()
        for bb, eb in chunk2:
            key = bb & mask
            cnt[key] += 1
        print("distrib chunk2", cnt.most_common(10))

    for aa, ea in chunk2:
        key = (aa ^ target) & mask
        for bb, eb in tab.get(key, ()):
            #assert (aa ^ bb) & mask == target & mask
            if aa != bb and (chunk1 is not chunk2 or aa < bb):
                ab.append((aa ^ bb, ea + eb))
            #if len(ab) >= limit:
#                return ab
    shuffle(ab)
    return ab[:limit]

## Attack

### Base list generation

In [13]:
val = Bin.unit(0, n).int
e = 0
chunk = []
mask = Bin.full(bits).resize(n).int
for j in range(chunk_size):
    if len(chunk) < chunk_size:
        chunk.append((val, (e,)))
    if val & 1:
        val >>= 1
        val ^= T.p0_raw.int
    else:
        val >>= 1
    e += 1
chunk = chunk[2*n:]  # avoid basic zeroes (they are good but need handling)
shuffle(chunk)
len(chunk)
assert T.toF(val) == T.g**chunk_size

### Merges

In [14]:
z = chunk
t = chunk
for i in range(levels):
    print("LEVEL", i, "->", i+1, "/", levels)
    print(len(z), len(t), "->")
    mask = Bin.full(bits*(i+1)).resize(n).int
    t = merge(t, z, mask, target=int_tar, limit=chunk_size)
    if i != levels-1:
        z = merge(z, z, mask, target=0, limit=chunk_size)
    print(len(z), len(t))

LEVEL 0 -> 1 / 4
2096992 2096992 ->
2097152 2096813
LEVEL 1 -> 2 / 4
2097152 2096813 ->
2095001 2097152
LEVEL 2 -> 3 / 4
2095001 2097152 ->
2092542 2097152
LEVEL 3 -> 4 / 4
2092542 2097152 ->
2092542 2097152


### Candidate testing

In [15]:
for isol, (res, ids) in enumerate(t):
    if message_fit_check(T, emax, ids):
        print()
        print("good!", isol)
        print()
        break


good! 44808



In [16]:
msg = message_fit(T, emax, t[isol][1])
len(msg)

1049351

In [17]:
msg.support

(44810,
 167928,
 267593,
 283893,
 372291,
 380101,
 420588,
 431687,
 543158,
 611284,
 779121,
 815516,
 836280,
 858023,
 865300,
 1047853,
 1049350)

In [18]:
assert msg == sum(2**(msg.n-1-i) for i in msg.support)

In [19]:
T.calc_raw(msg).hex

'aaaaaaaaaaaaaaaaaaaa'

# Linearization

In [20]:
from linearization import Linearization

In [21]:
T = TSHash(n=80)
L = Linearization(T)
L.linearize_prefix(N=10)

found s 255
found l 8344


Optional: display ANF (requires module `cry` in sage):

In [22]:
from cry.sbox2 import SBox2

y = []
for x0 in Bin.iter(L.prefix_N):
    emsg = L.eval_prefix_raw(x0)
    msg = L.T.compress(emsg)
    if x0 < 10:
        print(x0, len(emsg), "->", len(msg))
    y.append(emsg)

anfs = SBox2(y).anfs()

for a in anfs[:10]:
    print(a)
print("...")

monos = set()
for a in anfs:
    monos |= set(a.monomials())

sorted(monos)

0000000000 80 -> 41
0000000001 80 -> 38
0000000010 80 -> 42
0000000011 80 -> 39
0000000100 80 -> 35
0000000101 80 -> 48
0000000110 80 -> 47
0000000111 80 -> 36
0000001000 80 -> 43
0000001001 80 -> 43
1
x1 + x4 + x5 + x7 + x8 + x9 + 1
x0 + x2 + x7
x2 + x3 + x4 + x9
x0 + x2 + x4 + x6 + x7 + 1
x3 + x4 + x6 + x7 + x8 + 1
x2 + x3 + x4 + x6 + x7 + x8 + x9
x4 + x5 + x6
x0 + x1 + x4 + x5 + x6 + x7 + 1
x1 + x2 + x3 + x5 + x6 + x9 + 1
...


See https://github.com/sagemath/sage/issues/32709 for details.
  anfs = SBox2(y).anfs()


[1, x9, x8, x7, x6, x5, x4, x3, x2, x1, x0]