In [None]:
!unzip '/content/drive/MyDrive/apk/img_align_celeba.zip' -d '/content/drive/MyDrive/GAN/'

In [None]:
import glob
import matplotlib.pyplot as plt
import os
from PIL import Image

imgs = glob.glob('/content/drive/MyDrive/GAN/img_align_celeba/*.jpg')

print(len(imgs))

for i in range(6):
    plt.subplot(2,3,i+1)
    img = Image.open(imgs[i])
    plt.imshow(img)

In [None]:
len(imgs)

In [None]:
#입력 이미지 전처리
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data.dataloader import DataLoader

In [None]:
transforms = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])

dataset = ImageFolder(
    root='/content/drive/MyDrive/GAN',
    transform = transforms
)

loader = DataLoader(dataset,batch_size=128, shuffle=True)

In [None]:
import numpy as np

iterdata = iter(loader)
img,label = iterdata.next()

img = img[3].numpy()
plt.imshow(np.transpose(img,(1,2,0)))

In [None]:
#생성자  generator

import torch.nn as nn

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()

        self.gen = nn.Sequential(
            nn.ConvTranspose2d(100,512,kernel_size=4, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.ConvTranspose2d(512,256,kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.ConvTranspose2d(256,128,kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128,64,kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64,3,kernel_size=4,stride=2,padding=1,bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

In [None]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator,self).__init__()

        self.disc = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512,1,kernel_size=4),
            nn.Sigmoid()
        )

In [None]:
def weights_init(m):
    classname = type(m).__class__.__name__

    if classname.find('Conv') != -1:
        nn.init.normal_(m.weights.data,0.0,0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data,1.0,0.02)
        nn.init.constant_(m.bias.data,0)

In [None]:
import tqdm
from torch.optim.adam import Adam

device = 'cuda' if torch.cuda.is_available() else 'cpu'

G = Generator().to(device)
G.apply(weights_init)

D = Discriminator().to(device)
D.apply(weights_init)

G_optim = Adam(G.parameters(), lr=0.0001, betas=(0.5,0.999))
D_optim = Adam(D.parameters(), lr=0.0001, betas=(0.5,0.999))


In [None]:
for epochs in range(20):
    iterator = tqdm.tqdm(enumerate(loader,0), total=len(loader))

    for i,data in iterator:
        D_optim.zero_grad()

        label = torch.ones_like(data[1], dtype=torch.float32 ).to(device)
        label_fake = torch.zero_like(data[1], dtype=torch.float32).to(device)

        #진짜 이미지를 1로 판별
        real = D(data[0].to(device))
        #128,1,1,1
        Dloss_real = nn.BCELoss()(torch.sqeeze(real),label)
        Dloss_real.backward()

        noise = torch.randn(label.shape[0], 100, 1,1, device=device)

        fake = G(noise)

        output = D(fake.detach())

        Dloss_fake = nn.BCELoss()(torch.squeeze(output),label_fake)
        Dloss_fake.backward()

        Dloss = Dloss_real + Dloss_fake
        D_optim.step()

        G_optim.zero_grad()
        output = D(fake)
        Gloss = nn.BCELoss()(torch.squeeze(output),label)
        Gloss.backward()

        G_optim.step()

        iterator.set_description(f"epoch:{epochs} iteration:{i} D_loss:{Dloss} G_loss:{Gloss}")

    torch.save(G.state_dict(), "Generator.pth")
    torch.save(D.state_dict(), "Discriminator.pth")