In [33]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as mlt
import seaborn as sp
from torch.autograd import Variable
from torch import autograd
from datetime import datetime
import matplotlib.pyplot as plt
import argparse
from datetime import timedelta
import torch.autograd.functional as F

In [51]:
def gen_label(size, is_real=True, noise_ratio=0.1):
    if is_real:
        label = torch.ones(size, 1)
    else:
        label = torch.zeros(size, 1)
    return label.to(device)

def gen_z_input(batch_size, step, dset, dset_mask):
    return [dset[step * batch_size: (step + 1) * batch_size], dset_mask[step * batch_size: (step + 1) * batch_size]]


def gen_fake_batch(generator, batch_size, step, dset, dset_mask):
    z = gen_z_input(batch_size, step, dset, dset_mask)
    fake_dset = generator.predict(z)
    fake_label = gen_label(batch_size, is_real=False)
    return fake_dset, fake_label


def gen_real_batch(batch_size, step, dset):
    real_dset = dset[step * batch_size: (step + 1) * batch_size]
    real_label = gen_label(batch_size, is_real=True)
    return real_dset, real_label

def gen_random_batch(batch_size, step, dset):
    random_noise = dset[step * batch_size: (step + 1) * batch_size]
    return random_noise
    
def calculate_gradient_penalty(discriminator, real_data, fake_data):
    eta = torch.FloatTensor(batch_size, lag_size, input_size).uniform_(0, 1).to(device)
    eta = eta.expand(batch_size, lag_size, input_size)

    interpolated = eta * real_data + ((1 - eta) * fake_data)

    # define it to calculate gradient
    interpolated = Variable(interpolated, requires_grad=True)

    # calculate probability of interpolated examples
    prob_interpolated = discriminator(interpolated)

    fake = (torch.ones(prob_interpolated.size()).to(device)).requires_grad_(True)

    # calculate gradients of probabilities with respect to examples
    gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=fake, create_graph=True, retain_graph=True)[0]
    gradients = gradients.reshape(batch_size, -1)
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
    grad_penalty = ((gradients_norm - 1) ** 2).mean() * lambda_term
    return grad_penalty

In [58]:
def compute_mmd(real_samples, fake_samples, sigma=1.0):
    real_samples_flat = real_samples.view(-1, real_samples.size(-1))
    fake_samples_flat = fake_samples.view(-1, fake_samples.size(-1))

    kernel_real_real = torch.exp(-torch.cdist(real_samples_flat, real_samples_flat, p=2) / (2 * sigma**2)).mean(dim=0)
    kernel_fake_fake = torch.exp(-torch.cdist(fake_samples_flat, fake_samples_flat, p=2) / (2 * sigma**2)).mean(dim=0)
    kernel_real_fake = torch.exp(-torch.cdist(real_samples_flat, fake_samples_flat, p=2) / (2 * sigma**2)).mean(dim=0)

    mmd_loss = kernel_real_real + kernel_fake_fake - 2 * kernel_real_fake
    mmd_loss = torch.mean(mmd_loss)

    return mmd_loss

