### Introduction

This is MetaABaCo: A Variational Autoencoder Batch Correction algorithm. It uses the Encoder to extract only biological relevant information from the data while keeping apart batch variability from it. This is achieved by using a batch discriminator for adversarial training. The batch would try to correcly classify the batch label from the encoder output (latent space) while the encoder would try to confuse the discriminator by leaving appart as much batch information as possible. The decoder would work as a data recovery step, where it would reconstruct the data from the latent space. The prior distribution used is the Mixture of Gaussians (MoG) that can be useful for the latent space to retain the biological significance at all times. The biological significance is also assured to be retained by training a classifier that is able to distinguish the cofounding variable labels from the samples on the latent space. This is an attempt to make the original ABaCo algorithm, which is based on a simple autoencoder, into a probabilistic approach that tries to resemble the data on distributions divergence rather than MSE losses. 

### Libraries

In [None]:
#Essentials
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal, kl_divergence, NegativeBinomial, Bernoulli, Categorical
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Subset, ConcatDataset, TensorDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from umap import UMAP

#User libraries
from BatchEffectDataLoader import DataPreprocess, DataTransform, ABaCoDataLoader
from BatchEffectCorrection import correctCombat
from BatchEffectPlots import plotPCA, plotPCoA
from BatchEffectMetrics import kBET, iLISI, cLISI, ARI, ASW
from ABaCo import ABaCo, BatchDiscriminator, TissueClassifier, DataDiscriminator
from MetaABaCo import NormalPrior, NormalEncoder, ZINBDecoder

### Variational Autoencoder class

In [None]:
# Encoder class for the ABaCo model - MoG parameters are encoded
class ABaCoEncoder(nn.Module):
    def __init__(self, input_size, batch_size, d_z,
                 hl1 = 1028, hl2 = 512, hl3 = 256):
        super().__init__()
        self.d_z = d_z
        # self.n_comp = n_comp
        self.input_size = input_size
        self.batch_size = batch_size
        self.encode = nn.Sequential(
            nn.Linear(input_size + batch_size, hl1),
            nn.LeakyReLU(),
            nn.Linear(hl1, hl2),
            nn.LeakyReLU(),
            nn.Linear(hl2, hl3),
            nn.LeakyReLU()
        )
        ### FOR MOG PRIOR ###
        # self.fc_pi = nn.Linear(hl3, n_comp)
        # self.fc_mu = nn.Linear(hl3, d_z * n_comp)
        # self.fc_logvar = nn.Linear(hl3, d_z * n_comp)

        ### FOR NORMAL PRIOR ###
        self.fc_mu = nn.Linear(hl3, d_z)
        self.fc_logvar = nn.Linear(hl3, d_z)
    
    def forward(self, x):
        encoded = self.encode(x)

        ### FOR MOG PRIOR ###
        # mu = self.fc_mu(encoded).view(-1, self.d_z, self.n_comp)
        # logvar = self.fc_logvar(encoded).view(-1, self.d_z, self.n_comp)
        # pi = nn.functional.softmax(self.fc_pi(encoded), dim=-1).clamp(min=1e-8)
        # Re-normalize to ensure the probabilities sum to 1:
        # pi = pi / pi.sum(dim=-1, keepdim=True)
        # return mu, logvar, pi

        ### FOR NORMAL PRIOR ###
        mu = self.fc_mu(encoded)
        logvar = self.fc_logvar(encoded)

        return mu, logvar

