<a href="https://colab.research.google.com/github/bgeunjo/cGAN/blob/master/cGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch import nn,cuda,optim,cat
from torchvision import transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from os import path
#from google.colab import drive
#
#notebooks_dir_name = 'notebooks'
#drive.mount('/content/gdrive')
#notebooks_base_dir = path.join('./gdrive/My Drive/', notebooks_dir_name)
#if not path.exists(notebooks_base_dir):
#  print('Check your google drive directory. See you file explorer')
# Settings

download_root='mnist'
stored_path='images'
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,),std=(0.5,))
])
device= 'cuda' if cuda.is_available() else 'cpu'

# Params setting

leraing_rate=0.0001
batch_size=100


# Dataset
train_set=MNIST(download_root,train=True,transform=transform,download=True)

# Dataloader
train_loader=DataLoader(train_set,batch_size=batch_size,shuffle=True)

# Image_dir
import os
import imageio

if not os.path.isdir(stored_path):
    os.makedirs(stored_path,exist_ok=True)

# Model
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        #self.label=nn.Embedding(10,)
        def gen_block(in_features,out_features):
            layers=[nn.Linear(in_features,out_features)]
            layers.append(nn.ReLU())
            layers.append(nn.Dropout())
            return layers
        self.generator=nn.Sequential(
            *gen_block(110,128),
            *gen_block(128,256),
            *gen_block(256,512),
            *gen_block(512,1024),
            nn.Linear(1024,784),
            nn.Tanh()
        )
    def forward(self,z,label):
        z=cat([z,label],1)
        z=self.generator(z)
        return z

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def disc_block(in_features,out_features):
            layers=[nn.Linear(in_features,out_features)]
            layers.append(nn.ReLU())
            layers.append(nn.Dropout())
            return layers
        self.discriminator=nn.Sequential(
            *disc_block(794,1024),
            *disc_block(1024,512),
            *disc_block(512,256),
            nn.Linear(256,1),
            nn.Sigmoid()   
        )
    def forward(self,x,label):
        x=x.view(x.size(0),-1)
        x=cat([x,label],1)
        x=self.discriminator(x)
        return x

Gen=Generator().to(device)
Discrim=Discriminator().to(device)

# Loss & Optim
criterion=nn.BCELoss()

G_optimizer=optim.Adam(Gen.parameters(),lr=leraing_rate)
D_optimizer=optim.Adam(Discrim.parameters(),lr=leraing_rate)

# One_Hot

def oneHot(label,len_label=10): # label : 0 ~ 9
    # fills [100,10] with 0, only 1 at label ex ) [0,0,0,1,0,0,0,0,0,0]
    one_hot=torch.zeros(label.size(0),len_label,dtype=torch.int64).to(device)
    one_hot=one_hot.scatter(1,label.unsqueeze(1),1)
    return Variable(one_hot).to(device)

# Train
def train(epoch):
    for batch_idx,(data,target) in enumerate(train_loader):
        with torch.autograd.set_detect_anomaly(True):
            batch_size=data.size(0)
            
            fake_correct=Variable(torch.zeros(batch_size,1)).to(device)
            real_correct=Variable(torch.ones(batch_size,1)).to(device)
            z=torch.randn(batch_size, 100,device=device)
            z_label=torch.randint(10,(batch_size,),dtype=torch.int64).to(device)
            z_one_hot=oneHot(z_label,len_label=10)
            data,target=Variable(data).to(device),Variable(target).to(device)

            one_hot=oneHot(target.to(device),len_label=10) 

            # Gen 학습
            gen_img=Gen(z,z_one_hot)
            G_optimizer.zero_grad()
            G_loss=criterion(Discrim(gen_img,z_one_hot),real_correct)
            G_loss.backward()
            G_optimizer.step()
            # Discrim 학습
            # 진짜 이미지를 진짜로 판별할 수 있게 학습
            real_output=Discrim(data,one_hot)
            D_real_loss=criterion(real_output,real_correct)

            # 가짜 이미지를 가짜로 판별할 수 있게 학습
            fake_output=Discrim(gen_img.detach().to(device),one_hot) # Gen은 이미 학습해서 다시 학습 안 시키게 detach()
            D_optimizer.zero_grad()
            D_fake_loss=criterion(fake_output,fake_correct)
            D_loss=(D_real_loss+D_fake_loss)/2
            D_loss.backward()
            D_optimizer.step()

            batch_finish=epoch * len(train_loader) + batch_idx
            if (batch_finish) % 400 == 0:
                print("[Epoch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, 200, D_loss.item(), G_loss.item())
                )
    if (epoch+1) % 10 == 0:
        z=torch.randn(100, 100,device=device)
        z_label=torch.zeros(100,dtype=torch.int64).to(device) 
        for i in range(100):
            z_label[i] = i%10
        z_one_hot=oneHot(z_label,len_label=10)
        gen_img = Gen(z, z_one_hot)
        gen_img = gen_img.reshape([batch_size, 1, 28, 28])
        img_grid = make_grid(gen_img, nrow=10, normalize=True)
        save_image(img_grid, "images/result_%d.png"%(epoch+1)) 
if __name__ == "__main__":
    for epoch in range(200):
        train(epoch)
    images=[]
    for file_name in os.listdir(stored_path):
        images.append(imageio.imread(file_name))
        imageio.mimsave('result.gif',images)

[Epoch 0/200] [D loss: 0.679655] [G loss: 0.665070]
[Epoch 0/200] [D loss: 0.051774] [G loss: 6.450368]
[Epoch 1/200] [D loss: 0.013593] [G loss: 4.170480]
[Epoch 2/200] [D loss: 0.010070] [G loss: 5.511332]
[Epoch 2/200] [D loss: 0.022844] [G loss: 5.276737]
[Epoch 3/200] [D loss: 0.062123] [G loss: 3.883190]
[Epoch 4/200] [D loss: 0.017367] [G loss: 4.138217]
[Epoch 4/200] [D loss: 0.045535] [G loss: 3.781675]
[Epoch 5/200] [D loss: 0.053672] [G loss: 4.688679]
[Epoch 6/200] [D loss: 0.129361] [G loss: 3.425082]
[Epoch 6/200] [D loss: 0.163527] [G loss: 3.822104]
[Epoch 7/200] [D loss: 0.234814] [G loss: 3.883720]
[Epoch 8/200] [D loss: 0.197073] [G loss: 2.999124]
[Epoch 8/200] [D loss: 0.258423] [G loss: 2.783215]
[Epoch 9/200] [D loss: 0.166576] [G loss: 2.953043]
[Epoch 10/200] [D loss: 0.178604] [G loss: 2.902812]
[Epoch 10/200] [D loss: 0.172020] [G loss: 2.295790]
[Epoch 11/200] [D loss: 0.367456] [G loss: 2.311677]
[Epoch 12/200] [D loss: 0.284416] [G loss: 2.242724]
[Epoch 1

FileNotFoundError: ignored

In [15]:
for file_name in os.listdir('images'):
  print (1)
  images.append(imageio.imread(file_name))
  imageio.mimsave('result.gif',images)

1


FileNotFoundError: ignored

In [None]:
"h"