In [17]:
import torch
import time
import numpy as np

In [35]:
class ModelTrain():
    def __init__(self, config):
        self.config = config

    def train_Seq2Seq(self, model, optimizer_model, loss_function, real_train, missing_train, mask_train, step_per_epoch, helper):
    
        loss_evol = []
        real_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        gen_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        mask = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        epoch = 250
        start = time.time()
        for i in range(epoch):
            count = 0
            epoch_loss = 0.0

            for step in range(step_per_epoch):
                # Data for training the discriminator
                real_data, real_label = helper.gen_real_batch(self.config.batch_size, count, real_train)

                z_input, mask_input = helper.gen_z_input(self.config.batch_size, count, missing_train, mask_train)
                #z_input = torch.cat((z_input, (1 - mask_input[:, :, 1]).unsqueeze(2)), dim=2)
                model.zero_grad()
                generated_samples = model(z_input)

                # L2 = torch.sum((real_data - generated_samples)**2)
                penalty = torch.mean(torch.where(generated_samples[:, :, self.config.column] < 0, generated_samples[:, :, self.config.column] ** 2, torch.tensor(0.0, device=generated_samples.device)))
                L2 = torch.sum((real_data[:, :, self.config.column] - generated_samples[:, :, self.config.column])**2)
                g_loss = loss_function(generated_samples, real_data) + L2 + penalty
                g_loss.backward()
                optimizer_model.step()
                epoch_loss += g_loss.item()
                
                count += 1
                if i == epoch - 1:
                    real_dataset = torch.cat([real_dataset, real_data[:, 0, :]], dim=0)
                    gen_dataset = torch.cat([gen_dataset, generated_samples[:, 0, :]], dim=0)
                    mask = torch.cat([mask, mask_input[:, 0, :]], dim=0)

            epoch_loss /= step_per_epoch # average loss per batch
            loss_evol.append(epoch_loss)
            
            print(f"epoch: {i}, Train loss: {epoch_loss:.7f}")

        end = time.time()
        print(f"Time: {(end - start) / 60} minutes")

        return real_dataset, gen_dataset, loss_evol, mask
    
    
    def train_Gan(self, generator, discriminator, optimizer_discriminator, optimizer_generator, loss_function, loss_function_MSE, real_train, missing_train, mask_train, step_per_epoch, helper):
        
        loss_evol = []
        errors_discriminator = []
        errors_generator = []
        real_dataset = torch.empty((0, self.config.input_size + 1), dtype=torch.float32).to(self.config.device)
        gen_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        mask = torch.empty((0, self.config.input_size + 1), dtype=torch.float32).to(self.config.device)
        epoch = 301
        start = time.time()
        
        for i in range(epoch):
            count = 0
            epoch_loss = 0.0
            for step in range(step_per_epoch):
                # Data for training the discriminator
                real_data, real_label = helper.gen_real_batch(self.config.batch_size, count, real_train)
                z_input, mask_input = helper.gen_z_input(self.config.batch_size, count, missing_train, mask_train)

                random_noise = torch.tensor(np.random.randn(self.config.batch_size, self.config.lag_size, self.config.input_size + 1), dtype=torch.float32).to(self.config.device)
                # random_noise = (1 - mask_input)
                z_input += random_noise
                fake_data = generator(z_input[:, :, 0:7])
                fake_label = helper.gen_label(self.config.batch_size, is_real=False)

                discriminator.zero_grad()

                output_discriminator_real = discriminator(real_data[:, :, 0:7])
                output_discriminator_fake = discriminator(fake_data)
                
                #gradient_penalty = calculate_gradient_penalty(discriminator, real_data.detach(), fake_data.detach())
                # d_loss = loss_function(output_discriminator_real, real_label) + loss_function(output_discriminator_fake, fake_label)
                d_loss = -torch.sum(output_discriminator_real) + torch.sum(output_discriminator_fake)
                d_loss.backward()
                optimizer_discriminator.step()

                gan_fake_label = helper.gen_label(self.config.batch_size, is_real=True)
                
                #Training the generator
                generator.zero_grad()
                generated_samples = generator(z_input[:, :, 0:7])
                output_discriminator_generated = discriminator(generated_samples)
                #L2 = torch.norm((real_data * mask_input) - (generated_samples * mask_input))
                #L2 = torch.sum((real_data - generated_samples)**2)
                L2 = loss_function_MSE(real_data[:, :, self.config.column], generated_samples[:, :, self.config.column])
                g_loss = L2 + torch.sum(output_discriminator_generated)
                # g_loss = L2 + loss_function(output_discriminator_generated, gan_fake_label)
                g_loss.backward()
                optimizer_generator.step()
                errors_discriminator.append(d_loss.item())
                errors_generator.append(g_loss.item())

                count += 1
                if i == epoch - 1:
                    real_dataset = torch.cat([real_dataset, real_data[:, 0, :]], dim=0)
                    gen_dataset = torch.cat([gen_dataset, generated_samples[:, 0, :]], dim=0)
                    mask = torch.cat([mask, mask_input[:, 0, :]], dim=0)

                if step == step_per_epoch - 1:
                    print(f"Epoch: {i} Loss D.: {d_loss.item()} Loss G.: {g_loss.item()}")

        end = time.time()
        print(f"Time: {(end - start) / 60} minutes")

        return real_dataset, gen_dataset, mask, errors_generator, errors_discriminator
    

    def train_Vae(self, model, optimizer_model, loss_function, real_train, missing_train, mask_train, step_per_epoch, helper):
    
        loss_evol = []
        real_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        gen_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        mask = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        epoch = 300
        start = time.time()
        for i in range(epoch):
            count = 0
            epoch_loss = 0.0

            for step in range(step_per_epoch):
                # Data for training the discriminator
                real_data, real_label = helper.gen_real_batch(self.config.batch_size, count, real_train)

                z_input, mask_input = helper.gen_z_input(self.config.batch_size, count, missing_train, mask_train)

                model.zero_grad()
                recon_x, mean, logvar = model(z_input)

                L2 = torch.sum((real_data[:, :, self.config.column] - recon_x[:, :, self.config.column])**2)
                # g_loss = loss_function(generated_samples, real_data) + L2
                BCE_loss = loss_function(recon_x[:, :, self.config.column], real_data[:, :, self.config.column])
                KLD_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())) * 0.00025
                g_loss = BCE_loss + (KLD_loss) + L2
                g_loss.backward()
                optimizer_model.step()
                epoch_loss += g_loss.item()
                
                count += 1
                if i == epoch - 1:
                    real_dataset = torch.cat([real_dataset, real_data[:, 0, :]], dim=0)
                    gen_dataset = torch.cat([gen_dataset, recon_x[:, 0, :]], dim=0)
                    mask = torch.cat([mask, mask_input[:, 0, :]], dim=0)

            epoch_loss /= step_per_epoch # average loss per batch
            loss_evol.append(epoch_loss)
            
            print(f"epoch: {i}, Train loss: {epoch_loss:.7f}")

        end = time.time()
        print(f"Time: {(end - start) / 60} minutes")

        return real_dataset, gen_dataset, loss_evol, mask