#Decoder class for the ABaCo model - ZINB is used, so the output would be the parameters of the distribution
class ABaCoDecoder(nn.Module):
    def __init__(self, output_size, d_z, n_comp,
                 hl1 = 256, hl2 = 512, hl3 = 1028):
        super().__init__()
        self.d_z = d_z
        self.n_comp = n_comp
        self.output_size = output_size
        self.decode = nn.Sequential(
            nn.Linear(d_z, hl1),
            nn.LeakyReLU(),
            nn.Linear(hl1, hl2),
            nn.LeakyReLU(),
            nn.Linear(hl2, hl3),
            nn.LeakyReLU()
        )
        self.fc_mu = nn.Linear(hl3, output_size)
        self.fc_logtheta = nn.Linear(hl3, output_size)
        self.fc_pi = nn.Linear(hl3, output_size)
    
    # Reparameterization trick to sample z ~ MoG(z | mu, logvar, pi)
    def reparameterize(self, mu, logvar):

        ### FOR MOG PRIOR ###

        #   mu and logvar are of shape (batch_size, d_z, n_comp)
        #   pi is of shape (batch_size, n_comp) and is already softmaxed.
        # batch_size = mu.size(0)
        
        # # Sample one component index per sample (shape: (batch_size, 1))
        # comp_idx = torch.multinomial(pi, num_samples=1).squeeze(1)
        
        # # Create batch indices
        # batch_idx = torch.arange(batch_size)
        
        # # Select the corresponding mu and logvar for each sample's chosen component.
        # chosen_mu = mu[batch_idx, :, comp_idx]       # shape: (batch_size, d_z)
        # chosen_logvar = logvar[batch_idx, :, comp_idx] # shape: (batch_size, d_z)
        
        # # Reparameterize
        # eps = torch.randn_like(chosen_mu)
        # z = chosen_mu + eps * torch.exp(0.5 * chosen_logvar)

        ### FOR GAUSSIAN NORMAL PRIOR ###
        std = torch.exp(0.5 * logvar)
        eps = torch.radn_like(std)
        z = mu + eps*std
    
        return z

    
    def forward(self, mu, logvar):

        """
        Outputs the parameters of the ZINB distribution: x ~ ZINB(x | mu, theta, pi)
            mu_nb: mean of the NB distribution parameter
            theta: total counts in the NB distribution
            pi_nb: probability of zero-inflation 
        """

        ### FOR MOG PRIOR ###
        # z = self.reparameterize(mu, logvar, pi)

        ### FOR GAUSSIAN NORMAL PRIOR ###
        z = self.reparameterize(mu, logvar)

        decoded = self.decode(z)

        mu_nb = nn.functional.softplus(self.fc_mu(decoded)) # softplus ensure positive mu
        mu_nb = torch.clamp(mu_nb, min=1e-8, max=1e6)

        theta = nn.functional.softplus(self.fc_logtheta(decoded)) # Log-dispersion of NB always positive
        theta = torch.clamp(theta, min=1e-8, max=1e6)
        
        pi_nb = torch.sigmoid(self.fc_pi(decoded)) # Probability of zero-inflation: sigmoid to assure values in [0,1]

        return z, mu_nb, theta, pi_nb

### Batch discriminator class

In [12]:
class BatchDiscriminator(nn.Module):
    def __init__(self,
                 input_size,
                 batch_size,
                 tissue_size,
                 hl1_size = 128,
                 hl2_size = 64
                 ):
        super().__init__()
        self.input_size = input_size
        self.hl1_size = hl1_size
        self.hl2_size = hl2_size
        self.batch_size = batch_size
        self.tissue_size = tissue_size
        self.ffnn = nn.Sequential(
            nn.Linear(input_size + tissue_size, hl1_size),
            nn.ReLU(),
            nn.Linear(hl1_size, hl2_size),
            nn.ReLU(),
            nn.Linear(hl2_size, batch_size)
        )

    def forward(self, x):
        x = x.view(-1, self.input_size + self.tissue_size)
        y = self.ffnn(x)
        return y

### Reconstruction loss for ELBO

OTU data have zero-inflation and over-dispersion, which means that the variance is by far greater than the mean. The Negative Binomial distribution accounts for that and it's assume to be a good fit to model the OTU data. This is also used in the scDREAMER to account for zero-inflation in single cell RNA seq data. We are going to try to apply a variation of it called Zero-Inflated Negative Binomial distribution (ZINB), and we are going to try model it with non-transformed OTU data. The NB probability density function is defined as follows:

$$
NB(x | \mu, \theta) = \frac{\Gamma(x + \theta)}{\Gamma(\theta)\Gamma(x + 1)} \left( \frac{\theta}{\theta + \mu}\right)^\theta \left( \frac{\mu}{\theta + \mu}\right)^x
$$

Which the log probability density function would look as following:

$$
\mathrm{log} NB(x|\mu, \theta) = \mathrm{log} (\Gamma(x + \theta)) - \mathrm{log}(\Gamma(\theta)) - \mathrm{log}(\Gamma(x + 1)) + \theta ( \mathrm{log} (\theta) - \mathrm{log} (\theta + \mu)) + x ( \mathrm{log} (\mu) - \mathrm{log}(\theta + \mu))
$$

Only if x > 0. If x = 0:

