In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from model import Generator, Discriminator
import matplotlib.pyplot as plt

## Init parameters

In [2]:
# init parameters
torch.manual_seed(1)
np.random.seed(1)
batch_size = 128
num_epochs = 10
lr = 0.0001
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

## Load data


In [3]:
np_data = np.load("brown_corpus.npy")
np_crypted_data = np.load("brown_corpus_ceasar_shift.npy")

tensor_clear_text = torch.from_numpy(np_data)
tensor_crypted_data = torch.from_numpy(np_crypted_data)

tensor_clear_text = tensor_clear_text.float().view(-1, 1, 216, 27)
tensor_crypted_data = tensor_crypted_data.float().view(-1, 1, 216, 27)

num_train = len(tensor_clear_text)
indices = list(range(num_train))
np.random.shuffle(indices)
train_split = int(np.floor(0.6 * num_train))
test_split = int(np.floor(0.8 * num_train))

clear_txt_train = tensor_clear_text[:train_split]
clear_txt_test = tensor_clear_text[train_split:test_split]
clear_txt_valid = tensor_clear_text[test_split:]

crypted_txt_train = tensor_crypted_data[:train_split]
crypted_txt_test = tensor_crypted_data[train_split:test_split]
crypted_txt_valid = tensor_crypted_data[test_split:]

## Create data loaders

In [4]:
train_clear_loader = DataLoader(clear_txt_train, batch_size=batch_size)
test_clear_loader = DataLoader(clear_txt_test, batch_size=batch_size)
valid_clear_loader = DataLoader(clear_txt_valid, batch_size=batch_size)

train_crypted_loader = DataLoader(crypted_txt_train, batch_size=batch_size)
test_crypted_loader = DataLoader(crypted_txt_test, batch_size=batch_size)
valid_crypted_loader = DataLoader(crypted_txt_valid, batch_size=batch_size)


## Init models

In [5]:
crypted_gen = Generator().to(device) # converts clear to crypted
clear_gen = Generator().to(device) # converts crypted to clear
crypted_discr = Discriminator().to(device)
clear_discr = Discriminator().to(device)

# Setup Adam optimizers for both generators
optimizer_crypted_gen = optim.Adam(crypted_gen.parameters(), lr=lr)
optimizer_clear_gen = optim.Adam(clear_gen.parameters(), lr=lr)

# Setup Adam optimizers for both discriminators
optimizer_crypted_discr= optim.Adam(crypted_discr.parameters(), lr=lr)
optimizer_clear_discr = optim.Adam(clear_gen.parameters(), lr=lr)

BCE = nn.BCELoss()
cross_entropy = nn.CrossEntropyLoss()

## init loop variables

In [6]:
crypted_gen_loss_hist = []
clear_gen_loss_hist = []
crypted_discr_loss_hist = []
clear_discr_loss_hist = []
train_loss = []
test_loss = []
valid_accuracy = []

# train loop

In [7]:
crypted_gen.train()
clear_gen.train()
crypted_discr.train()
clear_discr.train()
for epoch in range(num_epochs):
    dataloader_iterator = iter(train_crypted_loader)

    for i, clear in enumerate(train_clear_loader):

        try:
            crypted = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(train_clear_loader)
            crypted = next(dataloader_iterator)

        real_clear_text = clear.to(device)
        real_crypted_text = crypted.to(device)
        fake_clear_text = clear_gen(real_crypted_text)
        fake_crypted_text = crypted_gen(real_clear_text)

        ## train clear discr on true data
        clear_discr.zero_grad()
        pred_real_clear = clear_discr(real_clear_text)
        true_labels = torch.full((len(real_clear_text),1), 1, dtype=torch.float, device=device)
        batch_clear_d_true_loss = BCE(pred_real_clear, true_labels)
        batch_clear_d_true_loss.backward()

        ## train clear Discr on fake data.py
        pred_fake_clear  = clear_discr(fake_clear_text.detach())
        fake_labels = torch.full((len(real_clear_text),1), 0, dtype=torch.float, device=device)
        batch_clear_d_fake_loss = BCE(pred_fake_clear, fake_labels)
        batch_clear_d_fake_loss.backward()
        error_d_clear = batch_clear_d_fake_loss + batch_clear_d_true_loss
        clear_discr_loss_hist.append(error_d_clear.item())
        optimizer_clear_discr.step()

        ## train clear gen
        clear_gen.zero_grad()
        gen_fake_clear = clear_discr(fake_clear_text)
        true_labels = torch.full((len(fake_clear_text),1), 1, dtype=torch.float, device=device)
        fake_gen_clear_loss = BCE(gen_fake_clear, true_labels)

        fake_crypted_reconstruct = crypted_gen(clear_gen(real_crypted_text))
        fake_crypted_reconstruct_loss = cross_entropy(fake_crypted_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_crypted_text, 3).view(-1,216))

        fake_clear_reconstruct =  clear_gen(crypted_gen(real_clear_text))
        fake_clear_reconstruct_loss = cross_entropy(fake_clear_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_clear_text, 3).view(-1,216))

        batch_clear_gen_loss = fake_gen_clear_loss + fake_crypted_reconstruct_loss + fake_clear_reconstruct_loss
        batch_clear_gen_loss.backward()
        optimizer_clear_gen.step()

        ## train crypted discr on true data
        crypted_discr.zero_grad()
        pred_real_crypted = crypted_discr(real_crypted_text)
        true_labels = torch.full((len(real_crypted_text),1), 1, dtype=torch.float, device=device)
        batch_crypted_d_true_loss = BCE(pred_real_crypted, true_labels)
        batch_crypted_d_true_loss.backward()

        ## train crypted Discr on fake data.py
        pred_fake_crypted  = crypted_discr(fake_crypted_text.detach())
        fake_labels = torch.full((len(fake_crypted_text),1), 0, dtype=torch.float, device=device)
        batch_crypted_d_fake_loss = BCE(pred_fake_crypted, fake_labels)
        batch_crypted_d_fake_loss.backward()
        error_d_crypted = batch_crypted_d_fake_loss + batch_crypted_d_true_loss
        crypted_discr_loss_hist.append(error_d_crypted.item())
        optimizer_crypted_discr.step()

        ## train crypted gen
        crypted_gen.zero_grad()
        gen_fake_crypted = clear_discr(fake_crypted_text)
        true_labels = torch.full((len(fake_crypted_text),1), 1, dtype=torch.float, device=device)
        fake_gen_crypted_loss = BCE(gen_fake_crypted, true_labels)

        fake_crypted_reconstruct = crypted_gen(clear_gen(real_crypted_text))
        fake_crypted_reconstruct_loss = cross_entropy(fake_crypted_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_crypted_text, 3).view(-1,216))

        fake_clear_reconstruct =  clear_gen(crypted_gen(real_clear_text))
        fake_clear_reconstruct_loss = cross_entropy(fake_clear_reconstruct.view(-1, 216, 27).transpose(1,2), torch.argmax(real_clear_text, 3).view(-1,216))

        batch_crypted_gen_loss = fake_gen_crypted_loss + fake_crypted_reconstruct_loss + fake_clear_reconstruct_loss
        batch_crypted_gen_loss.backward()
        optimizer_crypted_gen.step()
        break
    break



