Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Results on Kitti dataset are not reproducible #31

Closed
RuslanOm opened this issue Jun 8, 2021 · 18 comments
Closed

Results on Kitti dataset are not reproducible #31

RuslanOm opened this issue Jun 8, 2021 · 18 comments

Comments

@RuslanOm
Copy link

RuslanOm commented Jun 8, 2021

Hi! Thanks again about your work!

Recently, I tried to check accuracy of pre-trained models on KITTI (Eigen split) and found that it is differ from paper results.

Снимок экрана 2021-06-08 в 13 08 50

On this screenshoot you can see basic metrics used in depth prediction on Eigen split (files for split I take from this repo). For ground truth i used raw data from velodyne (used loader like this)

I hope, you can explain this results. Thanks!

@ranftlr
Copy link
Contributor

ranftlr commented Jun 8, 2021

That is not enough information for me. How does your evaluation script look like?

@RuslanOm
Copy link
Author

RuslanOm commented Jun 8, 2021

Of course, this script I used to get gt_depth from velodyne data in .npy format for every image from eigen split. in kitti_data you should have a normal raw kitti dataset with usual structure of files. After running this script you will have eigen split data in simple form:

.
+-- data
|   +-- kitti_eigen
|   |    +-- groundtruth
|   |    +-- images

In groundtruth you will have gt in .npy files and in images -- corresponding .png files.

And this script I used to eval. In dpt_large.npy should be ndarray with preds for images from kitti_eigen/images.

@ranftlr
Copy link
Contributor

ranftlr commented Jun 8, 2021

@RuslanOm Judging from the code, it seems that you are trying to reproduce "DPT-Large" in Table 1. If this is not the case, please tell me which model you are exactly using to reproduce what number in the paper. If that is the case, then your code is missing a proper alignment step that accounts for both scale and shift in the appropriate way. The alignment you do in your code is different from ours and also doesn't account for the shift. See here isl-org/MiDaS#29 for a link to evaluation code for NYUv2, specifically BadPixelMetric. We use the same metric for the KITTI results in Table 1.

@cheniynan: Can you post a screenshot of the results and/or code on how you are using the model?

@revisitq
Copy link

revisitq commented Jun 8, 2021

@RuslanOm Judging from the code, it seems that you are trying to reproduce "DPT-Large" in Table 1. If this is not the case, please tell me which model you are exactly using to reproduce what number in the paper. If that is the case, then your code is missing a proper alignment step that accounts for both scale and shift in the appropriate way. The alignment you do in your code is different from ours and also doesn't account for the shift. See here intel-isl/MiDaS#29 for a link to evaluation code for NYUv2, specifically BadPixelMetric. We use the same metric for the KITTI results in Table 1.

@cheniynan: Can you post a screenshot of the results and/or code on how you are using the model?

Thanks for your replay. Sorry, I can't show you the code. But I can show the result for you. Our work is use depth map to help improve the performance on monocular 3d detection. The result use depth map extracted by DORN is as follow:
DORN
The result use depth map extracted by you provided dpt-hybrid-kitti is as follow:
DPT
As you can see, the performance with depth map generated by dpt-hybrid-kitti has a huge gap with DORN, which means the accuracy of dpt-hybrid-kitti might worse than DORN.
The RMSE of depth map generated by DORN is as follow:
image
The RMSE of depth map generated by dpt-hybrid-kitti is as follow:
image
The gt depth map is generated from from velodyne. Here is the eval code:

import os
import numpy as np
from tqdm.auto import tqdm
import os.path as osp
import os
import cv2
def compute_errors(gt, pred):
    """Computation of error metrics between predicted and ground truth depths
    """
    thresh = np.maximum((gt / pred), (pred / gt))
    a1 = (thresh < 1.25     ).mean()
    a2 = (thresh < 1.25 ** 2).mean()
    a3 = (thresh < 1.25 ** 3).mean()

    rmse = (gt - pred) ** 2
    rmse = np.sqrt(rmse.mean())

    rmse_log = (np.log(gt) - np.log(pred)) ** 2
    rmse_log = np.sqrt(rmse_log.mean())

    abs_rel = np.mean(np.abs(gt - pred) / gt)

    sq_rel = np.mean(((gt - pred) ** 2) / gt)

    return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3