$$
\mathrm{log} NB(0|\mu, \theta) = \theta ( \mathrm{log} (\theta) - \mathrm{log} (\theta + \mu))
$$


This would be important to define because the ZINB probability distribution is defined by the following:

$$
ZINB(x | \mu, \theta, \pi) =        \begin{cases} 
                                        \pi + (1 - \pi) NB(x | \mu, \theta), \quad x > 0 \\
                                        (1 - \pi) NB(x | \mu, \theta), \quad \quad \quad x = 0
                                    \end{cases}
$$

In [13]:
# Define the negative log-likelihood of the zero-inflated negative binomial distribution
def zinb_nll(x, mu, theta, pi, eps = 1e-8):
    
    #First: define the NB for x > 0
    nb_case = (
        torch.lgamma(x + theta)
        - torch.lgamma(theta + eps)
        - torch.lgamma(x + 1)
        + theta * (torch.lgamma(theta + eps) - torch.log(theta + mu + eps))
        + x * (torch.log(mu + eps) - torch.log(mu + theta + eps))
    )

    #Second: define the NB for x = 0
    nb_zero = theta * (torch.log(theta + eps) - torch.log(theta + mu + eps))

    #Third: introduce ZINB log-likelihood
    ll = torch.where(
        x < eps,
        torch.log(pi + (1.0 - pi) * torch.exp(nb_zero) + eps),
        torch.log(1.0 - pi + eps) + nb_case
    )

    #Fourth: negative log-likelihood summed for every observation
    nll = -torch.sum(ll, dim = -1) #nll per sample
    return torch.mean(nll) #average from batch

### KL divergence for MoG

Because we are using a MoG model as our prior is a bit tricky to actually compute the KL-divergence of the posterior to the prior. First, let's define the prior distribution:

In [14]:
class MoGPrior(nn.Module):
    def __init__(self, d_z, n_comp, multiplier = 1.0):
        super().__init__()
        self.d_z = d_z
        self.n_comp = n_comp
        self.multiplier = multiplier

        # Params
        self.means = nn.Parameter(torch.randn(n_comp, self.d_z)*self.multiplier)
        self.logvars = nn.Parameter(torch.randn(n_comp, self.d_z))

        #Mix weights
        self.w = nn.Parameter(torch.zeros(n_comp, 1, 1))
    
    def get_params(self):
        return self.means, self.logvars
    
    def sample(self, batch_size):
        # mu, logvar
        means, logvars = self.get_params()

        # mixing probabilities
        w = nn.functional.softmax(self.w, dim=0)
        w = w.squeeze()

        # pick components
        indexes = torch.multinomial(w, batch_size, replacement=True)

        # means and logvars
        eps = torch.randn(batch_size, self.d_z)
        for i in range(batch_size):
            indx = indexes[i]
            if i == 0:
                z = means[[indx]] + eps[[i]] * torch.exp(logvars[[indx]])
            else:
                z = torch.cat((z, means[[indx]] + eps[[i]] * torch.exp(logvars[[indx]])), 0)
        
        return z
    
    def log_normal_diag(self, z, means, logvars):
        # var
        vars = torch.exp(logvars)

        # difference between z and means
        diff = z - means

        # log probability for each gaussian component
        # log p(x) = -0.5 * sum(log(2*pi) + log(var) + (x - mu)^2 / var)
        log_p = -0.5 * (torch.sum(diff ** 2 / vars, dim = -1) + torch.sum(logvars, dim = -1))
        log_p -= 0.5 * z.size(-1) * torch.log(torch.tensor(2 * torch.pi, device = z.device, dtype = z.dtype))

        return log_p

    
    def log_prob(self, z):
        # mu, logvar
        means, logvars = self.get_params()

        # mixing probabilities
        w = nn.functional.softmax(self.w, dim = 0)
        
        # log-MoG
        z = z.unsqueeze(0)
        means = means.unsqueeze(1)
        logvars = logvars.unsqueeze(1)

        log_p = self.log_normal_diag(z, means, logvars) + torch.log(w)
        log_prob = torch.logsumexp(log_p, dim = 0, keepdim=False)

        return log_prob

The KL-divergence is defined as follows:

$$
\mathrm{KL}[q(z | x) || p(z)] = \mathbb{E}_{z \sim q(z|x)}[\mathrm{log}q(z | x) - \mathrm{log}p(z)]
$$

