<a href="https://colab.research.google.com/github/frzlh/DeepLearning/blob/main/DCGAN_Adversarial_generating_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

准备数据

In [1]:
!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!tar xf 102flowers.tgz
!mkdir oxford-102
!mkdir g_state
!mkdir d_state
!mkdir save_img
!mkdir oxford-102/jpg
!mv jpg/*.jpg oxford-102/jpg


--2024-08-21 08:57:14--  http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
Resolving www.robots.ox.ac.uk (www.robots.ox.ac.uk)... 129.67.94.2
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz [following]
--2024-08-21 08:57:14--  https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://thor.robots.ox.ac.uk/flowers/102/102flowers.tgz [following]
--2024-08-21 08:57:14--  https://thor.robots.ox.ac.uk/flowers/102/102flowers.tgz
Resolving thor.robots.ox.ac.uk (thor.robots.ox.ac.uk)... 129.67.95.98
Connecting to thor.robots.ox.ac.uk (thor.robots.ox.ac.uk)|129.67.95.98|:443... connected.
HTTP request sent, awaiting response... 200 OK
Leng

数据预处理

In [32]:
import torch
from torch import nn,optim
from torch.utils.data import (Dataset,TensorDataset,DataLoader)
import tqdm
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
image_data=ImageFolder(
    "/content/oxford-102",
    transform=transforms.Compose([
        transforms.Resize(80),
        transforms.CenterCrop(64),
        transforms.ToTensor()
    ])
)
batch_size=64
img_loader=DataLoader(image_data,batch_size=batch_size,shuffle=True)


图像生成模型的构建

In [33]:
nz=100
ngf=32

class GNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main=nn.Sequential(
        nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),
        nn.BatchNorm2d(ngf*8),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),
        nn.BatchNorm2d(ngf*4),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
        nn.BatchNorm2d(ngf*2),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf,3,4,2,1,bias=False),
        nn.Tanh()

    )

  def forward(self,x):
    out=self.main(x)
    return out

图像识别模型的构建

In [34]:
ndf=32

class DNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main=nn.Sequential(
        nn.Conv2d(3,ndf,4,2,1,bias=False),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*2),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*4),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*8),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(ndf*8,1,4,1,0,bias=False)
    )
  def forward(self,x):
    out=self.main(x)
    return out.squeeze()


训练函数的构建

In [35]:
d=DNet().to("cuda:0")
g=GNet().to("cuda:0")
opt_d=optim.Adam(d.parameters(),lr=0.02,betas=(0.5,0.999))
opt_g=optim.Adam(g.parameters(),lr=0.02,betas=(0.5,0.999))

ones=torch.ones(batch_size).to("cuda:0")
zeros=torch.zeros(batch_size).to("cuda:0")
loss_f=nn.BCEWithLogitsLoss()
fixed_z=torch.randn(batch_size,nz,1,1).to("cuda:0")

训练函数

In [36]:
from statistics import mean
def train_dcgan(g,d,opt_g,opt_d,loader):
  log_loss_g=[]
  log_loss_d=[]
  for real_img,_ in tqdm.tqdm(loader):
    batch_len=len(real_img)
    real_img=real_img.to("cuda:0")
    z=torch.randn(batch_len,nz,1,1).to("cuda:0")
    fake_img=g(z)
    fake_img_tensor=fake_img.detach()
    out=d(fake_img)
    loss_g=loss_f(out,ones[:batch_len])
    log_loss_g.append(loss_g.item())
    d.zero_grad(),g.zero_grad()
    loss_g.backward()
    opt_g.step()
    real_out=d(real_img)
    loss_d_real=loss_f(real_out,ones[:batch_len])
    fake_img=fake_img_tensor
    fake_out=d(fake_img_tensor)
    loss_d_fake=loss_f(fake_out,zeros[:batch_len])
    loss_d=loss_d_real+loss_d_fake
    log_loss_d.append(loss_d.item())
    d.zero_grad,g.zero_grad()
    loss_d.backward()
    opt_d.step()

  return mean(log_loss_g),mean(log_loss_d)

DCGAN模型的训练

In [None]:
from IPython.display import Image,display_jpeg
for epoch in range(300):
  train_dcgan(g,d,opt_g,opt_d,img_loader)
  if epoch%10==0:
    torch.save(
        g.state_dict(),
        "g_state/g_{:03d}.prm".format(epoch),
        pickle_protocol=4
    )
    torch.save(
        d.state_dict(),
        "d_state/g_{:03d}.prm".format(epoch),
        pickle_protocol=4
    )
    generated_img=g(fixed_z)
    save_image(generated_img,"save_img/{:03d}.jpg".format(epoch))


100%|██████████| 256/256 [00:37<00:00,  6.84it/s]
100%|██████████| 256/256 [00:37<00:00,  6.89it/s]
100%|██████████| 256/256 [00:37<00:00,  6.83it/s]
100%|██████████| 256/256 [00:37<00:00,  6.88it/s]
100%|██████████| 256/256 [00:37<00:00,  6.87it/s]
100%|██████████| 256/256 [00:37<00:00,  6.90it/s]
100%|██████████| 256/256 [00:37<00:00,  6.85it/s]
100%|██████████| 256/256 [00:37<00:00,  6.89it/s]
100%|██████████| 256/256 [00:37<00:00,  6.85it/s]
100%|██████████| 256/256 [00:37<00:00,  6.86it/s]
100%|██████████| 256/256 [00:37<00:00,  6.81it/s]
100%|██████████| 256/256 [00:37<00:00,  6.85it/s]
100%|██████████| 256/256 [00:36<00:00,  6.97it/s]
100%|██████████| 256/256 [00:37<00:00,  6.91it/s]
100%|██████████| 256/256 [00:37<00:00,  6.90it/s]
100%|██████████| 256/256 [00:37<00:00,  6.89it/s]
100%|██████████| 256/256 [00:37<00:00,  6.83it/s]
100%|██████████| 256/256 [00:37<00:00,  6.84it/s]
100%|██████████| 256/256 [00:37<00:00,  6.77it/s]
100%|██████████| 256/256 [00:37<00:00,  6.77it/s]
