### Introduction

This is MetaABaCo: A Variational Autoencoder Batch Correction algorithm. It uses the Encoder to extract only biological relevant information from the data while keeping apart batch variability from it. This is achieved by using a batch discriminator for adversarial training. The batch would try to correcly classify the batch label from the encoder output (latent space) while the encoder would try to confuse the discriminator by leaving appart as much batch information as possible. The decoder would work as a data recovery step, where it would reconstruct the data from the latent space. The prior distribution used is the Mixture of Gaussians (MoG) that can be useful for the latent space to retain the biological significance at all times. The biological significance is also assured to be retained by training a classifier that is able to distinguish the cofounding variable labels from the samples on the latent space. This is an attempt to make the original ABaCo algorithm, which is based on a simple autoencoder, into a probabilistic approach that tries to resemble the data on distributions divergence rather than MSE losses. 

### Libraries

In [5]:
#Essentials
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal, kl_divergence, NegativeBinomial, Bernoulli, Categorical
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Subset, ConcatDataset, TensorDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from umap import UMAP

#User libraries
from BatchEffectDataLoader import DataPreprocess, DataTransform, ABaCoDataLoader
from BatchEffectCorrection import correctCombat
from BatchEffectPlots import plotPCA
from BatchEffectMetrics import kBET, iLISI, cLISI, ARI, ASW
from ABaCo import ABaCo, BatchDiscriminator, TissueClassifier, DataDiscriminator

### Variational Autoencoder class

In [6]:
# Encoder class for the ABaCo model
class ABaCoEncoder(nn.Module):
    def __init__(self, input_size, batch_size, d_z, n_comp,
                 hl1 = 1028, hl2 = 512, hl3 = 256):
        super().__init__()
        self.d_z = d_z
        self.n_comp = n_comp
        self.input_size = input_size
        self.batch_size = batch_size
        self.encode = nn.Sequential(
            nn.Linear(input_size + batch_size, hl1),
            nn.ReLU(),
            nn.Linear(hl1, hl2),
            nn.ReLU(),
            nn.Linear(hl2, hl3),
            nn.ReLU()
        )
        
        self.fc_mu = nn.Linear(hl3, d_z * n_comp)
        self.fc_logvar = nn.Linear(hl3, d_z * n_comp)
        self.fc_pi = nn.Linear(hl3, n_comp)
    
    def forward(self, x):
        encoded = self.encode(x)
        mu = self.fc_mu(encoded).view(-1, self.d_z, self.n_comp)
        logvar = self.fc_logvar(encoded).view(-1, self.d_z, self.n_comp)
        pi = nn.functional.softmax(self.fc_pi(encoded), dim=-1)
        return mu, logvar, pi

#Decoder class for the ABaCo model - ZINB is used, so the output would be the parameters of the distribution
class ABaCoDecoder(nn.Module):
    def __init__(self, output_size, d_z, n_comp,
                 hl1 = 256, hl2 = 512, hl3 = 1028):
        super().__init__()
        self.d_z = d_z
        self.n_comp = n_comp
        self.output_size = output_size
        self.decode = nn.Sequential(
            nn.Linear(d_z, hl1),
            nn.ReLU(),
            nn.Linear(hl1, hl2),
            nn.ReLU(),
            nn.Linear(hl2, hl3),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hl3, output_size)
        self.fc_logtheta = nn.Linear(hl3, output_size)
        self.fc_pi = nn.Linear(hl3, output_size)
    
    # Reparameterization trick to sample z ~ MoG(z |Â mu, logvar, pi)
    def reparameterize(self, mu, logvar, pi):

        z_sample = []

        for n in range(self.n_comp):
            
            n_logvar = logvar.gather[:, n, :]
            n_mu = mu.gather[:, n, :]

            std = torch.exp(0.5*n_logvar)
            eps = torch.radn_like(std)
            z = n_mu + eps*std
            z_sample.append(pi[:, n].view(-1,1)*z)
        
        return torch.sum(torch.stack(z_sample), dim=0)
    
    def forward(self, mu, logvar, pi):
        z = self.reparameterize(mu, logvar, pi)
        decoded = self.decode(z)
        mu_nb = torch.exp(self.fc_mu(decoded)) # Exp to ensure positive mu
        logtheta = self.fc_logtheta(decoded) # Log-dispersion of NB
        pi_nb = torch.sigmoid(self.fc_pi(decoded))

        return mu_nb, logtheta, pi_nb

### Reconstruction loss for ELBO

OTU data have zero-inflation and over-dispersion, which means that the variance is by far greater than the mean. The Negative Binomial distribution accounts for that and it's assume to be a good fit to model the OTU data. This is also used in the scDREAMER to account for zero-inflation in single cell RNA seq data. We are going to try to apply a variation of it called Zero-Inflated Negative Binomial distribution (ZINB), and we are going to try model it with non-transformed OTU data.

In [None]:
# Define the negative log-likelihood of the zero-inflated negative binomial distribution
def zinb_nll(x, mu, logtheta, pi):
    # Negative Binomial likelihood
    nb_dist = NegativeBinomial(mu, torch.exp(logtheta))
    zinb_dist = Bernoulli(pi)

    #Log-likelihood of ZINB: p(x) = pi * Bernoulli(x = 0) + (1 - pi) * NB(x | mu, theta)
    ll = torch.log(pi * zinb_dist.log_prob(torch.zeros_like(x)) + (1 - pi) * nb_dist.log_prob(x))
    nll = -torch.sum(ll)

    return nll

### KL divergence for MoG

Because we are using a MoG model as our prior is a bit tricky to actually compute the KL-divergence of the posterior to the prior. For that, let's start by defining the KL-divergence definition:

### Reconstructed data from ZINB distribution

After calculating the parameters of the ZINB, we are going to reconstruct the OTU data using it. 

In [7]:
def sample_from_zinb(mu, logtheta, pi):
    
    nb_dist = NegativeBinomial(mu, torch.exp(logtheta))
    zinb_dist = Bernoulli(pi)

    #1st sample from Bernoulli
    bernoulli_sample = zinb_dist.sample()

    #2nd sample from NB
    nb_sample = nb_dist.sample()

    #3rd apply ZI to NB
    recon_data = bernoulli_sample * torch.zeros_like(nb_sample) + (1 - bernoulli_sample)*nb_sample

    return recon_data

### Training loop

In [None]:
# Training loop for the model
def train_MetaABaCo(
        encoder: ABaCoEncoder,
        decoder: ABaCoDecoder,
        train_loader,
        num_epochs,
        device,
        lr_encode = 1e-5,
        lr_decode = 1e-5
):
    encoder_optim = torch.optim.Adam(encoder.parameters(), lr = lr_encode, weight_decay=1e-5)
    decoder_optim = torch.optim.Adam(decoder.parameters(), lr = lr_decode, weight_decay=1e-5)

    train_losses = []

    for epoch in range(num_epochs):
        # Training
        encoder.train()
        decoder.train()

        train_loss = 0

        for x, _, _ in train_loader:
            #Forward pass to encoder
            x = x.to(device)
            mu, logvar, pi = encoder(x)

            #Forward pass to decoder
            mu_nb, logtheta, pi_zinb = decoder(mu, logvar, pi)

            #Reconstructed OTU data
            recon_data = sample_from_zinb(mu_nb, logtheta, pi_zinb)

            #Reconstruction loss
            recon_loss = zinb_nll(x, mu_nb, logtheta, pi_zinb)
            