We have already define how to calculate $\mathrm{log} p(z)$ from the prior's function log_prob, so we are going to focus on calculating $\mathrm{log} q(z | x)$:

In [15]:
def compute_log_qzx(z, mu, logvar, pi):

    # size of parameters: (batch_size, latent dimension, number of components)
    batch_size, d_z, n_comp = mu.shape
    
    # expand z variable for computing the density
    z = z.unsqueeze(2)

    # Log density function
    log_norm = -0.5 * (
        d_z * torch.log(torch.tensor(2 * torch.pi, device = z.device, dtype = z.dtype)) +
        torch.sum(logvar, dim=1) +
        torch.sum((z - mu)**2 / torch.exp(logvar), dim = 1)
    )

    # Add the log mixing weights (add small offset to avoid log(0))
    log_comp = torch.log(torch.tensor(pi + 1e-10, device = z.device, dtype = z.dtype)) + log_norm 

    return torch.logsumexp(log_comp, dim=1)

### KL-divergence for Gaussian Normal prior

Just as a proof of concept, we define a Normal prior to check that the computation of the KL divergence and latent space variables are done as expected (had some issues when using the MoG prior). If so, the next step would be to correct the MoG prior definitions and KL-divergence computation (it is recommended to use kl_divergence() from torch.distributions).

In [None]:
class NormalPrior(nn.Module):
    def __init__(self, d_z):
        self.d_z = d_z
        self.mean = nn.Parameter(torch.zeros(self.d_z), requires_grad=False)
        self.std = nn.Parameter(torch.ones(self.d_z), requires_grad=False)

    def forward(self):
        
        return torch.distributions.Independent(Normal(loc=self.mean, scale=self.std), 1)

### Sample ZINB with learned parameters

It is really hard to sample directly from the ZINB distribution, so we are going to use what's called the Gamma-Poisson trick:

We start with a variable $\lambda$ that follows a Gamma distribution given the parameters shape $\theta$ and scale $\beta = \mu / \theta$:
$$
\lambda \sim \mathrm{Gamma}(\theta, \beta)
$$

Where the pdf would be:
$$
p(\lambda) = \frac{1}{\Gamma(\theta) \beta^\theta} \lambda^{\theta - 1} e^{-\lambda / \beta}
$$
$$
p(\lambda) = \frac{1}{\Gamma(\theta)} \left(\frac{\theta}{\mu} \right)^\theta \lambda ^ {\theta - 1} e ^{-\lambda \theta / \mu}
$$

Now we assume that our data $x$ given the value $\lambda$ follows a Poisson distribution, so that:

$$
x | \lambda \sim \mathrm{Poisson} (\lambda)
$$

Where the pdf would be:
$$
p(x | \lambda) = \frac{\lambda^x e^-\lambda}{x!}
$$

We can marginalize the probability of seeing our data $x$ by integrating over all possible values of $\lambda$:
$$
p(x) = \int_0^\infty p(x | \lambda) p(\lambda) d\lambda
$$

Which substituting for our define pdf, we get:
$$
p(x) = \frac{1}{x!\Gamma(\theta)} \left(\frac{\theta}{\mu} \right)^\theta \int_0^\infty \lambda ^ {\theta - 1} e ^{-\lambda \theta / \mu} \lambda^x e^{-\lambda}
$$
$$
p(x) = \frac{1}{x!\Gamma(\theta)} \left(\frac{\theta}{\mu} \right)^\theta \int_0^\infty \lambda ^ {x + \theta - 1} e ^{-\lambda \left( 1 + \frac{\theta}{\mu} \right)} 
$$

Integrating we get:
$$
p(x) = \frac{1}{x!\Gamma(\theta)} \left(\frac{\theta}{\mu} \right)^\theta \Gamma(x + \theta) \left( 1 + \frac{\theta}{\mu}\right)^{-(x + \theta)}
$$

$$
p(x) = \frac{\Gamma(x + \theta)}{x!\Gamma(\theta)} \left(\frac{\theta}{\mu} \right)^\theta \left( 1 + \frac{\theta}{\mu}\right)^{-(x + \theta)}
$$

$$
p(x) = \frac{\Gamma(x + \theta)}{x!\Gamma(\theta)} \left(\frac{\theta}{\mu} \right)^\theta \left(\frac{\mu}{\theta + \mu}\right)^{x + \theta}
$$

