In [14]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import os
from collections import OrderedDict
import torchvision.utils as v_utils
import itertools
import FashionMnist as FM
from skimage import io
import imageio

In [2]:
data_set = FM.FashionMNIST('./FashionMnist_data/raw',train=True,transform=transforms.ToTensor(),
                     target_transform=None)

data = torch.utils.data.DataLoader(data_set, batch_size=128,shuffle=True,
                                          num_workers=0,drop_last=True)

print(len(data))

468


In [3]:
def deconv(ch_in,ch_out,k_size=4,stride=2,pad=1,bn=True,momentum=0.9):
    layers = []
    layers.append(nn.ConvTranspose2d(ch_in,ch_out,k_size,stride,padding=pad))
    if bn:
        layers.append(nn.BatchNorm2d(ch_out,momentum=momentum))
    return nn.Sequential(*layers)

def conv(ch_in, ch_out, k_size=4, stride=2, pad=1, bn=True,momentum=0.9):
    layers = []
    layers.append(nn.Conv2d(ch_in,ch_out,k_size,stride,padding=pad))
    if bn:
        layers.append(nn.BatchNorm2d(ch_out,momentum=momentum))
    return nn.Sequential(*layers)

def fullycon(input_size,out_size,bn=True,momentum=0.9):
    layers = []
    layers.append(nn.Linear(input_size,out_size))
    if bn:
        layers.append(nn.BatchNorm2d(out_size,momentum=momentum))
    return nn.Sequential(*layers)

def normal_init(m,mean,std):
    if isinstance(m,nn.ConvTranspose2d) or isinstance(m,nn.Conv2d):
        m.weight.data.normal_(mean,std)
        b.bias.data.zero_()
        
def bn_init(m):
    if isinstance(m,nn.BatchNorm2d):
        m.weight.data.fill(1)
        m.bias.data.zero_()
        

In [4]:
class G(nn.Module):
    def __init__(self,dim_z=100, dim_c=10, input_size=28, colch=1,conv_dim=16):
        super(G,self).__init__()
        
        self.dim_z = dim_z
        self.dim_c = dim_c
        self.input_size = input_size
        self.colcn = colch
        self.conv_dim = conv_dim
        self.start_size = int(input_size/4)
        
        self.fc1 = fullycon(dim_c+dim_z,1024)
        self.fc2 = fullycon(1024,conv_dim*8*(self.start_size**2))
        self.deconv1 = deconv(conv_dim*8,conv_dim*2)
        self.deconv2 = deconv(conv_dim*2,1,bn=False)
        self.weight_init()
        
    def weight_init(self,mean=0,std=0.02):
        for m in self._modules:
            normal_init(m,mean,std)
            bn_init(m)
    
    def forward(self,z,c):

        zc = Variable(torch.cat([z,c],1)).cuda()
        out = F.relu(self.fc1(zc))
        out = F.relu(self.fc2(out))
        out = out.view(-1,self.conv_dim*8,
                              self.start_size,self.start_size)

        out = F.relu(self.deconv1(out))
        out = F.tanh(self.deconv2(out))
        return out

In [5]:
class D(nn.Module):
    def __init__(self,input_size=28,colch=1,conv_dim=16,leaky=0.2):
        super(D,self).__init__()
        
        self.input_size = input_size
        self.colch = colch
        self.conv_dim = 16
        self.leaky = leaky
        
        self.conv1 = conv(colch,conv_dim*4,4,2,bn=True) # [batch_size,1,28,28] -> [batch_size,16,14,14]
        self.conv2 = conv(conv_dim*4,conv_dim*8,4,2,bn=True) # [batch_size,16,14,14] -> [batch_size,32,7,7]
        self.fc1 = fullycon(conv_dim*8*int(input_size/4)**2,1024)
        self.fc2 = fullycon(1024,11,bn=False)
        self.weight_init()
        
    def weight_init(self,mean=0,std=0.02):
        for m in self._modules:
            normal_init(m,mean,std)
            bn_init(m)
            
    def forward(self,x):
        out = F.leaky_relu(self.conv1(x),self.leaky)
        out = F.leaky_relu(self.conv2(out),self.leaky)
        out = out.view(-1,self.conv_dim*8*int(self.input_size/4)**2)      
        out = F.leaky_relu(self.fc1(out),self.leaky)     
        out = self.fc2(out)
        latent = out[:,1:12]
        out = F.sigmoid(out[:,0])
        return out, latent

