### Libraries

In [2]:
#Essentials
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Subset, ConcatDataset, TensorDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from umap import UMAP

#User libraries
from BatchEffectDataLoader import DataPreprocess, DataTransform, ABaCoDataLoader
from BatchEffectCorrection import correctCombat
from BatchEffectPlots import plotPCA
from BatchEffectMetrics import kBET, iLISI, cLISI, ARI, ASW
from ABaCo import ABaCo, BatchDiscriminator, TissueClassifier

>> clustergrammer2 backend version 0.18.0


### Data creation (again)

In [8]:
#data
# file = "dataset_all_biomes_merged_abund_tables_phylum.csv"
# data = pd.read_csv(f"data/MGnify/preprocessed_data/{file}")

# #metadata
# meta_data = pd.read_csv("data/MGnify/raw_data/Mgnify_analyses_wwt_shot_metag_assembly.csv")
# meta_data = meta_data[["assembly_run_id", "experiment_type", "instrument_platform", "biomes", "study_id", "centre_name"]]
# meta_data.rename(columns={"assembly_run_id":"sample"}, inplace=True)

# #merge data based on sample ID
# data_merged = pd.merge(meta_data, data, on="sample", how="right")
# data_merged = data_merged.drop_duplicates()
# data_merged = data_merged.reset_index()

# #save file
# data_merged.to_csv(f"data/MGnify/datasets/metadataset_w_study_{file}", index=False)

### Filtering data for one centre

In [47]:
file = "metadataset_w_study_dataset_all_biomes_merged_abund_tables_genus.csv"
path = f"data/MGnify/datasets/{file}"

batch_label = "instrument_platform"
exp_label = "biomes"
drop_cols = ["experiment_type", "study_id", "centre_name", "index"]

raw_data = DataPreprocess(path, factors = ["sample", batch_label, exp_label]).dropna().reset_index(drop=True) #drop samples without meta info
pre_data = raw_data[raw_data["centre_name"]=="DTU-GE"].reset_index(drop=True)
pre_data = pre_data.drop(drop_cols, axis=1)
data = DataTransform(pre_data, factors = ["sample", batch_label, exp_label], count=True)
data[exp_label] = data[exp_label].str.replace("root:Engineered:", "", regex=False) #remove redundant label

#plot PCA of data to visualize it
plotPCA(data, sample_label="sample", batch_label=batch_label, experiment_label=exp_label)

In [48]:
data

Unnamed: 0,sample,instrument_platform,biomes,OTU1,OTU2,OTU3,OTU4,OTU5,OTU6,OTU7,...,OTU3317,OTU3318,OTU3319,OTU3320,OTU3321,OTU3322,OTU3323,OTU3324,OTU3325,OTU3326
0,ERR2683114,Illumina NovaSeq 6000,Wastewater,3.952985,4.623490,2.146188,-0.251708,-0.251708,-0.251708,-0.251708,...,-0.251708,-0.251708,-0.251708,-0.251708,-0.251708,-0.251708,-0.251708,-0.251708,-0.251708,-0.251708
1,ERR2683115,Illumina NovaSeq 6000,Wastewater,1.608462,2.707074,-0.183297,-0.183297,-0.183297,-0.183297,-0.183297,...,-0.183297,-0.183297,-0.183297,-0.183297,-0.183297,-0.183297,-0.183297,-0.183297,-0.183297,-0.183297
2,ERR2683118,Illumina NovaSeq 6000,Wastewater,3.470410,6.700776,3.424948,-0.336252,-0.336252,-0.336252,-0.336252,...,-0.336252,-0.336252,-0.336252,-0.336252,-0.336252,-0.336252,-0.336252,-0.336252,-0.336252,-0.336252
3,ERR2683121,Illumina NovaSeq 6000,Wastewater,6.033057,3.072553,2.117042,-0.185543,-0.185543,-0.185543,-0.185543,...,-0.185543,-0.185543,-0.185543,-0.185543,-0.185543,-0.185543,-0.185543,-0.185543,-0.185543,-0.185543
4,ERR2683124,Illumina NovaSeq 6000,Wastewater,4.808130,3.003125,0.882862,-0.215751,0.882862,-0.215751,-0.215751,...,-0.215751,-0.215751,-0.215751,-0.215751,-0.215751,-0.215751,-0.215751,-0.215751,-0.215751,-0.215751
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
307,ERR1512992,Illumina MiSeq,Wastewater:Water and sludge,-0.078683,1.019930,-0.078683,-0.078683,-0.078683,-0.078683,-0.078683,...,-0.078683,-0.078683,-0.078683,-0.078683,-0.078683,-0.078683,-0.078683,-0.078683,-0.078683,-0.078683
308,ERR1512999,Illumina MiSeq,Wastewater:Water and sludge,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
309,ERR1513000,Illumina MiSeq,Wastewater:Water and sludge,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208,...,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208,-0.000208
310,ERR1513001,Illumina MiSeq,Wastewater:Water and sludge,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


