# 전체코드

# Discriminator와 Generator

In [None]:
from torch import nn

class Discriminator_FCN(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.FCN1 = nn.Linear(784, 100)
        self.FCN2 = nn.Linear(100, 1)
        self.sig = nn.Sigmoid()
        self.LRU = nn.LeakyReLU(0.2)
    def forward(self, z):
        out1 = self.FCN1(z)
        out3 = self.LRU(out1)
        out4 = self.FCN2(out3)
        out5 = self.LRU(out4)
        result = self.sig(out5)
        return result

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.Conv1 = nn.Conv2d(11, 128, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.Conv2 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(256)
        self.Conv3 = nn.Conv2d(256, 512, 4, 1, 0)
        self.bn3 = nn.BatchNorm2d(512)
        self.Conv4 = nn.Conv2d(512, 1, 4, 1, 0)
        self.LRU = nn.LeakyReLU(0.2)
        self.sig = nn.Sigmoid()

    def forward(self, z):
        z = self.LRU(self.bn1(self.Conv1(z)))
        z = self.LRU(self.bn2(self.Conv2(z)))
        z = self.LRU(self.bn3(self.Conv3(z)))
        out = self.sig(self.LRU(self.Conv4(z)))


        return out

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.DeConv1 = nn.ConvTranspose2d(110, 512, kernel_size=4, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(512)
        self.DeConv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(256)
        self.DeConv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.DeConv4 = nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1)
        self.Relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z):

        out1 = self.Relu(self.bn1(self.DeConv1(z)))
        out2 = self.Relu(self.bn2(self.DeConv2(out1)))
        out3 = self.Relu(self.bn3(self.DeConv3(out2)))
        out4 = self.tanh(self.DeConv4(out3))

        return out4


# 폴더로 부터 이미지 데이터 가져오기

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from os.path import join
from os import listdir


class DatasetFromfolder(Dataset):
    def __init__(self, dir_mnist):
        super(DatasetFromfolder, self).__init__()

        self.filelist = []
        self.lenlist = []
        self.lensum = 0
        for i in range(10):
            idir = join(dir_mnist, str(i))
            filelist_tmp = [join(idir, x) for x in listdir(idir)]
            self.filelist.append((filelist_tmp, i))
            self.lenlist.append(len(filelist_tmp))
            self.lensum = self.lensum + len(filelist_tmp)

        self.transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, index):

        c, cindex = self._findlabelfromindex(index)
        clist, label = self.filelist[c]
        resultimage = self.transform(Image.open(clist[cindex]).convert('L'))
        return resultimage, label

    def __len__(self):
        return self.lensum

    def _findlabelfromindex(self, index):
        label = 0
        indexsum = 0

        for i in range(10):
            indexsum += self.lenlist[i]
            if index < indexsum:
                label = i
                break

        classindex = index - indexsum


        return label, classindex


# 트레이닝 데이터로 학습하기

In [14]:
import torch
import torch.nn as nn
#from model import Generator, Discriminator
#from Data_loader import DatasetFromfolder
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import random

onehot = torch.zeros(10,10)
onehot = onehot.scatter_(1, torch.LongTensor([0,1,2,3,4,5,6,7,8,9]).view(10,1),1).view(10,10,1,1)
fill = torch.zeros([10,10,28,28])

for i in range(10):
    fill[i,i,:,:] = 1

