In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 超参数
latent_size = 64
hidden_size = 256
image_size = 784
num_epoches = 10
batch_size = 100
sample_dir = 'samples'

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [3]:
# image process
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5),   # 3 for RGB channels
                                     std=(0.5))])

mnist = torchvision.datasets.MNIST('./mnist',train=True,transform=transform,download=True)

In [4]:
view_data = mnist.train_data.type(torch.FloatTensor)/255.



In [5]:
view_data.size()

torch.Size([60000, 28, 28])

In [6]:
data_loader = torch.utils.data.DataLoader(dataset = view_data,
                                         batch_size = batch_size,
                                         shuffle = True)

In [7]:
data_loader

<torch.utils.data.dataloader.DataLoader at 0x7f5b5432e358>

In [8]:
# 判别模型
D = nn.Sequential(
    nn.Linear(image_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,1),
    nn.Sigmoid())

# 生成模型
G = nn.Sequential(
    nn.Linear(latent_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,image_size),
    nn.Tanh())

In [9]:
# device setting
D = D.to(device)
G = G.to(device)

# binary cross entropy 
criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(D.parameters(),lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(),lr=0.0002)

def denorm(x):
    out = (x + 1)/2
    return out.clamp(0,1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [None]:
# 训练
total_step = len(data_loader)
for epoch in range(num_epoches):
    for i,images in enumerate(data_loader):
        # (batch_size , 784)
        images = images.reshape(batch_size,-1).to(device)
        #print(images)
        # create the labels which are later used as input for the bce loss
        real_labels = torch.ones(batch_size,1).to(device)
        fake_labels = torch.zeros(batch_size,1).to(device)
        
        # =========================================== #
        #            train the discriminator          # 
        # =========================================== #
        
        # compute bce_loss using real images where bce_loss(x,y): - y * log(D(x)) - (1-y) * log(1-D(x))
        # second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        #print('sssssss')
        d_loss_real = criterion(outputs,real_labels)
        read_score = outputs
        
        # compute bce_loss using faking images
        # first term of the loss is always zero since real_labels == 1
        z = torch.randn(batch_size,latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs,fake_labels)
        fake_score = outputs
        
        # backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ============================================= #
        #               train the generator             #
        # ============================================= #
        
        # compute loss with fake images
        z = torch.randn(batch_size,latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # we train G to minimize log(1-D(G(z))) == max log(D(G(z))) == min -log(D(G(z)))
        # use real labels ,the second term of the loss is zero
        g_loss = criterion(outputs,real_labels)
        
        # backward and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i + 1) % 200 == 0:
            print('Epoch %d , step %d , d_loss: %.4f, g_loss: %.4f D(x): %.2f,D(G(z)): %.2f'%(epoch,i,d_loss.item(),g_loss.item(),read_score.mean().item(),fake_score.mean().item()))
        # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

Epoch 0 , step 199 , d_loss: 0.1324, g_loss: 4.0410 D(x): 0.93,D(G(z)): 0.05
Epoch 0 , step 399 , d_loss: 0.3361, g_loss: 5.0757 D(x): 0.86,D(G(z)): 0.13
Epoch 0 , step 599 , d_loss: 0.5057, g_loss: 2.8876 D(x): 0.79,D(G(z)): 0.20
Epoch 1 , step 199 , d_loss: 0.2159, g_loss: 3.4477 D(x): 0.89,D(G(z)): 0.08
Epoch 1 , step 399 , d_loss: 0.8302, g_loss: 2.4512 D(x): 0.70,D(G(z)): 0.22
Epoch 1 , step 599 , d_loss: 1.4564, g_loss: 3.0170 D(x): 0.64,D(G(z)): 0.33
Epoch 2 , step 199 , d_loss: 0.3846, g_loss: 3.2894 D(x): 0.87,D(G(z)): 0.19
Epoch 2 , step 399 , d_loss: 0.3747, g_loss: 2.7805 D(x): 0.87,D(G(z)): 0.16
Epoch 2 , step 599 , d_loss: 0.8808, g_loss: 1.8891 D(x): 0.71,D(G(z)): 0.33
Epoch 3 , step 199 , d_loss: 1.1903, g_loss: 1.9403 D(x): 0.64,D(G(z)): 0.32
Epoch 3 , step 399 , d_loss: 0.2817, g_loss: 2.3429 D(x): 0.92,D(G(z)): 0.14
Epoch 3 , step 599 , d_loss: 0.0807, g_loss: 3.5722 D(x): 0.98,D(G(z)): 0.05
Epoch 4 , step 199 , d_loss: 0.1840, g_loss: 4.9962 D(x): 0.95,D(G(z)): 0.06