In [1]:
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers import modes, Cipher
from cryptography.hazmat.backends import default_backend
from os import urandom
import sys

In [2]:
BLOCKSIZE = 16

def xor(x, y):
    # assert len(x) == len(y)
    a = int.from_bytes(x, "big")
    b = int.from_bytes(y, "big")
    r = a ^ b
    return r.to_bytes(len(x), "big")

In [3]:
def unpad(m):
    
    pad = m[-1]
    assert 1 <= pad <= 16, "Wrong padding"
    
    for i in range(1, pad + 1):
        assert m[-i] == pad, "Wrong padding"
    
    return m[:-pad].decode("utf-8")

In [4]:
def pad(pt):
    pad_size = len(pt) % BLOCKSIZE
    
    if pad_size != 0:
        pad = [pad_size for i in range(pad_size)]
        pt += bytes(pad) 
    else:
        pad = [BLOCKSIZE for i in range(BLOCKSIZE)]
        pt += bytes(pad)
    #assert len(pt) % BLOCKSIZE == 0

In [5]:
def AES_DECRYPT(key):
    cipher = Cipher(AES(key), modes.ECB(), backend=default_backend())
    return cipher.decryptor().update


def AES_ENCRYPT(key):
    cipher = Cipher(AES(key), modes.ECB(), backend=default_backend())
    return cipher.encryptor().update

    
def encrypt_cbc(pt_string, k_string, blocksize=16):
    
    pt, k = bytearray(pt_string, 'utf-8'), bytes.fromhex(k_string)
    pad(pt)

    n = len(pt) // BLOCKSIZE
    current = urandom(BLOCKSIZE) #IV
    aes_encrypt = AES_ENCRYPT(k)
    ct = bytearray(current)
    
    
    for i in range(n):
        start, end = i*blocksize, (i+1)*blocksize
        m = pt[start:end]
        d = xor(m, current)
        current = aes_encrypt(d)
        ct += current
    
    return ct.hex()    

In [6]:
def decrypt_cbc(ct_string, k_string):
    
    ct, k = bytes.fromhex(ct_string), bytes.fromhex(k_string)
    # assert len(ct) % BLOCKSIZE == 0
    n = len(ct) // BLOCKSIZE
    aes_decrypt = AES_DECRYPT(k)
    m = bytearray()
    
    for i in range(n-1):
        start, mid, end = i*BLOCKSIZE, (i+1)*BLOCKSIZE, (i+2)*BLOCKSIZE 
        cx, cy = ct[start:mid], ct[mid:end]
        d = aes_decrypt(cy)
        m += xor(cx, d)
    
    return unpad(m)

In [7]:
k = "140b41b22a29beb4061bda66b6747e14"
ct1 = "4ca00ff4c898d61e1edbf1800618fb2828a226d160dad07883d04e008a7897ee2e4b7465d5290d0c0e6c6822236e1daafb94ffe0c5da05d9476be028ad7c1d81"
ct2 = "5b68629feb8606f9a6667670b75b38a5b4832d0f26e1ab7da33249de7d4afc48e713ac646ace36e872ad5fb8a512428a6e21364b0c374df45503473c5242a253"
pt1 = "Basic CBC mode encryption needs padding."
pt2 = "Our implementation uses rand. IV"

In [8]:
m1 = decrypt_cbc(ct1, k)
m2 = decrypt_cbc(ct2, k)

print(m1)
print(m2)

Basic CBC mode encryption needs padding.
Our implementation uses rand. IV


In [9]:
c1 = encrypt_cbc(pt1, k)
c2 = encrypt_cbc(pt2, k)

cx = encrypt_cbc(m1, k)
cy = encrypt_cbc(m2, k)

print(decrypt_cbc(c1, k))
print(decrypt_cbc(cy, k))

Basic CBC mode encryption needs padding.
Our implementation uses rand. IV


In [10]:
assert m1==pt1
assert m2==pt2
assert decrypt_cbc(c1, k)==decrypt_cbc(cx, k)
assert decrypt_cbc(c2, k)==decrypt_cbc(cy, k)