In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from scipy.linalg import sqrtm
import torchvision.models as models
import torch.nn.functional as F

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the Generator model
class Generator(nn.Module):
  def __init__(self, nz, nc, ngf, num_classes):
      super(Generator, self).__init__()
      self.label_emb = nn.Embedding(num_classes, num_classes)
      self.main = nn.Sequential(
          nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False),
          nn.BatchNorm2d(ngf * 8),
          nn.ReLU(True),
          nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf * 4),
          nn.ReLU(True),
          nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf * 2),
          nn.ReLU(True),
          nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf),
          nn.ReLU(True),
          nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
          nn.Tanh()
      )

  def forward(self, noise, labels):
      label_embedding = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
      input = torch.cat((noise, label_embedding), 1)
      return self.main(input)

# Define the Discriminator model
class Discriminator(nn.Module):
  def __init__(self, nc, ndf, num_classes):
      super(Discriminator, self).__init__()
      self.label_emb = nn.Embedding(num_classes, num_classes)
      self.main = nn.Sequential(
          nn.Conv2d(nc + num_classes, ndf, 4, 2, 1, bias=False),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf * 2),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf * 4),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf * 8),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
          nn.Sigmoid()
      )

  def forward(self, img, labels):
      label_embedding = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
      label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
      input = torch.cat((img, label_embedding), 1)
      return self.main(input)

# Function to train the cGAN
def train_cgan(generator, discriminator, dataloader, num_epochs=25, lr=0.0002, nz=100, num_classes=10):
  criterion = nn.BCELoss()
  optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
  optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

  for epoch in range(num_epochs):
      for i, (real_images, labels) in enumerate(dataloader):
          batch_size = real_images.size(0)
          real_images = real_images.to(device)
          labels = labels.to(device)

          # Train Discriminator
          discriminator.zero_grad()
          real_labels = torch.full((batch_size,), 1, dtype=torch.float, device=device)
          fake_labels = torch.full((batch_size,), 0, dtype=torch.float, device=device)

          output = discriminator(real_images, labels).view(-1)
          lossD_real = criterion(output, real_labels)
          lossD_real.backward()

          noise = torch.randn(batch_size, nz, 1, 1, device=device)
          fake_images = generator(noise, labels)
          output = discriminator(fake_images.detach(), labels).view(-1)
          lossD_fake = criterion(output, fake_labels)
          lossD_fake.backward()
          optimizerD.step()

          # Train Generator
          generator.zero_grad()
          output = discriminator(fake_images, labels).view(-1)
          lossG = criterion(output, real_labels)
          lossG.backward()
          optimizerG.step()

      print(f'Epoch [{epoch+1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}')

