# Hyperparameter Search

The following notebook searched for the best set of hyperparameters for the model. We demostrate with an adipose tissue dataset, see [original source](https://scpca.alexslemonade.org/projects/SCPCP000004). The data is preprocessed in this [notebook](PNB_preprocessing.ipynb). 

The model is chosen based on general reconstruction in PCA, the maximum values of the reconstructed and input data, and the data transformation (reconstructing with opposite label). 


In [9]:
# import the dependencies
import sys

# general imports
import single_translator_VAE as sv
from importlib import reload

import os
import scanpy as sc
import torch.nn
from sklearn.preprocessing import StandardScaler

# Images, plots, display, and visualization
import matplotlib.pyplot as plt
import torch
import ray as ray
from matplotlib.patches import Patch

sys.path.insert(1, "../../")
sys.path.insert(1, "../")
sys.path.insert(1, "../../../../../")
reload(sv)

## Setting up AnnData with scVI

In [None]:
res_name = "PNB"

In [10]:
# Reading in data:
path = f"{os.getcwd()}/../../data/{res_name}/"
adata_path = os.path.join(path, f"sc_sn_{res_name}_train.h5ad")
adata_train = sc.read_h5ad(adata_path)

  utils.warn_names_duplicates("obs")


In [11]:
adata_path = os.path.join(path, f"sc_sn_{res_name}_test.h5ad")
adata_test = sc.read_h5ad(adata_path)

  utils.warn_names_duplicates("obs")


In [12]:
print(adata_train.obs.data_type.unique())
print(adata_train.obs.batch.unique())
print(adata_train.obs.patient_id.unique())
print("\t")
print(adata_test.obs.data_type.unique())
print(adata_test.obs.batch.unique())
print(adata_test.obs.patient_id.unique())

['single_nucleus']
Categories (1, object): ['single_nucleus']
['SCPCS000702', 'SCPCS000101', 'SCPCS000112', 'SCPCS000122', 'SCPCS000113', ..., 'SCPCS000690', 'SCPCS000696', 'SCPCS000688', 'SCPCS000689', 'SCPCS000116']
Length: 12
Categories (12, object): ['SCPCS000101', 'SCPCS000112', 'SCPCS000113', 'SCPCS000114', ..., 'SCPCS000690', 'SCPCS000696', 'SCPCS000699', 'SCPCS000702']
['SJNBL013763', 'SJNBL046148', 'SJNBL012407', 'SJNBL030339', 'SJNBL015724', 'SJNBL063820']
Categories (6, object): ['SJNBL012407', 'SJNBL013763', 'SJNBL015724', 'SJNBL030339', 'SJNBL046148', 'SJNBL063820']
	
['single_nucleus', 'single_cell']
Categories (2, object): ['single_cell', 'single_nucleus']
['SCPCS000697', 'SCPCS000687', 'SCPCS000110', 'SCPCS000111', 'SCPCS000108', 'SCPCS000103', 'SCPCS000109']
Categories (7, object): ['SCPCS000103', 'SCPCS000108', 'SCPCS000109', 'SCPCS000110', 'SCPCS000111', 'SCPCS000687', 'SCPCS000697']
['SJNBL046', 'SJNBL108', 'SJNBL031246', 'SJNBL066155', 'SJNBL031802']
Categories (5,

## Processing

In [13]:
adata_train = sv.vae.VAEModel.use_obs(adata=adata_train, adata_obs=["data_type"], labels_key="labels_key")
adata_test = sv.vae.VAEModel.use_obs(adata=adata_test, adata_obs=["data_type"], labels_key="labels_key")
adata_train.obs

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


Unnamed: 0,barcodes,sum,detected,subsets_mito_sum,subsets_mito_detected,subsets_mito_percent,total,prob_compromised,scpca_filter,sizeFactor,...,tissue_ontology_term_id,assay_ontology_term_id,suspension_type,is_primary_data,batch,patient_id,data_type,cell_types,TrainTest,labels_key
AGATCGTAGACGGTCA,AGATCGTAGACGGTCA,5022.953385,3088,14.500000,6,0.288675,5022.953385,0.018830,Keep,1.184083,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000114,SJNBL015724,single_nucleus,neuron,Train,single_nucleus
CTGTGGGCATGACGGA,CTGTGGGCATGACGGA,8764.000007,4539,40.000000,12,0.456413,8764.000007,0.014789,Keep,2.124256,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000114,SJNBL015724,single_nucleus,neuron,Train,single_nucleus
ATATCCTGTTATGGTC,ATATCCTGTTATGGTC,4837.999997,2688,51.214286,10,1.058584,4837.999997,0.031049,Keep,1.181775,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000112,SJNBL012407,single_nucleus,neuron,Train,single_nucleus
TCATCATAGTCGCGAA,TCATCATAGTCGCGAA,550.000001,520,10.000000,5,1.818182,550.000001,0.119369,Keep,0.680737,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000702,SJNBL013763,single_nucleus,neuron,Train,single_nucleus
TCTGCCAGTGCGTGCT,TCTGCCAGTGCGTGCT,7583.000006,3444,67.615920,12,0.891678,7583.000006,0.033250,Keep,0.696001,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000688,SJNBL012407,single_nucleus,neuron,Train,single_nucleus
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GCATGATAGGACCCAA,GCATGATAGGACCCAA,27183.952780,8006,85.136364,13,0.313186,27183.952780,0.064938,Keep,2.746487,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000688,SJNBL012407,single_nucleus,neuron,Train,single_nucleus
AATAGAGCACTGCATA,AATAGAGCACTGCATA,212.000000,214,3.000000,2,1.415094,212.000000,0.008823,Keep,0.582060,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000116,SJNBL063820,single_nucleus,fibroblast,Train,single_nucleus
CTTGAGACACGCGTCA,CTTGAGACACGCGTCA,2892.000001,1801,221.760000,12,7.668050,2892.000001,0.322007,Keep,0.470878,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000113,SJNBL013763,single_nucleus,neuron,Train,single_nucleus
CTGAGGCCACACCTAA,CTGAGGCCACACCTAA,4105.000002,2450,55.833334,11,1.360130,4105.000002,0.051058,Keep,0.740121,...,UBERON:0002369,EFO:0009922,nucleus,False,SCPCS000113,SJNBL013763,single_nucleus,neuron,Train,single_nucleus


In [14]:
adata_manager, adata_train = sv.vae.VAEModel.setup_anndata(
    adata=adata_train, batch_key="batch", labels_key="labels_key", layer="counts"
)
# Registering:
adata_manager_test, adata_test = sv.vae.VAEModel.setup_anndata(
    adata=adata_test, batch_key="batch", labels_key="labels_key", layer="counts"
)
print(adata_train.uns)
adata_manager.view_registry()

  self.validate_field(adata)


OrderedDict([('_scvi_uuid', '0c33d3d9-5de9-4c6c-82c8-767075982edc'), ('_scvi_manager_uuid', '7ece9ef3-fb5f-457c-9b3b-45d7d7b7ba7e')])


  self.validate_field(adata)


In [7]:
print("Training data:")
print(adata_train.obs.data_type.unique())
print(adata_train.obs.batch.unique())
print(adata_train.obs.patient_id.unique())
print("\t")
print("Testing data:")
print(adata_test.obs.data_type.unique())
print(adata_test.obs.batch.unique())
print(adata_test.obs.patient_id.unique())

Training data:
['single_nucleus']
Categories (1, object): ['single_nucleus']
['SCPCS000114', 'SCPCS000113', 'SCPCS000702', 'SCPCS000688', 'SCPCS000112', ..., 'SCPCS000116', 'SCPCS000101', 'SCPCS000122', 'SCPCS000690', 'SCPCS000696']
Length: 12
Categories (12, object): ['SCPCS000101', 'SCPCS000112', 'SCPCS000113', 'SCPCS000114', ..., 'SCPCS000690', 'SCPCS000696', 'SCPCS000699', 'SCPCS000702']
['SJNBL015724', 'SJNBL013763', 'SJNBL012407', 'SJNBL063820', 'SJNBL046148', 'SJNBL030339']
Categories (6, object): ['SJNBL012407', 'SJNBL013763', 'SJNBL015724', 'SJNBL030339', 'SJNBL046148', 'SJNBL063820']
	
Testing data:
['single_cell', 'single_nucleus']
Categories (2, object): ['single_cell', 'single_nucleus']
['SCPCS000109', 'SCPCS000697', 'SCPCS000111', 'SCPCS000110', 'SCPCS000108', 'SCPCS000103', 'SCPCS000687']
Categories (7, object): ['SCPCS000103', 'SCPCS000108', 'SCPCS000109', 'SCPCS000110', 'SCPCS000111', 'SCPCS000687', 'SCPCS000697']
['SJNBL031802', 'SJNBL046', 'SJNBL108', 'SJNBL031246', 

## Initial Training and Evaluating

1. Model configuration 1
This configuration uses a higher latent space (n_latent=700) with a moderate KL weight (kl_weight=2.5).
It doesn't use observed library size normalization and aims to balance reconstruction and latent space regularization.

In [8]:
model = sv.vae.VAEModel(
    adata=adata_train,
    max_epochs=300,
    n_latent=700,
    lr=0.001,
    batch_size=300,
    n_layers=3,
    n_hidden=768,
    use_batch_norm=False,
    use_observed_lib_size=False,
    dispersion="gene-label",
    dropout_rate=0.4,
    weight_decay=0.002,
    recon_weight=0.5,
    kl_weight=2.5,
    encode_batch=False,
)

train_loss, val_loss, complete_val_loss, complete_train_loss, train_losses, val_losses = model.train(max_epochs=300)
sv.pl.loss_plots("summed", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("reconstruction", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("kl", complete_train_loss, complete_val_loss, train_losses, val_losses)

input_tensor, labels_tensor, batch_tensor = sv.pp.prepare_data_and_labels(
    adata_test, label_map={"single_cell": 0, "single_nucleus": 1}
)

original_sn = adata_test[adata_test.obs["data_type"] == "single_nucleus"].X.toarray()
original_sc = adata_test[adata_test.obs["data_type"] == "single_cell"].X.toarray()

input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="cell_types"
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="batch"
)

print(input_data.max())
print(reconstructed_data.max())
print(input_data.min())
print(reconstructed_data.min())

# Transformations
sn_to_sc = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_cell"
)
sn_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_nucleus"
)
sc_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_nucleus"
)
sc_to_sc = sv.tl.transform(adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_cell")

sv.pl.transformed_pca_data(original_sn, original_sc, sn_to_sn, sc_to_sc, sc_to_sn, sn_to_sc)

[34mINFO    [0m Single Translator VAE: VAE model has been initialized                                                     


Epoch 10/300 - Train Loss: 1365.6237, Val Loss: 1331.3332:   3%|▎         | 10/300 [24:16<11:43:57, 145.65s/it]


KeyboardInterrupt: 

2. Model configuration 2
This configuration includes observed library size normalization and slightly higher KL weight (kl_weight=2.8).
It uses a latent space of 750 dimensions and a slightly lower dropout rate.

In [None]:
model = sv.vae.VAEModel(
    adata=adata_train,
    max_epochs=300,
    n_latent=750,
    lr=0.001,
    batch_size=300,
    n_layers=3,
    n_hidden=768,
    use_batch_norm=False,
    use_observed_lib_size=True,
    dispersion="gene-label",
    dropout_rate=0.3,
    weight_decay=0.0025,
    recon_weight=0.5,
    kl_weight=2.8,
    encode_batch=False,
)


train_loss, val_loss, complete_val_loss, complete_train_loss, train_losses, val_losses = model.train(max_epochs=300)
sv.pl.loss_plots("summed", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("reconstruction", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("kl", complete_train_loss, complete_val_loss, train_losses, val_losses)

input_tensor, labels_tensor, batch_tensor = sv.pp.prepare_data_and_labels(
    adata_test, label_map={"single_cell": 0, "single_nucleus": 1}
)

original_sn = adata_test[adata_test.obs["data_type"] == "single_nucleus"].X.toarray()
original_sc = adata_test[adata_test.obs["data_type"] == "single_cell"].X.toarray()

input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="cell_types"
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="batch"
)

print(input_data.max())
print(reconstructed_data.max())
print(input_data.min())
print(reconstructed_data.min())

# Transformations
sn_to_sc = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_cell"
)
sn_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_nucleus"
)
sc_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_nucleus"
)
sc_to_sc = sv.tl.transform(adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_cell")

sv.pl.transformed_pca_data(original_sn, original_sc, sn_to_sn, sc_to_sc, sc_to_sn, sn_to_sc)

3. Model configuration 3
A slightly smaller latent space (n_latent=650) with higher regularization through weight_decay and KL weight.
This configuration does not normalize for library size differences.

In [None]:
model = sv.vae.VAEModel(
    adata=adata_train,
    max_epochs=300,
    n_latent=650,
    lr=0.001,
    batch_size=300,
    n_layers=3,
    n_hidden=768,
    use_batch_norm=False,
    use_observed_lib_size=False,
    dispersion="gene-label",
    dropout_rate=0.3,
    weight_decay=0.003,
    recon_weight=0.5,
    kl_weight=2.5,
    encode_batch=False,
)


train_loss, val_loss, complete_val_loss, complete_train_loss, train_losses, val_losses = model.train(max_epochs=300)
sv.pl.loss_plots("summed", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("reconstruction", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("kl", complete_train_loss, complete_val_loss, train_losses, val_losses)

input_tensor, labels_tensor, batch_tensor = sv.pp.prepare_data_and_labels(
    adata_test, label_map={"single_cell": 0, "single_nucleus": 1}
)

original_sn = adata_test[adata_test.obs["data_type"] == "single_nucleus"].X.toarray()
original_sc = adata_test[adata_test.obs["data_type"] == "single_cell"].X.toarray()

input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="cell_types"
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="batch"
)

print(input_data.max())
print(reconstructed_data.max())
print(input_data.min())
print(reconstructed_data.min())

# Transformations
sn_to_sc = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_cell"
)
sn_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_nucleus"
)
sc_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_nucleus"
)
sc_to_sc = sv.tl.transform(adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_cell")

sv.pl.transformed_pca_data(original_sn, original_sc, sn_to_sn, sc_to_sc, sc_to_sn, sn_to_sc)

4. Model configuration 4
This configuration has a larger latent space and hidden layer size, aiming to capture more complex patterns.
It includes batch normalization to handle potential batch effects.

In [None]:
model = sv.vae.VAEModel(
    adata=adata_train,
    max_epochs=300,
    n_latent=700,
    lr=0.001,
    batch_size=300,
    n_layers=3,
    n_hidden=800,
    use_batch_norm=False,
    use_observed_lib_size=False,
    dispersion="gene-label",
    dropout_rate=0.4,
    weight_decay=0.002,
    recon_weight=0.5,
    kl_weight=3,
    encode_batch=False,
)


train_loss, val_loss, complete_val_loss, complete_train_loss, train_losses, val_losses = model.train(max_epochs=300)
sv.pl.loss_plots("summed", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("reconstruction", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("kl", complete_train_loss, complete_val_loss, train_losses, val_losses)

input_tensor, labels_tensor, batch_tensor = sv.pp.prepare_data_and_labels(
    adata_test, label_map={"single_cell": 0, "single_nucleus": 1}
)

original_sn = adata_test[adata_test.obs["data_type"] == "single_nucleus"].X.toarray()
original_sc = adata_test[adata_test.obs["data_type"] == "single_cell"].X.toarray()

input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="cell_types"
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="batch"
)
print(input_data.max())
print(reconstructed_data.max())
print(input_data.min())
print(reconstructed_data.min())

# Transformations
sn_to_sc = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_cell"
)
sn_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_nucleus"
)
sc_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_nucleus"
)
sc_to_sc = sv.tl.transform(adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_cell")

sv.pl.transformed_pca_data(original_sn, original_sc, sn_to_sn, sc_to_sc, sc_to_sn, sn_to_sc)

5. Model configuration 5
Similar to configuration 2 but with increased `n_hidden` and `n_latent` dimensions.
It aims to handle gene-cell dispersion and includes observed library size normalization.

In [None]:
model = sv.vae.VAEModel(
    adata=adata_train,
    max_epochs=300,
    n_latent=750,
    lr=0.001,
    batch_size=300,
    n_layers=3,
    n_hidden=850,
    use_batch_norm=False,
    use_observed_lib_size=True,
    dispersion="gene-cell",
    dropout_rate=0.4,
    weight_decay=0.002,
    recon_weight=0.5,
    kl_weight=2.8,
    encode_batch=False,
)


train_loss, val_loss, complete_val_loss, complete_train_loss, train_losses, val_losses = model.train(max_epochs=300)
sv.pl.loss_plots("summed", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("reconstruction", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("kl", complete_train_loss, complete_val_loss, train_losses, val_losses)

input_tensor, labels_tensor, batch_tensor = sv.pp.prepare_data_and_labels(
    adata_test, label_map={"single_cell": 0, "single_nucleus": 1}
)

original_sn = adata_test[adata_test.obs["data_type"] == "single_nucleus"].X.toarray()
original_sc = adata_test[adata_test.obs["data_type"] == "single_cell"].X.toarray()

input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="cell_types"
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="batch"
)

print(input_data.max())
print(reconstructed_data.max())
print(input_data.min())
print(reconstructed_data.min())

# Transformations
sn_to_sc = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_cell"
)
sn_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_nucleus"
)
sc_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_nucleus"
)
sc_to_sc = sv.tl.transform(adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_cell")

sv.pl.transformed_pca_data(original_sn, original_sc, sn_to_sn, sc_to_sc, sc_to_sn, sn_to_sc)

6. Model configuration 6
The most complex model with a high `n_latent` and `n_hidden`, aimed at capturing the most intricate data structures.
It tests the robustness of the model with a high KL weight and without observed library size normalization.

In [None]:
model = sv.vae.VAEModel(
    adata=adata_train,
    max_epochs=300,
    n_latent=800,
    lr=0.001,
    batch_size=350,
    n_layers=3,
    n_hidden=850,
    use_batch_norm=False,
    use_observed_lib_size=False,
    dispersion="gene-label",
    dropout_rate=0.4,
    weight_decay=0.002,
    recon_weight=0.5,
    kl_weight=3,
    encode_batch=False,
)


train_loss, val_loss, complete_val_loss, complete_train_loss, train_losses, val_losses = model.train(max_epochs=300)
sv.pl.loss_plots("summed", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("reconstruction", complete_train_loss, complete_val_loss, train_losses, val_losses)
sv.pl.loss_plots("kl", complete_train_loss, complete_val_loss, train_losses, val_losses)

input_tensor, labels_tensor, batch_tensor = sv.pp.prepare_data_and_labels(
    adata_test, label_map={"single_cell": 0, "single_nucleus": 1}
)

original_sn = adata_test[adata_test.obs["data_type"] == "single_nucleus"].X.toarray()
original_sc = adata_test[adata_test.obs["data_type"] == "single_cell"].X.toarray()

input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="cell_types"
)
input_data, latent_space, reconstructed_data = sv.pl.results_PCA(
    model, adata_test, input_tensor, labels_tensor, batch_tensor, labels_key="batch"
)

print(input_data.max())
print(reconstructed_data.max())
print(input_data.min())
print(reconstructed_data.min())

# Transformations
sn_to_sc = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_cell"
)
sn_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_nucleus"], model, target_label="single_nucleus"
)
sc_to_sn = sv.tl.transform(
    adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_nucleus"
)
sc_to_sc = sv.tl.transform(adata_test[adata_test.obs["labels_key"] == "single_cell"], model, target_label="single_cell")

