In [1]:
import matplotlib.pyplot as plt 
import numpy as np
import os
from matplotlib.image import imread, imsave
from PIL import Image

import torch                                                
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

### 0) Hyperparmeters

In [2]:
#num_workers = 1
batch_size = 128     # NCHW 중 N값
nc = 3  # the number of output channels
nz = 100  # size of z latent vector(랜덤개수) (i.e., size of the generator input)
ngf = 64  # size of feature maps in generator
ndf = 64  # size of feature maps in discriminator
epochs = 50     #
lr = 0.0002     #업데이트 간격
beta1 = 0.5     # optima에 들어갈 값
seed = 230522   # 랜덤값 고정하기 위한
torch.manual_seed(seed)
check_epoch = 1
num_data = 50000    #202599중 5만개  # whole data

save_dir = "save_HT"
os.makedirs(save_dir)  

### 1) Define generator

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, (4, 4), 1, 0, bias=False),  # (100, 1, 1) -> (512, 4, 4) / can be linear  
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),    #  activation func
            
            nn.ConvTranspose2d(ngf * 8, ngf * 4, (4, 4), 2, 1, bias=False),  # (512,4,4) -> (256, 8, 8)
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, (4, 4), 2, 1, bias=False),  #  (256, 8, 8) -> (128, 16, 16)
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 2, ngf, (4, 4), 2, 1, bias=False),  # (128, 16, 16) -> (64, 32, 32)
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf, nc, (4, 4), 2, 1),  # (3, 64, 64)  # 이미지 채널에 맞게 3으로 줄여나가고 / 정보보존
            # nn.BatchNorm2d(nc),
            nn.Tanh()
        )

    def forward(self, z):           
        """z: random uniform (or Gaussian) noise of size (N, 100, 1, 1)"""
        return self.main(z)
    
    
    
    
    # init이 일단 실행 

### 2) Define discriminator

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, (4, 4), 2, 1),  # (64, 32, 32)
            # nn.BatchNorm2d(ndf),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf, ndf * 2, (4, 4), 2, 1, bias=False),  # (128, 16, 16)
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 2, ndf * 4, (4, 4), 2, 1, bias=False),  # (256, 8, 8)
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 4, ndf * 8, (4, 4), 2, 1, bias=False),  # (512, 4, 4)
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 8, 1, (4, 4), 1, 0),  # (512, 4, 4) -> (1, 1, 1) / can be linear
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.main(inputs).view(-1)  # (N)

### 3) Define GAN loss

In [5]:
criterion = nn.BCELoss()  # binaty Cross Entropy / class가 2개

### 4) Module construction

In [6]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02) # 가우시안에서 평균이0
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0.0)

In [7]:
g = Generator()
g.apply(weights_init) 

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Tanh()
  )
)

In [8]:
d = Discriminator()
d.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
    (12): Sigmoid()
  )
)

In [9]:
g_optim = torch.optim.Adam(
    g.parameters(),  #학습시킬 파라미터
    lr=lr,
    betas=(beta1, 0.999),
    weight_decay=0
)

In [10]:
d_optim = torch.optim.Adam(
    d.parameters(),
    lr=lr,
    betas=(beta1, 0.999),
    weight_decay=0
)

### 5) Data construction

In [11]:
##  py에서 대문자는 클래스  파이토치 클래스 기본구조 : 3 method




In [12]:
class Face(Dataset):
    def __init__(self, datadir, num_data, transforms=[]):
        files = os.listdir(datadir)                                         # 폴더안의 파일들을 리스트 형식으로 
        files = sorted(files, key=lambda x: int(x[:-4]))                   #정렬함수 sorted 
        self.files = [os.path.join(datadir, file) for file in files]
        self.files = self.files[:num_data]  #num_data = 50000 / 5만개 까지의 이미지만 사용
        self.transforms = transforms

    def __len__(self):    
        """
        Face dataset contains 202599 items (N=202599, C=3, H=218, W=178)
        Among them, uses specific number of data only (num_data parameter in __init__)
        """
        return len(self.files) ##전체 데이터 갯수

    def __getitem__(self, item):               #HDD -> RAM으로 전처리해서 데이터 이동
        file = self.files[item]
        img = imread(file)                      # 실제로 이미지 읽는 함수 : imread
        for transform in self.transforms:      
            img = transform(img)
        return img
    
    

In [13]:
def max_min_scaling(tensor):         #정규화
    u = tensor.min()
    v = tensor.max()
    return (tensor - 0.5 * (v + u)) / (0.5 * (v - u))  # -1 ~ 1


dataset = Face(
    "img_align_celeba",
    num_data=num_data,
    transforms=[
        transforms.ToPILImage(),  # Resize class only accepts PIL image as a forward input (torch 1.6.0)
        transforms.Resize(size=(64, 64), interpolation=Image.BICUBIC),
        transforms.ToTensor(),#파이토치에서 사용하는 데이터형식
        max_min_scaling
    ]
)

dataloader = DataLoader(            
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=num_workers,  # multi-process dataloading
    #pin_memory=True,
    #sampler=None
)

  transforms.Resize(size=(64, 64), interpolation=Image.BICUBIC),


FileNotFoundError: [WinError 3] 지정된 경로를 찾을 수 없습니다: 'img_align_celeba'

In [None]:
len(dataset)    #plt.imshow()

