In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from tqdm import tqdm
from inception_score import inception_score  # Make sure to have the inception_score.py file from the provided reference

# Set random seed for reproducibility
torch.manual_seed(42)

# Define the generator and discriminator architectures
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.model(input)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.model(input)


# Initialize the generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Define the loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Set up DataLoader for CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

batch_size = 64
dataset = datasets.CIFAR10(root='./data', download=True, train=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# Training the DCGAN
num_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator.to(device)
discriminator.to(device)

for epoch in tqdm(range(num_epochs)):
    for i, data in enumerate(dataloader, 0):
        # Update discriminator with real data
        discriminator.zero_grad()
        real_images, _ = data
        real_images = real_images.to(device)
        # label = torch.full((batch_size,), 1.0, device=device)
        label = torch.full((batch_size, 1, 1, 1), 1.0, device=device)
        output = discriminator(real_images)
        # errD_real = criterion(output.view(-1), label)
        # errD_real = criterion(output.squeeze(), label)
        # errD_real = criterion(output, label)
        errD_real = criterion(output, label.expand_as(output))
        errD_real.backward()
        D_x = output.mean().item()

        # Update discriminator with fake data
        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_images = generator(noise)
        label.fill_(0.0)
        output = discriminator(fake_images.detach())
        print(output)
        # errD_fake = criterion(output.view(-1), label)
        # errD_fake = criterion(output.squeeze(), label)
        # errD_fake = criterion(output, label)
        errD_fake = criterion(output, label.expand_as(output))
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizer_D.step()

        # Update generator
        generator.zero_grad()
        label.fill_(1.0)
        output = discriminator(fake_images)
        print("Discriminator: ", output)
        # errG = criterion(output.view(-1), label)
        errG = criterion(output, label.expand_as(output))
        print("errG: ", errG)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizer_G.step()
        
# Save the generator model
torch.save(generator.state_dict(), 'generator.pth')

# Generate 10 images from the learned distribution and save them
generator.eval()
for i in range(10):
    noise = torch.randn(1, 100, 1, 1, device=device)
    fake_image = generator(noise).detach().cpu().squeeze().numpy()
    fake_image = (fake_image + 1) / 2.0  # Rescale values to [0, 1]
    plt.imsave(f'generated_image_{i+1}.png', np.transpose(fake_image, (1, 2, 0)))

# Plot the Loss vs Epoch curve
plt.plot(range(num_epochs), errD.item(), label="Discriminator Loss")
plt.plot(range(num_epochs), errG.item(), label="Generator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Calculate Inception Score on the generated images
fake_images_list = []
for i in range(500):  # Generate 500 images for a more accurate Inception Score
    noise = torch.randn(1, 100, 1, 1, device=device)
    fake_images_list.append(generator(noise).detach().cpu())
fake_images_tensor = torch.cat(fake_images_list, dim=0)
inception_mean, inception_std = inception_score(fake_images_tensor, cuda=True, batch_size=32, resize=True, splits=1)

print(f"Inception Score: {inception_mean} ± {inception_std}")

Files already downloaded and verified


  0%|          | 0/50 [00:00<?, ?it/s]

tensor([[[[0.5406]]],


        [[[0.4490]]],


        [[[0.4416]]],


        [[[0.5329]]],


        [[[0.4703]]],


        [[[0.6141]]],


        [[[0.4606]]],


        [[[0.3893]]],


        [[[0.3733]]],


        [[[0.3384]]],


        [[[0.5118]]],


        [[[0.4296]]],


        [[[0.3805]]],


        [[[0.3689]]],


        [[[0.4412]]],


        [[[0.5824]]],


        [[[0.5427]]],


        [[[0.5359]]],


        [[[0.6345]]],


        [[[0.3739]]],


        [[[0.4859]]],


        [[[0.4089]]],


        [[[0.4742]]],


        [[[0.5678]]],


        [[[0.4978]]],


        [[[0.5364]]],


        [[[0.4100]]],


        [[[0.5206]]],


        [[[0.5058]]],


        [[[0.4519]]],


        [[[0.3464]]],


        [[[0.5911]]],


        [[[0.5853]]],


        [[[0.4728]]],


        [[[0.5299]]],


        [[[0.4940]]],


        [[[0.3854]]],


        [[[0.3768]]],


        [[[0.5005]]],


        [[[0.5007]]],


        [[[0.3541]]],


        [[[0.359