sv.pl.transformed_pca_data(original_sn, original_sc, sn_to_sn, sc_to_sc, sc_to_sn, sn_to_sc)

## Fine Tuning 

In [None]:
def results_TSNE(
    model,
    adata,
    input_tensor,
    labels_tensor,
    batch_tensor: torch.Tensor,
    labels_key: str = "labels_key",
):
    """
    Plot the resulting t-SNE with reconstructions.

    Parameters
    ----------
    model
        VAE model to use for process.
    input_tensor
        Tensor to encode and decode for reconstruction and plot for input data.
    adata
        AnnData used in input_tensor to get categories to plot.
    labels_tensor
        Tensor for labels to be used for encoding and decoding.
    batch_tensor
        Tensor for batches to be used for encoding and decoding.
    labels_key
        Column name string in adata.obs to get labels from.

    Returns
    -------
    original_data
        Original data plotted as numpy array.
    latent_space
        Latent space numpy array.
    reconstructed_data
        Original data after encoding and decoding with labels and numpy array.
    """
    # Ensure the labels_key column is of type 'category'
    if labels_key in adata.obs:
        adata.obs[labels_key] = adata.obs[labels_key].astype("category")
    else:
        raise AssertionError("labels_key used not found in adata.obs")

    # Model in evaluation mode
    model.eval()
    # Perform inference to get latent space, library, and reconstruction
    with torch.no_grad():
        outputs = model.module.inference(input_tensor, labels_tensor, batch_index=batch_tensor)
        reconstructed_data = model.module.generative(
            outputs["z"], outputs["library"], labels_tensor, batch_index=batch_tensor
        )["px_rate"]

    # Convert to numpy for t-SNE and plotting
    latent_space_ = outputs["z"].numpy()
    reconstructed_data_ = reconstructed_data.numpy()
    original_data_ = adata.X.toarray()

    # Standardize the data
    scaler = StandardScaler()
    original_data = scaler.fit_transform(original_data_)
    reconstructed_data = scaler.transform(reconstructed_data_)
    latent_space = scaler.fit_transform(latent_space_)

    # Fit t-SNE on the latent space
    TSNE = "a"
    tsne = TSNE(n_jobs=-1, random_state=42)
    tsne_latent = tsne.fit(latent_space)
    tsne_input = tsne.transform(original_data)
    tsne_reconstructed = tsne.transform(reconstructed_data)

    # Set up colors based on labels
    color_map = {dtype: plt.cm.tab20(i % 20) for i, dtype in enumerate(adata.obs[labels_key].cat.categories)}
    colors = [color_map[label] for label in adata.obs[labels_key]]

    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].scatter(tsne_input[:, 0], tsne_input[:, 1], c=colors, alpha=0.5)
    axes[0].set_title("t-SNE of Input Data", fontweight="bold", fontsize=12)
    axes[1].scatter(tsne_latent[:, 0], tsne_latent[:, 1], c=colors, alpha=0.5)
    axes[1].set_title("t-SNE of Latent Space", fontweight="bold", fontsize=12)
    axes[2].scatter(tsne_reconstructed[:, 0], tsne_reconstructed[:, 1], c=colors, alpha=0.5)
    axes[2].set_title("t-SNE of Reconstructed Data", fontweight="bold", fontsize=12)

    for ax in axes:
        ax.set_xlabel("t-SNE1")
        ax.set_ylabel("t-SNE2")

    # Create a legend
    handles = [Patch(color=color, label=label) for label, color in color_map.items()]
    fig.legend(handles=handles, loc="center left", bbox_to_anchor=(1, 0.5))

    plt.tight_layout()
    plt.show()

    return original_data_, latent_space_, reconstructed_data_

Based on the above results, we see that 