### Libraries

In [1]:
#Essentials
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Subset, ConcatDataset, TensorDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from umap import UMAP

#User libraries
from BatchEffectDataLoader import DataPreprocess, DataTransform
from BatchEffectPlots import plotPCA
from ABaCo import ABaCo, BatchDiscriminator, TissueClassifier

>> clustergrammer2 backend version 0.18.0


### Initialize ABaCo model

In [2]:
model = ABaCo()

### DataLoader

In [3]:
#Function to One-Hot encoding
def one_hot_encoding(labels):
    # Dictionary of batch labels
    alphabet = labels.unique()
    label_to_int = {label: i for i, label in enumerate(alphabet)}

    # Initialize the one-hot encoded matrix
    one_hot = np.zeros((len(labels), len(alphabet)), dtype=int)

    # Fill the matrix
    for i, label in enumerate(labels):
        if label in label_to_int:
            one_hot[i, label_to_int[label]] = 1

    return torch.tensor(one_hot), alphabet.tolist()

#Function for classes to int
def class_to_int(labels):
    # Dictionary of batch labels
    alphabet = labels.unique()
    label_to_int = {label: i for i, label in enumerate(alphabet)}

    # Initialize the empty array
    classes = np.zeros(len(labels), dtype=int)

    # Fill the matrix
    for i, label in enumerate(labels):
        classes[i] = label_to_int[label]
    
    return torch.tensor(classes)

In [4]:
#Load and preprocess data
path = "data/dataset_sponge.csv"
data = DataTransform(DataPreprocess(path))

#Convert data to tensor (structure: tensor([otus], [batch]))
otu_data = data.select_dtypes(include = "number")
otu_tensor = torch.tensor(otu_data.values, dtype = torch.float32)

#Extract labels and convert to one hot encoding matrix
data_batch = data["batch"]
data_tissue = data["tissue"]
ohe_batch, labels_batch = one_hot_encoding(data_batch)
ohe_tissue, labels_tissue = one_hot_encoding(data_tissue)

# dataloader = DataLoader(TensorDataset(otu_tensor), batch_size = 32) #this should be the correct way to define it
batch_size = 32

otu_dataloader = DataLoader(otu_tensor, batch_size = batch_size)
batch_dataloader = DataLoader(ohe_batch, batch_size = batch_size)
tissue_dataloader = DataLoader(ohe_tissue, batch_size = batch_size)

#Defining DataLoader for otus + batch information
otu_batch_tensor = torch.concat((otu_tensor, ohe_batch), 1)
otu_batch_dataloader = DataLoader(otu_batch_tensor, batch_size = batch_size)

#Defining DataLoader for otus + tissue information, also including batch as label for discriminator training
otu_tissue_tensor = torch.concat((otu_tensor, ohe_tissue), 1)
otu_tissue_dataloader = DataLoader(TensorDataset(otu_tissue_tensor, class_to_int(data_batch)), batch_size = batch_size)

#Defining DataLoader for otus including tissue as label for classifier training
otu_tissue_class_dataloader = DataLoader(TensorDataset(otu_tensor, class_to_int(data_tissue)), batch_size = batch_size)

#Defining DataLoader for otus including + batch information, also including tissue as label for discriminator training
abaco_dataloader = DataLoader(TensorDataset(otu_batch_tensor, class_to_int(data_tissue)), batch_size = batch_size)

### Train model

In [13]:
#Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs=1000
print(f"Using {device}")

#Defining autoencoder and optimizer
d_z = 10

batch_discriminator = BatchDiscriminator().to(device)
latent_classifier = TissueClassifier(input_size=d_z)
output_classifier = TissueClassifier(input_size=24)

#Training loop
model.train_model(batch_model=batch_discriminator,
            latent_class_model=latent_classifier,
            out_class_model=output_classifier,
            train_loader=abaco_dataloader,
            tissue_ohe=ohe_tissue,
            num_epochs=num_epochs,
            w_recon=0.5,
            w_adver=1.0,
            w_disc=1.0,
            w_latent=1.0,
            w_output=1.0,
            device=device,
            model_name="ABC"
            )

