In [29]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.optim import Adam
from torchvision.utils import make_grid


In [30]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [31]:
class Config:
    lr = 0.0002
    nz = 100 # noise dimension
    image_size = 64
    image_size2 = 64
    nc = 3 # chanel of img 
    ngf = 64 # generate channel
    ndf = 64 # discriminative channel
    beta1 = 0.5
    batch_size = 32
    max_epoch = 10 # =1 when debug
    workers = 2
    gpu = True # use gpu or not
    
opt=Config()

In [32]:
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)


In [33]:
# data preprocess
transform=transforms.Compose([
                transforms.Resize(opt.image_size) ,
                transforms.ToTensor(),
                transforms.Normalize([0.5]*3,[0.5]*3)
                ])


In [41]:
# MNIST dataset
CIFAR10 = torchvision.datasets.CIFAR10(root='./data/',
                                    train=True,
                                    transform=transform,
                                    download=True)


Files already downloaded and verified


In [42]:
# MNIST dataset
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=CIFAR10,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [43]:
G = nn.Sequential(
    nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False),
    nn.BatchNorm2d(opt.ngf*8),
    nn.ReLU(True),

    nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False),
    nn.BatchNorm2d(opt.ngf*4),
    nn.ReLU(True),

    nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False),
    nn.BatchNorm2d(opt.ngf*2),
    nn.ReLU(True),

    nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False),
    nn.BatchNorm2d(opt.ngf),
    nn.ReLU(True),

    nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False),
    nn.Tanh()
)


In [None]:
D = nn.Sequential(
    nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False),
    nn.BatchNorm2d(opt.ndf*2),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False),
    nn.BatchNorm2d(opt.ndf*4),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False),
    nn.BatchNorm2d(opt.ndf*8),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False),
    nn.Sigmoid()
)

In [44]:
# optimizer
optimizerD = Adam(netd.parameters(),lr=opt.lr,betas=(opt.beta1,0.999))
optimizerG = Adam(netg.parameters(),lr=opt.lr,betas=(opt.beta1,0.999))

# criterion
criterion = nn.BCELoss()

fix_noise = torch.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1)
if opt.gpu:
    fix_noise = fix_noise.to(device)
    D.to(device)
    G.to(device)
    criterion.to(device) # it's a good habit

In [46]:
for epoch in range(opt.max_epoch):
    for ii, data in enumerate(data_loader,0):
        real,_=data
        input = torch.Tensor(real)
        label = torch.ones(input.size(0))
        noise = torch.randn(input.size(0),opt.nz,1,1)
        noise = torch.Tensor(noise)
        
        if opt.gpu:
            noise = noise.to(device)
            input = input.to(device)
            label = label.to(device)
        
        # ----- train netd -----
        netd.zero_grad()
        ## train netd with real img
        output=D(input)
        error_real=criterion(output.squeeze(),label)
        error_real.backward()
        D_x=output.data.mean()
        
        ## train netd with fake img
        fake_pic=G(noise).detach()
        output2=D(fake_pic)
        label.data.fill_(0) # 0 for fake
        error_fake=criterion(output2.squeeze(),label)
        error_fake.backward()
        D_x2=output2.data.mean()
        
        error_D=error_real+error_fake
        optimizerD.step()
        
        # ------ train netg -------
        netg.zero_grad()
        label.data.fill_(1)
        noise.data.normal_(0,1)
        fake_pic=G(noise)
        output=D(fake_pic)
        error_G=criterion(output.squeeze(),label)
        error_G.backward()
        optimizerG.step()
        D_G_z2=output.data.mean()
        
#         if ii%500==0:
#             print ('{ii}/{epoch} lossD:{error_D},lossG:{error_G},{D_x2},{D_G_z2},{D_x}'.format(
#                     ii=ii,epoch=epoch,\
#                     error_D=error_D.data[0],\
#                     error_G=error_G.data[0],\
#                     D_x2=D_x2,D_G_z2=D_G_z2,D_x=D_x))
    if epoch % 2 ==0:
        fake_u=netg(fix_noise)
        imgs = make_grid(fake_u.data*0.5+0.5).cpu() # CHW
        plt.imshow(imgs.permute(1,2,0).numpy()) # HWC
        plt.show()

KeyboardInterrupt: 