In [1]:
import os
import numpy as np
import math
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
import torch
img_save_path = './conditionalGAN/images/'
os.makedirs(img_save_path, exist_ok=True)

n_epochs = 200 # number of epochs of training
batch_size = 64 # size of the batches
lr = 0.0002 # learning rate
beta1 = 0.5 # decay of the first order mmomentum of gradient
beta2 = 0.999 # decay of second order momentum of gradient
latent_dim = 100 # dim of the latent space
n_classes = 10 # number of classes for dataset
image_size = 28 # size of each image dimension
channels = 1 # number of image channels
sample_interval = 200 # interval between image sampling
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal(m.weight,0.0,0.02)
    elif classname.find('BatchNorm2d')!=-1:
        torch.nn.init.normal(m.weight,1.0,0.02)
        torch.nn.init.constant(m.bias,0.0)
        
def idx2onehot(idx, n):
    assert torch.max(idx).item() < n and idx.dim() == 1
    idx2dim = idx.view(-1,1) # change from 1-dim tensor to 2-dim tensor
    onehot = torch.zeros(idx2dim.size(0),n).scatter_(1,idx2dim,1)

    return onehot


    
class Generator(nn.Module):
    # initializers
    def __init__(self):
        super().__init__()
        self.fc1_1 = nn.Linear(latent_dim,256)
        self.fc1_1_bn = nn.BatchNorm1d(256)
        self.fc1_2 = nn.Linear(10, 256)
        self.fc1_2_bn = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(512,512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512,1024)
        self.fc3_bn = nn.BatchNorm1d(1024)
        self.fc4 = nn.Linear(1024,image_size**2)
    
    # forward method
    def forward(self, z, y):
        x1 = F.relu(self.fc1_1_bn(self.fc1_1(z)))
        x2 = F.relu(self.fc1_2_bn(self.fc1_2(y)))
        x = torch.cat([x1, x2], dim = 1)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = F.relu(self.fc3_bn(self.fc3(x)))
        x = torch.tanh(self.fc4(x))
        return x
    
class Discriminator(nn.Module):
    # initializers
    def __init__(self):
        super().__init__()
        self.fc1_1 = nn.Linear(image_size**2, 1024)
        self.fc1_2 = nn.Linear(10, 1024)
        self.fc2 = nn.Linear(2048, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512,256)
        self.fc3_bn = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 1)
        
        # forward method
    def forward(self, input, y):
        x1 = F.leaky_relu(self.fc1_1(input), 0.2)
        x2 = F.leaky_relu(self.fc1_2(y),0.2)
        x = torch.cat([x1, x2], dim = 1)
        x = F.leaky_relu(self.fc2_bn(self.fc2(x)),0.2)
        x = F.leaky_relu(self.fc3_bn(self.fc3(x)),0.2)
        x = torch.sigmoid(self.fc4(x))
        return x
        
# Loss function
loss = torch.nn.BCELoss()

# Initialize Generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Move model to the corresponding device
generator.to(device)
discriminator.to(device)

# Initialize weights
#generator.apply(weights_init_normal)
#discriminator.apply(weights_init_normal)

data_save_path = './conditionalGAN/data/'
os.makedirs(data_save_path, exist_ok = True)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*channels,[0.5]*channels)
])
dataset = datasets.MNIST(root = data_save_path, train = True, download = True, transform=transform)


dataloader = DataLoader(dataset,batch_size = batch_size,shuffle=True,drop_last = True)
print("the data is OK")


0.1%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./conditionalGAN/data/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./conditionalGAN/data/MNIST/raw/train-images-idx3-ubyte.gz to ./conditionalGAN/data/MNIST/raw


28.4%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./conditionalGAN/data/MNIST/raw/train-labels-idx1-ubyte.gz


0.5%5%

