Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validating ResNet50 #8672

Closed
SmileyScientist opened this issue Dec 3, 2017 · 21 comments
Closed

Validating ResNet50 #8672

SmileyScientist opened this issue Dec 3, 2017 · 21 comments

Comments

@SmileyScientist
Copy link

SmileyScientist commented Dec 3, 2017

I am trying to validate the ResNet50 model which is supposed to give 92.9% Top5 Accuracy and 75.9% Top1 Accuracy (as mentioned: https://keras.io/applications/)

But I am getting only 88.3% Top5 Accuracy and 68.094% Top1 Accuracy.

I have no clue where I am going wrong. Can somebody please help?

The code is as following:
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
import glob

#Functions for top1 and top5 accuracies
def gettop1acc(predictions,truth):
counter=0
for i in range(len(predictions)):
if truth[i] == predictions[i]:
counter = counter+1
return counter*100/(np.size(predictions,axis=0))

def gettop5acc(predictions5,truth):
counter=0
for i in range(np.size(predictions5,axis=0)):
if truth[i] in predictions5[i][:]:
counter = counter+1
return counter*100/(np.size(predictions5,axis=0))

#model
rn = ResNet50(weights='imagenet')

#%%
#Read Data
path = '/hdd/rmk6217/ImageNet/ILSVRC2015/Data/CLS-LOC/val/' #path to imagenet validation set images
X_names = np.array(glob.glob(path + "*.JPEG"))
X_names.sort()

noofinstances = len(X_names)

batchsize=500

#Batch sized X
X = np.zeros((batchsize,224,224,3),dtype=np.float32)

#Prediction for all images in the validation set
preds = np.zeros((noofinstances,1000),dtype=np.float32)
for i in range(0,int(noofinstances/batchsize)):
for j in range(0,batchsize):
img_path = X_names[j+(ibatchsize)]
img = image.load_img(img_path, target_size=(224, 224))
X[j] = image.img_to_array(img)
X = preprocess_input(X)
preds[i
batchsize:i*batchsize+batchsize,:] = rn.predict(X)

del X

#decode to get class, description and probability
p1all = decode_predictions(preds, top=1)
p5all = decode_predictions(preds, top=5)

#Get class for top 1
p1d =[]
for i in range(noofinstances):
p1d.append(p1all[i][0][0])

#Get class for top 5
p5d =[]
for i in range(noofinstances):
temp = []
for j in range(5):
temp.append(p5all[i][j][0])
p5d.append(temp)

#%%
path = '/hdd/rmk6217/ImageNet/ILSVRC2015/devkit/data/' #path to dev-kit

#Load Ground Truth Labels
load = open(path+'ILSVRC2015_clsloc_validation_ground_truth.txt','r')
full = load.readlines()
gt = []
for i in range(len(full)):
temp = full[i].split()
gt.append(int(temp[0]))

#Load Mapping of Labels and Class
mapload = open(path+'map_clsloc.txt','r')
mapfull = mapload.readlines()
mapped = []
for i in range(len(mapfull)):
temp = mapfull[i].split()
mapped.append(temp[0])

#Map out the groundtruth in terms of classes
gtactual = []
for i in range(len(gt)):
temp = mapped[gt[i]-1]
gtactual.append(temp)

#Get top1 and top5 accuracy
print(gettop1acc(p1d,gtactual))
print(gettop5acc(p5d,gtactual))

@fchollet
Copy link
Member

fchollet commented Dec 3, 2017

Possible issues:

  • different image preprocessing
  • different validation set

Also try to see if you are getting the same results with the TF and Theano backends.

I believe the weights in the model came from He et al (ported from Caffe), so they should achieve the claimed accuracy.

@SmileyScientist
Copy link
Author

If I use pixel wise mean instead of the pre-processing function given (https://keras.io/applications/), the top-5 accuracy goes to 89.88%. It doesn't even touch 90%.

The validation set for classification has been same since 2012. The images don't change. So its highly unlikely for the model to preform differently on the same set of images. [Also, coincidentally, I am using the same set as the paper i.e., 2015. So it actually shouldn't vary.]

I did try the two backends. The result is exactly the same.

If you can think of any other thing, please do let me know.
Quick question, did you not train ResNet50 for Keras ?

@guoxiaolu
Copy link

In fact, I have met the same problem, and I have tried many validation methods, and the top-n error is even lower than yours. It makes me confused. To my knowledge, the pretrained model weights are ported from others such as caffe and tensorflow. If you find the validation or train method that can get the close result, please let me know

@kushalkafle
Copy link

I also have the same problem as @blarkj .

@fchollet : Since you did not actually to read @blarkj code, which already clearly shows that she used the preprocess_input designated for ResNet50 class (defined in imagenet_utils), I see no point in putting mine up. But the key points are as follows.

  1. Used the 'imagenet' weights that Keras provides
  2. Used the aptly named "preprocess_input" script for ResNet50 class. If this is not the processing to use, I don't know why it is named as such?
  3. Used the same code that is available for Usage examples for image classification models which literally says "Classify ImageNet classes with ResNet50"
  4. I don't think the validation images has changed. I used the 2012 version if that matters.
  5. The Documentation for individual models claims 92.9 for resnet50, which is different from the reported number in the original paper (94.75 - 10 crop) as well different from the numbers reported in the caffe weights page (92.2% - single crop). It strongly suggests that the number found in Keras page is derived from some form of independent test done by someone, does it not?

We all just want to know what combination of preprocessing, cropping, and validation set gives the same number as claimed in Keras website? Could you provide or point towards a reproducible script that can get the same accuracy as reported in keras website or perhaps point to relevant pull request that does it. I could nowhere find anything about imagenet accuracy anywhere. If ResNet50 is meant to be used as off-the-shelf feature extractor, I think it is paramount that there exists a combination of preprocessing + code that can achieve the reported accuracy.

@acobus
Copy link

acobus commented May 28, 2018

Lets see if we can bring this thread back to life..

I am facing the same problem as @blarkj and @kushalkafle.
I am sure to use the correct ImageNet validation set and the "preprocess_input" method provided by Keras... it results in exactly the same accuracy: 88.3% Top5 and 68.1% Top1.

Also I cant reproduce the accuracies of the vgg16 and vgg19 nets provided in Keras (both significantly lower)!

@fchollet, there must be some script with which you confirmed the results on the Keras website.. it would be really helpful if you could provide this. Thanks in advance!

@sehgal-abhishek
Copy link

So what the Keras Applications page shows is that you can import the image resized to (224, 224) RGB and then preprocess the input i.e. subtract the spatial means and flip the image to BGR and then predict and you should get a pretty accurate estimate of the image.

However if you're evaluating the accuracy reported in the paper, the steps may be a tad bit different for preprocessing for validation:

  • Resize the image by converting the smaller edge to 256, i.e. if size of image is (500, 375), you resize it to (341, 256)

  • Now crop the center square of dimensions (224, 224) from the image

  • Flip the spatial dimensions to BGR if the image was loaded as RGB.

  • Now subtract the spatial mean image that was provided in Kaiming He's github repo. It is in .proto format so if you can convert to .npy it would be excellent, otherwise you can subtract the scalar means. As the image is BGR, subtract (103.939, 116.779, 123.68) from their respective channels.

Now when you validate the model, your Top-5 Accuracy should be around 92% and Top-1 as 74-75%

@guoxiaolu
Copy link

@sehgal-abhishek ,Thank you for your response, this problem makes me confused for a long time.

@RobinC94
Copy link

@sehgal-abhishek Hello! I have the same problem when evaluating pre-trained models provide by Keras. I have tried your resizing and cropping procedure before preprocess_input function. The acc of resnet rise to top-1: 72.74% and top-5: 90.81% , still lower than what Keras provides, same with other models such as vgg, mobilenet.
Is there any other operations I have missed for improving performance of these pre-trained weights?
Thanks!

@ArashAkbarinia
Copy link

I'm experiencing the same issue for all the 13 pretrained networks and their respective accuracy (https://keras.io/applications/).

I do believe that Keras should elaborate in greater details on both the training and evaluating procedure in order to allow its users to replicate those results faithfully and make meaningful comparisons.

Although, I think this issue of not obtaining the exact validation accuracy is very important, but in my personal opinion, it's even more important to know the exact training procedure (including all the parameters set for the optimsier, number of epochs, image generation, etc.) to obtain "identical" weights (I understand due to the random nature of some of those procedures, once cant obtain identical weights, but close enough).

This way one can really compare the effect of changing one parameter on achieved accuracies. Otherwise, it will be always close to imposible to disentangle the importance of "architecture" versus "training procedure".

Thanks a lot :).

