# GAN 전체코드

In [1]:
from torch import nn

class Discriminator_FCN(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.FCN1 = nn.Linear(784, 100)
        self.FCN2 = nn.Linear(100, 1)
        self.sig = nn.Sigmoid()
        self.LRU = nn.LeakyReLU(0.2)
    def forward(self, z):
        out1 = self.FCN1(z)
        out3 = self.LRU(out1)
        out4 = self.FCN2(out3)
        out5 = self.LRU(out4)
        result = self.sig(out5)
        return result

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.Conv1 = nn.Conv2d(1, 128, 4, 2, 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.Conv2 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(256)
        self.Conv3 = nn.Conv2d(256, 512, 4, 1, 0)
        self.bn3 = nn.BatchNorm2d(512)
        self.Conv4 = nn.Conv2d(512, 1, 4, 1, 0)
        self.LRU = nn.LeakyReLU(0.2)
        self.sig = nn.Sigmoid()

    def forward(self, z):
        z = self.LRU(self.bn1(self.Conv1(z)))
        z = self.LRU(self.bn2(self.Conv2(z)))
        z = self.LRU(self.bn3(self.Conv3(z)))
        out = self.sig(self.LRU(self.Conv4(z)))


        return out

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.DeConv1 = nn.ConvTranspose2d(100, 512, kernel_size=4, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(512)
        self.DeConv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(256)
        self.DeConv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.DeConv4 = nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1)
        self.Relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z):

        out1 = self.Relu(self.bn1(self.DeConv1(z)))
        out2 = self.Relu(self.bn2(self.DeConv2(out1)))
        out3 = self.Relu(self.bn3(self.DeConv3(out2)))
        out4 = self.tanh(self.DeConv4(out3))

        return out4


In [2]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from os.path import join
from os import listdir


class DatasetFromfolder(Dataset):
    def __init__(self, dir_mnist):
        super(DatasetFromfolder, self).__init__()

        self.filelist = []
        self.lenlist = []
        self.lensum = 0
        for i in range(10):
            idir = join(dir_mnist, str(i))
            filelist_tmp = [join(idir, x) for x in listdir(idir)]
            self.filelist.append((filelist_tmp, i))
            self.lenlist.append(len(filelist_tmp))
            self.lensum = self.lensum + len(filelist_tmp)

        self.transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, index):

        c, cindex = self._findlabelfromindex(index)
        clist, label = self.filelist[c]
        resultimage = self.transform(Image.open(clist[cindex]).convert('L'))
        return resultimage, label

    def __len__(self):
        return self.lensum

    def _findlabelfromindex(self, index):
        label = 0
        indexsum = 0

        for i in range(10):
            indexsum += self.lenlist[i]
            if index < indexsum:
                label = i
                break

        classindex = index - indexsum


        return label, classindex


In [6]:
import torch
import torch.nn as nn
from model import Generator, Discriminator
from Data_loader import DatasetFromfolder
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision

NUM_EPOCHS = 100
def train():
    train_set = DatasetFromfolder('./4주차/mnist_png.tar/mnist_png/training')
    train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=16, shuffle=True)

    netG = Generator()
    netD = Discriminator()

    criterion = nn.MSELoss()

    optimizerD = optim.Adam(netD.parameters())
    optimizerG = optim.Adam(netG.parameters())

    for epoch in range(1, NUM_EPOCHS + 1):
        batch_idx = 0
        for x, label in train_loader:

            batch_size = x.size(0)

            x = 2*x - 1
            z = torch.rand(batch_size, 100, 1, 1)
            fake_image = netG(z)

            fake = netD(fake_image)
            real = netD(x)

            netD.zero_grad()

            d_loss = criterion(fake.squeeze(), torch.zeros(batch_size)) + criterion(real.squeeze(), torch.ones(batch_size))
            d_loss.backward(retain_graph=True)
            optimizerD.step()

            netG.train()
            netG.zero_grad()

            g_loss = criterion(fake.squeeze(), torch.ones(batch_size))
            g_loss.backward(retain_graph=True)
            optimizerG.step()

            if batch_idx % 20 == 0:
                netG.eval()

                eval_z = torch.rand(batch_size, 100, 1, 1)
                generated_image = netG(eval_z)

                generated_image = (generated_image + 1) / 2

                print("Epoch:{} batch[{}/{}] G_loss:{} D_loss:{}".format(epoch, batch_idx, len(train_loader), g_loss, d_loss))
                torchvision.utils.save_image(generated_image.data, 'result/Generated-%d-%d.png' % (batch_idx, epoch))

            batch_idx += 1


if __name__ == "__main__":
    train()
    # test()




Epoch:1 batch[0/3750] G_loss:0.2260117083787918 D_loss:0.5138103365898132
Epoch:1 batch[20/3750] G_loss:0.8578395843505859 D_loss:0.005511901341378689
Epoch:1 batch[40/3750] G_loss:0.930311381816864 D_loss:0.0012601178605109453
Epoch:1 batch[60/3750] G_loss:0.7960892915725708 D_loss:0.01166468020528555
Epoch:1 batch[80/3750] G_loss:0.7694305777549744 D_loss:0.02663208544254303
Epoch:1 batch[100/3750] G_loss:0.9670018553733826 D_loss:0.00027709550340659916
Epoch:1 batch[120/3750] G_loss:0.9178361892700195 D_loss:0.0017631229711696506
Epoch:1 batch[140/3750] G_loss:0.9362282752990723 D_loss:0.001050595543347299
Epoch:1 batch[160/3750] G_loss:0.9596448540687561 D_loss:0.00044444037484936416
Epoch:1 batch[180/3750] G_loss:0.9178345203399658 D_loss:0.0017680064775049686
Epoch:1 batch[200/3750] G_loss:0.9278940558433533 D_loss:0.0013640668476000428
Epoch:1 batch[220/3750] G_loss:0.9311988353729248 D_loss:0.002220453228801489
Epoch:1 batch[240/3750] G_loss:0.9237613677978516 D_loss:0.00151491

