<a href="https://colab.research.google.com/github/olream/GAN_Series/blob/main/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

挂载谷歌云盘

In [1]:
from google.colab import drive
drive.mount('/content/drive') # 授权 Google Drive

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


下载数据集以及训练结果存放地址

In [2]:
PATH='/content/drive/MyDrive/ColabNotebooks/DCGAN/'

定义模型

In [3]:
import torch
import torch.nn as nn

In [4]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        # Input: N x channels_img x 64 x 64
        nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 x 32
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16
        self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8
        self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1
        nn.Sigmoid(),
    )
  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),      
    )

  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        # Input: N x z_dim x 1 x 1
        self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
        self._block(features_g*16, features_g*8, 4, 2, 1), # N x f_g*8 x 8 x 8
        self._block(features_g*8, features_g*4, 4, 2, 1), # N x f_g*4 x 16 x 16
        self._block(features_g*4, features_g*2, 4, 2, 1), # N x f_g*2 x 32 x 32
        nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1), # N x 3 x 64 x 64
        nn.Tanh(), # [-1,1]
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )
  
  def forward(self, x):
    return self.gen(x)


In [5]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

In [6]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  # discriminator
  disc = Discriminator(in_channels, 8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
  # generator
  gen = Generator(z_dim, in_channels, 64)
  initialize_weights(gen)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
  print('success')

# test() # 测试模型输出

main

In [7]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [8]:
# 调用cuda
if torch.cuda.is_available():
  device = torch.device('cuda')
  print(torch.cuda.get_device_name()) 
else:
  device = torch.device('cpu')
  print('cpu only')
# 训练参数 
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
SAVE_MODEL_PATH = PATH + 'check_points'

Tesla K80


In [9]:
# 下载并处理数据集
transforms = transforms.Compose(
    [
     transforms.Resize(IMAGE_SIZE),
     transforms.ToTensor(),
     transforms.Normalize(
         [0.5 for _ in range(CHANNELS_IMG)],[0.5 for _ in range(CHANNELS_IMG)]),
    ]
)
dataset = datasets.MNIST(root=PATH + 'Datasets', train=True, transform=transforms, download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [10]:
# 初始化网络
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)
# 加载参数继续训练
# gen = torch.load(os.path.join(SAVE_MODEL_PATH,f'gen_newest.pth'))
# disc = torch.load(os.path.join(SAVE_MODEL_PATH,f'disc_newest.pth'))

# 定义优化器
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# 定义损失函数
criterion = nn.BCELoss()

# 展示结果的时候，将相同的z输入生成器，得到相同的数字（图像不一样），以进行效果比较
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)

# tensorboard
writer = SummaryWriter(PATH + 'logs/all')


In [None]:
step=0
# 训练模型
gen.train()
disc.train()
for epoch in range(NUM_EPOCHS):
  for batch_idx, (real, _) in enumerate(dataloader):
    # 真实图片
    real = real.to(device)
    # 生成图片
    noise = torch.randn((BATCH_SIZE,Z_DIM,1,1)).to(device)
    fake = gen(noise)

    ### 训练判别器，目标函数max log(D(x)) + log(1-D(G(z)))
    # 输入真实的图片，得到判别器的输出
    disc_real = disc(real).reshape(-1) # N x 1 x 1 x 1 => N
    # 计算损失
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    # 输入生成的fake，得到判别器的输出
    disc_fake = disc(fake).reshape(-1)
    # 计算损失
    loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    # 损失求和 
    loss_disc = (loss_disc_real + loss_disc_fake)/2
    # 清除梯度
    opt_disc.zero_grad()
    # 误差反传，有两个backward，要保留计算图
    loss_disc.backward(retain_graph=True)
    # 更新参数
    opt_disc.step()

    ### 训练生成器，目标函数min log(1-D(G(z))) 等价于max log(D(G(z)))
    # 输入生成的fake，得到判别器的输出
    output = disc(fake).reshape(-1)
    loss_gen = criterion(output, torch.ones_like(output))
    opt_gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    # Print losser occasionally and print to tensorboard
    if batch_idx % 100 == 0:
      torch.save(gen, os.path.join(SAVE_MODEL_PATH,f'gen_newest.pth'))
      torch.save(disc, os.path.join(SAVE_MODEL_PATH,f'disc_newest.pth'))
      print(
          f'Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
          Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}'
      )

      with torch.no_grad():
        fake = gen(fixed_noise)
        img_grid_real = torchvision.utils.make_grid(
            real[:32], normalize=True
        )
        img_grid_fake = torchvision.utils.make_grid(
            fake[:32], normalize=True
        )

        writer.add_image('REAL', img_grid_real, global_step = step)
        writer.add_image('FAKE', img_grid_fake, global_step = step)
        writer.close()
      step += 1

In [None]:
# 模型效果可视化
%load_ext tensorboard
%tensorboard --logdir '/content/drive/MyDrive/ColabNotebooks/DCGAN/logs'