In [29]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
import datetime
import scipy.misc
import imageio


# 1.Generative Adversarial Network(GAN)

* Architecture of GAN
<div align="center">
    <img src="https://cdn-images-1.medium.com/v2/resize:fit:2000/1*39Nnni_nhPDaLu9AnTLoWw.png",alt="GAN">
</div>

* Generator
<div align="center">
    <img src="https://cdn-images-1.medium.com/v2/resize:fit:1000/1*7i9iCdLZraZkrMy1-KADrA.png",alt="Generator">
</div>

* Discriminator
<div align="center">
    <img src="https://www.researchgate.net/profile/Sinan-Kaplan/publication/319093376/figure/fig20/AS:526859935731712@1502624605127/Architecture-of-proposed-discriminator-network-which-is-part-of-GAN-based-on-CNN-units.png",alt="Discriminator">
</div>

In [2]:
Model_Name='ConditionGAN'
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
def to_cuda(x):
    return x.to(device)

In [4]:
def to_onehot(x,num_classes=10):
    assert isinstance(x,int) or isinstance(x,(torch.LongTensor,torch.cuda.LongTensor))
    if isinstance(x,int):
        c=torch.zeros(1,num_classes).long()
        c[0][x]=1
    else:
        x=x.cpu()
        c=torch.LongTensor(x.size(0),num_classes)
        c.zero_()
        c.scatter_(1,x,1)
    return c

In [5]:
def get_simple_image(G,n_noise=100):
    for num in range(10):
        c=to_cuda(to_onehot(num))
        for i in range(10):
            z=to_cuda(torch.randn(1,n_noise))
            y_hat=G(z,c)
            line_img=torch.cat((line_img,y_hat.view(28,28)),dim=1) if i>0 else y_hat.view(28,28)
        all_img=torch.cat((all_img,line_img),dim=0) if num>0 else line_img
    img=all_img.cpu().data.numpy()
    return img

    

In [7]:
class Discriminator(nn.Module):
    
    def __init__(self,input_size=784,label_size=10,num_classes=1):
        super(Discriminator,self).__init__()
        self.layer1=nn.Sequential(
            nn.Linear(input_size+label_size,200),
            nn.ReLU(),
            nn.Dropout()
        )
        self.layer2=nn.Sequential(
            nn.Linear(200,200),
            nn.ReLU(),
            nn.Dropout()
        )
        self.layer3=nn.Sequential(
            nn.Linear(200,num_classes),
            nn.Sigmoid()
        )
    
    def forward(self,x,y):
        x,y=x.view(x.size(0),-1),y.view(y.size(0),-1).float()
        v=torch.cat((x,y),1)
        y_=self.layer1(v)
        y_=self.layer2(y_)
        y_=self.layer3(y_)
        return y_

In [8]:
class Generator(nn.Module):

    def __init__(self,input_size=100,label_size=10,num_classes=784):
        super(Generator,self).__init__()
        self.layer=nn.Sequential(
            nn.Linear(input_size+label_size,200),
            nn.LeakyReLU(0.2),
            nn.Linear(200,200),
            nn.LeakyReLU(0.2),
            nn.Linear(200,num_classes),
            nn.Tanh()
        )
    
    def forward(self,x,y):
        x,y=x.view(x.size(0),-1),y.view(y.size(0),-1).float()
        v=torch.cat((x,y),1)
        y_=self.layer(v)
        y_=y_.view(x.size(0),1,28,28)
        return y_

In [9]:
D=to_cuda(Discriminator())
G=to_cuda(Generator())

In [19]:
transform=transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,),(0.5,))]
)

In [20]:
mnist=datasets.MNIST(root='./data/',train=True,transform=transform,download=True)

In [21]:
batch_size=64
condition_size=10

In [22]:
data_loader=DataLoader(dataset=mnist,batch_size=batch_size,shuffle=True,drop_last=True)

In [23]:
criterion=nn.BCELoss()
D_opt=torch.optim.Adam(D.parameters())
G_opt=torch.optim.Adam(G.parameters())

In [33]:
max_epoch=50
step=0
n_critic=5
n_noise=100


In [34]:
D_labels=to_cuda(torch.ones(batch_size).view(batch_size,1))
D_fakes=to_cuda(torch.zeros(batch_size).view(batch_size,1))

In [36]:
for epoch in range(max_epoch):
    for idx,(images,labels) in enumerate(data_loader):
        step+=1

        x=to_cuda(images)
        y=labels.view(batch_size,1)
        y=to_cuda(to_onehot(y))
        x_outputs=D(x,y)
        D_x_loss=criterion(x_outputs,D_labels)

        z=to_cuda(torch.randn(batch_size,n_noise))
        z_outputs=D(G(z,y),y)
        D_z_loss=criterion(z_outputs,D_fakes)
        D_loss=D_x_loss+D_z_loss

        D.zero_grad()
        D_loss.backward()
        D_opt.step()

        if step % n_critic==0:
            z=to_cuda(torch.randn(batch_size,n_noise))
            z_outputs=D(G(z,y),y)
            G_loss=criterion(z_outputs,D_labels)

            G.zero_grad()
            G_loss.backward()
            G_opt.step()
        
        if step%1000==0:
            print('Epoch:{}/{},step:{},D Loss:{},G Loss:{}'.format(epoch,max_epoch,step,D_loss,G_loss))
        
        if epoch%5==0:
            G.eval()
            img=get_simple_image(G)
            imageio.imwrite("{}_epoch_{}_type1.jpg".format(Model_Name,epoch),img)
            G.train()