NUM_EPOCHS = 100
def train():
    train_set = DatasetFromfolder('./4주차/mnist_png.tar/mnist_png/training')   #폴더로 부터 트레이닝셋 가져오기
    train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=16, shuffle=True) #트레이닝셋 데이터로더에 적재
    
    netG = Generator()        # Generator선언
    netD = Discriminator()        # Discriminator 선언

    criterion = nn.MSELoss()        # cost계산은 MSELOSS선언
    
    optimizerD = optim.Adam(netD.parameters(), lr=0.0002)    # 아담옵티마이저
    optimizerG = optim.Adam(netG.parameters(), lr=0.0002)

    for epoch in range(1, NUM_EPOCHS + 1):
        batch_idx = 0
        for x, label in train_loader:

            batch_size = x.size(0)

            #generator 라벨과 이미지 결합
            x = 2*x - 1
            z = torch.rand(batch_size, 100, 1, 1)
            
            #중복허용 배치사이즈만큼 랜덤함수 뽑기
            number=[random.choice(range(10)) for i in range(batch_size)]

            for i in range(batch_size):
                if i==0:
                    result_onehot = onehot[number[i]].unsqueeze(0)
                    result_fill = fill[number[i]].unsqueeze(0)
                    real_fill = fill[label[i].item()].unsqueeze(0)
                else:
                    result_onehot = torch.cat((result_onehot,onehot[number[i]].unsqueeze(0)),dim=0)
                    result_fill = torch.cat((result_fill, fill[number[i]].unsqueeze(0)),dim=0)
                    real_fill = torch.cat((real_fill, fill[label[i].item()].unsqueeze(0)),dim=0)
            
            z = torch.cat((z,result_onehot),dim=1)
            
            fake_image = netG(z)
            
            fake_image = torch.cat((fake_image,result_fill),dim=1)
            real_image = torch.cat((x, real_fill),dim=1)

            fake_image = netD(fake_image)
            real_image = netD(real_image)

            netD.zero_grad()
            
            d_loss = criterion(fake_image.squeeze(), torch.zeros(batch_size)) + criterion(real_image.squeeze(), torch.ones(batch_size))
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            netG.train()
            netG.zero_grad()
            
            g_loss = criterion(fake_image.squeeze(), torch.ones(batch_size))
            g_loss.backward(retain_graph=True)
            optimizerG.step()

            if batch_idx % 20 == 0:
                netG.eval()

                eval_z = torch.rand(batch_size, 110, 1, 1)
                generated_image = netG(eval_z)

                generated_image = (generated_image + 1) / 2

                print("Epoch:{} batch[{}/{}] G_loss:{} D_loss:{}".format(epoch, batch_idx, len(train_loader), g_loss, d_loss))
                torchvision.utils.save_image(generated_image.data, 'result/Generated-%d-%d.png' % (batch_idx, epoch))

            batch_idx += 1


if __name__ == "__main__":
    train()
    # test()




Epoch:1 batch[0/3750] G_loss:0.19628453254699707 D_loss:0.5187492966651917
Epoch:1 batch[20/3750] G_loss:0.33395302295684814 D_loss:0.39638084173202515
Epoch:1 batch[40/3750] G_loss:0.4271610677242279 D_loss:0.1558760106563568
Epoch:1 batch[60/3750] G_loss:0.6553338170051575 D_loss:0.051713988184928894
Epoch:1 batch[80/3750] G_loss:0.5477611422538757 D_loss:0.06989456713199615
Epoch:1 batch[100/3750] G_loss:0.6151707172393799 D_loss:0.09008370339870453
Epoch:1 batch[120/3750] G_loss:0.7239062190055847 D_loss:0.023009538650512695
Epoch:1 batch[140/3750] G_loss:0.8211373090744019 D_loss:0.009451774880290031
Epoch:1 batch[160/3750] G_loss:0.6906387209892273 D_loss:0.030546650290489197
Epoch:1 batch[180/3750] G_loss:0.7919066548347473 D_loss:0.01298427116125822
Epoch:1 batch[200/3750] G_loss:0.8445565700531006 D_loss:0.006978472229093313
Epoch:1 batch[220/3750] G_loss:0.7974663972854614 D_loss:0.012828187085688114
Epoch:1 batch[240/3750] G_loss:0.7559667229652405 D_loss:0.01795526780188083