Extracting ./conditionalGAN/data/MNIST/raw/train-labels-idx1-ubyte.gz to ./conditionalGAN/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./conditionalGAN/data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ./conditionalGAN/data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./conditionalGAN/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./conditionalGAN/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./conditionalGAN/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./conditionalGAN/data/MNIST/raw
Processing...
Done!
the data is OK


In [2]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),lr = lr, betas = (beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = lr, betas = (beta1, beta2))
def reset_grad():
    optimizer_D.zero_grad()
    optimizer_G.zero_grad()

In [5]:
def sample_image(n_row, n_col,epoch):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = torch.randn(n_row*n_col, latent_dim).to(device)
    # Get labels ranging from 0 to n_classes for n rows
    labels = torch.Tensor([num for _ in range(n_col) for num in range(n_row)]).type(torch.LongTensor)
    labels = idx2onehot(labels,n_classes).to(device)
    gen_imgs = generator(z, labels).view(z.size(0),channels,image_size,image_size) # reshape the output of the generator
    save_image(gen_imgs.data, os.path.join(img_save_path,'{}.png'.format(epoch)),nrow=n_row, normalize=True)
    

In [6]:
total_step = len(dataloader)

for epoch in range(n_epochs):
    for i, (images, labels) in enumerate(dataloader):
        images = images.view(images.size(0), -1).to(device)
        
        real_labels = torch.ones(batch_size,1).to(device)
        fake_labels = torch.zeros(batch_size,1).to(device)
        
        # Configure labels
        
        labels_onehot = idx2onehot(labels,n_classes).to(device)

        
        # =============================================
        #
        #  Training Discriminator 
        #
        # =============================================
        
        outputs = discriminator(images,labels_onehot)
        d_loss_real = loss(outputs, real_labels)
        real_score = outputs # just for the purpose of tracking training progress
        
        z = torch.randn(batch_size, latent_dim).to(device)
        label_z = idx2onehot(torch.randint(n_classes,(batch_size,)),n_classes).to(device)
        
        fake_images = generator(z,label_z).detach()
        outputs = discriminator(fake_images,label_z)
        d_loss_fake = loss(outputs, fake_labels)
        fake_score = outputs # just for the purpose of tracking training progress
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad() # reset stored gradients
        d_loss.backward()
        optimizer_D.step()
        
        
        # =============================================
        #
        #  Training Generator
        #
        # ============================================
        
        z = torch.randn(batch_size, latent_dim).to(device)
        label_z = idx2onehot(torch.randint(n_classes,(batch_size,)),n_classes).to(device)

        fake_images = generator(z,label_z)
        outputs = discriminator(fake_images,label_z)
        g_loss = loss(outputs, real_labels)
        reset_grad()
        g_loss.backward()
        optimizer_G.step()
        
        if (i+1) % sample_interval == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss:{:.2f}, D(G(z)): {:.2f}'.format(epoch, n_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))
          
    sample_image(n_row=10,n_col=10,epoch=epoch)

Epoch [0/200], Step [200/937], d_loss: 0.9844, g_loss:0.83, D(G(z)): 0.60
Epoch [0/200], Step [400/937], d_loss: 1.0356, g_loss:1.09, D(G(z)): 0.63
Epoch [0/200], Step [600/937], d_loss: 1.3081, g_loss:0.89, D(G(z)): 0.51
Epoch [0/200], Step [800/937], d_loss: 1.0311, g_loss:1.09, D(G(z)): 0.63
Epoch [1/200], Step [200/937], d_loss: 1.2091, g_loss:0.91, D(G(z)): 0.56
Epoch [1/200], Step [400/937], d_loss: 1.0780, g_loss:1.02, D(G(z)): 0.59
Epoch [1/200], Step [600/937], d_loss: 1.1574, g_loss:1.04, D(G(z)): 0.54
Epoch [1/200], Step [800/937], d_loss: 1.1214, g_loss:1.16, D(G(z)): 0.63
Epoch [2/200], Step [200/937], d_loss: 1.0272, g_loss:1.29, D(G(z)): 0.55
Epoch [2/200], Step [400/937], d_loss: 1.0769, g_loss:1.13, D(G(z)): 0.66
Epoch [2/200], Step [600/937], d_loss: 0.8803, g_loss:0.99, D(G(z)): 0.55
Epoch [2/200], Step [800/937], d_loss: 0.6323, g_loss:1.69, D(G(z)): 0.76
Epoch [3/200], Step [200/937], d_loss: 0.8940, g_loss:1.19, D(G(z)): 0.68
Epoch [3/200], Step [400/937], d_loss:

