In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, config):
        super(VAE, self).__init__()
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.lag_size = config.lag_size
        self.latent = config.hidden_size * 2
        self.config = config

        # Encoder
        #self.conv1 = nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1)
        self.lstm1 = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, batch_first=True)
        self.fc1 = nn.Linear(self.hidden_size, self.latent)
        self.fc2 = nn.Linear(self.hidden_size , self.latent)

        # Decoder
        self.fc3 = nn.Linear(self.latent, self.hidden_size)
        #self.conv2 = nn.ConvTranspose1d(hidden_size, input_size, kernel_size=3, padding=1)
        self.lstm2 = nn.LSTM(input_size=self.latent, hidden_size=self.hidden_size, batch_first=True)
        self.fc4 = nn.Linear(self.hidden_size, self.input_size)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

    def encode(self, x):
        x, (h,c) = self.lstm1(x)
        #h = h.view(batch_size, -1)
        mean = self.fc1(h)
        logvar = self.fc2(h)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, z, hidden):
        recon_x, (_,_) = self.lstm2(z, hidden)
        recon_x = self.fc4(recon_x)
        # recon_x = self.relu(recon_x)
        return recon_x

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        h_ = self.fc3(z)
        z = z.repeat(1, self.lag_size, 1)
        z = z.view(self.config.batch_size, self.lag_size, self.latent)

        # initialize hidden state
        hidden = (h_.contiguous(), h_.contiguous())
        recon_x = self.decode(z, hidden)
        return recon_x, mean, logvar