# Notes for FCN8

## Some Pytorch Code for Quick Understanding
- https://github.com/wkentaro/pytorch-fcn
- https://github.com/bodokaiser/piwise

In [1]:
import tensorflow as tf
from tensorflow.contrib import layers
import numpy as np

## load vgg16

In [2]:
# create a session
sess = tf.Session()

In [3]:

# load the frozen model
tf.saved_model.loader.load(sess, ["vgg16"], "data/vgg/")
# check the graph
graph = sess.graph

INFO:tensorflow:Restoring parameters from b'data/vgg/variables/variables'


In [4]:
# key variables
vgg_names = ["image_input", "keep_prob", "layer3_out", "layer4_out", "layer7_out"]
vgg16 = {name: graph.get_tensor_by_name("%s:0" % name) for name in vgg_names}

In [5]:
for name, tensor in vgg16.items():
    print(name, tensor.get_shape())

image_input (?, ?, ?, 3)
keep_prob <unknown>
layer3_out (?, ?, ?, 256)
layer4_out (?, ?, ?, 512)
layer7_out (?, ?, ?, 4096)


## build fcn8

In [6]:
def build_fcn8(graph, vgg16, num_classes):
    model = vgg16.copy()
    with graph.as_default():
        weight_initializer = layers.variance_scaling_initializer()

        # fcn layers
        model["layer3_fcn"] = tf.layers.conv2d(model["layer3_out"], num_classes,
                                               kernel_size=(1, 1), strides=(1, 1),
                                               kernel_initializer=weight_initializer,
                                               name="layer3_fcn")
        model["layer4_fcn"] = tf.layers.conv2d(model["layer4_out"], num_classes,
                                               kernel_size=(1, 1), strides=(1, 1),
                                               kernel_initializer=weight_initializer,
                                               name="layer4_fcn")
        model["layer7_fcn"] = tf.layers.conv2d(model["layer7_out"], num_classes,
                                               kernel_size=(1, 1), strides=(1, 1),
                                               kernel_initializer=weight_initializer,
                                               name="layer7_fcn")
        # upsampling and skipping - bottom up
        model["layer7_up"] = tf.layers.conv2d_transpose(model["layer7_fcn"], num_classes,
                                                        kernel_size=(4, 4), strides=(2, 2),
                                                        kernel_initializer=weight_initializer,
                                                        name="layer7_up", padding="SAME")
        model["layer4_skip"] = tf.add(model["layer4_fcn"], model["layer7_up"], name="layer4_skip")
        model["layer4_up"] = tf.layers.conv2d_transpose(model["layer4_skip"], num_classes,
                                                        kernel_size=(4, 4), strides=(2, 2),
                                                        kernel_initializer=weight_initializer,
                                                        name="layer4_up", padding="SAME")
        model["layer3_skip"] = tf.add(model["layer3_fcn"], model["layer4_up"], name="layer3_skip")
        model["heatmap"] = tf.layers.conv2d_transpose(model["layer3_skip"], num_classes,
                                                        kernel_size=(16, 16), strides=(8, 8),
                                                        kernel_initializer=weight_initializer,
                                                        name="layer3_up", padding="SAME")
        return model

In [7]:
fcn8 = build_fcn8(graph, vgg16, 2)

In [8]:
for name, tensor in fcn8.items():
    print(name, tensor.get_shape())

image_input (?, ?, ?, 3)
keep_prob <unknown>
layer3_out (?, ?, ?, 256)
layer4_out (?, ?, ?, 512)
layer7_out (?, ?, ?, 4096)
layer3_fcn (?, ?, ?, 2)
layer4_fcn (?, ?, ?, 2)
layer7_fcn (?, ?, ?, 2)
layer7_up (?, ?, ?, 2)
layer4_skip (?, ?, ?, 2)
layer4_up (?, ?, ?, 2)
layer3_skip (?, ?, ?, 2)
heatmap (?, ?, ?, 2)


## test fcn8

In [9]:
x = np.random.rand(16, 256, 256, 3)

In [10]:
sess.run(tf.global_variables_initializer())
names = fcn8.keys()
outputs = sess.run([fcn8[n] for n in names], feed_dict={
                            fcn8["image_input"]: x,
                            fcn8["keep_prob"]: 1})

In [11]:
for name, output in zip(names, outputs):
    print(name, output.shape)

image_input (16, 256, 256, 3)
keep_prob ()
layer3_out (16, 32, 32, 256)
layer4_out (16, 16, 16, 512)
layer7_out (16, 8, 8, 4096)
layer3_fcn (16, 32, 32, 2)
layer4_fcn (16, 16, 16, 2)
layer7_fcn (16, 8, 8, 2)
layer7_up (16, 16, 16, 2)
layer4_skip (16, 16, 16, 2)
layer4_up (16, 32, 32, 2)
layer3_skip (16, 32, 32, 2)
heatmap (16, 256, 256, 2)


In [12]:
sess.close()