In [None]:
# coding: utf-8
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras

from tensorflow.keras.applications.vgg19 import VGG19
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import numpy as np
import os

img_size = 256

#os.environ["CUDA_VISIBLE_DEVICES"] = ""

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
 

---
## 2.build the transfer model

In [None]:

MEAN_VALUES = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 1, 3)) # 通道颜色均值
#k_initializer = tf.truncated_normal_initializer(0, 0.1)
k_initializer = keras.initializers.TruncatedNormal(0, 0.1)
variation_weight = 0.01
kernel_regularizer =  keras.regularizers.l2(variation_weight)
bias_regularizer = keras.regularizers.l2(variation_weight) 
activity_regularizer = keras.regularizers.l2(0.01)
def relu(X):
    return keras.layers.Activation('relu')(X)

def instance_norm(X):
    return tfa.layers.InstanceNormalization(axis=3, 
                                    center=True, 
                                    scale=True,
                                    beta_initializer="random_uniform",
                                    gamma_initializer="random_uniform") (X)
    #return keras.layers.LayerNormalization()(X)
# kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer ,
# kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer,  
def conv2d(inputs, filters, kernel_size, strides, name = "noname"):
    return keras.layers.Conv2D(filters, kernel_size, strides, padding="same", kernel_initializer=k_initializer, name = name)(inputs)

