In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

In [5]:
batch_size = 64
z_dim = 100
lr = 1e-3
image_resolution = 28

In [8]:
class Generator(nn.Module):
    def __init__(self, z_dim = 100, img_dim = 1, resolution = 28) :
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.img_dim = img_dim
        self.resolution = resolution
    
        self.model = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, resolution * resolution * img_dim),
            nn.Tanh()
        )
    
    def forward(self, z):
        # view方法用于将生成的图像张量重新调整为4D张量，形状为(batch_size, img_dim, resolution, resolution)，其中batch_size表示批量大小。
        return self.model(z).view(-1, self.img_dim, self.resolution, self.resolution)
    
class Discriminator(nn.Module):
    def __init__(self, img_dim = 1, resolution = 28) :
        super(Discriminator, self).__init__()
        self.img_dim = img_dim
        self.resolution = resolution
        
        self.model = nn.Sequential(
            nn.Linear(resolution * resolution * img_dim, 256), 
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1), 
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x.view(-1, self.img_dim * self.resolution * self.resolution))