In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import Input, Concatenate, Reshape, Dense, Conv1D, Flatten
from keras.optimizers import Adam
from tqdm import tqdm

In [None]:
p_len = 16
k_len = 16
c_len = 16
epochs = 20
loss_threshold = 0.1
learning_rate = 0.0008
batch_size = 256
samples = 2 ** 16
batches = samples // batch_size

In [None]:
p_input = Input(shape=(p_len,), name="plaintext")
pbk_input = Input(shape=(k_len,), name="public_key")
pvk_input = Input(shape=(k_len,), name="private_key")
c_input = Input(shape=(c_len,), name="ciphertext")

In [None]:
alice_inputs = [p_input, pbk_input]
x = Concatenate(axis=1, name="alice_concatenate")(alice_inputs)
x = Dense(units=(32), activation="relu", name="alice_dense")(x)
x = Reshape(target_shape=(32, 1), name="alice_reshape")(x)
x = Conv1D(filters=2, kernel_size=4, strides=1, padding="same", activation="relu", name="alice_conv1d_1")(x)
x = Conv1D(filters=4, kernel_size=2, strides=2, padding="same", activation="relu", name="alice_conv1d_2")(x)
x = Conv1D(filters=4, kernel_size=1, strides=1, padding="same", activation="relu", name="alice_conv1d_3")(x)
x = Conv1D(filters=1, kernel_size=1, strides=1, padding="same", activation="tanh", name="alice_conv1d_4")(x)
alice_outputs = Flatten(name="alice_flatten")(x)
alice = Model(inputs=alice_inputs, outputs=alice_outputs, name="alice")
alice.compile()

In [None]:
bob_inputs = [c_input, pvk_input]
x = Concatenate(axis=1, name="bob_concatenate")(bob_inputs)
x = Dense(units=(32), activation="relu", name="bob_dense")(x)
x = Reshape(target_shape=(32, 1), name="bob_reshape")(x)
x = Conv1D(filters=2, kernel_size=4, strides=1, padding="same", activation="relu", name="bob_conv1d_1")(x)
x = Conv1D(filters=4, kernel_size=2, strides=2, padding="same", activation="relu", name="bob_conv1d_2")(x)
x = Conv1D(filters=4, kernel_size=1, strides=1, padding="same", activation="relu", name="bob_conv1d_3")(x)
x = Conv1D(filters=1, kernel_size=1, strides=1, padding="same", activation="tanh", name="bob_conv1d_4")(x)
bob_outputs = Flatten(name="bob_flatten")(x)
bob = Model(inputs=bob_inputs, outputs=bob_outputs, name="bob")
bob.compile()

In [None]:
eve_inputs = [c_input, pbk_input]
x = Concatenate(axis=1, name="eve_concatenate")(eve_inputs)
x = Dense(units=(32), activation="relu", name="eve_dense")(x)
x = Reshape(target_shape=(32, 1), name="eve_reshape")(x)
x = Conv1D(filters=2, kernel_size=4, strides=1, padding="same", activation="relu", name="eve_conv1d_1")(x)
x = Conv1D(filters=4, kernel_size=2, strides=2, padding="same", activation="relu", name="eve_conv1d_2")(x)
x = Conv1D(filters=4, kernel_size=1, strides=1, padding="same", activation="relu", name="eve_conv1d_3")(x)
x = Conv1D(filters=1, kernel_size=1, strides=1, padding="same", activation="tanh", name="eve_conv1d_4")(x)
eve_outputs = Flatten(name="eve_flatten")(x)
eve = Model(inputs=eve_inputs, outputs=eve_outputs, name="eve")
eve.compile()