Epoch:0/50,step:47000,D Loss:1.3647253513336182,G Loss:0.771939754486084




Epoch:1/50,step:48000,D Loss:1.2177232503890991,G Loss:0.8635991811752319
Epoch:2/50,step:49000,D Loss:1.1521919965744019,G Loss:1.0906529426574707
Epoch:3/50,step:50000,D Loss:1.2815239429473877,G Loss:0.8626819849014282
Epoch:4/50,step:51000,D Loss:1.2456295490264893,G Loss:0.9893711805343628




Epoch:5/50,step:52000,D Loss:1.2127619981765747,G Loss:0.9344995617866516




Epoch:6/50,step:53000,D Loss:1.3382790088653564,G Loss:0.8188465237617493
Epoch:7/50,step:54000,D Loss:1.2356550693511963,G Loss:0.7950323820114136
Epoch:8/50,step:55000,D Loss:1.2075514793395996,G Loss:0.889132022857666
Epoch:9/50,step:56000,D Loss:1.227764368057251,G Loss:0.9141058921813965




Epoch:10/50,step:57000,D Loss:1.2319614887237549,G Loss:0.7829486131668091




Epoch:11/50,step:58000,D Loss:1.2026115655899048,G Loss:0.7664190530776978
Epoch:12/50,step:59000,D Loss:1.1460590362548828,G Loss:0.9277381300926208
Epoch:14/50,step:60000,D Loss:1.1034929752349854,G Loss:1.2138373851776123




Epoch:15/50,step:61000,D Loss:1.2300864458084106,G Loss:0.9766071438789368




Epoch:16/50,step:62000,D Loss:1.21044921875,G Loss:0.8836027383804321
Epoch:17/50,step:63000,D Loss:1.2785398960113525,G Loss:0.971884548664093
Epoch:18/50,step:64000,D Loss:1.26224684715271,G Loss:0.9984557032585144
Epoch:19/50,step:65000,D Loss:1.2951937913894653,G Loss:0.8642131090164185




Epoch:20/50,step:66000,D Loss:1.281240701675415,G Loss:0.8072707056999207




Epoch:21/50,step:67000,D Loss:1.2553575038909912,G Loss:0.9531993269920349
Epoch:22/50,step:68000,D Loss:1.2519567012786865,G Loss:0.8864179849624634
Epoch:23/50,step:69000,D Loss:1.1822668313980103,G Loss:0.9336001873016357
Epoch:24/50,step:70000,D Loss:1.1838829517364502,G Loss:1.0819480419158936




Epoch:25/50,step:71000,D Loss:1.2078381776809692,G Loss:0.9553574323654175




Epoch:26/50,step:72000,D Loss:1.1896461248397827,G Loss:0.9759012460708618
Epoch:27/50,step:73000,D Loss:1.189794898033142,G Loss:1.0154601335525513
Epoch:28/50,step:74000,D Loss:1.1271018981933594,G Loss:0.9355926513671875




Epoch:30/50,step:75000,D Loss:1.2730956077575684,G Loss:0.9699745178222656




Epoch:31/50,step:76000,D Loss:1.1653053760528564,G Loss:0.9049778580665588
Epoch:32/50,step:77000,D Loss:1.1907835006713867,G Loss:0.9782962799072266
Epoch:33/50,step:78000,D Loss:1.2393507957458496,G Loss:0.8867018222808838
Epoch:34/50,step:79000,D Loss:1.1724319458007812,G Loss:1.0954997539520264




Epoch:35/50,step:80000,D Loss:1.2534093856811523,G Loss:1.1066722869873047




Epoch:36/50,step:81000,D Loss:1.1462063789367676,G Loss:0.978063702583313
Epoch:37/50,step:82000,D Loss:1.2466256618499756,G Loss:1.0273524522781372
Epoch:38/50,step:83000,D Loss:1.4348032474517822,G Loss:0.9771679639816284
Epoch:39/50,step:84000,D Loss:1.3304972648620605,G Loss:1.0896868705749512




Epoch:40/50,step:85000,D Loss:1.1237716674804688,G Loss:0.9892462491989136




Epoch:41/50,step:86000,D Loss:1.141232967376709,G Loss:0.891743540763855
Epoch:42/50,step:87000,D Loss:1.28325617313385,G Loss:0.8048086762428284
Epoch:43/50,step:88000,D Loss:1.1756000518798828,G Loss:0.9630880951881409
Epoch:44/50,step:89000,D Loss:1.0096538066864014,G Loss:1.1114394664764404




Epoch:46/50,step:90000,D Loss:1.1958842277526855,G Loss:1.000011920928955
Epoch:47/50,step:91000,D Loss:1.1787924766540527,G Loss:0.9955776929855347
Epoch:48/50,step:92000,D Loss:1.238279104232788,G Loss:0.9955602884292603
Epoch:49/50,step:93000,D Loss:1.1818616390228271,G Loss:0.8477843999862671
