In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch 
from torchvision import transforms,datasets
import matplotlib.pyplot as plt 
from torch import nn
import numpy as np
import torchvision

In [31]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5,std=0.5)
])

train_ds = datasets.MNIST(
    'data',
    train = True,
    transform=transform,
    download=True
)

train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size=16,
    shuffle=True
)

In [37]:
# 定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.linear1= nn.Linear(100,256*7*7)
        self.bn1 = nn.BatchNorm1d(256*7*7)
        self.deconv1 = nn.ConvTranspose2d(256,128,kernel_size=(3,3),padding=1)
            
        self.bn2 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128,64,kernel_size=(3,3),stride=2,padding=1)
            
        self.bn3 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64,1,kernel_size=(4,4),stride=2,padding=1)
        
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
    def forward(self,x):
        #x = nn.ReLU(self.linear1(x))
        x = self.relu(self.linear1(x))
        x = self.bn1(x)
        x = x.view(-1,256,7,7)
        x = self.relu(self.deconv1(x))
        x = self.bn2(x)
        x = self.relu(self.deconv2(x))
        x = self.bn3(x)
        x = self.tanh(self.deconv3(x))
        return x

In [54]:
# 判别器
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator,self).__init__()
        self.conv1 = nn.Conv2d(1,64,3,2)
        self.conv2 = nn.Conv2d(64,128,3,2)
        self.bn = nn.BatchNorm2d(128)
        self.dropout = nn.Dropout2d(p=0.3)
        self.leaky = nn.LeakyReLU() 
        self.fc  = nn.Linear(128*6*6,1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
#         x = nn.Dropout2d(nn.LeakyReLU(self.conv1(x)),p=0.3)
#         x = nn.Dropout2d(nn.LeakyReLU(self.conv2(x)),p=0.3)
        x = self.conv1(x)
        x = self.leaky(x)
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = self.leaky(x)
        x = self.dropout(x)
        x = self.bn(x)
        x = x.view(-1,256*7*7)
        x = self.fc(x)
        x = self.sigmoid(x)
#         x = nn.Sigmoid(self.fc(x))
        return x


In [55]:
gen = Generator()       # 实例化生成器
dis = Discriminator()   # 实例化判别器
loss_fn = nn.BCELoss()   # 定义损失函数
g_optimizer = torch.optim.Adam(gen.parameters(),lr=0.0001)  # 生成器优化器
d_optimizer = torch.optim.Adam(dis.parameters(),lr=0.0001)  # 判别器优化器
test_seed = torch.randn(16,100)

In [56]:
# 可视化
def generate_and_save_image(model,epoch,test_input):
    predictions = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4,4))  # 可视化
    for i in range(predictions.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((predictions[i]+1)/2,cmap='gray')
        plt.axis('off')
    plt.savefig('./dcgan_image/image_epoch_{}.png'.format(epoch))
    plt.show()

In [57]:
D_loss = []    # 记录训练过程中判别器的损失变化
G_loss = []    # 记录训练过程中生成器的损失变化
epochs = 20
# 训练环节
for epoch in range(epochs):
    D_epoch_loss = 0    # 用于累加一个epoch中的D的损失
    G_epoch_loss = 0    # 用于累加一个epoch中的G的损失
    count = len(train_dl)  # 共计的batch数量
    for step,(img,_) in enumerate(train_dl):
        size = img.shape[0]  # 获取bath大小
        random_seed = torch.randn(size,100)
        
        #######判别器损失值#######
        # 第一部分 真实图片的判定
        d_optimizer.zero_grad()   # 第 1-1 步 梯度归零 
        real_output = dis(img)    # 第 1-2 步 将原图传入判别器
        d_real_loss = loss_fn(real_output,torch.ones_like(real_output)) # 第1-3 获取损失值（真实图片与真实图片） 
        d_real_loss.backward()    # 第 1-4 步 误差向后传递
        
        # 第二部分 判定生成图片
        generated_img = gen(random_seed)           # 第 2-1 生成器输入随机张量得到生成图片
        fake_output = dis(generated_img.detach())  # 第 2-2 判别器输入生成图像
        d_fake_loss = loss_fn(fake_output,torch.zeros_like(fake_output))
        d_fake_loss.backward()
        
        disc_loss = d_real_loss + d_fake_loss
        d_optimizer.step() 
        #######判别器损失值end#######
        
        #######生成器损失值#######
        g_optimizer.zero_grad()
        fake_output = dis(generated_img)
        gen_loss = loss_fn(fake_output,torch.ones_like(fake_output))
        gen_loss.backward()
        g_optimizer.step()
        #######生成器损失值end#######
        
        # 记录器
        with torch.no_grad():
            D_epoch_loss += disc_loss
            G_epoch_loss += gen_loss
        
    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        print('Epoch:',epoch)
        generate_and_save_image(gen,epoch,test_seed)

RuntimeError: shape '[-1, 12544]' is invalid for input of size 73728