https://github.com/s-chh/Pytorch-cGAN-conditional-GAN/blob/main/cGAN.py


In [None]:
from torch import optim
import os
import torchvision.utils as vutils
import numpy as np
from torchvision import datasets
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import SVHN 
import shutil

In [None]:
from torch.utils.tensorboard import SummaryWriter

GAN_LOGS = os.path.join(os.getcwd(), "tboard_logs", "gan")
if  os.path.exists(GAN_LOGS):
    shutil.rmtree(GAN_LOGS)
if not os.path.exists(GAN_LOGS):
    os.makedirs(GAN_LOGS)

writer = SummaryWriter(GAN_LOGS)

In [None]:
# Arguments
BATCH_SIZE = 256
Z_DIM = 10
LABEL_EMBED_SIZE = 5
NUM_CLASSES = 10
IMGS_TO_DISPLAY_PER_CLASS = 20
LOAD_MODEL = False

CHANNELS = 3
EPOCHS = 10


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Data loader
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize([0.5], [0.5])])


dataset = SVHN(split='train',root='data/train',transform=transform,download=False)


data_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

In [None]:
samples_path = os.path.join('samples')
os.makedirs(samples_path, exist_ok=True)
model_path=os.path.join('models')
os.makedirs(model_path, exist_ok=True)

In [None]:
def generate_imgs(generator,z, fixed_label, epoch=0):
    generator.eval()
    fake_imgs = generator(z, fixed_label)
    fake_imgs = (fake_imgs + 1) / 2
    fake_imgs_ = vutils.make_grid(fake_imgs, normalize=False, nrow=IMGS_TO_DISPLAY_PER_CLASS)
    writer.add_image('images', fake_imgs_, global_step=epoch)
    vutils.save_image(fake_imgs_, os.path.join(samples_path, 'sample_' + str(epoch) + '.png'))

#### Implement a conditional DCGAN model (https://arxiv.org/abs/1411.1784)

In [None]:
# Method for storing generated images



# Networks
def conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)


class Generator(nn.Module):
    def __init__(self, z_dim=10, num_classes=10, label_embed_size=5, channels=3, conv_dim=64):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, label_embed_size)
        self.tconv1 = conv_block(z_dim + label_embed_size, conv_dim * 4, pad=0, transpose=True)
        self.tconv2 = conv_block(conv_dim * 4, conv_dim * 2, transpose=True)
        self.tconv3 = conv_block(conv_dim * 2, conv_dim, transpose=True)
        self.tconv4 = conv_block(conv_dim, channels, transpose=True, use_bn=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)

            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, label):
        x = x.reshape([x.shape[0], -1, 1, 1])
        label_embed = self.label_embedding(label)
        label_embed = label_embed.reshape([label_embed.shape[0], -1, 1, 1])
        x = torch.cat((x, label_embed), dim=1)
        x = F.relu(self.tconv1(x))
        x = F.relu(self.tconv2(x))
        x = F.relu(self.tconv3(x))
        x = torch.tanh(self.tconv4(x))
        return x


class Discriminator(nn.Module):
    def __init__(self, num_classes=10, channels=3, conv_dim=64):
        super(Discriminator, self).__init__()
        self.image_size = 32
        self.label_embedding = nn.Embedding(num_classes, self.image_size*self.image_size)
        self.conv1 = conv_block(channels + 1, conv_dim, use_bn=False)
        self.conv2 = conv_block(conv_dim, conv_dim * 2)
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4)
        self.conv4 = conv_block(conv_dim * 4, 1, k_size=4, stride=1, pad=0, use_bn=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.02)

            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, label):
        alpha = 0.2
        label_embed = self.label_embedding(label)
        label_embed = label_embed.reshape([label_embed.shape[0], 1, self.image_size, self.image_size])
        x = torch.cat((x, label_embed), dim=1)
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = torch.sigmoid(self.conv4(x))
        return x.squeeze()


In [None]:
gen = Generator(z_dim=Z_DIM, num_classes=NUM_CLASSES, label_embed_size=LABEL_EMBED_SIZE, channels=CHANNELS)
dis = Discriminator(num_classes=NUM_CLASSES, channels=CHANNELS)

