In [None]:
# coding: utf-8
import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.models import Model
import numpy as np
import os

img_size = 256

#os.environ["CUDA_VISIBLE_DEVICES"] = ""

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
 

---
## 2.build the transfer model

In [None]:

MEAN_VALUES = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 1, 3)) # 通道颜色均值
#k_initializer = tf.truncated_normal_initializer(0, 0.1)
k_initializer = keras.initializers.TruncatedNormal(0, 0.1)
variation_weight = 0.01
kernel_regularizer =  keras.regularizers.l2(variation_weight)
bias_regularizer = keras.regularizers.l2(variation_weight) 
def relu(X):
    return keras.layers.Activation('relu')(X)

def instance_norm(X):
    return keras.layers.LayerNormalization()(X)

def conv2d(inputs, filters, kernel_size, strides, name = "noname"):
    return keras.layers.Conv2D(filters, kernel_size, strides, padding="same", kernel_initializer=k_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name = name)(inputs)

def deconv2d(inputs, filters, kernel_size = 3, strides = 1, name="noname"):
    return tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides,  padding="same", kernel_initializer=k_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=name)(inputs)

    # 残差网络
def residual(inputs, filters = 128, kernel_size = 3, name="noname"):
    X = relu(conv2d(inputs, filters, kernel_size, 1, name=name+"_1"))
    X = conv2d(X, filters, kernel_size, 1, name=name+"_2")
    return keras.layers.Add()([inputs, X])

def get_transfer_model(input_shape=(256, 256, 3), name="style_transfer_net"):
    img_inputs = keras.Input(input_shape, name="transfer_inputs")
    #X = tf.pad(img_inputs - MEAN_VALUES, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='reflect')
    X = keras.layers.Subtract()([img_inputs, MEAN_VALUES])
    X = relu(instance_norm(conv2d(X, 32, 9, 1, name="conv1")))
    X = relu(instance_norm(conv2d(X, 64, 3, 2, name="conv2")))
    X = relu(instance_norm(conv2d(X, 128, 3, 2, name="conv3")))

    for i in range(5):
        X = residual(X, 128, 3, name="res"+str(i))

    X = relu(instance_norm(deconv2d(X, 64, 3, 2, name="conv4")))
    X = relu(instance_norm(deconv2d(X, 32, 3, 2, name="conv5")))
    X = keras.layers.Activation('tanh')(instance_norm(conv2d(X, 3, 9, 1, name="conv6")))
    #X = tf.nn.tanh(instance_norm(conv2d(X, 3, 9, 1, name="conv6")))
    #X = (X + 1) * (255.0/2)
    X = keras.layers.Lambda(lambda x: (x+1)*(255.0/2), name="transfer_outputs")(X)
    return keras.Model(inputs=img_inputs, outputs=X, name=name)

In [None]:
# # test transfer model
# test_module = transfer_model()
# test_module.summary()

In [None]:
# keras.utils.plot_model(test_model , show_shapes=True)

---
##  import VGG19
### Content feature and Content loss

In [None]:
vgg19 = VGG19(weights='imagenet', include_top=False)
# set vgg19 to untrainable
for layer in vgg19.layers:
    layer.trainable = False

In [None]:
# get extract model
OUTPUT_LAYERS=["output_feature_" + str(i) for i in range(4)]
STYLE_LAYERS = ['block1_conv2', 'block2_conv2', 'block3_conv3', 'block4_conv3']
def get_feature_extract_model(vgg19):
    features_list  = [vgg19.get_layer(layer_name).output for layer_name in STYLE_LAYERS]
    return keras.Model(inputs=vgg19.input, outputs=features_list, name='output_feature')

def get_features(img_path, model):
    img = image.load_img(img_path, target_size=(img_size, img_size))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return model(x)

feat_extract_model = get_feature_extract_model(vgg19)

In [None]:
# Get style and content image features
img_input = keras.Input(shape=(img_size, img_size, 3), name="img")

# content_features = feat_extract_model(img_input)
content_features = get_features('content/0.jpg', feat_extract_model)
style_img_path='styles/wave.jpg'
style_features = get_features(style_img_path, feat_extract_model)

In [None]:
# for tensor in style_features:
#     print(tensor.shape)
print(content_features)
print(style_features)

In [None]:
# get generate features

transfer_model = get_transfer_model()
X = transfer_model(img_input)
generate_features = feat_extract_model(X)


