In [4]:
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
plt.ion()
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import time
from graphviz import Digraph
import torchvision
import scipy.misc as misc
from tensorflow.examples.tutorials.mnist import input_data
import os
import sys

  from ._conv import register_converters as _register_converters


In [5]:
def display(g,img,x,meta):
    def tensor2array(x):
        return np.array(x)
    def variable2array(x):
        return np.array(x.data)
    def tensors2array(xs):
        return [tensor2array(_) for _ in xs]
    def variables2array(xs):
        return [variable2array(_) for _ in xs]
    print(meta)
    g,img,x=variables2array([g,img,x])
    g=g.reshape((10,10,4))
    parsed_g=np.unravel_index(np.argmax(g),g.shape)
    print(parsed_g)
    plt.figure()
    plt.subplot(121)
    plt.title('Correct Image')
    plt.imshow(img.reshape((56,56)),cmap='gray')
    plt.subplot(122)
    plt.title('Generated Image')
    plt.imshow(x.reshape((56,56)),cmap='gray')
    
#     print("real d_output:{}, fake d_output:{}".format(d_r, d_f))

In [6]:

class doodleGenerator:
    def __init__(self):
        self.edges=['u','r','d','l']
        self.MNIST= input_data.read_data_sets("../MNIST_data/", one_hot=True, reshape=[])
        self.MNIST_TRAIN_NUM = self.MNIST.train.num_examples
        self.MNIST_TRAIN_SET = self.MNIST.train.images, self.MNIST.train.labels, np.argmax(self.MNIST.train.labels, axis=1)
        
    def getRecordInfo(self,ind):
        return self.MNIST_TRAIN_SET[0][ind].reshape((28, 28)), self.MNIST_TRAIN_SET[1][ind], self.MNIST_TRAIN_SET[2][ind]

    def fetchRandomPair(self,shape=(10,10,4)):
        g = np.zeros(shape=shape)
        ind1, ind2 = np.random.randint(0, self.MNIST_TRAIN_NUM), np.random.randint(0, self.MNIST_TRAIN_NUM)
        pos = np.random.randint(0, 4)
        img1, label1, num1 = self.getRecordInfo(ind1)
        img2, label2, num2 = self.getRecordInfo(ind2)
        g[num1, num2, pos] = 1
        plt.subplot(131)
        plt.imshow(img1, cmap='gray')
        plt.subplot(132)
        plt.imshow(img2, cmap='gray')
        plt.subplot(133)
        newImg = {
            0: np.concatenate([np.zeros((56, 14)), np.concatenate([img2, img1], axis=0), np.zeros((56, 14))], axis=1),
            1: np.concatenate([np.zeros((14, 56)), np.concatenate([img1, img2], axis=1), np.zeros((14, 56))], axis=0),
            2: np.concatenate([np.zeros((56, 14)), np.concatenate([img1, img2], axis=0), np.zeros((56, 14))], axis=1),
            3: np.concatenate([np.zeros((14, 56)), np.concatenate([img2, img1], axis=1), np.zeros((14, 56))], axis=0)
        }.get(pos)
        return g, newImg


    def createDataSet(self,num=100000):
        G = []
        I = []
        for i in range(num):
            g, img = self.fetchRandomPair()
            G.append(g)
            I.append(img)
            if i%10000==0:
                print("{}/100000".format(i))

        I = np.array(I).reshape((num, 56, 56, 1))
        G = np.array(G)
        np.save('G' + str(num), G)
        np.save('I' + str(num), I)
        self.TRAIN_SET=[G,I]
        return G, I

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.g_conv1=nn.Conv2d(4,8,3,padding=1)
        self.z_conv1=nn.Conv2d(16,12,3,padding=1)
        self.z_pool1=nn.MaxPool2d(2,2)
        self.z_conv2=nn.Conv2d(12,8,7,stride=1)
        self.m_fc=nn.Linear(10*10*16,56*56*1)
        self.bn_g=nn.modules.BatchNorm2d(8)
        self.bn_z=nn.modules.BatchNorm2d(12)

    def forward(self, g,z):
        g=g.view(-1,4,10,10)
        z=z.view(-1,16,32,32)
#         print(g.shape,z.shape)
        gc1 = self.g_conv1(g)

        gc1=F.relu(gc1)
