In [17]:
import torch
import torch.nn as nn
import torch.optim as optim

class ConvGenerator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ConvGenerator, self).__init__()

        self.lstm1 = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.batchNorm1 = nn.BatchNorm1d(hidden_size)

        self.linear1 = nn.Linear(hidden_size, hidden_size)
        self.batchNorm2 = nn.BatchNorm1d(hidden_size)
        
        self.lstm2 = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.batchNorm3 = nn.BatchNorm1d(hidden_size)

        self.linear2 = nn.Linear(hidden_size, output_size)
        self.activation = nn.ReLU()

    def forward(self, x, z):
        output, (h, c) = self.lstm1(x)
        h = torch.transpose(h, 1, 2)
        h = self.batchNorm1(h)
        h = torch.transpose(h, 1, 2)

        h = self.linear1(h)
        # h = torch.transpose(h, 1, 2)
        # h = self.batchNorm2(h)
        # h = torch.transpose(h, 1, 2)

        output, _ = self.lstm2(z, (h, c))
        output = torch.transpose(output, 1, 2)
        output = self.batchNorm3(output)
        output = torch.transpose(output, 1, 2)

        output = self.linear2(output)
        
        return self.activation(output)


In [16]:
class ConvDiscriminator(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ConvDiscriminator, self).__init__()

        self.lstm1 = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.batchNorm1 = nn.BatchNorm1d(hidden_size)

        self.linear1 = nn.Linear(hidden_size, 1)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        output, _ = self.lstm1(x)
        output = torch.transpose(output, 1, 2)
        output = self.batchNorm1(output)
        output = torch.transpose(output, 1, 2)

        output = self.linear1(output[:, -1, :])
        
        return self.activation(output)