Epoch:1 batch[2080/3750] G_loss:0.966705322265625 D_loss:0.00029165088199079037
Epoch:1 batch[2100/3750] G_loss:0.9855790734291077 D_loss:0.021448543295264244
Epoch:1 batch[2120/3750] G_loss:0.9712396860122681 D_loss:0.000274309073574841
Epoch:1 batch[2140/3750] G_loss:0.9742198586463928 D_loss:0.000198095862288028
Epoch:1 batch[2160/3750] G_loss:0.9811457395553589 D_loss:0.00010477031173650175
Epoch:1 batch[2180/3750] G_loss:0.9312477111816406 D_loss:0.0017636979464441538
Epoch:1 batch[2200/3750] G_loss:0.9868040084838867 D_loss:4.638973041437566e-05
Epoch:1 batch[2220/3750] G_loss:0.9889776110649109 D_loss:3.143812864436768e-05
Epoch:1 batch[2240/3750] G_loss:0.978891134262085 D_loss:0.00011503935093060136
Epoch:1 batch[2260/3750] G_loss:0.9214922189712524 D_loss:0.0016546251717954874
Epoch:1 batch[2280/3750] G_loss:0.6760477423667908 D_loss:0.07183282822370529
Epoch:1 batch[2300/3750] G_loss:0.900397002696991 D_loss:0.005317536182701588
Epoch:1 batch[2320/3750] G_loss:0.886556267738

Epoch:2 batch[400/3750] G_loss:0.976107120513916 D_loss:0.00014768500113859773
Epoch:2 batch[420/3750] G_loss:0.9596096277236938 D_loss:0.000532470759935677
Epoch:2 batch[440/3750] G_loss:0.9504700303077698 D_loss:0.0007479197229258716
Epoch:2 batch[460/3750] G_loss:0.9687348008155823 D_loss:0.0002865290443878621
Epoch:2 batch[480/3750] G_loss:0.9773308634757996 D_loss:0.0001393653074046597
Epoch:2 batch[500/3750] G_loss:0.975903332233429 D_loss:0.00018316552450414747
Epoch:2 batch[520/3750] G_loss:0.9705036282539368 D_loss:0.0002314510493306443
Epoch:2 batch[540/3750] G_loss:0.937509298324585 D_loss:0.0014937728410586715
Epoch:2 batch[560/3750] G_loss:0.9316984415054321 D_loss:0.0013767934869974852
Epoch:2 batch[580/3750] G_loss:0.8764358758926392 D_loss:0.03565102070569992
Epoch:2 batch[600/3750] G_loss:0.7512375712394714 D_loss:0.023608144372701645
Epoch:2 batch[620/3750] G_loss:0.9996520280838013 D_loss:0.03165096789598465
Epoch:2 batch[640/3750] G_loss:0.9476082921028137 D_loss:0.

Epoch:2 batch[2480/3750] G_loss:0.895846962928772 D_loss:0.005661034025251865
Epoch:2 batch[2500/3750] G_loss:0.8646672368049622 D_loss:0.010688134469091892
Epoch:2 batch[2520/3750] G_loss:0.8379484415054321 D_loss:0.016247062012553215
Epoch:2 batch[2540/3750] G_loss:0.9815447330474854 D_loss:0.030538802966475487
Epoch:2 batch[2560/3750] G_loss:0.9832161664962769 D_loss:7.78387620812282e-05
Epoch:2 batch[2580/3750] G_loss:0.846763014793396 D_loss:0.007959544658660889
Epoch:2 batch[2600/3750] G_loss:0.9809679985046387 D_loss:0.0001570777385495603
Epoch:2 batch[2620/3750] G_loss:0.981971025466919 D_loss:0.00010910287528531626
Epoch:2 batch[2640/3750] G_loss:0.9414302706718445 D_loss:0.01891489140689373
Epoch:2 batch[2660/3750] G_loss:0.9547423124313354 D_loss:0.0006694306503050029
Epoch:2 batch[2680/3750] G_loss:0.9337377548217773 D_loss:0.0015942276222631335
Epoch:2 batch[2700/3750] G_loss:0.6801590323448181 D_loss:0.03966391086578369
Epoch:2 batch[2720/3750] G_loss:0.989946722984314 D_

