In [1]:
import tensorflow as tf
import numpy as np
from matplotlib import pylab as plt
import cv2
from time import time
from time import sleep
from progressbar import ProgressBar
from mpl_toolkits.axes_grid1 import ImageGrid
import os,pickle
from ipywidgets import ToggleButtons
from ipywidgets import Button
from ipywidgets import interact
from IPython import display
from IPython.display import clear_output

In [2]:
%matplotlib inline
"""画像の読み込み"""
files = os.listdir('dataset/padded')
try:
    files.remove('.DS_Store')
except:
    pass
pad = 5
data = []
for name in files[:256]:
    d = cv2.imread("dataset/padded/"+name)
    d = d[pad:88-pad,pad:88-pad]
    d = cv2.resize(d,(64,64))
    d = cv2.cvtColor(d, cv2.COLOR_BGR2RGB)
    data.append(d)
data = np.array(data, dtype=np.float32)
data = data / 128.0 - 1

In [None]:
"""ネットワークのセット"""
nz = 100
n0 = 256
n1 = 128
n2 = 64
n3 = 32
n4 = 16
nc = 3
f_size = 4

sess = tf.InteractiveSession()

def leaky(x):
    return(tf.maximum(0.2 * x, x))

def regularizer(x):
    return(0.0001*tf.nn.l2_loss(x))

G_feed = tf.placeholder(tf.float32, shape=[None, nz])

class Generator:
    def __init__(self):
        self.reuse = False
        
    def __call__(self, G_feed ,mid=False):
        G_feed = tf.convert_to_tensor(G_feed)
        with tf.variable_scope('g', reuse=self.reuse):
            """reshape"""
            out = tf.layers.dense(G_feed, n0*f_size*f_size, kernel_regularizer=regularizer)
            out = tf.reshape(out, [-1, f_size, f_size, n0])
            out = tf.layers.batch_normalization(out,training=True)
            out = tf.nn.relu(out)
            """deconv0"""
            out = tf.layers.conv2d_transpose(out, n1, [5, 5], strides=(2, 2), padding='SAME', kernel_regularizer=regularizer)
            out = tf.nn.relu(tf.layers.batch_normalization(out, training=True))
            """deconv1"""
            out = tf.layers.conv2d_transpose(out, n2, [5, 5], strides=(2, 2), padding='SAME', kernel_regularizer=regularizer)
            out = tf.nn.relu(tf.layers.batch_normalization(out, training=True))
            """deconv2"""
            out = tf.layers.conv2d_transpose(out, n3, [5, 5], strides=(2, 2), padding='SAME', kernel_regularizer=regularizer)
            out = tf.nn.relu(tf.layers.batch_normalization(out, training=True))
            if mid:
                return(tf.tanh(out))
            """deconv3"""
            out = tf.layers.conv2d_transpose(out, nc, [5, 5], strides=(2, 2), padding='SAME', kernel_regularizer=regularizer)
            out = tf.tanh(out)
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
        self.reuse = True
        return(out)
    
    def midout(self, G_feed):
        return(self(G_feed, mid=True))

D_feed = tf.placeholder(tf.float32, shape=[None, 64, 64, nc])

class Discriminator:
    def __init__(self):
        self.reuse = False
        
    def __call__(self, D_feed):
        D_feed = tf.convert_to_tensor(D_feed)
        with tf.variable_scope('d', reuse=self.reuse):
            """conv1"""
            out = tf.layers.conv2d(D_feed, n4, [5, 5], strides=(2, 2), padding='SAME', kernel_regularizer=regularizer)
            out = leaky(tf.layers.batch_normalization(out, training=True))
            """conv2"""
            out = tf.layers.conv2d(out, n3, [5, 5], strides=(2, 2), padding='SAME', kernel_regularizer=regularizer)
            out = leaky(tf.layers.batch_normalization(out, training=True))
            """conv3"""
            out = tf.layers.conv2d(out, n2, [5, 5], strides=(2, 2), padding='SAME', kernel_regularizer=regularizer)
            out = leaky(tf.layers.batch_normalization(out, training=True))
            """conv4"""
            out = tf.layers.conv2d(out, n1, [5, 5], strides=(2, 2), padding='SAME', kernel_regularizer=regularizer)
            out = leaky(tf.layers.batch_normalization(out, training=True))
            """full-connect"""
            out = tf.reshape(out, [-1,4*4*n1])
            out = tf.layers.dense(out, 2, kernel_regularizer=regularizer)
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d')
        self.reuse = True
        return(out)


In [None]:
"""誤差関数のセット"""
"""prediction: G->0, T->1"""
gen = Generator()
dis = Discriminator()

G_y = tf.placeholder(tf.int32, shape=None)
G_loss = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=G_y,
        logits=dis(gen(G_feed))))