In [37]:
class ModelTest():
    def __init__(self, config):
        self.config = config

    def test_model(self, model, real_data_test, missing_data_test, mask_test, loss_function, step_per_epoch, helper):
        model.eval()
        test_loss = []

        # batch_size = real_data_test.shape[0]
        # print(batch_size)
        real_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        gen_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        mask = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        with torch.no_grad():
            count = 0
            for step in range(step_per_epoch):
                real_data, real_label = helper.gen_real_batch(self.config.batch_size, count, real_data_test)

                z_input, mask_input = helper.gen_z_input(self.config.batch_size, count, missing_data_test, mask_test)
                
                imputed_results = model(z_input)

                # L2 = torch.sum((real_data - imputed_results)**2)
                L2 = torch.sum((real_data[:, :, self.config.column] - imputed_results[:, :, self.config.column])**2)
                loss = loss_function(imputed_results, real_data) + L2
                test_loss.append(loss.item())

                real_dataset = torch.cat([real_dataset, real_data[:, 0, :]], dim=0)
                gen_dataset = torch.cat([gen_dataset, imputed_results[:, 0, :]], dim=0)
                mask = torch.cat([mask, mask_input[:, 0, :]], dim=0)

                count += 1

        return real_dataset, gen_dataset, test_loss, mask
    
    def test_gan(self, model, real_data_test, missing_data_test, mask_test, loss_function, step_per_epoch, helper):
        model.eval()
        test_loss = []

        # batch_size = real_data_test.shape[0]
        # print(batch_size)
        real_dataset = torch.empty((0, self.config.input_size + 1), dtype=torch.float32).to(self.config.device)
        gen_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        mask = torch.empty((0, self.config.input_size + 1), dtype=torch.float32).to(self.config.device)
        with torch.no_grad():
            count = 0
            for step in range(step_per_epoch):
                real_data, real_label = helper.gen_real_batch(self.config.batch_size, count, real_data_test)

                z_input, mask_input = helper.gen_z_input(self.config.batch_size, count, missing_data_test, mask_test)
                
                imputed_results = model(z_input[:, :, 0:7])

                # L2 = torch.sum((real_data - imputed_results)**2)
                L2 = torch.sum((real_data[:, :, self.config.column] - imputed_results[:, :, self.config.column])**2)
                loss = loss_function(imputed_results[:, :, 0:7], real_data[:, :, 0:7]) + L2
                test_loss.append(loss.item())

                real_dataset = torch.cat([real_dataset, real_data[:, 0, :]], dim=0)
                gen_dataset = torch.cat([gen_dataset, imputed_results[:, 0, :]], dim=0)
                mask = torch.cat([mask, mask_input[:, 0, :]], dim=0)

                count += 1

        return real_dataset, gen_dataset, test_loss, mask
    
    def test_Vae(self, model, real_data_test, missing_data_test, mask_test, loss_function, step_per_epoch, helper):
        model.eval()
        test_loss = []

        # batch_size = real_data_test.shape[0]
        # print(batch_size)
        real_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        gen_dataset = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        mask = torch.empty((0, self.config.input_size), dtype=torch.float32).to(self.config.device)
        with torch.no_grad():
            count = 0
            for step in range(step_per_epoch):
                real_data, real_label = helper.gen_real_batch(self.config.batch_size, count, real_data_test)

                z_input, mask_input = helper.gen_z_input(self.config.batch_size, count, missing_data_test, mask_test)
                
                recon_x, mean, logvar = model(z_input)

                L2 = torch.sum((real_data[:, :, self.config.column] - recon_x[:, :, self.config.column])**2)
                # g_loss = loss_function(generated_samples, real_data) + L2
                BCE_loss = loss_function(recon_x, real_data)
                KLD_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())) * 0.00025
                g_loss = BCE_loss + (KLD_loss) + L2
                test_loss.append(g_loss.item())

                real_dataset = torch.cat([real_dataset, real_data[:, 0, :]], dim=0)
                gen_dataset = torch.cat([gen_dataset, recon_x[:, 0, :]], dim=0)
                mask = torch.cat([mask, mask_input[:, 0, :]], dim=0)

                count += 1

        return real_dataset, gen_dataset, test_loss, mask
        