In [13]:
import os
import numpy as np
import math
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
import torch
img_save_path = './conditionalGAN/images/'
os.makedirs(img_save_path, exist_ok=True)

n_epochs = 200 # number of epochs of training
batch_size = 64 # size of the batches
lr = 0.0002 # learning rate
beta1 = 0.5 # decay of the first order mmomentum of gradient
beta2 = 0.999 # decay of second order momentum of gradient
latent_dim = 100 # dim of the latent space
n_classes = 10 # number of classes for dataset
image_size = 28 # size of each image dimension
channels = 1 # number of image channels
sample_interval = 200 # interval between image sampling
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal(m.weight,0.0,0.02)
    elif classname.find('BatchNorm2d')!=-1:
        torch.nn.init.normal(m.weight,1.0,0.02)
        torch.nn.init.constant(m.bias,0.0)
        
def idx2onehot(idx, n):
    assert torch.max(idx).item() < n and idx.dim() == 1
    idx2dim = idx.view(-1,1) # change from 1-dim tensor to 2-dim tensor
    onehot = torch.zeros(idx2dim.size(0),n).scatter_(1,idx2dim,1)

    return onehot


    
class Generator(nn.Module):
    # initializers
    def __init__(self):
        super().__init__()
        self.fc1_1 = nn.Linear(latent_dim,256)
        self.fc1_1_bn = nn.BatchNorm1d(256)
        self.fc1_2 = nn.Linear(10, 256)
        self.fc1_2_bn = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(512,512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512,1024)
        self.fc3_bn = nn.BatchNorm1d(1024)
        self.fc4 = nn.Linear(1024,image_size**2)
    
    # forward method
    def forward(self, z, y):
        x1 = F.relu(self.fc1_1_bn(self.fc1_1(z)))
        x2 = F.relu(self.fc1_2_bn(self.fc1_2(y)))
        x = torch.cat([x1, x2], dim = 1)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = F.relu(self.fc3_bn(self.fc3(x)))
        x = torch.tanh(self.fc4(x))
        return x
    
class Discriminator(nn.Module):
    # initializers
    def __init__(self):
        super().__init__()
        self.fc1_1 = nn.Linear(image_size**2, 1024)
        self.fc1_2 = nn.Linear(10, 1024)
        self.fc2 = nn.Linear(2048, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512,256)
        self.fc3_bn = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 1)
        
        # forward method
    def forward(self, input, y):
        x1 = F.leaky_relu(self.fc1_1(input), 0.2)
        x2 = F.leaky_relu(self.fc1_2(y),0.2)
        x = torch.cat([x1, x2], dim = 1)
        x = F.leaky_relu(self.fc2_bn(self.fc2(x)),0.2)
        x = F.leaky_relu(self.fc3_bn(self.fc3(x)),0.2)
        x = torch.sigmoid(self.fc4(x))
        return x
        
# Loss function
loss = torch.nn.BCELoss()

# Initialize Generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Move model to the corresponding device
generator.to(device)
discriminator.to(device)

# Initialize weights
#generator.apply(weights_init_normal)
#discriminator.apply(weights_init_normal)

data_save_path = './conditionalGAN/data/'
os.makedirs(data_save_path, exist_ok = True)
# 
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*channels,[0.5]*channels)
   
])
dataset = datasets.MNIST(root = data_save_path, train = True, download = True, transform=transform)


dataloader = DataLoader(dataset,batch_size = batch_size,shuffle=True,drop_last = True)
print("the data is OK")


the data is OK


In [14]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),lr = lr, betas = (beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = lr, betas = (beta1, beta2))
def reset_grad():
    optimizer_D.zero_grad()
    optimizer_G.zero_grad()

In [15]:
def sample_image(n_row, n_col,epoch):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = torch.randn(n_row*n_col, latent_dim).to(device)
    # Get labels ranging from 0 to n_classes for n rows
    labels = torch.Tensor([num for _ in range(n_col) for num in range(n_row)]).type(torch.LongTensor)
    labels = idx2onehot(labels,n_classes).to(device)
    gen_imgs = generator(z, labels).view(z.size(0),channels,image_size,image_size) # reshape the output of the generator
    save_image(gen_imgs.data, os.path.join(img_save_path,'{}.png'.format(epoch)),nrow=n_row, normalize=True)
    

In [None]:
total_step = len(dataloader)