In [8]:
def train_Gan(generator, discriminator, optimizer_discriminator, optimizer_generator, loss_function, loss_function_MSE, real_train, missing_train, mask_train, step_per_epoch):
    
    errors_discriminator = []
    errors_generator = []
    real_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    gen_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    epoch = 500
    
    for i in range(epoch):
        count = 0
        for step in range(step_per_epoch):
            # Data for training the discriminator
            real_data, real_label = gen_real_batch(batch_size, count, real_train)
            z_input, mask_input = gen_z_input(batch_size, count, missing_train, mask_train)

            # random_noise = torch.tensor(np.random.randn(batch_size, lag_size, input_size), dtype=torch.float32, requires_grad=True).to(device)
            # random_noise = (1 - mask_input)
            random_data = np.random.normal(loc=0, scale=1, size=(batch_size, lag_size, input_size))
            random_data = np.clip(random_data, 0, 1)
            random_data = torch.tensor(random_data,dtype=torch.float32).to(device)
            z_input += random_data
            fake_data = generator(z_input)
            fake_label = gen_label(batch_size, is_real=False)

            discriminator.zero_grad()

            output_discriminator_real = discriminator(real_data)
            output_discriminator_fake = discriminator(fake_data)
            
            #gradient_penalty = calculate_gradient_penalty(discriminator, real_data.detach(), fake_data.detach())
            d_loss = loss_function(output_discriminator_real, real_label) + loss_function(output_discriminator_fake, fake_label)
            #d_loss = -torch.sum(output_discriminator_real) + torch.sum(output_discriminator_fake)
            d_loss.backward()
            optimizer_discriminator.step()

            gan_fake_label = gen_label(batch_size, is_real=True)
            
            #Training the generator
            generator.zero_grad()
            generated_samples = generator(z_input)
            output_discriminator_generated = discriminator(generated_samples)
            #L2 = torch.norm((real_data * mask_input) - (generated_samples * mask_input))
            #L2 = torch.sum((real_data - generated_samples)**2)
            L2 = loss_function_MSE(real_data, generated_samples)
            #g_loss = -torch.sum(output_discriminator_generated) + (10 * L2)
            g_loss = loss_function(output_discriminator_generated, gan_fake_label) + (L2)
            g_loss.backward()
            optimizer_generator.step()
            errors_discriminator.append(d_loss.item())
            errors_generator.append(g_loss.item())

            # discriminator.zero_grad()

            # output_discriminator_real = discriminator(real_data)
            # output_discriminator_fake = discriminator(fake_data)
            
            # gradient_penalty = calculate_gradient_penalty(discriminator, real_data.detach(), fake_data.detach())

            # #d_loss = loss_function(output_discriminator_real, real_label) + loss_function(output_discriminator_fake, fake_label)
            # d_loss = -torch.sum(output_discriminator_real) + torch.sum(output_discriminator_fake) + gradient_penalty
            # d_loss.backward()
            # optimizer_discriminator.step()

            # if step % 3 == 0:
            #     gan_fake_label = gen_label(batch_size, is_real=True)
            
            #     # Training the generator
            #     generator.zero_grad()
            #     generated_samples = generator(z_input)
            #     output_discriminator_generated = discriminator(generated_samples)
            #     #L2 = torch.norm((real_data * mask_input) - (generated_samples * mask_input))
            #     L2 = loss_function(real_data, generated_samples)
            #     #mmd_loss = compute_mmd(real_data, generated_samples)
            #     g_loss = -torch.sum(output_discriminator_generated) + (1.2 * L2)
            #     g_loss.backward()
            #     optimizer_generator.step()
            #     errors_discriminator.append(d_loss.item())
            #     errors_generator.append(g_loss.item())
            
            count += 1
            if i == epoch - 1:
                real_dataset = torch.cat([real_dataset, real_data], dim=0)
                gen_dataset = torch.cat([gen_dataset, fake_data], dim=0)

            # Show loss
            if step == step_per_epoch - 1:
                print(f"Epoch: {i} Loss D.: {d_loss.item()} Loss G.: {g_loss.item()} mmd: {L2}")

    return real_dataset, gen_dataset, errors_generator, errors_discriminator

In [1]:
def train_Seq2Seq(model, optimizer_model, loss_function, real_train, missing_train, mask_train, step_per_epoch):
    
    errors_generator = []
    real_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    gen_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    mask = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    epoch = 500
    
    for i in range(epoch):
        count = 0
        for step in range(step_per_epoch):
            # Data for training the discriminator
            real_data, real_label = gen_real_batch(batch_size, count, real_train)

            z_input, mask_input = gen_z_input(batch_size, count, missing_train, mask_train)
            #z_input = torch.cat((z_input, (1 - mask_input[:, :, 1]).unsqueeze(2)), dim=2)
            model.zero_grad()
            generated_samples = model(z_input)

            L2 = torch.sum((real_data - generated_samples)**2)
            g_loss = loss_function(generated_samples, real_data) + L2
            g_loss.backward()
            optimizer_model.step()
            errors_generator.append(g_loss.item())
            
            count += 1
            if i == epoch - 1:
                real_dataset = torch.cat([real_dataset, real_data], dim=0)
                gen_dataset = torch.cat([gen_dataset, generated_samples], dim=0)
                mask = torch.cat([mask, mask_input], dim=0)

            # Show loss
            if step == step_per_epoch - 1:
                print(f"Epoch: {i} Loss G.: {g_loss.item()}")

    return real_dataset, gen_dataset, errors_generator, mask

