Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert Model from MXNet to PyTorch #6

Closed
ahkarami opened this issue Nov 29, 2017 · 26 comments
Closed

Convert Model from MXNet to PyTorch #6

ahkarami opened this issue Nov 29, 2017 · 26 comments
Labels

Comments

@ahkarami
Copy link

Dear @kitstar,
Thank you for your nice repository. I have a pre-trained ResNet152 model on MXNet and I want to convert it to PyTorch. Would you please kindly guide me to do that?

@kitstar
Copy link
Contributor

kitstar commented Nov 29, 2017

Hi @ahkarami ,
Pytorch emitter is on-going. Would you mind sharing your models through netdisk so I can try it first?
Thanks.

@kitstar kitstar assigned kitstar and unassigned kitstar Nov 29, 2017
@ahkarami
Copy link
Author

ahkarami commented Nov 29, 2017

Dear @kitstar,
In fact, I want to convert the ResNet152 which is trained on ImageNet11k with 11221 classes. The model can be found in below links:
Large Scale Image Classification
Image Classification MXNet Example
Model Files
I think it would be great if we have this valuable model at PyTorch.
Thanks.

@kitstar
Copy link
Contributor

kitstar commented Nov 30, 2017

I will follow up and tell you if there is any progress. Thanks.

@kitstar
Copy link
Contributor

kitstar commented Dec 18, 2017

Hi @ahkarami , please check the newest code. mxnet -> pytorch resnet152 with 11K tested.

@ahkarami
Copy link
Author

Dear @kitstar,
I have converted the ResNet152 (which is trained on ImageNet11k with 11221 classes) on MXNet to PyTorch. However, the results of MXNet version significantly differ from the Converted one on PyTorch.
I have used the PyTorch 0.2 version and MXNet 0.11.0.
In fact, I did the following steps:
1- python -m mmdnn.conversion._script.convertToIR -f mxnet -n resnet-152-symbol.json -w resnet-152-0000.params -d resnet152 --inputShape 3 224 224
then the produced results were as follows:
IR network structure is saved as [resnet152.json].
IR network structure is saved as [resnet152.pb].
IR weights are saved as [resnet152.npy].

2- python -m mmdnn.conversion._script.IRToCode -f pytorch --IRModelPath resnet152.pb --dstModelPath kit_imagenet.py --IRWeightPath resnet152.npy -dw kit_pytorch.npy
then the produced results were as follows:
Parse file [resnet152.pb] with binary format successfully.
Target network code snippet is saved as [kit_imagenet.py]
Target weights are saved as [kit_pytorch.npy]

3- python -m mmdnn.conversion.examples.pytorch.imagenet_test --dump resnet152Full.pth -n kit_imagenet.py -w kit_pytorch.npy
then the produced results were as follows:
PyTorch model file is saved as [resnet152Full.pth], generated by [kit_imagenet.py] and [kit_pytorch.npy].

I think I have converted the model correctly. After that, I have tested the results of two models (i.e., the original MXNet one and converted PyTorch one (i.e., the resnet152Full.pth)) via my implemented scripts. However, the results were different from each other. Have you ever tested the results of these two models by yourself? and Would you please help me?

@kitstar
Copy link
Contributor

kitstar commented Dec 20, 2017

I tested the conversion with

python -m mmdnn.conversion.examples.pytorch.imagenet_test -n kit_imagenet.py -w kit_pytorch.npy -i mmdnn/conversion/examples/data/seagull.jpg

which use seagull.jpg to test the inference result.

mxnet inference result of top-5:
[(1278, 0.49073416), (1277, 0.21393695), (282, 0.12980066), (1282, 0.0663582), (1224, 0.022041745)]

converted pytorch inference result of top-5:
[(1278, 0.49070838), (1277, 0.21392572), (282, 0.12979434), (1282, 0.066355459), (1224, 0.022040628)]

Since the difference is not so significant, I didn't look into it. (Different implementation leads to difference result).

  1. Could you share the implement scripts and check if you use the same preprocess method?
  2. Not quite sure if my converted model inference result above is acceptable. You have any idea about it?

Thanks.

@ahkarami
Copy link
Author

Dear @kitstar,
Thank you for your follow up. At first, unfortunately when I use this command:
python -m mmdnn.conversion.examples.pytorch.imagenet_test -n kit_imagenet.py -w kit_pytorch.npy -i mmdnn/conversion/examples/data/seagull.jpg
I got this error:

