In [1]:
import numpy as np
import torch
from data import ceasar_shift, convert_data
from torch import nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from model import Generator, Discriminator
import matplotlib.pyplot as plt

## Init parameters

In [2]:
# init parameters
torch.manual_seed(1)
np.random.seed(1)
batch_size = 256
num_epochs = 25
lr = 0.0001
device = "cuda" if torch.cuda.is_available() else "cpu"
shift = 10

## Load data


In [3]:
np_data = convert_data()
np_crypted_data = ceasar_shift(np_data, shift)

tensor_clear_text = torch.from_numpy(np_data)
tensor_crypted_data = torch.from_numpy(np_crypted_data)

tensor_clear_text = tensor_clear_text.float().view(-1, 1, 216, 27)
tensor_crypted_data = tensor_crypted_data.float().view(-1, 1, 216, 27)

num_train = len(tensor_clear_text)
indices_clear = list(range(num_train))
np.random.shuffle(indices_clear)
train_split_clear = int(np.floor(0.6 * num_train))
test_split_clear = int(np.floor(0.8 * num_train))


num_train = len(tensor_clear_text)
indices_crypted = list(range(num_train))
np.random.shuffle(indices_crypted)
train_split_crypted = int(np.floor(0.6 * num_train))
test_split_crypted = int(np.floor(0.8 * num_train))

clear_txt_train = tensor_clear_text[:train_split_clear]
clear_txt_test = tensor_clear_text[train_split_clear:test_split_clear]
clear_txt_valid = tensor_clear_text[test_split_clear:]

crypted_txt_train = tensor_crypted_data[:train_split_crypted]
crypted_txt_test = tensor_crypted_data[train_split_crypted:test_split_crypted]
crypted_txt_valid = tensor_crypted_data[test_split_crypted:]

(32635, 216, 27) int32
(32635, 216, 27) int32


## Create data loaders

In [4]:
train_clear_loader = DataLoader(clear_txt_train, batch_size=batch_size)
test_clear_loader = DataLoader(clear_txt_test, batch_size=batch_size)
valid_clear_loader = DataLoader(clear_txt_valid, batch_size=batch_size)

train_crypted_loader = DataLoader(crypted_txt_train, batch_size=batch_size)
test_crypted_loader = DataLoader(crypted_txt_test, batch_size=batch_size)
valid_crypted_loader = DataLoader(crypted_txt_valid, batch_size=batch_size)


## Init models

In [5]:
crypted_gen = Generator().to(device) # converts clear to crypted
clear_gen = Generator().to(device) # converts crypted to clear
crypted_discr = Discriminator().to(device)
clear_discr = Discriminator().to(device)

# Setup Adam optimizers for both generators
optimizer_crypted_gen = optim.Adam(crypted_gen.parameters(), lr=lr*22)
optimizer_clear_gen = optim.Adam(clear_gen.parameters(), lr=lr*22)

# Setup Adam optimizers for both discriminators
optimizer_crypted_discr= optim.Adam(crypted_discr.parameters(), lr=lr*0.1)
optimizer_clear_discr = optim.Adam(clear_discr.parameters(), lr=lr*0.1)

BCE = nn.BCELoss()
cross_entropy = nn.CrossEntropyLoss()

## init loop variables

In [6]:
crypted_gen_loss_hist = []
clear_gen_loss_hist = []
crypted_discr_loss_hist = []
clear_discr_loss_hist = []
train_loss = []
test_loss = []
valid_accuracy = []

# train loop