Epoch [27/200], Step [600/937], d_loss: 0.9508, g_loss:1.02, D(G(z)): 0.49
Epoch [27/200], Step [800/937], d_loss: 0.3123, g_loss:3.36, D(G(z)): 0.98
Epoch [28/200], Step [200/937], d_loss: 0.2040, g_loss:2.84, D(G(z)): 0.87
Epoch [28/200], Step [400/937], d_loss: 0.3517, g_loss:2.64, D(G(z)): 0.78
Epoch [28/200], Step [600/937], d_loss: 0.2281, g_loss:2.54, D(G(z)): 0.89
Epoch [28/200], Step [800/937], d_loss: 0.1636, g_loss:1.99, D(G(z)): 0.92
Epoch [29/200], Step [200/937], d_loss: 0.2990, g_loss:4.29, D(G(z)): 0.80
Epoch [29/200], Step [400/937], d_loss: 0.2568, g_loss:3.28, D(G(z)): 0.86
Epoch [29/200], Step [600/937], d_loss: 0.3882, g_loss:1.99, D(G(z)): 0.82
Epoch [29/200], Step [800/937], d_loss: 0.3465, g_loss:1.73, D(G(z)): 0.84
Epoch [30/200], Step [200/937], d_loss: 0.7885, g_loss:3.39, D(G(z)): 0.84
Epoch [30/200], Step [400/937], d_loss: 0.3023, g_loss:1.99, D(G(z)): 0.87
Epoch [30/200], Step [600/937], d_loss: 0.4530, g_loss:2.07, D(G(z)): 0.74
Epoch [30/200], Step [800

Epoch [55/200], Step [200/937], d_loss: 0.3065, g_loss:2.91, D(G(z)): 0.83
Epoch [55/200], Step [400/937], d_loss: 0.1530, g_loss:3.11, D(G(z)): 0.99
Epoch [55/200], Step [600/937], d_loss: 0.3681, g_loss:1.56, D(G(z)): 0.92
Epoch [55/200], Step [800/937], d_loss: 0.0863, g_loss:4.33, D(G(z)): 0.96
Epoch [56/200], Step [200/937], d_loss: 0.2002, g_loss:2.70, D(G(z)): 0.99
Epoch [56/200], Step [400/937], d_loss: 0.2169, g_loss:3.72, D(G(z)): 0.84
Epoch [56/200], Step [600/937], d_loss: 0.1877, g_loss:4.50, D(G(z)): 0.96
Epoch [56/200], Step [800/937], d_loss: 0.1102, g_loss:2.92, D(G(z)): 0.95
Epoch [57/200], Step [200/937], d_loss: 0.4326, g_loss:3.01, D(G(z)): 0.78
Epoch [57/200], Step [400/937], d_loss: 0.1715, g_loss:2.97, D(G(z)): 0.96
Epoch [57/200], Step [600/937], d_loss: 0.8025, g_loss:2.08, D(G(z)): 0.53
Epoch [57/200], Step [800/937], d_loss: 0.5527, g_loss:4.33, D(G(z)): 0.67
Epoch [58/200], Step [200/937], d_loss: 0.6614, g_loss:3.83, D(G(z)): 0.99
Epoch [58/200], Step [400

Epoch [82/200], Step [600/937], d_loss: 0.1526, g_loss:1.73, D(G(z)): 0.92
Epoch [82/200], Step [800/937], d_loss: 0.2671, g_loss:3.82, D(G(z)): 0.98
Epoch [83/200], Step [200/937], d_loss: 0.0902, g_loss:3.95, D(G(z)): 0.97
Epoch [83/200], Step [400/937], d_loss: 0.1436, g_loss:5.37, D(G(z)): 0.97
Epoch [83/200], Step [600/937], d_loss: 0.2091, g_loss:4.10, D(G(z)): 0.88
Epoch [83/200], Step [800/937], d_loss: 0.2348, g_loss:5.60, D(G(z)): 0.81
Epoch [84/200], Step [200/937], d_loss: 0.1017, g_loss:4.67, D(G(z)): 0.99
Epoch [84/200], Step [400/937], d_loss: 0.2101, g_loss:5.02, D(G(z)): 0.85
Epoch [84/200], Step [600/937], d_loss: 0.2481, g_loss:0.85, D(G(z)): 0.98
Epoch [84/200], Step [800/937], d_loss: 0.3090, g_loss:3.90, D(G(z)): 0.85
Epoch [85/200], Step [200/937], d_loss: 0.0672, g_loss:3.39, D(G(z)): 0.99
Epoch [85/200], Step [400/937], d_loss: 0.3984, g_loss:1.37, D(G(z)): 0.72
Epoch [85/200], Step [600/937], d_loss: 0.2992, g_loss:3.96, D(G(z)): 0.86
Epoch [85/200], Step [800

Epoch [109/200], Step [800/937], d_loss: 0.0513, g_loss:1.47, D(G(z)): 0.99
Epoch [110/200], Step [200/937], d_loss: 0.1444, g_loss:3.83, D(G(z)): 0.91
Epoch [110/200], Step [400/937], d_loss: 0.3338, g_loss:1.82, D(G(z)): 0.79
Epoch [110/200], Step [600/937], d_loss: 0.0250, g_loss:0.91, D(G(z)): 1.00
Epoch [110/200], Step [800/937], d_loss: 0.1905, g_loss:3.78, D(G(z)): 0.94
Epoch [111/200], Step [200/937], d_loss: 0.4336, g_loss:1.96, D(G(z)): 0.73
Epoch [111/200], Step [400/937], d_loss: 0.1015, g_loss:4.18, D(G(z)): 0.99
Epoch [111/200], Step [600/937], d_loss: 0.6363, g_loss:3.37, D(G(z)): 0.57
Epoch [111/200], Step [800/937], d_loss: 0.5388, g_loss:5.30, D(G(z)): 0.67
Epoch [112/200], Step [200/937], d_loss: 0.6934, g_loss:3.44, D(G(z)): 0.98
Epoch [112/200], Step [400/937], d_loss: 0.2016, g_loss:4.35, D(G(z)): 0.86
Epoch [112/200], Step [600/937], d_loss: 0.0635, g_loss:1.20, D(G(z)): 0.98
Epoch [112/200], Step [800/937], d_loss: 0.1185, g_loss:2.84, D(G(z)): 0.97
Epoch [113/2

Epoch [136/200], Step [800/937], d_loss: 0.1912, g_loss:3.36, D(G(z)): 0.99
Epoch [137/200], Step [200/937], d_loss: 0.4976, g_loss:3.81, D(G(z)): 0.98
Epoch [137/200], Step [400/937], d_loss: 0.3566, g_loss:4.06, D(G(z)): 0.94
Epoch [137/200], Step [600/937], d_loss: 0.1523, g_loss:4.22, D(G(z)): 0.97
Epoch [137/200], Step [800/937], d_loss: 0.5663, g_loss:3.14, D(G(z)): 0.62
Epoch [138/200], Step [200/937], d_loss: 0.6896, g_loss:3.83, D(G(z)): 0.99
Epoch [138/200], Step [400/937], d_loss: 0.2906, g_loss:0.89, D(G(z)): 0.88
Epoch [138/200], Step [600/937], d_loss: 0.9930, g_loss:2.37, D(G(z)): 0.77
Epoch [138/200], Step [800/937], d_loss: 0.0882, g_loss:4.49, D(G(z)): 0.92
Epoch [139/200], Step [200/937], d_loss: 0.1629, g_loss:4.76, D(G(z)): 0.98
Epoch [139/200], Step [400/937], d_loss: 0.1189, g_loss:2.30, D(G(z)): 0.99
Epoch [139/200], Step [600/937], d_loss: 1.1152, g_loss:3.23, D(G(z)): 0.85
Epoch [139/200], Step [800/937], d_loss: 0.1511, g_loss:2.46, D(G(z)): 0.90
Epoch [140/2

Epoch [163/200], Step [800/937], d_loss: 0.7937, g_loss:4.44, D(G(z)): 0.53
Epoch [164/200], Step [200/937], d_loss: 0.2349, g_loss:3.83, D(G(z)): 0.82
Epoch [164/200], Step [400/937], d_loss: 0.3610, g_loss:3.00, D(G(z)): 0.78
Epoch [164/200], Step [600/937], d_loss: 1.3135, g_loss:4.40, D(G(z)): 0.91
Epoch [164/200], Step [800/937], d_loss: 1.2875, g_loss:4.18, D(G(z)): 0.97
Epoch [165/200], Step [200/937], d_loss: 0.1446, g_loss:5.03, D(G(z)): 0.93
Epoch [165/200], Step [400/937], d_loss: 0.1927, g_loss:1.70, D(G(z)): 0.85
Epoch [165/200], Step [600/937], d_loss: 0.2511, g_loss:1.50, D(G(z)): 1.00
Epoch [165/200], Step [800/937], d_loss: 0.1936, g_loss:2.58, D(G(z)): 0.86
Epoch [166/200], Step [200/937], d_loss: 0.1458, g_loss:6.83, D(G(z)): 0.95
Epoch [166/200], Step [400/937], d_loss: 0.0800, g_loss:3.99, D(G(z)): 0.97
Epoch [166/200], Step [600/937], d_loss: 0.3335, g_loss:3.92, D(G(z)): 0.90
Epoch [166/200], Step [800/937], d_loss: 0.9031, g_loss:5.28, D(G(z)): 0.66
Epoch [167/2

Epoch [190/200], Step [800/937], d_loss: 0.0713, g_loss:1.57, D(G(z)): 0.96
Epoch [191/200], Step [200/937], d_loss: 5.4366, g_loss:6.02, D(G(z)): 0.01
Epoch [191/200], Step [400/937], d_loss: 0.2487, g_loss:4.40, D(G(z)): 0.82
Epoch [191/200], Step [600/937], d_loss: 0.7838, g_loss:1.83, D(G(z)): 0.55
Epoch [191/200], Step [800/937], d_loss: 0.3976, g_loss:1.47, D(G(z)): 0.72
Epoch [192/200], Step [200/937], d_loss: 0.2986, g_loss:4.00, D(G(z)): 0.98
Epoch [192/200], Step [400/937], d_loss: 0.7279, g_loss:4.50, D(G(z)): 0.72
Epoch [192/200], Step [600/937], d_loss: 0.1717, g_loss:1.84, D(G(z)): 0.97
Epoch [192/200], Step [800/937], d_loss: 0.3222, g_loss:6.38, D(G(z)): 0.88
Epoch [193/200], Step [200/937], d_loss: 0.2626, g_loss:7.28, D(G(z)): 0.80
Epoch [193/200], Step [400/937], d_loss: 0.1668, g_loss:4.97, D(G(z)): 0.88
Epoch [193/200], Step [600/937], d_loss: 0.0824, g_loss:2.21, D(G(z)): 0.95
Epoch [193/200], Step [800/937], d_loss: 0.3519, g_loss:3.35, D(G(z)): 1.00
Epoch [194/2

In [None]:
torch.randint(10,[3,2])