return f(*args, **kwds)
Traceback (most recent call last):
  File "/home/karami/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/karami/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/karami/anaconda3/lib/python3.6/site-packages/mmdnn/conversion/examples/pytorch/imagenet_test.py", line 66, in <module>
    tester.inference(tester.args.image)
  File "/home/karami/anaconda3/lib/python3.6/site-packages/mmdnn/conversion/examples/pytorch/imagenet_test.py", line 45, in inference
    self.preprocess(image_path)
  File "/home/karami/anaconda3/lib/python3.6/site-packages/mmdnn/conversion/examples/pytorch/imagenet_test.py", line 26, in preprocess
    x = super(TestTorch, self).preprocess(image_path)
  File "/home/karami/anaconda3/lib/python3.6/site-packages/mmdnn/conversion/examples/imagenet_test.py", line 166, in preprocess
    func = self.preprocess_func[self.args.s][self.args.preprocess]
KeyError: None

Nevertheless, I will share my tested scripts on both frameworks (i.e., MXNet and PyTorch) in the next comment.

@kitstar
Copy link
Contributor

kitstar commented Dec 20, 2017

Sorry for the lack argument

python -m mmdnn.conversion.examples.pytorch.imagenet_test -n kit_imagenet.py -w kit_pytorch.npy -i mmdnn/conversion/examples/data/seagull.jpg -p resnet152-11k -s mxnet

-s mxnet

and the image is in the git repo.

@ahkarami
Copy link
Author

Dear @kitstar,
Sorry for the inconvenience, I run your mentioned command (with new argument), and I got this result:

return f(*args, **kwds)
[(1278, 0.48384374), (1277, 0.26611671), (282, 0.1002832), (1224, 0.027631452), (1282, 0.025890373)]
Traceback (most recent call last):
  File "/home/karami/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/karami/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/karami/anaconda3/lib/python3.6/site-packages/mmdnn/conversion/examples/pytorch/imagenet_test.py", line 66, in <module>
    tester.inference(tester.args.image)
  File "/home/karami/anaconda3/lib/python3.6/site-packages/mmdnn/conversion/examples/pytorch/imagenet_test.py", line 51, in inference
    self.test_truth()
  File "/home/karami/anaconda3/lib/python3.6/site-packages/mmdnn/conversion/examples/imagenet_test.py", line 192, in test_truth
    assert np.isclose(this_truth[index][1], i[1], atol = 1e-6)
AssertionError

Unfortunately, I got error again! and another amazing thing that the produced results on my machine is a little different from your reported results!!

My System Results:
[(1278, 0.48384374), (1277, 0.26611671), (282, 0.1002832), (1224, 0.027631452), (1282, 0.025890373)]
Your Reported Results:
[(1278, 0.49070838), (1277, 0.21392572), (282, 0.12979434), (1282, 0.066355459), (1224, 0.022040628)]

@kitstar
Copy link
Contributor

kitstar commented Dec 20, 2017

my mxnet version is '0.12.0'
my pytorch version is ''0.4.0a0+7ddcb91'

maybe it matters?

@ahkarami
Copy link
Author

Dear @kitstar,
As previously stated I have used above-mentioned sequence of commands for converting the ResNet152-11k from MXNet to PyTorch via your fantastic repository. I hope my conversion commands is correct. My used sequence of commands are as follows:
1- python -m mmdnn.conversion._script.convertToIR -f mxnet -n resnet-152-symbol.json -w resnet-152-0000.params -d resnet152 --inputShape 3 224 224
2- python -m mmdnn.conversion._script.IRToCode -f pytorch --IRModelPath resnet152.pb --dstModelPath kit_imagenet.py --IRWeightPath resnet152.npy -dw kit_pytorch.npy
3- python -m mmdnn.conversion.examples.pytorch.imagenet_test --dump resnet152Full.pth -n kit_imagenet.py -w kit_pytorch.npy

Now, I want to compare the results of the original MXNet Model (i.e., resnet-152-symbol.json & resnet-152-0000.params) vs the converted PyTorch one (i.e., resnet152Full.pth model) visually.
This is my used script on MXNet:

import mxnet as mx
import matplotlib.pyplot as plt
import cv2
import numpy as np
from collections import namedtuple


# Load & Download the ResNet152 trained of the Full ImageNet:

# path='http://data.mxnet.io/models/imagenet-11k/'
# [mx.test_utils.download(path+'resnet-152/resnet-152-symbol.json'),
#  mx.test_utils.download(path+'resnet-152/resnet-152-0000.params'),
#  mx.test_utils.download(path+'synset.txt')]


