In [36]:
#---imports---
import tensorflow as tf
import numpy as np
import os
import scanpy as sc
from sklearn.preprocessing import OneHotEncoder


In [37]:
#---Custom Training Loop---
# Define the autoencoder adversarial training function
@tf.function #overriding the normal training function
def train_autoencoder_adversarial(gene_expression, batch_labels, autoencoder, discriminator, optimizer):
    """
    Train the autoencoder to fool the discriminator into classifying reconstructions as target classes.
    
    Args:
        gene_expression: Input gene_expression data
        batch_labels: One-hot encoded target batch labels to fool the discriminator
        encoder: Pretrained encoder model
        decoder: Pretrained decoder model
        discriminator: Pretrained discriminator model (frozen during this training)
        optimizer: Optimizer for the autoencoder
    """
    with tf.GradientTape() as tape:
        # CAll autoencoder to get reconstructed gene expression
        reconstructed_gene_expression = autoencoder(gene_expression)
        
        # Get discriminator output for reconstructed gene expression
        disc_output = discriminator(reconstructed_gene_expression)
        
        # Adversarial loss - make discriminator classify reconstructions as target labels
        adversarial_loss = tf.keras.losses.CategoricalCrossentropy()(batch_labels, disc_output)
    
    # Get gradients and update autoencoder weights only
    gradients = tape.gradient(adversarial_loss, autoencoder.trainable_variables)
    optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))
    
    return adversarial_loss

# Create a custom training loop
def adversarial_training(dataset, epochs, autoencoder, discriminator):
    # Freeze the discriminator weights
    discriminator.trainable = False
    
    # Optimizer for the autoencoder
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
    
    # Training loop
    for epoch in range(epochs):
        adv_loss_avg = tf.keras.metrics.Mean()

        for batch in dataset:
            gene_expression, batch_labels = batch
            
            # Target same class as original
            target_labels = batch_labels
            
            # Train autoencoder
            adv_loss = train_autoencoder_adversarial(
                gene_expression, target_labels, autoencoder, discriminator, optimizer
            )
            
            # Update adv loss
            adv_loss_avg.update_state(adv_loss)

        
        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs}")
        print(
              f"Adversarial Loss: {adv_loss_avg.result():.4f}, " 
            )

In [38]:
#---load data and Neuronal Networks---
#set base path to load data: goes back one directory and then into the data
base_path = os.path.join('..', 'data')
#Name of data set
dataset_name = 'large_atac_gene_activity'
# read dataset into an anndata object:  Category - Cells of the brain
inPath = os.path.join(base_path, f"{dataset_name}.h5ad")
adata = sc.read(inPath)

#set base path to load data: goes back one directory and then into the data
base_path = os.path.join('..', 'src', 'models', 'saved_models')
#Name of autoencoder
dataset_name = 'autoencoder_mselossfunction'
# load autoencoder
inPath = os.path.join(base_path, f"{dataset_name}.keras")
autoencoder = tf.keras.models.load_model(inPath)

#Name of discriminator
dataset_name = 'discriminator_pretrained'
# load discriminator
inPath = os.path.join(base_path, f"{dataset_name}.keras")
discriminator = tf.keras.models.load_model(inPath)

In [39]:
#---Prepare Data---
#ADATA->NUMPY
GENE_EXPRESSION = adata.X.toarray()

#One-hot encoded Batches
encoder = OneHotEncoder(sparse_output=False)  # `sparse=False` returns a dense array
BATCH_LABELS = encoder.fit_transform(adata.obs[['batchname_all']])

#Combine in a Tensorflow dataset
train_dataset = tf.data.Dataset.from_tensor_slices((GENE_EXPRESSION, BATCH_LABELS))
#Create batches
batch_size = 30
train_dataset = train_dataset.batch(batch_size)


In [40]:
#---Training---
# Run the training
epochs = 10
adversarial_training(train_dataset, epochs, autoencoder, discriminator)

Epoch 1/10
Adversarial Loss: 0.8067, 
Epoch 2/10
Adversarial Loss: 0.9264, 
Epoch 3/10
Adversarial Loss: 0.9540, 
Epoch 4/10
Adversarial Loss: 0.8953, 
Epoch 5/10
Adversarial Loss: 0.8297, 
Epoch 6/10
Adversarial Loss: 0.7527, 
Epoch 7/10
Adversarial Loss: 0.6886, 
Epoch 8/10
Adversarial Loss: 0.6872, 
Epoch 9/10
Adversarial Loss: 0.6182, 
Epoch 10/10
Adversarial Loss: 0.5905, 


In [41]:
CORRECTED_GENE_EXPRESSION = autoencoder(GENE_EXPRESSION) #test new autoencoder

In [42]:
 #test the discriminator
print("Accuracy Gene expression Discrimination:")
score = discriminator.evaluate(GENE_EXPRESSION, BATCH_LABELS, verbose=2)
print("Accuracy Correccted gene expression Discrimination:")
score = discriminator.evaluate(CORRECTED_GENE_EXPRESSION, BATCH_LABELS, verbose=2)

Accuracy Gene expression Discrimination:
2651/2651 - 3s - 965us/step - accuracy: 0.3211 - loss: 1.7133
Accuracy Correccted gene expression Discrimination:
2651/2651 - 2s - 822us/step - accuracy: 0.0961 - loss: 6.1056
