# Notebook to understand the scVI model
    Developed by: Christian Eger
    Würzburg Institute for Systems Immunology, Faculty of Medicine, Julius-Maximilian-Universität Würzburg
    Created: 240328
    Latest version: 240408

## Module imports

In [24]:
import scvi
import scanpy as sc
from torch import nn
import torch

## Data loading

In [2]:
adata = sc.read_h5ad(
    '../data/Marburg_cell_states_locked_scANVI_ctl230901.raw.h5ad'
)
adata

AnnData object with n_obs × n_vars = 97573 × 27208
    obs: 'sex', 'age', 'ethnicity', 'PaCO2', 'donor', 'infection', 'disease', 'SMK', 'illumina_stimunr', 'bd_rhapsody', 'n_genes', 'doublet_scores', 'predicted_doublets', 'batch', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'n_counts', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'condition', 'sample_group', 'IAV_score', 'group', 'Viral_score', 'cell_type', 'cell_states', 'leiden', 'cell_compartment', 'seed_labels', '_scvi_batch', '_scvi_labels', 'C_scANVI'
    var: 'mt', 'ribo'
    obsm: 'X_scANVI', 'X_scVI', 'X_umap'

In [3]:
adata.layers['counts'] = adata.X.copy()

In [4]:
sc.pp.highly_variable_genes(
    adata=adata,
    n_top_genes=3000,
    layer='counts',
    flavor='seurat_v3',
    batch_key='batch'
)

## scVI model preparation

In [5]:
scvi.model.SCVI.setup_anndata(
    adata=adata,
    layer='counts',
    batch_key='donor',
)

In [6]:
model = scvi.model.SCVI(
    adata=adata,
    n_latent=50,
    n_hidden=3,
)

In [7]:
model



### scVI model

In [8]:
model.module