# Load the ResNet152 trained of the Full ImageNet:

sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)
mod = mx.mod.Module(symbol=sym, context=mx.gpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))],
         label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
with open('synset.txt', 'r') as f:
    labels = [l.rstrip() for l in f]


# ** Predicting:

# define a simple data batch
Batch = namedtuple('Batch', ['data'])

def get_image(imageAddress, show=True):
    # load and show the image
    img = cv2.cvtColor(cv2.imread(imageAddress), cv2.COLOR_BGR2RGB)
    if img is None:
         return None
    if show:
         plt.imshow(img)
         plt.axis('off')
         plt.show()
    # convert into format (batch, RGB, width, height)
    img = cv2.resize(img, (224, 224))
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)
    img = img[np.newaxis, :]
    return img

def predict(imageAddress):
    img = get_image(imageAddress, show=True)
    # compute the predict probabilities
    mod.forward(Batch([mx.nd.array(img)]))
    prob = mod.get_outputs()[0].asnumpy()
    # print the top-5
    prob = np.squeeze(prob)
    a = np.argsort(prob)[::-1]
    for i in a[0:5]:
        print(i)
        print('probability=%f, class=%s' %(prob[i], labels[i]))


# *** Test Model:
sampleInputImage = 'seagull.jpg'
predict(sampleInputImage)


And this is the produced Results:

1278
probability=0.536169, class=n02041246 gull, seagull, sea gull
1277
probability=0.104755, class=n02041085 larid
282
probability=0.098345, class=n01517966 carinate, carinate bird, flying bird
1280
probability=0.081313, class=n02041875 black-backed gull, great black-backed gull, cob, Larus marinus
1281
probability=0.051755, class=n02042046 herring gull, Larus argentatus

And this my used PyTorch script:

import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.nn import functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.misc import imread

# Parameter:
num_predictions = 5


# Load Model
bestModelAddress = 'resnet152Full.pth' # for loading models
model = torch.load(bestModelAddress).cuda()
model.eval()


# Load One Input Image
test_image_address = 'seagull.jpg'

# ***** Show the original image
img = imread(test_image_address)
plt.subplot(1, 1, 1)
plt.axis('off') # turn off axis

# ***** Image pre-processing transforms
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])


img_pil = Image.open(test_image_address)
img_tensor = preprocess(img_pil)
img_tensor.unsqueeze_(0)
img_tensor = Variable(img_tensor).cuda()


# Load Full-Imagenet Synset:
# *** Note: You can download it from:
# http://data.mxnet.io/models/imagenet-11k/synset.txt
with open('synset.txt', 'r') as f:
    labels = [l.rstrip() for l in f]


# Make predictions
output = model(img_tensor) # size(1, 11,221)
max, argmax = output.data.squeeze().max(0)
class_id = argmax[0]
classname = labels[class_id]
# print(class_id)

# print the top-5
h_x = F.softmax(output).data.squeeze()
probs, idx = h_x.sort(0, True)

print('Result: ')
# output the prediction
for i in range(0, num_predictions):
    print('{:.3f} -> {}'.format(100 * probs[i], labels[idx[i]]))



print('The Image is a', classname)

StrLabel = 'The Image is a ' + classname

# ***** Show the result & Image:
ax = plt.subplot(1, 1, 1)
# str_confidence = str("{0:.2f}".format(class_probabilities_cpu[0][max_index]))
ax.set_title(StrLabel)
plt.imshow(img)

plt.show()

And this is the PyTorch produced Results:

0.013 -> n04960277 black, blackness, inkiness
0.012 -> n04960582 coal black, ebony, jet black, pitch black, sable, soot black
0.009 -> n13896217 crescent
0.009 -> n03728437 match, lucifer, friction match
0.009 -> n02846141 black

So It is obviously clear that the results are significantly different from each other. It is worth noting that I want to use the resnet152Full.pth model on PyTorch. And I also tested some other images, but the results are very different.
I hope that my conversion commands would be correct. If you can help me to tackle this issue I will be very appreciated.
Thank you.

@kitstar
Copy link
Contributor

kitstar commented Dec 20, 2017

From your code, mxnet uses the original graph for training. But your pytorch code normalizes the input image with std and mean array. If you want to get the same result, you could try to remove the normalize part of preprocess.

And not quite sure about the transforms.Scale(256) part, maybe you can try to remove it too. You can refer our testing method to implement your inference code.

@ahkarami
Copy link
Author

ahkarami commented Dec 20, 2017

