In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.utils import shuffle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Load the "Digits" dataset
digits = load_digits()
X = digits.data
y = digits.target

# Preprocess the data
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)

# Reduce dimensionality using PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

# Convert the data to PyTorch tensors
X_tensor = torch.tensor(X_pca, dtype=torch.float32)

# Define the Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, X.shape[1]),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Define the Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(X.shape[1], 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Define the GAN
class GAN(nn.Module):
    def __init__(self, generator, discriminator):
        super(GAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def forward(self, z):
        fake_images = self.generator(z)
        fake_preds = self.discriminator(fake_images)
        return fake_images, fake_preds

# Set random seed for reproducibility
torch.manual_seed(42)

# Define hyperparameters
batch_size = 64
epochs = 200
lr = 0.0002
latent_dim = 100

# Create instances of Generator, Discriminator, and GAN
generator = Generator()
discriminator = Discriminator()
gan = GAN(generator, discriminator)

# Define loss function and optimizers
criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    for batch_idx, real_images in enumerate(DataLoader(X_tensor, batch_size=batch_size, shuffle=True)):
        batch_size = real_images.size(0)
        
        # Train discriminator
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        disc_optimizer.zero_grad()
        
        real_preds = discriminator(real_images)
        real_loss = criterion(real_preds, real_labels)
        
        z = torch.randn(batch_size, latent_dim)
        fake_images = generator(z)
        fake_preds = discriminator(fake_images.detach())
        fake_loss = criterion(fake_preds, fake_labels)
        
        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        disc_optimizer.step()
        
        # Train generator
        gen_optimizer.zero_grad()
        
        z = torch.randn(batch_size, latent_dim)
        fake_images, fake_preds = gan(z)
        gen_loss = criterion(fake_preds, real_labels)
        
        gen_loss.backward()
        gen_optimizer.step()
        
    if (epoch+1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}] Loss_D: {disc_loss.item():.4f} Loss_G: {gen_loss.item():.4f}")

# Generate synthetic images
num_samples = 1000
z = torch.randn(num_samples, latent_dim)
fake_images, _ = gan(z)
fake_images = fake_images.detach().numpy()

# Plot the synthetic images
plt.scatter(fake_images[:, 0], fake_images[:, 1], c='blue', label='Synthetic')
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='tab10', label='Real')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('GAN Generated Synthetic Images')
plt.legend()
plt.show()


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x2 and 64x256)