<a href="https://colab.research.google.com/github/gkrry2723/ML_pytorch_study/blob/master/5_DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. 데이터 준비
- 꽃 데이터
- 102 종류의 꽃을 약 8000장의 이미지 데이터로 제공

In [None]:
!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!tar xf 102flowers.tgz
!mkdir oxford-102
!mkdir oxford-102/jpg
!mv jpg/*.jpg oxford-102/jpg

--2021-05-14 08:39:59--  http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
Resolving www.robots.ox.ac.uk (www.robots.ox.ac.uk)... 129.67.94.2
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz [following]
--2021-05-14 08:39:59--  https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 344862509 (329M) [application/x-gzip]
Saving to: ‘102flowers.tgz’


2021-05-14 08:40:15 (20.8 MB/s) - ‘102flowers.tgz’ saved [344862509/344862509]



2. data loader만들기

In [None]:
import torch
from torch import nn,optim
from torch.utils.data import (Dataset,DataLoader,TensorDataset)
from torchvision.datasets import ImageFolder
from torchvision import transforms

img_data = ImageFolder("/content/oxford-102/", transform = transforms.Compose([transforms.Resize(80), transforms.CenterCrop(64),transforms.ToTensor()]))

batch_size = 64
img_loader = DataLoader(img_data,batch_size = batch_size, shuffle=True)


3. net 만들기

생성모델    z : 100차원 -> 3* 64 * 64
식별모델    image : 3* 64 * 64 -> 1차원 스칼라

In [None]:
nz = 100
ngf = 32

#생성 모델
class GNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),
        nn.BatchNorm2d(ngf*8),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),
        nn.BatchNorm2d(ngf*4),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),
        nn.BatchNorm2d(ngf*2),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(ngf,3,4,2,1,bias=False),
        nn.Tanh()
    )

  def forward(self,x):
    out = self.main(x)
    return out


#식별 모델
ndf = 32
class DNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main = nn.Sequential(
        nn.Conv2d(3,ndf,4,2,1,bias=False),
        nn.LeakyReLU(0.2,inplace=True),
        
        nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*2),
        nn.LeakyReLU(0.2,inplace=True),
        
        nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*4),
        nn.LeakyReLU(0.2,inplace=True),

        nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
        nn.BatchNorm2d(ndf*8),
        nn.LeakyReLU(0.2,inplace=True),

        nn.Conv2d(ndf*8, 1, 4,1,0,bias=False)
    )

  def forward(self,x):
    out = self.main(x)
    return out.squeeze()

4. 훈련 함수 작성


In [None]:
d = DNet().to("cuda:0")
g = GNet().to("cuda:0")

opt_d = optim.Adam(d.parameters(), lr = 0.0002, betas=(0.5,0.999))
opt_g = optim.Adam(g.parameters(), lr=0.0002, betas=(0.5,0.999))

# ce 계산하기 위한 보조변수들 
ones = torch.ones(batch_size).to("cuda:0")
zeros = torch.zeros(batch_size).to("cuda:0")
loss_f=nn.BCEWithLogitsLoss()

# 모니터링용 z
fixed_z= torch.randn(batch_size,nz,1,1).to("cuda:0")

# 훈련함수
from statistics import mean
import tqdm

def train_dcgan(g,d,opt_g,opt_d,loader):
  log_loss_g=[]
  log_loss_d=[]
  for real_img, _ in tqdm.tqdm(loader):
    batch_len = len(real_img)

    real_img = real_img.to("cuda:0")

    #가짜 이미지 만들기
    z = torch.randn(batch_len,nz,1,1).to("cuda:0")
    fake_img = g(z)

    # g 갱신하고 d 갱신관련 할건데 g 갱신하고나면 fake_img의 파라메터도 막 바뀌니까 일단 저장해놓기
    fake_img_tensor = fake_img.detach()

    #가짜이미지의 평가함수 계산
    out = d(fake_img)

    # 생성 모델 업데이트
    # 가짜 이미지에 대한 d의 평가와 1(진짜)를 크로스엔트로피.
    # 가짜이미지가 더 진짜같아 질수록 해당 크로스엔트로피는 점점 줄어듦
    loss_g = loss_f(out,ones[:batch_len])
    log_loss_g.append(loss_g.item())

    d.zero_grad()
    g.zero_grad()
    loss_g.backward()
    opt_g.step()

    # 식별모델
    real_out = d(real_img)
    loss_d_real = loss_f(real_out,ones[:batch_len])
    fake_img = fake_img_tensor
    fake_out = d(fake_img_tensor)
    loss_d_fake = loss_f(fake_out,zeros[:batch_len])

    # discriminator 의 loss는 real 의 로스와 fake 의 로스를 합한거임..!
    loss_d = loss_d_real + loss_d_fake
    log_loss_d.append(loss_d.item())

    # 식별 모델의 미분 계산과 파라미터 갱신
    d.zero_grad()
    g.zero_grad()
    loss_d.backward()
    opt_d.step()

  return mean(log_loss_g), mean(log_loss_d)


5. 훈련하기

In [None]:
from torchvision.utils import save_image

for epoch in range(300):
  train_dcgan(g,d,opt_g,opt_d,img_loader)

  #10회 반복마다 학습 결과 저장
  if epoch%10 == 0:
    #파라미터 저장
    torch.save(
        g.state_dict(),
        "/content/g_{:03d}.prm".format(epoch),
        pickle_protocol=4)
    
    torch.save(
        d.state_dict(),
        "/content/d_{:03d}.prm".format(epoch),
        pickle_protocol=4)
    
    # 모니터링용 z로부터 생성한 이미지 저장
    generated_img = g(fixed_z)
    save_image(generated_img, "/content/{:03d}.jpg".format(epoch))


100%|██████████| 128/128 [00:54<00:00,  2.36it/s]
100%|██████████| 128/128 [00:53<00:00,  2.39it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]
100%|██████████| 128/128 [00:53<00:00,  2.39it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]
100%|██████████| 128/128 [00:53<00:00,  2.40it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]
100%|██████████| 128/128 [00:53<00:00,  2.41it/s]
100%|██████████| 128/128 [00:53<00:00,  2.40it/s]
100%|██████████| 128/128 [00:53<00:00,  2.41it/s]
100%|██████████| 128/128 [00:53<00:00,  2.40it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]
100%|██████████| 128/128 [00:53<00:00,  2.39it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]
100%|██████████| 128/128 [00:53<00:00,  2.41it/s]
100%|██████████| 128/128 [00:53<00:00,  2.39it/s]
100%|██████████| 128/128 [00:53<00:00,  2.38it/s]


In [None]:
from IPython.display import Image,display_jpeg
display_jpeg(Image('/content/000.jpg'))