# 흑백이미지를 칼라이미지로 간을 이용

# Discriminator 와 Generator

In [45]:
from torch import nn

class Discriminator_FCN(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.FCN1 = nn.Linear(784, 100)
        self.FCN2 = nn.Linear(100, 1)
        self.sig = nn.Sigmoid()
        self.LRU = nn.LeakyReLU(0.2)
    def forward(self, z):
        out1 = self.FCN1(z)
        out3 = self.LRU(out1)
        out4 = self.FCN2(out3)
        out5 = self.LRU(out4)
        result = self.sig(out5)
        return result

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.Conv1 = nn.Conv2d(4, 128, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.Conv2 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(256)
        self.Conv3 = nn.Conv2d(256, 512, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(512)
        self.Conv4 = nn.Conv2d(512, 1, 4, 1, 0)
        self.LRU = nn.LeakyReLU(0.2)
        self.sig = nn.Sigmoid()

    def forward(self, z):
        z = self.LRU(self.bn1(self.Conv1(z)))
        z = self.LRU(self.bn2(self.Conv2(z)))
        z = self.LRU(self.bn3(self.Conv3(z)))
        out = self.sig(self.LRU(self.Conv4(z)))


        return out

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.DeConv1 = nn.ConvTranspose2d(1, 256, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(256)
        self.DeConv2 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.DeConv3 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.DeConv4 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1)
        self.Relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z):

        out1 = self.Relu(self.bn1(self.DeConv1(z)))
        out2 = self.Relu(self.bn2(self.DeConv2(out1)))
        out3 = self.Relu(self.bn3(self.DeConv3(out2)))
        out4 = self.tanh(self.DeConv4(out3))

        return out4


# 폴더로 부터 이미지 데이터 가져오기

In [7]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from os.path import join
from os import listdir


class DatasetFromfolder(Dataset):
    def __init__(self, dir_):
        super(DatasetFromfolder, self).__init__()
        
        self.filelist_tmp = [join(dir_, x) for x in listdir(dir_)]
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, index):

        gray = self.transform(Image.open(self.filelist_tmp[index]).convert('L'))
        color = self.transform(Image.open(self.filelist_tmp[index]))
        
        return gray, color

    def __len__(self):
        return len(self.filelist_tmp)

# 트레이닝 데이터로 학습하기

In [None]:
import torch
import torch.nn as nn
#from model import Generator, Discriminator
#from Data_loader import DatasetFromfolder
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import random

#제네레이터를 위한 흑백사진 가져오기
z_train_set = DatasetFromfolder('./cifar/cifar/train')   #폴더로 부터 트레이닝셋 가져오기
z_train_loader = DataLoader(dataset=z_train_set, num_workers=0, batch_size=16, shuffle=True) #트레이닝셋 데이터로더에 적재

NUM_EPOCHS = 100
def train():
    train_set = DatasetFromfolder('./cifar/cifar/train')   #폴더로 부터 트레이닝셋 가져오기
    train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=16, shuffle=True) #트레이닝셋 데이터로더에 적재
    
    netG = Generator()        # Generator선언
    netD = Discriminator()        # Discriminator 선언

    criterion = nn.MSELoss()        # cost계산은 MSELOSS선언
    
    optimizerD = optim.Adam(netD.parameters(), lr=0.0002)    # 아담옵티마이저
    optimizerG = optim.Adam(netG.parameters(), lr=0.0002)

    for epoch in range(1, NUM_EPOCHS + 1):
        batch_idx = 0
        for x, label in train_loader:

            gray = x; color = label
            
            batch_size = x.size(0)    # 배치사이즈 설정
            x = 2*x - 1        #generator 라벨과 이미지 결합
            
            #print("흑백  : ",gray.size())
            #print("칼라  : ",color.size())
            
            fake_image = netG(gray)
            
            #print("fake_image : " , fake_image.size())
            #print("gray_image : ", gray.size())
            #print("color_image : " , color.size())
            
            
            fake_image = fake_image.view(16,3,32,32)
            
            fake_image = torch.cat((fake_image,gray),dim=1)    # 흑백 제네레이터 칼라이미지 합치기
            real_image = torch.cat((color, gray),dim=1)    # 흑백 실제칼라 이미지 합치기

            #print("fake_image : " , fake_image.size())
            #print("real_image : " , real_image.size())
            
            fake_image = netD(fake_image)
            real_image = netD(real_image)

            #print("fake_image1 : " , fake_image.size())
            #print("real_image1 : " , real_image.size())
            
            
            netD.zero_grad()
            
            d_loss = criterion(fake_image.squeeze(), torch.ones(batch_size))
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            netG.train()
            netG.zero_grad()
            
            g_loss = criterion(fake_image.squeeze(), real_image.squeeze())
            g_loss.backward(retain_graph=True)
            optimizerG.step()

            if batch_idx % 20 == 0:
                netG.eval()

                generated_image = netG(gray)

                generated_image = (generated_image + 1) / 2

                print("Epoch:{} batch[{}/{}] G_loss:{} D_loss:{}".format(epoch, batch_idx, len(train_loader), g_loss, d_loss))
                torchvision.utils.save_image(generated_image.data, 'result/Generated-%d-%d.png' % (batch_idx, epoch))

            batch_idx += 1


if __name__ == "__main__":
    train()
    # test()




# 테스트하기

In [None]:
batch_size=16
def test(): 
    for x, label in train_loader:
        gray=x

        batch_idx=321;epoch=123

        netG = Generator()

        netG.eval()
        generated_image = netG(gray)
        generated_image = (generated_image + 1) / 2

        torchvision.utils.save_image(generated_image.data, 'result/Generated-%d-%d.png' % (batch_idx, epoch))

        print("generated_image.data : ")#, generated_image.data)

if __name__ == "__main__":
    test()