def evaluate(opt):
    """Evaluates a pretrained model using a specified test set
    """
    MIN_DEPTH = 1e-3
    MAX_DEPTH = 80

    base_path = '/home/chenyinan/Projects/github/DPT/data/kitti/training'
    gt_path = osp.join(base_path, 'depth_map')
    pred_path = osp.join(base_path, 'depth_2_dpt_crop')
    preds = os.listdir(pred_path)
    ratios = []
    errors = []
    for i in tqdm(range(len(preds))):
        gt_depth = cv2.imread(osp.join(gt_path, preds[i]), cv2.IMREAD_UNCHANGED).astype('float') / 256

        pred_depth = cv2.imread(osp.join(pred_path, preds[i]), cv2.IMREAD_UNCHANGED).astype('float') / 256
        mask = gt_depth > 0
        if opt == "gt":
            ratio = np.median(gt_depth[mask] / pred_depth[mask])
        else:
            ratio = 1

        pred_depth = pred_depth[mask]
        gt_depth = gt_depth[mask]

        pred_depth *= ratio
        ratios.append(ratio)

        pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH
        pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH

        if len(gt_depth) != 0:
            errors.append(compute_errors(gt_depth, pred_depth))

    ratios = np.array(ratios)
    med = np.median(ratios)
    print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med)))

    mean_errors = np.array(errors).mean(0)

    print("\n  " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3"))
    print(("&{: 8.3f}  " * 7).format(*mean_errors.tolist()) + "\\\\")
    print("\n-> Done!")


if __name__ == '__main__':
    st = 'gt'
    evaluate(st)
 

@RuslanOm
Copy link
Author

RuslanOm commented Jun 8, 2021

@RuslanOm Judging from the code, it seems that you are trying to reproduce "DPT-Large" in Table 1. If this is not the case, please tell me which model you are exactly using to reproduce what number in the paper. If that is the case, then your code is missing a proper alignment step that accounts for both scale and shift in the appropriate way. The alignment you do in your code is different from ours and also doesn't account for the shift. See here intel-isl/MiDaS#29 for a link to evaluation code for NYUv2, specifically BadPixelMetric. We use the same metric for the KITTI results in Table 1.

@cheniynan: Can you post a screenshot of the results and/or code on how you are using the model?

Thanks! One more question: if I want to get absolute depth from prediction of model (for example, using cam height like DGC module from this paper) I should also calculate scale and shift, or just a scale from some other method (like paper I mentioned)?

@ranftlr
Copy link
Contributor

ranftlr commented Jun 8, 2021

@RuslanOm I haven't read this paper, but the relation of the output p of the relative depth models to the true depth d = 1 / (s*p + t), where s is scale and t is shift. You need to determine both if you want to recover absolute metric depth and they vary per image.

@cheniynan The depth evalaution numbers are quite different to what we got (both with our method and with DORN). Are you using a specific subset of KITTI for evaluation? Could you perhaps post only the part of the code that shows image transform, passing through DPT, and any post-processing so that I can try to reproduce your evaluation? Without that it is pretty much impossible to debug your issue.

@revisitq
Copy link

revisitq commented Jun 8, 2021

@RuslanOm I haven't read this paper, but the relation of the output p of the relative depth models to the true depth d = 1 / (s*p + t), where s is scale and t is shift. You need to determine both if you want to recover absolute metric depth and they vary per image.

@cheniynan The depth evalaution numbers are quite different to what we got (both with our method and with DORN). Are you using a specific subset of KITTI for evaluation? Could you perhaps post only the part of the code that shows image transform, passing through DPT, and any post-processing so that I can try to reproduce your evaluation? Without that it is pretty much impossible to debug your issue.

The dataset that I use is kitti 3D object detection dataset, which contains 7481 images for training&validation and 7518 images for test. The training set is divided into 3712 images for training and 3769 images for validation. The workflow of our model is as follow:
1)extract depth map from image, using method like DORN or your method
2)Normalize depth map and image, the means and stds are computed throughout all the training set(7481 images)
3)input depth map and image to our detection model, get detection results.
The RMSE mentioned above are computed on training set.
Is that the scale and shift factor need to be finetune when I use DPT to extract depth map from kitti object detection dataset? I’m not sure.