In [None]:
key_gen_inputs = pvk_input
x = Dense(units=(32), activation="relu", name="key_gen_dense")(key_gen_inputs)
x = Reshape(target_shape=(32, 1), name="key_gen_reshape")(x)
x = Conv1D(filters=2, kernel_size=4, strides=1, padding="same", activation="relu", name="key_gen_conv1d_1")(x)
x = Conv1D(filters=4, kernel_size=2, strides=2, padding="same", activation="relu", name="key_gen_conv1d_2")(x)
x = Conv1D(filters=4, kernel_size=1, strides=1, padding="same", activation="relu", name="key_gen_conv1d_3")(x)
x = Conv1D(filters=1, kernel_size=1, strides=1, padding="same", activation="tanh", name="key_gen_conv1d_4")(x)
key_gen_outputs = Flatten(name="key_gen_flatten")(x)
key_gen = Model(inputs=key_gen_inputs, outputs=key_gen_outputs, name="key_gen")
key_gen.compile()

In [None]:
# l1 distance metric for loss
def l1_distance(a, b):
  a = (a + 1) / 2
  b = (b + 1) / 2
  return tf.reduce_mean(tf.reduce_sum(tf.abs(a - b), axis=-1))


# create training batch
def create_batch():
  p_batch = np.random.choice([-1, 1], size=(batch_size, p_len))
  k_batch = np.random.choice([-1, 1], size=(batch_size, k_len))
  return p_batch, k_batch

In [None]:
# single forward pass for symbolic links
key_gen_output = key_gen(pvk_input)
alice_output = alice([p_input, key_gen_output])
bob_output = bob([alice_output, pvk_input])
eve_output = eve([alice_output, key_gen_output])
# loss and metric functions
tn_eve_loss = l1_distance(p_input, eve_output)
tn_alice_bob_loss = l1_distance(p_input, bob_output) + tf.square(p_len / 2 - tn_eve_loss) / ((p_len / 2) ** 2)
tn_alice_bob_metric = l1_distance(p_input, bob_output)
tn_eve_metric = l1_distance(p_input, eve_output)
# create auxillary training networks
tn_alice_bob = Model(inputs=[pvk_input, p_input], outputs=bob_output, name="tn_alice_bob")
tn_alice_bob.add_loss(tn_alice_bob_loss)
tn_alice_bob.add_metric(tn_alice_bob_metric, name="l1_distance")
tn_eve = Model(inputs=[pvk_input, p_input], outputs=eve_output, name="tn_eve")
tn_eve.add_loss(tn_eve_loss)
tn_eve.add_metric(tn_eve_metric, name="l1_distance")
tn_alice_bob.compile(Adam(learning_rate=learning_rate))
alice.trainable = False
tn_eve.compile(Adam(learning_rate=learning_rate))

In [None]:
bob_train_errors = []
eve_train_errors = []

epoch = 0
above_threshold = True
start_time = time.time()
with tqdm(total=epochs*batches, desc="Training", unit="batch") as pbar:
  while epoch < epochs and above_threshold:
    for batch in range(batches):
      p_batch, pvk_batch = create_batch()
      tn_alice_bob.train_on_batch([pvk_batch, p_batch])
      alice_bob_loss, bob_error = tn_alice_bob.evaluate([pvk_batch, p_batch], verbose=0)
      bob_train_errors.append(bob_error)

      # train Eve for 2 batches
      for _ in range(2):
        p_batch, pvk_batch = create_batch()
        tn_eve.train_on_batch([pvk_batch, p_batch])
      _, eve_error = tn_eve.evaluate([pvk_batch, p_batch], verbose=0)
      eve_train_errors.append(eve_error)

      # update progress bar
      pbar.set_postfix({"alice_bob_loss": alice_bob_loss, "bob_error": bob_error, "eve_error": eve_error})
      pbar.update()

      # exit if Alice and Bob loss is below threshold
      if alice_bob_loss < loss_threshold:
        print("Minimum loss threshold reached, exiting early")
        above_threshold = False
        break

    epoch += 1

total_time = time.strftime("%M:%S", time.gmtime(time.time() - start_time))
print(f"Training finished ({total_time})")

In [None]:
# plot training errors
plt.figure(figsize=(8, 6))
plt.plot(bob_train_errors, label="Bob")
plt.plot(eve_train_errors, label="Eve")
plt.title(f"Asymmetric model training errors")
plt.xlabel("Batches")
plt.ylabel(f"Bits wrong (of {p_len})")
plt.yticks(np.arange(0, (p_len / 2) + 0.5, 0.5))
plt.legend()
plt.show()