for epoch in range(n_epochs):
    for i, (images, labels) in enumerate(dataloader):
        images = images.view(images.size(0), -1).to(device)
        
        real_labels = torch.ones(batch_size,1).to(device)
        fake_labels = torch.zeros(batch_size,1).to(device)
        
        # Configure labels
        
        labels_onehot = idx2onehot(labels,n_classes).to(device)

        
        # =============================================
        #
        #  Training Discriminator 
        #
        # =============================================
        
        outputs = discriminator(images,labels_onehot)
        d_loss_real = loss(outputs, real_labels)
        real_score = outputs # just for the purpose of tracking training progress
        
        z = torch.randn(batch_size, latent_dim).to(device)
        label_z = idx2onehot(torch.randint(n_classes,(batch_size,)),n_classes).to(device)
        
        fake_images = generator(z,label_z).detach()
        outputs = discriminator(fake_images,label_z)
        d_loss_fake = loss(outputs, fake_labels)
        fake_score = outputs # just for the purpose of tracking training progress
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad() # reset stored gradients
        d_loss.backward()
        optimizer_D.step()
        
        
        # =============================================
        #
        #  Training Generator
        #
        # ============================================
        
        z = torch.randn(batch_size, latent_dim).to(device)
        label_z = idx2onehot(torch.randint(n_classes,(batch_size,)),n_classes).to(device)

        fake_images = generator(z,label_z)
        outputs = discriminator(fake_images,label_z)
        g_loss = loss(outputs, real_labels)
        reset_grad()
        g_loss.backward()
        optimizer_G.step()
        
        if (i+1) % sample_interval == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss:{:.2f}, D(x):{:.2f}, D(G(z)): {:.2f}'.format(epoch, n_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))
          
    sample_image(n_row=10,n_col=10,epoch=epoch)

Epoch [0/200], Step [200/937], d_loss: 1.1161, g_loss:0.94, D(x):0.57, D(G(z)): 0.42
Epoch [0/200], Step [400/937], d_loss: 1.1329, g_loss:0.96, D(x):0.56, D(G(z)): 0.42
Epoch [0/200], Step [600/937], d_loss: 1.1653, g_loss:0.91, D(x):0.57, D(G(z)): 0.44
Epoch [0/200], Step [800/937], d_loss: 0.9513, g_loss:0.93, D(x):0.61, D(G(z)): 0.36
Epoch [1/200], Step [200/937], d_loss: 1.0670, g_loss:0.92, D(x):0.57, D(G(z)): 0.38
Epoch [1/200], Step [400/937], d_loss: 1.1945, g_loss:0.74, D(x):0.54, D(G(z)): 0.42
Epoch [1/200], Step [600/937], d_loss: 1.0021, g_loss:0.98, D(x):0.63, D(G(z)): 0.40
Epoch [1/200], Step [800/937], d_loss: 1.2912, g_loss:1.10, D(x):0.49, D(G(z)): 0.40
Epoch [2/200], Step [200/937], d_loss: 1.3802, g_loss:0.97, D(x):0.52, D(G(z)): 0.49
Epoch [2/200], Step [400/937], d_loss: 1.0948, g_loss:1.23, D(x):0.60, D(G(z)): 0.42
Epoch [2/200], Step [600/937], d_loss: 1.2602, g_loss:0.82, D(x):0.53, D(G(z)): 0.45
Epoch [2/200], Step [800/937], d_loss: 1.0631, g_loss:1.08, D(x):

Epoch [24/200], Step [200/937], d_loss: 0.6601, g_loss:2.08, D(x):0.79, D(G(z)): 0.32
Epoch [24/200], Step [400/937], d_loss: 0.2969, g_loss:1.69, D(x):0.93, D(G(z)): 0.19
Epoch [24/200], Step [600/937], d_loss: 0.9531, g_loss:1.61, D(x):0.77, D(G(z)): 0.45
Epoch [24/200], Step [800/937], d_loss: 0.4717, g_loss:1.74, D(x):0.71, D(G(z)): 0.07
Epoch [25/200], Step [200/937], d_loss: 0.6275, g_loss:3.00, D(x):0.90, D(G(z)): 0.35
Epoch [25/200], Step [400/937], d_loss: 0.3126, g_loss:2.68, D(x):0.81, D(G(z)): 0.09
Epoch [25/200], Step [600/937], d_loss: 0.4172, g_loss:2.50, D(x):0.98, D(G(z)): 0.31
Epoch [25/200], Step [800/937], d_loss: 0.4015, g_loss:2.08, D(x):0.79, D(G(z)): 0.14
Epoch [26/200], Step [200/937], d_loss: 0.6422, g_loss:2.63, D(x):0.71, D(G(z)): 0.19
Epoch [26/200], Step [400/937], d_loss: 1.0365, g_loss:2.76, D(x):0.45, D(G(z)): 0.09
Epoch [26/200], Step [600/937], d_loss: 0.1716, g_loss:3.01, D(x):0.90, D(G(z)): 0.06
Epoch [26/200], Step [800/937], d_loss: 0.4209, g_loss