D_yF = tf.placeholder(tf.int32, shape=None)
D_loss = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=D_yF,
        logits=dis(gen(G_feed))))

D_yT = tf.placeholder(tf.int32, shape=None)
D_loss += tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=D_yT,
        logits=dis(D_feed)))

In [None]:
"""AdamOptimizerのセット"""
G_optimizer = tf.train.AdamOptimizer(
        learning_rate=0.0002, beta1=0.5).minimize(G_loss, var_list=gen.variables)
D_optimizer = tf.train.AdamOptimizer(
        learning_rate=0.0002, beta1=0.5).minimize(D_loss, var_list=dis.variables)
sess.run(tf.global_variables_initializer())
e = 1700

In [None]:
class midresult:
    def __init__(self):
        with open("testgen.pickle", "rb") as f:
            self.testgen = pickle.load(f)
        """結果バーっと出力"""
    def __call__(self,n):
        im = gen(self.testgen).eval()
        plt.figure(figsize=(10,10), dpi=100)
        grid = ImageGrid(plt.figure(figsize=(10,10), dpi=100), 111, nrows_ncols = (8, batch//8))
        for j in range(batch):
            grid[j].tick_params(labelbottom='off', labelleft='off')
            im[j] += 1
            im[j] /= 2
            grid[j].imshow(im[j])
        plt.savefig("res/{0:06d}_cg.png".format(n))
        plt.show()
        plt.close()
writer = midresult()

In [None]:
"""学習"""
batch = 64
e=0
epoch = 100
dsize = int(data.shape[0])
p = ProgressBar(max_value=int(epoch)*int(dsize/batch))
b = 0
for _ in range(epoch):
    e += 1
    np.random.shuffle(data)
    for i in range(0, dsize - batch, batch):
        b += 1
        p.update(b)
        ones = np.ones(batch)
        zeros = np.zeros(batch)
        G_f = np.random.rand(batch, nz).astype(np.float32)*2 - 1
        D_f = data[i:i+batch]
#         print(dis(D_f).eval())
#         print(dis(gen(G_f)).eval())
        """ここでGに2回通して非効率だからGの後にPlaceHolder置くように直さないといけない"""
        G_optimizer.run(feed_dict={G_feed:G_f,G_y:ones})
        D_optimizer.run(feed_dict={G_feed:G_f,D_feed:D_f,D_yF:zeros,D_yT:ones})
    if e % 10 == 0:
        print("epoch",e,"of",epoch,":")
        print(" G_loss:",G_loss.eval(feed_dict={G_feed:G_f, G_y:ones}))
        print(" D_loss:",D_loss.eval(feed_dict={G_feed:G_f, D_feed:D_f, D_yT:ones, D_yF:zeros}))
        writer(e)



In [None]:
writer(5000)

In [None]:
"""変数ロード"""
saver = tf.train.Saver()
saver.restore(sess, "cg_model_4_1450.tfv")
with open("features", "rb") as f:
    features = pickle.load(f)

In [None]:
"""変数セーブ"""
saver = tf.train.Saver()
saver.save(sess, "cg_model_4_2000.tfv")
with open("testgen.pickle", "wb") as f:
    pickle.dump(writer.testgen, f)
with open("features", "wb") as f:
    pickle.dump(features, f)
print("Done")



In [None]:
"""途中層"""
G_f = np.random.rand(batch, nz).astype(np.float32) -.5
im = gen.midout(G_f).eval()
real = (gen(G_f).eval()+1)/2
for i in range(5):
    grid = ImageGrid(plt.figure(figsize=(10,10), dpi=100), 111, nrows_ncols = (5, 8))
    grid[0].tick_params(labelbottom='off', labelleft='off')
    grid[0].imshow(cv2.resize(real[i],(32,32)))
    for j in range(n3):
        grid[8+j].tick_params(labelbottom='off', labelleft='off')
        im[i,:,:,j] += 1
        im[i,:,:,j] /= 2
        grid[8+j].imshow(im[i,:,:,j],"gray")
    plt.show()
plt.close()

In [None]:
"""ええやつだけ出す"""
grid = ImageGrid(plt.figure(figsize=(10,10), dpi=100), 111, nrows_ncols = (8, 16))
G_f = np.random.rand(batch, nz).astype(np.float32)-.5
ims = gen(G_f).eval()/2 +.5
for i in range(batch):
    grid[i].imshow(ims[i])
plt.show()
plt.close()

In [None]:
"""人力選別"""
good = []
goodz = []
G_f = None
a = Button(description="GOOD")
b = Button(description="BAD")
display.display(a)
display.display(b)
ims = []
im = None
z = None

def foo():
    global good,ims,im,goodz,z,G_f
    clear_output()
    if len(ims) == 0:
        G_f = np.random.rand(batch, nz).astype(np.float32) - 0.5
        ims = list((gen(G_f).eval()+1)/2)
        G_f = list(G_f)
    im = ims.pop()
    z = G_f.pop()
    plt.imshow(im)
    plt.show()
    plt.close()
    print(len(good))
    
def a_c(a):
    global good,ims,im,z,goodz
    good.append(im)
    goodz.append(z)
    foo()
    
def b_c(b):
    foo()
    
foo()
a.on_click(a_c)
b.on_click(b_c)

In [None]:
G_f = (np.random.rand(batch, nz).astype(np.float32) - .5)*1.5
images = gen(G_f).eval()
grid = ImageGrid(plt.figure(figsize=(10,10), dpi=100), 111, nrows_ncols = (10, batch//9))
for i in range(batch):
    grid[i].imshow((images[i]+1)/2)
    grid[i].set_title(str(i))
plt.show()
plt.close()

In [None]:
"""↑で見つけたいいやつの保存"""
# features["bright_hair"] = G_f[[0,24,25,73]]
features["bright_hair"] = np.vstack((features["bright_hair"],G_f[[]]))
print(features["bright_hair"].shape)

In [None]:
"""check"""
seed = features["bright_hair"]
G_f = np.zeros((batch,nz),dtype=np.float32)
G_f[:len(seed)] = seed
chkims = gen(G_f).eval()
grid = ImageGrid(plt.figure(figsize=(10,10), dpi=100), 111, nrows_ncols = (10, batch//9))
for i in range(len(seed)):
    grid[i].imshow((chkims[i]+1)/2)
#     grid[i].set_title(str(i))
plt.show()
plt.close()

In [None]:
features["bright_hair"] = np.delete(features["bright_hair"],[],axis=0)

In [None]:
"""比較"""
vec = features["bright_hair"].mean(axis=0)
G_f = np.random.rand(batch, nz).astype(np.float32) - .5
def norm(x):
    return(np.sqrt(np.square(x).sum(axis=1)).reshape(-1,1))
G_f[:12] = G_f[12:12*2] - vec
G_f[:12] *= norm(G_f[12:12*2])/norm(G_f[:12])
G_f[12*2:12*3] = G_f[12:12*2] + vec
G_f[12*2:12*3] *= norm(G_f[12:12*2])/norm(G_f[12*2:12*3])
grid = ImageGrid(plt.figure(figsize=(10,3), dpi=100), 111, nrows_ncols = (3,12),axes_pad=0)
images = gen(G_f).eval()/2 + 0.5
for i in range(12*3):
    grid[i].imshow(images[i])
plt.show()
plt.close()

In [None]:
"ええやつ"
grid = ImageGrid(plt.figure(figsize=(10,10), dpi=100), 111, nrows_ncols = (8, 8))
for i in range(min(len(good),64)):
    grid[i].imshow(good[i])
plt.show()
plt.close()

In [None]:
"""aからbへの変遷"""
a = goodz[0]
b = goodz[1]
div = 16
grid = ImageGrid(plt.figure(figsize=(div,1), dpi=60), 111, nrows_ncols = (1,div),axes_pad=0)
def indiv(a,b,n):
    return([a*(1-i) + b*i for i in np.arange(0,1+1/(n-1),1/(n-1))])
G_f = np.random.rand(batch, nz).astype(np.float32)*2 - 1
G_f[:div] = indiv(a,b,div)
images = gen(G_f).eval()/2 + 0.5
for i in range(div):
    grid[i].imshow(images[i])
plt.show()

In [None]:
a = goodz[0]
b = goodz[1]
c = goodz[2]
d = goodz[6]
div = 11
grid = ImageGrid(plt.figure(figsize=(div,div), dpi=60), 111, nrows_ncols = (div,div),axes_pad=0)
def indiv(a,b,n):
     return([a*(1-i) + b*i for i in np.arange(0,1+1/(n-1)-10**(-10),1/(n-1))])
G_f = np.random.rand(batch, nz).astype(np.float32)*2 - 1
square = np.zeros((div,div,nz),dtype=np.float32)
upper = indiv(a,b,div)
lower = indiv(c,d,div)
square[0,:] = upper
square[-1,:] = lower
for i in range(div):
    square[:,i] = indiv(upper[i],lower[i],div)
G_f[:div**2] = square.reshape(-1,1,100)[:,0,:]
images = gen(G_f).eval()/2 + 0.5
for i in range(div**2):
    grid[i].imshow(images[i])
plt.show()


In [None]:
"""ガウス分布による発生"""
G_f = (np.random.rand(batch, nz).astype(np.float32)-.5)
images = gen(G_f).eval()
grid = ImageGrid(plt.figure(figsize=(10,10), dpi=100), 111, nrows_ncols = (8, batch//8))
for i in range(batch):
    grid[i].imshow((images[i]+1)/2)
plt.show()
plt.close()