In [6]:
from Crypto.Cipher import AES
from Crypto.Util.number import *
import json
import os
import string

import galois # https://github.com/mhostetter/galois/tree/a23bff69a59d24db372621721b5c1d8f1741b2aa
import z3     # https://github.com/Z3Prover/z3

import pwn

In [7]:
# Class from challenge.

ACCEPTABLE = string.ascii_letters + string.digits + string.punctuation + " "

class GoodHash:
    def __init__(self, v=b""):
        self.key = b"goodhashGOODHASH"
        self.buf = v

    def update(self, v):
        self.buf += v

    def digest(self):
        cipher = AES.new(self.key, AES.MODE_GCM, nonce=self.buf)
        enc, tag = cipher.encrypt_and_digest(b"\0" * 32)
        return enc + tag

    def hexdigest(self):
        return self.digest().hex()

In [38]:
# Useful functions.
def bit_rev(a):
    """Bit reverse a number."""
    strv = bin(int(a))[2:][::-1]
    strv += "0" * (128-len(strv))
    return int(strv, 2)

def simplified_ghash (h,txt,debug_print=False):
    """Simplified ghash without auth text."""
    org_len = len(txt)
    x = 0
    # Pad text
    txt = txt + bytes((16-len(txt)%16)%16)
    for i in range(0,len(txt),16):
        x = bit_rev((gf(bit_rev(h)) * gf(bit_rev(x ^ bytes_to_long(txt[i:i+16])))).base)
        if debug_print:
            print(f"{i:2d}: {hex(x)}")
    return bit_rev((gf(bit_rev(h)) * gf(bit_rev(x ^ org_len*8))).base)
 

## Overview
GCM mode: Galois/Counter Mode.

Only nonce change. Therefore the attack in on the IV generation.

IV == GHASH(nonce) for len(nonce) != 96

Therefore collision needs to be generated in GHASH.

