In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

# class Generator(nn.Module):
#     def __init__(self, config):
#         super(Generator, self).__init__()
#         self.input_size = config.input_size
#         self.hidden_size = config.hidden_size
#         self.output_size = config.output_size

#         self.lstm1 = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
#         self.batchNorm1 = nn.BatchNorm1d(self.hidden_size)

#         self.linear1 = nn.Linear(self.hidden_size, self.hidden_size)
#         self.batchNorm2 = nn.BatchNorm1d(self.hidden_size)
        
#         self.lstm2 = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
#         self.batchNorm3 = nn.BatchNorm1d(self.hidden_size)

#         self.linear2 = nn.Linear(self.hidden_size, self.output_size)
#         self.activation = nn.ReLU()

#     def forward(self, x, z):
#         output, (h, c) = self.lstm1(x)
#         h = torch.transpose(h, 1, 2)
#         h = self.batchNorm1(h)
#         h = torch.transpose(h, 1, 2)

#         h = self.linear1(h)
#         # h = torch.transpose(h, 1, 2)
#         # h = self.batchNorm2(h)
#         # h = torch.transpose(h, 1, 2)

#         output, _ = self.lstm2(z, (h, c))
#         output = torch.transpose(output, 1, 2)
#         output = self.batchNorm3(output)
#         output = torch.transpose(output, 1, 2)

#         output = self.linear2(output)
        
#         return self.activation(output)
class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()
        self.input_size = config.input_size  # feature size
        self.hidden_size = config.hidden_size
        self.latent_dim = config.hidden_size * 2
        self.output_size = config.output_size

        # Encoder: GRU to map incomplete data to latent space (z)
        self.encoder_gru = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
        self.encoder_gru2 = nn.GRU(input_size=self.hidden_size, hidden_size=self.hidden_size, batch_first=True)
        self.fc1 = nn.Linear(self.hidden_size, self.latent_dim)  # Map hidden state to latent vector (z)

        # Decoder: GRU to reconstruct complete data from latent vector (z)
        self.decoder_gru = nn.GRU(input_size=self.latent_dim, hidden_size=self.hidden_size, batch_first=True)
        self.decoder_gru2 = nn.GRU(input_size=self.hidden_size, hidden_size=self.hidden_size, batch_first=True)
        self.fc2 = nn.Linear(self.hidden_size, self.output_size)  # Map hidden state to output

        self.activation = nn.ReLU()

    def forward(self, x):
        # Encoding: Pass input (batch_size, lag, feature) through encoder GRU
        output, h = self.encoder_gru(x)  # h shape: (1, batch_size, hidden_size)
        output, h = self.encoder_gru2(output)

        # Latent vector z from the hidden state
        z = self.fc1(output)  # z shape: (batch_size, latent_dim)


        # Pass z through decoder GRU
        output, _ = self.decoder_gru(z)  # output shape: (batch_size, lag, hidden_size)
        output, _ = self.decoder_gru2(output)

        # Map to original output size
        output = self.fc2(output)  # output shape: (batch_size, lag, output_size)

        return output



In [6]:
class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.output_size = config.output_size
        self.lag_size = config.lag_size

        self.gru1 = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
        self.batchNorm1 = nn.BatchNorm1d(self.hidden_size)

        self.linear1 = nn.Linear(self.hidden_size, 1)
        self.activation = nn.Tanh()
        self.flatten = nn.Flatten()

    def forward(self, x):
        output, _ = self.gru1(x)
        # output = torch.transpose(output, 1, 2)
        # output = self.batchNorm1(output)
        # output = torch.transpose(output, 1, 2)

        # output = self.flatten(output)
        output = self.linear1(output)
        
        return self.activation(output)