$$
p(x) = \frac{\Gamma(x + \theta)}{x!\Gamma(\theta)} \left(\frac{\theta}{\theta + \mu} \right)^\theta \left(\frac{\mu}{\theta + \mu}\right)^{x} = NB(x)
$$

So, for the case where $\pi$ > 0 on the ZINB distribution:
$$
x \sim \mathrm{Poisson}(\lambda), \quad \lambda \sim \Gamma(\theta, \mu / \theta)
$$

In [16]:
def sample_zinb(mu, theta, pi, eps=1e-8):
    # generate random number from [0, 1) as threshold for dropout 
    drop_rand = torch.rand_like(pi)
    drop = drop_rand < pi # True means probability of zero value is higher than threshold, so x = 0

    # sample from NB through Gamma-Poisson trick
    gamma_theta = theta + eps
    gamma_beta = mu / gamma_theta
    gamma_sample = torch._standard_gamma(gamma_theta) * gamma_beta
    nb_sample = torch.poisson(gamma_sample)

    # ZINB with drop being True
    zinb_sample = torch.where(drop, torch.zeros_like(nb_sample), nb_sample)
    return zinb_sample

### Training loop

In [85]:
# Training loop for the model
def train_MetaABaCo(
        encoder: ABaCoEncoder,
        decoder: ABaCoDecoder,
        discriminator: BatchDiscriminator,
        mog_prior: nn.Module,
        train_loader,
        ohe_exp_loader,
        num_epochs,
        batch_size,
        device,
        lr_encode = 1e-5,
        lr_decode = 1e-5,
        lr_batch = 1e-5,
        lr_adver = 1e-5,
        kl_beta = 1.0,
        w_batch = 1.0,
        w_adv = 1.0
):
    encoder_optim = torch.optim.Adam(encoder.parameters(), lr = lr_encode, weight_decay=1e-5)
    decoder_optim = torch.optim.Adam(decoder.parameters(), lr = lr_decode, weight_decay=1e-5)
    batch_optim = torch.optim.Adam(discriminator.parameters(), lr = lr_batch, weight_decay=1e-5)
    adversarial_optim = torch.optim.Adam(encoder.parameters(), lr = lr_adver, weight_decay=1e-5)

    # Define loss functions for adversarial training
    disc_criterion = nn.CrossEntropyLoss()

    train_losses = []

    for epoch in range(num_epochs):
        # Training
        encoder.train()
        decoder.train()

        train_loss = 0

        for (x, _, _), (ohe_exp) in zip(train_loader, ohe_exp_loader):
            #Forward pass to encoder
            x = x.to(device)
            x = x.to(torch.float64)

            ohe_exp = ohe_exp.to(torch.float64)
            
            ohe_batch = x[:, -batch_size:]
            
            mu, logvar, pi = encoder(x)

            #Forward pass to decoder
            z, mu_nb, theta, pi_zinb = decoder(mu, logvar, pi)
            
            #Forward pass to batch discriminator
            pred_batch = discriminator(torch.concat((z.detach(), ohe_exp), 1))

            #CE loss for discriminator
            disc_loss = w_batch*disc_criterion(pred_batch, ohe_batch)

            #Backpropagation - Batch discriminator
            batch_optim.zero_grad()
            disc_loss.backward()
            batch_optim.step()

            #Forward pass to batch discriminator a second time 
            pred_adv = discriminator(torch.concat((z, ohe_exp), 1))
            
            #Adversarial loss - Negative Cross Entropy loss
            adv_loss = -w_adv*torch.mean(torch.sum(ohe_batch * torch.log(pred_adv), dim=1))

            #Backpropagation - Encoder
            encoder_optim.zero_grad()
            adv_loss.backward()
            encoder_optim.step()

            #Detach encoder and decoder outputs and recalculate
            x = x.detach()
            mu, logvar, pi = encoder(x)
            z, mu_nb, theta, pi_zinb = decoder(mu, logvar, pi)

            #Reconstructed OTU data
            # recon_data = sample_zinb(mu_nb, logtheta, pi_zinb)

            #Reconstruction loss
            recon_loss = zinb_nll(x[:,:-batch_size], mu_nb, theta, pi_zinb)
            #print(f"Reconstruction loss: {recon_loss}")

            #Compute prior log probability
            log_pz = mog_prior.log_prob(z)

            #Compute log density under posterior q(z | x) 
            log_qzx = compute_log_qzx(z, mu, logvar, pi)

            #KL divergence
            kl_div = kl_beta * torch.mean(log_qzx - log_pz)
            #print(f"KL-divergence: {kl_div}")

            # Total loss - ELBO
            loss = recon_loss + kl_div
            #print(f"Total loss: {loss}")

            # Backward pass - ELBO
            encoder_optim.zero_grad()
            decoder_optim.zero_grad()
            loss.backward()
            encoder_optim.step()
            decoder_optim.step()

            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        print(f"Epoch {epoch + 1}/{num_epochs},  Recon. Loss: {recon_loss:.4f}   |   KL-Div: {kl_div:.4f}    |    Loss: {train_loss:.4f}")
    
    return train_losses


