In [None]:
import sys
from google.colab import drive
drive.mount('/content/drive') # 授权 Google Drive
sys.path.append('/content/drive/MyDrive/ColabNotebooks/DCGAN') # 添加路径 导入自己写的py文件

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights

In [None]:
# 调用cuda
if torch.cuda.is_available():
  device = torch.device('cuda')
  print(torch.cuda.get_device_name()) 
else:
  device = torch.device('cpu')
  print('cpu only')
# 训练参数 
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
SAVE_MODEL_PATH = '/content/drive/MyDrive/ColabNotebooks/DCGAN/check_points'

Tesla K80


In [None]:
# 下载并处理数据集
transforms = transforms.Compose(
    [
     transforms.Resize(IMAGE_SIZE),
     transforms.ToTensor(),
     transforms.Normalize(
         [0.5 for _ in range(CHANNELS_IMG)],[0.5 for _ in range(CHANNELS_IMG)]),
    ]
)
dataset = datasets.MNIST(root='/content/drive/MyDrive/ColabNotebooks/DCGAN/Datasets', train=True, transform=transforms, download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# 初始化网络
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)
# 加载参数继续训练
# gen = torch.load(os.path.join(SAVE_MODEL_PATH,f'gen_newest.pth'))
# disc = torch.load(os.path.join(SAVE_MODEL_PATH,f'disc_newest.pth'))

# 定义优化器
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# 定义损失函数
criterion = nn.BCELoss()

# 展示结果的时候，将相同的z输入生成器，得到相同的数字（图像不一样），以进行效果比较
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)

# tensorboard
writer = SummaryWriter(f'/content/drive/MyDrive/ColabNotebooks/DCGAN/logs/all')


In [None]:
step=0
# 训练模型
gen.train()
disc.train()
for epoch in range(NUM_EPOCHS):
  for batch_idx, (real, _) in enumerate(dataloader):
    # 真实图片
    real = real.to(device)
    # 生成图片
    noise = torch.randn((BATCH_SIZE,Z_DIM,1,1)).to(device)
    fake = gen(noise)

    ### 训练判别器，目标函数max log(D(x)) + log(1-D(G(z)))
    # 输入真实的图片，得到判别器的输出
    disc_real = disc(real).reshape(-1) # N x 1 x 1 x 1 => N
    # 计算损失
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    # 输入生成的fake，得到判别器的输出
    disc_fake = disc(fake).reshape(-1)
    # 计算损失
    loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    # 损失求和 
    loss_disc = (loss_disc_real + loss_disc_fake)/2
    # 清除梯度
    opt_disc.zero_grad()
    # 误差反传，有两个backward，要保留计算图
    loss_disc.backward(retain_graph=True)
    # 更新参数
    opt_disc.step()

    ### 训练生成器，目标函数min log(1-D(G(z))) 等价于max log(D(G(z)))
    # 输入生成的fake，得到判别器的输出
    output = disc(fake).reshape(-1)
    loss_gen = criterion(output, torch.ones_like(output))
    opt_gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    # Print losser occasionally and print to tensorboard
    if batch_idx % 100 == 0:
      torch.save(gen, os.path.join(SAVE_MODEL_PATH,f'gen_newest.pth'))
      torch.save(disc, os.path.join(SAVE_MODEL_PATH,f'disc_newest.pth'))
      print(
          f'Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
          Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}'
      )

      with torch.no_grad():
        fake = gen(fixed_noise)
        img_grid_real = torchvision.utils.make_grid(
            real[:32], normalize=True
        )
        img_grid_fake = torchvision.utils.make_grid(
            fake[:32], normalize=True
        )

        writer.add_image('REAL', img_grid_real, global_step = step)
        writer.add_image('FAKE', img_grid_fake, global_step = step)
        writer.close()
      step += 1

Epoch [0/5] Batch 0/469           Loss D: 0.6404, Loss G: 0.8267
Epoch [0/5] Batch 100/469           Loss D: 0.0148, Loss G: 4.1180
Epoch [0/5] Batch 200/469           Loss D: 0.1823, Loss G: 0.0250
Epoch [0/5] Batch 300/469           Loss D: 0.4846, Loss G: 1.2697
Epoch [0/5] Batch 400/469           Loss D: 0.5625, Loss G: 1.2188
Epoch [1/5] Batch 0/469           Loss D: 0.5862, Loss G: 1.0467
Epoch [1/5] Batch 100/469           Loss D: 0.6174, Loss G: 0.9048
Epoch [1/5] Batch 200/469           Loss D: 0.5708, Loss G: 0.9373
Epoch [1/5] Batch 300/469           Loss D: 0.6005, Loss G: 1.0897


In [None]:
import torch
a = torch.randn((32,1,1,1))
print(a.shape)
b = a.reshape(-1)
print(b.shape)