In [4]:
import logging
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import add_arg_scope
from neuralgym.ops.layers import resize
from neuralgym.ops.layers import *
from neuralgym.ops.loss_ops import *
from neuralgym.ops.summary_ops import *

  from ._conv import register_converters as _register_converters


Instructions for updating:
Use the retry module or similar alternatives.


[32m[2018-10-15 23:11:33 @__init__.py:79][0m Set root logger. Unset logger with neuralgym.unset_logger().
[32m[2018-10-15 23:11:33 @__init__.py:80][0m Saving logging to file: neuralgym_logs/20181015231133654053.


In [5]:
logger = logging.getLogger()
np.random.seed(2018)

In [6]:
def gen_conv(x,cnum,ksize,stride=1,rate=1,name='conv',padding = 'SAME',activation = tf.nn.elu,training = True):
    x = tf.layers.conv2d(x,cnum,ksize,stride,'SAME',name = name)
    x1,x2 = tf.split(x,2)
    x = tf.nn.sigmoid(x2)*tf.nn.leaky_relu(x1)
    return x

In [7]:
def gen_deconv(x,cnum,name = 'upsample',padding = 'SAME',training = True):
    with tf.variable_scope(name):
        x = resize(x,func = tf.image.resize_nearest_neighbor)
        x = gen_conv(x,cnum,3,1,name = name+'_conv',padding = padding,training = training)
    return x

In [8]:
def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
                         fuse_k=3, softmax_scale=10., training=True, fuse=True):
    raw_fs = tf.shape(f)
    raw_int_fs = f.get_shape().as_list()
    raw_int_bs = b.get_shape().as_list()
    kernel = 2*rate
    raw_w = tf.extract_image_patches(
        b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME')
    raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
    raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1])
    f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
    b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor)  # https://github.com/tensorflow/tensorflow/issues/11651
    if mask is not None:
        mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)
    fs = tf.shape(f)
    int_fs = f.get_shape().as_list()
    f_groups = tf.split(f, int_fs[0], axis=0)
    bs = tf.shape(b)
    int_bs = b.get_shape().as_list()
    w = tf.extract_image_patches(
        b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
    w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
    w = tf.transpose(w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    if mask is None:
        mask = tf.zeros([1, bs[1], bs[2], 1])
    m = tf.extract_image_patches(
        mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
    m = tf.reshape(m, [1, -1, ksize, ksize, 1])
    m = tf.transpose(m, [0, 2, 3, 4, 1])
    m = m[0]
    mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.), tf.float32)
    w_groups = tf.split(w, int_bs[0], axis=0)
    raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
    y = []
    offsets = []
    k = fuse_k
    scale = softmax_scale
    fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
    for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
        wi = wi[0]
        wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4)
        yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME")
        if fuse:
            yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
            yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
        yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]])
        yi *=  mm 
        yi = tf.nn.softmax(yi*scale, 3)
        yi *=  mm  

        offset = tf.argmax(yi, axis=3, output_type=tf.int32)
        offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)
        wi_center = raw_wi[0]
        yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4.
        y.append(yi)
        offsets.append(offset)
    y = tf.concat(y, axis=0)
    y.set_shape(raw_int_fs)
    offsets = tf.concat(offsets, axis=0)
    offsets.set_shape(int_bs[:3] + [2])
    h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1])
    w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1])
    offsets = offsets - tf.concat([h_add, w_add], axis=3)
    flow = flow_to_image_tf(offsets)
    if rate != 1:
        flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor)
    return y, flow


In [9]:
def build_inpaint_net(x, mask, config=None, reuse=False,training=True, padding='SAME', name='inpaint_net'):
    xin = x
    offset_flow = None
    ones_x = tf.ones_like(x)[:,:,:.0:1]
    x = tf.concat([x,ones_x,ones_x*mask],axis = 3)
    cnum = 32
    with tf.variable_scope(name, reuse = reuse),arg_scope([gen_conv,gen_deconv],training = training,padding = padding):
        x = gen_conv(x, cnum, 5, 1, name='conv1')
        x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample')
        x = gen_conv(x, 2*cnum, 3, 1, name='conv3')
        x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample')
        x = gen_conv(x, 4*cnum, 3, 1, name='conv5')
        x = gen_conv(x, 4*cnum, 3, 1, name='conv6')
        mask_s = resize_mask_like(mask, x)
        x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous')
        x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous')
        x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous')
        x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous')
        x = gen_conv(x, 4*cnum, 3, 1, name='conv11')
        x = gen_conv(x, 4*cnum, 3, 1, name='conv12')
        x = gen_deconv(x, 2*cnum, name='conv13_upsample')
        x = gen_conv(x, 2*cnum, 3, 1, name='conv14')
        x = gen_deconv(x, cnum, name='conv15_upsample')
        x = gen_conv(x, cnum//2, 3, 1, name='conv16')
        x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
        x = tf.clip_by_value(x, -1., 1.)
        x_stage1 = x
            
        x = x*mask + xin*(1.-mask)
        x.set_shape(xin.get_shape().as_list())
        xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
        x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
        x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
        x = gen_conv(x, 2*cnum, 3, 1, name='xconv3')
        x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample')
        x = gen_conv(x, 4*cnum, 3, 1, name='xconv5')
        x = gen_conv(x, 4*cnum, 3, 1, name='xconv6')
        x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous')
        x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous')
        x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous')
        x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous')
        x_hallu = x
        x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
        x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
        x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3')
        x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample')
        x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5')
        x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6', activation=tf.nn.relu)
        x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
        x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9')
        x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10')
        pm = x
        x = tf.concat([x_hallu, pm], axis=3)
        x = gen_conv(x, 4*cnum, 3, 1, name='allconv11')
        x = gen_conv(x, 4*cnum, 3, 1, name='allconv12')
        x = gen_deconv(x, 2*cnum, name='allconv13_upsample')
        x = gen_conv(x, 2*cnum, 3, 1, name='allconv14')
        x = gen_deconv(x, cnum, name='allconv15_upsample')
        x = gen_conv(x, cnum//2, 3, 1, name='allconv16')
        x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
        x_stage2 = tf.clip_by_value(x, -1., 1.)
    return x_stage1, x_stage2, offset_flow