-
Notifications
You must be signed in to change notification settings - Fork 275
-
Notifications
You must be signed in to change notification settings - Fork 275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inference issuse #32
Comments
Hi @deep-practice , please refer to the 'validate()' function in train.py (Line 241), there is an extra 'softmax' in your script. |
I use softmax to get the probability of each class.It doesn't affect the results. |
And I just take the pipeline of validate as a reference. |
The transform is the same with us. |
I can get the correct result with mobilenet using the same picture.I just download your code and add a demo function.Nothing changed.So I'm confused what's wrong with my demo. |
Try removing ‘ToBGRTensor’, if your image is already in the BGR format. |
I use PIL.Image module to read image,so I believe the input format is RGB. |
When inference using ShuffleNetV2,the class score is low and the class label is wrong.I use the following code:
transform = tansforms.Compose([
OpencvResize(256),
tansforms.CenterCrop(224),
ToBGRTensor(),
])
model = ShuffleNetV2(model_size='2.0x')
model.load_state_dict(remove_prefix(torch.load("ShuffleNetV2.2.0x.pth.tar")["state_dict"],"module."))
model.eval()
image = Image.open("../Image/cat.jpg")
img = transform(image).unsqueeze(0)
output = model(img)
output=torch.softmax(output,dim=1)
vals,idxs=torch.max(output, dim=1)
print(vals[0].item(),idxs[0].item())
Output:
0.31983616948127747 287
287 means 'lynx, catamount'
Do I preprocess the input image incorecctly?
The text was updated successfully, but these errors were encountered: