Skip to content

Commit

Permalink
classic tensorflow folder
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-dunnhofer committed Nov 20, 2017
1 parent 7841e19 commit 4924f2a
Show file tree
Hide file tree
Showing 9 changed files with 899 additions and 64 deletions.
68 changes: 4 additions & 64 deletions README.md
@@ -1,72 +1,12 @@
# AlexNet training on ImageNet LSVRC 2012

As part of a university project I implemented AlexNet model and its training and testing procedures on the ILSVRC 2012 dataset, all using TensorFlow.
![alt text](tsne-full.jpg)

This repository contains an implementation of AlexNet convolutional neural network and its training and testing procedures on the ILSVRC 2012 dataset, all using TensorFlow.

Folder ```tf``` contains code in the "classic TensorFlow" framework whereas code in the ```tf_eager``` directory has been developed with TensorFlow's new impearative style, TensorFlow eager.

### Training
To train AlexNet just run the command:
```shell
python train.py option
```
with options ```-scratch``` to train the model from scratch or ```-resume``` to resume the training from a checkpoint.

I trained AlexNet with the hyperparameters set in the script for ~46000 steps (roughly 46 epochs), decreasing the learning rate two times (by a factor of 10) when the loss became stagnant. The training image were preprocessed subtracting the training-set mean for each channel. No data-augmentation was performed (future improvement). The training was carried on a NVIDIA Tesla K40c (thanks to [Avires Lab](https://https://avires.dimi.uniud.it)) and took a few days.



### Testing
To evaluate the accuracy of the trained model I used the ILSVRC validation set (no test set is available). Run simply:
```shell
python train.py
```
This evaluates *Top-1* and *Top-k* (you can set *k* inside the script) accuracy and error-rate.
Inside the script you can also play with the ```K_CROPS``` parameter to see how the accuracy change when the predictions are averaged through different random crops of the images.

I tested the trained model on the ILSVRC validation set consisting of 50000 images. I obtained a *Top-1* accuracy of **57.31%** and a *Top-5* accuracy of **80.31%**, averaging the predictions on 5 random crops. With more epochs and some tweaks they can be improved of a few more points. I hope to do so in the next weeks.



### Classify an image
To predict the classes of an input image run:
```shell
python classify.py image
```
where ```image``` is the path of the image you want to classify.

e. g. that command on the ```lussari.jpg``` image
![alt text](lussari.jpg)
gives the output:
```shell
AlexNet saw:
alp - score: 0.575796604156
church, church building - score: 0.0516746938229
valley, vale - score: 0.0432425364852
castle - score: 0.0284509658813
monastery - score: 0.0265731271356
```
Again, you can change the number of random crops produced and the *Top-k* prediction retrieved (here are both `5`).



### Notes
```train.py``` and ```test.py``` scripts assume that ImageNet dataset folder is structured in this way:
```
ILSVRC2012
ILSVRC2012_img_train
n01440764
n01443537
n01484850
...
ILSVRC2012_img_val
ILSVRC2012_val_00000001.JPEG
ILSVRC2012_val_00000002.JPEG
...
data
meta.mat
ILSVRC2012_validation_ground_truth.txt
```

The two implementations are independent and refer to the READMEs inside the folders for specific instruction on how to train and to test.


#### References
Expand Down
74 changes: 74 additions & 0 deletions tf/README.md
@@ -0,0 +1,74 @@
# AlexNet training on ImageNet LSVRC 2012

As part of a university project I implemented AlexNet model and its training and testing procedures on the ILSVRC 2012 dataset, all using TensorFlow.



### Training
To train AlexNet just run the command:
```shell
python train.py option
```
with options ```-scratch``` to train the model from scratch or ```-resume``` to resume the training from a checkpoint.

I trained AlexNet with the hyperparameters set in the script for ~46000 steps (roughly 46 epochs), decreasing the learning rate two times (by a factor of 10) when the loss became stagnant. The training image were preprocessed subtracting the training-set mean for each channel. No data-augmentation was performed (future improvement). The training was carried on a NVIDIA Tesla K40c (thanks to [Avires Lab](https://https://avires.dimi.uniud.it)) and took a few days.



### Testing
To evaluate the accuracy of the trained model I used the ILSVRC validation set (no test set is available). Run simply:
```shell
python train.py
```
This evaluates *Top-1* and *Top-k* (you can set *k* inside the script) accuracy and error-rate.
Inside the script you can also play with the ```K_CROPS``` parameter to see how the accuracy change when the predictions are averaged through different random crops of the images.

I tested the trained model on the ILSVRC validation set consisting of 50000 images. I obtained a *Top-1* accuracy of **57.31%** and a *Top-5* accuracy of **80.31%**, averaging the predictions on 5 random crops. With more epochs and some tweaks they can be improved of a few more points. I hope to do so in the next weeks.



### Classify an image
To predict the classes of an input image run:
```shell
python classify.py image
```
where ```image``` is the path of the image you want to classify.

e. g. that command on the ```lussari.jpg``` image
![alt text](lussari.jpg)
gives the output:
```shell
AlexNet saw:
alp - score: 0.575796604156
church, church building - score: 0.0516746938229
valley, vale - score: 0.0432425364852
castle - score: 0.0284509658813
monastery - score: 0.0265731271356
```
Again, you can change the number of random crops produced and the *Top-k* prediction retrieved (here are both `5`).



### Notes
```train.py``` and ```test.py``` scripts assume that ImageNet dataset folder is structured in this way:
```
ILSVRC2012
ILSVRC2012_img_train
n01440764
n01443537
n01484850
...
ILSVRC2012_img_val
ILSVRC2012_val_00000001.JPEG
ILSVRC2012_val_00000002.JPEG
...
data
meta.mat
ILSVRC2012_validation_ground_truth.txt
```



#### References
+ *Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton.* ImageNet Classification with Deep Convolutional Neural Networks. Advances in Neural Inforamtion Processing Systems 25, 2012.
+ *Olga Russakovsky°, Jia Deng°, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei.* (° = equal contribution) ImageNet Large Scale Visual Recognition Challenge. IJCV, 2015
73 changes: 73 additions & 0 deletions tf/classify.py
@@ -0,0 +1,73 @@
"""
Written by Matteo Dunnhofer - 2017
Classify an input image
"""
import sys
import os.path
import tensorflow as tf
import train_util as tu
from models import alexnet
import numpy as np

def classify(
image,
top_k,
k_patches,
ckpt_path,
imagenet_path):
""" Procedure to classify the image given through the command line
Args:
image: path to the image to classify
top_k: integer representing the number of predictions with highest probability
to retrieve
k_patches: number of crops taken from an image and to input to the model
ckpt_path: path to model's tensorflow checkpoint
imagenet_path: path to ILSRVC12 ImageNet folder containing train images,
validation images, annotations and metadata file
"""
wnids, words = tu.load_imagenet_meta(os.path.join(imagenet_path, 'data/meta.mat'))

# taking a few crops from an image
image_patches = tu.read_k_patches(image, k_patches)

x = tf.placeholder(tf.float32, [None, 224, 224, 3])

_, pred = alexnet.classifier(x, dropout=1.0)

# calculate the average precision through the crops
avg_prediction = tf.div(tf.reduce_sum(pred, 0), k_patches)

# retrieve top 5 scores
scores, indexes = tf.nn.top_k(avg_prediction, k=top_k)

saver = tf.train.Saver()

with tf.Session(config=tf.ConfigProto()) as sess:
saver.restore(sess, os.path.join(ckpt_path, 'alexnet-cnn.ckpt'))

s, i = sess.run([scores, indexes], feed_dict={x: image_patches})
s, i = np.squeeze(s), np.squeeze(i)

print('AlexNet saw:')
for idx in range(top_k):
print ('{} - score: {}'.format(words[i[idx]], s[idx]))


if __name__ == '__main__':
TOP_K = 5
K_CROPS = 5
IMAGENET_PATH = 'ILSVRC2012'
CKPT_PATH = 'ckpt-alexnet'

image_path = sys.argv[1]

classify(
image_path,
TOP_K,
K_CROPS,
CKPT_PATH,
IMAGENET_PATH)

Binary file added tf/lussari.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added tf/models/__init__.py
Empty file.
108 changes: 108 additions & 0 deletions tf/models/alexnet.py
@@ -0,0 +1,108 @@
"""
Written by Matteo Dunnhofer - 2017
Definition of AlexNet architecture
"""

import tensorflow as tf
import train_util as tu

def cnn(x):
"""
AlexNet convolutional layers definition
Args:
x: tensor of shape [batch_size, width, height, channels]
Returns:
pool5: tensor with all convolutions, pooling and lrn operations applied
"""
with tf.name_scope('alexnet_cnn') as scope:
with tf.name_scope('alexnet_cnn_conv1') as inner_scope:
wcnn1 = tu.weight([11, 11, 3, 96], name='wcnn1')
bcnn1 = tu.bias(0.0, [96], name='bcnn1')
conv1 = tf.add(tu.conv2d(x, wcnn1, stride=(4, 4), padding='SAME'), bcnn1)
#conv1 = tu.batch_norm(conv1)
conv1 = tu.relu(conv1)
norm1 = tu.lrn(conv1, depth_radius=2, bias=1.0, alpha=2e-05, beta=0.75)
pool1 = tu.max_pool2d(norm1, kernel=[1, 3, 3, 1], stride=[1, 2, 2, 1], padding='VALID')

with tf.name_scope('alexnet_cnn_conv2') as inner_scope:
wcnn2 = tu.weight([5, 5, 96, 256], name='wcnn2')
bcnn2 = tu.bias(1.0, [256], name='bcnn2')
conv2 = tf.add(tu.conv2d(pool1, wcnn2, stride=(1, 1), padding='SAME'), bcnn2)
#conv2 = tu.batch_norm(conv2)
conv2 = tu.relu(conv2)
norm2 = tu.lrn(conv2, depth_radius=2, bias=1.0, alpha=2e-05, beta=0.75)
pool2 = tu.max_pool2d(norm2, kernel=[1, 3, 3, 1], stride=[1, 2, 2, 1], padding='VALID')

with tf.name_scope('alexnet_cnn_conv3') as inner_scope:
wcnn3 = tu.weight([3, 3, 256, 384], name='wcnn3')
bcnn3 = tu.bias(0.0, [384], name='bcnn3')
conv3 = tf.add(tu.conv2d(pool2, wcnn3, stride=(1, 1), padding='SAME'), bcnn3)
#conv3 = tu.batch_norm(conv3)
conv3 = tu.relu(conv3)

with tf.name_scope('alexnet_cnn_conv4') as inner_scope:
wcnn4 = tu.weight([3, 3, 384, 384], name='wcnn4')
bcnn4 = tu.bias(1.0, [384], name='bcnn4')
conv4 = tf.add(tu.conv2d(conv3, wcnn4, stride=(1, 1), padding='SAME'), bcnn4)
#conv4 = tu.batch_norm(conv4)
conv4 = tu.relu(conv4)

with tf.name_scope('alexnet_cnn_conv5') as inner_scope:
wcnn5 = tu.weight([3, 3, 384, 256], name='wcnn5')
bcnn5 = tu.bias(1.0, [256], name='bcnn5')
conv5 = tf.add(tu.conv2d(conv4, wcnn5, stride=(1, 1), padding='SAME'), bcnn5)
#conv5 = tu.batch_norm(conv5)
conv5 = tu.relu(conv5)
pool5 = tu.max_pool2d(conv5, kernel=[1, 3, 3, 1], stride=[1, 2, 2, 1], padding='VALID')

return pool5

def classifier(x, dropout):
"""
AlexNet fully connected layers definition
Args:
x: tensor of shape [batch_size, width, height, channels]
dropout: probability of non dropping out units
Returns:
fc3: 1000 linear tensor taken just before applying the softmax operation
it is needed to feed it to tf.softmax_cross_entropy_with_logits()
softmax: 1000 linear tensor representing the output probabilities of the image to classify
"""
pool5 = cnn(x)

dim = pool5.get_shape().as_list()
flat_dim = dim[1] * dim[2] * dim[3] # 6 * 6 * 256
flat = tf.reshape(pool5, [-1, flat_dim])

with tf.name_scope('alexnet_classifier') as scope:
with tf.name_scope('alexnet_classifier_fc1') as inner_scope:
wfc1 = tu.weight([flat_dim, 4096], name='wfc1')
bfc1 = tu.bias(0.0, [4096], name='bfc1')
fc1 = tf.add(tf.matmul(flat, wfc1), bfc1)
#fc1 = tu.batch_norm(fc1)
fc1 = tu.relu(fc1)
fc1 = tf.nn.dropout(fc1, dropout)

with tf.name_scope('alexnet_classifier_fc2') as inner_scope:
wfc2 = tu.weight([4096, 4096], name='wfc2')
bfc2 = tu.bias(0.0, [4096], name='bfc2')
fc2 = tf.add(tf.matmul(fc1, wfc2), bfc2)
#fc2 = tu.batch_norm(fc2)
fc2 = tu.relu(fc2)
fc2 = tf.nn.dropout(fc2, dropout)

with tf.name_scope('alexnet_classifier_output') as inner_scope:
wfc3 = tu.weight([4096, 1000], name='wfc3')
bfc3 = tu.bias(0.0, [1000], name='bfc3')
fc3 = tf.add(tf.matmul(fc2, wfc3), bfc3)
softmax = tf.nn.softmax(fc3)

return fc3, softmax

0 comments on commit 4924f2a

Please sign in to comment.