Dear @kitstar,
I also removed the normalization part, but the results were different from two frameworks.
Maybe the problem related to the dataflow of these two frameworks (i.e., NHWC(channel last) or NCHW(channel first) format). What's your opinion?

@kitstar
Copy link
Contributor

kitstar commented Dec 20, 2017

Possible. I am not familiar with PyTorch preprocess method. Have you also removed the scale part?
Could you can try my preprocess methd first, like:

from tensorflow.contrib.keras.api.keras.preprocessing import image
img = image.load_img(path, target_size = (224, 224))
x = image.img_to_array(img)
x = x[..., ::-1]     # In my test, I transform image from RGB --> BGR in both mxnet and pytorch
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0).copy()
x = torch.from_numpy(x)
x = torch.autograd.Variable(x, requires_grad = False)
output = model(x)

@ahkarami
Copy link
Author

Dear @kitstar,
Thank you very much for your help. I have used your suggested preprocessing and the problem has been addressed. However, the results from these two frameworks are just a little different from each other, but I think it is normal.
I think the main problem was related to the RGB format of input image. When I have changed it to the BGR format and remove the normalization preprocess the problem has been addressed.
Thank you very much

@kitstar
Copy link
Contributor

kitstar commented Dec 21, 2017

Glad I could help.

@kitstar kitstar closed this as completed Dec 21, 2017
@yuzcccc
Copy link

yuzcccc commented May 23, 2018

Hi, @kitstar
I am also interested in this imagenet-11k model in MxNet, and I want to transform it to Caffe. Could you please give me some hits to do this?

Besides, does running the following script need tensorflow?

python -m mmdnn.conversion.examples.pytorch.imagenet_test -n kit_imagenet.py -w kit_pytorch.npy -i mmdnn/conversion/examples/data/seagull.jpg

since the following module use the image processing module in Keras

from mmdnn.conversion.examples.imagenet_test import TestKit

@kitstar
Copy link
Contributor

kitstar commented May 23, 2018

Hi @yuzcccc,

  1. Not sure if Caffe covers all operators of mxnet imagnet-11k model. You can try to use mmconvert cmd to do the conversion and post the conversion result.

  2. Yes it is. But we could use PIL module instead.

@yuzcccc
Copy link

yuzcccc commented May 23, 2018

I use the the mmconvert and successfully transform the mxnet model to caffe, however, I find a very strange phenomenon in the generated caffe's prototxt

layer {
  name: "pooling0"
  type: "Pooling"
  bottom: "bn0"
  top: "pooling0"
  pooling_param {
    pool: MAX
    kernel_size: 3
    stride: 2
    pad_h: 1
    pad_w: 1
  }
}
layer {
  name: "DummyData1"
  type: "DummyData"
  top: "DummyData1"
  dummy_data_param {
    shape {
      dim: 1
      dim: 64
      dim: 56
      dim: 56
    }
  }
}
layer {
  name: "pooling0_crop"
  type: "Crop"
  bottom: "pooling0"
  bottom: "DummyData1"
  top: "pooling0_crop"
}

Why there is such a crop operation? In the original caffe's implementation, not such part, and also I could not find this part in the MxNet's json file

@yuzcccc
Copy link

yuzcccc commented May 23, 2018

The command I use is

mmconvert -sf mxnet -in ../model_mxnet/resnet-152-symbol.json -iw ../model_mxnet/resnet-152-0000.params -df caffe -om resnet-152-11k --inputShape 3 224 224

@kitstar
Copy link
Contributor

kitstar commented May 23, 2018

It is for eliminating the padding algorithm difference between caffe and other frameworks.

@yuzcccc
Copy link

yuzcccc commented May 23, 2018

Oh, I find the difference. Caffe's ResNet usually do not introduce pad in this pooling layer. However, I am wondering that whether I can remove the dummydata and crop layer, and simple remove the padding(pad=1) param manually from the prototxt? Would this influence the results?
Since the dummydata layer gives a fixed shape [56,56], which is not suitable when input image size is not fixed.

ps. Do I have to worry about the RGB and BGR problem for the generated caffemodel?

@kitstar
Copy link
Contributor

kitstar commented May 23, 2018

  1. You could have a try about it.

  2. I think not. Just follow mxnet preprocess method.

@yuzcccc
Copy link

yuzcccc commented May 23, 2018

thanks a lot!

@back2yes
Copy link

I followed the instructions and it worked very well. Great thanks!

@ahkarami
Copy link
Author

I think this link would be helpful:
Convert Full ImageNet Pre-trained Model from MXNet to PyTorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants