<a href="https://colab.research.google.com/github/Doji-Technologies/com.doji.midas/blob/master/tools/MiDaS_ONNX_Export.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Environment Setup
!git clone https://github.com/semjon00/MiDaS
!git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit
%cd MiDaS
!pip install torch==1.13 torchvision==0.14.0
!pip install timm==0.6.13
!pip install einops==0.6.0

In [None]:
# @title Export all models to ONNX
import cv2
import torch
import utils
from midas.dpt_depth import DPTDepthModel
from midas.midas_net_custom import MidasNet_small
from midas.midas_net import MidasNet
import os
import requests
import gc

def download_file(url, folder_path):
    # Create the folder if it doesn't exist
    os.makedirs(folder_path, exist_ok=True)

    # Get the file name from the URL
    file_name = url.split("/")[-1]

    # Combine the folder path and file name to get the full file path
    file_path = os.path.join(folder_path, file_name)

    # Check if the file already exists in the folder
    if os.path.exists(file_path):
        print(f"File already downloaded: {file_path}")
    else:
        # Send an HTTP GET request to the URL
        response = requests.get(url)

        # Check if the request was successful
        if response.status_code == 200:
            # Open the file and write the content from the response
            with open(file_path, 'wb') as file:
                file.write(response.content)
            print(f"File downloaded and saved to: {file_path}")
        else:
            print(f"Failed to download the file. HTTP status code: {response.status_code}")


def patchUnflatten():
    import torch.nn as nn

    class View(nn.Module):
        def __init__(self, dim,  shape):
            super(View, self).__init__()
            self.dim = dim
            self.shape = shape

        def forward(self, input):
            new_shape = list(input.shape)[:self.dim] + list(self.shape) + list(input.shape)[self.dim+1:]
            return input.view(*new_shape)

    nn.Unflatten = View

model_params = [
    {
        "name": "dpt_beit_large_512",
        "path": "weights/dpt_beit_large_512.pt",
        "backbone": "beitl16_512",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt"
    },
    {
        "name": "dpt_beit_large_384",
        "path": "weights/dpt_beit_large_384.pt",
        "backbone": "beitl16_384",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt"
    },
    {
        "name": "dpt_beit_base_384",
        "path": "weights/dpt_beit_base_384.pt",
        "backbone": "beitb16_384",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt"
    },
    {
        "name": "dpt_swin2_large_384",
        "path": "weights/dpt_swin2_large_384.pt",
        "backbone": "swin2l24_384",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt"
    },
    {
        "name": "dpt_swin2_base_384",
        "path": "weights/dpt_swin2_base_384.pt",
        "backbone": "swin2b24_384",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt"
    },
    {
        "name": "dpt_swin2_tiny_256",
        "path": "weights/dpt_swin2_tiny_256.pt",
        "backbone": "swin2t16_256",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt"
    },
    {
        "name": "dpt_swin_large_384",
        "path": "weights/dpt_swin_large_384.pt",
        "backbone": "swinl12_384",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt"
    },
    {
        "name": "dpt_next_vit_large_384",
        "path": "weights/dpt_next_vit_large_384.pt",
        "backbone": "next_vit_large_6m",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt"
    },
    {
        "name": "dpt_levit_224",
        "path": "weights/dpt_levit_224.pt",
        "backbone": "levit_384",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt"
    },
    {
        "name": "dpt_large_384",
        "path": "weights/dpt_large_384.pt",
        "backbone": "vitl16_384",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt"
    },
    {
        "name": "midas_v21_384",
        "path": "weights/midas_v21_384.pt",
        "backbone": "",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt"
    },
    {
        "name": "midas_v21_small_256",
        "path": "weights/midas_v21_small_256.pt",
        "backbone": "",
        "url": "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt"
    },
]