In [None]:
# Load previous model   
if LOAD_MODEL:
    gen.load_state_dict(torch.load(os.path.join(model_path, 'gen.pkl')))
    dis.load_state_dict(torch.load(os.path.join(model_path, 'dis.pkl')))



# Define Optimizers
g_opt = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)
d_opt = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)

# Loss functions
loss_fn = nn.BCELoss()

# Fix images for viz
fixed_z = torch.randn(IMGS_TO_DISPLAY_PER_CLASS*NUM_CLASSES, Z_DIM)
fixed_label = torch.arange(0, NUM_CLASSES)
fixed_label = torch.repeat_interleave(fixed_label, IMGS_TO_DISPLAY_PER_CLASS)

# Labels
real_label = torch.ones(BATCH_SIZE)
fake_label = torch.zeros(BATCH_SIZE)

# GPU Compatibility


gen, dis = gen.to(device), dis.to(device)
real_label, fake_label = real_label.to(device), fake_label.to(device)
fixed_z, fixed_label = fixed_z.to(device), fixed_label.to(device)


max_iter = len(data_loader)

# Training


In [None]:
def train(EPOCHS,data_loader,generator,dis):
    total_iters = 0
    iter_ = 0
    
    for epoch in range(EPOCHS):
        generator.train()
        dis.train()

        for i, data in enumerate(data_loader):

            total_iters += 1

            # Loading data
            x_real, x_label = data
            z_fake = torch.randn(BATCH_SIZE, Z_DIM)

            
            x_real = x_real.to(device)
            x_label = x_label.to(device)
            z_fake = z_fake.to(device)

            # Generate fake data
            x_fake = generator(z_fake, x_label)

            # Train Discriminator
            fake_out = dis(x_fake.detach(), x_label)
            real_out = dis(x_real.detach(), x_label)
            d_loss_fake,d_loss_real=loss_fn(fake_out, fake_label) , loss_fn(real_out, real_label)
            d_loss = (d_loss_fake+d_loss_real) / 2

            d_opt.zero_grad()
            d_loss.backward()
            d_opt.step()

            # Train Generator
            fake_out = dis(x_fake, x_label)
            g_loss = loss_fn(fake_out, real_label)

            g_opt.zero_grad()
            g_loss.backward()
            g_opt.step()
            writer.add_scalar(f'Loss/Generator Loss', g_loss.item(), global_step=iter_)
            writer.add_scalar(f'Loss/Discriminator Loss', d_loss_real.item(), global_step=iter_)
            writer.add_scalars(f'Loss/Discriminator Losses', {
                        "Real Images Loss": d_loss_real.item(),
                        "Fake Images Loss": d_loss_fake.item(),
                    }, global_step=iter_)
            writer.add_scalars(f'Comb_Loss/Losses', {
                            'Discriminator': d_loss.item(),
                            'Generator':  g_loss.item()
                        }, iter_) 
            iter_=iter_+1

            if i % 50 == 0:
                print("Epoch: " + str(epoch + 1) + "/" + str(EPOCHS)
                    + "\titer: " + str(i) + "/" + str(max_iter)
                    + "\ttotal_iters: " + str(total_iters)
                    + "\td_loss:" + str(round(d_loss.item(), 4))
                    + "\tg_loss:" + str(round(g_loss.item(), 4))
                    )

        if (epoch + 1) % 5 == 0:
            # torch.save(generator.state_dict(), os.path.join(model_path, 'gen.pkl'))
            # torch.save(dis.state_dict(), os.path.join(model_path, 'dis.pkl'))

            generate_imgs(generator,fixed_z, fixed_label, epoch)



#### Train the model for conditional generation on the SVHN dataset

In [None]:
train(EPOCHS=EPOCHS,data_loader=data_loader,dis=dis,generator=gen)

#### Show the capabilities of the model to generate data based on given label

In [None]:

generate_imgs(gen,fixed_z, fixed_label, 1)

In [1]:
%load_ext tensorboard
%tensorboard --logdir=tboard_logs/gan