# Create datasets
def create_datasets(imbalance_ratios, batch_size=64):
  train_transform = transforms.Compose([
      transforms.Resize(64),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

  fid_transform = transforms.Compose([
      transforms.Resize(299),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

  cifar10_train = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
  cifar10_test = datasets.CIFAR10(root="./data", train=False, download=True, transform=fid_transform)

  train_size = int(0.8 * len(cifar10_train))
  val_size = len(cifar10_train) - train_size
  train_dataset, val_dataset = random_split(cifar10_train, [train_size, val_size])

  targets = np.array([cifar10_train.targets[i] for i in train_dataset.indices])
  indices = [i for class_id, ratio in imbalance_ratios.items()
             for i in np.where(targets == class_id)[0][:int(len(np.where(targets == class_id)[0]) * ratio)]]

  imbalanced_dataset = Subset(train_dataset, indices)

  return (
      DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True),
      DataLoader(imbalanced_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True),
      DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True),
      DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
  )

# Calculate FID score between two sets of features
def calculate_fid(real_features, fake_features):
  mu_real, sigma_real = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
  mu_fake, sigma_fake = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)

  covmean = sqrtm(sigma_real @ sigma_fake)
  if np.iscomplexobj(covmean):
      covmean = covmean.real

  fid = np.sum((mu_real - mu_fake) ** 2) + np.trace(sigma_real + sigma_fake - 2 * covmean)
  return fid

# Extract features class-by-class
def extract_class_features(loader, model, class_id, generator=None):
  model.eval()
  features = []
  with torch.no_grad():
      for inputs, labels in loader:
          mask = labels == class_id  # Filter inputs by class
          if mask.sum() == 0:
              continue
          inputs = inputs[mask].to(device)

          if generator:
              # Generate fake images using the generator
              noise = torch.randn(inputs.size(0), 100, 1, 1, device=device)
              inputs = generator(noise, labels[mask])

          # Resize inputs to 299x299 for Inception-v3
          inputs = F.interpolate(inputs, size=(299, 299), mode='bilinear', align_corners=False)
          outputs = model(inputs)  # Extract features
          features.append(outputs.cpu().numpy())

  return np.concatenate(features, axis=0)

# Compare FID scores per class
def compare_fid_scores(test_loader, generator_balanced, generator_imbalanced):
  # Initialize the Inception-v3 model for feature extraction
  inception = models.inception_v3(pretrained=True, transform_input=False).to(device)
  inception.fc = torch.nn.Identity()  # Replace the FC layer

  fid_data = []

  for class_id in range(10):  # Assuming CIFAR-10 has 10 classes
      print(f"Processing class {class_id}...")

      # Extract real features for the current class
      real_features = extract_class_features(test_loader, inception, class_id)

      # Extract fake features from the balanced and imbalanced generators
      fake_features_balanced = extract_class_features(test_loader, inception, class_id, generator_balanced)
      fake_features_imbalanced = extract_class_features(test_loader, inception, class_id, generator_imbalanced)

      # Calculate FID scores
      fid_balanced = calculate_fid(real_features, fake_features_balanced)
      fid_imbalanced = calculate_fid(real_features, fake_features_imbalanced)

      # Store results for this class
      fid_data.append({
          "Class": class_id,
          "FID (Balanced)": fid_balanced,
          "FID (Imbalanced)": fid_imbalanced,
          "Delta FID": fid_imbalanced - fid_balanced
      })

  # Display the results
  import pandas as pd
  df = pd.DataFrame(fid_data)
  print(df)

  return df

In [None]:

# Hyperparameters
nz = 100  # Size of z latent vector (i.e. size of generator input)
nc = 3    # Number of channels in the training images. For color images this is 3
ngf = 64  # Size of feature maps in generator
ndf = 64  # Size of feature maps in discriminator
num_classes = 10
num_epochs = 25
lr = 0.0002

# Imbalance ratios for each class
imbalance_ratios = {0: 0.01, 1: 0.01, 2: 0.02, 3: 0.05, 4: 0.4, 5: 0.5, 6: 0.6, 7: 0.7, 8: 0.8, 9: 0.9}
# Create datasets
train_loader_balanced, train_loader_imbalanced, val_loader, test_loader = create_datasets(imbalance_ratios)

# Initialize models
generator_balanced = Generator(nz, nc, ngf, num_classes).to(device)
discriminator_balanced = Discriminator(nc, ndf, num_classes).to(device)

generator_imbalanced = Generator(nz, nc, ngf, num_classes).to(device)
discriminator_imbalanced = Discriminator(nc, ndf, num_classes).to(device)

# Train models
print("Training balanced cGAN...")
train_cgan(generator_balanced, discriminator_balanced, train_loader_balanced, num_epochs, lr, nz, num_classes)

print("Training imbalanced cGAN...")
train_cgan(generator_imbalanced, discriminator_imbalanced, train_loader_imbalanced, num_epochs, lr, nz, num_classes)

# Compare FID scores
print("Comparing FID scores...")
compare_fid_scores(test_loader, generator_balanced, generator_imbalanced)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 30.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified




Training balanced cGAN...


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from scipy.linalg import sqrtm
import torchvision.models as models
import torch.nn.functional as F

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define the Generator model
class Generator(nn.Module):
  def __init__(self, nz, nc, ngf, num_classes):
      super(Generator, self).__init__()
      self.label_emb = nn.Embedding(num_classes, num_classes)
      self.main = nn.Sequential(
          nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False),
          nn.BatchNorm2d(ngf * 8),
          nn.ReLU(True),
          nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf * 4),
          nn.ReLU(True),
          nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf * 2),
          nn.ReLU(True),
          nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf),
          nn.ReLU(True),
          nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
          nn.Tanh()
      )

  def forward(self, noise, labels):
      label_embedding = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
      input = torch.cat((noise, label_embedding), 1)
      return self.main(input)