In [16]:
def train_ConvGan(generator, discriminator, optimizer_discriminator, optimizer_generator, loss_function, real_train, missing_train, mask_train, step_per_epoch, random_data):
    
    errors_discriminator = []
    errors_generator = []
    real_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    gen_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    mask_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    epoch = 1000
    
    for i in range(epoch):
        count = 0
        for step in range(step_per_epoch):
            # Data for training the discriminator
            real_data, real_label = gen_real_batch(batch_size, count, real_train)
            z_input, mask_input = gen_z_input(batch_size, count, missing_train, mask_train)
            random_noise = gen_random_batch(batch_size, count, random_data)
            #z_input = z_input + random_noise

            #random_noise = torch.tensor(np.random.randn(batch_size, lag_size, input_size), dtype=torch.float32, requires_grad=True).to(device)
            # random_noise = (1 - mask_input)

            fake_data = generator(real_data, random_noise)
            fake_label = gen_label(batch_size, is_real=False)

            discriminator.zero_grad()

            # real_data = torch.transpose(real_data, 1, 2)
            # fake_data = torch.transpose(fake_data, 1, 2)
            output_discriminator_real = discriminator(real_data)
            output_discriminator_fake = discriminator(fake_data)
            d_loss = loss_function(output_discriminator_real, real_label) + loss_function(output_discriminator_fake, fake_label)
            d_loss.backward()
            optimizer_discriminator.step()

            # Training the generator
            generator.zero_grad()
            generated_samples = generator(real_data, random_noise)
            output_discriminator_generated = discriminator(generated_samples)
            #L2 = torch.norm((real_data * mask_input) - (generated_samples * mask_input))
            L2 = torch.sum((real_data - generated_samples)**2)
            g_loss = loss_function(output_discriminator_generated, real_label) + L2
            g_loss.backward()
            optimizer_generator.step()
            errors_discriminator.append(d_loss.item())
            errors_generator.append(g_loss.item())
            
            # with torch.backends.cudnn.flags(enabled=False):
            #     discriminator.zero_grad()

            #     output_discriminator_real = discriminator(real_data)
            #     output_discriminator_fake = discriminator(fake_data)
            
            #     gradient_penalty = calculate_gradient_penalty(discriminator, real_data.detach(), fake_data.detach())

            #     #d_loss = loss_function(output_discriminator_real, real_label) + loss_function(output_discriminator_fake, fake_label)
            #     d_loss = -torch.sum(output_discriminator_real) + torch.sum(output_discriminator_fake) + gradient_penalty
            #     d_loss.backward()
            #     optimizer_discriminator.step()

            # if step % 7 == 0:
            #     gan_fake_label = gen_label(batch_size, is_real=True)
            
            #     # Training the generator
            #     generator.zero_grad()
            #     generated_samples = generator(random_noise)
            #     output_discriminator_generated = discriminator(generated_samples)
            #     #L2 = torch.norm((real_data * mask_input) - (generated_samples * mask_input))
            #     L2 = torch.sum((real_data[:, -1, 0].view(-1) - generated_samples[:, -1, 0].view(-1))**2)
            #     g_loss = -torch.sum(output_discriminator_generated) + L2
            #     g_loss.backward()
            #     optimizer_generator.step()
            #     errors_discriminator.append(d_loss.item())
            #     errors_generator.append(g_loss.item())

            count += 1
            if i == epoch - 1:
                real_dataset = torch.cat([real_dataset, real_data], dim=0)
                gen_dataset = torch.cat([gen_dataset, fake_data], dim=0)
                mask_dataset = torch.cat([mask_dataset, mask_input], dim=0)

            # Show loss
            if step == step_per_epoch - 1:
                print(f"Epoch: {i} Loss D.: {d_loss.item()} Loss G.: {g_loss.item()}")

    return real_dataset, gen_dataset, errors_generator, errors_discriminator, mask_dataset

In [10]:
def train_autoEncoder(model, optimizer_model, loss_function, real_train, missing_train, mask_train, step_per_epoch):
    
    errors_generator = []
    real_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    gen_dataset = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    mask = torch.empty((0, lag_size, input_size), dtype=torch.float32).to(device)
    epoch = 500
    
    for i in range(epoch):
        count = 0
        for step in range(step_per_epoch):
            # Data for training the discriminator
            real_data, real_label = gen_real_batch(batch_size, count, real_train)

            z_input, mask_input = gen_z_input(batch_size, count, missing_train, mask_train)
            #z_input = torch.cat((z_input, (1 - mask_input[:, :, 1]).unsqueeze(2)), dim=2)
            model.zero_grad()
            recon_x, mean, logvar = model(z_input)

            L2 = torch.sum((real_data - recon_x)**2)
            # g_loss = loss_function(generated_samples, real_data) + L2
            BCE_loss = loss_function(recon_x, real_data)
            KLD_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())) * 0.00025
            g_loss = BCE_loss + (KLD_loss)
            g_loss.backward()
            optimizer_model.step()
            errors_generator.append(g_loss.item())
            
            count += 1
            if i == epoch - 1:
                real_dataset = torch.cat([real_dataset, real_data], dim=0)
                gen_dataset = torch.cat([gen_dataset, recon_x], dim=0)
                mask = torch.cat([mask, mask_input], dim=0)

            # Show loss
            if step == step_per_epoch - 1:
                print(f"Epoch: {i} Loss G.: {g_loss.item()}")

    return real_dataset, gen_dataset, errors_generator, mask