In [1]:
import os
import numpy as np
import torch
from data import ceasar_shift, convert_data
from torch import nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from models.modelV2 import GeneratorV2, DiscriminatorV2
from models.ResNet import resnet18, resnet34, resnet50, resnet101, resnet152
from models.model import Generator, Discriminator
from models.cyphergan_models import GeneratorV3, DiscriminatorV3
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import wandb
import datetime
from torch.autograd import Variable
from torch import autograd

In [2]:
os.environ["WANDB_SILENT"] = "true"
# os.environ["WANDB_MODE"] = "offline"
torch.manual_seed(1)
np.random.seed(1)
# init parameters
wandb.config = {
    "version" : 1,
    "batch_size" : 128,
    "train_split" : 0.8,
    "test_split": 0.9,
    "num_epochs" : 15,
    "lr_gen" : 0.0002,
    "lr_discr" : 0.0005,
    "beta1" : 0.9,
    "beta2" : 0.999,
    "device" :  "cuda" if torch.cuda.is_available() else "cpu",
    "shift" : 10,
    "reg" : 0.1,
    "instance_size" : 100,
    "dictionary_size" : 27,
    "discriminator_step" : 2,
    "generator_step" : 1,
    "lambda_term":10
}
run = wandb.init(project="Research_project_IS_test", entity="davidvicente", name="Cross entropy cypher generator", config=wandb.config)

## Create data
np_data = convert_data(fixed_len=wandb.config["instance_size"])
np_crypted_data = ceasar_shift(np_data, wandb.config["shift"])

tensor_clear_text = torch.from_numpy(np_data)
tensor_crypted_data = torch.from_numpy(np_crypted_data)

tensor_clear_text = tensor_clear_text.float().view(-1, 1, wandb.config["instance_size"], wandb.config["dictionary_size"])
tensor_crypted_data = tensor_crypted_data.float().view(-1, 1, wandb.config["instance_size"], wandb.config["dictionary_size"])

(24206, 100, 27) int32
(24206, 100, 27) int32


In [3]:
num_train = len(tensor_clear_text)
indices = list(range(num_train))
np.random.shuffle(indices)
train_split = int(np.floor(wandb.config["train_split"] * num_train))
test_split = int(np.floor(wandb.config["test_split"] * num_train))


clear_txt_train = tensor_clear_text[indices[:train_split]]
clear_txt_test = tensor_clear_text[indices[train_split:test_split]]
clear_txt_valid = tensor_clear_text[indices[test_split:]]

crypted_txt_train = tensor_crypted_data[indices[:train_split]]
crypted_txt_test = tensor_crypted_data[indices[train_split:test_split]]
crypted_txt_valid = tensor_crypted_data[indices[test_split:]]

In [4]:
train_clear_loader = DataLoader(clear_txt_train, batch_size=wandb.config["batch_size"])
train_crypted_loader = DataLoader(crypted_txt_train, batch_size=wandb.config["batch_size"])

In [5]:
crypted_gen = Generator(wandb.config["instance_size"], wandb.config["dictionary_size"]).to(wandb.config["device"]) # converts clear to crypted
clear_gen = Generator(wandb.config["instance_size"], wandb.config["dictionary_size"]).to(wandb.config["device"]) # converts crypted to clear

wandb.watch(crypted_gen, log="all", log_freq=1000, log_graph=True, idx=1)
wandb.watch(clear_gen, log="all", log_freq=1000, log_graph=True, idx=2)

optimizer_crypted_gen = optim.Adam(crypted_gen.parameters(), lr=wandb.config["lr_gen"], betas=(wandb.config["beta1"], wandb.config["beta2"]))
optimizer_clear_gen = optim.Adam(clear_gen.parameters(), lr=wandb.config["lr_gen"], betas=(wandb.config["beta1"], wandb.config["beta2"]))

BCE = nn.BCELoss()
cross_entropy = nn.CrossEntropyLoss()
mse = nn.MSELoss()

In [6]:

torch.autograd.set_detect_anomaly(True)
for epoch in range(wandb.config["num_epochs"]):

    dataloader_iterator = iter(train_crypted_loader)

    for i, clear in enumerate(train_clear_loader):
        wandb.log({"epoch": epoch}, commit=False)

        crypted_gen.train()
        clear_gen.train()

        try:
            crypted = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(train_clear_loader)
            crypted = next(dataloader_iterator)

        real_clear_text = clear.to(wandb.config["device"])
        real_crypted_text = crypted.to(wandb.config["device"])

        optimizer_crypted_gen.zero_grad()
        fake_crypted = crypted_gen(real_clear_text)
        # crypted_loss = torch.sum(torch.square(fake_crypted - real_crypted_text), dim=(2,3)).mean()
        crypted_loss = cross_entropy(fake_crypted.view(-1, wandb.config["instance_size"], wandb.config["dictionary_size"]).transpose(1,2), torch.argmax(real_crypted_text, 3).view(-1,wandb.config["instance_size"]))
        crypted_loss.backward()
        optimizer_crypted_gen.step()


        optimizer_clear_gen.zero_grad()
        fake_clear = clear_gen(real_crypted_text)
        # clear_loss = torch.sum(torch.square(fake_clear - real_clear_text), dim=(2,3)).mean()
        clear_loss = cross_entropy(fake_clear.view(-1, wandb.config["instance_size"], wandb.config["dictionary_size"]).transpose(1,2), torch.argmax(real_clear_text, 3).view(-1,wandb.config["instance_size"]))
        clear_loss.backward()
        optimizer_clear_gen.step()

        wandb.log({"clear loss": clear_loss, "crypted loss": crypted_loss})

KeyboardInterrupt: 

In [None]:
crypted_gen.eval()
clear_gen.eval()

In [None]:

with torch.no_grad():
    clear_gen = clear_gen.to("cpu")
    val_crypt_np = crypted_txt_valid.view(-1,wandb.config["instance_size"],wandb.config["dictionary_size"]).detach().numpy()
    decrypted_np = ceasar_shift(val_crypt_np, -wandb.config["shift"])
    decrypted_gen = clear_gen(crypted_txt_valid).view(-1,wandb.config["instance_size"],wandb.config["dictionary_size"]).detach().numpy()
    decrypted_np_char = np.argmax(decrypted_np, axis=2).reshape(-1)
    decrypted_gen_char = np.argmax(decrypted_gen, axis=2).reshape(-1)
    print((decrypted_np_char == decrypted_gen_char).mean())

In [None]:
with torch.no_grad():
    crypted_gen = crypted_gen.to("cpu")
    val_clear_np = clear_txt_valid.view(-1,wandb.config["instance_size"],wandb.config["dictionary_size"]).detach().numpy()
    encrypted_np = ceasar_shift(val_clear_np,wandb.config["shift"])
    encrypted_gen = crypted_gen(clear_txt_valid).view(-1,wandb.config["instance_size"],wandb.config["dictionary_size"]).detach().numpy()
    encrypted_np_char = np.argmax(encrypted_np, axis=2).reshape(-1)
    encrypted_gen_char = np.argmax(encrypted_gen, axis=2).reshape(-1)
    print((encrypted_np_char == encrypted_gen_char).mean())

In [None]:
import torch.nn.functional as F

new_clear = clear_gen(crypted_gen(clear_txt_valid, logits=False))
arg_clear = torch.argmax(clear_txt_valid, dim=3).view(-1)
arg_new_clear = torch.argmax(new_clear, dim=3).view(-1)
print((arg_new_clear == arg_clear).sum()/len(arg_clear))