In [None]:
# content loss
def get_content_loss(content_feature, generate_feature):
    content_loss = 2 * tf.nn.l2_loss(content_feature-generate_feature) / tf.cast(tf.size(content_feature), dtype=tf.float32)
    return content_loss


In [None]:
# style loss

def get_style_gram(style_features): # input: list of tensor
    grams = []
    for feat in style_features:
        feat = np.reshape(feat, (-1, feat.shape[3]))
        gram = np.matmul(feat.T, feat) / feat.size
        grams.append(gram)
    return grams

style_grams = get_style_gram(style_features)

def get_style_loss(style_gram, generate_feature):
    layer = generate_feature
    shape = layer.get_shape().as_list()
    bs, height, width, channel = shape[0], shape[1], shape[2], shape[3]
    features = tf.reshape(layer, (-1, height * width, channel))
    gram = tf.matmul(tf.transpose(features, (0,2,1)), features) / (height * width * channel*1.0)
    size = tf.cast(tf.size(layer), tf.float32)
    style_loss = 2 * tf.nn.l2_loss(gram - style_gram) / size
    #style_loss = tf.reduce_sum(style_loss, name = 'style_loss')
    return style_loss



## 

In [None]:
# # 
# feat_extract_model.summary()
# keras.utils.plot_model(feat_extract_model , show_shapes=True)


In [None]:
# create train model

train_model = keras.Model(inputs=img_input, outputs=generate_features, name='train_model')


In [None]:
# train_model.summary()
# keras.utils.plot_model(train_model , show_shapes=True)

In [None]:
# y = get_features("content/0.jpg",train_model)
# print(y)

In [None]:
# for x in train_model.trainable_weights:
#     print(x.name)

In [None]:
# Loss function

content_weight = 1
style_weight = 250
def loss_fn_0(y_true, y_pred):
    style_loss = get_style_loss(style_grams[0], y_pred) * style_weight
    return style_loss
def loss_fn_1(y_true, y_pred):
    style_loss = get_style_loss(style_grams[1], y_pred) * style_weight
    return style_loss
def loss_fn_2(y_true, y_pred):
    con = feat_extract_model.predict_on_batch(y_true)
    # con = feat_extract_model.predict(y_true)
    content_loss = get_content_loss(content_features[2], y_pred)
    style_loss = get_style_loss(style_grams[2], y_pred)
    total_loss = content_weight*content_loss + style_weight*style_loss
    return total_loss
def loss_fn_3(y_true, y_pred):
    style_loss = get_style_loss(style_grams[3], y_pred) * style_weight
    return style_loss

opt = keras.optimizers.Adam(learning_rate=0.001)
loss_fn = [loss_fn_0, loss_fn_1, loss_fn_2, loss_fn_3]
train_model.compile(optimizer=opt, loss=loss_fn, loss_weights=[1, 1, 1,1])

In [None]:
    # ## 加载图片
    X_data = np.load('train/train2014_5000.preprocessing.npy')
    X_data = X_data[0:60, :, :, :]
    Y_data = X_data.astype(float)


In [None]:
# begin training
STYLE_LAYERS =  ['output_feature', 'output_feature_1', 'output_feature_2', 'output_feature_3']

# Y_data = np.zeros([4, X_data.shape[0] ])

# outs = {STYLE_LAYERS[i]: Y_data[i] for i in range(4)}
outs = {STYLE_LAYERS[i]: Y_data for i in range(4)}

train_model.fit(
    {'img':X_data},
    outs,
    batch_size=1,
    epochs=2
    )

In [None]:
# test
import matplotlib.pyplot as plt

sample_img_path= 'content/0.jpg'
sample_size = 256
# , target_size=(sample_size, sample_size)
img = image.load_img(sample_img_path , target_size=(sample_size, sample_size))
x = image.img_to_array(img)
plt.imshow(x.astype(int))
# print(x.shape)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

predit_model = train_model.get_layer("style_transfer_net")

#imsave('content/0_wave.jpg', result[0])




In [None]:
result = predit_model(x)
result = tf.cast(result, dtype=tf.int32)


In [None]:
import matplotlib.pyplot as plt
plt.axis('off')
plt.imshow(result[0])
# print(x[0].shape)
# print(x[0])

In [None]:
y = feat_extract_model.predict(x)
print(y)