Epoch [48/200], Step [200/937], d_loss: 0.6749, g_loss:1.51, D(x):0.60, D(G(z)): 0.05
Epoch [48/200], Step [400/937], d_loss: 0.1546, g_loss:3.59, D(x):0.95, D(G(z)): 0.08
Epoch [48/200], Step [600/937], d_loss: 0.4066, g_loss:0.75, D(x):0.71, D(G(z)): 0.03
Epoch [48/200], Step [800/937], d_loss: 0.3501, g_loss:2.79, D(x):0.75, D(G(z)): 0.03
Epoch [49/200], Step [200/937], d_loss: 0.1499, g_loss:3.64, D(x):0.93, D(G(z)): 0.07
Epoch [49/200], Step [400/937], d_loss: 0.1279, g_loss:0.98, D(x):0.96, D(G(z)): 0.08
Epoch [49/200], Step [600/937], d_loss: 0.1978, g_loss:3.25, D(x):1.00, D(G(z)): 0.14
Epoch [49/200], Step [800/937], d_loss: 0.3990, g_loss:3.82, D(x):0.79, D(G(z)): 0.12
Epoch [50/200], Step [200/937], d_loss: 0.0388, g_loss:3.59, D(x):0.98, D(G(z)): 0.02
Epoch [50/200], Step [400/937], d_loss: 0.0726, g_loss:3.89, D(x):0.95, D(G(z)): 0.02
Epoch [50/200], Step [600/937], d_loss: 0.4622, g_loss:1.70, D(x):0.73, D(G(z)): 0.08
Epoch [50/200], Step [800/937], d_loss: 0.3503, g_loss

Epoch [72/200], Step [200/937], d_loss: 1.2948, g_loss:3.75, D(x):0.75, D(G(z)): 0.57
Epoch [72/200], Step [400/937], d_loss: 0.0973, g_loss:1.03, D(x):0.93, D(G(z)): 0.02
Epoch [72/200], Step [600/937], d_loss: 1.6816, g_loss:4.64, D(x):0.25, D(G(z)): 0.01
Epoch [72/200], Step [800/937], d_loss: 0.2190, g_loss:1.55, D(x):0.94, D(G(z)): 0.13
Epoch [73/200], Step [200/937], d_loss: 0.2711, g_loss:2.83, D(x):0.91, D(G(z)): 0.13
Epoch [73/200], Step [400/937], d_loss: 0.3224, g_loss:5.18, D(x):0.79, D(G(z)): 0.06
Epoch [73/200], Step [600/937], d_loss: 0.2728, g_loss:6.02, D(x):0.79, D(G(z)): 0.01
Epoch [73/200], Step [800/937], d_loss: 0.0871, g_loss:2.00, D(x):0.97, D(G(z)): 0.05
Epoch [74/200], Step [200/937], d_loss: 1.6894, g_loss:2.50, D(x):0.26, D(G(z)): 0.03
Epoch [74/200], Step [400/937], d_loss: 0.6392, g_loss:3.39, D(x):0.62, D(G(z)): 0.04
Epoch [74/200], Step [600/937], d_loss: 0.1356, g_loss:4.15, D(x):1.00, D(G(z)): 0.11
Epoch [74/200], Step [800/937], d_loss: 0.1097, g_loss

Epoch [96/200], Step [200/937], d_loss: 0.2418, g_loss:3.54, D(x):0.96, D(G(z)): 0.16
Epoch [96/200], Step [400/937], d_loss: 0.0432, g_loss:6.81, D(x):0.99, D(G(z)): 0.04
Epoch [96/200], Step [600/937], d_loss: 0.3890, g_loss:0.65, D(x):0.75, D(G(z)): 0.04
Epoch [96/200], Step [800/937], d_loss: 0.3334, g_loss:3.72, D(x):0.75, D(G(z)): 0.01
Epoch [97/200], Step [200/937], d_loss: 0.5246, g_loss:3.63, D(x):0.68, D(G(z)): 0.03
Epoch [97/200], Step [400/937], d_loss: 0.0064, g_loss:7.47, D(x):0.99, D(G(z)): 0.00
Epoch [97/200], Step [600/937], d_loss: 0.0725, g_loss:2.39, D(x):0.99, D(G(z)): 0.06
Epoch [97/200], Step [800/937], d_loss: 0.6167, g_loss:1.38, D(x):0.62, D(G(z)): 0.03
Epoch [98/200], Step [200/937], d_loss: 0.2033, g_loss:5.67, D(x):0.96, D(G(z)): 0.13
Epoch [98/200], Step [400/937], d_loss: 0.3782, g_loss:4.89, D(x):0.86, D(G(z)): 0.18
Epoch [98/200], Step [600/937], d_loss: 0.0637, g_loss:4.76, D(x):0.98, D(G(z)): 0.04
Epoch [98/200], Step [800/937], d_loss: 0.1525, g_loss

In [4]:
import torch
print(torch.Tensor([0.51]).round())

tensor([1.])