#         print('gc1.shape:',gc1.shape)
        gc1=self.bn_g(gc1)
        zc1 = self.z_pool1(self.z_conv1(z))
        zc1=F.relu(zc1)
#         print('zc1.shape',zc1.shape)        
        zc1=self.bn_z(zc1)
        zc2 = self.z_conv2(zc1)
        zc2=F.relu(zc2)
#         print('zc2.shape',zc2.shape)        
        merged = torch.cat([gc1, zc2], 1)
#         print('merged.shape',merged.shape)
        merged_r=merged.view(-1,16*10*10)
        o = self.m_fc(merged_r)
#         print('output.shape',o.shape)
        return o.view((-1,1,56, 56))

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.g_conv1 = nn.Conv2d(4, 8, 3, padding=1)
        self.x_conv1 = nn.Conv2d(1, 16, 3, padding=0)
        self.x_pool1 = nn.MaxPool2d(2, 2)
        self.x_conv2 = nn.Conv2d(16, 64, 8)

        self.m_conv = nn.Conv2d(72, 128, 3)
        self.m_fc1 = nn.Linear(1 *128*8*8, 2048)
        self.m_fc2 = nn.Linear(2048, 256)
        self.m_fc3 = nn.Linear(256, 2)
        self.d_softmax=nn.Softmax()
        
        self.bn_g=nn.modules.BatchNorm2d(num_features=8)
        self.bn_x=nn.modules.BatchNorm2d(num_features=16)

    def forward(self, g,x):
        g=g.view(-1,4,10,10)
        x=x.view(-1,1,56,56)
        gc1 = self.g_conv1(g)
        gc1=F.relu(gc1)
        gc1=self.bn_g(gc1)
#         print('gc1.shape:',gc1.shape)
        xc1=self.x_conv1(x)
#         print('xc1.shape:',xc1.shape)
        xp1=F.relu(self.x_pool1(xc1))
        xp1=self.bn_x(xp1)
#         print('xp1.shape:',xp1.shape)
        xc2=self.x_conv2(xp1)
#         print('xc2.shape:',xc2.shape)
        xp2=F.relu(self.x_pool1(xc2))
#         print('xp2.shape:',xp2.shape)
        merged = torch.cat([xp2, gc1], 1)
#         print('merged.shape:',merged.shape)
        mc = self.m_conv(merged)
#         print('mc.shape:',mc.shape)
        mc=mc.view(-1,128*8*8)
        m_fc1=self.m_fc1(mc)
#         print('m_fc1.shape',m_fc1.shape)
        m_fc2=self.m_fc2(m_fc1)
#         print('m_fc2.shape',m_fc2.shape)
        m_fc3 = self.m_fc3(m_fc2)
#         print('m_fc3.shape',m_fc3.shape)
        return self.d_softmax(m_fc3)