def deconv2d(inputs, filters, kernel_size = 3, strides = 1, name="noname"):
    #return tf.keras.layers.Conv2DTranspose(filters, kernel_size, strides,  padding="same", kernel_initializer=k_initializer, activity_regularizer=activity_regularizer, name=name)(inputs)
    # shape = tf.shape(inputs)
    shape = inputs.get_shape().as_list()
    
    height, width = shape[1], shape[2]
    # 近邻插值法，
    print(inputs)
    h0 = tf.image.resize(inputs, [height * strides * 2, width * strides * 2], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    return conv2d(h0, filters, kernel_size, strides, name)

    # 残差网络
def residual(inputs, filters = 128, kernel_size = 3, name="noname"):
    X = relu(conv2d(inputs, filters, kernel_size, 1, name=name+"_1"))
    X = conv2d(X, filters, kernel_size, 1, name=name+"_2")
    return keras.layers.Add()([inputs, X])

def get_transfer_model(input_shape=(256, 256, 3), name="style_transfer_net"):
    img_inputs = keras.Input(input_shape, name="transfer_inputs")
    #X = tf.pad(img_inputs - MEAN_VALUES, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='reflect')
    X = keras.layers.Subtract()([img_inputs, MEAN_VALUES])
    X = relu(instance_norm(conv2d(X, 32, 9, 1, name="conv1")))
    X = relu(instance_norm(conv2d(X, 64, 3, 2, name="conv2")))
    X = relu(instance_norm(conv2d(X, 128, 3, 2, name="conv3")))

    for i in range(5):
        X = residual(X, 128, 3, name="res"+str(i))

    X = relu(instance_norm(deconv2d(X, 64, 3, 2, name="conv4")))
    X = relu(instance_norm(deconv2d(X, 32, 3, 2, name="conv5")))
    X = keras.layers.Activation('tanh')(instance_norm(conv2d(X, 3, 9, 1, name="conv6")))
    #X = tf.nn.tanh(instance_norm(conv2d(X, 3, 9, 1, name="conv6")))
    X = (X + 1) * (255.0/2)
    #X = keras.layers.Lambda(lambda x: (x+1)*(255.0/2), name="transfer_outputs")(X)
    return keras.Model(inputs=img_inputs, outputs=X, name=name)

In [None]:
# # test transfer model
# test_module = get_transfer_model()
# test_module.summary()

In [None]:
# keras.utils.plot_model(test_model , show_shapes=True)

---
##  import VGG19
### Content feature and Content loss

In [None]:
vgg19 = VGG19(weights='imagenet', include_top=False)
# set vgg19 to untrainable
for layer in vgg19.layers:
    layer.trainable = False

In [None]:
# get feature extract model
OUTPUT_LAYERS=["output_feature_" + str(i) for i in range(4)]
STYLE_LAYERS = ['block1_conv2', 'block2_conv2', 'block3_conv3', 'block4_conv3']
def get_feature_extract_model(vgg19):
    features_list  = [vgg19.get_layer(layer_name).output for layer_name in STYLE_LAYERS]
    return keras.Model(inputs=vgg19.input, outputs=features_list, name='output_feature')

def get_features(img_path, model):
    img = image.load_img(img_path, target_size=(img_size, img_size))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return model(x)

feat_extract_model = get_feature_extract_model(vgg19)

In [None]:
# Get style image features

style_img_path='styles/wave.jpg'
style_features = get_features(style_img_path, feat_extract_model)

In [None]:
# # for tensor in style_features:
# #     print(tensor.shape)
# print(content_features)
# print(style_features)

In [None]:
# get generate features

transfer_model = get_transfer_model()


In [None]:
# content loss
def get_content_loss(content_features, generate_features):
    content_loss = 2 * tf.nn.l2_loss(content_features[2]-generate_features[2]) / tf.cast(tf.size(content_features[2]), dtype=tf.float32)
    return content_loss


In [None]:
# style loss

def get_style_gram(style_features): # input: list of tensor
    grams = []
    for feat in style_features:
        feat = np.reshape(feat, (-1, feat.shape[3]))
        gram = np.matmul(feat.T, feat) / feat.size
        grams.append(gram)
    return grams

style_grams = get_style_gram(style_features)

def get_style_loss( generate_features):
    for i in range(len(generate_features)):
        layer = generate_features[i]
        shape = layer.get_shape().as_list()
        bs, height, width, channel = shape[0], shape[1], shape[2], shape[3]
        features = tf.reshape(layer, (-1, height * width, channel))
        gram = tf.matmul(tf.transpose(features, (0,2,1)), features) / (height * width * channel*1.0)
        size = tf.cast(tf.size(layer), tf.float32)
        style_loss = 2 * tf.nn.l2_loss(gram - style_grams[i]) / size
    style_loss = tf.reduce_sum(style_loss, name = 'style_loss')
    return style_loss



## 

In [None]:
# variation loss
def get_total_variation_loss(inputs):
    h = inputs[:, :-1, :, :] - inputs[:, 1:, :, :]
    w = inputs[:, :, 1:, :]
    h_size = tf.cast(tf.size(h), tf.float32)
    w_size = tf.cast(tf.size(w), tf.float32)
    return tf.nn.l2_loss(h)/ h_size + tf.nn.l2_loss(w) / w_size

In [None]:
# # 
# feat_extract_model.summary()
# keras.utils.plot_model(feat_extract_model , show_shapes=True)


In [None]:
# Loss function

content_weight = 1
style_weight = 250
variation_weight = 0.01

def loss_fn(y_true, y_pred):
    content_features = feat_extract_model.predict_on_batch(y_true)
    generate_features = feat_extract_model.predict_on_batch(y_pred)
    # style_features
    content_loss = get_content_loss(content_features, generate_features)
    style_loss = get_style_loss(generate_features)
    variation_loss = get_total_variation_loss(y_pred)
    total_loss = content_weight*content_loss + style_weight*style_loss + variation_weight*variation_loss
    return total_loss

opt = keras.optimizers.Adam(learning_rate=0.001)
transfer_model.compile(optimizer=opt, loss=loss_fn)

In [None]:
    # ## 加载图片
    X_data = np.load('train/train2014_5000.preprocessing.npy')
    X_data = X_data[0:1000, :, :, :]
    Y_data = X_data.astype(float)


In [None]:
# # begin training
transfer_model.fit(X_data, Y_data, batch_size=2, epochs=1)


In [None]:
# test

sample_img_path= 'content/0.jpg'
sample_size = 256
# , target_size=(sample_size, sample_size)
img = image.load_img(sample_img_path , target_size=(sample_size, sample_size))
x = image.img_to_array(img)
plt.imshow(x.astype(int))
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)





In [None]:
result = transfer_model.predict(x)
plt.axis('off')
plt.imshow(result[0].astype(int))
