## TF-Slim을 사용한 pre-trained model 사용법
참고: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim/python/slim/nets

In [1]:
import os
import urllib.request
import tarfile

model_dir = "model/"
vgg_tar_path  = "model/vgg_19_2016_08_28.tar.gz"
vgg_url = "http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz"

resnet_url = "http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz"
resnet_tar_path  = "model/resnet_v1_101_2016_08_28.tar.gz"

if not os.path.exists(model_dir):
    os.makedirs(model_dir)    

model_path = os.path.join(model_dir, "vgg_19.ckpt")
if not os.path.exists(model_path):
    print("Downloading and extracting vgg-19 networks...")
    file_path, _ = urllib.request.urlretrieve(url=vgg_url, filename=vgg_tar_path)
    tarfile.open(name=vgg_tar_path, mode="r:gz").extractall(model_dir)
    
model_path = os.path.join(model_dir, "resnet_v1_101.ckpt")
if not os.path.exists(model_path):
    print("Downloading and extracting resnet-101 networks...")
    file_path, _ = urllib.request.urlretrieve(url=resnet_url, filename=resnet_tar_path)
    tarfile.open(name=resnet_tar_path, mode="r:gz").extractall(model_dir)

In [2]:
%matplotlib inline
import numpy as np
import scipy.misc as misc
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
vgg = nets.vgg
resnet = nets.resnet_v1

In [3]:
image_path = "image.jpg"

mean_pixel = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 1, 3))

# load images
image = misc.imread(image_path)
image = misc.imresize(image, (224, 224))
image = image.reshape(1, 224, 224, 3).astype(np.float32)
image -= mean_pixel

## VGG-19 모델에서 특징 추출

In [4]:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])

with slim.arg_scope(vgg.vgg_arg_scope()):
    logit, model = vgg.vgg_19(inputs)
    
for k, v in model.items():
    print(k, v.get_shape())
    
init_fn = slim.assign_from_checkpoint_fn("model/vgg_19.ckpt",
    slim.get_variables_to_restore(exclude=["vgg_19/fc8"]),
    ignore_missing_vars=True)

sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
sess = tf.Session(config=sess_config)
sess.run(tf.global_variables_initializer())
init_fn(sess)

vgg_19/conv1/conv1_1 (?, 224, 224, 64)
vgg_19/conv1/conv1_2 (?, 224, 224, 64)
vgg_19/pool1 (?, 112, 112, 64)
vgg_19/conv2/conv2_1 (?, 112, 112, 128)
vgg_19/conv2/conv2_2 (?, 112, 112, 128)
vgg_19/pool2 (?, 56, 56, 128)
vgg_19/conv3/conv3_1 (?, 56, 56, 256)
vgg_19/conv3/conv3_2 (?, 56, 56, 256)
vgg_19/conv3/conv3_3 (?, 56, 56, 256)
vgg_19/conv3/conv3_4 (?, 56, 56, 256)
vgg_19/pool3 (?, 28, 28, 256)
vgg_19/conv4/conv4_1 (?, 28, 28, 512)
vgg_19/conv4/conv4_2 (?, 28, 28, 512)
vgg_19/conv4/conv4_3 (?, 28, 28, 512)
vgg_19/conv4/conv4_4 (?, 28, 28, 512)
vgg_19/pool4 (?, 14, 14, 512)
vgg_19/conv5/conv5_1 (?, 14, 14, 512)
vgg_19/conv5/conv5_2 (?, 14, 14, 512)
vgg_19/conv5/conv5_3 (?, 14, 14, 512)
vgg_19/conv5/conv5_4 (?, 14, 14, 512)
vgg_19/pool5 (?, 7, 7, 512)
vgg_19/fc6 (?, 1, 1, 4096)
vgg_19/fc7 (?, 1, 1, 4096)
vgg_19/fc8 (?, 1000)
INFO:tensorflow:Restoring parameters from model/vgg_19.ckpt


In [5]:
fc7 = sess.run(model["vgg_19/fc7"], feed_dict={inputs:image})
fc7 = np.squeeze(fc7)
print(fc7.shape)

(4096,)


In [6]:
# 세션 닫기
tf.reset_default_graph()
sess.close()

## ResNet-101 모델에서 특징 추출

In [7]:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])

with slim.arg_scope(resnet.resnet_arg_scope()):
    net, model = resnet.resnet_v1_101(inputs)
    
for k, v in model.items():
    print(k, v.get_shape())
print(net)
    
init_fn = slim.assign_from_checkpoint_fn("model/resnet_v1_101.ckpt",
    slim.get_variables_to_restore(exclude=["resnet_v1_101/logits"]),
    ignore_missing_vars=True)

sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
sess = tf.Session(config=sess_config)
sess.run(tf.global_variables_initializer())
init_fn(sess)

resnet_v1_101/conv1 (?, 112, 112, 64)
resnet_v1_101/block1/unit_1/bottleneck_v1/shortcut (?, 56, 56, 256)
resnet_v1_101/block1/unit_1/bottleneck_v1/conv1 (?, 56, 56, 64)
resnet_v1_101/block1/unit_1/bottleneck_v1/conv2 (?, 56, 56, 64)
resnet_v1_101/block1/unit_1/bottleneck_v1/conv3 (?, 56, 56, 256)
resnet_v1_101/block1/unit_1/bottleneck_v1 (?, 56, 56, 256)
resnet_v1_101/block1/unit_2/bottleneck_v1/conv1 (?, 56, 56, 64)
resnet_v1_101/block1/unit_2/bottleneck_v1/conv2 (?, 56, 56, 64)
resnet_v1_101/block1/unit_2/bottleneck_v1/conv3 (?, 56, 56, 256)
resnet_v1_101/block1/unit_2/bottleneck_v1 (?, 56, 56, 256)
resnet_v1_101/block1/unit_3/bottleneck_v1/conv1 (?, 56, 56, 64)
resnet_v1_101/block1/unit_3/bottleneck_v1/conv2 (?, 28, 28, 64)
resnet_v1_101/block1/unit_3/bottleneck_v1/conv3 (?, 28, 28, 256)
resnet_v1_101/block1/unit_3/bottleneck_v1 (?, 28, 28, 256)
resnet_v1_101/block1 (?, 28, 28, 256)
resnet_v1_101/block2/unit_1/bottleneck_v1/shortcut (?, 28, 28, 512)
resnet_v1_101/block2/unit_1/bott

INFO:tensorflow:Restoring parameters from model/resnet_v1_101.ckpt


In [8]:
pool5 = sess.run(net, feed_dict={inputs:image})
pool5 = np.squeeze(pool5)
print(pool5.shape)

(2048,)
