In [37]:
import torch
from torchvision import transforms
import tensorflow as tf
from pytorch_models.model_arch import TransformerNet
import onnx
import onnxruntime

In [30]:
from PIL import Image

def load_image(filename, size=None, scale=None):
    img = Image.open(filename).convert('RGB')
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
    return img


def save_image(filename, data):
    img = data.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    img = Image.fromarray(img)
    img.save(filename)


In [21]:
model_checkpoint = "./pytorch_models/rain_princess.pth"
state_dict = torch.load(model_checkpoint)


odict_keys(['conv1.conv2d.weight', 'conv1.conv2d.bias', 'in1.weight', 'in1.bias', 'in1.running_mean', 'in1.running_var', 'conv2.conv2d.weight', 'conv2.conv2d.bias', 'in2.weight', 'in2.bias', 'in2.running_mean', 'in2.running_var', 'conv3.conv2d.weight', 'conv3.conv2d.bias', 'in3.weight', 'in3.bias', 'in3.running_mean', 'in3.running_var', 'res1.conv1.conv2d.weight', 'res1.conv1.conv2d.bias', 'res1.in1.weight', 'res1.in1.bias', 'res1.in1.running_mean', 'res1.in1.running_var', 'res1.conv2.conv2d.weight', 'res1.conv2.conv2d.bias', 'res1.in2.weight', 'res1.in2.bias', 'res1.in2.running_mean', 'res1.in2.running_var', 'res2.conv1.conv2d.weight', 'res2.conv1.conv2d.bias', 'res2.in1.weight', 'res2.in1.bias', 'res2.in1.running_mean', 'res2.in1.running_var', 'res2.conv2.conv2d.weight', 'res2.conv2.conv2d.bias', 'res2.in2.weight', 'res2.in2.bias', 'res2.in2.running_mean', 'res2.in2.running_var', 'res3.conv1.conv2d.weight', 'res3.conv1.conv2d.bias', 'res3.in1.weight', 'res3.in1.bias', 'res3.in1.runni

In [24]:
model_checkpoint = "./saved_models/rain_princess.pth"
state_dict = torch.load(model_checkpoint)

keys_to_del = []

for i in range(1, 6):
    keys_to_del.extend([
        f"in{i}.running_mean",
        f"in{i}.running_var",
    ])

for i in range(1,3):
    for r in range(1, 6):
        keys_to_del.extend([
            f"res{r}.in{i}.running_mean",
            f"res{r}.in{i}.running_var"
        ])

for k in keys_to_del:
    del state_dict[k]

print(state_dict.keys())

In [25]:
model = TransformerNet()
model.load_state_dict(state_dict)

<All keys matched successfully>

In [26]:
model.eval()

TransformerNet(
  (conv1): ConvLayer(
    (reflection_pad): ReflectionPad2d((4, 4, 4, 4))
    (conv2d): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
  )
  (in1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv2): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (in2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv3): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  )
  (in3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (res1): ResidualBlock(
    (conv1): ConvLayer(
      (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
      (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
    (in1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (

In [31]:
# Just some random number
batch_size = 1
height = 224
width = 224

x = torch.randn(batch_size, 3, height, width, requires_grad=True)
output = model(x)

# Convert the PyTorch model to ONNX format
torch.onnx.export(model,                     # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "./temp/rain_princess.onnx", # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0: "batch_size", 2: "height", 3: "width"},    # variable lenght axes
                                'output' : {0: "batch_size",  2: "height", 3: "width"}})


In [34]:
def stylize_onnx(model_path, content_image_path, output_image_path):
    """
    Read ONNX model and run it using onnxruntime
    """
    def to_numpy(tensor):
        if tensor.requires_grad:
            return tensor.detach().cpu().numpy()    
        else: 
            return tensor.cpu().numpy()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    assert model_path.endswith(".onnx")
    ort_session = onnxruntime.InferenceSession(model_path)
    
    content_image = load_image(content_image_path, scale=None)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(content_image)}
    ort_outs = ort_session.run(None, ort_inputs)
    img_out_y = torch.from_numpy(ort_outs[0])

    save_image(output_image_path, img_out_y[0])

In [40]:
stylize_onnx("./temp/rain_princess.onnx", "./temp/test.jpg", "./temp/test_onnx.jpg")