In [None]:
#|default_exp autoencoders
#|export

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#|export

class VariationalAutoencoder(nn.Module):
    def __init__(self, input_size:int, sigmoid_mask: torch.Tensor):
        '''
        Variational Autoencoder for data compression and reconstruction
        
        Args:
            input_size (int): Dimension of input features
            hidden_size (int): Dimension of hidden layers (default: 64)
            latent_size (int): Dimension of latent space (default: 32)
        '''
        super(VariationalAutoencoder, self).__init__()

        #Stores key model parameters
        self.input_size = input_size

        hidden_size_1 = 64
        hidden_size_2 = 32
        latent_size = 16
        
        # Encoder architecture
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size_1),
            nn.SELU(),
            nn.Linear(hidden_size_1, hidden_size_2), 
            nn.SELU(),
        )
        
        # Latent space parameters
        self.fc_mean = nn.Linear(hidden_size_2, latent_size)  # Mean of latent distribution
        self.fc_log_variance = nn.Linear(hidden_size_2, latent_size)  # Log variance of latent distribution
        
        # Decoder architecture
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, hidden_size_2),
            nn.SELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size_2, hidden_size_1),
            nn.SELU(),
            nn.Linear(hidden_size_1, input_size)
        )

        self.register_buffer('sigmoid_mask', sigmoid_mask.unsqueeze(0))
        
    def encode(self, x:torch.Tensor)->tuple[torch.Tensor, torch.Tensor]:
        '''
        Encode input data into latent space parameters
        
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            tuple: (mean, log_variance) parameters of the latent distribution
        '''
        # Generate latent space parameters
        hidden = self.encoder(x)
        mean = self.fc_mean(hidden)
        log_variance = self.fc_log_variance(hidden)
        return mean, log_variance
    
    def reparameterise(self, mean:torch.Tensor, log_variance:torch.Tensor)->torch.Tensor:
        '''
        Reparameterization to enable backpropagation through random sampling
        
        Args:
            mean (torch.Tensor): Mean of the latent distribution
            log_variance (torch.Tensor): Log variance of the latent distribution
            
        Returns:
            torch.Tensor: Sampled point from the latent distribution
        '''
        log_variance = F.softplus(log_variance) + 1e-6 

        std = torch.exp(0.5 * log_variance)
        eps = torch.randn_like(std)  # Random noise from standard normal
        return mean + eps * std
    
    def decode(self, latent_vector:torch.Tensor)->torch.Tensor:
        '''
        Decode latent representation back to input space
        
        Args:
            latent_vector (torch.Tensor): Latent space representation
            
        Returns:
            torch.Tensor: Reconstructed input
        '''
        return self.decoder(latent_vector)
    
    def forward(self, x:torch.Tensor)->tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        '''
        Forward pass through the VAE
        
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            tuple: (reconstruction, mean)
        '''
        mean, _ = self.encode(x)
        #latent_vector = self.reparameterise(mean, log_variance)
        raw_reconstruction = self.decode(mean)

        reconstruction = torch.where(
            self.sigmoid_mask,                     # broadcast to [B,input_size]
            torch.sigmoid(raw_reconstruction),
            raw_reconstruction
            )

        return reconstruction, mean
