In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch

In [None]:
#os.makedirs('images', exist_ok=True)
# parser is not used
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=400, help='interval between image sampling')
opt = parser.parse_args()
#print(opt)

In [2]:
cuda = True if torch.cuda.is_available() else False

In [3]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


### Genarator class 
-  no pooling layer 
-  added batch normalization

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 32 // 4
        self.l1 = nn.Sequential(nn.Linear(100, 128*self.init_size**2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


### Discriminator class

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [   nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                        nn.LeakyReLU(0.2, inplace=True),
                        nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(1, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = 32 // 2**4
        self.adv_layer = nn.Sequential( nn.Linear(128*ds_size**2, 1),
                                        nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [9]:
class DataLoad(object):
    def __init__(self):
        pass
    def load_data(self):
        batch_size = 100
        # Image processing
        transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])

        # MNIST dataset
        mnist = torchvision.datasets.MNIST(root='E:/DataSet/MNIST',
                                   train=True,
                                   transform=transform,
                                   download=True)

        # Data loader
        data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)
                
        return data_loader

In [11]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [12]:
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


Discriminator(
  (model): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True)
    

In [13]:
obj=DataLoad()
dataloader=obj.load_data()

In [14]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

### Training

In [None]:
for epoch in range(200):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], 100))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, 200, i, len(dataloader),
                                                            d_loss.item(), g_loss.item()))

        batches_done = epoch * len(dataloader) + i
        if batches_done % 400 == 0:
            save_image(gen_imgs.data[:25], 'E:/DataSet/etc/dcgan/%d.png' % batches_done, nrow=5, normalize=True)