In [None]:
# # save models
# alice.save("models/asymmetric/alice.keras")
# bob.save("models/asymmetric/bob.keras")
# eve.save("models/asymmetric/eve.keras")
# key_gen.save("models/asymmetric/key_gen.keras")
# # load models
# alice = keras.models.load_model("models/asymmetric/alice.keras")
# bob = keras.models.load_model("models/asymmetric/bob.keras")
# eve = keras.models.load_model("models/asymmetric/eve.keras")
# key_gen = keras.models.load_model("models/asymmetric/key_gen.keras")

In [None]:
bob_eval_errors = []
eve_eval_errors = []

start_time = time.time()
with tqdm(total=batches, desc="Evaluation", unit="batch") as pbar:
  for batch in range(batches):
    p_batch, pvk_batch = create_batch()
    alice_bob_loss, bob_error = tn_alice_bob.evaluate([pvk_batch, p_batch], verbose=0)
    bob_eval_errors.append(bob_error)

    p_batch, pvk_batch = create_batch()
    _, eve_error = tn_eve.evaluate([pvk_batch, p_batch], verbose=0)
    eve_eval_errors.append(eve_error)

    # update progress bar
    pbar.set_postfix({"alice_bob_loss": alice_bob_loss, "bob_error": bob_error, "eve_error": eve_error})
    pbar.update()

total_time = time.strftime("%M:%S", time.gmtime(time.time() - start_time))
print(f"Evaluation finished ({total_time})")

In [None]:
# plot evaluation errors
plt.figure(figsize=(8, 6))
plt.plot(bob_eval_errors, label="Bob")
plt.plot(eve_eval_errors, label="Eve")
plt.title(f"Asymmetric model evaluation errors")
plt.xlabel("Batches")
plt.ylabel(f"Bits wrong (of {p_len})")
plt.yticks(np.arange(0, (p_len / 2) + 0.5, 0.5))
plt.legend()
plt.show()

In [None]:
# convert text of utf-8 characters to tensor
def text_to_tensor(text, p_len):
  # convert single utf-8 character to 8-bit binary list
  def char_to_binary(ch):
    return [int(bit) for bit in format(ord(ch), "08b")]

  binary = np.array([char_to_binary(ch) for ch in text]).flatten()
  # pad binary list to multiple of p_len
  pad = (p_len - len(binary) % p_len) % p_len
  tensor = np.concatenate([(binary * 2) - 1, np.zeros(pad)])
  return tensor, pad


# convert tensor to text of utf-8 characters
def tensor_to_text(tensor, pad):
  # convert 8-bit binary list to single utf-8 character
  def binary_to_char(binary):
    return chr(int("".join([str(bit) for bit in binary]), 2))

  binary = np.round((tensor + 1) / 2.0).astype("int").flatten()
  binarys = [binary[i: i + 8] for i in range(0, len(binary) - pad, 8)]
  return "".join(map(binary_to_char, binarys))


# perform asymmetric encryption/decryption on text using trained models
def asymmetric_encryption(plaintext):
  tensor, pad = text_to_tensor(plaintext, p_len)
  p_inputs = np.array(tensor).reshape(-1, p_len)
  pvk_inputs = np.random.choice([-1, 1], size=(len(p_inputs), k_len))

  key_gen_output = key_gen(pvk_inputs)
  alice_output = alice([p_inputs, key_gen_output])
  bob_output = bob([alice_output, pvk_inputs])
  eve_output = eve([alice_output, key_gen_output])

  ciphertext = tensor_to_text(alice_output, pad)
  plaintext_bob = tensor_to_text(bob_output, pad)
  plaintext_eve = tensor_to_text(eve_output, pad)

  return ciphertext, plaintext_bob, plaintext_eve

In [None]:
plaintext = "Hello, World!"
ciphertext, plaintext_bob, plaintext_eve = asymmetric_encryption(plaintext)
print(f"Plaintext: {plaintext}")
print(f"Ciphertext: {ciphertext}")
print(f"Plaintext (Bob): {plaintext_bob}")
print(f"Plaintext (Eve): {plaintext_eve}")