### Prepare Data for ABaCo

In [52]:
#Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

otu_dataloader, ohe_batch, ohe_biome, otu_data, otu_batch, otu_biome = ABaCoDataLoader(data, 
                                                                                       device = device, 
                                                                                       batch_label=batch_label, 
                                                                                       exp_label=exp_label, 
                                                                                       batch_size = 312, 
                                                                                       total_size = 3326, 
                                                                                       total_batch=3)

Using cuda


### Training ABaCo

In [85]:
#Epochs 
num_epochs=5000

#Defining autoencoder and other models
d_z = 128
model = ABaCo(d_z=d_z, input_size=3326, batch_size=3).to(device)
batch_discriminator = BatchDiscriminator(input_size=3326, batch_size=3, tissue_size=2).to(device)
latent_classifier = TissueClassifier(input_size=d_z, tissue_size=2).to(device)
output_classifier = TissueClassifier(input_size=3326, tissue_size=2).to(device)

# Training
_, _, _ = model.train_model(batch_model=batch_discriminator,
                latent_class_model=latent_classifier,
                out_class_model=output_classifier,
                train_loader=otu_dataloader,
                ohe_exp_loader=ohe_biome,
                num_epochs=num_epochs,
                w_recon=1.0,
                lr_recon=1e-3,
                w_adver=10.0,
                lr_adver=1e-3,
                w_disc=10.0,
                lr_disc=1e-3,
                w_latent=1.0,
                lr_latent=1e-3,
                w_output=1.0,
                lr_output=1e-3,
                device=device,
                model_name="ABC"
                )


reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.



Epoch 1/5000 | Dis. Train Loss: 11.1820 | Adv. Train Loss: 0.0001 | Recon. Train Loss: 0.6747 | Lat. Train Loss: 0.6947 | Out. Train Loss: 0.6747
Epoch 2/5000 | Dis. Train Loss: 11.0555 | Adv. Train Loss: 0.0067 | Recon. Train Loss: 0.7110 | Lat. Train Loss: 0.6322 | Out. Train Loss: 0.7110
Epoch 3/5000 | Dis. Train Loss: 11.2426 | Adv. Train Loss: 0.0048 | Recon. Train Loss: 0.6833 | Lat. Train Loss: 0.5572 | Out. Train Loss: 0.6833
Epoch 4/5000 | Dis. Train Loss: 10.6458 | Adv. Train Loss: 0.0490 | Recon. Train Loss: 0.6601 | Lat. Train Loss: 0.4731 | Out. Train Loss: 0.6601
Epoch 5/5000 | Dis. Train Loss: 10.3692 | Adv. Train Loss: 0.0851 | Recon. Train Loss: 0.6475 | Lat. Train Loss: 0.4070 | Out. Train Loss: 0.6475
Epoch 6/5000 | Dis. Train Loss: 9.7045 | Adv. Train Loss: 0.2319 | Recon. Train Loss: 0.6341 | Lat. Train Loss: 0.3361 | Out. Train Loss: 0.6341
Epoch 7/5000 | Dis. Train Loss: 8.7669 | Adv. Train Loss: 0.5519 | Recon. Train Loss: 0.6251 | Lat. Train Loss: 0.2661 | Out.

### Batch corrected data

In [86]:
otu_batch_corrected = []

#Load data into a single batch
one_batch_data, _, _, _, _, _ = ABaCoDataLoader(data, 
                                                device = device, 
                                                batch_label=batch_label, 
                                                exp_label=exp_label, 
                                                batch_size = 312, 
                                                total_size = 3326, 
                                                total_batch=3)

for x, _, _ in one_batch_data:

    otu_batch_corrected.append(model(x).tolist())

otu_batch_corrected = np.array(otu_batch_corrected)  # Convert list to NumPy array
otu_batch_corrected = otu_batch_corrected.reshape(-1, otu_batch_corrected.shape[-1])

otu_corrected_pd = pd.concat([pd.DataFrame(otu_batch_corrected, index = otu_data.index, columns = otu_data.columns),
                          otu_batch,
                          otu_biome,
                          data["sample"]],
                          axis=1)

In [87]:
plotPCA(otu_corrected_pd, sample_label="sample", batch_label=batch_label, experiment_label=exp_label)

In [None]:
correctCombat()