@ranftlr
Copy link
Contributor

ranftlr commented Jun 8, 2021

The KITTI model doesn't have a missing scale or shift, so this should be fine. Let me ask you a couple of specific questions:

  • How do you exactly do step 1 with our method? Are you using our script? If not, can you show me how you are using our model (the code)?
  • How do you write out the png files that you seem to be storing the results as? Are you sure that you did that correctly?
  • Did you consider the different input image normalization that is required for DPT models?
  • Can you give me one input image and one corresponding result of DPT that you are getting so that I can try to reproduce this without having to download and preprocess a complete dataset?

@AlexeyAB
Copy link
Contributor

AlexeyAB commented Jun 8, 2021

@ranftlr @RuslanOm @cheniynan Hi, I fixed this issue, look at this PR how to reproduce our results: #32

@revisitq
Copy link

revisitq commented Jun 9, 2021

@ranftlr

  • I just use python run_monodepth.py -i data/kitti/training/image_2/ -o ./crop_train -t dpt_hybrid_kitti --kitti_crop to extract depth map
  • I modified write_depth as follow to store png files:
def write_depth(path, depth, bits=1):
    """Write depth map to pfm and png file.

    Args:
        path (str): filepath without extension
        depth (array): depth
    """
    # write_pfm(path + ".pfm", depth.astype(np.float32))

    depth_min = depth.min()
    depth_max = depth.max()

    max_val = (2 ** (8 * bits)) - 1

    if depth_max - depth_min > np.finfo("float").eps:
        out = max_val * (depth - depth_min) / (depth_max - depth_min)
    else:
        assert True, 'zeros'
        out = np.zeros(depth.shape, dtype=depth.dtype)
    out = depth
    if bits == 1:
        cv2.imwrite(path + ".png", out.astype("uint8"))
    elif bits == 2:
        cv2.imwrite(path + ".png", out.astype("uint16"))
    return
  • I didn't modify image normalization parameters for dpt_hybrid_kitti and just follow yours
  • one input image and output image from DPT is here:
    depth_000029
    img_000029
    One import thing is that I didn't set cv2.IMWRITE_PNG_COMPRESSION to 0 so the png file is compressed before save. I set cv2.IMWRITE_PNG_COMPRESSION to 0 and extract depth map from image use python run_monodepth.py -i data/kitti/training/image_2/ -o ./crop_train -t dpt_hybrid_kitti --kitti_crop again and evaluate the accuracy of depth map, but I still get the same result.
    image

@revisitq
Copy link

revisitq commented Jun 9, 2021

@ranftlr

  • I just use python run_monodepth.py -i data/kitti/training/image_2/ -o ./crop_train -t dpt_hybrid_kitti --kitti_crop to extract depth map
  • I modified write_depth as follow to store png files:
def write_depth(path, depth, bits=1):
    """Write depth map to pfm and png file.

    Args:
        path (str): filepath without extension
        depth (array): depth
    """
    # write_pfm(path + ".pfm", depth.astype(np.float32))

    depth_min = depth.min()
    depth_max = depth.max()

    max_val = (2 ** (8 * bits)) - 1

    if depth_max - depth_min > np.finfo("float").eps:
        out = max_val * (depth - depth_min) / (depth_max - depth_min)
    else:
        assert True, 'zeros'
        out = np.zeros(depth.shape, dtype=depth.dtype)
    out = depth
    if bits == 1:
        cv2.imwrite(path + ".png", out.astype("uint8"))
    elif bits == 2:
        cv2.imwrite(path + ".png", out.astype("uint16"))
    return
  • I didn't modify image normalization parameters for dpt_hybrid_kitti and just follow yours
  • one input image and output image from DPT is here:
    depth_000029
    img_000029
    One import thing is that I didn't set cv2.IMWRITE_PNG_COMPRESSION to 0 so the png file is compressed before save. I set cv2.IMWRITE_PNG_COMPRESSION to 0 and extract depth map from image use python run_monodepth.py -i data/kitti/training/image_2/ -o ./crop_train -t dpt_hybrid_kitti --kitti_crop again and evaluate the accuracy of depth map, but I still get the same result.
    image