@gabrieldemarmiesse
Copy link
Contributor

In the documentation, there is a section explaining for every model how the weights were obtained. What do you think is missing?

@gabrieldemarmiesse
Copy link
Contributor

My bad, some are missing. You can find all the information here in the corresponding files: https://github.com/keras-team/keras-applications/tree/master/keras_applications

@ArashAkbarinia
Copy link

@gabrieldemarmiesse thank you for your prompt response.

Let's take ResNet50 as an example. In the documentation (https://keras.io/applications/#resnet50) it states: "These weights are ported from the ones released by Kaiming He under the MIT license."

In the GitHub of Kaiming He (https://github.com/KaimingHe/deep-residual-networks) there are two tables with different top-1 and top-5 accuracies, that neither of them correspond exactly to those of reported by Keras. So, confusion is already starting.

If we look at the article of Kaiming He (https://arxiv.org/pdf/1512.03385.pdf), section 3.4, it states "The standard color augmentation in [21] is used". At least to me, it's not clear what the standard augmentation is, whether is augmentation, or generation, etc. With respect to the optimiser, it doesn't state whether SGD has used Nesterov momentum or not.

In the source code itself (keras_applications/resnet50.py) there is no further information.

Naturally, you could object that I haven't investigated enough, which is fair enough. But, the point I'd like to make is that there has been a lot of efforts made to collect all these 13 pretrained networks. For sure they have been rigorously tested before publishing. Why not releasing the source code to exactly replicate those results. This facilitates the work of many researchers and avoids confusion.

This was an example, since the thread is about ResNet50. The same is true for most others. A quick another example, is mobilenet_v2. By
default the code points to this model (https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/) . However, in the table of accuracy at the top of keras_applications/mobilenet_v2.py, there is no "v1.1".

I think these little things could make comparisons very difficult.

Thanks a lot :).

@calebrob6
Copy link

@ArashAkbarinia I just wrote up the steps I have found for reproducing the top-1/top-5 accuracy reported in the Keras documentation here, http://calebrob.com/ml/imagenet/ilsvrc2012/2018/10/22/imagenet-benchmarking.html. The code is here, https://github.com/calebrob6/imagenet_validation.

Let me know if this helps or if I can make the steps more clear!

Best.

@ArashAkbarinia
Copy link

Thanks @calebrob6 . I've tried the Keras "flow_from_directory" pipeline with the cropping @sehgal-abhishek mentioned (similar to your code, however with inter linear interpolation rather than the cubic one), I could reproduce results closer to what the documentation of Keras states. However, for almost all as @RobinC94 mentioned there is tendency of 1-2% mismatch.

@calebrob6
Copy link

The difference between using linear and cubic interpolation is not large (I'm observing a 0.1% mismatch in top-5 accuracy). The most important steps seem to be the resizing method (the way @sehgal-abhishek described):

  • Resize the shorter side of each image to 256
  • Resize the longer side to maintain the aspect ratio
  • Central 224x224 crop

E.g. if we instead resize both edges of the validation images to 256x256 then take the central crop we get 1-2% mismatch.

@sehgal-abhishek
Copy link

So the pre-processing provided in my answer is only for ResNet and not for the other models in the applications list. I know Inception and MobileNet follow the same pre-processing as they were both developed at Google. But each of the models has their own pre-processing technique and Keras provides a pre-processing module built-in, though the document related to that is a bit sparse.

If you use a generic imagenet validation, it might not give you the exact accuracy as reported in the paper. Plus in the papers they also use an ensemble of models, and take multiple crops instead of the single center crop that I had done, which boosts the accuracy by 1-2%.

@mostafaelhoushi
Copy link

mostafaelhoushi commented Apr 30, 2019

So what the Keras Applications page shows is that you can import the image resized to (224, 224) RGB and then preprocess the input i.e. subtract the spatial means and flip the image to BGR and then predict and you should get a pretty accurate estimate of the image.

However if you're evaluating the accuracy reported in the paper, the steps may be a tad bit different for preprocessing for validation:

  • Resize the image by converting the smaller edge to 256, i.e. if size of image is (500, 375), you resize it to (341, 256)
  • Now crop the center square of dimensions (224, 224) from the image
  • Flip the spatial dimensions to BGR if the image was loaded as RGB.
  • Now subtract the spatial mean image that was provided in Kaiming He's github repo. It is in .proto format so if you can convert to .npy it would be excellent, otherwise you can subtract the scalar means. As the image is BGR, subtract (103.939, 116.779, 123.68) from their respective channels.

Now when you validate the model, your Top-5 Accuracy should be around 92% and Top-1 as 74-75%

Thanks @sehgal-abhishek for your detailed answer. I have tried this code instead of calling Keras' ResNet preprocess_image function but for some reason the accuracy actually went down to Top-5 accuracy of 78.3% and Top-1 accuracy of 59%:

        img_shape = image.shape
        # Resize the image by converting the smaller edge to 256
        smaller_edge = min(img_shape[0], img_shape[1])
        ratio = 256/smaller_edge
        img_shape = (int(img_shape[0] * ratio), int(img_shape[1] * ratio), img_shape[2])
        image = cv.resize(image, img_shape[0:2] )

        # Now crop the center square of dimensions (224, 224) from the image
        h, w, c = img_shape
        centre = (h//2, w//2)
        image = cv.getRectSubPix(image, (224,224), centre)

        # Flip the spatial dimensions to BGR if the image was loaded as RGB.
        image = np.flip(image, axis=-1)

        # Now subtract the spatial mean image that was provided in Kaiming He's github repo
        image = image - [103.939, 116.779, 123.68]

@abhishek-sehgal
Copy link

Hi @mostafaelhoushi your code looks perfectly okay to me. The only things i can say is maybe you read the iage as BGR only and you may have not needed to change the shape.

Also I'm not properly familiar with the workings of cv.getRectSubPix.

@calebrob6 has provided a great imagenet validation notebook. You can compare with his code and see.
https://github.com/calebrob6/imagenet_validation/blob/master/1.%20Preprocess%20ImageNet%20validation%20set.ipynb

Let me know if it works out.

@lynn901213
Copy link

Many thanks to @calebrob6 ! Now I got almost the same accuracy as reported on keras.
Anyone managed to get the same validation accuracy for resnet50_v2? I noticed that the preprocessing is "tf" mode but still rescale and crop can also increase the accuracy. But the accuracy I got for resnet_v2 models are always around 5% lower compared to reported accuracy. Anyone tried on that?

@dribnet
Copy link
Contributor

dribnet commented Oct 31, 2020

I'm also noticing a difference in reported accuracy on ilsvrc2012 validation scores across models from what the current documentation reports. This is when using the sample inference code given there:

from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np

model = ResNet50(weights='imagenet')

img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print('Predicted:', decode_predictions(preds, top=3)[0])
# Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]

I've verified that is not channel order, etc. Here's some of the difference's I'm seeing.

Model Reported Top-1 Accuracy Actual Top-1 Accuracy Top-1 diff
DenseNet169 76.2% 73.7% -2.5%
DenseNet201 77.3% 74.6% -2.7%
InceptionV3 77.9% 76.4% -1.5%
ResNet101 76.4% 68.6% -7.8%

This seemed a bit surprising to me (and I'm still looking for a bug in my pipeline to explain this), but if I'm understanding this thread correctly the reported top-1 accuracy scores in the documentation are only attainable with a custom input pre-processing routine - so perhaps a drop is expected and having this reference table of actual accuracies with the given sample code is useful for others.

@mostafaelhoushi
Copy link

FYI... I totally gave up on getting the correct accuracy on Imagenet using Keras 2 years ago and had to switch to PyTorch. Their Imagenet script works out of the box: https://github.com/pytorch/examples/tree/master/imagenet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests