In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as mlt
import seaborn as sp
from torch.autograd import Variable
from torch import autograd
from datetime import datetime
import matplotlib.pyplot as plt
import argparse
from datetime import timedelta
import torch.autograd.functional as F

In [1]:
class ModelHelper():
    def __init__(self, config):
        self.config = config

    def gen_label(self, size, is_real=True, noise_ratio=0.1):
        if is_real:
            label = torch.ones(size, self.config.lag_size, 1)
        else:
            label = torch.zeros(size, self.config.lag_size, 1)
        return label.to(self.config.device)

    def gen_z_input(self, batch_size, step, dset, dset_mask):
        return [dset[step * batch_size: (step + 1) * batch_size], dset_mask[step * batch_size: (step + 1) * batch_size]]


    def gen_fake_batch(self, generator, batch_size, step, dset, dset_mask):
        z = self.gen_z_input(batch_size, step, dset, dset_mask)
        fake_dset = generator.predict(z)
        fake_label = self.gen_label(batch_size, is_real=False)
        return fake_dset, fake_label


    def gen_real_batch(self, batch_size, step, dset):
        real_dset = dset[step * batch_size: (step + 1) * batch_size]
        real_label = self.gen_label(batch_size, is_real=True)
        return real_dset, real_label

    def gen_random_batch(self, batch_size, step, dset):
        random_noise = dset[step * batch_size: (step + 1) * batch_size]
        return random_noise
        
    def calculate_gradient_penalty(self, discriminator, real_data, fake_data):
        eta = torch.FloatTensor(self.config.batch_size, self.config.lag_size, self.config.input_size).uniform_(0, 1).to(self.config.device)
        eta = eta.expand(self.config.batch_size, self.config.lag_size, self.config.input_size)

        interpolated = eta * real_data + ((1 - eta) * fake_data)

        # define it to calculate gradient
        interpolated = Variable(interpolated, requires_grad=True)

        # calculate probability of interpolated examples
        prob_interpolated = discriminator(interpolated)

        fake = (torch.ones(prob_interpolated.size()).to(self.config.device))

        # calculate gradients of probabilities with respect to examples
        gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=fake, create_graph=True, retain_graph=True)[0]
        gradients = gradients.reshape(self.config.batch_size, -1)
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
        grad_penalty = ((gradients_norm - 1) ** 2).mean() * 10
        return grad_penalty
