In [29]:
import os
import torch
import torchvision
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.optim import Adam
from torchvision.utils import make_grid
from torch.optim import RMSprop

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [3]:
'''
https://zhuanlan.zhihu.com/p/25071913
WGAN modified of DCGAN in:
1. remove sigmoid in the last layer of discriminator(classification -> regression)                                       
# 回归问题,而不是二分类概率
2. no log Loss (Wasserstein distance)
3. clip param norm to c (Wasserstein distance and Lipschitz continuity)
4. No momentum-based optimizer, use RMSProp，SGD instead

explanation of GAN：
collapse mode ->KL diverse
digit unstability-> comflict between KL Divergence and JS Divergence
'''

class Config:
    lr = 0.00005
    nz = 100 # noise dimension
    image_size = 64
    image_size2 = 64
    nc = 3 # chanel of img 
    ngf = 64 # generate channel
    ndf = 64 # discriminative channel
    beta1 = 0.5
    batch_size = 32
    max_epoch = 50 # =1 when debug
    workers = 2
    gpu = True # use gpu or not
    clamp_num=0.01# WGAN clip gradient
    
opt=Config()

In [4]:
# data preprocess
transform=transforms.Compose([
                transforms.Resize(opt.image_size) ,
                transforms.ToTensor(),
                transforms.Normalize([0.5]*3,[0.5]*3)
                ])

In [8]:
dataset=torchvision.datasets.CIFAR10(root='./data/',
                                    transform=transform,
                                    download=True)


Files already downloaded and verified


In [10]:
# dataloader with multiprocessing
dataloader=torch.utils.data.DataLoader(dataset,
                                       opt.batch_size,
                                       shuffle = True,
                                       num_workers=opt.workers)

In [15]:
G = nn.Sequential(
            nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False),
            nn.BatchNorm2d(opt.ngf*8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ngf*4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ngf*2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ngf),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False),
            nn.Tanh()
        )





In [None]:
D = nn.Sequential(
            nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False),
            # Modification 1: remove sigmoid
            # nn.Sigmoid()
        )


In [None]:
def weight_init(m):
    # weight_initialization: important for wgan
    class_name=m.__class__.__name__
    if class_name.find('Conv')!=-1:
        m.weight.data.normal_(0,0.02)
    elif class_name.find('Norm')!=-1:
        m.weight.data.normal_(1.0,0.02)
#     else:print(class_name)

D.apply(weight_init)
G.apply(weight_init)
print()

In [31]:
# modification 2: Use RMSprop instead of Adam
# optimizer
optimizerD = RMSprop(D.parameters(),lr=opt.lr ) 
optimizerG = RMSprop(G.parameters(),lr=opt.lr )  

# modification3: No Log in loss
# criterion
# criterion = nn.BCELoss()

fix_noise = torch.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1)
if opt.gpu:
    fix_noise = fix_noise.to(device)
    D.to(device)
    G.to(device)

In [32]:
# begin training
one=torch.ones(32,1,1,1).float()
mone=-1*one

for epoch in range(opt.max_epoch):
    for ii, data in enumerate(dataloader,0):
        real,_=data
        input = torch.Tensor(real)
        noise = torch.randn(input.size(0),opt.nz,1,1)
        noise = torch.Tensor(noise)
        
        if opt.gpu:
            one = one.to(device)
            mone = mone.to(device)
            noise = noise.to(device)
            input = input.to(device)

        # modification: clip param for discriminator
        for parm in D.parameters():
                parm.data.clamp_(-opt.clamp_num,opt.clamp_num)
        
        # ----- train netd -----
        D.zero_grad()
        ## train netd with real img
        output=D(input)
        output.backward(one)
        ## train netd with fake img
        fake_pic=G(noise).detach()
        output2=D(fake_pic)
        output2.backward(mone)
        optimizerD.step()

        
        
        # ------ train netg -------
        # train netd more: because the better netd is,
        # the better netg will be
        if (ii+1)%5 ==0:
            G.zero_grad()
            noise.data.normal_(0,1)
            fake_pic=G(noise)
            output=D(fake_pic)
            output.backward(one)
            optimizerG.step()
            if ii%100==0:pass
    fake_u=G(fix_noise)
    
    imgs = make_grid(fake_u.data*0.5+0.5).cpu() # CHW
    plt.imshow(imgs.permute(1,2,0).numpy()) # HWC
    plt.show()

KeyboardInterrupt: 

In [None]:
noise = torch.randn(64,opt.nz,1,1).cuda()
noise = torch.Tensor(noise)
fake_u=G(noise)
imgs = make_grid(fake_u.data*0.5+0.5).cpu() # CHW
plt.figure(figsize=(5,5))
plt.imshow(imgs.permute(1,2,0).numpy()) # HWC
plt.show()