In [28]:
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from data_loading import sine_data_generation
from utils import random_generator
from data_loading import MinMaxScaler

from torch.utils.data import DataLoader


from utils import extract_time


Define Class for Module Construction

In [3]:
class Time_GAN_module(nn.Module):
    """
    Class from which a module of the Time GAN Architecture can be constructed, 
    consisting of a n_layer stacked RNN layers and a fully connected layer
    
    input_size = dim of data (depending if module operates on latent or non-latent space)
    """
    def __init__(self, input_size, output_size, hidden_dim, n_layers, activation=torch.sigmoid):
        super(Time_GAN_module, self).__init__()

        # Parameters
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.sigma = activation

        #Defining the layers
        # RNN Layer
        self.rnn = nn.GRU(input_size, hidden_dim, n_layers, batch_first=True)   
        # Fully connected layer
        self.fc = nn.Linear(hidden_dim, output_size)
        
    def forward(self, x):
    
            batch_size = x.size(0)

            # Initializing hidden state for first input using method defined below
            hidden = self.init_hidden(batch_size)

            # Passing in the input and hidden state into the model and obtaining outputs
            out, hidden = self.rnn(x, hidden)
        
            # Reshaping the outputs such that it can be fit into the fully connected layer
            out = out.contiguous().view(-1, self.hidden_dim)
            out = self.fc(out)
            
            if self.sigma == nn.Identity:
                idendity = nn.Identity()
                return idendity(out)
                
            out = self.sigma(out)
            
            # HIDDEN STATES WERDEN IN DER PAPER IMPLEMENTIERUNG AUCH COMPUTED, ALLERDINGS NICHT BENUTZT?
            
            return out, hidden
    
    def init_hidden(self, batch_size):
        # This method generates the first hidden state of zeros which we'll use in the forward pass
        # We'll send the tensor holding the hidden state to the device we specified earlier as well
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
        return hidden

Parameters

In [6]:
input_size = 5
output_size = 20
hidden_dim = 20
n_layers = 3
gamma = 1

no, seq_len, dim = 12800, 24, 5 

batch_size = 128
epoch = 100

Data Generation

In [9]:
data = sine_data_generation(no, seq_len, dim)
data = MinMaxScaler(data)
data = torch.Tensor(data)
data.shape

torch.Size([12800, 24, 5])

Create Modules

In [10]:
Embedder = Time_GAN_module(input_size=dim, output_size=hidden_dim, hidden_dim=hidden_dim, n_layers=n_layers)
Embedder

Time_GAN_module(
  (rnn): GRU(5, 20, num_layers=3, batch_first=True)
  (fc): Linear(in_features=20, out_features=20, bias=True)
)

In [12]:
Recovery = Time_GAN_module(input_size=hidden_dim, output_size=dim, hidden_dim=hidden_dim, n_layers=n_layers)
Recovery

Time_GAN_module(
  (rnn): GRU(20, 20, num_layers=3, batch_first=True)
  (fc): Linear(in_features=20, out_features=5, bias=True)
)

In [13]:
Generator = Time_GAN_module(input_size=dim, output_size=hidden_dim, hidden_dim=hidden_dim, n_layers=n_layers)
Generator

Time_GAN_module(
  (rnn): GRU(5, 20, num_layers=3, batch_first=True)
  (fc): Linear(in_features=20, out_features=20, bias=True)
)

In [15]:
Supervisor = Time_GAN_module(input_size=hidden_dim, output_size=hidden_dim, hidden_dim=hidden_dim, n_layers=n_layers-1)
Supervisor

Time_GAN_module(
  (rnn): GRU(20, 20, num_layers=2, batch_first=True)
  (fc): Linear(in_features=20, out_features=20, bias=True)
)

In [17]:
Discriminator = Time_GAN_module(input_size=hidden_dim, output_size=1, hidden_dim=hidden_dim, n_layers=n_layers, 
                               activation=nn.Identity)
Discriminator

