<a href="https://colab.research.google.com/github/eyaler/LordTubeMaster/blob/main/models/teed/teed2onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import locale
locale.getpreferredencoding = lambda: 'UTF-8'

%cd /content
!git clone --depth=1 https://github.com/eyaler/TEED
!pip install kornia==0.7.3
!pip install onnx==1.16.2
!pip install numpy==1.26.4
!pip install onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
!pip install onnxconverter-common==1.14.0 --no-deps
!wget -nc https://upload.wikimedia.org/wikipedia/commons/a/a9/Hong_Kong_Night_view.jpg -O TEED/data/image.jpg

In [None]:
#@title Export to ONNX (with pre/post-processing)

%cd /content/TEED


import torch
import torch.nn as nn
import torch.nn.functional as F

from ted import TED


device = 'cuda'


class MyTED(nn.Module):
    def __init__(self):
        super().__init__()
        self.ted_model = TED()
        self.ted_model.load_state_dict(torch.load('checkpoints/BIPED/7/7_model.pth', map_location=device))

    def forward(self, x):
        x = x.permute(0, 3, 1, 2).to(torch.float32)
        h, w = x.shape[2:]
        pad = F.pad(x, (0, (w//4+1)*4 - w, 0, (h//4+1)*4 - h))
        block_cat = self.ted_model(pad)[-1]
        crop = block_cat.squeeze()[:h, :w]
        sigmoid = torch.sigmoid(crop)
        minimum = sigmoid.min()
        return (sigmoid-minimum) * 255 / (sigmoid.max()-minimum+1e-12)


model = MyTED().to(device)

dummy_input = torch.randint(0, 255, (1, 1080, 1920, 3), dtype=torch.uint8, device=device)

torch.onnx.export(model,
                  (dummy_input),
                  'teed.onnx',
                  input_names=['input'],
                  output_names=['output'],
                  dynamic_axes=dict(input={1 : 'height', 2: 'width'}, output={0 : 'height', 1: 'width'}),
                 )


import onnx
from onnxconverter_common import float16


model = onnx.load('teed.onnx')
model16 = float16.convert_float_to_float16(model, keep_io_types=True)
onnx.save_model(model16, 'teed16.onnx')

In [None]:
#@title ONNX Runtime Python

%cd /content/TEED


from time import time

import cv2
from google.colab.patches import cv2_imshow
import numpy as np
import onnxruntime as ort

bench_iters = 100
image_path = 'data/image.jpg'
image = cv2.imread(image_path, cv2.IMREAD_COLOR)[..., ::-1]
# image = cv2.resize(image, (1917, 1077))  # Odd dimensions for debugging
image = image[None, :]

ort.set_default_logger_severity(0)
ort_session = ort.InferenceSession('teed16.onnx', providers=['CUDAExecutionProvider'])

bench_iters = max(bench_iters, 2)
for i in range(bench_iters):
  outputs = ort_session.run(None, {'input': image})
  output = np.dstack(outputs[:1] * 3)
  if not i:
    start_time = time()
print(f'{outputs[0].shape[1]}x{outputs[0].shape[0]} {(time()-start_time) * 1000 / (bench_iters-1) :.0f}ms/iter')

cv2.imwrite('out.jpg', output)
cv2_imshow(output)

In [None]:
!ln -sf /usr/local/share/jupyter/nbextensions /nbextensions
!cp /content/TEED/teed.onnx /nbextensions/teed.onnx
!cp /content/TEED/teed16.onnx /nbextensions/teed16.onnx
!cp /content/TEED/data/image.jpg /nbextensions/image.jpg

In [None]:
#@title ONNX Runtime Wasm

%%html
<p id="stats"><p>
<canvas id="canvas" style="width: 100%"></canvas>
<script type="module">
  import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.18.0/dist/esm/ort.wasm-core.min.js'
  ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.18.0/dist/'

  const teed = await ort.InferenceSession.create('/nbextensions/teed.onnx')

  let bench_iters = 10

  const ctx = canvas.getContext('2d')
  const img = new Image()
  img.crossOrigin = 'anonymous'

  async function predict() {
    const w = canvas.width = img.width
    const h = canvas.height = img.height
    ctx.drawImage(img, 0, 0)
    const rgba = ctx.getImageData(0, 0, w, h).data
    const rgba_out = rgba.slice()
    ctx.clearRect(0, 0, w, h)
    let start_time
    bench_iters = Math.max(bench_iters, 2)
    for (let n = 0; n < bench_iters; n++) {
      const bgr = new Uint8Array(h * w * 3)
      for (let i = 0; i < bgr.length; i++)
        bgr[i] = rgba[(i/3|0)*4 + 2 - i%3]
      const result = await teed.run({input: new ort.Tensor(bgr, [1, h, w, 3])})
      for (let i = 0; i < result.output.data.length; i++)
        rgba_out[i * 4] = rgba_out[i*4 + 1] = rgba_out[i*4 + 2] = result.output.data[i]
      if (!n)
        start_time = performance.now()
    }
    stats.textContent = w + 'x' + h + ' ' + ((performance.now()-start_time)/(bench_iters-1)|0) + 'ms/iter'
    ctx.putImageData(new ImageData(rgba_out, w, h), 0, 0)
  }

  img.addEventListener('load', () => predict())
  img.src = '/nbextensions/image.jpg'
</script>

In [None]:
#@title ONNX Runtime WebGPU

%%html
<p id="stats"><p>
<canvas id="canvas" style="width: 100%"></canvas>
<script type="module">
  import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.18.0/dist/esm/ort.webgpu.min.js'
  ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.18.0/dist/'

  let teed
  try {
    teed = await ort.InferenceSession.create('/nbextensions/teed16.onnx', {executionProviders: ['webgpu']})
  } catch (e) {
    if (e.message.includes('webgpu'))
      stats.textContent = 'WebGPU not supported.'
    throw e
  }

  let bench_iters = 100

  const ctx = canvas.getContext('2d')
  const img = new Image()
  img.crossOrigin = 'anonymous'

  async function predict() {
    const w = canvas.width = img.width
    const h = canvas.height = img.height
    ctx.drawImage(img, 0, 0)
    const rgba = ctx.getImageData(0, 0, w, h).data
    const rgba_out = rgba.slice()
    ctx.clearRect(0, 0, w, h)
    let start_time
    bench_iters = Math.max(bench_iters, 2)
    for (let n = 0; n < bench_iters; n++) {
      const bgr = new Uint8Array(h * w * 3)
      for (let i = 0; i < bgr.length; i++)
        bgr[i] = rgba[(i/3|0)*4 + 2 - i%3]
      const result = await teed.run({input: new ort.Tensor(bgr, [1, h, w, 3])})
      for (let i = 0; i < result.output.data.length; i++)
        rgba_out[i * 4] = rgba_out[i*4 + 1] = rgba_out[i*4 + 2] = result.output.data[i]
      if (!n)
        start_time = performance.now()
    }
    stats.textContent = w + 'x' + h + ' ' + ((performance.now()-start_time)/(bench_iters-1)|0) + 'ms/iter'
    ctx.putImageData(new ImageData(rgba_out, w, h), 0, 0)
  }

  img.addEventListener('load', () => predict())
  img.src = '/nbextensions/image.jpg'
</script>