VAE(
  (z_encoder): Encoder(
    (encoder): FCLayers(
      (fc_layers): Sequential(
        (Layer 0): Sequential(
          (0): Linear(in_features=27208, out_features=3, bias=True)
          (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): None
          (3): ReLU()
          (4): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (mean_encoder): Linear(in_features=3, out_features=50, bias=True)
    (var_encoder): Linear(in_features=3, out_features=50, bias=True)
  )
  (l_encoder): Encoder(
    (encoder): FCLayers(
      (fc_layers): Sequential(
        (Layer 0): Sequential(
          (0): Linear(in_features=27208, out_features=3, bias=True)
          (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): None
          (3): ReLU()
          (4): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (mean_encoder): Linear(in_features=3, out_features=1, bias=True)
    (var_enco

## scVI class rebuilding

### FC Layer Class

#### scVI Model

In [9]:
model.module.z_encoder.encoder.fc_layers

Sequential(
  (Layer 0): Sequential(
    (0): Linear(in_features=27208, out_features=3, bias=True)
    (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (2): None
    (3): ReLU()
    (4): Dropout(p=0.1, inplace=False)
  )
)

#### rebuilt model

In [42]:
class FCLayers(nn.Module):

    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fc_layers = nn.Sequential(
            nn.Linear(
                in_features=in_features,
                out_features=out_features,
                bias=True,
            ),
            nn.BatchNorm1d(
                3,
                eps=0.001,
                momentum=0.01,
                affine=True,
                track_running_stats=True
            ),
            None,
            nn.ReLU(),
            nn.Dropout(
                p=0.1,
                inplace=False,
            )
        )
    
    def forward(self, x):
        x = self.fc_layers(x) 
        return x

In [44]:
FCLayers(
    in_features=27208,
    out_features=3,
)

FCLayers(
  (fc_layers): Sequential(
    (0): Linear(in_features=27208, out_features=3, bias=True)
    (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (2): None
    (3): ReLU()
    (4): Dropout(p=0.1, inplace=False)
  )
)

### Encoder Class

#### scVI model

In [21]:
model.module.z_encoder

Encoder(
  (encoder): FCLayers(
    (fc_layers): Sequential(
      (Layer 0): Sequential(
        (0): Linear(in_features=27208, out_features=3, bias=True)
        (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): None
        (3): ReLU()
        (4): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (mean_encoder): Linear(in_features=3, out_features=50, bias=True)
  (var_encoder): Linear(in_features=3, out_features=50, bias=True)
)

In [28]:
model.module.l_encoder

Encoder(
  (encoder): FCLayers(
    (fc_layers): Sequential(
      (Layer 0): Sequential(
        (0): Linear(in_features=27208, out_features=3, bias=True)
        (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): None
        (3): ReLU()
        (4): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (mean_encoder): Linear(in_features=3, out_features=1, bias=True)
  (var_encoder): Linear(in_features=3, out_features=1, bias=True)
)

### rebuilt model

In [78]:
from torch.distributions import Normal


class Encoder(nn.Module):

    def __init__(
            self,
            fc_in,
            fc_out,
            n_hidden,
            n_output,
    ):
        super().__init__()
        self.fc_in = fc_in
        self.fc_out = fc_out
        self.encoder = FCLayers(
            in_features=fc_in,
            out_features=fc_out,
        )
        self.mean_encoder = nn.Linear(
            in_features=n_hidden,
            out_features=n_output,
        )
        self.var_encoder = nn.Linear(
            in_features=n_hidden,
            out_features=n_output,
        )


    def forward(self, x):
        q = self.encoder(x)
        q_m = self.mean_encoder(q)
        q_v = torch.exp(self.var_encoder(q)) + 1e-4
        dist = Normal(q_m, q_v.sqrt())
        latent = dist.rsample()
        return q_m, q_v, latent

In [79]:
Encoder(
    fc_in=27208,
    fc_out=3,
    n_hidden=3,
    n_output=50,
)

Encoder(
  (encoder): FCLayers(
    (fc_layers): Sequential(
      (0): Linear(in_features=27208, out_features=3, bias=True)
      (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): None
      (3): ReLU()
      (4): Dropout(p=0.1, inplace=False)
    )
  )
  (mean_encoder): Linear(in_features=3, out_features=50, bias=True)
  (var_encoder): Linear(in_features=3, out_features=50, bias=True)
)

In [80]:
Encoder(
    fc_in=27208,
    fc_out=3,
    n_hidden=3,
    n_output=1,
)

Encoder(
  (encoder): FCLayers(
    (fc_layers): Sequential(
      (0): Linear(in_features=27208, out_features=3, bias=True)
      (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): None
      (3): ReLU()
      (4): Dropout(p=0.1, inplace=False)
    )
  )
  (mean_encoder): Linear(in_features=3, out_features=1, bias=True)
  (var_encoder): Linear(in_features=3, out_features=1, bias=True)
)

### DecoderSCVI class

#### scVI model

In [15]:
model.module.decoder

DecoderSCVI(
  (px_decoder): FCLayers(
    (fc_layers): Sequential(
      (Layer 0): Sequential(
        (0): Linear(in_features=62, out_features=3, bias=True)
        (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): None
        (3): ReLU()
        (4): None
      )
    )
  )
  (px_scale_decoder): Sequential(
    (0): Linear(in_features=3, out_features=27208, bias=True)
    (1): Softmax(dim=-1)
  )
  (px_r_decoder): Linear(in_features=3, out_features=27208, bias=True)
  (px_dropout_decoder): Linear(in_features=3, out_features=27208, bias=True)
)

### rebuilt model

In [82]:
class DecoderSCVI(nn.Module):

    def __init__(
            self,
            fc_in,
            fc_out,
            decoder_in,
            decoder_out,
    ):
        super().__init__()
        self.fc_in = fc_in
        self.fc_out = fc_out
        self.decoder_in = decoder_in
        self.decoder_out = decoder_out
        self.px_decoder = FCLayers(
            in_features=fc_in,
            out_features=fc_out,
        )
        self.px_scale_decoder = nn.Sequential(
            nn.Linear(
                in_features=decoder_in,
                out_features=decoder_out,
            ),
            nn.Softmax(dim=1)
                    )
        self.px_r_decoder = nn.Linear(
            in_features=decoder_in,
            out_features=decoder_out,
        )
        self.px_dropout_decoder = nn.Linear(
            in_features=decoder_in,
            out_features=decoder_out,
        )

    def forward(self, x):
        pass

In [83]:
DecoderSCVI(
    fc_in=62,
    fc_out=3,
    decoder_in=3,
    decoder_out=27208,
)

DecoderSCVI(
  (px_decoder): FCLayers(
    (fc_layers): Sequential(
      (0): Linear(in_features=62, out_features=3, bias=True)
      (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): None
      (3): ReLU()
      (4): Dropout(p=0.1, inplace=False)
    )
  )
  (px_scale_decoder): Sequential(
    (0): Linear(in_features=3, out_features=27208, bias=True)
    (1): Softmax(dim=1)
  )
  (px_r_decoder): Linear(in_features=3, out_features=27208, bias=True)
  (px_dropout_decoder): Linear(in_features=3, out_features=27208, bias=True)
)

### VAE Class

#### scVI model

In [18]:
model.module

VAE(
  (z_encoder): Encoder(
    (encoder): FCLayers(
      (fc_layers): Sequential(
        (Layer 0): Sequential(
          (0): Linear(in_features=27208, out_features=3, bias=True)
          (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): None
          (3): ReLU()
          (4): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (mean_encoder): Linear(in_features=3, out_features=50, bias=True)
    (var_encoder): Linear(in_features=3, out_features=50, bias=True)
  )
  (l_encoder): Encoder(
    (encoder): FCLayers(
      (fc_layers): Sequential(
        (Layer 0): Sequential(
          (0): Linear(in_features=27208, out_features=3, bias=True)
          (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): None
          (3): ReLU()
          (4): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (mean_encoder): Linear(in_features=3, out_features=1, bias=True)
    (var_enco

#### rebuilt model

In [69]:
class VAE(nn.Module):
    def __init__(
            self,
            z_enc_fc_in,
            z_enc_fc_out,
            z_n_hidden,
            z_n_output,
            l_enc_fc_in,
            l_enc_fc_out,
            l_n_hidden,
            l_n_output,
            decoder_fc_in,
            decoder_fc_out,
            decoder_in,
            decoder_out,
    ):
        self.z_enc_fc_in = z_enc_fc_in
        self.z_enc_fc_out = z_enc_fc_out
        self.z_enc_fc_bias = True
        self.z_n_hidden = z_n_hidden
        self.z_n_output = z_n_output
        self.l_enc_fc_in = l_enc_fc_in
        self.l_enc_fc_out = l_enc_fc_out
        self.l_enc_fc_bias = True
        self.l_n_hidden = l_n_hidden
        self.l_n_output = l_n_output
        self.decoder_fc_in = decoder_fc_in
        self.decoder_fc_out = decoder_fc_out
        self.decoder_fc_bias = True
        self.decoder_in = decoder_in
        self.decoder_out = decoder_out
        self.decoder_bias = True
        super().__init__()
        self.z_encoder = Encoder(
            fc_in=z_enc_fc_in,
            fc_out=z_enc_fc_out,
            n_hidden=z_n_hidden,
            n_output=z_n_output,
        )
        self.l_encoder = Encoder(
            fc_in=l_enc_fc_in,
            fc_out=l_enc_fc_out,
            n_hidden=l_n_hidden,
            n_output=l_n_output,
        )
        self.decoder = DecoderSCVI(
            fc_in=decoder_fc_in,
            fc_out=decoder_fc_out,
            decoder_in=decoder_in,
            decoder_out=decoder_out,
        )

    def forward(self, x):
        pass


In [84]:
VAE(
    z_enc_fc_in=27208,
    z_enc_fc_out=3,
    z_n_hidden=3,
    z_n_output=50,
    l_enc_fc_in=27208,
    l_enc_fc_out=3,
    l_n_hidden=3, 
    l_n_output=1,
    decoder_fc_in=62,
    decoder_fc_out=3,
    decoder_in=3,
    decoder_out=27208,
)

VAE(
  (z_encoder): Encoder(
    (encoder): FCLayers(
      (fc_layers): Sequential(
        (0): Linear(in_features=27208, out_features=3, bias=True)
        (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): None
        (3): ReLU()
        (4): Dropout(p=0.1, inplace=False)
      )
    )
    (mean_encoder): Linear(in_features=3, out_features=50, bias=True)
    (var_encoder): Linear(in_features=3, out_features=50, bias=True)
  )
  (l_encoder): Encoder(
    (encoder): FCLayers(
      (fc_layers): Sequential(
        (0): Linear(in_features=27208, out_features=3, bias=True)
        (1): BatchNorm1d(3, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): None
        (3): ReLU()
        (4): Dropout(p=0.1, inplace=False)
      )
    )
    (mean_encoder): Linear(in_features=3, out_features=1, bias=True)
    (var_encoder): Linear(in_features=3, out_features=1, bias=True)
  )
  (decoder): DecoderSCVI(
    (px_decoder):