In [134]:
import torch.nn as nn
import torch

class LSTM_VAE(nn.Module):
    def __init__(self, latent_features, num_samples, num_features):
        super(LSTM_VAE, self).__init__()

        self.latent_features = latent_features
        self.num_samples = num_samples
        self.num_features = num_features
        
        # We encode the data onto the latent space using two linear layers
        self.encoder = nn.LSTM(
            input_size=self.num_features,
            hidden_size=2*self.latent_features,
            num_layers=1,
            bias=True,
            batch_first=True,
            dropout=0,
            bidirectional=False
        )

        # The latent code must be decoded into the original image
        self.decoder = nn.LSTM(
            input_size=self.latent_features,
            hidden_size=self.num_features,
            num_layers=1,
            bias=True,
            batch_first=True,
            dropout=0,
            bidirectional=False
        )

        # Initialize hidden state and cell state as learnable parameters
        self.hidden = torch.zeros(1, 1, 2*self.latent_features)
        self.cell_state = torch.zeros(1, 1, 2*self.latent_features)
        
    def forward(self, x):
        outputs = {}
            
        hidden_encoded, _ = self.encoder(x, (self.hidden, self.cell_state))
                
        # Split encoder outputs into a mean and variance vector
        mu, log_var = torch.chunk(hidden_encoded, 2, dim=-1)

        # :- Reparametrisation trick
        # a sample from N(mu, sigma) is mu + sigma * epsilon
        # where epsilon ~ N(0, 1)

        # Don't propagate gradients through randomness
        with torch.no_grad():
            batch_size = mu.size(0)
            epsilon = torch.randn(
                batch_size, self.num_samples, self.latent_features)

        sigma = torch.exp(log_var/2)

        # We will need to unsqueeze to turn
        # (batch_size, latent_dim) -> (batch_size, 1, latent_dim)
        z = mu + epsilon * sigma        
        
        # Run through decoder
        x_hat = self.decoder(z)

        outputs["x_hat"] = x_hat
        outputs["z"] = z
        outputs["mu"] = mu
        outputs["log_var"] = log_var

        return outputs


In [135]:
from lvead.dataload.load_data import load_data

train, _ = load_data('ItalyPowerDemand', basepath='..')

train_x = torch.Tensor(train[:, 1:])

print(train_x.shape)


torch.Size([67, 24])


In [136]:
latent_features = 8
num_samples = train_x.shape[0]
num_features = train_x.shape[1]

train_x = train_x.unsqueeze(0)

net = LSTM_VAE(latent_features, num_samples, num_features)

output = net(train_x)

print(output['x_hat'][0].shape)

torch.Size([1, 67, 24])