### The data - MGnify DTU study

In [102]:
file = "metadataset_w_study_dataset_all_biomes_merged_abund_tables_genus.csv"
path = f"data/MGnify/datasets/{file}"

batch_label = "instrument_platform"
exp_label = "biomes"
drop_cols = ["experiment_type", "study_id", "centre_name", "index"]

raw_data = DataPreprocess(path, factors = ["sample", batch_label, exp_label]).dropna().reset_index(drop=True) #drop samples without meta info
pre_data = raw_data[raw_data["centre_name"]=="DTU-GE"].reset_index(drop=True)
new_data = pre_data.drop(drop_cols, axis=1)
new_data[exp_label] = new_data[exp_label].str.replace("root:Engineered:", "", regex=False) #remove redundant label

#exclude batches that can't be corrected
data = new_data
data = new_data[new_data[exp_label] != "Wastewater"]

#plot PCA of data to visualize it
plotPCoA(data, method="aitchison", sample_label="sample", batch_label=batch_label, experiment_label=exp_label)

### Prepare data for MetaABaCo

In [104]:
#Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Using {device}")

otu_dataloader, ohe_batch, ohe_biome, otu_data, otu_batch, otu_biome = ABaCoDataLoader(data, 
                                                                                       device = device, 
                                                                                       batch_label=batch_label, 
                                                                                       exp_label=exp_label, 
                                                                                       batch_size = 145, 
                                                                                       total_size = 3326, 
                                                                                       total_batch=2)

Using cuda


### Train MetaABaCo

In [134]:
num_epochs = 500

d_z = 16
encoder_model = ABaCoEncoder(d_z = d_z, input_size=3326, batch_size=2, n_comp=1).to(device)
encoder_model = encoder_model.double()
decoder_model = ABaCoDecoder(d_z = d_z, n_comp = 1, output_size=3326).to(device)
decoder_model = decoder_model.double()
batch_model = BatchDiscriminator(input_size = d_z, batch_size = 2, tissue_size = 1).to(device)
batch_model = batch_model.double()
prior = MoGPrior(d_z = d_z, n_comp = 1).to(device)

# Train
mabaco_loss = train_MetaABaCo(encoder = encoder_model,
                              decoder = decoder_model,
                              discriminator=batch_model,
                              mog_prior = prior,
                              train_loader=otu_dataloader,
                              ohe_exp_loader=ohe_biome,
                              num_epochs=num_epochs,
                              batch_size=2,
                              device=device,
                              lr_encode=1e-5,
                              lr_decode=1e-5,
                              kl_beta=1,
                              w_batch=1,
                              w_adv=1)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Epoch 1/500,  Recon. Loss: -303069237148590.5625   |   KL-Div: 802526731757718189046462525283336737656821832774761330134357073495945551436750234000348478472262925974845305084390560031920730323439486202030465008056086064741359691436013163330404352.0000    |    Loss: 802526731757718189046462525283336737656821832774761330134357073495945551436750234000348478472262925974845305084390560031920730323439486202030465008056086064741359691436013163330404352.0000
Epoch 2/500,  Recon. Loss: -300909179754422.2500   |   KL-Div: 962266887939979960553885106432964106090291882801410667249821045717724688105477491161940837468476227840144992546990793735410555473182845441601131408037324613144385693761618145208958976.0000    |    Loss: 962266887939979960553885106432964106090291882801410667249821045717724688105477491161940837468476227840144992546990793735410555473182845441601131408037324613144385693761618145208958976.0000
Epoch 3/500,  Recon. Loss: -303822138050609.8125   |   KL-Div: 648901285258554874867365956

In [147]:
prior.get_params()