Epoch:3 batch[800/3750] G_loss:0.7593483924865723 D_loss:0.08108426630496979
Epoch:3 batch[820/3750] G_loss:0.9543149471282959 D_loss:0.01742301881313324
Epoch:3 batch[840/3750] G_loss:0.7654434442520142 D_loss:0.016390252858400345
Epoch:3 batch[860/3750] G_loss:0.9835490584373474 D_loss:0.043157074600458145
Epoch:3 batch[880/3750] G_loss:0.8375325202941895 D_loss:0.008855902589857578
Epoch:3 batch[900/3750] G_loss:0.8067116141319275 D_loss:0.012141124345362186
Epoch:3 batch[920/3750] G_loss:0.9136952757835388 D_loss:0.1283237785100937
Epoch:3 batch[940/3750] G_loss:0.8370203375816345 D_loss:0.2242274284362793
Epoch:3 batch[960/3750] G_loss:0.9848808646202087 D_loss:0.037186942994594574
Epoch:3 batch[980/3750] G_loss:0.7586674690246582 D_loss:0.02710662968456745
Epoch:3 batch[1000/3750] G_loss:0.8364800810813904 D_loss:0.013297284953296185
Epoch:3 batch[1020/3750] G_loss:0.9909632205963135 D_loss:0.001102364039979875
Epoch:3 batch[1040/3750] G_loss:0.9738165140151978 D_loss:0.000377733

Epoch:3 batch[2880/3750] G_loss:0.88602614402771 D_loss:0.004810280632227659
Epoch:3 batch[2900/3750] G_loss:0.9291326403617859 D_loss:0.001616532914340496
Epoch:3 batch[2920/3750] G_loss:0.9630170464515686 D_loss:0.0004289568751119077
Epoch:3 batch[2940/3750] G_loss:0.9562892913818359 D_loss:0.0005876501672901213
Epoch:3 batch[2960/3750] G_loss:0.9836711883544922 D_loss:0.0001295321126235649
Epoch:3 batch[2980/3750] G_loss:0.9720038771629333 D_loss:0.00036186116631142795
Epoch:3 batch[3000/3750] G_loss:0.8615792989730835 D_loss:0.009186356328427792
Epoch:3 batch[3020/3750] G_loss:0.8967511653900146 D_loss:0.05844280123710632
Epoch:3 batch[3040/3750] G_loss:0.9721677303314209 D_loss:0.0006495246198028326
Epoch:3 batch[3060/3750] G_loss:0.9418764710426331 D_loss:0.0021614599972963333
Epoch:3 batch[3080/3750] G_loss:0.9411250352859497 D_loss:0.0017970317276194692
Epoch:3 batch[3100/3750] G_loss:0.942913293838501 D_loss:0.019098574295639992
Epoch:3 batch[3120/3750] G_loss:0.95907759666442

FileNotFoundError: [Errno 2] No such file or directory: './4주차/mnist_png.tar/mnist_png/training\\0\\27658.png'

# 테스트하기

In [None]:
batch_size=16
def test():
    
    onehot = torch.zeros(10,10)
    onehot = onehot.scatter_(1, torch.LongTensor([0,1,2,3,4,5,6,7,8,9]).view(10,1),1).view(10,10,1,1)
    fill = torch.zeros([10,10,28,28])
    
    batch_idx=321;epoch=123

    netG = Generator()
    z = torch.rand(batch_size, 100, 1, 1)

    for i in range(10):
        fill[i,i,:,:] = 1

    number=[1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6]
    for i in range(16):
        if i==0:
            result_onehot = onehot[number[i]].unsqueeze(0)
        else:
            result_onehot = torch.cat((result_onehot,onehot[number[i]].unsqueeze(0)),dim=0)

    z = torch.cat((z,result_onehot),dim=1)

    netG.eval()
    generated_image = netG(z)
    generated_image = (generated_image + 1) / 2

    torchvision.utils.save_image(generated_image.data, 'result/Generated-%d-%d.png' % (batch_idx, epoch))
    
    print("generated_image.data : ")#, generated_image.data)

if __name__ == "__main__":
    test()


In [8]:
onehot[0].size()

torch.Size([10, 1, 1])