In [1]:
import torch.nn as nn
import torch
from collections.abc import Iterable
import numpy as np

In [None]:
class MaskedLinearDecoder(nn.Module):
    """Linear decoder for scVI with hard mask on its regression weights."""
    def __init__(
        self,
        n_input: int,
        n_output: int,
        mask: torch.Tensor,
    ):
        
        super().__init__()
        # 1) keep mask as a buffer (not a parameter)
        #    shape must be [out_features, in_features] == [n_output, n_input]
        self.register_buffer("mask", mask)

        # 2) build your normal 1-layer FCLayers that outputs 2*n_output units
        self.linear = nn.Linear(n_input, n_output)

        # 4) zero out masked positions at init
        with torch.no_grad():
            print(self.linear.weight.shape, self.mask.shape)
            self.linear.weight.mul_(self.mask)

    def forward(self, z):
        # 5) re-apply mask on every forward under no_grad
        with torch.no_grad():
            self.linear.weight.mul_(self.mask)

        # 6) proceed as before
        return self.linear(z)

mask = torch.tensor(np.random.choice([0,1], size=(30,10)))

torch.Size([30, 10]) torch.Size([30, 10])


torch.Size([100, 30])

In [None]:
class VelocityDecoder(nn.Module):
    def __init__(
        self,
        n_input: int,
        n_output: int,
        n_hidden: int = 128,
    ):

        super().__init__()
        self.shared_decoder = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
        )

        self.gene_velocity_decoder = nn.Sequential(
                nn.Linear(n_hidden, n_output)
        )

        self.gp_velocity_decoder = nn.Sequential(
                nn.Linear(n_hidden, n_input)
        )

    def forward(self, z):
        # Parameters for latent distribution
        h = self.shared_decoder(z)
        velocity = self.gene_velocity_decoder(h)
        velocity_gp = self.gp_velocity_decoder(h)

        return velocity, velocity_gp

In [None]:
class Encoder(nn.Module):
    def __init__(self, n_input: int, n_latent: int, n_hidden: int = 128):
        super().__init__()
        # shared encoder MLP
        self.encoder = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
        )
        
        # project to mean and log-variance
        self.mean_layer   = nn.Linear(n_hidden, n_latent)
        self.logvar_layer = nn.Linear(n_hidden, n_latent)

    def forward(self, x: torch.Tensor):
        h      = self.encoder(x)
        mean   = self.mean_layer(h)
        logvar = self.logvar_layer(h)
        z      = self.reparametrize(mean, logvar)
        return z, mean, logvar

    @staticmethod
    def reparametrize(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std


In [7]:
encoder = Encoder(
    n_input=200,
    n_latent=10,
    n_hidden=128
)

x = torch.randn((20, 200))
z, mean, logvar = encoder(x)

In [8]:

decoder = MaskedLinearDecoder(
    10,
    30,
    mask
)

z = torch.randn((100,10))
out = decoder(z)
out.shape

torch.Size([30, 10]) torch.Size([30, 10])


torch.Size([100, 30])

In [9]:
velo_decoder = VelocityDecoder(
    n_input=10,
    n_output=30,
    n_hidden=128
)

velo, velo_gp = velo_decoder(z)