In [None]:
%cd ..

In [None]:
%pwd

In [None]:
API_KEY = ""
!mkdir -p .ivy
!echo -n $API_KEY > .ivy/key.pem
%cd ..

In [None]:
import jax
import haiku as hk
import ivy
import torch
import requests
import numpy as np
from PIL import Image
import traceback

import torchvision
from mmpretrain import get_model, list_models, inference_model
from mmengine import ConfigDict

jax.config.update("jax_enable_x64", True)

In [None]:
def get_scale(cfg):
  if type(cfg) == ConfigDict:
    if cfg.get('type', False) and cfg.get('scale', False):
      return cfg['scale']
    else:
      for k in cfg.keys():
        input_shape = get_scale(cfg[k])
        if input_shape:
          return input_shape
  elif type(cfg) == list:
    for block in cfg:
      input_shape = get_scale(block)
      if input_shape:
        return input_shape
  else:
    return None

In [None]:
tested_model_archs = []
def to_test(model_name):
    if '-base' in model_name:
        short = model_name.split("-base")[0]
        if short in tested_model_archs:
            return False
        else:
            tested_model_archs.append(short)
            return True
    short = model_name.split("-")[0]
    if short in tested_model_archs:
        return False
    else:
        tested_model_archs.append(short)
        return True

In [None]:
start_from = 35
for i, model_name in enumerate(list_models()[start_from:]):
  print(f'testing {model_name} -> {i+start_from}')
  if not to_test(model_name):
    print('skipped because this arch already tested before')
    continue
  try:
    model = get_model(model_name, pretrained=True)
  except Exception as e:
    print(f'model was skipped due to {traceback.format_exc()}')
    continue
  input_shape = get_scale(model._config.train_pipeline)
  assert type(input_shape) == int, 'input shape was not detected'
  url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  image = Image.open(requests.get(url, stream=True).raw)
  transform = torchvision.transforms.Compose([
      torchvision.transforms.Resize((input_shape, input_shape)),
      torchvision.transforms.ToTensor()
  ])
  tensor_image = transform(image).unsqueeze(0)
  print('transpiling..')
  transpiled_graph = ivy.transpile(model, to="haiku", args=(tensor_image,))

  tensor_image = transform(image).unsqueeze(0)

  def _f(args):
    return model(args)

  comp_model = torch.compile(_f)
  _ = comp_model(tensor_image)

  tensor_image = transform(image).unsqueeze(0)
  np_image = tensor_image.detach().cpu().numpy()

  def _forward(args):
    module = transpiled_graph()
    return module(args)

  _forward = jax.jit(_forward)
  rng_key = jax.random.PRNGKey(42)
  jax_mlp_forward = hk.transform(_forward)
  params = jax_mlp_forward.init(rng=rng_key, args=np_image)

  url = "http://images.cocodataset.org/train2017/000000283921.jpg"
  image = Image.open(requests.get(url, stream=True).raw)
  tensor_image = transform(image).unsqueeze(0)
  np_image = tensor_image.detach().cpu().numpy()
  out_torch = comp_model(tensor_image)
  out_jax = jax_mlp_forward.apply(params, None, np_image)

  if type(out_torch) == torch.Tensor:
    print(np.allclose(out_torch.detach().cpu().numpy(), out_jax, atol=1e-3))
  else:
    print('Fancy output detected. Vverify manually depending on output structure')
  del model, transpiled_graph, comp_model, params, jax_mlp_forward
  del tensor_image, np_image
  del out_torch, out_jax