In [1]:
import torch
model = torch.hub.load("bryandlee/animegan2-pytorch", "generator", pretrained="paprika").eval()


Downloading: "https://github.com/bryandlee/animegan2-pytorch/zipball/main" to /home/bowserj/.cache/torch/hub/main.zip
Downloading: "https://github.com/bryandlee/animegan2-pytorch/raw/main/weights/paprika.pt" to /home/bowserj/.cache/torch/hub/checkpoints/paprika.pt


  0%|          | 0.00/8.20M [00:00<?, ?B/s]

In [2]:
model

Generator(
  (block_a): Sequential(
    (0): ConvNormLReLU(
      (0): ReflectionPad2d((3, 3, 3, 3))
      (1): Conv2d(3, 32, kernel_size=(7, 7), stride=(1, 1), bias=False)
      (2): GroupNorm(1, 32, eps=1e-05, affine=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): ConvNormLReLU(
      (0): ReflectionPad2d((0, 1, 0, 1))
      (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (2): GroupNorm(1, 64, eps=1e-05, affine=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): ConvNormLReLU(
      (0): ReflectionPad2d((1, 1, 1, 1))
      (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (2): GroupNorm(1, 64, eps=1e-05, affine=True)
      (3): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (block_b): Sequential(
    (0): ConvNormLReLU(
      (0): ReflectionPad2d((0, 1, 0, 1))
      (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (2): GroupNorm(1, 128, eps=1e-05, affi

In [3]:
from PIL import Image
import numpy as np

import torch
from torchvision.transforms.functional import to_tensor, to_pil_image

In [4]:
img = Image.open('packraft.png').convert("RGB")

In [5]:
image_input = to_tensor(img)

In [6]:
image_input = image_input * 2 - 1
image_input = image_input.unsqueeze(0)

In [7]:
im_out = model.forward(image_input)

In [8]:
out = im_out.squeeze(0).clip(-1, 1) * 0.5 + 0.5

In [9]:
out = to_pil_image(out)

In [10]:
out.show()

In [11]:
import torch.onnx
x = torch.randn(1, 3, 512, 512, requires_grad=True)
torch_out = model(x)

In [12]:
# Open ended Height and Width
# Export the model
torch.onnx.export(model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "animegan_paprika.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=14,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size', 2: 'height', 3 : 'width'},    # variable length axes (required for batching)
                                'output' : {0 : 'batch_size', 2: 'height', 3 : 'width'}})

  if align_corners:
  if align_corners:
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [13]:
# Open ended Height and Width
# Export the model
torch.onnx.export(model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "animegan_paprika_constraned.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=14,          # the ONNX version to export the model to
                  do_constant_folding=True  # whether to execute constant folding for optimization
                  )

In [14]:
from torch.utils.mobile_optimizer import optimize_for_mobile
import torch.backends._nnapi.prepare

# Input to the model
dummy_input = torch.rand(1, 3, 512, 512)
torchscript_model = torch.jit.trace(model, dummy_input)
torchscript_model_optimized = optimize_for_mobile(torchscript_model)
torch.jit.save(torchscript_model_optimized, "animegan2.pt")

In [15]:
chan_last_model = model.to(memory_format=torch.channels_last)

In [16]:
#NHWC Model
dummy_input_last = torch.rand(1, 512, 512, 3)
traced_script_module = torch.jit.trace(chan_last_model, dummy_input)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
torch.jit.save(traced_script_module_optimized, "animegan2_nhwc.pt")

In [17]:
#Vulkan NCHW Backend
nchw_script_module_optimized = optimize_for_mobile(torchscript_model, backend='vulkan')
torch.jit.save(nchw_script_module_optimized, "animegan_vulkan_nchw.pt")

#Vulkan NHWC Backend
traced_script_module_vulkan_optimized = optimize_for_mobile(traced_script_module, backend='vulkan')
torch.jit.save(traced_script_module_vulkan_optimized, "animegan_vulkan_nhwc.pt")