I evaluated the accuracy on kitti depth predict dataset and got the same result as your paper. But the accuracy of DPT on kitti object detection dataset is worse than DORN, maybe I should find another model to get better depth map for object detection. Anyway, thanks for your great work!

@RuslanOm
Copy link
Author

RuslanOm commented Jun 9, 2021

@AlexeyAB Thanks! Can you explain, please, for what we need to use this cropps for KITTI (or other cropps)?

@ranftlr
Copy link
Contributor

ranftlr commented Jun 9, 2021

@cheniynan Seems like you are normalizing the output to [0, 1]. Not sure if this influences your results since you seem to be able to reproduce our numbers on the Eigen split, but we've added a flag "--absolute_depth" that should write results in the format that you need.

@RuslanOm Crops are usually necessary to align the input image size to the requirements of the network. For example, DPT requires the input image to be a fixed multiple of 32. For KITTI and NYUv2 we use crops that have been standard for evaluating models on these datasets in prior work to ensure that the results are comparable.

@AlexeyAB
Copy link
Contributor

AlexeyAB commented Jun 9, 2021

@RuslanOm

We need to use --do_kb_crop --garg_crop for Kitti, and --eigen_crop for NYU, when evaluate by using eval_with_pngs.py. This is the standard approach for evaluation suggested by David Eigen et al. and Ravi Garg et al., for fair comparison with other state-of-the-art works.

In the PR I show how to match results from our paper (only for NYU result is slightly better than in the paper: AbsRel=0.109, while in the paper AbsRel=0.110): #32

We use the same evaluation approach as in the papers: Ravi Garg et al., David Eigen et al. and Jin Han Lee at al. https://github.com/cogaplex-bts/bts/tree/master/pytorch#bts


Ravi Garg et al. Unsupervised CNN for Single View Depth Estimation: Geometry to the Rescue: https://arxiv.org/pdf/1603.04992.pdf

For fair comparison with state-of-the-art single view depth prediction, we evaluate our results on the same cropped region of interest as [8].

David Eigen et al. Depth Map Prediction from a Single Image using a Multi-Scale Deep Network: https://papers.nips.cc/paper/2014/file/7bccfde7714a1ebadf06c5f4cea752c1-Paper.pdf

Depth is only provided within the bottom part of the RGB image, however we feed the entire image into our model
to provide additional context to the global coarse-scale network (the fine network sees the bottom crop corresponding to the target area).

@AlexeyAB
Copy link
Contributor

AlexeyAB commented Jun 9, 2021

@cheniynan

I just use python run_monodepth.py -i data/kitti/training/image_2/ -o ./crop_train -t dpt_hybrid_kitti --kitti_crop to extract depth map
I modified write_depth as follow to store png files:

Try to use new write_depth() function with flag absolute_depth=True (PR is merged).
util.io.write_depth(filename, prediction, bits=2, absolute_depth=True)

Or use this command:
python run_monodepth.py --model_type dpt_hybrid_kitti --kitti_crop --absolute_depth


https://github.com/intel-isl/DPT/blob/b71c02b3cda1823e703cad2dae8b005b13e90590/util/io.py#L171-L198

@RuslanOm
Copy link
Author

RuslanOm commented Jun 9, 2021

@ranftlr @AlexeyAB I'm so sorry, but I try to make same manipulations for dpt_large, but get something bad. I understand that dpt_large returns inverted depth. I also understand that I need to calculate scale and shift (there are not the same that for hybrid kitti model, aren't they?). I used compute_scale_and_shift procedure for this, but I get negative scale and strange shift (I passed pred of model, gt and mask to function). Can you give some peace of advice, please, hot to deal with it?

@revisitq
Copy link

@AlexeyAB OK, thanks!

@ranftlr
Copy link
Contributor

ranftlr commented Jun 10, 2021

@RuslanOm Here is a gist that reproduces the evaluation https://gist.github.com/ranftlr/45f4c7ddeb1bbb88d606bc600cab6c8d

Result with DPT-large:
image

See here for the subset of 161 images that we use for evaluation: isl-org/MiDaS#18

Please pull my latest changes to get exactly this result, as I found a small discrepancy in the interpolation module which let to slightly higher error (8.47% instead of 8.46%).

@ranftlr ranftlr closed this as completed Jul 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants