### Libraries

In [1]:
#Essentials
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Subset, ConcatDataset, TensorDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from umap import UMAP

#User libraries
from BatchEffectDataLoader import DataPreprocess, DataTransform, ABaCoDataLoader
from BatchEffectPlots import plotPCA
from ABaCo import ABaCo, BatchDiscriminator, TissueClassifier

>> clustergrammer2 backend version 0.18.0


### Initialize ABaCo model

In [15]:
model = ABaCo()

### DataLoader

In [14]:
#Load and preprocess data
path = "data/dataset_sponge.csv"
data = DataTransform(DataPreprocess(path))
abaco_dataloader, ohe_batch, ohe_tissue, otu_data, data_batch, data_tissue = ABaCoDataLoader(data)

### Train model

In [16]:
#Setting up device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
num_epochs=1000
print(f"Using {device}")

#Defining autoencoder and optimizer
d_z = 10

batch_discriminator = BatchDiscriminator().to(device)
latent_classifier = TissueClassifier(input_size=d_z)
output_classifier = TissueClassifier(input_size=24)

#Training loop
_, _, _ = model.train_model(batch_model=batch_discriminator,
            latent_class_model=latent_classifier,
            out_class_model=output_classifier,
            train_loader=abaco_dataloader,
            tissue_ohe=ohe_tissue,
            num_epochs=num_epochs,
            w_recon=0.5,
            w_adver=1.0,
            w_disc=1.0,
            w_latent=1.0,
            w_output=1.0,
            device=device,
            model_name="ABC"
            )

Using cpu
Epoch 1/1000 | Dis. Train Loss: 0.7064 | Adv. Train Loss: 0.7064 | Tri. Train Loss: 0.7064
Epoch 2/1000 | Dis. Train Loss: 0.6963 | Adv. Train Loss: 0.6963 | Tri. Train Loss: 0.6963
Epoch 3/1000 | Dis. Train Loss: 0.6915 | Adv. Train Loss: 0.6915 | Tri. Train Loss: 0.6915
Epoch 4/1000 | Dis. Train Loss: 0.6880 | Adv. Train Loss: 0.6880 | Tri. Train Loss: 0.6880
Epoch 5/1000 | Dis. Train Loss: 0.6841 | Adv. Train Loss: 0.6841 | Tri. Train Loss: 0.6841
Epoch 6/1000 | Dis. Train Loss: 0.6804 | Adv. Train Loss: 0.6804 | Tri. Train Loss: 0.6804
Epoch 7/1000 | Dis. Train Loss: 0.6741 | Adv. Train Loss: 0.6741 | Tri. Train Loss: 0.6741
Epoch 8/1000 | Dis. Train Loss: 0.6694 | Adv. Train Loss: 0.6694 | Tri. Train Loss: 0.6694
Epoch 9/1000 | Dis. Train Loss: 0.6680 | Adv. Train Loss: 0.6680 | Tri. Train Loss: 0.6680
Epoch 10/1000 | Dis. Train Loss: 0.6720 | Adv. Train Loss: 0.6720 | Tri. Train Loss: 0.6720
Epoch 11/1000 | Dis. Train Loss: 0.6739 | Adv. Train Loss: 0.6739 | Tri. Train 


reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.



Epoch 23/1000 | Dis. Train Loss: 0.5984 | Adv. Train Loss: 0.5984 | Tri. Train Loss: 0.5984
Epoch 24/1000 | Dis. Train Loss: 0.6149 | Adv. Train Loss: 0.6149 | Tri. Train Loss: 0.6149
Epoch 25/1000 | Dis. Train Loss: 0.6335 | Adv. Train Loss: 0.6335 | Tri. Train Loss: 0.6335
Epoch 26/1000 | Dis. Train Loss: 0.6571 | Adv. Train Loss: 0.6571 | Tri. Train Loss: 0.6571
Epoch 27/1000 | Dis. Train Loss: 0.6983 | Adv. Train Loss: 0.6983 | Tri. Train Loss: 0.6983
Epoch 28/1000 | Dis. Train Loss: 0.7600 | Adv. Train Loss: 0.7600 | Tri. Train Loss: 0.7600
Epoch 29/1000 | Dis. Train Loss: 0.8180 | Adv. Train Loss: 0.8180 | Tri. Train Loss: 0.8180
Epoch 30/1000 | Dis. Train Loss: 0.8594 | Adv. Train Loss: 0.8594 | Tri. Train Loss: 0.8594
Epoch 31/1000 | Dis. Train Loss: 0.8806 | Adv. Train Loss: 0.8806 | Tri. Train Loss: 0.8806
Epoch 32/1000 | Dis. Train Loss: 0.8865 | Adv. Train Loss: 0.8865 | Tri. Train Loss: 0.8865
Epoch 33/1000 | Dis. Train Loss: 0.8730 | Adv. Train Loss: 0.8730 | Tri. Train L

### Create DataFrame with batch-corrected output

In [17]:
batch_corrected = []
for x, _ in abaco_dataloader:

    batch_corrected.append(model(x).tolist())

batch_corrected = np.array(batch_corrected)  # Convert list to NumPy array
batch_corrected = batch_corrected.reshape(-1, batch_corrected.shape[-1])

corrected_pd = pd.concat([pd.DataFrame(batch_corrected, index = otu_data.index, columns = otu_data.columns),
                          data_batch,
                          data_tissue,
                          data["sample"]],
                          axis=1)

In [18]:
plotPCA(data)

In [19]:
plotPCA(corrected_pd)