Using cpu
Epoch 1/1000 | Dis. Train Loss: 0.7295 | Adv. Train Loss: 0.7295 | Tri. Train Loss: 0.7295
Epoch 2/1000 | Dis. Train Loss: 0.6810 | Adv. Train Loss: 0.6810 | Tri. Train Loss: 0.6810
Epoch 3/1000 | Dis. Train Loss: 0.6618 | Adv. Train Loss: 0.6618 | Tri. Train Loss: 0.6618
Epoch 4/1000 | Dis. Train Loss: 0.6410 | Adv. Train Loss: 0.6410 | Tri. Train Loss: 0.6410
Epoch 5/1000 | Dis. Train Loss: 0.6299 | Adv. Train Loss: 0.6299 | Tri. Train Loss: 0.6299
Epoch 6/1000 | Dis. Train Loss: 0.6217 | Adv. Train Loss: 0.6217 | Tri. Train Loss: 0.6217
Epoch 7/1000 | Dis. Train Loss: 0.6081 | Adv. Train Loss: 0.6081 | Tri. Train Loss: 0.6081
Epoch 8/1000 | Dis. Train Loss: 0.5940 | Adv. Train Loss: 0.5940 | Tri. Train Loss: 0.5940
Epoch 9/1000 | Dis. Train Loss: 0.5847 | Adv. Train Loss: 0.5847 | Tri. Train Loss: 0.5847
Epoch 10/1000 | Dis. Train Loss: 0.5761 | Adv. Train Loss: 0.5761 | Tri. Train Loss: 0.5761
Epoch 11/1000 | Dis. Train Loss: 0.5682 | Adv. Train Loss: 0.5682 | Tri. Train 


reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.



Epoch 26/1000 | Dis. Train Loss: 0.6040 | Adv. Train Loss: 0.6040 | Tri. Train Loss: 0.6040
Epoch 27/1000 | Dis. Train Loss: 0.6023 | Adv. Train Loss: 0.6023 | Tri. Train Loss: 0.6023
Epoch 28/1000 | Dis. Train Loss: 0.6030 | Adv. Train Loss: 0.6030 | Tri. Train Loss: 0.6030
Epoch 29/1000 | Dis. Train Loss: 0.6048 | Adv. Train Loss: 0.6048 | Tri. Train Loss: 0.6048
Epoch 30/1000 | Dis. Train Loss: 0.6054 | Adv. Train Loss: 0.6054 | Tri. Train Loss: 0.6054
Epoch 31/1000 | Dis. Train Loss: 0.6193 | Adv. Train Loss: 0.6193 | Tri. Train Loss: 0.6193
Epoch 32/1000 | Dis. Train Loss: 0.6350 | Adv. Train Loss: 0.6350 | Tri. Train Loss: 0.6350
Epoch 33/1000 | Dis. Train Loss: 0.6523 | Adv. Train Loss: 0.6523 | Tri. Train Loss: 0.6523
Epoch 34/1000 | Dis. Train Loss: 0.6671 | Adv. Train Loss: 0.6671 | Tri. Train Loss: 0.6671
Epoch 35/1000 | Dis. Train Loss: 0.6758 | Adv. Train Loss: 0.6758 | Tri. Train Loss: 0.6758
Epoch 36/1000 | Dis. Train Loss: 0.6778 | Adv. Train Loss: 0.6778 | Tri. Train L

([0.7294992208480835,
  0.6809887886047363,
  0.6617615222930908,
  0.641046404838562,
  0.6298626065254211,
  0.6217373013496399,
  0.6080752611160278,
  0.5940495729446411,
  0.5846865177154541,
  0.5761377215385437,
  0.5681780576705933,
  0.5615565776824951,
  0.5550568699836731,
  0.5479275584220886,
  0.5424478650093079,
  0.541265606880188,
  0.5460706353187561,
  0.5599347949028015,
  0.5758172273635864,
  0.5875226259231567,
  0.5960022211074829,
  0.602798581123352,
  0.6058396100997925,
  0.6054744124412537,
  0.6054798364639282,
  0.6039519309997559,
  0.6022631525993347,
  0.6029950976371765,
  0.6047565937042236,
  0.6054046154022217,
  0.6193276047706604,
  0.635047435760498,
  0.6522709727287292,
  0.6671214699745178,
  0.6758490800857544,
  0.677844762802124,
  0.670743465423584,
  0.6545021533966064,
  0.6244097948074341,
  0.5949820876121521,
  0.5723056793212891,
  0.5565605759620667,
  0.5494374632835388,
  0.5527268052101135,
  0.5618952512741089,
  0.581953287124

### Create DataFrame with batch-corrected output

In [14]:
batch_corrected = []
for x, _ in abaco_dataloader:

    batch_corrected.append(model(x).tolist())

batch_corrected = np.array(batch_corrected)  # Convert list to NumPy array
batch_corrected = batch_corrected.reshape(-1, batch_corrected.shape[-1])

corrected_pd = pd.concat([pd.DataFrame(batch_corrected, index = otu_data.index, columns = otu_data.columns),
                          data_batch,
                          data_tissue,
                          data["sample"]],
                          axis=1)

In [15]:
plotPCA(data)

In [16]:
plotPCA(corrected_pd)