Skip to content

Commit

Permalink
up-to-date with tensorflow/models commits
Browse files Browse the repository at this point in the history
  • Loading branch information
daviddao committed Jul 7, 2016
1 parent b7f98dd commit 65a67c6
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 124 deletions.
23 changes: 9 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# Spatial Transformer Network

A more recent and maintained version can be found in [Tensorflow/models](https://github.com/tensorflow/models/tree/master/transformer).

Spatial Transformer Networks [1] allow us to attend specific regions of interest of an image while, at the same time, provide invariance to shapes and sizes of the resulting image patches. This can improve the accuracy of the CNN and discover meaningful discriminative regions of an image.
The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.

<div align="center">
<img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br>
</div>

### API

A Spatial Transformer Network based on [2] and implemented in Tensorflow.
A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].

#### How to use

Expand All @@ -19,7 +17,7 @@ A Spatial Transformer Network based on [2] and implemented in Tensorflow.
</div>

```python
transformer(U, theta, downsample_factor=1)
transformer(U, theta, out_size)
```

#### Parameters
Expand All @@ -30,13 +28,8 @@ transformer(U, theta, downsample_factor=1)
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
out_size: tuple of two ints
The size of the output of the network

#### Notes
Expand All @@ -52,10 +45,12 @@ theta = tf.Variable(initial_value=identity)
#### Experiments

<div align="center">
<img width="600px" src="./cluttered_mnist.png"><br><br>
<img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br>
</div>

We used cluttered MNIST. Left columns are the input images, right columns are the attended parts of the image by an STN.
We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.

All experiments were run in Tensorflow 0.7.

### References

Expand Down
Binary file removed cluttered_mnist.png
Binary file not shown.
74 changes: 45 additions & 29 deletions cluttered_mnist.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot
from tf_utils import weight_variable, bias_variable, dense_to_one_hot

# %% Load the data
# %% Load data
mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')

X_train = mnist_cluttered['X_train']
Expand All @@ -23,7 +35,7 @@
# %% Graph representation of our network

# %% Placeholders for 40x40 resolution
x = tf.placeholder(tf.float32, [None, 1600])
x = tf.placeholder(tf.float32, [None, 1600])
y = tf.placeholder(tf.float32, [None, 10])

# %% Since x is currently [batch, height*width], we need to reshape to a
Expand All @@ -34,13 +46,15 @@
# dimension should not change size.
x_tensor = tf.reshape(x, [-1, 40, 40, 1])

# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
# %% We'll setup the two-layer localisation network to figure out the
# %% parameters for an affine transformation of the input
# %% Create variables for fully connected layer
W_fc_loc1 = weight_variable([1600, 20])
b_fc_loc1 = bias_variable([20])

W_fc_loc2 = weight_variable([20, 6])
initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
# Use identity transformation as starting point
initial = np.array([[1., 0, 0], [0, 1., 0]])
initial = initial.astype('float32')
initial = initial.flatten()
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
Expand All @@ -53,8 +67,10 @@
# %% Second layer
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)

# %% We'll create a spatial transformer module to identify discriminative patches
h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)
# %% We'll create a spatial transformer module to identify discriminative
# %% patches
out_size = (40, 40)
h_trans = transformer(x_tensor, h_fc_loc2, out_size)

# %% We'll setup the first convolutional layer
# Weight matrix is [height x width x input_channels x output_channels]
Expand Down Expand Up @@ -103,16 +119,17 @@
# %% And finally our softmax layer:
W_fc2 = weight_variable([n_fc, 10])
b_fc2 = bias_variable([10])
y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
y_logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

# %% Define loss/eval/training functions
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(y_logits, y))
opt = tf.train.AdamOptimizer()
optimizer = opt.minimize(cross_entropy)
grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])

# %% Monitor accuracy
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
correct_prediction = tf.equal(tf.argmax(y_logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))

# %% We now create a new session to actually perform the initialization the
Expand All @@ -126,33 +143,32 @@
n_epochs = 500
train_size = 10000

indices = np.linspace(0,10000 - 1,iter_per_epoch)
indices = np.linspace(0, 10000 - 1, iter_per_epoch)
indices = indices.astype('int')

for epoch_i in range(n_epochs):
for iter_i in range(iter_per_epoch - 1):
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]

if iter_i % 10 == 0:
loss = sess.run(cross_entropy,
feed_dict={
x: batch_xs,
y: batch_ys,
keep_prob: 1.0
})
feed_dict={
x: batch_xs,
y: batch_ys,
keep_prob: 1.0
})
print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))

sess.run(optimizer, feed_dict={
x: batch_xs, y: batch_ys, keep_prob: 0.8})


print('Accuracy: ' + str(sess.run(accuracy,
feed_dict={
x: X_valid,
y: Y_valid,
keep_prob: 1.0
})))
#theta = sess.run(h_fc_loc2, feed_dict={

print('Accuracy (%d): ' % epoch_i + str(sess.run(accuracy,
feed_dict={
x: X_valid,
y: Y_valid,
keep_prob: 1.0
})))
# theta = sess.run(h_fc_loc2, feed_dict={
# x: batch_xs, keep_prob: 1.0})
#print(theta[0])
# print(theta[0])
57 changes: 36 additions & 21 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,61 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from scipy import ndimage
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable

# Preprocessing
# Create a batch of three images (1600 x 1200)
im = ndimage.imread('./data/cat.jpg')
# %% Create a batch of three images (1600 x 1200)
# %% Image retrieved from:
# %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
im = ndimage.imread('cat.jpg')
im = im / 255.
im = im.reshape(1, 1200, 1600, 3)
im = im.astype('float32')
# Simulate batch

# %% Let the output size of the transformer be half the image size.
out_size = (600, 800)

# %% Simulate batch
batch = np.append(im, im, axis=0)
batch = np.append(batch, im, axis=0)

num_batch = 3
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch,'float32')

num_batch = 3
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch,'float32')
x = tf.cast(batch, 'float32')

# Create localisation network and convolutional layer
# %% Create localisation network and convolutional layer
with tf.variable_scope('spatial_transformer_0'):

# %% Create a fully-connected layer:
n_fc = 6
# %% Create a fully-connected layer with 6 output nodes
n_fc = 6
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
initial = np.array([[0.5,0, 0],[0,0.5,0]])

# %% Zoom into the image
initial = np.array([[0.5, 0, 0], [0, 0.5, 0]])
initial = initial.astype('float32')
initial = initial.flatten()

b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
x_flatten = tf.reshape(x,[-1,1200 * 1600 * 3])
#h_fc1 = tf.nn.relu(tf.matmul(x_flatten, W_fc1) + b_fc1)
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
h_trans = transformer(x, h_fc1, downsample_factor=2)
h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), W_fc1) + b_fc1
h_trans = transformer(x, h_fc1, out_size)

# Run session
# %% Run session
sess = tf.Session()
sess.run(tf.initialize_all_variables())
y = sess.run(h_trans, feed_dict={x: batch})

plt.imshow(y[0])
# plt.imshow(y[0])
Loading

0 comments on commit 65a67c6

Please sign in to comment.