Time_GAN_module(
  (rnn): GRU(20, 20, num_layers=3, batch_first=True)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

Create Optimizers

In [18]:
embedder_optimizer = optim.Adam(Embedder.parameters(), lr=0.001)
recovery_optimizer = optim.Adam(Recovery.parameters(), lr=0.001)
supervisor_optimizer = optim.Adam(Recovery.parameters(), lr=0.001)
discriminator_optimizer = optim.Adam(Discriminator.parameters(), lr=0.001)
generator_optimizer = optim.Adam(Generator.parameters(), lr=0.001)

Data Loader

In [20]:
loader = DataLoader(data, batch_size, shuffle=True)

Embedder Training

In [23]:
print('Start Embedding Network Training')

for e in range(epoch): 
    for batch_index, X in enumerate(loader):
        
        MSE_loss = nn.MSELoss()
        
        H, _ = Embedder(X.float())
        H = torch.reshape(H, (batch_size, seq_len, hidden_dim))

        X_tilde, _ = Recovery(H)
        X_tilde = torch.reshape(X_tilde, (batch_size, seq_len, dim))

        E_loss0 = 10 * torch.sqrt(MSE_loss(X, X_tilde))  

        Embedder.zero_grad()
        Recovery.zero_grad()

        E_loss0.backward(retain_graph=True)

        embedder_optimizer.step()
        recovery_optimizer.step()

        if e in range(1,epoch) and batch_index == 0:
            print('step: '+ str(e) + '/' + str(epoch) + ', e_loss: ' + str(np.sqrt(E_loss0.detach().numpy())))

print('Finish Embedding Network Training')

Start Embedding Network Training
step: 1/100, e_loss: 1.2113627
step: 2/100, e_loss: 1.1118678
step: 3/100, e_loss: 1.0656068
step: 4/100, e_loss: 1.0669308
step: 5/100, e_loss: 0.99132323
step: 6/100, e_loss: 1.0290759
step: 7/100, e_loss: 1.05712
step: 8/100, e_loss: 0.9705321
step: 9/100, e_loss: 0.95293725
step: 10/100, e_loss: 0.917935
step: 11/100, e_loss: 0.7142837
step: 12/100, e_loss: 0.632312
step: 13/100, e_loss: 0.57735693
step: 14/100, e_loss: 0.56844777
step: 15/100, e_loss: 0.5494405
step: 16/100, e_loss: 0.51849407
step: 17/100, e_loss: 0.5033815
step: 18/100, e_loss: 0.48689562
step: 19/100, e_loss: 0.49090096
step: 20/100, e_loss: 0.46919236
step: 21/100, e_loss: 0.4618927
step: 22/100, e_loss: 0.45426735
step: 23/100, e_loss: 0.45190057
step: 24/100, e_loss: 0.44268012
step: 25/100, e_loss: 0.42174116
step: 26/100, e_loss: 0.40710518
step: 27/100, e_loss: 0.41366136
step: 28/100, e_loss: 0.40219304
step: 29/100, e_loss: 0.39501354
step: 30/100, e_loss: 0.3933027
step

Training with supervised Loss

In [24]:
print('Start Training with Supervised Loss Only')

for e in range(epoch): 
    for batch_index, X in enumerate(loader):

        H, _ = Embedder(X.float())
        H = torch.reshape(H, (batch_size, seq_len, hidden_dim))

        H_hat_supervise, _ = Supervisor(H)
        H_hat_supervise = torch.reshape(H_hat_supervise, (batch_size, seq_len, hidden_dim))  

        G_loss_S = MSE_loss(H[:,1:,:], H_hat_supervise[:,:-1,:])


        Embedder.zero_grad()
        Supervisor.zero_grad()

        G_loss_S.backward(retain_graph=True)

        embedder_optimizer.step()
        supervisor_optimizer.step()

        if e in range(1,epoch) and batch_index == 0:
            print('step: '+ str(e) + '/' + str(epoch) + ', s_loss: ' + str(np.sqrt(G_loss_S.detach().numpy())))

print('Finish Training with Supervised Loss Only')

Start Training with Supervised Loss Only
step: 1/100, s_loss: 0.09964683
step: 2/100, s_loss: 0.09478158
step: 3/100, s_loss: 0.092493
step: 4/100, s_loss: 0.08582403
step: 5/100, s_loss: 0.083220385
step: 6/100, s_loss: 0.07979374
step: 7/100, s_loss: 0.074106425
step: 8/100, s_loss: 0.07030248
step: 9/100, s_loss: 0.072493605
step: 10/100, s_loss: 0.06692988
step: 11/100, s_loss: 0.06455012
step: 12/100, s_loss: 0.06289997
step: 13/100, s_loss: 0.059171397
step: 14/100, s_loss: 0.0561611
step: 15/100, s_loss: 0.053076327
step: 16/100, s_loss: 0.05145452
step: 17/100, s_loss: 0.047989562
step: 18/100, s_loss: 0.044613775
step: 19/100, s_loss: 0.042740956
step: 20/100, s_loss: 0.03997947
step: 21/100, s_loss: 0.038406666
step: 22/100, s_loss: 0.038452793
step: 23/100, s_loss: 0.035605773
step: 24/100, s_loss: 0.03442791
step: 25/100, s_loss: 0.032908514
step: 26/100, s_loss: 0.03171904
step: 27/100, s_loss: 0.029951302
step: 28/100, s_loss: 0.028834099
step: 29/100, s_loss: 0.028288601

In [26]:
epoch = 2

In [58]:
random_data = random_generator(batch_size=batch_size, z_dim=dim, 
                                       T_mb=extract_time(data)[0], max_seq_len=extract_time(data)[1])

In [91]:
loader = DataLoader(data, batch_size, shuffle=True)

random_loader = DataLoader(random_data, batch_size, shuffle=True)

binary_cross_entropy_loss = nn.BCEWithLogitsLoss()

MSE_loss = nn.MSELoss()



In [99]:
print('Start Joint Training')

for e in range(epoch): 

    for batch_index, X in enumerate(loader):
        
        random_data = random_generator(batch_size=batch_size, z_dim=dim, 
                                       T_mb=extract_time(data)[0], max_seq_len=extract_time(data)[1])
        
        
        # Generator Training 
        ## Train Generator
        z = torch.tensor(random_data)
        z = z.float()
        
        e_hat, _ = Generator(z)
        e_hat = torch.reshape(e_hat, (batch_size, seq_len, hidden_dim))
        
        H_hat, _ = Supervisor(e_hat)
        H_hat = torch.reshape(H_hat, (batch_size, seq_len, hidden_dim))
        
        Y_fake = Discriminator(H_hat)
        Y_fake = torch.reshape(Y_fake, (batch_size, seq_len, 1))
        
        x_hat, _ = Recovery(H_hat)
        x_hat = torch.reshape(x_hat, (batch_size, seq_len, dim))
        
        
        Generator.zero_grad()
        Supervisor.zero_grad()
        Discriminator.zero_grad()
        Recovery.zero_grad()
        
        G_loss_U = binary_cross_entropy_loss(torch.ones_like(Y_fake), Y_fake)
        
        
        
        G_loss_V1 = torch.mean(torch.abs((torch.std(x_hat, [0], unbiased = False)) + 1e-6 - (torch.std(X, [0]) + 1e-6)))
        G_loss_V2 = torch.mean(torch.abs((torch.mean(x_hat, [0]) - (torch.mean(X, [0])))))
        G_loss_V = G_loss_V1 + G_loss_V2
        
 
        G_loss_U.backward(retain_graph=True)
        G_loss_V.backward()


        generator_optimizer.step()
        supervisor_optimizer.step()
        discriminator_optimizer.step()
        
        ## Train Embedder
        
        MSE_loss = nn.MSELoss()
        
        H, _ = Embedder(X.float())
        H = torch.reshape(H, (batch_size, seq_len, hidden_dim))

        X_tilde, _ = Recovery(H)
        X_tilde = torch.reshape(X_tilde, (batch_size, seq_len, dim))

        E_loss0 = 10 * torch.sqrt(MSE_loss(X, X_tilde))  
        
        H_hat_supervise, _ = Supervisor(H)
        H_hat_supervise = torch.reshape(H_hat_supervise, (batch_size, seq_len, hidden_dim))  

        G_loss_S = MSE_loss(H[:,1:,:], H_hat_supervise[:,:-1,:])
        E_loss = E_loss0  + 0.1 * G_loss_S
        
        G_loss_S.backward(retain_graph=True)
        E_loss.backward()
        
        Embedder.zero_grad()
        Recovery.zero_grad()
        Supervisor.zero_grad()
        
        embedder_optimizer.step()
        recovery_optimizer.step()
        supervisor_optimizer.step()
        
        
        
        # Train Discriminator 
        #....
        
        
        
        #if e in range(1,epoch) and batch_index == 0:
        print('step: '+ str(e) + '/' + str(epoch) + ', G_loss_U: ' + str(G_loss_U.detach().numpy()) + ', G_loss_S: ' + 
             str(G_loss_S.detach().numpy()) + ', E_loss_t0: ' + str(np.sqrt(E_loss0.detach().numpy()))
             )
        
        
 



    


        

print('Finish Joint Training')

Start Joint Training
step: 0/2, G_loss_U: nan, G_loss_S: 1.18187945e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.0969117e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.147849e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.1307436e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.1593437e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.0761217e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.0991596e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.1246363e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.1470942e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.1571507e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.0728089e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.0862017e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.0789463e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.16448e-05, E_loss_t0: nan
step: 0/2, G_loss_U: nan, G_loss_S: 1.1714253e-05, E_loss

KeyboardInterrupt: 