<a href="https://colab.research.google.com/github/efitzgerald763/snRNAseq_ssGSEA_DE/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the generator
class Generator(nn.Module):
    def __init__(self, input_dim, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x, z):
        x_encoded = self.encoder(x)
        combined = x_encoded * z
        return self.decoder(combined)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Define the inverse mapping function (mapping disease data to latent space)
class InverseFunction(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(InverseFunction, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

    def forward(self, x):
        return self.model(x)

# Adversarial loss
adversarial_loss = nn.BCELoss()

# Instantiate the models
input_dim = 15000  # Adjust based on gene expression data
latent_dim = 10    # Number of patterns
output_dim = input_dim
batch_size = 32

# Create instances of the models
generator = Generator(input_dim, latent_dim, output_dim)
discriminator = Discriminator(output_dim)
inverse_function = InverseFunction(output_dim, latent_dim)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
optimizer_I = optim.Adam(inverse_function.parameters(), lr=0.0002)

# Training loop (simplified version)
for epoch in range(10000):
    # Sample real control data x and latent variable z
    x_real = torch.randn(batch_size, input_dim)
    z = torch.rand(batch_size, latent_dim)

    # Generate synthesized disease data
    y_synth = generator(x_real, z)

    # Discriminator training
    optimizer_D.zero_grad()
    real_loss = adversarial_loss(discriminator(x_real), torch.ones(batch_size, 1))
    fake_loss = adversarial_loss(discriminator(y_synth.detach()), torch.zeros(batch_size, 1))
    d_loss = real_loss + fake_loss
    d_loss.backward()
    optimizer_D.step()

    # Generator and inverse function training
    optimizer_G.zero_grad()
    optimizer_I.zero_grad()
    g_loss = adversarial_loss(discriminator(y_synth), torch.ones(batch_size, 1))
    reconstruction_loss = nn.MSELoss()(inverse_function(y_synth), z)
    total_loss = g_loss + reconstruction_loss
    total_loss.backward()
    optimizer_G.step()
    optimizer_I.step()

    if epoch % 1000 == 0:
        print(f"Epoch {epoch} | D Loss: {d_loss.item()} | G Loss: {g_loss.item()} | Reconstruction Loss: {reconstruction_loss.item()}")

print("Training complete.")
