In [1]:
%%capture
!pip install torch
!pip install torchvision
!pip install torchmetrics pytorch-fid
!pip install torchmetrics[image]
!pip install torch-fidelity

In [2]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch import nn
from tqdm import tqdm
from torchsummary import summary
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision import models

In [3]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
training_data = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform
)


train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 39.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
class FCCGenerator(nn.Module):
    def __init__(self, latent_dim=100):
        super(FCCGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 512),
            nn.ReLU(),
            nn.Linear(512, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU()  # Add ReLU after BN if used in paper
        )

        self.conv = nn.Sequential(
            nn.Unflatten(1, (256, 4, 4)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = self.conv(x)
        return x


class FCCDiscriminator(nn.Module):
    def __init__(self):
        super(FCCDiscriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(4096, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(64, 16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )


    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

In [5]:
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
generator = FCCGenerator()
discriminator = FCCDiscriminator()

#Define BCE Loss
bce_loss = nn.BCELoss()

#Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
generator.to(device)
discriminator.to(device)

FCCDiscriminator(
  (conv): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (fc): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=4096, out_features=512, bias=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=512, out_features=64, bias=True)
    (4): LeakyReLU(negativ

In [7]:
# Parameters
num_epochs = 20
latent_dim = 100
fixed_noise = torch.randn(64, latent_dim).to(device)
# Initialize FID and Inception Score
fid = FrechetInceptionDistance(feature=2048).to(device) # Default Feature is 2048
inception = InceptionScore().to(device)

for epoch in range(num_epochs):
    epoch_progress = tqdm(train_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)

    for batch_idx, (real_images, _) in enumerate(epoch_progress):
        # TRAIN DISCRIMINATOR
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # Label thật và giả
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Tính loss với ảnh thật
        real_outputs = discriminator(real_images)

        d_loss_real = bce_loss(real_outputs, real_labels)

        # Sinh ảnh giả
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(noise)

        # Tính loss với ảnh giả
        fake_outputs = discriminator(fake_images)
        d_loss_fake = bce_loss(fake_outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake

        # Cập nhật Discriminator
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # TRAIN GENERATOR
        # Sinh lại ảnh giả
        fake_images = generator(noise)
        # Tính loss của G
        fake_outputs = discriminator(fake_images)
        g_loss = bce_loss(fake_outputs, real_labels)

        # Cập nhật Generator
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        # Cập nhật tqdm progress bar
        epoch_progress.set_postfix(D_Loss=f"{d_loss.item():.4f}", G_Loss=f"{g_loss.item():.4f}")


    tqdm.write(f"Epoch [{epoch+1}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} ")

    # Save generated images for visualization

    with torch.no_grad():
            fake_images = generator(fixed_noise)
            grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True)
            plt.figure(figsize=(8, 8))
            plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
            plt.title(f"Epoch {epoch+1}")
            plt.show()

Output hidden; open in https://colab.research.google.com to view.

In [8]:
num_epochs = 10
for epoch in range(num_epochs):
    with torch.no_grad():
        batch_size = 16
        num_samples = 1000
        for _ in range(num_samples // batch_size):
            noise_eval = torch.randn(batch_size, latent_dim).to(device)
            fake_images = generator(noise_eval)

            # Resize images to 299x299 for FID computation
            fake_images = nn.functional.interpolate(fake_images, size=(299, 299), mode='bicubic', align_corners=False)
            real_images_batch = real_images[:batch_size]
            real_images_batch=nn.functional.interpolate(real_images_batch, size=(299,299), mode = 'bicubic', align_corners= False)
            #fid update
            fid.update(real_images_batch.to(torch.uint8), real=True)
            fid.update(fake_images.to(torch.uint8), real=False)
            #inception update
            inception.update(fake_images.to(torch.uint8))

        fid_score = fid.compute()
        inception_score, inception_std = inception.compute()

        print(f"Epoch [{epoch+1}/{num_epochs}] | FID: {fid_score:.4f} | IS: {inception_score:.4f} (std: {inception_std:.4f})")

        fid.reset()
        inception.reset()

Epoch [1/10] | FID: 201.0068 | IS: 2.8464 (std: 0.1314)
Epoch [2/10] | FID: 252.3581 | IS: 3.0197 (std: 0.1623)
Epoch [3/10] | FID: 194.0525 | IS: 2.9004 (std: 0.1779)
Epoch [4/10] | FID: 234.7229 | IS: 3.0577 (std: 0.1499)
Epoch [5/10] | FID: 206.7721 | IS: 2.9620 (std: 0.1963)
Epoch [6/10] | FID: 218.5246 | IS: 2.9881 (std: 0.1234)
Epoch [7/10] | FID: 230.8002 | IS: 3.0003 (std: 0.2395)
Epoch [8/10] | FID: 232.1312 | IS: 3.0520 (std: 0.2398)
Epoch [9/10] | FID: 224.5375 | IS: 3.0355 (std: 0.1784)
Epoch [10/10] | FID: 201.8732 | IS: 2.8594 (std: 0.1735)
