In [1]:
import numpy as np
import torch
from data import ceasar_shift, convert_data
from torch import nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from model import Generator, Discriminator
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

## Init parameters

In [2]:
# init parameters
torch.manual_seed(1)
np.random.seed(1)
batch_size = 256
num_epochs = 5
lr = 0.0001
device = "cuda" if torch.cuda.is_available() else "cpu"
shift = 10

## Load data


In [None]:
np_data = convert_data()
np_crypted_data = ceasar_shift(np_data, shift)

tensor_clear_text = torch.from_numpy(np_data)
tensor_crypted_data = torch.from_numpy(np_crypted_data)

tensor_clear_text = tensor_clear_text.float().view(-1, 1, 216, 27)
tensor_crypted_data = tensor_crypted_data.float().view(-1, 1, 216, 27)

num_train = len(tensor_clear_text)
indices_clear = list(range(num_train))
np.random.shuffle(indices_clear)
train_split_clear = int(np.floor(0.6 * num_train))
test_split_clear = int(np.floor(0.8 * num_train))


num_train = len(tensor_clear_text)
indices_crypted = list(range(num_train))
np.random.shuffle(indices_crypted)
train_split_crypted = int(np.floor(0.6 * num_train))
test_split_crypted = int(np.floor(0.8 * num_train))

clear_txt_train = tensor_clear_text[:train_split_clear]
clear_txt_test = tensor_clear_text[train_split_clear:test_split_clear]
clear_txt_valid = tensor_clear_text[test_split_clear:]

crypted_txt_train = tensor_crypted_data[:train_split_crypted]
crypted_txt_test = tensor_crypted_data[train_split_crypted:test_split_crypted]
crypted_txt_valid = tensor_crypted_data[test_split_crypted:]

## Create data loaders

In [None]:
train_clear_loader = DataLoader(clear_txt_train, batch_size=batch_size)
test_clear_loader = DataLoader(clear_txt_test, batch_size=batch_size)
valid_clear_loader = DataLoader(clear_txt_valid, batch_size=batch_size)

train_crypted_loader = DataLoader(crypted_txt_train, batch_size=batch_size)
test_crypted_loader = DataLoader(crypted_txt_test, batch_size=batch_size)
valid_crypted_loader = DataLoader(crypted_txt_valid, batch_size=batch_size)


## Init models

In [None]:
crypted_gen = Generator().to(device) # converts clear to crypted
clear_gen = Generator().to(device) # converts crypted to clear
crypted_discr = Discriminator().to(device)
clear_discr = Discriminator().to(device)

# Setup Adam optimizers for both generators
optimizer_crypted_gen = optim.Adam(crypted_gen.parameters(), lr=lr*50)
optimizer_clear_gen = optim.Adam(clear_gen.parameters(), lr=lr*50)

# Setup Adam optimizers for both discriminators
optimizer_crypted_discr= optim.Adam(crypted_discr.parameters(), lr=lr * 0.5)
optimizer_clear_discr = optim.Adam(clear_discr.parameters(), lr=lr * 0.5)

BCE = nn.BCELoss()
cross_entropy = nn.CrossEntropyLoss()

## init loop variables

In [None]:
crypted_gen_loss_hist = []
clear_gen_loss_hist = []
crypted_discr_loss_hist = []
clear_discr_loss_hist = []
test_decrypted_accuracy_hist = []
test_encrypted_accuracy_hist = []

In [None]:
with torch.no_grad():
    test_crypt_np = crypted_txt_test.view(-1,216,27).detach().numpy()
    test_decrypted_np = ceasar_shift(test_crypt_np, -shift)
    test_decrypted_np_char = np.argmax(test_decrypted_np, axis=2).reshape(-1)
    test_decrypted = torch.from_numpy(test_decrypted_np_char).to(device)
    crypted_txt_valid = crypted_txt_valid.float().to(device)

    test_clear_np = clear_txt_test.view(-1,216,27).detach().numpy()
    test_encrypted_np = ceasar_shift(test_clear_np, shift)
    test_encrypted_np_char = np.argmax(test_encrypted_np, axis=2).reshape(-1)
    test_encrypted = torch.from_numpy(test_encrypted_np_char).to(device)
    clear_txt_valid = clear_txt_valid.float().to(device)

# train loop

