In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
torch.manual_seed(42)

<torch._C.Generator at 0x1e79cc81110>

In [3]:
# 超参数
batch_size = 64
z_dim = 100
lr = 1e-3
image_resolution = 28

In [4]:
# 生成器模型
class Generator(nn.Module):
    def __init__(self, z_dim = 100, img_dim =1, resolution = 28) :
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.img_dim = img_dim
        self.resolution = resolution
        
        self.model = nn.Sequential(
            nn.Linear(self.z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, resolution * resolution * img_dim),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).view(-1, self.img_dim, self.resolution, self.resolution)

#判别器模型 
class Discriminator(nn.Module):
    def __init__(self, img_dim = 1, resolution = 28):
        super(Discriminator, self).__init__()
        self.img_dim = img_dim
        self.resolution = resolution
        
        self.model = nn.Sequential(
            nn.Linear(resolution * resolution * img_dim, 256),
            nn.ReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        return self.model(img.view(-1, self.img_dim * self.resolution * self.resolution))

In [5]:
# 模型
generator = Generator(z_dim, img_dim = 1, resolution=  image_resolution).to("cuda")
discriminator = Discriminator(img_dim = 1, resolution = image_resolution).to("cuda")

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr = lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr = lr)

# 损失函数  
criterion = nn.BCELoss()

# 转换以将图像标准化为 (-1, 1)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 加载数据集
train_loader = datasets.FashionMNIST(root = '../data', train = False, download = True, transform = transform)

In [8]:
num_epochs = 10
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)
        
        discriminator.zero_grad()
        real_images = real_images.view(-1, image_resolution * image_resolution).to("cuda")
        real_labels = Variable(torch.ones(batch_size, 1)).to("cuda")
        fake_labels = Variable(torch.zeros(batch_size, 1)).to("cuda")
        
        # 判别器对真实图像的前向传播
        output_real = discriminator(real_images)
        loss_real = criterion(output_real, real_labels.to("cuda"))
        loss_real.backward()
        
        # 判别器对生成的图像的前向传播
        z = Variable(torch.randn(batch_size, z_dim)).to("cuda")
        fake_images = generator(z.to("cuda"))
        output_fake = discriminator(fake_images.detach().view(-1, image_resolution * image_resolution))
        
        optimizer_D.step()
        
        # 训练生成器
        generator.zero_grad()
        output_fake = discriminator(fake_images.view(-1, image_resolution * image_resolution))
        loss_G = criterion(output_fake, real_labels)
        loss_G.backward()
        optimizer_G.step()
        
        if i % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'.format(
                epoch, num_epochs, i, len(train_loader), loss_real.item(), loss_G.item()))
        
        
        with torch.no_grad():
            z = Variable(torch.randn(16, z_dim))
            generated_images = generator(z).detach().numpy()
            
            fig = plt.figure(figsize = (4, 4))
            for i in range(generated_images.shape[0]):
                plt.subplot(4, 4, i + 1)
                plt.imshow(generated_images[i][0], cmap = 'gray')
                plt.axis('off')
                
            if not os.path.exists('out'):
                os.makedirs('out')
            plt.savefig(f'out/epoch_{epoch}.png')
            plt.close(fig)

Epoch [0/10], Step [0/10000], d_loss: 0.0694, g_loss: 0.6120


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_addmm)