In [1]:
import random
import sys
import collections
from Crypto.Util.number import bytes_to_long, long_to_bytes

In [2]:
def xor(a_, b_):
    return bytes(a ^ b for a, b in zip(a_, b_))

[https://gist.github.com/tejainece/4dd9fb65645745edbcef1cac18f7e695](https://gist.github.com/tejainece/4dd9fb65645745edbcef1cac18f7e695)

In [3]:
class Params:
    # clearly a mathematician and not a programmer came up with these names
    # because a dozen single-letter names would ordinarily be insane
    w = 32              # word size
    n = 624             # degree of recursion
    m = 397             # middle term
    r = 31              # separation point of one word
    a = 0x9908b0df      # bottom row of matrix A
    u = 11              # tempering shift
    s = 7               # tempering shift
    t = 15              # tempering shift
    l = 18              # tempering shift
    b = 0x9d2c5680      # tempering mask
    c = 0xefc60000      # tempering mask

def undo_xor_rshift(x, shift):
    ''' reverses the operation x ^= (x >> shift) '''
    result = x
    for shift_amount in range(shift, Params.w, shift):
        result ^= (x >> shift_amount)
    return result

def undo_xor_lshiftmask(x, shift, mask):
    ''' reverses the operation x ^= ((x << shift) & mask) '''
    window = (1 << shift) - 1
    for _ in range(Params.w // shift):
        x ^= (((window & x) << shift) & mask)
        window <<= shift
    return x

def temper(x):
    ''' tempers the value to improve k-distribution properties '''
    x ^= (x >> Params.u)
    x ^= ((x << Params.s) & Params.b)
    x ^= ((x << Params.t) & Params.c)
    x ^= (x >> Params.l)
    return x

def untemper(x):
    ''' reverses the tempering operation '''
    x = undo_xor_rshift(x, Params.l)
    x = undo_xor_lshiftmask(x, Params.t, Params.c)
    x = undo_xor_lshiftmask(x, Params.s, Params.b)
    x = undo_xor_rshift(x, Params.u)
    return x

def upper(x):
    ''' return the upper (w - r) bits of x '''
    return x & ((1 << Params.w) - (1 << Params.r))

def lower(x):
    ''' return the lower r bits of x '''
    return x & ((1 << Params.r) - 1)

def timesA(x):
    ''' performs the equivalent of x*A '''
    if x & 1:
        return (x >> 1) ^ Params.a
    else:
        return (x >> 1)

In [4]:
keys = []
seen = collections.deque(maxlen=Params.n)
with open("flag.jpg.enc", 'rb') as c, open("flag.jpg.partial", 'rb') as p:
    for n in range(Params.n):
        ct = c.read(4)
        pt = p.read(4)
        key = bytes_to_long(xor(ct, pt))
        keys.append(key)
        seen.append(untemper(key))
    while True:
        next_val = seen[-Params.n + Params.m] ^ timesA(upper(seen[-Params.n]) | lower(seen[-Params.n + 1]))
        seen.append(next_val)
        keys.append(temper(next_val))
        if len(c.read(4)) < 4:
            break

In [5]:
idx = 0
with open("flag.jpg.enc", 'rb') as c, open("flag.jpg", "wb") as d:
    while True:
        ct = c.read(4)
        if len(ct) < 4:
            break
        pt = xor(ct, long_to_bytes(keys[idx], 4))
        d.write(pt)
        idx += 1