for model_param in reversed(model_params):
    onnxFile = "weights/" + model_param["name"] + ".onnx"
    if os.path.exists(onnxFile):
        print(f"ONNX model for {model_param['name']} already exists. Skipping...")
        continue

    patchUnflatten()
    download_file(model_param["url"], "weights")
    model_path = model_param["path"]
    device = torch.device("cpu")

    if model_param["name"] == "dpt_levit_224":
        model = DPTDepthModel(
            path=model_path,
            backbone=model_param["backbone"],
            non_negative=True,
            head_features_1=64,
            head_features_2=8,
        )
    elif model_param["name"] == "midas_v21_384":
        model = MidasNet(model_path, non_negative=True)
    elif model_param["name"] == "midas_v21_small_256":
        model = MidasNet_small(
            model_path,
            features=64,
            backbone="efficientnet_lite3",
            exportable=True,
            non_negative=True,
            blocks={'expand': True})
    else:
        model = DPTDepthModel(
            path=model_path,
            backbone=model_param["backbone"],
            non_negative=True,
        )

    if model_param["name"] == "dpt_swin2_tiny_256" or model_param["name"] == "midas_v21_small_256":
        net_w, net_h = 256, 256
    elif model_param["name"] == "dpt_levit_224":
        net_w, net_h = 224, 224
    else:
        net_w, net_h = 384, 384

    #resize_mode = "minimal"
    #normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    #resize_image = Resize(
    #    net_w,
    #    net_h,
    #    resize_target=None,
    #    keep_aspect_ratio=False,
    #    ensure_multiple_of=32,
    #    resize_method="upper_bound",
    #    image_interpolation_method=cv2.INTER_CUBIC,
    #)

    #transform = Compose(
    #    [
    #        resize_image,
    #        normalization,
    #        PrepareForNet()
    #    ]
    #)
    model.eval()

    #img = utils.read_image("input/dog.jpg")
    #img_input = transform({"image": img})["image"]
    #shaped = img_input.reshape(1, 3, net_h, net_w)
    torch.onnx.export(model, torch.rand(1, 3, net_h, net_w, dtype=torch.float), onnxFile, export_params=True)

    # free memory
    del model
    gc.collect()


ONNX model for midas_v21_small_256 already exists. Skipping...
ONNX model for midas_v21_384 already exists. Skipping...
ONNX model for dpt_large_384 already exists. Skipping...
ONNX model for dpt_levit_224 already exists. Skipping...
ONNX model for dpt_next_vit_large_384 already exists. Skipping...
ONNX model for dpt_swin_large_384 already exists. Skipping...
ONNX model for dpt_swin2_tiny_256 already exists. Skipping...
ONNX model for dpt_swin2_base_384 already exists. Skipping...
ONNX model for dpt_swin2_large_384 already exists. Skipping...
ONNX model for dpt_beit_base_384 already exists. Skipping...
ONNX model for dpt_beit_large_384 already exists. Skipping...
File already downloaded: weights/dpt_beit_large_512.pt


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  window_size = tuple(np.array(resolution) // 16)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [None]:
# @title Copy to Google Drive
!pip install google-colab

from google.colab import drive
drive.mount('/content/drive')

import os
import shutil

# Define the source folder containing files with the specific extension
source_folder = 'weights/'

# Define the target folder in your Google Drive where you want to copy the files
target_folder = '/content/drive/MyDrive/MiDaS_Models/'

# Ensure the target folder exists, or create it if it doesn't
if not os.path.exists(target_folder):
    os.makedirs(target_folder)

# Specify the file extension you're looking for
file_extension = '.onnx'

# Get a list of files in the target folder
target_files = os.listdir(target_folder)

# Iterate over files in the source folder
for root, dirs, files in os.walk(source_folder):
    for file in files:
        if file.endswith(file_extension):
            source_path = os.path.join(root, file)
            target_path = os.path.join(target_folder, file)
            # Copy the file to Google Drive only if it doesn't already exist
            if file not in target_files:
                shutil.copy(source_path, target_path)
                print(f'Copied {file} to Google Drive.')

print(f'Copied all {file_extension} files to Google Drive.')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Copied dpt_beit_large_512.onnx to Google Drive.
Copied dpt_beit_large_384.onnx to Google Drive.
Copied dpt_beit_base_384.onnx to Google Drive.
Copied all .onnx files to Google Drive.
