In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import logging
import os

from data_prep import Password as P
from model import Generator, Discriminator
from training_helper import *
from config import *

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [2]:
# custom weights initialization called on G and D
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0, 0.001)

In [None]:
g = Generator(GEN_HIDDEN_SIZE, GEN_NEURON_SIZE).to(device)
d = Discriminator(DISC_HIDDEN_SIZE, DISC_NEURON_SIZE).to(device)

In [None]:
#opt_g = torch.optim.Adam(g.parameters(), lr=0.0002, betas=(0.5, 0.999))
#opt_d = torch.optim.Adam(d.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_g = torch.optim.RMSprop(g.parameters(), lr=0.0002)
opt_d = torch.optim.RMSprop(d.parameters(), lr=0.0002)

In [None]:
p = P()
batch_gen = p.string_gen

In [None]:
logger.setLevel(logging.DEBUG)

TRAIN_FROM_CKPT = True

if os.path.isfile(CHECKPOINT_PATH) and TRAIN_FROM_CKPT:
    checkpoint = torch.load(CHECKPOINT_PATH, map_location = device)
    g.load_state_dict(checkpoint['gen_state_dict'])
    d.load_state_dict(checkpoint['disc_state_dict'])
    g = g.to(device)
    d = d.to(device)
    opt_g.load_state_dict(checkpoint['gen_optimizer_state_dict'])
    opt_d.load_state_dict(checkpoint['disc_optimizer_state_dict'])
    start_len = checkpoint['seq_len']
    start_iter = checkpoint['iter']
    disc_loss = checkpoint['disc_loss']
    gen_loss = checkpoint['gen_loss']
else:
    start_len = 1
    start_iter = 1
    disc_loss = []
    gen_loss = []

for seq_len in range(start_len, MAX_LEN + 1):
    logging.info("---------- Adversarial Training with Seq Len %d, Batch Size %d ----------\n" 
                 % (seq_len, BATCH_SIZE))
    
    for i in range(start_iter, ITERS_PER_SEQ_LEN + 1):
                
        if i % SAVE_CHECKPOINTS_EVERY == 0:
            torch.save({
                'seq_len': seq_len,
                'gen_state_dict': g.state_dict(),
                'disc_state_dict': d.state_dict(),
                'gen_optimizer_state_dict': opt_g.state_dict(),
                'disc_optimizer_state_dict': opt_d.state_dict(),
                'iter': i,
                'gen_loss': gen_loss,
                'disc_loss': disc_loss
                }, CHECKPOINT_PATH)
            logging.info("  *** Model Saved ***\n")
        
        logging.debug("----------------- %d / %d -----------------\n" % (i, ITERS_PER_SEQ_LEN))

        logging.debug("Training discriminator...\n")
        
        d.requiresGrad()
        d.zero_grad()
        g.zero_grad()
        for j in range(CRITIC_ITERS):
            with torch.backends.cudnn.flags(enabled=False):
                L = 0
                
                data = next(batch_gen)
                pred = g(data, seq_len)
                real, fake = get_train_dis(data, pred, seq_len)
                interpolate = get_interpolate(real, fake)

                # Genuine
                disc_real = d(real, seq_len)
                loss_real = -disc_real.mean()
                logging.debug("real loss: "+str(loss_real.item()))
                L += loss_real

                # Fake
                disc_fake = d(fake, seq_len)
                loss_fake = disc_fake.mean()
                logging.debug("fake loss: "+str(loss_fake.item()))
                L += loss_fake

                # Gradient penalty
                interpolate = torch.autograd.Variable(interpolate, requires_grad=True)
                disc_interpolate = d(interpolate, seq_len)
                grad = torch.ones_like(disc_interpolate).to(device)
                gradients = torch.autograd.grad(
                        outputs=disc_interpolate,
                        inputs=interpolate,
                        grad_outputs=grad,
                        create_graph=True,
                        retain_graph=True,
                        only_inputs=True,
                    )[0]
                loss_gp = ((gradients.norm(2, dim=2) - 1) ** 2).mean() * LAMBDA
                logging.debug("grad loss: "+str(loss_gp.item()))
                L += loss_gp

                L.backward(retain_graph=False)
                opt_d.step()
                
                logging.debug("Critic Iter " + str(j+1) + " Loss: " + str(L.item()) + "\n")


        logging.debug("Done training discriminator.\n")    

        logging.debug("Training generator...")

        d.requiresNoGrad()

        for j in range(GEN_ITERS):
            data = next(batch_gen)
            pred = g(data, seq_len)
            fake = get_train_gen(data, pred, seq_len)
            loss_gen = -d(fake, seq_len).mean()
            logging.debug("Gen Iter " + str(j+1) + " Loss: "+str(loss_gen.item()))
            loss_gen.backward(retain_graph=False)
            opt_g.step()

        logging.debug("Done training generator.\n")

        if i % SAVE_CHECKPOINTS_EVERY == 0:
            disc_loss.append(L)
            gen_loss.append(loss_gen)
    
    start_iter = 1
    

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

%matplotlib inline

logger.setLevel(logging.INFO)

plt.figure()
plt.plot(loss_trend)
#plt.savefig('image/rnn-rnn-loss.png',dpi=400)
plt.show()

In [None]:
print_every = 100
loss_trend = []
logging.info("---------- Pre-training generator ----------")
for i in range(PRE_GEN_ITERS):
    pas = next(batch_gen)
    input_tensor = P.passwordToInputTensor(pas).to(device)
    target_tensor = P.passwordToTargetTensor(pas).to(device)
    output, loss = g.pre_train(input_tensor, target_tensor)
    
    if i % print_every == 0:
        loss_trend.append(loss)
        logging.debug("Iter: "+ str(i)+" Loss: "+str(loss))

In [None]:
m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)

In [18]:
gr = nn.GRU(5, 6, 1, batch_first = True, dropout = 0).to(device)
inp = torch.ones(2,3,5).to(device)
print(inp.dim())
le = [2,3]
inp = nn.utils.rnn.pack_padded_sequence(inp, le, batch_first=True, enforce_sorted=False)
ii, batch_sizes, sorted_indices, unsorted_indices = inp
output = gr(inp)
batch_sizes, ii.dim()

3


(tensor([2, 2, 1]), 2)

In [6]:
checkpoint = torch.load(CHECKPOINT_PATH, map_location = device)
g.load_state_dict(checkpoint['gen_state_dict'])

<All keys matched successfully>

In [6]:
g.pre_train(p)

In [7]:
d.pre_train(p, g)

AttributeError: 'float' object has no attribute 'backward'

In [7]:
print(g.generate_N(p))

['gcccccccccccccccccc', '0202020202020202020', '1202020202020202020', 'epes8', 'ia10202020202020202', 's8alalalalalalalala', 'piiiiiiiiiiiiiiiiii', 'bypepes8', '4iia020202020202020', 'piiamiamiamiamames8', 'es8', '1202020202020202020', 'bypepees8', '0020200200200200200', 'piiiames8', 'kia0000000000000000', '9720202020202020202', '1202020202020202020', '0202020202020202020', 'piiames8']


In [6]:
a = torch.tensor([[1.0,2],[3,4]]).to(device)
a.mean()

tensor(2.5000, device='cuda:0')