I found this [blog](https://jhafranco.com/2013/05/31/aes-gcm-implementation-in-python/) to be a helpful reference.

In [8]:
# Setup the field. This is what most of the operation works in.
p = galois.irreducible_poly(2,128)
gf = galois.GF(2**128,p)

In [25]:
# Key is constant, therefore H is constant.
obj = AES.new(b"goodhashGOODHASH", AES.MODE_ECB)
h = bytes_to_long(obj.encrypt(bytes(16)))

In [52]:
# Looking at the generation of the IV.
nounce = b'{"token" "8ce334a2107af7e5c0e532f849ad1a23", "admin" false}'
iv = simplified_ghash(h, nounce, debug_print=True)
print(f"IV: {hex(iv)}")

 0: 0x415af00b3f895ec2a58423f1367d0958
16: 0x60a704b985285e30ba94d7c6048e916
32: 0x1e5bf91166e6de354e113956a35a7105
48: 0x5ac2e15cde928526b3ebd11cae75fccd
IV: 0x35cf410279885675e1c2b0ead53ae24f


## Generating collision in the GHASH functions.

Easy parts:
* Maintain same length: 'false' -> 'true ' so last op stays the same.
* Have free bytes: the second 16 byte block of the nounce is entirely user specified.

Annoying parts:
* Acceptable chars only.

### Hand wavey method

To propage error backwards, we can divide in GF(2).

Find changes to lower bits of last chars of token nounce[32:42] that will result in all lower bit error in the nouce[16:32]. That way we ensure the final data is in the acceptable char.

To find the a solution to above, we find the effect of each bit, we are willing to change in nounce[32:42], on nounce[16:32]. The we find a combitnation of these (using z3) that when xor with nounce[16:32] will still result in acceptable chars.

In [70]:
bad_nounce = b'{"token" "8ce334a2107af7e5c0e532f849ad1a23", "admin" true }'
bad_iv = simplified_ghash(h, bad_nounce, debug_print=True)
print(bad_nounce)
print(f"IV: {hex(bad_iv)}")
print()

error = bad_iv ^ iv
back_error_1 = bit_rev((gf(bit_rev(error)) / gf(bit_rev(h))).base)
back_error_2 = bit_rev((gf(bit_rev(back_error_1)) / gf(bit_rev(h))).base)
back_error_3 = bit_rev((gf(bit_rev(back_error_2)) / gf(bit_rev(h))).base)
back_error_4 = bit_rev((gf(bit_rev(back_error_3)) / gf(bit_rev(h))).base)

print(f"err_4: {hex(back_error_4)}") # This could simply be xor with nounce[16:32] but that would result in unaccepted chars.
print(f"err_3: {hex(back_error_3)}")
print(f"err_2: {hex(back_error_2)}") # Note how this is just the error introduce by 'false' -> 'true '
print(f"err_1: {hex(back_error_1)}")
print()

# Basic POC ignoring acceptable chars.
poc_nounce = bad_nounce[:16] + long_to_bytes(bytes_to_long(bad_nounce[16:32]) ^ back_error_4) + bad_nounce[32:]
poc_iv = simplified_ghash(h, poc_nounce, debug_print=True)
print(poc_nounce)
print(f"IV: {hex(poc_iv)} == {hex(iv)}")

 0: 0x415af00b3f895ec2a58423f1367d0958
16: 0x60a704b985285e30ba94d7c6048e916
32: 0x1e5bf91166e6de354e113956a35a7105
48: 0xf618f65773c6ad8cb69180dab5da69ed
b'{"token" "8ce334a2107af7e5c0e532f849ad1a23", "admin" true }'
IV: 0x98ed09b3b153714aece0e31c1252bf2f

err_4: 0x3358ff80ee634e06d95d4f2e6811d06c
err_3: 0xc38828e8fb29e9cd9a786146f7d9d2ad
err_2: 0x1213191645000000000000
err_1: 0xacda170bad5428aa057a51c61baf9520

 0: 0x415af00b3f895ec2a58423f1367d0958
16: 0xc58258a3637b6c2e91d12c3a97913bbb
32: 0x1e5bf91166f4cd2c58543956a35a7105
48: 0x5ac2e15cde928526b3ebd11cae75fccd
b'{"token" "8ce334Rj\xce\xb0\xd9\x02(1\xbch,\x1e\r$\xe3^f849ad1a23", "admin" true }'
IV: 0x35cf410279885675e1c2b0ead53ae24f == 0x35cf410279885675e1c2b0ead53ae24f


In [136]:
nounce = b'{"token" "8ce334a2107af7e5c0e532f849ad1a23", "admin" false}'
org_hash = GoodHash(nounce).digest()
print(f"Original nounce: {nounce}")
print(f"Original hash: {org_hash.hex()}")

want_clear = []
want_set = []

# This is the error we expect to propage back to back_error_4.
expected_error_4 = bytes_to_long(nounce[32:48]) ^ back_error_4

# For the 10 charcters that make up nounce[32:42]
for i in range(10):
    # Allowing for changes in the last 5 bits of each char (this will result in a probably acceptable char)
    for j in range(5):
        bit_mask = bit_rev(1 << ((i+6) * 8 + j))
        back_propaged_error = bit_rev((gf(bit_mask) / gf(bit_rev(h))).base)
        
        want_clear.append((back_propaged_error ^ expected_error_4) & bytes_to_long(bytes([0x80] * 16)))
        want_set.append((back_propaged_error ^ expected_error_4) & bytes_to_long(bytes([0x20] * 16)))



def xor_reduce(x,pos):
    y = z3.BitVecVal(0,128)
    for i,x in enumerate(x):
        y ^= z3.If(x, pos[i], z3.BitVecVal(0x00,128))
    return y

clear_data = [z3.BitVecVal(i, 128) for i in want_clear]
set_data = [z3.BitVecVal(i, 128) for i in want_set]
x = [z3.Bool('x%s' % i) for i in range(len(clear_data)) ]

s = z3.Solver()
# We want 8th bit cleared.
s.add(xor_reduce(x,clear_data) == 0)
# We want 6th bit set.
s.add(xor_reduce(x,set_data) == bytes_to_long(bytes([0x20] * 16)))
# Skip the trivial all zero solution.
s.add(z3.Or(x))
s.check()

correction_3 = 0
for i,b in enumerate(x):
    correction_3 |= bool(s.model()[b]) << (((i//5) + 6) * 8 + ((i%5)))
correction_4 = back_error_4 ^ bit_rev((gf(bit_rev(correction_3)) / gf(bit_rev(h))).base)
print(f"correction_4: {hex(correction_4)}")
print(f"correction_3: {hex(correction_3)}")

fake_nounce = (nounce[:16]+
      long_to_bytes(bytes_to_long(nounce[16:32]) ^ correction_4)+
      long_to_bytes(bytes_to_long(nounce[32:48]) ^ correction_3)+
      long_to_bytes(bytes_to_long(nounce[48:]) ^ 0x121319164500))
fake_hash = GoodHash(fake_nounce).digest()
print(f"Fake nounce: {fake_nounce}")
print(f"Fake hash: {fake_hash.hex()}")

assert org_hash == good_hash
assert all([i in ACCEPTABLE for i in good_nounce.decode()])

# This doesn't work 100% of the time but good enough.
print("PASS")

Original nounce: b'{"token" "8ce334a2107af7e5c0e532f849ad1a23", "admin" false}'
Original hash: 54bfb3cd3d2ec069c576a100c1b47970aea5ce244fa021feac97a14942360446eb26ccdef440ac01edc577593b5e160d
correction_4: 0xf115d030110450a114819194007485a
correction_3: 0x501091309100f030006000000000000
Fake nounce: b'{"token" "8ce334n#l36q#=t}z)%2{hc9=*ht>b25", "admin" true }'
Fake hash: 54bfb3cd3d2ec069c576a100c1b47970aea5ce244fa021feac97a14942360446eb26ccdef440ac01edc577593b5e160d
PASS