In [None]:
crypted_gen.train()
clear_gen.train()
crypted_discr.train()
clear_discr.train()
torch.autograd.set_detect_anomaly(True)
for epoch in range(num_epochs):
    dataloader_iterator = iter(train_crypted_loader)

    for i, clear in enumerate(train_clear_loader):

        try:
            crypted = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(train_clear_loader)
            crypted = next(dataloader_iterator)

        real_clear_text = clear.to(device)
        real_crypted_text = crypted.to(device)
        fake_clear_text = clear_gen(real_crypted_text)
        fake_crypted_text = crypted_gen(real_clear_text)

        ## train clear discr on true data
        clear_discr.zero_grad()
        pred_real_clear = clear_discr(real_clear_text)
        true_labels = torch.full((len(real_clear_text),1), 1, dtype=torch.float, device=device)
        batch_clear_d_true_loss = BCE(pred_real_clear, true_labels)
        batch_clear_d_true_loss.backward()

        ## train clear Discr on fake data.py
        pred_fake_clear  = clear_discr(fake_clear_text.detach())
        fake_labels = torch.full((len(real_clear_text),1), 0, dtype=torch.float, device=device)
        batch_clear_d_fake_loss = BCE(pred_fake_clear, fake_labels)
        batch_clear_d_fake_loss.backward()
        error_d_clear = batch_clear_d_fake_loss + batch_clear_d_true_loss
        clear_discr_loss_hist.append(error_d_clear.item())
        optimizer_clear_discr.step()

        ## train clear gen
        clear_gen.zero_grad()
        gen_fake_clear = clear_discr(fake_clear_text)
        gen_labels = torch.full((len(fake_clear_text),1), 1, dtype=torch.float, device=device)
        fake_gen_clear_loss = BCE(gen_fake_clear, gen_labels)
        fake_gen_clear_loss.backward()

        fake_crypted_reconstruct = crypted_gen(clear_gen(real_crypted_text))
        fake_crypted_reconstruct_loss = cross_entropy(fake_crypted_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_crypted_text, 3).view(-1,216))
        fake_crypted_reconstruct_loss.backward()

        fake_clear_reconstruct =  clear_gen(crypted_gen(real_clear_text))
        fake_clear_reconstruct_loss = cross_entropy(fake_clear_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_clear_text, 3).view(-1,216))
        fake_clear_reconstruct_loss.backward()

        batch_clear_gen_loss = fake_gen_clear_loss + fake_clear_reconstruct_loss  +fake_crypted_reconstruct_loss
        # batch_clear_gen_loss.backward()
        optimizer_clear_gen.step()
        clear_gen_loss_hist.append(batch_clear_gen_loss.item())

        ## train crypted discr on true data
        crypted_discr.zero_grad()
        pred_real_crypted = crypted_discr(real_crypted_text)
        true_labels = torch.full((len(real_crypted_text),1), 1, dtype=torch.float, device=device)
        batch_crypted_d_true_loss = BCE(pred_real_crypted, true_labels)
        batch_crypted_d_true_loss.backward()

        ## train crypted Discr on fake data.py
        pred_fake_crypted  = crypted_discr(fake_crypted_text.detach())
        fake_labels = torch.full((len(fake_crypted_text),1), 0, dtype=torch.float, device=device)
        batch_crypted_d_fake_loss = BCE(pred_fake_crypted, fake_labels)
        batch_crypted_d_fake_loss.backward()
        error_d_crypted = batch_crypted_d_fake_loss + batch_crypted_d_true_loss
        crypted_discr_loss_hist.append(error_d_crypted.item())
        optimizer_crypted_discr.step()

        ## train crypted gen
        crypted_gen.zero_grad()
        gen_fake_crypted = crypted_discr(fake_crypted_text)
        gen_labels = torch.full((len(fake_crypted_text),1), 1, dtype=torch.float, device=device)
        fake_gen_crypted_loss = BCE(gen_fake_crypted, gen_labels)
        fake_gen_crypted_loss.backward()

        fake_crypted_reconstruct = crypted_gen(clear_gen(real_crypted_text))
        fake_crypted_reconstruct_loss = cross_entropy(fake_crypted_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_crypted_text, 3).view(-1,216))
        fake_crypted_reconstruct_loss.backward()

        fake_clear_reconstruct =  clear_gen(crypted_gen(real_clear_text))
        fake_clear_reconstruct_loss = cross_entropy(fake_clear_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_clear_text, 3).view(-1,216))
        fake_clear_reconstruct_loss.backward()

        batch_crypted_gen_loss = fake_gen_crypted_loss + fake_crypted_reconstruct_loss + fake_clear_reconstruct_loss
        # batch_crypted_gen_loss.backward()
        optimizer_crypted_gen.step()
        crypted_gen_loss_hist.append(batch_crypted_gen_loss.item())
        print(epoch, i,crypted_gen_loss_hist[-1],clear_gen_loss_hist[-1],crypted_discr_loss_hist[-1],clear_discr_loss_hist[-1])




0 0 7.314979076385498 7.245937347412109 1.3873960971832275 1.3874726295471191
0 1 7.313647270202637 7.243830680847168 1.3868637084960938 1.387528657913208
0 2 7.313785552978516 7.2435150146484375 1.3860876560211182 1.3871419429779053
0 3 7.313301086425781 7.243560791015625 1.385937213897705 1.3864659070968628
0 4 7.312480926513672 7.242727279663086 1.386350393295288 1.386480450630188
0 5 7.31325626373291 7.241246223449707 1.3857139348983765 1.3879791498184204
0 6 7.312309265136719 7.239325523376465 1.3857834339141846 1.3897314071655273
0 7 7.311402797698975 7.238444805145264 1.3858375549316406 1.3904461860656738
0 8 7.308910369873047 7.23708438873291 1.3863990306854248 1.3905563354492188
0 9 7.307299613952637 7.236809730529785 1.387345790863037 1.390254259109497
0 10 7.305594444274902 7.235737323760986 1.3878580331802368 1.3900516033172607
0 11 7.303835391998291 7.235523223876953 1.3885254859924316 1.3893636465072632
0 12 7.303345680236816 7.2359619140625 1.3881622552871704 1.388324975

In [None]:

plt.figure()
plt.plot(crypted_gen_loss_hist)
plt.show()
plt.close()
plt.figure()
plt.plot(clear_gen_loss_hist)
plt.show()
plt.close()
plt.figure()
plt.plot(crypted_discr_loss_hist)
plt.show()
plt.close()
plt.figure()
plt.plot(clear_discr_loss_hist)
plt.show()
plt.close()

In [None]:
crypted_gen.eval()
clear_gen.eval()
crypted_discr.eval()
clear_discr.eval()

from data import ceasar_shift

with torch.no_grad():
    clear_gen = clear_gen.to("cpu")
    val_crypt_np = crypted_txt_valid.view(-1,216,27).detach().numpy()
    decrypted_np = ceasar_shift(val_crypt_np, -10)
    decrypted_gen = clear_gen(crypted_txt_valid.float()).view(-1,216,27).detach().numpy()
    decrypted_np_char = np.argmax(decrypted_np, axis=2).reshape(-1)
    decrypted_gen_char = np.argmax(decrypted_gen, axis=2).reshape(-1)
    print((decrypted_np_char == decrypted_gen_char).mean())