In [1]:
# .. ... .... .. .....

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd.variable import Variable
from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings

In [3]:
# Set random seed for reproducibility
torch.manual_seed(123)
warnings.filterwarnings('ignore')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set hyperparameters
latent_dim = 100
input_dim = 3 * 32 * 32
output_dim = input_dim

# Training loop
num_epochs = 20
batch_size = 64
sample_interval = 10

In [4]:
# Prepare the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Define the Generator
class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        return x

In [6]:
# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.fc(x)
        return x


In [7]:
# Define the Autoencoder
class Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Autoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim),
            nn.ReLU()
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
            nn.Tanh()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [8]:
# Initialize the models
generator = Generator(latent_dim, output_dim).to(device)
discriminator = Discriminator(input_dim).to(device)
autoencoder = Autoencoder(input_dim, latent_dim).to(device)

In [9]:
# Define loss functions
adversarial_loss = nn.BCELoss()
autoencoder_loss = nn.MSELoss()

In [10]:
# Define optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
autoencoder_optimizer = optim.Adam(autoencoder.parameters(), lr=0.0002)

In [12]:
pbar = tqdm(total = num_epochs, desc='Training', unit='batch')
for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(train_dataloader):
        # Adversarial ground truths
        real_labels = torch.ones((real_images.size(0), 1)).to(device)
        fake_labels = torch.zeros((batch_size, 1)).to(device)

        # Train the discriminator
        discriminator_optimizer.zero_grad()
        real_images = real_images.view(-1, input_dim).to(device)
        real_outputs = discriminator(real_images)
        real_loss = adversarial_loss(real_outputs, real_labels)

        sample_latent = prepare_sample_dataset(batch_size)
        fake_images = generator(sample_latent)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = adversarial_loss(fake_outputs, fake_labels)

        discriminator_loss = (real_loss + fake_loss) / 2
        discriminator_loss.backward()
        discriminator_optimizer.step()

        # Train the generator
        generator_optimizer.zero_grad()
        fake_outputs = discriminator(fake_images)
        generator_labels = torch.ones_like(fake_outputs)  # Generate labels for the generator

        generator_loss = adversarial_loss(fake_outputs, generator_labels)
        generator_loss.backward(retain_graph=True)  # Set retain_graph=True
        generator_optimizer.step()
        
    

    # Sample and save images
    if (epoch + 1) % sample_interval == 0:
        sample_latent = prepare_sample_dataset(25)
        generated_images = generator(sample_latent).detach()
        generated_images = generated_images.view(-1, 3, 32, 32).cpu().numpy()

        fig, axes = plt.subplots(5, 5, figsize=(10, 10))
        for i, ax in enumerate(axes.flat):
            ax.imshow(generated_images[i].transpose((1, 2, 0)), cmap="gray")
            ax.axis("off")

        plt.savefig(f"generated_images_epoch_{epoch + 1}.png")
        plt.close()
    pbar.set_postfix({'Discriminator Loss': discriminator_loss.item()})
    pbar.update(1)

Training:  45%|██▎  | 9/20 [01:23<01:42,  9.35s/batch, Discriminator Loss=0.241]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range f