(Parameter containing:
 tensor([[-0.2649, -0.2112,  2.1581, -0.1379, -0.0099,  1.0127,  1.5908,  1.8094,
          -1.2034,  0.3217,  1.0955,  0.5312,  0.1682, -1.1402,  1.4156,  0.1248]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([[-0.0463, -1.9337, -0.6377, -1.4303,  1.0016, -0.8430, -1.9405, -0.1447,
          -1.9474, -0.5752,  0.1468, -1.5211,  0.1208, -0.4031,  0.9079,  0.3954]],
        device='cuda:0', requires_grad=True))

### Reconstructed data

In [135]:
out_params = []

for x, _, _ in otu_dataloader:
    x = x.to(device)
    x = x.to(torch.float64)
    z_mu, z_logvar, z_pi = encoder_model(x)
    z, mu, theta, pi = decoder_model(z_mu, z_logvar, z_pi)
    out_params.append([mu, theta, pi])

In [136]:
recon = []
for mu, theta, pi in out_params:
    sample = sample_zinb(mu, theta, pi)
    recon.append(sample.tolist())


In [137]:
recon_pd = []
for batch in recon:
    for x in batch:
        recon_pd.append(x)

recon_pd = np.array(recon_pd)  # Convert list to NumPy array
recon_pd = recon_pd.reshape(-1, recon_pd.shape[-1])

recon_pd = pd.concat([pd.DataFrame(recon_pd, index = otu_data.index, columns = otu_data.columns),
                          otu_batch,
                          otu_biome,
                          data["sample"]],
                          axis=1)

plotPCoA(recon_pd, method="aitchison", sample_label="sample", batch_label=batch_label, experiment_label=exp_label)

In [138]:
recon_pd

Unnamed: 0,OTU1,OTU2,OTU3,OTU4,OTU5,OTU6,OTU7,OTU8,OTU9,OTU10,...,OTU3320,OTU3321,OTU3322,OTU3323,OTU3324,OTU3325,OTU3326,instrument_platform,biomes,sample
167,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Illumina HiSeq 3000,Wastewater:Water and sludge,ERR1713331
168,0.0,0.0,0.0,0.0,0.0,0.0,0.0,17689.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Illumina HiSeq 3000,Wastewater:Water and sludge,ERR1713332
169,0.0,0.0,0.0,0.0,0.0,0.0,0.0,62.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Illumina HiSeq 3000,Wastewater:Water and sludge,ERR1725942
170,1.0,0.0,0.0,0.0,1.0,0.0,0.0,2.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Illumina HiSeq 3000,Wastewater:Water and sludge,ERR1725946
171,0.0,0.0,0.0,0.0,0.0,0.0,0.0,235.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Illumina HiSeq 3000,Wastewater:Water and sludge,ERR1725948
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
307,0.0,0.0,3.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Illumina MiSeq,Wastewater:Water and sludge,ERR1512992
308,2.0,4.0,0.0,2.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,Illumina MiSeq,Wastewater:Water and sludge,ERR1512999
309,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,Illumina MiSeq,Wastewater:Water and sludge,ERR1513000
310,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Illumina MiSeq,Wastewater:Water and sludge,ERR1513001


In [146]:
combat_data = correctCombat(data, sample_label="sample", batch_label=batch_label, experiment_label=exp_label)
combat_data

Found 2 batches.
Adjusting for 0 covariate(s) or covariate level(s).
Standardizing Data across genes.
Fitting L/S model and finding priors.
Finding parametric adjustments.
Adjusting the Data



invalid value encountered in divide



Unnamed: 0,sample,instrument_platform,biomes,OTU1,OTU2,OTU3,OTU4,OTU5,OTU6,OTU7,...,OTU3317,OTU3318,OTU3319,OTU3320,OTU3321,OTU3322,OTU3323,OTU3324,OTU3325,OTU3326
167,ERR1713331,Illumina HiSeq 3000,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
168,ERR1713332,Illumina HiSeq 3000,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
169,ERR1725942,Illumina HiSeq 3000,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
170,ERR1725946,Illumina HiSeq 3000,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
171,ERR1725948,Illumina HiSeq 3000,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
307,ERR1512992,Illumina MiSeq,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
308,ERR1512999,Illumina MiSeq,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
309,ERR1513000,Illumina MiSeq,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
310,ERR1513001,Illumina MiSeq,Wastewater:Water and sludge,,,,,,,,...,,,,,,,,,,