In [None]:
torch.autograd.set_detect_anomaly(True)
for epoch in range(num_epochs):

    dataloader_iterator = iter(train_crypted_loader)

    for i, clear in enumerate(train_clear_loader):
        crypted_gen.train()
        clear_gen.train()
        crypted_discr.train()
        clear_discr.train()
        try:
            crypted = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(train_clear_loader)
            crypted = next(dataloader_iterator)

        real_clear_text = clear.to(device)
        real_crypted_text = crypted.to(device)
        fake_clear_text = clear_gen(real_crypted_text)
        fake_crypted_text = crypted_gen(real_clear_text)

        ## train clear discr on true data
        clear_discr.zero_grad()
        pred_real_clear = clear_discr(real_clear_text)
        true_labels = torch.full((len(real_clear_text),1), 1, dtype=torch.float, device=device)
        batch_clear_d_true_loss = BCE(pred_real_clear, true_labels)
        batch_clear_d_true_loss.backward()

        ## train clear Discr on fake data.py
        pred_fake_clear  = clear_discr(fake_clear_text.detach())
        fake_labels = torch.full((len(real_clear_text),1), 0, dtype=torch.float, device=device)
        batch_clear_d_fake_loss = BCE(pred_fake_clear, fake_labels)
        batch_clear_d_fake_loss.backward()
        error_d_clear = batch_clear_d_fake_loss + batch_clear_d_true_loss
        clear_discr_loss_hist.append(error_d_clear.item())
        optimizer_clear_discr.step()

        ## train crypted discr on true data
        crypted_discr.zero_grad()
        pred_real_crypted = crypted_discr(real_crypted_text)
        true_labels = torch.full((len(real_crypted_text),1), 1, dtype=torch.float, device=device)
        batch_crypted_d_true_loss = BCE(pred_real_crypted, true_labels)
        batch_crypted_d_true_loss.backward()

        ## train crypted Discr on fake data
        pred_fake_crypted  = crypted_discr(fake_crypted_text.detach())
        fake_labels = torch.full((len(fake_crypted_text),1), 0, dtype=torch.float, device=device)
        batch_crypted_d_fake_loss = BCE(pred_fake_crypted, fake_labels)
        batch_crypted_d_fake_loss.backward()
        error_d_crypted = batch_crypted_d_fake_loss + batch_crypted_d_true_loss
        crypted_discr_loss_hist.append(error_d_crypted.item())
        optimizer_crypted_discr.step()

        ## train clear gen
        clear_gen.zero_grad()
        crypted_gen.zero_grad()
        gen_fake_clear = clear_discr(fake_clear_text)
        gen_labels = torch.full((len(fake_clear_text),1), 1, dtype=torch.float, device=device)
        fake_gen_clear_loss = BCE(gen_fake_clear, gen_labels)
        fake_gen_clear_loss.backward()

        fake_crypted_reconstruct = crypted_gen(clear_gen(real_crypted_text))
        fake_crypted_reconstruct_loss = cross_entropy(fake_crypted_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_crypted_text, 3).view(-1,216))
        fake_crypted_reconstruct_loss.backward()

        fake_clear_reconstruct =  clear_gen(crypted_gen(real_clear_text))
        fake_clear_reconstruct_loss = cross_entropy(fake_clear_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_clear_text, 3).view(-1,216))
        fake_clear_reconstruct_loss.backward()

        batch_clear_gen_loss = fake_gen_clear_loss + fake_crypted_reconstruct_loss + fake_clear_reconstruct_loss
        # batch_clear_gen_loss.backward()
        optimizer_clear_gen.step()
        clear_gen_loss_hist.append(batch_clear_gen_loss.item())

        ## train crypted gen
        crypted_gen.zero_grad()
        clear_gen.zero_grad()
        gen_fake_crypted = crypted_discr(fake_crypted_text)
        gen_labels = torch.full((len(fake_crypted_text),1), 1, dtype=torch.float, device=device)
        fake_gen_crypted_loss = BCE(gen_fake_crypted, gen_labels)
        fake_gen_crypted_loss.backward()

        fake_crypted_reconstruct = crypted_gen(clear_gen(real_crypted_text))
        fake_crypted_reconstruct_loss = cross_entropy(fake_crypted_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_crypted_text, 3).view(-1,216))
        fake_crypted_reconstruct_loss.backward()

        fake_clear_reconstruct =  clear_gen(crypted_gen(real_clear_text))
        fake_clear_reconstruct_loss = cross_entropy(fake_clear_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_clear_text, 3).view(-1,216))
        fake_clear_reconstruct_loss.backward()

        batch_crypted_gen_loss = fake_gen_crypted_loss + fake_crypted_reconstruct_loss + fake_clear_reconstruct_loss
        # batch_crypted_gen_loss.backward()
        optimizer_crypted_gen.step()
        crypted_gen_loss_hist.append(batch_crypted_gen_loss.item())
        clear_output()
        display("epoch : " + str(epoch))
        display("iteration : " + str(i))
        display("train loss crypt gen : " + str(crypted_gen_loss_hist[-1]))
        display("train loss clear gen : " + str(clear_gen_loss_hist[-1]))
        display("train loss crypt disc : " + str(crypted_discr_loss_hist[-1]))
        display("train loss clear disc : " + str(clear_discr_loss_hist[-1]))
        # print(epoch, i,crypted_gen_loss_hist[-1],clear_gen_loss_hist[-1],crypted_discr_loss_hist[-1],clear_discr_loss_hist[-1])
        crypted_gen.eval()
        clear_gen.eval()
        crypted_discr.eval()
        clear_discr.eval()
        with torch.no_grad():
            test_decrypted_gen = clear_gen(crypted_txt_valid).view(-1,216,27).detach()
            test_decrypted_gen_char = torch.argmax(test_decrypted_gen, 2).view(-1)
            test_decrypted_accuracy = (test_decrypted==test_decrypted_gen_char).sum()/len(test_decrypted_gen_char)
            display("decrypting test accuracy : " +str(test_decrypted_accuracy.item()))
            test_decrypted_accuracy_hist.append(test_decrypted_accuracy.to("cpu").item())

            test_encrypted_gen = crypted_gen(clear_txt_valid).view(-1,216,27).detach()
            test_encrypted_gen_char = torch.argmax(test_encrypted_gen, 2).view(-1)
            test_encrypted_accuracy = (test_encrypted==test_encrypted_gen_char).sum()/len(test_encrypted_gen_char)
            display("encrypting test accuracy : " +str(test_encrypted_accuracy.item()))
            test_encrypted_accuracy_hist.append(test_encrypted_accuracy.to("cpu").item())