In [None]:
len(dataloader)   # 50000/128

### 6) Train

In [None]:
def train(dataloader, g, d, criterion, g_optim, d_optim):
    g.train()
    d.train()

    for batch_idx, imgs in enumerate(dataloader):                    
        z = torch.FloatTensor(batch_size, nz, 1, 1).uniform_(-1, 1)  #랜덤값 

        
        # gen 시작
        fake = g(z)
        d_real = d(imgs)         # input실제이미지 output : 0~1
        d_fake = d(fake.detach()) # input가짜이미지 output : 0~1
        
        d_loss_real = criterion(d_real, torch.ones_like(d_real))  # 실제이미지 넣고 lossfuc 1로보내기
        d_loss_fake = criterion(d_fake, torch.zeros_like(d_fake))
        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_optim.zero_grad()
        d_loss.backward()
        d_optim.step()        #discri training

        d_fake = d(fake)  # extends graph from generator to discriminator
        g_loss = criterion(d_fake, torch.ones_like(d_fake))
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()
        
        print('batch_idx: {}'.format(batch_idx))

In [None]:
def test(z, g, d):
    g.eval()
    d.eval()

    with torch.no_grad():          #이미지 학습이 끝나면 discrimi 
        fake = g(z)

        return denormalize(fake)
    

def denormalize(tensor):
    return 0.5 * tensor + 0.5  # 0 ~ 1   -1~1를 0~1로 보내줌

In [None]:
class AppendTensor:    #학습을
    """this class is for tensors to accumulate"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = []

    def update(self, tensor):
        self.sum.append(tensor)

In [None]:
def figure(epoch, size, tensor, save_dir):
    fig = plt.figure(figsize=(15, 15))
    plt.title('epochs: {}'.format(epoch), fontsize=30)
    plt.axis('off')
    for i in range(size):
        subplot = fig.add_subplot(10, 10, i + 1)

        subplot.imshow(
            tensor.permute(0, 2, 3, 1).numpy()[i]
        )
        plt.axis('off')
        plt.tight_layout()

    plt.savefig(
        os.path.join(save_dir, "{}.png".format(epoch)),
        dpi=300
    )
    plt.show()

In [None]:
# fixed_z = torch.FloatTensor(100, nz, 1, 1).uniform_(-1, 1)
# outs = AppendTensor()

# for epoch in range(epochs):
#     train(dataloader, g, d, criterion, g_optim, d_optim)
#     print('Epoch{} has been completed.'.format(epoch))
#     print('')
    
#     if epoch % check_epoch == 0:
#         fake = test(fixed_z, g, d)
#         outs.update(fake)
        
#         figure(epoch, fixed_z.size(0), fake, save_dir)

In [None]:
# for epoch in range(17, epochs):
#     train(dataloader, g, d, criterion, g_optim, d_optim)
#     print('Epoch{} has been completed.'.format(epoch))
#     print('')
    
#     if epoch % check_epoch == 0:
#         fake = test(fixed_z, g, d)
#         outs.update(fake)
        
#         figure(epoch, fixed_z.size(0), fake, save_dir)

In [None]:
# for epoch in range(28, epochs):
#     train(dataloader, g, d, criterion, g_optim, d_optim)
#     print('Epoch{} has been completed.'.format(epoch))
#     print('')
    
#     if epoch % check_epoch == 0:
#         fake = test(fixed_z, g, d)
#         outs.update(fake)
        
#         figure(epoch, fixed_z.size(0), fake, save_dir)

In [None]:
# torch.save(g.state_dict(), os.path.join(save_dir, "g_state_dict_at_epoch{}.pth.tar".format(epoch - 1)))
# torch.save(d.state_dict(), os.path.join(save_dir, "d_state_dict_at_epoch{}.pth.tar".format(epoch - 1)))

In [None]:
os.environ['KMP_DUPLCATE_LIB_OK'] = 'True'

In [None]:
g.load_state_dict(torch.load("save\g_state_dict_at_epoch48.pth.tar"))
#d.load_state_dict(torch.load(save_dir + "\d_state_dict_at_epoch27.pth.tar"))

### 7) Walking in the latent space

In [None]:
def lerp(v0, v1, n_step):
    h = []
    for i in torch.linspace(0, 1, n_step):
        h.append((1 - i) * v0 + i * v1)
    
    return torch.cat(h, dim=0)

In [None]:
for _ in range(10):
    # define latent vectors
    z0 = torch.Tensor(1, nz, 1, 1).uniform_(-1, 1)
    z1 = torch.Tensor(1, nz, 1, 1).uniform_(-1, 1)
    zs = lerp(z0, z1, 10)
    #print(zs.size())  # torch.Size([10, 100, 1, 1])

    with torch.no_grad():
        fake = g(zs)
        fake = denormalize(fake)

    fig = plt.figure(figsize=(15, 15))
    for i in range(fake.size(0)):
        subplot = fig.add_subplot(1, 10, i+1)
        subplot.imshow(fake[i].permute(1, 2, 0).numpy())
        plt.axis('off')
    plt.show()

In [None]:
z_mk = torch.Tensor(1, nz, 1, 1).uniform_(-1, 1)
r_mk=denormalize(g(z_mk))
plt.imshow(r_mk.detach().permute(0,2,3,1).numpy()[0])  #permute 파라미터 순서 바꿈


