In [14]:
import io
import numpy as np

from torch import  nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

In [15]:
import torch.nn as nn
import torch.nn.init as init

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()
        
        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5,5), (1,1), (2,2))
        self.conv2 = nn.Conv2d(64, 64, (3,3), (1,1), (1,1))
        self.conv3 = nn.Conv2d(64, 32, (3,3), (1,1), (1,1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3,3), (1,1), (1,1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        
        self._initialize_weights()
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x
    
    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        
torch_model = SuperResolutionNet(upscale_factor=3)

In [16]:
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # just a random number

map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

torch_model.train(False)

SuperResolutionNet(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

In [17]:
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

torch_out = torch.onnx._export(torch_model,x,"model/super_resolution.onnx",export_params=True)

In [20]:
import onnx
import caffe2.python.onnx.backend as onnx_caffe2_backend

# Load the ONNX ModelProto object. model is a standard Python protobuf object
model = onnx.load("model/super_resolution.onnx")

# prepare the caffe2 backend for executing the model this converts the ONNX model into a
# Caffe2 NetDef that can execute it. Other ONNX backends, like one for CNTK will be
# availiable soon.
prepared_backend = onnx_caffe2_backend.prepare(model)

# run the model in Caffe2

# Construct a map from input names to Tensor data.
# The graph of the model itself contains inputs for all weight parameters, after the input image.
# Since the weights are already embedded, we just need to pass the input image.
# Set the first input.
W = {model.graph.input[0].name: x.data.numpy()}

# Run the Caffe2 net:
c2_out = prepared_backend.run(W)[0]

# Verify the numerical correctness upto 3 decimal places
np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=2)

print("Exported model has been executed on Caffe2 backend, and the result looks good!")

AssertionError: 
Arrays are not almost equal to 2 decimals

(mismatch 98.8934506094%)
 x: array([[[[ 0.36,  0.41,  0.26, ...,  1.01,  1.15,  1.04],
         [ 0.42,  0.51,  0.34, ...,  1.29,  1.46,  1.18],
         [ 0.23,  0.33,  0.23, ...,  1.09,  1.25,  0.98],...
 y: array([[[[ 0.36,  0.07, -0.11, ...,  0.13,  0.8 ,  1.04],
         [-0.02, -0.26, -0.31, ...,  1.12,  0.5 ,  0.86],
         [-0.22,  0.65, -0.25, ...,  0.43, -0.31,  0.56],...

In [22]:
c2_workspace = prepared_backend.workspace
c2_model = prepared_backend.predict_net

from caffe2.python.predictor import mobile_exporter
init_net, predict_net = mobile_exporter.Export(c2_workspace, c2_model, c2_model.external_input)

with open('model/init_net.pb','wb') as fopen:
    fopen.write(init_net.SerializeToString())
with open('model/predict_net.pb', "wb") as fopen:
    fopen.write(predict_net.SerializeToString())

In [23]:
from caffe2.proto import caffe2_pb2
from caffe2.python import core, net_drawer, net_printer, visualize, workspace, utils

import numpy as np
import os
import subprocess
from PIL import Image
from matplotlib import pyplot
from skimage import io, transform

In [24]:
img_in = io.imread('data/cat_224x224.jpg')

img = transform.resize(img_in, [224,224])



  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


In [25]:
img = Image.open('data/cat_224x224.jpg')
img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()

workspace.RunNetOnce(init_net)
workspace.RunNetOnce(predict_net)

print(net_printer.to_string(predict_net))

torch-jit-export_predict = core.Net('torch-jit-export_predict')
torch-jit-export_predict.Conv(['0', '1', '2'], ['9'], strides=[1L, 1L], pads=[2L, 2L, 2L, 2L], kernels=[5L, 5L], group=1, dilations=[1L, 1L])
torch-jit-export_predict.Relu(['9'], ['10'])
torch-jit-export_predict.Conv(['10', '3', '4'], ['11'], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], kernels=[3L, 3L], group=1, dilations=[1L, 1L])
torch-jit-export_predict.Relu(['11'], ['12'])
torch-jit-export_predict.Conv(['12', '5', '6'], ['13'], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], kernels=[3L, 3L], group=1, dilations=[1L, 1L])
torch-jit-export_predict.Relu(['13'], ['14'])
torch-jit-export_predict.Conv(['14', '7', '8'], ['15'], strides=[1L, 1L], pads=[1L, 1L, 1L, 1L], kernels=[3L, 3L], group=1, dilations=[1L, 1L])
torch-jit-export_predict.Reshape(['15', '16'], ['17', 'OC2_DUMMY_0'])
torch-jit-export_predict.Transpose(['17'], ['18'], axes=[0L, 1L, 4L, 2L, 5L, 3L])
torch-jit-export_predict.Reshape(['18', '19'], ['20', 'OC2_DUMMY_1'])


In [27]:
workspace.FeedBlob("9", np.array(img_y)[np.newaxis, np.newaxis, :, :].astype(np.float32))

# run the predict_net to get the model output
workspace.RunNetOnce(predict_net)

# Now let's get the model output blob
img_out = workspace.FetchBlob("20")

In [28]:
img_out_y = Image.fromarray(np.uint8((img_out[0, 0]).clip(0, 255)), mode='L')

# get the output image follow post-processing step from PyTorch implementation
final_img = Image.merge(
    "YCbCr", [
        img_out_y,
        img_cb.resize(img_out_y.size, Image.BICUBIC),
        img_cr.resize(img_out_y.size, Image.BICUBIC),
    ]).convert("RGB")

# Save the image, we will compare this with the output image from mobile device
final_img.save("cat_superres.jpg")