In [None]:

plt.figure()
plt.plot(crypted_gen_loss_hist)
plt.show()
plt.close()
plt.figure()
plt.plot(clear_gen_loss_hist)
plt.show()
plt.close()
plt.figure()
plt.plot(crypted_discr_loss_hist)
plt.show()
plt.close()
plt.figure()
plt.plot(clear_discr_loss_hist)
plt.show()
plt.close()
plt.figure()
plt.plot(test_decrypted_accuracy_hist)
plt.show()
plt.close()
plt.figure()
plt.plot(test_encrypted_accuracy_hist)
plt.show()
plt.close()

In [None]:
crypted_gen.eval()
clear_gen.eval()
crypted_discr.eval()
clear_discr.eval()

with torch.no_grad():
    clear_gen = clear_gen.to("cpu")
    val_crypt_np = crypted_txt_valid.to("cpu").view(-1,216,27).detach().numpy()
    decrypted_np = ceasar_shift(val_crypt_np, -shift)
    decrypted_gen = clear_gen(crypted_txt_valid.to("cpu")).view(-1,216,27).detach().numpy()
    decrypted_np_char = np.argmax(decrypted_np, axis=2).reshape(-1)
    decrypted_gen_char = np.argmax(decrypted_gen, axis=2).reshape(-1)
    print((decrypted_np_char == decrypted_gen_char).mean())

In [None]:
with torch.no_grad():
    crypted_gen = crypted_gen.to("cpu")
    val_clear_np = clear_txt_valid.to("cpu").view(-1,216,27).detach().numpy()
    encrypted_np = ceasar_shift(val_clear_np, shift)
    encrypted_gen = crypted_gen(clear_txt_valid.to("cpu")).view(-1,216,27).detach().numpy()
    encrypted_np_char = np.argmax(encrypted_np, axis=2).reshape(-1)
    encrypted_gen_char = np.argmax(encrypted_gen, axis=2).reshape(-1)
    print((encrypted_np_char == encrypted_gen_char).mean())