In [1]:
import numpy as np

import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

device = th.device("cuda" if th.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


### Char to byte and vice-versa

In [2]:
def char2byte(c):
  byte = bin(int.from_bytes(c.encode("ascii"), 'big'))[2:]
  if len(byte) < 8: byte = "".join(['0' for _ in range(8-len(byte))])+byte
  return np.array([int(b) for b in byte], dtype=np.float32)

def byte2char(byte):
  b = "0b" +"".join([str(int(b.item())) for b in byte])
  n = int(b, 2)
  return n.to_bytes((n.bit_length() + 7) // 8, "big").decode("ascii")

In [3]:
def create_dataset(symbols, nb, replace=True):
  ids = np.random.choice(symbols.shape[0], nb, replace=replace)
  return TensorDataset(th.from_numpy(symbols[ids]))

### Training functions

In [4]:
def get_accuracy(ann, symbols):
  th_symbols = th.tensor(symbols).to(device)
  pred_symbols = ann(th_symbols).round()
  return (th.all(th_symbols == pred_symbols, dim=1).sum() / len(symbols)).item()

def get_test_loss(
    ann,
    loss_recons,
    loss_diff,
    gamma,
    chars,
    batch_size
):

  ann.eval()

  dataloader = DataLoader(chars, batch_size=batch_size, shuffle=True, pin_memory=True)
  avg_loss_recons = 0
  avg_loss_diff = 0
  nb_batch = 0
  for data in dataloader:
    nb_batch += 1

    data        = data[0].to(device)
    data_cipher = ann.cipher(data)
    data_pred   = ann.decipher(data_cipher)

    l_recons  = loss_recons(data_pred, data)
    l_diff    = gamma*loss_diff(data_cipher, data)
    l         = l_recons - l_diff
    avg_loss_recons += l_recons.item()
    avg_loss_diff   += l_diff.item()

  ann.train()
  return avg_loss_recons / nb_batch, avg_loss_diff / nb_batch


def train(
    ann,
    loss_recons,
    loss_diff,
    acc_fun,
    optimizer,
    train_chars,
    valid_chars,
    test_chars,
    batch_size,
    nb_epochs=-1,
    gamma=0.2,
    verbose=False
):

  """
  Training loop

  Parameters
  ----------
  loss_recons: `[th.Tensor, th.Tensor] -> th.Tensor`
      Evaluates the difference between the original data and the decrypted ones. To be minimized.
  loss_diff: `[th.Tensor, th.Tensor] -> th.Tensor`
      Evaluates the difference between the original data and the encrypted ones. To be maximized.
  acc_fun: `[nn.Module, Iterable] -> float`
      Evaluates the accuracy of the model on each possible symbol.
  nb_epochs: `int`
      Number of epochs. If -1, the training loop stops when the accuracy is 1.0. Default -1.
  """

  ann.train()
  train_losses = []
  valid_losses = []
  epoch = 0
  while True:
    dataloader = DataLoader(train_chars, batch_size=batch_size, shuffle=True, pin_memory=True)
    if verbose:
      print(f"Epoch {epoch+1}")

    avg_loss_recons = 0
    avg_loss_diff = 0
    nb_batch = 0
    for data in dataloader:
      nb_batch += 1

      data        = data[0].to(device)
      data_cipher = ann.cipher(data)
      data_pred   = ann.decipher(data_cipher)

      l_recons  = loss_recons(data_pred, data)
      l_diff    = gamma*loss_diff(data_cipher, data)
      l         = l_recons - l_diff
      avg_loss_recons += l_recons.item()
      avg_loss_diff   += l_diff.item()
      l.backward()

      optimizer.step()

    train_losses                   .append((avg_loss_recons / nb_batch, avg_loss_diff / nb_batch))
    val_loss_recons, val_loss_diff = get_test_loss(ann, loss_recons, loss_diff, gamma, valid_chars, batch_size)
    valid_losses                   .append((val_loss_recons, val_loss_diff))
    accuracy                       = acc_fun(ann) 
    if verbose:
      print(f"Train loss recons {avg_loss_recons / nb_batch:.4f} diff {avg_loss_diff / nb_batch:.4f}")
      print(f"Valid loss recons {val_loss_recons:.4f} diff {val_loss_diff:.4f}")
      print(f"Accuracy: {accuracy:.4f}")
    
    epoch += 1
    if epoch == nb_epochs or accuracy == 1: break

  test_loss_recons, test_loss_diff = get_test_loss(ann, loss_recons, loss_diff, gamma, test_chars, batch_size)
  if verbose:
    print(f"Test loss recons {test_loss_recons:.4f} diff {test_loss_diff:.4f}")

  return (test_loss_recons, test_loss_diff), train_losses, valid_losses


### The cipher and decipher model

In [5]:
class Cipher(nn.Module):
  def __init__(self, arch, latent_dim):
    super().__init__()

    layers = []
    input_dim = 8
    for a in arch:
      layers.append(nn.Linear(input_dim, a))
      input_dim = a
    self.fcs = nn.Sequential(*layers)

    self.last_fc = nn.Linear(input_dim, latent_dim)

  def forward(self, inputs):
    h = self.fcs(inputs)
    h = self.last_fc(h)
    return h
  
class Decipher(nn.Module):
  def __init__(self, arch, latent_dim):
    super().__init__()

    layers = []
    input_dim = latent_dim
    for a in arch:
      layers.append(nn.Linear(input_dim, a))
      input_dim = a
    self.fcs = nn.Sequential(*layers)

    self.last_fc = nn.Linear(input_dim, 8)

  def forward(self, inputs):
    h = self.fcs(inputs)
    h = self.last_fc(h)
    return th.sigmoid(h)

class ANNCrypto(nn.Module):
  def __init__(self, latent_dim, cipher_arch=[], decipher_arch=[]):
    super().__init__()

    self.cipher   = Cipher(cipher_arch, latent_dim)
    self.decipher = Decipher(decipher_arch, latent_dim)
  
  def forward(self, inputs):
    h = self.cipher(inputs)
    return self.decipher(h)

In [6]:
def encrypt(ann, text):
  text_byte   = th.tensor(np.array([char2byte(c) for c in text])).to(device)
  cipher_byte = ann.cipher(text_byte)

  return cipher_byte

def decrypt(ann, cipher_byte):
  text_byte   = ann.decipher(cipher_byte).round()

  text = ""
  for b in text_byte:
    text += byte2char(b)
  return text

### Create datasets

In [7]:
symbols       = np.array([char2byte(chr(c)) for c in range(32, 127)])
train_dataset = create_dataset(symbols, len(symbols), replace=False)
valid_dataset = create_dataset(symbols, len(symbols), replace=False)
test_dataset  = create_dataset(symbols, len(symbols), replace=False)

### Create the model, the optimizer and the losses.

In [8]:
ann         = ANNCrypto(8, [10], [10]).to(device)
optimizer   = th.optim.Adam(ann.parameters(), lr=1e-4)
loss_recons = th.nn.BCELoss()
loss_diff   = th.nn.MSELoss()

In [9]:
_, _, _ = train(
            ann,
            loss_recons,
            loss_diff,
            lambda ann: get_accuracy(ann, symbols),
            optimizer,
            train_dataset,
            valid_dataset,
            test_dataset,
            nb_epochs=-1,
            batch_size=len(symbols),
            gamma=0.001,
            verbose=False)

### Example

In [10]:
original_text  = "Try to decrypt me!"
cipher_byte    = encrypt(ann, original_text)
decrypted_text = decrypt(ann, cipher_byte)
assert original_text == decrypted_text
decrypted_text

'Try to decrypt me!'