<a href="https://colab.research.google.com/github/olream/GAN_Series/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Sun Apr 17 13:54:08 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import sys
sys.path.append('/content/drive/MyDrive/ColabNotebooks/pix2pix') 

In [3]:
import torch
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from dataset import MapDataset
from generator_model import Generator
from discriminator_model import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm

In [4]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce):
  loop = tqdm(loader, leave=True)
  for idx, (x, y) in enumerate(loop):
    x, y = x.to(config.DEVICE), y.to(config.DEVICE)

    # 训练判别器
    y_fake = gen(x)
    D_real = disc(x, y)
    D_fake = disc(x, y_fake.detach())  # 和计算图分离
    D_real_loss = bce(D_real, torch.ones_like(D_real))
    D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
    D_loss = (D_real_loss + D_fake_loss)/2

    disc.zero_grad()
    D_loss.backward()
    opt_disc.step()


    # 训练生成器
    D_fake = disc(x, y_fake)
    G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
    L1 = l1(y_fake, y) * config.L1_LAMBDA
    G_loss = G_fake_loss + L1

    gen.zero_grad()
    G_loss.backward()
    opt_gen.step()


In [5]:
def main():
  disc = Discriminator(in_channels=3).to(config.DEVICE)
  gen = Generator(in_channels=3).to(config.DEVICE)
  opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
  opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
  BCE = nn.BCELoss()
  L1_LOSS = nn.L1Loss()

  if config.LOAD_MODEL:
    load_checkpoint(config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE)
    load_checkpoint(config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE)

  train_dataset = MapDataset(root_dir='/content/drive/MyDrive/ColabNotebooks/' + config.TRAIN_DIR)
  train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS)
  val_dataset = MapDataset(root_dir='/content/drive/MyDrive/ColabNotebooks/' + config.VAL_DIR)
  val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

  for epoch in range(config.NUM_EPOCHS):
    train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE)

    if config.SAVE_MODEL and epoch % 10 == 0:
      save_checkpoint(gen, opt_gen, filename='/content/drive/MyDrive/ColabNotebooks/pix2pix/' + config.CHECKPOINT_GEN)
      save_checkpoint(disc, opt_disc, filename='/content/drive/MyDrive/ColabNotebooks/pix2pix/' + config.CHECKPOINT_DISC)
      save_some_examples(gen, val_loader, epoch, folder='/content/drive/MyDrive/ColabNotebooks/pix2pix/evaluation')


In [None]:
if __name__ == '__main__':
  main()

['105.jpg', '90.jpg', '92.jpg', '107.jpg', '95.jpg', '110.jpg', '97.jpg', '112.jpg', '98.jpg', '114.jpg', '99.jpg', '116.jpg', '118.jpg', '119.jpg', '120.jpg', '121.jpg', '122.jpg', '123.jpg', '124.jpg', '125.jpg', '126.jpg', '127.jpg', '128.jpg', '129.jpg', '130.jpg', '131.jpg', '132.jpg', '133.jpg', '101.jpg', '134.jpg', '104.jpg', '136.jpg', '106.jpg', '138.jpg', '108.jpg', '140.jpg', '142.jpg', '109.jpg', '111.jpg', '144.jpg', '146.jpg', '147.jpg', '113.jpg', '148.jpg', '150.jpg', '151.jpg', '115.jpg', '152.jpg', '117.jpg', '154.jpg', '135.jpg', '156.jpg', '137.jpg', '158.jpg', '139.jpg', '160.jpg', '141.jpg', '162.jpg', '143.jpg', '164.jpg', '145.jpg', '149.jpg', '167.jpg', '153.jpg', '169.jpg', '155.jpg', '171.jpg', '157.jpg', '173.jpg', '159.jpg', '175.jpg', '161.jpg', '177.jpg', '163.jpg', '179.jpg', '165.jpg', '166.jpg', '168.jpg', '170.jpg', '172.jpg', '181.jpg', '174.jpg', '187.jpg', '176.jpg', '189.jpg', '178.jpg', '191.jpg', '180.jpg', '193.jpg', '194.jpg', '182.jpg', '196

100%|██████████| 69/69 [00:49<00:00,  1.39it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:50<00:00,  1.36it/s]
100%|██████████| 69/69 [00:52<00:00,  1.33it/s]
100%|██████████| 69/69 [00:53<00:00,  1.30it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:53<00:00,  1.28it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
100%|██████████| 69/69 [00:53<00:00,  1.29it/s]
  0%|          | 0/69 [00:00<?, ?it/s]