In [10]:
class cDCGAN(nn.Module):
    def __init__(self):
        super(cDCGAN, self).__init__()
        self.G=Generator()
        self.D=Discriminator()
        self.G_LOSS=nn.CrossEntropyLoss()
        self.rD_LOSS=nn.CrossEntropyLoss()
        self.fD_LOSS=nn.CrossEntropyLoss()
        self.use_gpu=torch.cuda.is_available()
        self.lr=0.001
        self.batch_size=100
        self.iters=1000
        self.epoch=200
        
        self.SAVE_DIR='./out/2/'
        try:
            os.makedirs(self.SAVE_DIR)
        except:
            pass

    def setTrainParas(self,batch_num,lr,iters,epoch):
        self.lr,self.batch_num,self.iters,self.epoch=lr,batch_num,iters,epoch

    def feedData(self,GI,ratio=0.8):
        G,I=GI
        num_of_records=G.shape[0]
        
        self.dataLoader={
            'train':[],
            'test':[]
        }
        batch_num=num_of_records//self.batch_size
        print("{} records, {} batches of size {} each".format(num_of_records,batch_num,self.batch_size))
        cur_batch_G=[]
        cur_batch_I = []
        batched_data=[]
        for i in range(num_of_records):
            cur_batch_G.append(G[i])
            cur_batch_I.append(I[i])
            if i %self.batch_size==self.batch_size-1:
                batched_data.append((np.array(cur_batch_G),np.array(cur_batch_I)))
                cur_batch_G=[]
                cur_batch_I=[]
        self.dataLoader['train']=batched_data[:int(batch_num*0.8)]
        self.dataLoader['test']=batched_data[int(batch_num*0.8):]
        print(self.dataLoader['train'][0][0][0].shape,self.dataLoader['train'][0][0][1].shape)
    def convert2Cuda(self,l):
        return [_.cuda() for _ in l]
            
    def trainNetwork(self):
        G_optimizer=optim.SGD(self.G.parameters(),lr=self.lr,momentum=0.9)
        D_optimizer=optim.SGD(self.D.parameters(),lr=self.lr,momentum=0.9)

        for  e in range(self.epoch):
            G_losses, D_losses = [], []
            eporch_starttime=time.time()
            print('Epoch {}/{}'.format(e+1,self.epoch))
            print("-"*10)
            for phase in ['train','test']:
                print('Current Phase:{}'.format(phase))

                if phase=='train':
                    G_optimizer.step()
                    D_optimizer.step()
                    self.D.train(True)
                    self.G.train(True)
                else:
                    self.D.train(False)
                    self.G.train(False)

                for i, data in enumerate(self.dataLoader[phase],0):
                    z=torch.randn((self.batch_size,32,32,16))
                
                    g,img=data
                    G_optimizer.zero_grad()
                    D_optimizer.zero_grad()
                    g,img,z=Variable(torch.Tensor(g)),Variable(torch.Tensor(img)),Variable(torch.Tensor(z))
                    rd_label=Variable(torch.LongTensor(np.zeros((self.batch_size))),requires_grad=False)
#                     rd_label.requires_grad=False
                    fd_label=Variable(torch.LongTensor(np.ones((self.batch_size))),requires_grad=False)
#                     fd_label.requires_grad=False

                    if self.use_gpu:
                        g,img,z,rd_label,fd_label,self.convert2Cuda([g,img,z,rd_label,fd_label])
                    
                    x=self.G(g,z)
                    r_d_out=self.D(g,img)
                    f_d_out=self.D(g,x)
                    
                    
                    real_d_loss=self.rD_LOSS(r_d_out,rd_label)
                    fake_d_loss=self.fD_LOSS(f_d_out,fd_label)
                    d_loss=real_d_loss+fake_d_loss
                    g_loss=self.G_LOSS(f_d_out,rd_label)

                    if i%50==0:
                        ind=np.random.randint(0,self.batch_size)
                        print("{}  G_loss: {}, D_loss:{}".format(phase,g_loss.data[0],d_loss.data[0]))
#                         display(g,img,x,meta=' ')
                        misc.imsave(self.SAVE_DIR+'{}_{}_{}_img.png'.format(e,phase,i),np.array(img[ind].data).reshape((56,56)))
                        misc.imsave(self.SAVE_DIR+'{}_{}_{}_x.png'.format(e,phase,i),np.array(x[ind].data).reshape((56,56)))
                    G_losses.append(g_loss)
                    D_losses.append(d_loss)
                    if phase=='train':
                        g_loss.backward(retain_graph=True)
                        G_optimizer.step()
                        d_loss.backward()
                        D_optimizer.step()
                
            if e%50==9:
                self.saveCheckpoint('{}'.format(e))
    def saveCheckpoint(self,e):
        self.D.save_state_dict('D_temp{}.pth.tar'.format(e))
        self.G.save_state_dict('G_temp{}.pth.tar'.format(e))

    def loadCheckpoint(self,e):
        self.D.load_state_dict(torch.load('D_temp{}.pth.tar'.format(e)))
        self.G.load_state_dict(torch.load('G_temp{}.pth.tar'.format(e)))
        

In [11]:
# dg=doodleGenerator()
# data=dg.createDataSet()
data=[np.load('G100000.npy'),np.load('I100000.npy')]
dcgan=cDCGAN()
dcgan.feedData(data)
dcgan.loadCheckpoint('')
dcgan.trainNetwork()

100000 records, 1000 batches of size 100 each
(10, 10, 4) (10, 10, 4)
Epoch 1/200
----------
Current Phase:train




train  G_loss: 0.657154381275177, D_loss:1.0658833980560303


KeyboardInterrupt: 

In [461]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')
        

In [1]:
import torch

False