In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
from torch.utils.tensorboard import SummaryWriter

In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d, img_size, num_classes):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.embed = nn.Embedding(num_classes, img_size * img_size)
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img + 1, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            nn.Conv2d(features_d * 4, 1, kernel_size=4, stride=1, padding=0)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x, labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
        x = torch.cat([x, embedding], dim=1)
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g, num_classes, img_size):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.noise_dim = channels_noise
        self.embed = nn.Embedding(num_classes, channels_noise)

        self.net = nn.Sequential(
            self._block(channels_noise * 2, features_g * 8, 4, 1, 0),
            self._block(features_g * 8, features_g * 4, 4, 2, 1),
            self._block(features_g * 4, features_g * 2, 4, 2, 1),
            nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, noise, labels):
        batch_size = noise.shape[0]
        noise = noise.view(batch_size, self.noise_dim, 1, 1)

        embedding = self.embed(labels).view(batch_size, self.noise_dim, 1, 1)

        x = torch.cat([noise, embedding], dim=1)
        return self.net(x)


log_dir = "./logs"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
writer_fake = SummaryWriter(os.path.join(log_dir, "fake"))
writer_real = SummaryWriter(os.path.join(log_dir, "real"))
step = 0

def train_cgan(dataset_name="MNIST", num_epochs=50, batch_size=128, lr=2e-4, img_size=32, num_classes=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) if dataset_name == "MNIST" else transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    if dataset_name == "MNIST":
        dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    elif dataset_name == "CIFAR10":
        dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    else:
        raise ValueError("Dataset not supported.")

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    channels_img = 1 if dataset_name == "MNIST" else 3
    channels_noise = 100
    features_g = 64
    features_d = 64

    gen = Generator(channels_noise, channels_img, features_g, num_classes, img_size).to(device)
    disc = Discriminator(channels_img, features_d, img_size, num_classes).to(device)

    opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

    criterion = nn.BCEWithLogitsLoss()

    fixed_noise = torch.randn(64, channels_noise).to(device)
    fixed_labels = torch.randint(0, num_classes, (64,)).to(device)

    global step

    for epoch in range(num_epochs):
        for batch_idx, (real, labels) in enumerate(dataloader):
            real, labels = real.to(device), labels.to(device)
            batch_size = real.shape[0]

            noise = torch.randn(batch_size, channels_noise).to(device)
            fake = gen(noise, labels)

            disc_real = disc(real, labels).reshape(-1)
            loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))

            disc_fake = disc(fake.detach(), labels).reshape(-1)
            loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

            loss_disc = (loss_disc_real + loss_disc_fake) / 2
            disc.zero_grad()
            loss_disc.backward()
            opt_disc.step()

            output = disc(fake, labels).reshape(-1)
            loss_gen = criterion(output, torch.ones_like(output))
            gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

            if batch_idx % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} "
                      f"Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

                writer_fake.add_scalar("Loss/Generator", loss_gen.item(), global_step=step)
                writer_real.add_scalar("Loss/Discriminator", loss_disc.item(), global_step=step)
                step += 1

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")

        with torch.no_grad():
            fake = gen(fixed_noise, fixed_labels)
            img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
            img_grid_real = torchvision.utils.make_grid(real[:64], normalize=True)

            writer_fake.add_image(f"{dataset_name} Fake", img_grid_fake, global_step=epoch)
            writer_real.add_image(f"{dataset_name} Real", img_grid_real, global_step=epoch)

    torch.save(gen.state_dict(), "generator.pth")
    torch.save(disc.state_dict(), "discriminator.pth")
    print("Training complete. Models saved.")

train_cgan(dataset_name="MNIST", num_epochs=50)
writer_fake.close()
writer_real.close()

Using device: cuda
Epoch [1/50] Batch 0/469 Loss D: 0.7319, Loss G: 0.8717
Epoch [1/50] Batch 100/469 Loss D: 0.8035, Loss G: 0.7991
Epoch [1/50] Batch 200/469 Loss D: 0.3755, Loss G: 2.4367
Epoch [1/50] Batch 300/469 Loss D: 0.3698, Loss G: 1.9489
Epoch [1/50] Batch 400/469 Loss D: 0.3353, Loss G: 2.1972
Epoch [1/50] Loss D: 0.3810, Loss G: 2.5593
Epoch [2/50] Batch 0/469 Loss D: 0.4465, Loss G: 1.9349
Epoch [2/50] Batch 100/469 Loss D: 0.3155, Loss G: 2.1350
Epoch [2/50] Batch 200/469 Loss D: 0.2694, Loss G: 2.6079
Epoch [2/50] Batch 300/469 Loss D: 0.2323, Loss G: 3.1206
Epoch [2/50] Batch 400/469 Loss D: 0.2958, Loss G: 3.0966
Epoch [2/50] Loss D: 0.2581, Loss G: 2.8493
Epoch [3/50] Batch 0/469 Loss D: 0.2479, Loss G: 2.1854
Epoch [3/50] Batch 100/469 Loss D: 0.3388, Loss G: 2.7881
Epoch [3/50] Batch 200/469 Loss D: 0.6343, Loss G: 1.7116
Epoch [3/50] Batch 300/469 Loss D: 0.3305, Loss G: 1.9741
Epoch [3/50] Batch 400/469 Loss D: 0.2134, Loss G: 2.7528
Epoch [3/50] Loss D: 0.1832, 