Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
2 contributors

Users who have contributed to this file

@bojone @repletetop
251 lines (201 sloc) 6.86 KB
#! -*- coding: utf-8 -*-
# Keras implement of NICE (Non-linear Independent Components Estimation)
# https://arxiv.org/abs/1410.8516
from keras.layers import *
from keras.models import Model
from keras.datasets import mnist
from keras import backend as K
from keras.callbacks import ModelCheckpoint
import imageio
(x_train, y_train), (x_test, y_test) = mnist.load_data()
image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
class Shuffle(Layer):
"""打乱层,提供两种方式打乱输入维度
一种是直接反转,一种是随机打乱,默认是直接反转维度
"""
def __init__(self, idxs=None, mode='reverse', **kwargs):
super(Shuffle, self).__init__(**kwargs)
self.idxs = idxs
self.mode = mode
def call(self, inputs):
v_dim = K.int_shape(inputs)[-1]
if self.idxs == None:
self.idxs = list(range(v_dim))
if self.mode == 'reverse':
self.idxs = self.idxs[::-1]
elif self.mode == 'random':
np.random.shuffle(self.idxs)
inputs = K.transpose(inputs)
outputs = K.gather(inputs, self.idxs)
outputs = K.transpose(outputs)
return outputs
def inverse(self):
v_dim = len(self.idxs)
_ = sorted(zip(range(v_dim), self.idxs), key=lambda s: s[1])
reverse_idxs = [i[0] for i in _]
return Shuffle(reverse_idxs)
class SplitVector(Layer):
"""将输入分区为两部分,交错分区
"""
def __init__(self, **kwargs):
super(SplitVector, self).__init__(**kwargs)
def call(self, inputs):
v_dim = K.int_shape(inputs)[-1]
inputs = K.reshape(inputs, (-1, v_dim//2, 2))
return [inputs[:,:,0], inputs[:,:,1]]
def compute_output_shape(self, input_shape):
v_dim = input_shape[-1]
return [(None, v_dim//2), (None, v_dim//2)]
def inverse(self):
layer = ConcatVector()
return layer
class ConcatVector(Layer):
"""将分区的两部分重新合并
"""
def __init__(self, **kwargs):
super(ConcatVector, self).__init__(**kwargs)
def call(self, inputs):
inputs = [K.expand_dims(i, 2) for i in inputs]
inputs = K.concatenate(inputs, 2)
return K.reshape(inputs, (-1, np.prod(K.int_shape(inputs)[1:])))
def compute_output_shape(self, input_shape):
return (None, sum([i[-1] for i in input_shape]))
def inverse(self):
layer = SplitVector()
return layer
class AddCouple(Layer):
"""加性耦合层
"""
def __init__(self, isinverse=False, **kwargs):
self.isinverse = isinverse
super(AddCouple, self).__init__(**kwargs)
def call(self, inputs):
part1, part2, mpart1 = inputs
if self.isinverse:
return [part1, part2 + mpart1] # 逆为加
else:
return [part1, part2 - mpart1] # 正为减
def compute_output_shape(self, input_shape):
return [input_shape[0], input_shape[1]]
def inverse(self):
layer = AddCouple(True)
return layer
class Scale(Layer):
"""尺度变换层
"""
def __init__(self, **kwargs):
super(Scale, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(name='kernel',
shape=(1, input_shape[1]),
initializer='glorot_normal',
trainable=True)
def call(self, inputs):
self.add_loss(-K.sum(self.kernel)) # 对数行列式
return K.exp(self.kernel) * inputs
def inverse(self):
scale = K.exp(-self.kernel)
return Lambda(lambda x: scale * x)
def build_basic_model(v_dim):
"""基础模型,即加性耦合层中的m
"""
_in = Input(shape=(v_dim,))
_ = _in
for i in range(5):
_ = Dense(1000, activation='relu')(_)
_ = Dense(v_dim, activation='relu')(_)
return Model(_in, _)
shuffle1 = Shuffle()
shuffle2 = Shuffle()
shuffle3 = Shuffle()
shuffle4 = Shuffle()
split = SplitVector()
couple = AddCouple()
concat = ConcatVector()
scale = Scale()
basic_model_1 = build_basic_model(original_dim//2)
basic_model_2 = build_basic_model(original_dim//2)
basic_model_3 = build_basic_model(original_dim//2)
basic_model_4 = build_basic_model(original_dim//2)
x_in = Input(shape=(original_dim,))
x = x_in
# 给输入加入负噪声
x = Lambda(lambda s: K.in_train_phase(s-0.01*K.random_uniform(K.shape(s)), s))(x)
x = shuffle1(x)
x1,x2 = split(x)
mx1 = basic_model_1(x1)
x1, x2 = couple([x1, x2, mx1])
x = concat([x1, x2])
x = shuffle2(x)
x1,x2 = split(x)
mx1 = basic_model_2(x1)
x1, x2 = couple([x1, x2, mx1])
x = concat([x1, x2])
x = shuffle3(x)
x1,x2 = split(x)
mx1 = basic_model_3(x1)
x1, x2 = couple([x1, x2, mx1])
x = concat([x1, x2])
x = shuffle4(x)
x1,x2 = split(x)
mx1 = basic_model_4(x1)
x1, x2 = couple([x1, x2, mx1])
x = concat([x1, x2])
x = scale(x)
encoder = Model(x_in, x)
encoder.summary()
encoder.compile(loss=lambda y_true,y_pred: K.sum(0.5 * y_pred**2, 1),
optimizer='adam')
checkpoint = ModelCheckpoint(filepath='./best_encoder.weights',
monitor='val_loss',
verbose=1,
save_best_only=True)
encoder.fit(x_train,
x_train,
batch_size=128,
epochs=30,
validation_data=(x_test, x_test),
callbacks=[checkpoint])
encoder.load_weights('./best_encoder.weights')
# 搭建逆模型(生成模型),将所有操作倒过来执行
x = x_in
x = scale.inverse()(x)
x1,x2 = concat.inverse()(x)
mx1 = basic_model_4(x1)
x1, x2 = couple.inverse()([x1, x2, mx1])
x = split.inverse()([x1, x2])
x = shuffle4.inverse()(x)
x1,x2 = concat.inverse()(x)
mx1 = basic_model_3(x1)
x1, x2 = couple.inverse()([x1, x2, mx1])
x = split.inverse()([x1, x2])
x = shuffle3.inverse()(x)
x1,x2 = concat.inverse()(x)
mx1 = basic_model_2(x1)
x1, x2 = couple.inverse()([x1, x2, mx1])
x = split.inverse()([x1, x2])
x = shuffle2.inverse()(x)
x1,x2 = concat.inverse()(x)
mx1 = basic_model_1(x1)
x1, x2 = couple.inverse()([x1, x2, mx1])
x = split.inverse()([x1, x2])
x = shuffle1.inverse()(x)
decoder = Model(x_in, x)
# 采样查看生成效果
n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
for i in range(n):
for j in range(n):
z_sample = np.array(np.random.randn(1, original_dim)) * 0.75 # 标准差取0.75而不是1
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
figure = np.clip(figure*255, 0, 255)
imageio.imwrite('test.png', figure)
You can’t perform that action at this time.