Epoch:1 batch[2080/3750] G_loss:0.9991474151611328 D_loss:1.8260855938478926e-07
Epoch:1 batch[2100/3750] G_loss:0.9988434314727783 D_loss:3.362059999290068e-07
Epoch:1 batch[2120/3750] G_loss:0.9983267188072205 D_loss:7.087974154273979e-07
Epoch:1 batch[2140/3750] G_loss:0.9966883659362793 D_loss:2.764684950307128e-06
Epoch:1 batch[2160/3750] G_loss:0.9694362282752991 D_loss:0.00024024075537454337
Epoch:1 batch[2180/3750] G_loss:0.9848588705062866 D_loss:5.818477438879199e-05
Epoch:1 batch[2200/3750] G_loss:0.9536355137825012 D_loss:0.0005527095054276288
Epoch:1 batch[2220/3750] G_loss:0.9755405187606812 D_loss:0.004553710110485554
Epoch:1 batch[2240/3750] G_loss:0.9990506768226624 D_loss:2.324008931964272e-07
Epoch:1 batch[2260/3750] G_loss:0.9984657764434814 D_loss:6.122447189227387e-07
Epoch:1 batch[2280/3750] G_loss:0.9911791682243347 D_loss:2.0127672542002983e-05
Epoch:1 batch[2300/3750] G_loss:0.9539170265197754 D_loss:0.0005462287808768451
Epoch:1 batch[2320/3750] G_loss:0.9999

Epoch:2 batch[380/3750] G_loss:0.9999985694885254 D_loss:5.698677749306491e-13
Epoch:2 batch[400/3750] G_loss:0.9999986886978149 D_loss:4.619292144972253e-13
Epoch:2 batch[420/3750] G_loss:0.9999986886978149 D_loss:4.2817610351394975e-13
Epoch:2 batch[440/3750] G_loss:0.9999986886978149 D_loss:4.569202546704509e-13
Epoch:2 batch[460/3750] G_loss:0.9999984502792358 D_loss:9.99854170811998e-13
Epoch:2 batch[480/3750] G_loss:0.9999985694885254 D_loss:6.597079653392068e-13
Epoch:2 batch[500/3750] G_loss:0.9999986290931702 D_loss:5.198325494018552e-13
Epoch:2 batch[520/3750] G_loss:0.9999986290931702 D_loss:5.002932204796473e-13
Epoch:2 batch[540/3750] G_loss:0.9999986290931702 D_loss:5.503326201868053e-13
Epoch:2 batch[560/3750] G_loss:0.9999986290931702 D_loss:5.927213472638193e-13
Epoch:2 batch[580/3750] G_loss:0.9999986886978149 D_loss:4.2049073641428625e-13
Epoch:2 batch[600/3750] G_loss:0.9999986886978149 D_loss:4.675277092552887e-13
Epoch:2 batch[620/3750] G_loss:0.9999986886978149 D

Epoch:2 batch[2440/3750] G_loss:0.9999982714653015 D_loss:8.359580389027954e-13
Epoch:2 batch[2460/3750] G_loss:0.9999983906745911 D_loss:7.837651426305381e-13
Epoch:2 batch[2480/3750] G_loss:0.9999982714653015 D_loss:9.454711319412112e-13
Epoch:2 batch[2500/3750] G_loss:0.9999980926513672 D_loss:1.3410137800554112e-12
Epoch:2 batch[2520/3750] G_loss:0.9999980926513672 D_loss:1.2388958131950845e-12
Epoch:2 batch[2540/3750] G_loss:0.9999980926513672 D_loss:1.2041701186529807e-12
Epoch:2 batch[2560/3750] G_loss:0.9999980926513672 D_loss:1.3038781209243067e-12
Epoch:2 batch[2580/3750] G_loss:0.9999982714653015 D_loss:8.500817779734382e-13
Epoch:2 batch[2600/3750] G_loss:0.999998152256012 D_loss:1.0795344652925198e-12
Epoch:2 batch[2620/3750] G_loss:0.9999977946281433 D_loss:2.061487916135696e-12
Epoch:2 batch[2640/3750] G_loss:0.9999979734420776 D_loss:1.3327003346372268e-12
Epoch:2 batch[2660/3750] G_loss:0.9999980330467224 D_loss:1.2671740826780686e-12
Epoch:2 batch[2680/3750] G_loss:0.

Epoch:3 batch[960/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[980/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1000/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1020/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1040/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1060/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1080/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1100/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1120/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1140/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1160/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1180/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1200/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1220/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1240/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1260/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1280/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1300/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1320/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1340/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1360/3750] G_loss:0.0 D_loss:1.0
Epoch:3 batch[1

Epoch:4 batch[720/3750] G_loss:0.0 D_loss:1.0
Epoch:4 batch[740/3750] G_loss:0.0 D_loss:1.0
Epoch:4 batch[760/3750] G_loss:0.0 D_loss:1.0
Epoch:4 batch[780/3750] G_loss:0.0 D_loss:1.0
Epoch:4 batch[800/3750] G_loss:0.0 D_loss:1.0
Epoch:4 batch[820/3750] G_loss:0.0 D_loss:1.0
Epoch:4 batch[840/3750] G_loss:0.0 D_loss:1.0
Epoch:4 batch[860/3750] G_loss:0.0 D_loss:1.0
Epoch:4 batch[880/3750] G_loss:0.0 D_loss:1.0


KeyboardInterrupt: 

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])