# Define the Discriminator model
class Discriminator(nn.Module):
  def __init__(self, nc, ndf, num_classes):
      super(Discriminator, self).__init__()
      self.label_emb = nn.Embedding(num_classes, num_classes)
      self.main = nn.Sequential(
          nn.Conv2d(nc + num_classes, ndf, 4, 2, 1, bias=False),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf * 2),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf * 4),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ndf * 8),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
          nn.Sigmoid()
      )

  def forward(self, img, labels):
      label_embedding = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
      label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
      input = torch.cat((img, label_embedding), 1)
      return self.main(input)

# Function to train the cGAN
def train_cgan(generator, discriminator, dataloader, num_epochs=25, lr=0.0002, nz=100, num_classes=10):
  criterion = nn.BCELoss()
  optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
  optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

  for epoch in range(num_epochs):
      for i, (real_images, labels) in enumerate(dataloader):
          batch_size = real_images.size(0)
          real_images = real_images.to(device)
          labels = labels.to(device)

          # Train Discriminator
          discriminator.zero_grad()
          real_labels = torch.full((batch_size,), 1, dtype=torch.float, device=device)
          fake_labels = torch.full((batch_size,), 0, dtype=torch.float, device=device)

          output = discriminator(real_images, labels).view(-1)
          lossD_real = criterion(output, real_labels)
          lossD_real.backward()

          noise = torch.randn(batch_size, nz, 1, 1, device=device)
          fake_images = generator(noise, labels)
          output = discriminator(fake_images.detach(), labels).view(-1)
          lossD_fake = criterion(output, fake_labels)
          lossD_fake.backward()
          optimizerD.step()

          # Train Generator
          generator.zero_grad()
          output = discriminator(fake_images, labels).view(-1)
          lossG = criterion(output, real_labels)
          lossG.backward()
          optimizerG.step()

      print(f'Epoch [{epoch+1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}')

# Create datasets
def create_datasets(imbalance_ratios, batch_size=64):
  train_transform = transforms.Compose([
      transforms.Resize(64),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

  fid_transform = transforms.Compose([
      transforms.Resize(299),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

  cifar10_train = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
  cifar10_test = datasets.CIFAR10(root="./data", train=False, download=True, transform=fid_transform)

  train_size = int(0.8 * len(cifar10_train))
  val_size = len(cifar10_train) - train_size
  train_dataset, val_dataset = random_split(cifar10_train, [train_size, val_size])

  targets = np.array([cifar10_train.targets[i] for i in train_dataset.indices])
  indices = [i for class_id, ratio in imbalance_ratios.items()
             for i in np.where(targets == class_id)[0][:int(len(np.where(targets == class_id)[0]) * ratio)]]

  imbalanced_dataset = Subset(train_dataset, indices)

  return (
      DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True),
      DataLoader(imbalanced_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True),
      DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True),
      DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)
  )

# Calculate FID score between two sets of features
def calculate_fid(real_features, fake_features):
  mu_real, sigma_real = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
  mu_fake, sigma_fake = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)

  covmean = sqrtm(sigma_real @ sigma_fake)
  if np.iscomplexobj(covmean):
      covmean = covmean.real

  fid = np.sum((mu_real - mu_fake) ** 2) + np.trace(sigma_real + sigma_fake - 2 * covmean)
  return fid

# Extract features class-by-class
def extract_class_features(loader, model, class_id, generator=None):
  model.eval()
  features = []
  with torch.no_grad():
      for inputs, labels in loader:
          mask = labels == class_id  # Filter inputs by class
          if mask.sum() == 0:
              continue
          inputs = inputs[mask].to(device)

          if generator:
              # Generate fake images using the generator
              noise = torch.randn(inputs.size(0), 100, 1, 1, device=device)
              inputs = generator(noise, labels[mask])

          # Resize inputs to 299x299 for Inception-v3
          inputs = F.interpolate(inputs, size=(299, 299), mode='bilinear', align_corners=False)
          outputs = model(inputs)  # Extract features
          features.append(outputs.cpu().numpy())

  return np.concatenate(features, axis=0)

