In [15]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
log = 0

In [29]:
class Generator(nn.Module):
    def __init__(self, latent):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 784),
            nn.Tanh()
        )
    def forward(self, x):
        x = self.fc(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            #nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.fc(x)
        return x

In [30]:
train_data = torchvision.datasets.MNIST(root='../dataset/', download=False, train=True, transform=torchvision.transforms.ToTensor())
train_iter = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)

In [31]:
latent = 64
lr = 0.0001
epochs = 20
clip_value = 0.01
k = 1

In [32]:
G = Generator(latent)
D = Discriminator()
#criterion = nn.BCELoss() #BCELoss(x,y): -y * log(x) - (1-y) * log(1-x)
optimizer_G = torch.optim.RMSprop(G.parameters(), lr=lr)
optimizer_D = torch.optim.RMSprop(D.parameters(), lr=lr)

In [None]:
log = log + 1
writer = SummaryWriter(log_dir='wgan/'+str(log))
test = torch.normal(0, 1, (32, latent))
for epoch in range(epochs):
    avg_loss_g = 0.0
    avg_loss_d = 0.0
    num = 0
    for i, (real_img, _) in enumerate(train_iter):
        batch_size = real_img.shape[0]
        real_img = real_img.view(batch_size, -1)
        
        #real_label = torch.ones(batch_size, 1)
        #fake_label = torch.zeros(batch_size, 1)
        
        # 训练判别器
        d_real = D(real_img)
        #d_real_loss = criterion(d_real, real_label)
        d_real_loss = d_real
        
        z = torch.normal(0, 1, (batch_size, latent))
        fake_img = G(z)
        d_fake = D(fake_img)
        #d_fake_loss = criterion(d_fake, fake_label)
        d_fake_loss = d_fake
        
        optimizer_D.zero_grad()
        #d_loss = d_real_loss + d_fake_loss
        d_loss = torch.mean(d_fake_loss) - torch.mean(d_real_loss)
        d_loss.backward()
        optimizer_D.step()
        
        for p in D.parameters():
            p.data.clamp_(-clip_value, clip_value)
        
        # 训练生成器
        if (i+1) % k == 0:
            num += 1
            fake_img = G(z)
            d_fake = D(fake_img)
            #g_loss = criterion(d_fake, real_label)
            g_loss = - torch.mean(d_fake)

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

            avg_loss_g += g_loss.item()
            avg_loss_d += d_loss.item()
        
        
    print("Epoch:{}, Loss_D:{}, Loss_G:{}".format(epoch, avg_loss_d/num, avg_loss_g/num))
    writer.add_scalar('Loss_D', avg_loss_d/num, epoch)
    writer.add_scalar('Loss_G', avg_loss_g/num, epoch)
    
    with torch.no_grad():
        s = "Epoch-"+str(epoch)
        show = torch.clamp(G(test), 0, 1).reshape(32, 1, 28, 28)
        writer.add_images(s, show, 0)
writer.close()