In [0]:
import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.autograd as autograd

In [0]:
os.makedirs("images", exist_ok=True)
opt={
"epoch":190,
"n_epochs":400, 
"batch_size":64, 
"lr":0.0002, 
"b1":0.5, 
"b2":0.999, 
"n_cpu":8, 
"latent_dim":100, 
"img_size":64, 
"channels":3, 
"sample_interval":1,
"checkpoint_interval":20,
"dataset_name":'faces'
}
 



In [0]:
cuda = True if torch.cuda.is_available() else False

In [0]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [0]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        print('---',classname)
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt['img_size'] // 4
        self.l1 = nn.Sequential(nn.Linear(100, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt['channels'], 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [0]:
gen = Generator()

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt['img_size'] // 2 ** 4
        self.adv_layer =nn.Linear(128 * ds_size ** 2, 1)

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [0]:
d = Discriminator()

In [0]:
#summary(d,input_size=(3,96,96))
#summary(gen,input_size=(100,))

In [0]:
adversarial_loss = torch.nn.MSELoss()

In [0]:
generator = Generator()
discriminator = Discriminator()

In [0]:
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()


In [0]:
if opt['epoch'] != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("colab/saved_models/generator_%d.pth" % (opt['epoch'])))
    discriminator.load_state_dict(torch.load("colab/saved_models/discriminator_%d.pth" % (opt['epoch'])))
else:
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

In [0]:
class LambdaLR:
    #定义一个学习率衰减 offset 为当前初始epoch （一般为0）
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

In [0]:
transforms_ = [
    #transforms.Resize(int(96 * 1.12),3),
    #transforms.RandomHorizontalFlip(),#随机水平翻转
    transforms.ToTensor(),#转化为tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

In [0]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [0]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [0]:
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image

import matplotlib.pyplot as plt
import torchvision.transforms as transforms
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None):
        self.transform = transforms.Compose(transforms_)#将多个transforms 组合起来使用
        self.files_A = sorted(glob.glob(root + "/*.*"))  
    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])
        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        item_A = self.transform(image_A)
        return item_A
    def __len__(self):
        return len(self.files_A)


In [0]:
dataloader = DataLoader(
    ImageDataset("face_data/extra_data/images/", transforms_=transforms_),
    batch_size=opt['batch_size'],
    shuffle=True,
    num_workers=opt['n_cpu'],
)

In [58]:
!mkdir colab/saved_models
!mkdir colab/images

mkdir: cannot create directory ‘colab/saved_models’: File exists
mkdir: cannot create directory ‘colab/images’: File exists


In [0]:
import sys
import datetime
import time
prev_time = time.time()
for epoch in range(opt['epoch'], opt['n_epochs']):
    for i,imgs in enumerate(dataloader):

        # Adversarial ground truths
        #valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        valid = Variable(Tensor(np.ones((imgs.size(0),16))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs.size(0),16))), requires_grad=False)
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------
        
        

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], 100))))

        # Generate a batch of images
        gen_imgs = generator(z)
        if i % 5 == 0:
            optimizer_G.zero_grad()
        # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, gen_imgs.data)
        d_loss = (real_loss + fake_loss) / 2 + gradient_penalty*10

        d_loss.backward()
        optimizer_D.step()
        
        batches_done = epoch * len(dataloader) + i
        batches_left = opt['n_epochs'] * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]  ETA: %s"
            % (epoch, opt['n_epochs'], i, len(dataloader), d_loss.item(), g_loss.item(),time_left,)
        )

        batches_done = epoch * len(dataloader) + i
    if epoch % opt['sample_interval'] == 0:
        
        save_image(gen_imgs.data[:9], "colab/images/%d.png" % epoch, nrow=3, normalize=True)
        img = Image.open('colab/images/%d.png'%epoch)
        plt.imshow(img)
    if opt.get('checkpoint_interval') != -1 and epoch+1 % opt.get('checkpoint_interval') == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), "colab/saved_models/generator_%d.pth" % (epoch))
        torch.save(discriminator.state_dict(), "colab/saved_models/discriminator_%d.pth" % (epoch))
        

  return F.mse_loss(input, target, reduction=self.reduction)


[Epoch 190/400] [Batch 574/575] [D loss: 0.099500] [G loss: 0.224289]  ETA: 0:44:50.698620

  return F.mse_loss(input, target, reduction=self.reduction)


[Epoch 228/400] [Batch 553/575] [D loss: 0.216985] [G loss: 0.208622]  ETA: 1:19:26.881086