In [3]:
# Initialise packages 
from model import U2NET
import coremltools as ct
from coremltools.proto import FeatureTypes_pb2 as ft
import torch
import os
from PIL import Image
from torchvision import transforms

In [4]:
# Initialise our model.
net = U2NET(3,1)
device = torch.device('cpu')
model_dir = os.path.join(os.getcwd(), 'saved_models', "u2net", "u2net" + '.pth')
net.load_state_dict(torch.load(model_dir, map_location=device))
net.cpu()
net.eval()

U2NET(
  (stage1): RSU7(
    (rebnconvin): REBNCONV(
      (conv_s1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (rebnconv1): REBNCONV(
      (conv_s1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv2): REBNCONV(
      (conv_s1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv3): REBNCONV(
      (conv_s1): Conv2d(32, 32, k

In [None]:
# Create a test input.

# Specify an image as input here
input_image = Image.open("im_01.png").resize((320,320))
example_input = transforms.ToTensor()(input_image).unsqueeze_(0)
example_input = example_input.type(torch.FloatTensor)

# or uncomment the below line to use a random Tensor instead.
# example_input = torch.rand(1,3,320,320)

In [None]:
# Trace and convert the model.
traced_model = torch.jit.trace(net, example_input)
model = ct.convert(traced_model, inputs=[ct.ImageType(name="input_1", shape=example_input.shape)])

In [None]:
# Add metadata
model.short_description = "U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection"
model.license = "Apache 2.0"
model.author = "Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Dehghan, Masood and Zaiane, Osmar and Jagersand, Martin"

In [None]:
# Rename inputs
spec = model.get_spec()
ct.utils.rename_feature(spec, "input_1", "in_0")
ct.utils.rename_feature(spec,"2179","out_a0")
ct.utils.rename_feature(spec,"2180","out_a1")
ct.utils.rename_feature(spec,"2181","out_a2")
ct.utils.rename_feature(spec,"2182","out_a3")
ct.utils.rename_feature(spec,"2183","out_a4")
ct.utils.rename_feature(spec,"2184","out_a5")
ct.utils.rename_feature(spec,"2185","out_a6")
model = ct.models.MLModel(spec)
model.save("u2netp.mlmodel")

In [None]:
# Re-open model for modification and append new output layers.
model = ct.models.MLModel("u2netp.mlmodel")
spec = model.get_spec()
spec_layers = getattr(spec, spec.WhichOneof("Type")).layers
output_layers = spec_layers[476:] # Get only the last output layers, may change with full-size U^2net
new_layers = []
layernum = 0;
for layer in output_layers:
    new_layer = spec_layers.add()
    new_layer.name = 'out_p'+str(layernum)
    new_layers.append('out_p'+str(layernum))

    new_layer.activation.linear.alpha=255
    new_layer.activation.linear.beta=0

    new_layer.input.append('out_a'+str(layernum))
    new_layer.output.append('out_p'+str(layernum))
    output_description = next(x for x in spec.description.output if x.name==output_layers[layernum].output[0])
    output_description.name = new_layer.name
    
    layernum = layernum + 1

In [None]:
# Specify the outputs as grayscale images.
for output in spec.description.output: 
    if output.name not in new_layers: 
        continue
    if output.type.WhichOneof('Type') != 'multiArrayType': 
        raise ValueError("%s is not a multiarray type" % output.name) 
    output.type.imageType.colorSpace = ft.ImageFeatureType.ColorSpace.Value('GRAYSCALE')
    output.type.imageType.width = 320 
    output.type.imageType.height = 320

In [None]:
# Save our new model
updated_model = ct.models.MLModel(spec)
updated_model.save("u2netp.mlmodel")

In [None]:
# Test our model.
out_dict = updated_model.predict({'in_0': input_image})
for key,value in out_dict.items():
    value.save(key+".png")