In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torch.utils.data as data
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def generate_synthetic_data(generator, num_samples=10):
    z = torch.randn(num_samples, 100, 1, 1).to(device)
    fake_images = generator(z)
    return fake_images

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images to range [-1, 1]
])

trainset = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)


RuntimeError: Dataset not found or corrupted. You can use download=True to download it

In [None]:
resnet_model = models.resnet18(pretrained=False)
resnet_model.load_state_dict(torch.load('resnet18_cifar10.pth'))  # Load the trained ResNet18 model
resnet_model.to(device)
resnet_model.eval()

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
generator = Generator().to(device)

In [None]:
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
num_epochs = 50
for epoch in range(num_epochs):
    for i, data in enumerate(trainloader, 0):
        real_images, _ = data
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # Create labels for real and fake images
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Step 1: Train the generator
        generator.zero_grad()
        fake_images = generate_synthetic_data(generator, batch_size)
        outputs = resnet_model(fake_images)
        g_loss = criterion(outputs, real_labels)

        g_loss.backward()
        optimizer_G.step()

        if i % 200 == 0:
            print(f"[Epoch {epoch + 1}/{num_epochs}] Batch {i}/{len(trainloader)} "
                  f"Generator Loss: {g_loss.item():.4f}")

print("Finished Training GAN")

In [None]:
num_samples_to_generate = 10
generated_images = generate_synthetic_data(generator, num_samples_to_generate)

# Display the generated images
grid = torchvision.utils.make_grid(generated_images.cpu().detach(), nrow=int(np.sqrt(num_samples_to_generate)))
plt.figure(figsize=(8, 8))
plt.imshow(np.transpose(grid, (1, 2, 0)))
plt.axis('off')
plt.show()