https://mp.weixin.qq.com/s?__biz=MzI5NTIxNTg0OA==&mid=2247498068&idx=3&sn=146f902866fa2094aca0d70989563ed3&chksm=ec544ed3db23c7c50bc370b93214219c9bd0354a58807de1763d1defb896988d6ec7d2bf9c9e&scene=0&xtrack=1&key=e8c75d77299604ef82eb6769e6b689f2ca2cec7932b55edf6f49266cca36c0541f2172f2b6a0f4c01cb62a62835ee855450c80a25548300ce516a1d8cf4d665c6952e90d059ee1d0988223de941c3a6a&ascene=1&uin=MjA1MjAyODkxNg%3D%3D&devicetype=Windows+10&version=62070141&lang=zh_CN&pass_ticket=Vsf3NZgYf5i5ZiliBW0AmsBXvt2ictUXr4fKbzPM5HHCZws0l68BCE4kvylUn040

In [1]:
import torch
from torch import nn,optim
from torch.autograd.variable import Variable
import torchvision
import torchvision.transforms as transforms

In [3]:
# Preprocessing
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])

In [4]:
# Training data
train_set = torchvision.datasets.MNIST(root='../dataset/',train=True,download=True,transform=transform)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=32,shuffle=True)
# Labels
classes = [str(i) for i in range(0,10)]

In [7]:
# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784,1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024,512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512,256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        out = self.model(x.view(x.size(0),784))
        out = out.view(out.size(0),-1)
        return out.cuda()
discriminator = Discriminator()

In [9]:
# 生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100,256),
            nn.ReLU(inplace=True),
            nn.Linear(256,512),
            nn.ReLU(inplace=True),
            nn.Linear(512,1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024,784),
            nn.Tanh()
        )
    def forward(self,x):
        x = x.view(x.size(0),100)
        out = self.model(x).cuda()
        return out

generator = Generator()

In [10]:
if torch.cuda.is_available():
    print('Using CUDA')
    discriminator.cuda()
    generator.cuda()
# 设置损失函数和优化器
lr = 0.0001
num_epochs = 40
num_batches = len(train_loader)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(),lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(),lr=lr)

Using CUDA


In [12]:
# 开始循环训练，pytorch中的训练循环通常由一个遍历各个训练周期的外部循环和一个遍历batch数据的内部循环组成，
# 训练GAN的关键是我们需要在一个循环中更新生成器和判别器
def train_discriminator(discriminator,real_images,real_labels,fake_images,fake_labels):
    discriminator.zero_grad()
    
    # 真实图片的预测，损失和得分
    predictions = discriminator(real_images)
    real_loss = criterion(predictions,real_labels)
    real_score = predictions
    # 假的图片的预测，损失和得分
    predictions = discriminator(fake_images)
    fake_loss = criterion(predictions,fake_labels)
    fake_score = predictions
    # 计算总的损失，更新权重和优化器
    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    return d_loss,real_score,fake_score

def train_generator(generator,discriminator_outputs,real_labels):
    generator.zero_grad()
    
    # 计算总损失，更新权重和优化器
    g_loss = criterion(discriminator_outputs,real_labels)
    g_loss.backward()
    g_optimizer.step()
    return g_loss

In [None]:
# 开始训练
for epoch in range(num_epochs):
    for n,(images,_) in enumerate(train_loader):
        # (1)为判别器准备真实图片
        real_images = Variable(images).cuda()
        real_labels = Variable(torch.ones(images.size(0))).cuda()
        # (2)为生成器准备随机噪声数据
        noise = Variable(torch.randn(images.size(0),100)).cuda()
        # (3)为判别器准备假的图片
        fake_images = generator(noise)
        fake_labels = Variable(torch.zeros(images.size(0))).cuda()
        # (4)在真实图片和假的图片上训练判别器
        d_loss,real_score,fake_score = train_discriminator(discriminator,real_images,real_labels,fake_images,fake_labels)
        # (5a)从生成器产生一些新的假的图片
        # (5b)假的图片在判别器上得到预测的标签
        noise = Variable(torch.randn(images.size(0),100)).cuda()
        fake_images = generator(noise)
        
        outputs = discriminator(fake_images)
        # 训练生成器
        g_loss = train_generator(generator,outputs,real_labels)