In [None]:
import torch
import torch.nn as nn

class VAEEncoder(nn.Module):
    def __init__(self, input_dim: int, latent_dim: int, hidden_dims=None):
        """
        :param input_dim:  Number of genes (columns) in your TF x gene matrix.
        :param latent_dim: Dimensionality of the latent space (z).
        :param hidden_dims: List of hidden layer sizes (e.g. [512, 256]).
        """
        super(VAEEncoder, self).__init__()
        
        if hidden_dims is None:
            hidden_dims = [512, 256]  # example defaults

        # Build a sequence of linear -> ReLU (or other) layers
        layers = []
        prev_dim = input_dim
        for hdim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hdim))
            layers.append(nn.ReLU())
            prev_dim = hdim
        
        # Combine into a Sequential module
        self.net = nn.Sequential(*layers)

        # Final layers to produce mu and log_var
        self.fc_mu = nn.Linear(prev_dim, latent_dim)
        self.fc_logvar = nn.Linear(prev_dim, latent_dim)


In [None]:

    def forward(self, x):
        """
        :param x: [batch_size, input_dim] tensor
        :return: mu, log_var
                 mu -> [batch_size, latent_dim]
                 log_var -> [batch_size, latent_dim]
        """
        # Pass input through the main feed-forward net
        out = self.net(x)

        # Compute mu and log_var
        mu = self.fc_mu(out)
        log_var = self.fc_logvar(out)

        return mu, log_var
