Skip to content

Commit

Permalink
Merge Midas v2.1 (#54)
Browse files Browse the repository at this point in the history
* Assign an array instead of a scalar to out if the range is less than eps.

* Added Mobile (Android, iOS) and ROS1

* Added small model

* Rename small model

* Update README.md

* Update Dockerfile

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* fix error on cpu execution

Co-authored-by: AlexeyAB <kikots@mail.ru>
Co-authored-by: Alexey <AlexeyAB@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 10, 2020
1 parent b00bf61 commit 9983656
Show file tree
Hide file tree
Showing 117 changed files with 11,500 additions and 89 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ COPY ./midas ./midas
COPY ./*.py ./

# download model weights so the docker image can be used offline
RUN curl -OL https://github.com/intel-isl/MiDaS/releases/download/v2/model-f46da743.pt
RUN curl -OL https://github.com/intel-isl/MiDaS/releases/download/v2_1/model-f6b98070.pt
RUN python3 run.py; exit 0

# entrypoint (dont forget to mount input and output directories)
Expand Down
47 changes: 44 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@ This repository contains code to compute depth from a single image. It accompani
>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun

The pre-trained model corresponds to `MIX 5` with multi-objective optimization enabled.
MiDaS v2.1 was trained on 10 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, ApolloScape, BlendedMVS, IRS) with
multi-objective optimization enabled .
The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/intel-isl/MiDaS/releases/tag/v2)


### Changelog
* [Nov 2020] Released MiDaS v2.1:
- New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)
- New light-weight model that achieves [real-time performance](https://github.com/intel-isl/MiDaS/tree/master/mobile) on mobile platforms.
- Sample applications for [iOS](https://github.com/intel-isl/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/intel-isl/MiDaS/tree/master/mobile/android)
- [ROS package](https://github.com/intel-isl/MiDaS/tree/master/ros) for easy deployment on robots
* [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/).
* [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust
* [Jul 2019] Initial release of MiDaS ([Link](https://github.com/intel-isl/MiDaS/releases/tag/v1))
Expand All @@ -20,7 +28,8 @@ Please be patient. Inference might take up to 30 seconds due to hardware restric

### Setup

1) Download the model weights [model-f45da743.pt](https://github.com/intel-isl/MiDaS/releases/download/v2/model-f46da743.pt) and place the
1) Download the model weights [model-f6b98070.pt](https://github.com/intel-isl/MiDaS/releases/download/v2_1/model-f6b98070.pt)
and [model-small-70d6b9c8.pt](https://github.com/intel-isl/MiDaS/releases/download/v2_1/model-small-70d6b9c8.pt) and place the
file in the root folder.

2) Set up dependencies:
Expand All @@ -29,7 +38,7 @@ file in the root folder.
conda install pytorch torchvision opencv
```

The code was tested with Python 3.7, PyTorch 1.2.0, and OpenCV 3.4.2.
The code was tested with Python 3.7, PyTorch 1.7.0, and OpenCV 4.4.0.


### Usage
Expand All @@ -42,6 +51,12 @@ file in the root folder.
python run.py
```

Or run the small model:

```shell
python run.py --model_weights model-small-70d6b9c8.pt --model_type small
```

3) The resulting inverse depth maps are written to the `output` folder.


Expand Down Expand Up @@ -73,6 +88,32 @@ The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/

See [README](https://github.com/intel-isl/MiDaS/tree/master/tf) in the `tf` subdirectory.

#### via Mobile (iOS / Android)

See [README](https://github.com/intel-isl/MiDaS/tree/master/mobile) in the `mobile` subdirectory.

#### via ROS1 (Robot Operating System)

See [README](https://github.com/intel-isl/MiDaS/tree/master/ros) in the `ros` subdirectory.


### Accuracy

Zero-shot error (the lower - the better) and speed (FPS):

| Model | DIW, WHDR | Eth3d, AbsRel | Sintel, AbsRel | Kitti, δ>1.25 | NyuDepthV2, δ>1.25 | TUM, δ>1.25 | Speed, FPS |
|---|---|---|---|---|---|---|---|
| **Small models:** | | | | | | | iPhone 11 |
| MiDaS v2 small | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | 0.6 |
| MiDaS v2.1 small [URL](https://github.com/intel-isl/MiDaS/releases/download/v2_1/model-small-70d6b9c8.pt) | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | 30 |
| Relative improvement | -7.7% | **+13.3%** | -2.1% | -34.2% | **+14.6%** | **+14.5%** | **50x** |
| | | | | | | |
| **Big models:** | | | | | | | GPU RTX 2080Ti |
| MiDaS v2 large [URL](https://github.com/intel-isl/MiDaS/releases/download/v2/model-f46da743.pt) | **0.1246** | 0.1290 | **0.3270** | 23.90 | 9.55 | 14.29 | 59 |
| MiDaS v2.1 large [URL](https://github.com/intel-isl/MiDaS/releases/download/v2_1/model-f6b98070.pt) | 0.1295 | **0.1155** | 0.3285 | **16.08** | **8.71** | **12.51** | 59 |
| Relative improvement | -3.9% | **+10.5%** | -0.52% | **+32.7%** | **+8.8%** | **+12.5%** | 1x |


### Citation

Please cite our paper if you use this code or any of the models:
Expand Down
22 changes: 21 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from midas.midas_net import MidasNet
from midas.midas_net_custom import MidasNet_small


def MiDaS(pretrained=True, **kwargs):
Expand All @@ -15,7 +16,26 @@ def MiDaS(pretrained=True, **kwargs):

if pretrained:
checkpoint = (
"https://github.com/intel-isl/MiDaS/releases/download/v2/model-f46da743.pt"
"https://github.com/intel-isl/MiDaS/releases/download/v2_1/model-f6b98070.pt"
)
state_dict = torch.hub.load_state_dict_from_url(
checkpoint, progress=True, check_hash=True
)
model.load_state_dict(state_dict)

return model

def MiDaS_small(pretrained=True, **kwargs):
""" # This docstring shows up in hub.help()
MiDaS model for monocular depth estimation
pretrained (bool): load pretrained weights into model
"""

model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True})

if pretrained:
checkpoint = (
"https://github.com/intel-isl/MiDaS/releases/download/v2_1/model-small-70d6b9c8.pt"
)
state_dict = torch.hub.load_state_dict_from_url(
checkpoint, progress=True, check_hash=True
Expand Down
2 changes: 1 addition & 1 deletion midas/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def load(self, path):
Args:
path (str): file path
"""
parameters = torch.load(path)
parameters = torch.load(path, map_location=torch.device('cpu'))

if "optimizer" in parameters:
parameters = parameters["model"]
Expand Down
200 changes: 179 additions & 21 deletions midas/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,72 @@
import torch.nn as nn


def _make_encoder(features, use_pretrained):
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features)

def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True):
if backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False

return pretrained, scratch


def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()

out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8

scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)

return scratch


def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)


def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()

pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])

return pretrained


def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
Expand All @@ -27,23 +86,6 @@ def _make_pretrained_resnext101_wsl(use_pretrained):
return _make_resnet_backbone(resnet)


def _make_scratch(in_shape, out_shape):
scratch = nn.Module()

scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape, kernel_size=3, stride=1, padding=1, bias=False
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape, kernel_size=3, stride=1, padding=1, bias=False
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape, kernel_size=3, stride=1, padding=1, bias=False
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape, kernel_size=3, stride=1, padding=1, bias=False
)
return scratch


class Interpolate(nn.Module):
"""Interpolation module.
Expand Down Expand Up @@ -151,3 +193,119 @@ def forward(self, *xs):
)

return output




class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""

def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()

self.bn = bn

self.groups=1

self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)

self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)

if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)

self.activation = activation

self.skip_add = nn.quantized.FloatFunctional()

def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""

out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)

out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)

if self.groups > 1:
out = self.conv_merge(out)

return self.skip_add.add(out, x)

# return out + x


class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""

def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()

self.deconv = deconv
self.align_corners = align_corners

self.groups=1

self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2

self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)

self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)

self.skip_add = nn.quantized.FloatFunctional()

def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]

if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res

output = self.resConfUnit2(output)

output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)

output = self.out_conv(output)

return output

2 changes: 1 addition & 1 deletion midas/midas_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, path=None, features=256, non_negative=True):

use_pretrained = False if path is None else True

self.pretrained, self.scratch = _make_encoder(features, use_pretrained)
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)

self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
Expand Down
Loading

0 comments on commit 9983656

Please sign in to comment.