In [6]:
class InfoGan():
    def __init__(self,input_size=28,dim_z=100,dim_c=10,colch=1,conv_dim=16):
        self.input_size = input_size
        self.dim_z = dim_z
        self.dim_c = dim_c
        self.colch = colch
        self.conv_dim = conv_dim
        
        self.G = None
        self.D = None
        self.G_optim = None
        self.D_optim = None
        self.Info_optim = None
        self.G_losses = []
        self.D_losses = []
        self.Info_losses = []
        
        self.build_model()
        
    def build_model(self):
        self.G = G(self.dim_z,self.dim_c,self.input_size,self.colch,self.conv_dim).cuda()
        self.D = D(self.input_size,self.colch,self.conv_dim).cuda()
        
    def train(self, data, batch_size=128, epoch=10, lr=0.0002,check_step=100,
             view_size=100):
        
        batch_length = data.__len__()
        
        self.G_optim = torch.optim.Adam(self.G.parameters(), lr=lr)
        self.D_optim = torch.optim.Adam(self.D.parameters(), lr=lr)
        self.Info_optim = torch.optim.Adam(itertools.chain(self.G.parameters(),self.D.parameters()),lr=lr)
        
        loss_func = nn.BCELoss()
        loss_CE = nn.CrossEntropyLoss()
        counter=0
        self.fixed_z, self.fixed_c = self.test_sample(view_size,self.dim_z,self.dim_c)
        
        if not os.path.isdir("./infogan_result"):
            os.mkdir("./infogan_result")
            
        for e in range(epoch):
            for i,[batch_x,_] in enumerate(data):
                
                batch_x = Variable(batch_x).cuda()
                # update D
                z,c = self.noise_sample(batch_size,self.dim_z,self.dim_c)
                self.batch_x = batch_x

                self.D_optim.zero_grad()
                
                G_fake = self.denorm(self.G.forward(z,c))
                D_fake,L_fake = self.D.forward(G_fake)
                D_real,L_real = self.D.forward(batch_x)
                
                D_loss = (torch.sum(loss_func(D_fake,Variable(torch.zeros(batch_size)).cuda())) + 
                         torch.sum(loss_func(D_real,Variable(torch.ones(batch_size)).cuda())))
                         
                D_loss.backward(retain_variables=True)
                self.D_optim.step()
                
                # update G
                z,c = self.noise_sample(batch_size,self.dim_z,self.dim_c)
                
                self.G_optim.zero_grad()
                
                G_fake = self.denorm(self.G.forward(z,c))
                D_fake, L_fake = self.D.forward(G_fake)
                
                G_loss = torch.sum(loss_func(D_fake,Variable(torch.ones(batch_size)).cuda()))
                
                G_loss.backward(retain_variables=True)
                self.G_optim.step()
                
                # update infoLoss
                self.Info_optim.zero_grad()
                self.L = L_fake
                self.c = c
                
                Info_loss = loss_CE(L_fake,Variable(torch.max(c,1)[1]).cuda())
                Info_loss.backward()
                self.Info_optim.step()
                
                counter+=1
                
                if counter % check_step ==0:
                    
                    self.D_losses.append(D_loss.cpu().data.numpy())
                    self.G_losses.append(G_loss.cpu().data.numpy())
                    self.Info_losses.append(Info_loss.cpu().data.numpy())
                    print("Epoch [%d/%d], Step [%d/%d], D_loss : %.4f, G_loss : %.4f" % 
                          (e+1,epoch,i+1,batch_length,D_loss.cpu().data.numpy(),G_loss.cpu().data.numpy()))
                    view = self.denorm(self.G.forward(self.fixed_z, self.fixed_c))
                    v_utils.save_image(view.data,"./infogan_result/gen_{}_{}.png".format(e,i),nrow=10)

    def noise_sample(self,batch_size,dim_z,dim_c):
        idx = np.random.randint(dim_c,size=batch_size)
        c = np.zeros([batch_size,dim_c],dtype=np.float32)
        c[range(batch_size),idx] = 1.
        c = torch.from_numpy(c)
        
        z = torch.Tensor(batch_size,dim_z).uniform_(-1,1)
        return z,c
    
    def test_sample(self,view_size=100,dim_z=10,dim_c=10):
        c = np.zeros([view_size,dim_c],dtype=np.float32)

        for i in range(view_size):
            for j in range(dim_c):
                if i//dim_c ==j:
                    c[i,j] = 1.
        c = torch.from_numpy(c)
        
        z = torch.Tensor(view_size,dim_z).uniform_(-1,1)
        return z,c
        
    
    def denorm(self, x):
        """Convert range (-1, 1) to (0, 1)"""
        out = (x + 1) / 2
        return out.clamp(0, 1)

In [7]:
infogan = InfoGan()

infogan.train(data,epoch=50,lr=0.001)

# 애니메이션 만들기

In [17]:
# setting a directory
result_dir = './infogan_result'

# data load
result_load = io.ImageCollection(result_dir + '/*.png')

images=[]
# image들을 gif파일로 저장
for i, result in enumerate(result_load):
        if i%5 ==0:
            images.append(result)

imageio.mimsave('generation_animation.gif',images,fps=5)
