## S-DES

In [300]:
from typing import Literal

### Constants

In [301]:
P10 = [3, 5, 2, 7, 4, 10, 1, 9, 8, 6] 
P8 = [6, 3, 7, 4, 8, 5, 10, 9]
IP = [2, 6, 3, 1, 4, 8, 5, 7]
FP = [4, 1, 3, 5, 7, 2, 8, 6] # or [IP.index(i)+1 for i in range(1, len(IP)+1)] (Inverse of IP)

EP = [4, 1, 2, 3, 2, 3, 4, 1]
S0 = [
	[1, 0, 3, 2], 
	[3, 2, 1, 0], 
	[0, 2, 1, 3], 
	[3, 1, 3, 2]
]
S1 = [
	[0, 1, 2, 3], 
	[2, 0, 1, 3], 
	[3, 0, 1, 0], 
	[2, 1, 0, 3]
]
P4 = [2, 4, 3, 1]


### Helper functions

In [302]:
def bin_to_dec(x: str):
	return int(x, 2)

def dec_to_bin(x: int):
	return bin(x)[2:]

In [304]:
def left_circular_shift(x: str, shifts:int = 1):
	shifts %= len(x)
	return x[shifts:] + x[:shifts]

In [305]:
def permutate(key: str, perm: list):
	# TODO: Handle plaintext > 8 bits
	res = ''
	for k in perm:
		res += key[k-1]
	return res

In [306]:
def split_str(key: str):
	half = len(key) // 2 # integer division operator
	return key[:half], key[half:]

In [307]:
def xor(x: str, y: str):
	res = ''
	for i in range(len(x)):
		res += '0' if x[i] == y[i] else '1'
	return res

### Algorithm

In [308]:
def gen_subkeys(key: str) -> tuple[str, str]:
	l_key, r_key = split_str(permutate(key, P10))
	subkeys = []

	for i in [1,2]:
		l_key, r_key = [left_circular_shift(x, i) for x in [l_key, r_key]]
		subkeys.append(permutate(l_key + r_key, P8))

	return tuple(subkeys)

In [309]:
def s_box(text: str, s: list[list[int]]):
	r = text[0] + text[3]
	c = text[1] + text[2]

	r,c = [bin_to_dec(i) for i in [r,c]]

	out = dec_to_bin(s[r][c])

	if len(out) < 2:
		out = ('0' * (2 - len(out))) + out
	
	return out

In [310]:
def fk(l_in: str, r_in: str, subkey: str):
	l_out = r_in
	l_out = permutate(l_out, EP)
	l_out = xor(l_out, subkey)
	left, right = split_str(l_out)
	l_out = s_box(left, S0) + s_box(right, S1)
	l_out = permutate(l_out, P4)
	l_out = xor(l_out, l_in)
	return l_out, r_in

In [311]:
def x_cryption(input_text: str, key: str, process: Literal['encryption', 'decryption']):
	print('\n', process)
	
	k1, k2 = gen_subkeys(key)
	print("k1:", k1, "\nk2:", k2)
	
	if process == 'decryption':
		k1, k2 = k2, k1
	
	output_text = permutate(input_text, IP)
	print("IP:", output_text)

	left, right = split_str(output_text)

	left, right = fk(left, right, k1)
	print("fk1:", left, right)

	left, right = right, left

	left, right = fk(left, right, k2)
	print("fk2:", left, right)

	output_text = permutate(left + right, FP)
	print("FP:", output_text)

	return output_text

In [312]:
def encryption(plaintext: str, key: str):
	return x_cryption(plaintext, key, 'encryption')

def decryption(ciphertext: str, key: str):
	return x_cryption(ciphertext, key, 'decryption')

### I/O

In [None]:
# optional input validation
def ip_validate(ip: str, l: int) -> str:
	temp = input("Enter {}: ".format(ip))
	# return temp here to skip validation
	if ip == 'key':
		is_bin = False
		try:
			temp = bin_to_dec(temp) # Will fail here if non-binary
			temp = dec_to_bin(temp) # Strip 0b
			is_bin = True
		except ValueError:
			print("Non-binary input detected, attempting conversion to binary")
			pass

		if not is_bin:
			temp = dec_to_bin(int(temp))

		if len(temp) > l:
			raise ValueError("Input '{}' is bigger than expected".format(ip))
	
	if len(temp) < l:
		print("Input '{}' is smaller than expected; Padding with 0s".format(ip))

	temp = '0' * (l - len(temp)) + temp
	return temp

In [313]:
PREDEFINED = True

key = '1100011110' if PREDEFINED else ip_validate("key", 10)
plaintext = '00101000' if PREDEFINED else ip_validate("plaintext", 8)

print("key", key)
print("plaintext", plaintext)

ciphertext = encryption(plaintext, key)
print("Ciphertext:",ciphertext)

p = decryption(ciphertext, key)
print("Plaintext:", p)

assert(p == plaintext)

key 1100011110
plaintext 00101000

 encryption
k1: 11101001 
k2: 10100111
IP: 00100010
fk1: 0011 0010
fk2: 0001 0011
FP: 10001010
Ciphertext: 10001010

 decryption
k1: 11101001 
k2: 10100111
IP: 00010011
fk1: 0010 0011
fk2: 0010 0010
FP: 00101000
Plaintext: 00101000
