# Unet을 이용한 Generator Networks 만들기

# Discriminator  와 Generator

In [4]:
from torch import nn

class Discriminator_FCN(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.FCN1 = nn.Linear(784, 100)
        self.FCN2 = nn.Linear(100, 1)
        self.sig = nn.Sigmoid()
        self.LRU = nn.LeakyReLU(0.2)
    def forward(self, z):
        out1 = self.FCN1(z)
        out3 = self.LRU(out1)
        out4 = self.FCN2(out3)
        out5 = self.LRU(out4)
        result = self.sig(out5)
        return result

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.Conv1 = nn.Conv2d(4, 128, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.Conv2 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(256)
        self.Conv3 = nn.Conv2d(256, 512, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(512)
        self.Conv4 = nn.Conv2d(512, 1, 4, 1, 0)
        self.LRU = nn.LeakyReLU(0.2)
        self.sig = nn.Sigmoid()

    def forward(self, z):
        z = self.LRU(self.bn1(self.Conv1(z)))
        z = self.LRU(self.bn2(self.Conv2(z)))
        z = self.LRU(self.bn3(self.Conv3(z)))
        out = self.sig(self.LRU(self.Conv4(z)))

        return out

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 줄이기
        self.Conv1 = nn.Conv2d(1, 64, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(64)         # 32x32,64
        self.Conv2 = nn.Conv2d(64, 128, 4, 2, 1) 
        self.bn2 = nn.BatchNorm2d(128)         # 16x16,128
        self.Conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(256)         # 8x8,256
        self.Conv4 = nn.Conv2d(256, 256, 4, 2, 1)
        self.bn4 = nn.BatchNorm2d(256)         # 4x4,256
        self.Conv5 = nn.Conv2d(256, 256, 4, 2, 1)
        self.bn5 = nn.BatchNorm2d(256)         # 2x2,256
        self.Conv6 = nn.Conv2d(256, 256, 4, 2, 1)
        self.bn6 = nn.BatchNorm2d(256)         # 1x1,256
        
        # 늘이기
        self.DeConv1 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1)
        self.Dbn1 = nn.BatchNorm2d(256)       # 2x2,256
        self.DeConv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.Dbn2 = nn.BatchNorm2d(256)       # 4x4,128
        self.DeConv3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.Dbn3 = nn.BatchNorm2d(256)       # 8x8 64
        self.DeConv4 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.Dbn4 = nn.BatchNorm2d(128)       # 16x16 32
        self.DeConv5 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.Dbn5 = nn.BatchNorm2d(64)       # 16x16 32
        self.DeConv6 = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1)
        # 32x32 3
        
        
        self.Relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z):

        out1 = self.Relu(self.bn1(self.Conv1(z)))       # 32x32 64
        out2 = self.Relu(self.bn2(self.Conv2(out1)))    # 16x16 128
        out3 = self.Relu(self.bn3(self.Conv3(out2)))    # 8x8,256
        out4 = self.Relu(self.bn4(self.Conv4(out3)))    # 4x4,256
        out5 = self.Relu(self.bn5(self.Conv5(out4)))    # 2x2,256
        out6 = self.Relu(self.bn6(self.Conv6(out5)))    # 1x1,256
        
       
        Dout1 = self.Relu(self.Dbn1(self.DeConv1(out6))) # out6 : 256
        Dout2 = self.Relu(self.Dbn2(self.DeConv2(torch.cat((out5, Dout1), dim=1))))#256 256out5 dout1
        Dout3 = self.Relu(self.Dbn3(self.DeConv3(torch.cat((out4, Dout2), dim=1))))#256 128 out4 dout2
        Dout4 = self.Relu(self.Dbn4(self.DeConv4(torch.cat((out3, Dout3), dim=1))))#256 64  out3 dout3
        Dout5 = self.Relu(self.Dbn5(self.DeConv5(torch.cat((out2, Dout4), dim=1))))
        Dout6 = self.tanh(self.DeConv6(torch.cat((out1, Dout5), dim=1)))#128 32 out2 dout4
        
        return Dout6


# Dataloader

In [5]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from os.path import join
from os import listdir


class DatasetFromfolder(Dataset):
    def __init__(self, dir_):
        super(DatasetFromfolder, self).__init__()
        
        self.filelist_tmp = [join(dir_, x) for x in listdir(dir_)]
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, index):

        gray = self.transform(Image.open(self.filelist_tmp[index]).convert('L'))
        color = self.transform(Image.open(self.filelist_tmp[index]))
        
        return gray, color

    def __len__(self):
        return len(self.filelist_tmp)

