In [62]:
import torch.nn as nn
import torch
from collections.abc import Iterable
import numpy as np

In [None]:
class MaskedLinearDecoder(nn.Module):
    """Linear decoder for scVI with hard mask on its regression weights."""
    def __init__(
        self,
        n_input: int,
        n_output: int,
        mask: torch.Tensor,
    ):
        
        super().__init__()
        # 1) keep mask as a buffer (not a parameter)
        #    shape must be [out_features, in_features] == [n_output, n_input]
        self.register_buffer("mask", mask)

        # 2) build your normal 1-layer FCLayers that outputs 2*n_output units
        self.linear = nn.Linear(n_input, n_output)

        # 4) zero out masked positions at init
        with torch.no_grad():
            print(self.linear.weight.shape, self.mask.shape)
            self.linear.weight.mul_(self.mask)

    def forward(self, z):
        # 5) re-apply mask on every forward under no_grad
        with torch.no_grad():
            self.linear.weight.mul_(self.mask)

        # 6) proceed as before
        return self.linear(z)

mask = torch.tensor(np.random.choice([0,1], size=(30,10)))

decoder = MaskedLinearDecoder(
    10,
    30,
    mask
)

z = torch.randn((100,10))
out = decoder(z)
out.shape

torch.Size([30, 10]) torch.Size([30, 10])


torch.Size([100, 30])

In [127]:
class VelocityDecoder(nn.Module):
    def __init__(
        self,
        n_input: int,
        n_output: int,
        n_hidden: int = 128,
    ):

        super().__init__()
        self.shared_decoder = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
        )

        self.gene_velocity_decoder = nn.Sequential(
                nn.Linear(n_hidden, n_output)
        )

        self.gp_velocity_decoder = nn.Sequential(
                nn.Linear(n_hidden, n_input)
        )

    def forward(self, z):
        # Parameters for latent distribution
        h = self.shared_decoder(z)
        velocity = self.gene_velocity_decoder(h)
        velocity_gp = self.gp_velocity_decoder(h)

        return velocity, velocity_gp
    
velo_decoder = VelocityDecoder(
    n_input=10,
    n_output=30,
    n_hidden=128
)

velo, velo_gp = velo_decoder(z)

In [None]:
class Encoder(nn.Module):
    def __init__(self, n_input: int, n_latent: int, n_hidden: int = 128):
        super().__init__()
        # shared encoder MLP
        self.encoder = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
        )
        
        # project to mean and log-variance
        self.mean_layer   = nn.Linear(n_hidden, n_latent)
        self.logvar_layer = nn.Linear(n_hidden, n_latent)

    def forward(self, x: torch.Tensor):
        h      = self.encoder(x)
        mean   = self.mean_layer(h)
        logvar = self.logvar_layer(h)
        z      = self.reparametrize(mean, logvar)
        return z, mean, logvar

    @staticmethod
    def reparametrize(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

encoder = Encoder(
    n_input=200,
    n_latent=10,
    n_hidden=128
)

x = torch.randn((20, 200))
z, mean, logvar = encoder(x)

In [142]:
z

tensor([[ 2.7815e+00, -1.2500e+00,  3.4379e-01, -8.9446e-01,  1.4596e+00,
         -5.7749e-01,  1.1130e+00,  1.1371e+00, -1.0372e+00, -2.4681e-01],
        [-1.5438e+00, -2.0263e+00,  6.4457e-01,  4.0716e-01,  1.8892e+00,
         -1.1677e+00, -8.4082e-01,  2.1628e+00,  1.1040e+00, -4.8483e-01],
        [ 8.7712e-01, -6.0255e-01,  1.6068e+00,  1.2103e+00,  6.9378e-01,
          8.8754e-01,  2.5591e-01, -1.2884e+00, -1.3222e+00,  1.4341e-01],
        [-7.3171e-01, -1.3477e-02, -1.0907e+00, -2.6450e-01, -1.2891e+00,
         -1.3188e+00, -6.9773e-01,  2.3941e-01,  1.3722e+00,  4.5603e-01],
        [ 1.0138e+00, -3.3183e+00,  8.0623e-01, -1.1777e-01, -4.7822e-02,
         -1.2980e+00, -7.1306e-01, -9.8975e-01, -1.1985e-01, -2.8009e+00],
        [-1.0001e+00, -1.1881e+00,  5.6144e-01,  9.4248e-02, -1.7554e+00,
         -3.3904e-01,  6.5655e-01, -1.7898e-01,  6.7672e-01,  1.2773e+00],
        [ 1.1776e+00, -4.8968e-01, -5.6751e-01, -1.0875e+00, -9.8670e-01,
         -2.2252e-01,  2.6571e-0