# Compare FID scores per class
def compare_fid_scores(test_loader, generator_balanced, generator_imbalanced):
  # Initialize the Inception-v3 model for feature extraction
  inception = models.inception_v3(pretrained=True, transform_input=False).to(device)
  inception.fc = torch.nn.Identity()  # Replace the FC layer

  fid_data = []

  for class_id in range(10):  # Assuming CIFAR-10 has 10 classes
      print(f"Processing class {class_id}...")

      # Extract real features for the current class
      real_features = extract_class_features(test_loader, inception, class_id)

      # Extract fake features from the balanced and imbalanced generators
      fake_features_balanced = extract_class_features(test_loader, inception, class_id, generator_balanced)
      fake_features_imbalanced = extract_class_features(test_loader, inception, class_id, generator_imbalanced)

      # Calculate FID scores
      fid_balanced = calculate_fid(real_features, fake_features_balanced)
      fid_imbalanced = calculate_fid(real_features, fake_features_imbalanced)

      # Store results for this class
      fid_data.append({
          "Class": class_id,
          "FID (Balanced)": fid_balanced,
          "FID (Imbalanced)": fid_imbalanced,
          "Delta FID": fid_imbalanced - fid_balanced
      })

  # Display the results
  import pandas as pd
  df = pd.DataFrame(fid_data)
  print(df)

  return df

# Main function to execute the training and evaluation
def main():
  # Hyperparameters
  nz = 100  # Size of z latent vector (i.e. size of generator input)
  nc = 3    # Number of channels in the training images. For color images this is 3
  ngf = 64  # Size of feature maps in generator
  ndf = 64  # Size of feature maps in discriminator
  num_classes = 10
  num_epochs = 25
  lr = 0.0002

  # Imbalance ratios for each class
  imbalance_ratios = {0: 0.01, 1: 0.01, 2: 0.02, 3: 0.05, 4: 0.4, 5: 0.5, 6: 0.6, 7: 0.7, 8: 0.8, 9: 0.9}

  # Create datasets
  train_loader_balanced, train_loader_imbalanced, val_loader, test_loader = create_datasets(imbalance_ratios)

  # Initialize models
  generator_balanced = Generator(nz, nc, ngf, num_classes).to(device)
  discriminator_balanced = Discriminator(nc, ndf, num_classes).to(device)

  generator_imbalanced = Generator(nz, nc, ngf, num_classes).to(device)
  discriminator_imbalanced = Discriminator(nc, ndf, num_classes).to(device)

  # Train models
  print("Training balanced cGAN...")
  train_cgan(generator_balanced, discriminator_balanced, train_loader_balanced, num_epochs, lr, nz, num_classes)

  print("Training imbalanced cGAN...")
  train_cgan(generator_imbalanced, discriminator_imbalanced, train_loader_imbalanced, num_epochs, lr, nz, num_classes)

  # Compare FID scores
  print("Comparing FID scores...")
  compare_fid_scores(test_loader, generator_balanced, generator_imbalanced)

if __name__ == "__main__":
  main()

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified




Training balanced cGAN...
Epoch [1/25] Loss D: 1.122218132019043, Loss G: 1.7036888599395752
Epoch [2/25] Loss D: 0.6722404956817627, Loss G: 2.850545883178711
Epoch [3/25] Loss D: 0.3328244686126709, Loss G: 2.921764850616455
Epoch [4/25] Loss D: 0.43519553542137146, Loss G: 2.9261364936828613
Epoch [5/25] Loss D: 0.3724260926246643, Loss G: 2.7494711875915527
Epoch [6/25] Loss D: 0.03655776381492615, Loss G: 4.265316963195801
Epoch [7/25] Loss D: 0.07403385639190674, Loss G: 4.79141902923584
Epoch [8/25] Loss D: 0.0925154760479927, Loss G: 3.9526398181915283


KeyboardInterrupt: 

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


In [None]:
def load_data(dataset_name='cifar10', dataroot='./data', image_size=64, batch_size=64):
    if dataset_name == 'cifar10':
        dataset = dset.CIFAR10(
            root=dataroot, download=True,
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        )
        nc = 3  # CIFAR-10 có 3 kênh (RGB)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=1
    )
    return dataloader, nc

