In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import os

# Device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
# Parameters
latent_dim = 100 # synthetic neural data size
image_size = 28*28
batch_size = 64
num_epochs = 200
save_dir = './generated_images'
real_images = './real_images'
os.makedirs(save_dir, exist_ok=True)
os.makedirs(real_images, exist_ok=True)


In [9]:
# Create synthetic neural data
def generate_synthetic_neural_data(batch_size, dim=latent_dim):
    return torch.randn(batch_size, dim).to(device)

In [10]:
# Generator
class Generator(nn.Module):
    def __init__(self, input_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024), 
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))), 
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(z.size(0), *self.img_shape)
    
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),

            nn.Linear(512, 256), 
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1), 
            nn.Sigmoid()
        )
    def forward(self, img):
        flat = img.view(img.size(0), -1)
        return self.model(flat)

In [None]:
# Initialize models
img_shape = (1, 28, 28)
generator = Generator(latent_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

# Loss and optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
for epoch in range(num_epochs):
    # Generate data
    real_imgs = torch.randn(batch_size, *img_shape).to(device)
    z = generate_synthetic_neural_data(batch_size)

    # Ground truths
    valid = torch.ones(batch_size, 1).to(device)
    fake = torch.zeros(batch_size, 1).to(device)

    # --- Train Generator ---
    optimizer_G.zero_grad()
    gen_imgs = generator(z)
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
    g_loss.backward()
    optimizer_G.step()

    # --- Train Discriminator ---
    optimizer_D.zero_grad()
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    optimizer_D.step()

    # Save progress
    if epoch % 10 == 0:
        print(f"[Epoch {epoch}/{num_epochs}] D loss: {d_loss.item():.4f} | G loss: {g_loss.item():.4f}")
        save_image(gen_imgs.data[:25], f"{save_dir}/epoch_{epoch}.png", nrow=5, normalize=True)
        save_image(real_imgs.data[:25], f"{real_images}/epoch_{epoch}_real.png", nrow=4, normalize=True)