[Epoch 0/200] [Batch 0/600] [D loss: 0.693282] [G loss: 0.678581]
[Epoch 0/200] [Batch 1/600] [D loss: 0.693218] [G loss: 0.679163]
[Epoch 0/200] [Batch 2/600] [D loss: 0.693152] [G loss: 0.679794]
[Epoch 0/200] [Batch 3/600] [D loss: 0.693105] [G loss: 0.680377]
[Epoch 0/200] [Batch 4/600] [D loss: 0.692938] [G loss: 0.680980]
[Epoch 0/200] [Batch 5/600] [D loss: 0.692880] [G loss: 0.681572]
[Epoch 0/200] [Batch 6/600] [D loss: 0.692813] [G loss: 0.682306]
[Epoch 0/200] [Batch 7/600] [D loss: 0.692641] [G loss: 0.682796]
[Epoch 0/200] [Batch 8/600] [D loss: 0.692510] [G loss: 0.683353]
[Epoch 0/200] [Batch 9/600] [D loss: 0.692403] [G loss: 0.683873]
[Epoch 0/200] [Batch 10/600] [D loss: 0.692168] [G loss: 0.684557]
[Epoch 0/200] [Batch 11/600] [D loss: 0.691865] [G loss: 0.685088]
[Epoch 0/200] [Batch 12/600] [D loss: 0.691762] [G loss: 0.685529]
[Epoch 0/200] [Batch 13/600] [D loss: 0.691438] [G loss: 0.685880]
[Epoch 0/200] [Batch 14/600] [D loss: 0.691101] [G loss: 0.686163]
[Epoc

[Epoch 0/200] [Batch 242/600] [D loss: 0.003182] [G loss: 6.253016]
[Epoch 0/200] [Batch 243/600] [D loss: 0.002791] [G loss: 6.330124]
[Epoch 0/200] [Batch 244/600] [D loss: 0.002594] [G loss: 6.226155]
[Epoch 0/200] [Batch 245/600] [D loss: 0.002746] [G loss: 6.256879]
[Epoch 0/200] [Batch 246/600] [D loss: 0.002882] [G loss: 6.232729]
[Epoch 0/200] [Batch 247/600] [D loss: 0.003653] [G loss: 6.452109]
[Epoch 0/200] [Batch 248/600] [D loss: 0.002937] [G loss: 6.470602]
[Epoch 0/200] [Batch 249/600] [D loss: 0.003109] [G loss: 6.499826]
[Epoch 0/200] [Batch 250/600] [D loss: 0.002817] [G loss: 6.385987]
[Epoch 0/200] [Batch 251/600] [D loss: 0.002815] [G loss: 6.418725]
[Epoch 0/200] [Batch 252/600] [D loss: 0.003766] [G loss: 6.478392]
[Epoch 0/200] [Batch 253/600] [D loss: 0.002572] [G loss: 6.461334]
[Epoch 0/200] [Batch 254/600] [D loss: 0.002012] [G loss: 6.422729]
[Epoch 0/200] [Batch 255/600] [D loss: 0.002766] [G loss: 6.334690]
[Epoch 0/200] [Batch 256/600] [D loss: 0.002778]

[Epoch 0/200] [Batch 482/600] [D loss: 0.000598] [G loss: 7.788920]
[Epoch 0/200] [Batch 483/600] [D loss: 0.000745] [G loss: 7.594642]
[Epoch 0/200] [Batch 484/600] [D loss: 0.000718] [G loss: 7.785437]
[Epoch 0/200] [Batch 485/600] [D loss: 0.000811] [G loss: 7.922599]
[Epoch 0/200] [Batch 486/600] [D loss: 0.000614] [G loss: 7.997054]
[Epoch 0/200] [Batch 487/600] [D loss: 0.000648] [G loss: 7.892864]
[Epoch 0/200] [Batch 488/600] [D loss: 0.000578] [G loss: 7.879824]
[Epoch 0/200] [Batch 489/600] [D loss: 0.000540] [G loss: 7.772927]
[Epoch 0/200] [Batch 490/600] [D loss: 0.000775] [G loss: 8.086200]
[Epoch 0/200] [Batch 491/600] [D loss: 0.000602] [G loss: 7.953145]
[Epoch 0/200] [Batch 492/600] [D loss: 0.001056] [G loss: 8.093983]
[Epoch 0/200] [Batch 493/600] [D loss: 0.000528] [G loss: 8.172709]
[Epoch 0/200] [Batch 494/600] [D loss: 0.000737] [G loss: 7.941971]
[Epoch 0/200] [Batch 495/600] [D loss: 0.000586] [G loss: 8.243049]
[Epoch 0/200] [Batch 496/600] [D loss: 0.000611]

[Epoch 1/200] [Batch 124/600] [D loss: 0.000283] [G loss: 8.422001]
[Epoch 1/200] [Batch 125/600] [D loss: 0.000291] [G loss: 8.666947]
[Epoch 1/200] [Batch 126/600] [D loss: 0.000260] [G loss: 8.711805]
[Epoch 1/200] [Batch 127/600] [D loss: 0.000282] [G loss: 8.763564]
[Epoch 1/200] [Batch 128/600] [D loss: 0.000323] [G loss: 8.957349]
[Epoch 1/200] [Batch 129/600] [D loss: 0.000244] [G loss: 8.723634]
[Epoch 1/200] [Batch 130/600] [D loss: 0.000249] [G loss: 8.753546]
[Epoch 1/200] [Batch 131/600] [D loss: 0.000338] [G loss: 8.498034]
[Epoch 1/200] [Batch 132/600] [D loss: 0.000387] [G loss: 8.780226]
[Epoch 1/200] [Batch 133/600] [D loss: 0.000355] [G loss: 8.690950]
[Epoch 1/200] [Batch 134/600] [D loss: 0.000250] [G loss: 8.713392]
[Epoch 1/200] [Batch 135/600] [D loss: 0.000337] [G loss: 8.937205]
[Epoch 1/200] [Batch 136/600] [D loss: 0.000254] [G loss: 8.613260]
[Epoch 1/200] [Batch 137/600] [D loss: 0.000259] [G loss: 8.693749]
[Epoch 1/200] [Batch 138/600] [D loss: 0.000236]

[Epoch 1/200] [Batch 364/600] [D loss: 0.000163] [G loss: 9.275195]
[Epoch 1/200] [Batch 365/600] [D loss: 0.000191] [G loss: 9.403320]
[Epoch 1/200] [Batch 366/600] [D loss: 0.000170] [G loss: 9.357193]
[Epoch 1/200] [Batch 367/600] [D loss: 0.000212] [G loss: 9.282433]
[Epoch 1/200] [Batch 368/600] [D loss: 0.000196] [G loss: 9.276418]
[Epoch 1/200] [Batch 369/600] [D loss: 0.000163] [G loss: 9.298134]
[Epoch 1/200] [Batch 370/600] [D loss: 0.000129] [G loss: 9.189229]
[Epoch 1/200] [Batch 371/600] [D loss: 0.000161] [G loss: 9.332725]
[Epoch 1/200] [Batch 372/600] [D loss: 0.000169] [G loss: 9.334762]
[Epoch 1/200] [Batch 373/600] [D loss: 0.000176] [G loss: 9.527009]
[Epoch 1/200] [Batch 374/600] [D loss: 0.000216] [G loss: 9.053519]
[Epoch 1/200] [Batch 375/600] [D loss: 0.000156] [G loss: 9.390020]
[Epoch 1/200] [Batch 376/600] [D loss: 0.000229] [G loss: 9.473807]
[Epoch 1/200] [Batch 377/600] [D loss: 0.000170] [G loss: 9.347260]
[Epoch 1/200] [Batch 378/600] [D loss: 0.000180]

[Epoch 2/200] [Batch 4/600] [D loss: 0.000179] [G loss: 9.948008]
[Epoch 2/200] [Batch 5/600] [D loss: 0.000102] [G loss: 9.767533]
[Epoch 2/200] [Batch 6/600] [D loss: 0.000104] [G loss: 9.743795]
[Epoch 2/200] [Batch 7/600] [D loss: 0.000103] [G loss: 9.738871]
[Epoch 2/200] [Batch 8/600] [D loss: 0.000091] [G loss: 9.736090]
[Epoch 2/200] [Batch 9/600] [D loss: 0.000144] [G loss: 9.407537]
[Epoch 2/200] [Batch 10/600] [D loss: 0.000093] [G loss: 9.842372]
[Epoch 2/200] [Batch 11/600] [D loss: 0.000134] [G loss: 9.785001]
[Epoch 2/200] [Batch 12/600] [D loss: 0.000092] [G loss: 9.678839]
[Epoch 2/200] [Batch 13/600] [D loss: 0.000122] [G loss: 9.971008]
[Epoch 2/200] [Batch 14/600] [D loss: 0.000140] [G loss: 9.686808]
[Epoch 2/200] [Batch 15/600] [D loss: 0.000095] [G loss: 9.890226]
[Epoch 2/200] [Batch 16/600] [D loss: 0.000102] [G loss: 9.925119]
[Epoch 2/200] [Batch 17/600] [D loss: 0.000101] [G loss: 9.753422]
[Epoch 2/200] [Batch 18/600] [D loss: 0.000117] [G loss: 9.854462]
[

[Epoch 2/200] [Batch 244/600] [D loss: 0.000092] [G loss: 10.149184]
[Epoch 2/200] [Batch 245/600] [D loss: 0.000073] [G loss: 10.070426]
[Epoch 2/200] [Batch 246/600] [D loss: 0.000095] [G loss: 10.293147]
[Epoch 2/200] [Batch 247/600] [D loss: 0.000069] [G loss: 10.126266]
[Epoch 2/200] [Batch 248/600] [D loss: 0.000057] [G loss: 10.054312]
[Epoch 2/200] [Batch 249/600] [D loss: 0.000069] [G loss: 10.090304]
[Epoch 2/200] [Batch 250/600] [D loss: 0.000067] [G loss: 10.087935]
[Epoch 2/200] [Batch 251/600] [D loss: 0.000076] [G loss: 10.337615]
[Epoch 2/200] [Batch 252/600] [D loss: 0.000080] [G loss: 10.229342]
[Epoch 2/200] [Batch 253/600] [D loss: 0.000072] [G loss: 10.083660]
[Epoch 2/200] [Batch 254/600] [D loss: 0.000072] [G loss: 10.320019]
[Epoch 2/200] [Batch 255/600] [D loss: 0.000077] [G loss: 10.210760]
[Epoch 2/200] [Batch 256/600] [D loss: 0.000064] [G loss: 10.111988]
[Epoch 2/200] [Batch 257/600] [D loss: 0.000068] [G loss: 9.916177]
[Epoch 2/200] [Batch 258/600] [D lo

[Epoch 2/200] [Batch 480/600] [D loss: 0.000069] [G loss: 10.514335]
[Epoch 2/200] [Batch 481/600] [D loss: 0.000045] [G loss: 10.533490]
[Epoch 2/200] [Batch 482/600] [D loss: 0.000057] [G loss: 10.552097]
[Epoch 2/200] [Batch 483/600] [D loss: 0.000051] [G loss: 10.603416]
[Epoch 2/200] [Batch 484/600] [D loss: 0.000074] [G loss: 10.296222]
[Epoch 2/200] [Batch 485/600] [D loss: 0.000055] [G loss: 10.656799]
[Epoch 2/200] [Batch 486/600] [D loss: 0.000055] [G loss: 10.632102]
[Epoch 2/200] [Batch 487/600] [D loss: 0.000064] [G loss: 10.545350]
[Epoch 2/200] [Batch 488/600] [D loss: 0.000052] [G loss: 10.526454]
[Epoch 2/200] [Batch 489/600] [D loss: 0.000068] [G loss: 10.568652]
[Epoch 2/200] [Batch 490/600] [D loss: 0.000058] [G loss: 10.583229]
[Epoch 2/200] [Batch 491/600] [D loss: 0.000063] [G loss: 10.420376]
[Epoch 2/200] [Batch 492/600] [D loss: 0.000049] [G loss: 10.622395]
[Epoch 2/200] [Batch 493/600] [D loss: 0.000051] [G loss: 10.746239]
[Epoch 2/200] [Batch 494/600] [D l

[Epoch 3/200] [Batch 118/600] [D loss: 0.000050] [G loss: 10.848480]
[Epoch 3/200] [Batch 119/600] [D loss: 0.000035] [G loss: 11.079137]
[Epoch 3/200] [Batch 120/600] [D loss: 0.000048] [G loss: 10.796770]
[Epoch 3/200] [Batch 121/600] [D loss: 0.000048] [G loss: 10.876620]
[Epoch 3/200] [Batch 122/600] [D loss: 0.000040] [G loss: 10.792026]
[Epoch 3/200] [Batch 123/600] [D loss: 0.000038] [G loss: 10.832501]
[Epoch 3/200] [Batch 124/600] [D loss: 0.000047] [G loss: 10.845831]
[Epoch 3/200] [Batch 125/600] [D loss: 0.000039] [G loss: 10.984188]
[Epoch 3/200] [Batch 126/600] [D loss: 0.000042] [G loss: 10.730351]
[Epoch 3/200] [Batch 127/600] [D loss: 0.000062] [G loss: 10.982428]
[Epoch 3/200] [Batch 128/600] [D loss: 0.000031] [G loss: 10.770059]
[Epoch 3/200] [Batch 129/600] [D loss: 0.000052] [G loss: 10.993185]
[Epoch 3/200] [Batch 130/600] [D loss: 0.000039] [G loss: 10.945296]
[Epoch 3/200] [Batch 131/600] [D loss: 0.000060] [G loss: 10.993749]
[Epoch 3/200] [Batch 132/600] [D l

[Epoch 3/200] [Batch 354/600] [D loss: 0.000029] [G loss: 11.063530]
[Epoch 3/200] [Batch 355/600] [D loss: 0.000040] [G loss: 11.105552]
[Epoch 3/200] [Batch 356/600] [D loss: 0.000043] [G loss: 11.230361]
[Epoch 3/200] [Batch 357/600] [D loss: 0.000028] [G loss: 11.306884]
[Epoch 3/200] [Batch 358/600] [D loss: 0.000036] [G loss: 11.359654]
[Epoch 3/200] [Batch 359/600] [D loss: 0.000026] [G loss: 11.179726]
[Epoch 3/200] [Batch 360/600] [D loss: 0.000029] [G loss: 10.985540]
[Epoch 3/200] [Batch 361/600] [D loss: 0.000033] [G loss: 10.920074]
[Epoch 3/200] [Batch 362/600] [D loss: 0.000027] [G loss: 11.145905]
[Epoch 3/200] [Batch 363/600] [D loss: 0.000035] [G loss: 11.363389]
[Epoch 3/200] [Batch 364/600] [D loss: 0.000028] [G loss: 11.215069]
[Epoch 3/200] [Batch 365/600] [D loss: 0.000041] [G loss: 11.027052]
[Epoch 3/200] [Batch 366/600] [D loss: 0.000042] [G loss: 11.178297]
[Epoch 3/200] [Batch 367/600] [D loss: 0.000026] [G loss: 10.897076]
[Epoch 3/200] [Batch 368/600] [D l