In [None]:
class Generator(nn.Module):
    def __init__(self, ngpu, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)


In [None]:
class Critic(nn.Module):
    def __init__(self, ngpu, ndf, nc):
        super(Critic, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)
        )

    def forward(self, input):
        return self.main(input).view(-1)  # Output is a scalar score per input


In [None]:
def gradient_penalty(critic, real_data, fake_data, device):
    batch_size = real_data.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = epsilon * real_data + (1 - epsilon) * fake_data
    interpolated.requires_grad_(True)

    score_interpolated = critic(interpolated)

    grad_outputs = torch.ones_like(score_interpolated, device=device)
    gradients = torch.autograd.grad(
        outputs=score_interpolated,
        inputs=interpolated,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = ((gradient_norm - 1) ** 2).mean()
    return penalty


In [None]:
nz = 100  # Latent vector size
ngf = 64  # Generator feature map size
ndf = 64  # Critic feature map size
ngpu = 1  # Number of GPUs

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

netG = Generator(ngpu, nz, ngf, 3).to(device)
netC = Critic(ngpu, ndf, 3).to(device)

optimizerG = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerC = optim.Adam(netC.parameters(), lr=0.0001, betas=(0.0, 0.9))


In [None]:
n_critic = 5  # Update critic 5 times per generator update
lambda_gp = 10  # Gradient penalty weight
# Khởi tạo dataloader và số lượng kênh
dataloader, nc = load_data('cifar10')

for epoch in range(25):
    for i, data in enumerate(dataloader):
        # Train Critic
        netC.zero_grad()
        real_data = data[0].to(device)

        batch_size = real_data.size(0)
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_data = netG(noise).detach()

        real_score = netC(real_data).mean()
        fake_score = netC(fake_data).mean()

        # Compute gradient penalty
        gp = gradient_penalty(netC, real_data, fake_data, device)

        # Critic loss
        lossC = fake_score - real_score + lambda_gp * gp
        lossC.backward()
        optimizerC.step()

        # Update Generator every n_critic iterations
        if i % n_critic == 0:
            netG.zero_grad()
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake_data = netG(noise)
            lossG = -netC(fake_data).mean()
            lossG.backward()
            optimizerG.step()

        if i % 100 == 0:
            print(f'[{epoch}/{25}][{i}/{len(dataloader)}] '
                  f'Loss_C: {lossC.item():.4f} Loss_G: {lossG.item():.4f}')


Files already downloaded and verified
[0/25][0/782] Loss_C: 206.8642 Loss_G: 0.3810
[0/25][100/782] Loss_C: -1.1295 Loss_G: 1.2748
[0/25][200/782] Loss_C: -0.7499 Loss_G: 0.7439
[0/25][300/782] Loss_C: -0.3188 Loss_G: 0.2283
[0/25][400/782] Loss_C: -0.3124 Loss_G: 0.6594
[0/25][500/782] Loss_C: -0.3546 Loss_G: 1.0025
[0/25][600/782] Loss_C: -0.2581 Loss_G: 1.0145
[0/25][700/782] Loss_C: -0.2458 Loss_G: 1.1796
[1/25][0/782] Loss_C: -0.2467 Loss_G: 1.2689
[1/25][100/782] Loss_C: -0.2634 Loss_G: 1.3794
[1/25][200/782] Loss_C: -0.3040 Loss_G: 1.4249
[1/25][300/782] Loss_C: -0.2925 Loss_G: 1.4357
[1/25][400/782] Loss_C: -0.3044 Loss_G: 1.4063
[1/25][500/782] Loss_C: -0.3014 Loss_G: 1.4268
[1/25][600/782] Loss_C: -0.2624 Loss_G: 1.3171
[1/25][700/782] Loss_C: 0.1467 Loss_G: 1.2153
[2/25][0/782] Loss_C: -0.2333 Loss_G: 1.1716
[2/25][100/782] Loss_C: -0.2600 Loss_G: 1.2413
[2/25][200/782] Loss_C: -0.2287 Loss_G: 1.0703
[2/25][300/782] Loss_C: -0.2103 Loss_G: 0.9555
[2/25][400/782] Loss_C: -0.2