# 트레이닝 데이터로 학습하기

In [11]:
import torch
import torch.nn as nn
#from model import Generator, Discriminator
#from Data_loader import DatasetFromfolder
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import random

#제네레이터를 위한 흑백사진 가져오기
z_train_set = DatasetFromfolder('./cifar/cifar/train')   #폴더로 부터 트레이닝셋 가져오기
z_train_loader = DataLoader(dataset=z_train_set, num_workers=0, batch_size=16, shuffle=True) #트레이닝셋 데이터로더에 적재

NUM_EPOCHS = 100
def train():
    test_cnt = 0
    
    train_set = DatasetFromfolder('./cifar/cifar/train')   #폴더로 부터 트레이닝셋 가져오기
    train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=16, shuffle=True) #트레이닝셋 데이터로더에 적재
    
    test_set = DatasetFromfolder('./cifar/cifar/test1')   #폴더로 부터 테스트셋 가져오기
    test_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=1, shuffle=True) #테스트셋 데이터로더에 할당
    
    
    netG = Generator()        # Generator선언
    netD = Discriminator()        # Discriminator 선언

    criterion = nn.MSELoss()        # cost계산은 MSELOSS선언
    
    optimizerD = optim.Adam(netD.parameters(), lr=0.0002)    # 아담옵티마이저
    optimizerG = optim.Adam(netG.parameters(), lr=0.0002)

    for epoch in range(1, NUM_EPOCHS + 1):
        batch_idx = 0
        for x, label in train_loader:

            gray = x; color = label
            
            batch_size = x.size(0)    # 배치사이즈 설정
            #gray = 2*gray  - 1        #generator 라벨과 이미지 결합
            #color = 2*color - 1
            
            #print("흑백  : ",gray.size())
            #print("칼라  : ",color.size())
            
            fake_image = netG(gray)
            
            #print("fake_image : " , fake_image.size())
            #print("gray_image : ", gray.size())
            #print("color_image : " , color.size())
            
            
            fake_image = torch.cat((fake_image,gray),dim=1)    # 흑백 제네레이터 칼라이미지 합치기
            real_image = torch.cat((color, gray),dim=1)    # 흑백 실제칼라 이미지 합치기

            #print("fake_image : " , fake_image.size())
            #print("real_image : " , real_image.size())
            
            fake_decision = netD(fake_image)
            real_decision = netD(real_image)

            #print("fake_image1 : " , fake_image.size())
            #print("real_image1 : " , real_image.size())
            
            
            netD.zero_grad()
            
            d_loss = criterion(fake_decision.squeeze(), torch.zeros(batch_size)) + criterion(real_decision.squeeze(), torch.ones(batch_size))
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            netG.train()
            netG.zero_grad()
            
            g_loss = 10*criterion(fake_image.squeeze(), real_image.squeeze()) + criterion(fake_decision.squeeze(), torch.ones(batch_size))
            g_loss.backward(retain_graph=True)
            optimizerG.step()

            if batch_idx % 20 == 0:
                netG.eval()
                for gray_image, original in test_loader:
                    generated_image = netG(gray_image) 

                    #gray_image = (gray_image + 1)/2     # 테스트셋으로 만든 흑백 이미지
                    #generated_image = (generated_image + 1) / 2        #테스트셋에 제네레이터를 거쳐 만든 이미지

                    print("Epoch:{} batch[{}/{}] G_loss:{} D_loss:{}".format(epoch, batch_idx, len(train_loader), g_loss, d_loss))
                    
                    
                    total = torch.cat((generated_image, original), dim=2)
                    torchvision.utils.save_image(total.data, 'result/Generated-%d-%d.png' % (batch_idx, epoch))   #흑백이미지 저장
                    #torchvision.utils.save_image(generated_image.data, 'result/Generated-%d-%d.png' % (batch_idx, epoch))   #칼라이미지 저장
                    #torchvision.utils.save_image(original.data, 'result/Original-%d-%d.png' % (batch_idx, epoch))   #칼라이미지 저장
                
            batch_idx += 1


if __name__ == "__main__":
    train()
    # test()




Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_lo

Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_lo

Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_lo

Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_loss:0.5066874027252197
Epoch:1 batch[0/3125] G_loss:7.249268531799316 D_lo

PermissionError: [Errno 13] Permission denied: 'result/Generated-0-1.png'

# 테스트하기