From 7730714a24a5b070a706280dea29cfa371af17cf Mon Sep 17 00:00:00 2001 From: xiaodan Date: Mon, 8 Jan 2024 11:53:09 +0800 Subject: [PATCH] add self supervised depth completion. --- modelscope/metainfo.py | 8 +- .../__init__.py | 21 + .../criteria.py | 90 +++ .../dataloaders/__init__.py | 0 .../dataloaders/calib_cam_to_cam.txt | 34 + .../dataloaders/kitti_loader.py | 330 ++++++++++ .../dataloaders/pose_estimator.py | 104 +++ .../dataloaders/transforms.py | 618 ++++++++++++++++++ .../helper.py | 264 ++++++++ .../inverse_warp.py | 138 ++++ .../metrics.py | 164 +++++ .../self_supervised_depth_completion/model.py | 211 ++++++ .../self_supervised_depth_completion.py | 396 +++++++++++ .../vis_utils.py | 113 ++++ modelscope/outputs/outputs.py | 1 + modelscope/pipelines/cv/__init__.py | 4 + ...lf_supervised_depth_completion_pipeline.py | 36 + modelscope/utils/constant.py | 2 + modelscope/utils/pipeline_schema.json | 14 +- .../test_self_supervised_depth_completion.py | 46 ++ 20 files changed, 2592 insertions(+), 2 deletions(-) create mode 100644 modelscope/models/cv/self_supervised_depth_completion/__init__.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/criteria.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/dataloaders/__init__.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/dataloaders/calib_cam_to_cam.txt create mode 100644 modelscope/models/cv/self_supervised_depth_completion/dataloaders/kitti_loader.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/dataloaders/pose_estimator.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/dataloaders/transforms.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/helper.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/inverse_warp.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/metrics.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/model.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/self_supervised_depth_completion.py create mode 100644 modelscope/models/cv/self_supervised_depth_completion/vis_utils.py create mode 100644 modelscope/pipelines/cv/self_supervised_depth_completion_pipeline.py create mode 100644 tests/pipelines/test_self_supervised_depth_completion.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index d7487f849..62883ce58 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -129,6 +129,7 @@ class Models(object): image_control_3d_portrait = 'image-control-3d-portrait' rife = 'rife' anydoor = 'anydoor' + self_supervised_depth_completion = 'self-supervised-depth-completion' # nlp models bert = 'bert' @@ -460,6 +461,7 @@ class Pipelines(object): rife_video_frame_interpolation = 'rife-video-frame-interpolation' anydoor = 'anydoor' image_to_3d = 'image-to-3d' + self_supervised_depth_completion = 'self-supervised-depth-completion' # nlp tasks automatic_post_editing = 'automatic-post-editing' @@ -941,7 +943,10 @@ class Pipelines(object): 'damo/cv_image-view-transform'), Tasks.image_control_3d_portrait: ( Pipelines.image_control_3d_portrait, - 'damo/cv_vit_image-control-3d-portrait-synthesis') + 'damo/cv_vit_image-control-3d-portrait-synthesis'), + Tasks.self_supervised_depth_completion: ( + Pipelines.self_supervised_depth_completion, + 'damo/self-supervised-depth-completion') } @@ -964,6 +969,7 @@ class CVTrainers(object): nerf_recon_4k = 'nerf-recon-4k' action_detection = 'action-detection' vision_efficient_tuning = 'vision-efficient-tuning' + self_supervised_depth_completion = 'self-supervised-depth-completion' class NLPTrainers(object): diff --git a/modelscope/models/cv/self_supervised_depth_completion/__init__.py b/modelscope/models/cv/self_supervised_depth_completion/__init__.py new file mode 100644 index 000000000..7b6dc05ad --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .self_supervised_depth_completion import SelfSupervisedDepthCompletion +else: + _import_structure = { + 'selfsuperviseddepthcompletion': ['SelfSupervisedDepthCompletion'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) \ No newline at end of file diff --git a/modelscope/models/cv/self_supervised_depth_completion/criteria.py b/modelscope/models/cv/self_supervised_depth_completion/criteria.py new file mode 100644 index 000000000..d0071b80c --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/criteria.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +from modelscope.utils.logger import get_logger +logger = get_logger() + +loss_names = ['l1', 'l2'] + + +class MaskedMSELoss(nn.Module): + def __init__(self): + super(MaskedMSELoss, self).__init__() + + def forward(self, pred, target): + assert pred.dim() == target.dim(), "inconsistent dimensions" + valid_mask = (target > 0).detach() + diff = target - pred + diff = diff[valid_mask] + self.loss = (diff**2).mean() + return self.loss + + +class MaskedL1Loss(nn.Module): + def __init__(self): + super(MaskedL1Loss, self).__init__() + + def forward(self, pred, target, weight=None): + assert pred.dim() == target.dim(), "inconsistent dimensions" + valid_mask = (target > 0).detach() + diff = target - pred + diff = diff[valid_mask] + self.loss = diff.abs().mean() + return self.loss + + +class PhotometricLoss(nn.Module): + def __init__(self): + super(PhotometricLoss, self).__init__() + + def forward(self, target, recon, mask=None): + + assert recon.dim( + ) == 4, "expected recon dimension to be 4, but instead got {}.".format( + recon.dim()) + assert target.dim( + ) == 4, "expected target dimension to be 4, but instead got {}.".format( + target.dim()) + assert recon.size() == target.size(), "expected recon and target to have the same size, but got {} and {} instead"\ + .format(recon.size(), target.size()) + diff = (target - recon).abs() + diff = torch.sum(diff, 1) # sum along the color channel + + # compare only pixels that are not black + valid_mask = (torch.sum(recon, 1) > 0).float() * (torch.sum(target, 1) + > 0).float() + if mask is not None: + valid_mask = valid_mask * torch.squeeze(mask).float() + valid_mask = valid_mask.byte().detach() + if valid_mask.numel() > 0: + diff = diff[valid_mask] + if diff.nelement() > 0: + self.loss = diff.mean() + else: + logger.info( + "warning: diff.nelement()==0 in PhotometricLoss (this is expected during early stage of training, try larger batch size)." + ) + self.loss = 0 + else: + logger.info("warning: 0 valid pixel in PhotometricLoss") + self.loss = 0 + return self.loss + + +class SmoothnessLoss(nn.Module): + def __init__(self): + super(SmoothnessLoss, self).__init__() + + def forward(self, depth): + def second_derivative(x): + assert x.dim( + ) == 4, "expected 4-dimensional data, but instead got {}".format( + x.dim()) + horizontal = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, 1:-1, : + -2] - x[:, :, 1:-1, 2:] + vertical = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, :-2, 1: + -1] - x[:, :, 2:, 1:-1] + der_2nd = horizontal.abs() + vertical.abs() + return der_2nd.mean() + + self.loss = second_derivative(depth) + return self.loss diff --git a/modelscope/models/cv/self_supervised_depth_completion/dataloaders/__init__.py b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modelscope/models/cv/self_supervised_depth_completion/dataloaders/calib_cam_to_cam.txt b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/calib_cam_to_cam.txt new file mode 100644 index 000000000..04a75a0db --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/calib_cam_to_cam.txt @@ -0,0 +1,34 @@ +calib_time: 09-Jan-2012 13:57:47 +corner_dist: 9.950000e-02 +S_00: 1.392000e+03 5.120000e+02 +K_00: 9.842439e+02 0.000000e+00 6.900000e+02 0.000000e+00 9.808141e+02 2.331966e+02 0.000000e+00 0.000000e+00 1.000000e+00 +D_00: -3.728755e-01 2.037299e-01 2.219027e-03 1.383707e-03 -7.233722e-02 +R_00: 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 +T_00: 2.573699e-16 -1.059758e-16 1.614870e-16 +S_rect_00: 1.242000e+03 3.750000e+02 +R_rect_00: 9.999239e-01 9.837760e-03 -7.445048e-03 -9.869795e-03 9.999421e-01 -4.278459e-03 7.402527e-03 4.351614e-03 9.999631e-01 +P_rect_00: 7.215377e+02 0.000000e+00 6.095593e+02 0.000000e+00 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 +S_01: 1.392000e+03 5.120000e+02 +K_01: 9.895267e+02 0.000000e+00 7.020000e+02 0.000000e+00 9.878386e+02 2.455590e+02 0.000000e+00 0.000000e+00 1.000000e+00 +D_01: -3.644661e-01 1.790019e-01 1.148107e-03 -6.298563e-04 -5.314062e-02 +R_01: 9.993513e-01 1.860866e-02 -3.083487e-02 -1.887662e-02 9.997863e-01 -8.421873e-03 3.067156e-02 8.998467e-03 9.994890e-01 +T_01: -5.370000e-01 4.822061e-03 -1.252488e-02 +S_rect_01: 1.242000e+03 3.750000e+02 +R_rect_01: 9.996878e-01 -8.976826e-03 2.331651e-02 8.876121e-03 9.999508e-01 4.418952e-03 -2.335503e-02 -4.210612e-03 9.997184e-01 +P_rect_01: 7.215377e+02 0.000000e+00 6.095593e+02 -3.875744e+02 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 +S_02: 1.392000e+03 5.120000e+02 +K_02: 9.597910e+02 0.000000e+00 6.960217e+02 0.000000e+00 9.569251e+02 2.241806e+02 0.000000e+00 0.000000e+00 1.000000e+00 +D_02: -3.691481e-01 1.968681e-01 1.353473e-03 5.677587e-04 -6.770705e-02 +R_02: 9.999758e-01 -5.267463e-03 -4.552439e-03 5.251945e-03 9.999804e-01 -3.413835e-03 4.570332e-03 3.389843e-03 9.999838e-01 +T_02: 5.956621e-02 2.900141e-04 2.577209e-03 +S_rect_02: 1.242000e+03 3.750000e+02 +R_rect_02: 9.998817e-01 1.511453e-02 -2.841595e-03 -1.511724e-02 9.998853e-01 -9.338510e-04 2.827154e-03 9.766976e-04 9.999955e-01 +P_rect_02: 7.215377e+02 0.000000e+00 6.095593e+02 4.485728e+01 0.000000e+00 7.215377e+02 1.728540e+02 2.163791e-01 0.000000e+00 0.000000e+00 1.000000e+00 2.745884e-03 +S_03: 1.392000e+03 5.120000e+02 +K_03: 9.037596e+02 0.000000e+00 6.957519e+02 0.000000e+00 9.019653e+02 2.242509e+02 0.000000e+00 0.000000e+00 1.000000e+00 +D_03: -3.639558e-01 1.788651e-01 6.029694e-04 -3.922424e-04 -5.382460e-02 +R_03: 9.995599e-01 1.699522e-02 -2.431313e-02 -1.704422e-02 9.998531e-01 -1.809756e-03 2.427880e-02 2.223358e-03 9.997028e-01 +T_03: -4.731050e-01 5.551470e-03 -5.250882e-03 +S_rect_03: 1.242000e+03 3.750000e+02 +R_rect_03: 9.998321e-01 -7.193136e-03 1.685599e-02 7.232804e-03 9.999712e-01 -2.293585e-03 -1.683901e-02 2.415116e-03 9.998553e-01 +P_rect_03: 7.215377e+02 0.000000e+00 6.095593e+02 -3.395242e+02 0.000000e+00 7.215377e+02 1.728540e+02 2.199936e+00 0.000000e+00 0.000000e+00 1.000000e+00 2.729905e-03 diff --git a/modelscope/models/cv/self_supervised_depth_completion/dataloaders/kitti_loader.py b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/kitti_loader.py new file mode 100644 index 000000000..6dfa1e8ca --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/kitti_loader.py @@ -0,0 +1,330 @@ +import os +import os.path +import glob +import numpy as np +from numpy import linalg as LA +from random import choice +from PIL import Image +import torch.utils.data as data +import cv2 +from modelscope.models.cv.self_supervised_depth_completion.dataloaders import transforms +from modelscope.models.cv.self_supervised_depth_completion.dataloaders.pose_estimator import get_pose_pnp + +input_options = ['d', 'rgb', 'rgbd', 'g', 'gd'] + + +def load_calib(args): + """ + Temporarily hardcoding the calibration matrix using calib file from 2011_09_26 + """ + calib = open(os.path.join(args.data_folder, "calib_cam_to_cam.txt"), "r") + lines = calib.readlines() + P_rect_line = lines[25] + + Proj_str = P_rect_line.split(":")[1].split(" ")[1:] + Proj = np.reshape(np.array([float(p) for p in Proj_str]), + (3, 4)).astype(np.float32) + K = Proj[:3, :3] # camera matrix + + # note: we will take the center crop of the images during augmentation + # that changes the optical centers, but not focal lengths + K[0, 2] = K[ + 0, + 2] - 13 # from width = 1242 to 1216, with a 13-pixel cut on both sides + K[1, 2] = K[ + 1, + 2] - 11.5 # from width = 375 to 352, with a 11.5-pixel cut on both sides + return K + + +def get_paths_and_transform(split, args): + assert (args.use_d or args.use_rgb + or args.use_g), 'no proper input selected' + + if split == "train": + transform = train_transform + glob_d = os.path.join( + args.data_folder, + 'data_depth_velodyne/train/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png' + ) + glob_gt = os.path.join( + args.data_folder, + 'data_depth_annotated/train/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png' + ) + + def get_rgb_paths(p): + ps = p.split('/') + pnew = '/'.join([args.data_folder] + ['data_rgb'] + ps[-6:-4] + + ps[-2:-1] + ['data'] + ps[-1:]) + return pnew + elif split == "val": + if args.val == "full": + transform = val_transform + glob_d = os.path.join( + args.data_folder, + 'data_depth_velodyne/val/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png' + ) + glob_gt = os.path.join( + args.data_folder, + 'data_depth_annotated/val/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png' + ) + + def get_rgb_paths(p): + ps = p.split('/') + pnew = '/'.join(ps[:-7] + + ['data_rgb '] + ps[-6:-4] + ps[-2:-1] + ['data'] + ps[-1:]) + return pnew + elif args.val == "select": + transform = no_transform + glob_d = os.path.join( + args.data_folder, + "depth_selection/val_selection_cropped/velodyne_raw/*.png") + glob_gt = os.path.join( + args.data_folder, + "depth_selection/val_selection_cropped/groundtruth_depth/*.png" + ) + + def get_rgb_paths(p): + return p.replace("groundtruth_depth", "image") + elif split == "test_completion": + transform = no_transform + glob_d = os.path.join( + args.data_folder, + "depth_selection/test_depth_completion_anonymous/velodyne_raw/*.png" + ) + glob_gt = None # "test_depth_completion_anonymous/" + glob_rgb = os.path.join( + args.data_folder, + "depth_selection/test_depth_completion_anonymous/image/*.png") + elif split == "test_prediction": + transform = no_transform + glob_d = None + glob_gt = None # "test_depth_completion_anonymous/" + glob_rgb = os.path.join( + args.data_folder, + "depth_selection/test_depth_prediction_anonymous/image/*.png") + else: + raise ValueError("Unrecognized split " + str(split)) + + if glob_gt is not None: + # train or val-full or val-select + paths_d = sorted(glob.glob(glob_d)) + paths_gt = sorted(glob.glob(glob_gt)) + paths_rgb = [get_rgb_paths(p) for p in paths_gt] + else: + # test only has d or rgb + paths_rgb = sorted(glob.glob(glob_rgb)) + paths_gt = [None] * len(paths_rgb) + if split == "test_prediction": + paths_d = [None] * len( + paths_rgb) # test_prediction has no sparse depth + else: + paths_d = sorted(glob.glob(glob_d)) + + if len(paths_d) == 0 and len(paths_rgb) == 0 and len(paths_gt) == 0: + raise (RuntimeError("Found 0 images under {}".format(glob_gt))) + if len(paths_d) == 0 and args.use_d: + raise (RuntimeError("Requested sparse depth but none was found")) + if len(paths_rgb) == 0 and args.use_rgb: + raise (RuntimeError("Requested rgb images but none was found")) + if len(paths_rgb) == 0 and args.use_g: + raise (RuntimeError("Requested gray images but no rgb was found")) + if len(paths_rgb) != len(paths_d) or len(paths_rgb) != len(paths_gt): + raise (RuntimeError("Produced different sizes for datasets")) + + paths = {"rgb": paths_rgb, "d": paths_d, "gt": paths_gt} + return paths, transform + + +def rgb_read(filename): + assert os.path.exists(filename), "file not found: {}".format(filename) + img_file = Image.open(filename) + # rgb_png = np.array(img_file, dtype=float) / 255.0 # scale pixels to the range [0,1] + rgb_png = np.array(img_file, dtype='uint8') # in the range [0,255] + img_file.close() + return rgb_png + + +def depth_read(filename): + # loads depth map D from png file + # and returns it as a numpy array, + # for details see readme.txt + assert os.path.exists(filename), "file not found: {}".format(filename) + img_file = Image.open(filename) + depth_png = np.array(img_file, dtype=int) + img_file.close() + # make sure we have a proper 16bit depth map here.. not 8bit! + assert np.max(depth_png) > 255, \ + "np.max(depth_png)={}, path={}".format(np.max(depth_png), filename) + + depth = depth_png.astype(float) / 256. + # depth[depth_png == 0] = -1. + depth = np.expand_dims(depth, -1) + return depth + + +oheight, owidth = 352, 1216 + + +def drop_depth_measurements(depth, prob_keep): + mask = np.random.binomial(1, prob_keep, depth.shape) + depth *= mask + return depth + + +def train_transform(rgb, sparse, target, rgb_near, args): + # s = np.random.uniform(1.0, 1.5) # random scaling + # angle = np.random.uniform(-5.0, 5.0) # random rotation degrees + do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip + + transform_geometric = transforms.Compose([ + # transforms.Rotate(angle), + # transforms.Resize(s), + transforms.BottomCrop((oheight, owidth)), + transforms.HorizontalFlip(do_flip) + ]) + if sparse is not None: + sparse = transform_geometric(sparse) + target = transform_geometric(target) + if rgb is not None: + brightness = np.random.uniform(max(0, 1 - args.jitter), + 1 + args.jitter) + contrast = np.random.uniform(max(0, 1 - args.jitter), 1 + args.jitter) + saturation = np.random.uniform(max(0, 1 - args.jitter), + 1 + args.jitter) + transform_rgb = transforms.Compose([ + transforms.ColorJitter(brightness, contrast, saturation, 0), + transform_geometric + ]) + rgb = transform_rgb(rgb) + if rgb_near is not None: + rgb_near = transform_rgb(rgb_near) + # sparse = drop_depth_measurements(sparse, 0.9) + + return rgb, sparse, target, rgb_near + + +def val_transform(rgb, sparse, target, rgb_near, args): + transform = transforms.Compose([ + transforms.BottomCrop((oheight, owidth)), + ]) + if rgb is not None: + rgb = transform(rgb) + if sparse is not None: + sparse = transform(sparse) + if target is not None: + target = transform(target) + if rgb_near is not None: + rgb_near = transform(rgb_near) + return rgb, sparse, target, rgb_near + + +def no_transform(rgb, sparse, target, rgb_near, args): + return rgb, sparse, target, rgb_near + + +to_tensor = transforms.ToTensor() +def to_float_tensor(x): return to_tensor(x).float() + + +def handle_gray(rgb, args): + if rgb is None: + return None, None + if not args.use_g: + return rgb, None + else: + img = np.array(Image.fromarray(rgb).convert('L')) + img = np.expand_dims(img, -1) + if not args.use_rgb: + rgb_ret = None + else: + rgb_ret = rgb + return rgb_ret, img + + +def get_rgb_near(path, args): + assert path is not None, "path is None" + + def extract_frame_id(filename): + head, tail = os.path.split(filename) + number_string = tail[0:tail.find('.')] + number = int(number_string) + return head, number + + def get_nearby_filename(filename, new_id): + head, _ = os.path.split(filename) + new_filename = os.path.join(head, '%010d.png' % new_id) + return new_filename + + head, number = extract_frame_id(path) + count = 0 + max_frame_diff = 3 + candidates = [ + i - max_frame_diff for i in range(max_frame_diff * 2 + 1) + if i - max_frame_diff != 0 + ] + while True: + random_offset = choice(candidates) + path_near = get_nearby_filename(path, number + random_offset) + if os.path.exists(path_near): + break + assert count < 20, "cannot find a nearby frame in 20 trials for {}".format( + path) + count += 1 + + return rgb_read(path_near) + + +class KittiDepth(data.Dataset): + """A data loader for the Kitti dataset + """ + + def __init__(self, split, args): + self.args = args + self.split = split + paths, transform = get_paths_and_transform(split, args) + self.paths = paths + self.transform = transform + self.K = load_calib(args) + self.threshold_translation = 0.1 + + def __getraw__(self, index): + rgb = rgb_read(self.paths['rgb'][index]) if \ + (self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None + sparse = depth_read(self.paths['d'][index]) if \ + (self.paths['d'][index] is not None and self.args.use_d) else None + target = depth_read(self.paths['gt'][index]) if \ + self.paths['gt'][index] is not None else None + rgb_near = get_rgb_near(self.paths['rgb'][index], self.args) if \ + self.split == 'train' and self.args.use_pose else None + return rgb, sparse, target, rgb_near + + def __getitem__(self, index): + rgb, sparse, target, rgb_near = self.__getraw__(index) + rgb, sparse, target, rgb_near = self.transform(rgb, sparse, target, + rgb_near, self.args) + r_mat, t_vec = None, None + if self.split == 'train' and self.args.use_pose: + success, r_vec, t_vec = get_pose_pnp(rgb, rgb_near, sparse, self.K) + # discard if translation is too small + success = success and LA.norm(t_vec) > self.threshold_translation + if success: + r_mat, _ = cv2.Rodrigues(r_vec) + else: + # return the same image and no motion when PnP fails + rgb_near = rgb + t_vec = np.zeros((3, 1)) + r_mat = np.eye(3) + + rgb, gray = handle_gray(rgb, self.args) + candidates = {"rgb": rgb, "d": sparse, "gt": target, + "g": gray, "r_mat": r_mat, "t_vec": t_vec, "rgb_near": rgb_near} + items = { + key: to_float_tensor(val) + for key, val in candidates.items() if val is not None + } + + return items + + def __len__(self): + return len(self.paths['gt']) diff --git a/modelscope/models/cv/self_supervised_depth_completion/dataloaders/pose_estimator.py b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/pose_estimator.py new file mode 100644 index 000000000..e1af127c6 --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/pose_estimator.py @@ -0,0 +1,104 @@ +import cv2 +import numpy as np + + +def rgb2gray(rgb): + return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) + + +def convert_2d_to_3d(u, v, z, K): + v0 = K[1][2] + u0 = K[0][2] + fy = K[1][1] + fx = K[0][0] + x = (u - u0) * z / fx + y = (v - v0) * z / fy + return (x, y, z) + + +def feature_match(img1, img2): + r''' Find features on both images and match them pairwise + ''' + max_n_features = 1000 + # max_n_features = 500 + use_flann = False # better not use flann + + detector = cv2.xfeatures2d.SIFT_create(max_n_features) + + # find the keypoints and descriptors with SIFT + kp1, des1 = detector.detectAndCompute(img1, None) + kp2, des2 = detector.detectAndCompute(img2, None) + if (des1 is None) or (des2 is None): + return [], [] + des1 = des1.astype(np.float32) + des2 = des2.astype(np.float32) + + if use_flann: + # FLANN parameters + FLANN_INDEX_KDTREE = 0 + index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) + search_params = dict(checks=50) + flann = cv2.FlannBasedMatcher(index_params, search_params) + matches = flann.knnMatch(des1, des2, k=2) + else: + matcher = cv2.DescriptorMatcher().create('BruteForce') + matches = matcher.knnMatch(des1, des2, k=2) + + good = [] + pts1 = [] + pts2 = [] + # ratio test as per Lowe's paper + for i, (m, n) in enumerate(matches): + if m.distance < 0.8 * n.distance: + good.append(m) + pts2.append(kp2[m.trainIdx].pt) + pts1.append(kp1[m.queryIdx].pt) + + pts1 = np.int32(pts1) + pts2 = np.int32(pts2) + return pts1, pts2 + + +def get_pose_pnp(rgb_curr, rgb_near, depth_curr, K): + gray_curr = rgb2gray(rgb_curr).astype(np.uint8) + gray_near = rgb2gray(rgb_near).astype(np.uint8) + height, width = gray_curr.shape + + pts2d_curr, pts2d_near = feature_match(gray_curr, + gray_near) # feature matching + + # dilation of depth + kernel = np.ones((4, 4), np.uint8) + depth_curr_dilated = cv2.dilate(depth_curr, kernel) + + # extract 3d pts + pts3d_curr = [] + pts2d_near_filtered = [ + ] # keep only feature points with depth in the current frame + for i, pt2d in enumerate(pts2d_curr): + # print(pt2d) + u, v = pt2d[0], pt2d[1] + z = depth_curr_dilated[v, u] + if z > 0: + xyz_curr = convert_2d_to_3d(u, v, z, K) + pts3d_curr.append(xyz_curr) + pts2d_near_filtered.append(pts2d_near[i]) + + # the minimal number of points accepted by solvePnP is 4: + if len(pts3d_curr) >= 4 and len(pts2d_near_filtered) >= 4: + pts3d_curr = np.expand_dims(np.array(pts3d_curr).astype(np.float32), + axis=1) + pts2d_near_filtered = np.expand_dims( + np.array(pts2d_near_filtered).astype(np.float32), axis=1) + + # ransac + ret = cv2.solvePnPRansac(pts3d_curr, + pts2d_near_filtered, + K, + distCoeffs=None) + success = ret[0] + rotation_vector = ret[1] + translation_vector = ret[2] + return (success, rotation_vector, translation_vector) + else: + return (0, None, None) diff --git a/modelscope/models/cv/self_supervised_depth_completion/dataloaders/transforms.py b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/transforms.py new file mode 100644 index 000000000..ddee9123b --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/dataloaders/transforms.py @@ -0,0 +1,618 @@ +from __future__ import division +import torch + +from PIL import Image, ImageEnhance +try: + import accimage +except ImportError: + accimage = None + +import numpy as np +import numbers +import types + +import scipy.ndimage.interpolation as itpl +import skimage.transform + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL Image: Contrast adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL Image: Saturation adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See https://en.wikipedia.org/wiki/Hue for more details on Hue. + + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL Image: Hue adjusted image. + """ + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError( + 'hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + +def adjust_gamma(img, gamma, gain=1): + """Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + I_out = 255 * gain * ((I_in / 255) ** gamma) + + See https://en.wikipedia.org/wiki/Gamma_correction for more details. + + Args: + img (PIL Image): PIL Image to be adjusted. + gamma (float): Non negative real number. gamma larger than 1 make the + shadows darker, while gamma smaller than 1 make dark regions + lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + input_mode = img.mode + img = img.convert('RGB') + + np_img = np.array(img, dtype=np.float32) + np_img = 255 * gain * ((np_img / 255)**gamma) + np_img = np.uint8(np.clip(np_img, 0, 255)) + + img = Image.fromarray(np_img, 'RGB').convert(input_mode) + return img + + +class Compose(object): + """Composes several transforms together. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + +class ToTensor(object): + """Convert a ``numpy.ndarray`` to tensor. + + Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). + """ + + def __call__(self, img): + """Convert a ``numpy.ndarray`` to tensor. + + Args: + img (numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not (_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + + if isinstance(img, np.ndarray): + # handle numpy array + if img.ndim == 3: + img = torch.from_numpy(img.transpose((2, 0, 1)).copy()) + elif img.ndim == 2: + img = torch.from_numpy(img.copy()) + else: + raise RuntimeError( + 'img should be ndarray with 2 or 3 dimensions. Got {}'. + format(img.ndim)) + + return img + + +class NormalizeNumpyArray(object): + """Normalize a ``numpy.ndarray`` with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform + will normalize each channel of the input ``numpy.ndarray`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, img): + """ + Args: + img (numpy.ndarray): Image of size (H, W, C) to be normalized. + + Returns: + Tensor: Normalized image. + """ + if not (_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + # TODO: make efficient + # print(img.shape) + for i in range(3): + img[:, :, i] = (img[:, :, i] - self.mean[i]) / self.std[i] + return img + + +class NormalizeTensor(object): + """Normalize an tensor image with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + if not _is_tensor_image(tensor): + raise TypeError('tensor is not a torch image.') + # TODO: make efficient + for t, m, s in zip(tensor, self.mean, self.std): + t.sub_(m).div_(s) + return tensor + + +class Rotate(object): + """Rotates the given ``numpy.ndarray``. + + Args: + angle (float): The rotation angle in degrees. + """ + + def __init__(self, angle): + self.angle = angle + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be rotated. + + Returns: + img (numpy.ndarray (C x H x W)): Rotated image. + """ + + # order=0 means nearest-neighbor type interpolation + return skimage.transform.rotate(img, self.angle, resize=False, order=0) + + +class Resize(object): + """Resize the the given ``numpy.ndarray`` to the given size. + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, size, interpolation='nearest'): + assert isinstance(size, float) + self.size = size + self.interpolation = interpolation + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be scaled. + Returns: + img (numpy.ndarray (C x H x W)): Rescaled image. + """ + if img.ndim == 3: + return skimage.transform.rescale(img, self.size, order=0) + elif img.ndim == 2: + return skimage.transform.rescale(img, self.size, order=0) + else: + RuntimeError( + 'img should be ndarray with 2 or 3 dimensions. Got {}'.format( + img.ndim)) + + +class CenterCrop(object): + """Crops the given ``numpy.ndarray`` at the center. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for center crop. + + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. + """ + h = img.shape[0] + w = img.shape[1] + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + + # # randomized cropping + # i = np.random.randint(i-3, i+4) + # j = np.random.randint(j-3, j+4) + + return i, j, th, tw + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + + Returns: + img (numpy.ndarray (C x H x W)): Cropped image. + """ + i, j, h, w = self.get_params(img, self.size) + """ + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + """ + if not (_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + if img.ndim == 3: + return img[i:i + h, j:j + w, :] + elif img.ndim == 2: + return img[i:i + h, j:j + w] + else: + raise RuntimeError( + 'img should be ndarray with 2 or 3 dimensions. Got {}'.format( + img.ndim)) + + +class BottomCrop(object): + """Crops the given ``numpy.ndarray`` at the bottom. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for bottom crop. + + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop. + """ + h = img.shape[0] + w = img.shape[1] + th, tw = output_size + i = h - th + j = int(round((w - tw) / 2.)) + + # randomized left and right cropping + # i = np.random.randint(i-3, i+4) + # j = np.random.randint(j-1, j+1) + + return i, j, th, tw + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + + Returns: + img (numpy.ndarray (C x H x W)): Cropped image. + """ + i, j, h, w = self.get_params(img, self.size) + """ + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + """ + if not (_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + if img.ndim == 3: + return img[i:i + h, j:j + w, :] + elif img.ndim == 2: + return img[i:i + h, j:j + w] + else: + raise RuntimeError( + 'img should be ndarray with 2 or 3 dimensions. Got {}'.format( + img.ndim)) + + +class Crop(object): + """Crops the given ``numpy.ndarray`` at the center. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, crop): + self.crop = crop + + @staticmethod + def get_params(img, crop): + """Get parameters for ``crop`` for center crop. + + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. + """ + x_l, x_r, y_b, y_t = crop + h = img.shape[0] + w = img.shape[1] + assert x_l >= 0 and x_l < w + assert x_r >= 0 and x_r < w + assert y_b >= 0 and y_b < h + assert y_t >= 0 and y_t < h + assert x_l < x_r and y_b < y_t + + return x_l, x_r, y_b, y_t + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + + Returns: + img (numpy.ndarray (C x H x W)): Cropped image. + """ + x_l, x_r, y_b, y_t = self.get_params(img, self.crop) + """ + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + """ + if not (_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + if img.ndim == 3: + return img[y_b:y_t, x_l:x_r, :] + elif img.ndim == 2: + return img[y_b:y_t, x_l:x_r] + else: + raise RuntimeError( + 'img should be ndarray with 2 or 3 dimensions. Got {}'.format( + img.ndim)) + + +class Lambda(object): + """Apply a user-defined lambda as a transform. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + + +class HorizontalFlip(object): + """Horizontally flip the given ``numpy.ndarray``. + + Args: + do_flip (boolean): whether or not do horizontal flip. + + """ + + def __init__(self, do_flip): + self.do_flip = do_flip + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be flipped. + + Returns: + img (numpy.ndarray (C x H x W)): flipped image. + """ + if not (_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + + if self.do_flip: + return np.fliplr(img) + else: + return img + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation of an image. + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + transforms = [] + transforms.append( + Lambda(lambda img: adjust_brightness(img, brightness))) + transforms.append(Lambda(lambda img: adjust_contrast(img, contrast))) + transforms.append( + Lambda(lambda img: adjust_saturation(img, saturation))) + transforms.append(Lambda(lambda img: adjust_hue(img, hue))) + np.random.shuffle(transforms) + self.transform = Compose(transforms) + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Input image. + + Returns: + img (numpy.ndarray (C x H x W)): Color jittered image. + """ + if not (_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + + pil = Image.fromarray(img) + return np.array(self.transform(pil)) diff --git a/modelscope/models/cv/self_supervised_depth_completion/helper.py b/modelscope/models/cv/self_supervised_depth_completion/helper.py new file mode 100644 index 000000000..c43534030 --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/helper.py @@ -0,0 +1,264 @@ +import os +import time +import shutil +import torch +import csv +from modelscope.models.cv.self_supervised_depth_completion import vis_utils +from modelscope.models.cv.self_supervised_depth_completion.metrics import Result + +fieldnames = [ + 'epoch', 'rmse', 'photo', 'mae', 'irmse', 'imae', 'mse', 'absrel', 'lg10', + 'silog', 'squared_rel', 'delta1', 'delta2', 'delta3', 'data_time', + 'gpu_time' +] + + +class logger: + def __init__(self, args, prepare=True): + self.args = args + output_directory = get_folder_name(args) + self.output_directory = output_directory + self.best_result = Result() + self.best_result.set_to_worst() + + if not prepare: + return + if not os.path.exists(output_directory): + os.makedirs(output_directory) + self.train_csv = os.path.join(output_directory, 'train.csv') + self.val_csv = os.path.join(output_directory, 'val.csv') + self.best_txt = os.path.join(output_directory, 'best.txt') + + # backup the source code + if args.resume == '': + print("=> creating source code backup ...") + backup_directory = os.path.join(output_directory, "code_backup") + self.backup_directory = backup_directory + # backup_source_code(backup_directory) + # create new csv files with only header + with open(self.train_csv, 'w') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + with open(self.val_csv, 'w') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + print("=> finished creating source code backup.") + + def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter, + avg_meter): + if (i + 1) % self.args.print_freq == 0: + avg = avg_meter.average() + blk_avg = blk_avg_meter.average() + print('=> output: {}'.format(self.output_directory)) + print( + '{split} Epoch: {0} [{1}/{2}]\tlr={lr} ' + 't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) ' + 't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t' + 'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) ' + 'MAE={blk_avg.mae:.2f}({average.mae:.2f}) ' + 'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) ' + 'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t' + 'silog={blk_avg.silog:.2f}({average.silog:.2f}) ' + 'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) ' + 'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) ' + 'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t' + 'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) ' + 'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) ' + .format(epoch, + i + 1, + n_set, + lr=lr, + blk_avg=blk_avg, + average=avg, + split=split.capitalize())) + blk_avg_meter.reset() + + def conditional_save_info(self, split, average_meter, epoch): + avg = average_meter.average() + if split == "train": + csvfile_name = self.train_csv + elif split == "val": + csvfile_name = self.val_csv + elif split == "eval": + eval_filename = os.path.join(self.output_directory, 'eval.txt') + self.save_single_txt(eval_filename, avg, epoch) + return avg + elif "test" in split: + return avg + else: + raise ValueError("wrong split provided to logger") + with open(csvfile_name, 'a') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writerow({ + 'epoch': epoch, + 'rmse': avg.rmse, + 'photo': avg.photometric, + 'mae': avg.mae, + 'irmse': avg.irmse, + 'imae': avg.imae, + 'mse': avg.mse, + 'silog': avg.silog, + 'squared_rel': avg.squared_rel, + 'absrel': avg.absrel, + 'lg10': avg.lg10, + 'delta1': avg.delta1, + 'delta2': avg.delta2, + 'delta3': avg.delta3, + 'gpu_time': avg.gpu_time, + 'data_time': avg.data_time + }) + return avg + + def save_single_txt(self, filename, result, epoch): + with open(filename, 'w') as txtfile: + txtfile.write( + ("rank_metric={}\n" + "epoch={}\n" + "rmse={:.3f}\n" + + "mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" + + "irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" + + "absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" + + "t_gpu={:.4f}").format(self.args.rank_metric, epoch, + result.rmse, result.mae, result.silog, + result.squared_rel, result.irmse, + result.imae, result.mse, result.absrel, + result.lg10, result.delta1, + result.gpu_time)) + + def save_best_txt(self, result, epoch): + self.save_single_txt(self.best_txt, result, epoch) + + def _get_img_comparison_name(self, mode, epoch, is_best=False): + if mode == 'eval': + return self.output_directory + '/comparison_eval.png' + if mode == 'val': + if is_best: + return self.output_directory + '/comparison_best.png' + else: + return self.output_directory + '/comparison_' + str( + epoch) + '.png' + + def conditional_save_img_comparison(self, mode, i, ele, pred, epoch): + # save 8 images for visualization + if mode == 'val' or mode == 'eval': + skip = 100 + if i == 0: + self.img_merge = vis_utils.merge_into_row(ele, pred) + elif i % skip == 0 and i < 8 * skip: + row = vis_utils.merge_into_row(ele, pred) + self.img_merge = vis_utils.add_row(self.img_merge, row) + elif i == 8 * skip: + filename = self._get_img_comparison_name(mode, epoch) + vis_utils.save_image(self.img_merge, filename) + + def save_img_comparison_as_best(self, mode, epoch): + if mode == 'val': + filename = self._get_img_comparison_name(mode, epoch, is_best=True) + vis_utils.save_image(self.img_merge, filename) + + def get_ranking_error(self, result): + return getattr(result, self.args.rank_metric) + + def rank_conditional_save_best(self, mode, result, epoch): + error = self.get_ranking_error(result) + best_error = self.get_ranking_error(self.best_result) + is_best = error < best_error + if is_best and mode == "val": + self.old_best_result = self.best_result + self.best_result = result + self.save_best_txt(result, epoch) + return is_best + + def conditional_save_pred(self, mode, i, pred, epoch): + if ("test" in mode or mode == "eval") and self.args.save_pred: + + # save images for visualization/ testing + image_folder = os.path.join(self.output_directory, + mode + "_output") + if not os.path.exists(image_folder): + os.makedirs(image_folder) + img = torch.squeeze(pred.data.cpu()).numpy() + filename = os.path.join(image_folder, '{0:010d}.png'.format(i)) + vis_utils.save_depth_as_uint16png(img, filename) + + def conditional_summarize(self, mode, avg, is_best): + print("\n*\nSummary of ", mode, "round") + print('' + 'RMSE={average.rmse:.3f}\n' + 'MAE={average.mae:.3f}\n' + 'Photo={average.photometric:.3f}\n' + 'iRMSE={average.irmse:.3f}\n' + 'iMAE={average.imae:.3f}\n' + 'squared_rel={average.squared_rel}\n' + 'silog={average.silog}\n' + 'Delta1={average.delta1:.3f}\n' + 'REL={average.absrel:.3f}\n' + 'Lg10={average.lg10:.3f}\n' + 't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time)) + if is_best and mode == "val": + print("New best model by %s (was %.3f)" % + (self.args.rank_metric, + self.get_ranking_error(self.old_best_result))) + elif mode == "val": + print("(best %s is %.3f)" % + (self.args.rank_metric, + self.get_ranking_error(self.best_result))) + print("*\n") + + +ignore_hidden = shutil.ignore_patterns(".", "..", ".git*", "*pycache*", + "*build", "*.fuse*", "*_drive_*") + + +def backup_source_code(backup_directory): + if os.path.exists(backup_directory): + shutil.rmtree(backup_directory) + shutil.copytree('.', backup_directory, ignore=ignore_hidden) + + +def adjust_learning_rate(lr_init, optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every 5 epochs""" + lr = lr_init * (0.1**(epoch // 5)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr + + +def save_checkpoint(state, is_best, epoch, output_directory): + checkpoint_filename = os.path.join(output_directory, + 'checkpoint-' + str(epoch) + '.pth.tar') + torch.save(state, checkpoint_filename) + if is_best: + best_filename = os.path.join(output_directory, 'model_best.pth.tar') + shutil.copyfile(checkpoint_filename, best_filename) + if epoch > 0: + prev_checkpoint_filename = os.path.join( + output_directory, 'checkpoint-' + str(epoch - 1) + '.pth.tar') + if os.path.exists(prev_checkpoint_filename): + os.remove(prev_checkpoint_filename) + + +def get_folder_name(args): + current_time = time.strftime('%Y-%m-%d@%H-%M') + if args.use_pose: + prefix = "mode={}.w1={}.w2={}.".format(args.train_mode, args.w1, + args.w2) + else: + prefix = "mode={}.".format(args.train_mode) + # return os.path.join(args.result, + # prefix + 'input={}.resnet{}.criterion={}.lr={}.bs={}.wd={}.pretrained={}.jitter={}.time={}'. + # format(args.input, args.layers, args.criterion, \ + # args.lr, args.batch_size, args.weight_decay, \ + # args.pretrained, args.jitter, current_time + # )) + return os.path.join(args.result, 'test') + + +avgpool = torch.nn.AvgPool2d(kernel_size=2, stride=2).cuda() + + +def multiscale(img): + img1 = avgpool(img) + img2 = avgpool(img1) + img3 = avgpool(img2) + img4 = avgpool(img3) + img5 = avgpool(img4) + return img5, img4, img3, img2, img1 diff --git a/modelscope/models/cv/self_supervised_depth_completion/inverse_warp.py b/modelscope/models/cv/self_supervised_depth_completion/inverse_warp.py new file mode 100644 index 000000000..81dccfdb6 --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/inverse_warp.py @@ -0,0 +1,138 @@ +import torch +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger +logger = get_logger() + + +class Intrinsics: + """Intrinsics""" + def __init__(self, width, height, fu, fv, cu=0, cv=0): + self.height, self.width = height, width + self.fu, self.fv = fu, fv # fu, fv: focal length along the horizontal and vertical axes + + # cu, cv: optical center along the horizontal and vertical axes + self.cu = cu if cu > 0 else (width - 1) / 2.0 + self.cv = cv if cv > 0 else (height - 1) / 2.0 + + # U, V represent the homogeneous horizontal and vertical coordinates in the pixel space + self.U = torch.arange(start=0, end=width).expand(height, width).float() + self.V = torch.arange(start=0, end=height).expand(width, + height).t().float() + + # X_cam, Y_cam represent the homogeneous x, y coordinates (assuming depth z=1) in the camera coordinate system + self.X_cam = (self.U - self.cu) / self.fu + self.Y_cam = (self.V - self.cv) / self.fv + + self.is_cuda = False + + def cuda(self): + self.X_cam.data = self.X_cam.data.cuda() + self.Y_cam.data = self.Y_cam.data.cuda() + self.is_cuda = True + return self + + def scale(self, height, width): + # return a new set of corresponding intrinsic parameters for the scaled image + ratio_u = float(width) / self.width + ratio_v = float(height) / self.height + fu = ratio_u * self.fu + fv = ratio_v * self.fv + cu = ratio_u * self.cu + cv = ratio_v * self.cv + new_intrinsics = Intrinsics(width, height, fu, fv, cu, cv) + if self.is_cuda: + new_intrinsics.cuda() + return new_intrinsics + + def __print__(self): + logger.info('size=({},{})\nfocal length=({},{})\noptical center=({},{})'. + format(self.height, self.width, self.fv, self.fu, self.cv, + self.cu)) + + +def image_to_pointcloud(depth, intrinsics): + assert depth.dim() == 4 + assert depth.size(1) == 1 + + X = depth * intrinsics.X_cam + Y = depth * intrinsics.Y_cam + return torch.cat((X, Y, depth), dim=1) + + +def pointcloud_to_image(pointcloud, intrinsics): + assert pointcloud.dim() == 4 + + batch_size = pointcloud.size(0) + X = pointcloud[:, 0, :, :] # .view(batch_size, -1) + Y = pointcloud[:, 1, :, :] # .view(batch_size, -1) + Z = pointcloud[:, 2, :, :].clamp(min=1e-3) # .view(batch_size, -1) + + # compute pixel coordinates + U_proj = intrinsics.fu * X / Z + intrinsics.cu # horizontal pixel coordinate + V_proj = intrinsics.fv * Y / Z + intrinsics.cv # vertical pixel coordinate + + # normalization to [-1, 1], required by torch.nn.functional.grid_sample + U_proj_normalized = (2 * U_proj / (intrinsics.width - 1) - 1).view( + batch_size, -1) + V_proj_normalized = (2 * V_proj / (intrinsics.height - 1) - 1).view( + batch_size, -1) + + # This was important since PyTorch didn't do as it claimed for points out of boundary + # See https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py + # Might not be necessary any more + U_proj_mask = ((U_proj_normalized > 1) + (U_proj_normalized < -1)).detach() + U_proj_normalized[U_proj_mask] = 2 + V_proj_mask = ((V_proj_normalized > 1) + (V_proj_normalized < -1)).detach() + V_proj_normalized[V_proj_mask] = 2 + + pixel_coords = torch.stack([U_proj_normalized, V_proj_normalized], + dim=2) # [B, H*W, 2] + return pixel_coords.view(batch_size, intrinsics.height, intrinsics.width, + 2) + + +def batch_multiply(batch_scalar, batch_matrix): + # input: batch_scalar of size b, batch_matrix of size b * 3 * 3 + # output: batch_matrix of size b * 3 * 3 + batch_size = batch_scalar.size(0) + output = batch_matrix.clone() + for i in range(batch_size): + output[i] = batch_scalar[i] * batch_matrix[i] + return output + + +def transform_curr_to_near(pointcloud_curr, r_mat, t_vec, intrinsics): + # translation and rotmat represent the transformation from tgt pose to src pose + batch_size = pointcloud_curr.size(0) + XYZ_ = torch.bmm(r_mat, pointcloud_curr.view(batch_size, 3, -1)) + + X = (XYZ_[:, 0, :] + t_vec[:, 0].unsqueeze(1)).view( + -1, 1, intrinsics.height, intrinsics.width) + Y = (XYZ_[:, 1, :] + t_vec[:, 1].unsqueeze(1)).view( + -1, 1, intrinsics.height, intrinsics.width) + Z = (XYZ_[:, 2, :] + t_vec[:, 2].unsqueeze(1)).view( + -1, 1, intrinsics.height, intrinsics.width) + + pointcloud_near = torch.cat((X, Y, Z), dim=1) + + return pointcloud_near + + +def homography_from(rgb_near, depth_curr, r_mat, t_vec, intrinsics): + # inverse warp the RGB image from the nearby frame to the current frame + + # to ensure dimension consistency + r_mat = r_mat.view(-1, 3, 3) + t_vec = t_vec.view(-1, 3) + + # compute source pixel coordinate + pointcloud_curr = image_to_pointcloud(depth_curr, intrinsics) + pointcloud_near = transform_curr_to_near(pointcloud_curr, r_mat, t_vec, + intrinsics) + pixel_coords_near = pointcloud_to_image(pointcloud_near, intrinsics) + + # the warping + warped = F.grid_sample(rgb_near, pixel_coords_near) + + return warped diff --git a/modelscope/models/cv/self_supervised_depth_completion/metrics.py b/modelscope/models/cv/self_supervised_depth_completion/metrics.py new file mode 100644 index 000000000..b7a58a548 --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/metrics.py @@ -0,0 +1,164 @@ +import torch +import math +import numpy as np + +lg_e_10 = math.log(10) + + +def log10(x): + """Convert a new tensor with the base-10 logarithm of the elements of x. """ + return torch.log(x) / lg_e_10 + + +class Result(object): + """Result""" + def __init__(self): + self.irmse = 0 + self.imae = 0 + self.mse = 0 + self.rmse = 0 + self.mae = 0 + self.absrel = 0 + self.squared_rel = 0 + self.lg10 = 0 + self.delta1 = 0 + self.delta2 = 0 + self.delta3 = 0 + self.data_time = 0 + self.gpu_time = 0 + self.silog = 0 # Scale invariant logarithmic error [log(m)*100] + self.photometric = 0 + + def set_to_worst(self): + self.irmse = np.inf + self.imae = np.inf + self.mse = np.inf + self.rmse = np.inf + self.mae = np.inf + self.absrel = np.inf + self.squared_rel = np.inf + self.lg10 = np.inf + self.silog = np.inf + self.delta1 = 0 + self.delta2 = 0 + self.delta3 = 0 + self.data_time = 0 + self.gpu_time = 0 + + def update(self, irmse, imae, mse, rmse, mae, absrel, squared_rel, lg10, + delta1, delta2, delta3, gpu_time, data_time, silog, photometric=0): + """update""" + self.irmse = irmse + self.imae = imae + self.mse = mse + self.rmse = rmse + self.mae = mae + self.absrel = absrel + self.squared_rel = squared_rel + self.lg10 = lg10 + self.delta1 = delta1 + self.delta2 = delta2 + self.delta3 = delta3 + self.data_time = data_time + self.gpu_time = gpu_time + self.silog = silog + self.photometric = photometric + + def evaluate(self, output, target, photometric=0): + """evaluate""" + valid_mask = target > 0.1 + + # convert from meters to mm + output_mm = 1e3 * output[valid_mask] + target_mm = 1e3 * target[valid_mask] + + abs_diff = (output_mm - target_mm).abs() + + self.mse = float((torch.pow(abs_diff, 2)).mean()) + self.rmse = math.sqrt(self.mse) + self.mae = float(abs_diff.mean()) + self.lg10 = float((log10(output_mm) - log10(target_mm)).abs().mean()) + self.absrel = float((abs_diff / target_mm).mean()) + self.squared_rel = float(((abs_diff / target_mm)**2).mean()) + + maxRatio = torch.max(output_mm / target_mm, target_mm / output_mm) + self.delta1 = float((maxRatio < 1.25).float().mean()) + self.delta2 = float((maxRatio < 1.25**2).float().mean()) + self.delta3 = float((maxRatio < 1.25**3).float().mean()) + self.data_time = 0 + self.gpu_time = 0 + + # silog uses meters + err_log = torch.log(target[valid_mask]) - torch.log(output[valid_mask]) + normalized_squared_log = (err_log**2).mean() + log_mean = err_log.mean() + self.silog = math.sqrt(normalized_squared_log + - log_mean * log_mean) * 100 + + # convert from meters to km + inv_output_km = (1e-3 * output[valid_mask])**(-1) + inv_target_km = (1e-3 * target[valid_mask])**(-1) + abs_inv_diff = (inv_output_km - inv_target_km).abs() + self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) + self.imae = float(abs_inv_diff.mean()) + + self.photometric = float(photometric) + + +class AverageMeter(object): + """AverageMeter""" + def __init__(self): + self.reset() + + def reset(self): + """reset""" + self.count = 0.0 + self.sum_irmse = 0 + self.sum_imae = 0 + self.sum_mse = 0 + self.sum_rmse = 0 + self.sum_mae = 0 + self.sum_absrel = 0 + self.sum_squared_rel = 0 + self.sum_lg10 = 0 + self.sum_delta1 = 0 + self.sum_delta2 = 0 + self.sum_delta3 = 0 + self.sum_data_time = 0 + self.sum_gpu_time = 0 + self.sum_photometric = 0 + self.sum_silog = 0 + + def update(self, result, gpu_time, data_time, n=1): + """update""" + self.count += n + self.sum_irmse += n * result.irmse + self.sum_imae += n * result.imae + self.sum_mse += n * result.mse + self.sum_rmse += n * result.rmse + self.sum_mae += n * result.mae + self.sum_absrel += n * result.absrel + self.sum_squared_rel += n * result.squared_rel + self.sum_lg10 += n * result.lg10 + self.sum_delta1 += n * result.delta1 + self.sum_delta2 += n * result.delta2 + self.sum_delta3 += n * result.delta3 + self.sum_data_time += n * data_time + self.sum_gpu_time += n * gpu_time + self.sum_silog += n * result.silog + self.sum_photometric += n * result.photometric + + def average(self): + """average""" + avg = Result() + if self.count > 0: + avg.update( + self.sum_irmse / self.count, self.sum_imae / self.count, + self.sum_mse / self.count, self.sum_rmse / self.count, + self.sum_mae / self.count, self.sum_absrel / self.count, + self.sum_squared_rel / self.count, self.sum_lg10 / self.count, + self.sum_delta1 / self.count, self.sum_delta2 / self.count, + self.sum_delta3 / self.count, self.sum_gpu_time / self.count, + self.sum_data_time / self.count, self.sum_silog / self.count, + self.sum_photometric / self.count) + return avg diff --git a/modelscope/models/cv/self_supervised_depth_completion/model.py b/modelscope/models/cv/self_supervised_depth_completion/model.py new file mode 100644 index 000000000..1c25d7569 --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/model.py @@ -0,0 +1,211 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models import resnet + + +def init_weights(m): + """init_weights""" + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + m.weight.data.normal_(0, 1e-3) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.ConvTranspose2d): + m.weight.data.normal_(0, 1e-3) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +def conv_bn_relu(in_channels, out_channels, kernel_size, + stride=1, padding=0, bn=True, relu=True): + """conv_bn_relu""" + bias = not bn + layers = [] + layers.append( + nn.Conv2d(in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=bias)) + if bn: + layers.append(nn.BatchNorm2d(out_channels)) + if relu: + layers.append(nn.LeakyReLU(0.2, inplace=True)) + layers = nn.Sequential(*layers) + + # initialize the weights + for m in layers.modules(): + init_weights(m) + + return layers + + +def convt_bn_relu(in_channels, out_channels, kernel_size, + stride=1, padding=0, output_padding=0, bn=True, relu=True): + """convt_bn_relu""" + bias = not bn + layers = [] + layers.append( + nn.ConvTranspose2d(in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + bias=bias)) + if bn: + layers.append(nn.BatchNorm2d(out_channels)) + if relu: + layers.append(nn.LeakyReLU(0.2, inplace=True)) + layers = nn.Sequential(*layers) + + # initialize the weights + for m in layers.modules(): + init_weights(m) + + return layers + + +class DepthCompletionNet(nn.Module): + """DepthCompletionNet""" + def __init__(self, args): + assert ( + args.layers in [18, 34, 50, 101, 152] + ), f'Only layers 18, 34, 50, 101, and 152 are defined, but got {layers}'.format( + layers) + super(DepthCompletionNet, self).__init__() + self.modality = args.input + + if 'd' in self.modality: + channels = 64 // len(self.modality) + self.conv1_d = conv_bn_relu(1, + channels, + kernel_size=3, + stride=1, + padding=1) + if 'rgb' in self.modality: + channels = 64 * 3 // len(self.modality) + self.conv1_img = conv_bn_relu(3, + channels, + kernel_size=3, + stride=1, + padding=1) + elif 'g' in self.modality: + channels = 64 // len(self.modality) + self.conv1_img = conv_bn_relu(1, + channels, + kernel_size=3, + stride=1, + padding=1) + + pretrained_model = resnet.__dict__['resnet{}'.format( + args.layers)](pretrained=args.pretrained) + if not args.pretrained: + pretrained_model.apply(init_weights) + # self.maxpool = pretrained_model._modules['maxpool'] + self.conv2 = pretrained_model._modules['layer1'] + self.conv3 = pretrained_model._modules['layer2'] + self.conv4 = pretrained_model._modules['layer3'] + self.conv5 = pretrained_model._modules['layer4'] + del pretrained_model # clear memory + + # define number of intermediate channels + if args.layers <= 34: + num_channels = 512 + elif args.layers >= 50: + num_channels = 2048 + self.conv6 = conv_bn_relu(num_channels, + 512, + kernel_size=3, + stride=2, + padding=1) + + # decoding layers + kernel_size = 3 + stride = 2 + self.convt5 = convt_bn_relu(in_channels=512, + out_channels=256, + kernel_size=kernel_size, + stride=stride, + padding=1, + output_padding=1) + self.convt4 = convt_bn_relu(in_channels=768, + out_channels=128, + kernel_size=kernel_size, + stride=stride, + padding=1, + output_padding=1) + self.convt3 = convt_bn_relu(in_channels=(256 + 128), + out_channels=64, + kernel_size=kernel_size, + stride=stride, + padding=1, + output_padding=1) + self.convt2 = convt_bn_relu(in_channels=(128 + 64), + out_channels=64, + kernel_size=kernel_size, + stride=stride, + padding=1, + output_padding=1) + self.convt1 = convt_bn_relu(in_channels=128, + out_channels=64, + kernel_size=kernel_size, + stride=1, + padding=1) + self.convtf = conv_bn_relu(in_channels=128, + out_channels=1, + kernel_size=1, + stride=1, + bn=False, + relu=False) + + def forward(self, x): + """forward""" + # first layer + if 'd' in self.modality: + conv1_d = self.conv1_d(x['d']) + if 'rgb' in self.modality: + conv1_img = self.conv1_img(x['rgb']) + elif 'g' in self.modality: + conv1_img = self.conv1_img(x['g']) + + if self.modality == 'rgbd' or self.modality == 'gd': + conv1 = torch.cat((conv1_d, conv1_img), 1) + else: + conv1 = conv1_d if (self.modality == 'd') else conv1_img + + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) # batchsize * ? * 176 * 608 + conv4 = self.conv4(conv3) # batchsize * ? * 88 * 304 + conv5 = self.conv5(conv4) # batchsize * ? * 44 * 152 + conv6 = self.conv6(conv5) # batchsize * ? * 22 * 76 + + # decoder + convt5 = self.convt5(conv6) + y = torch.cat((convt5, conv5), 1) + + convt4 = self.convt4(y) + y = torch.cat((convt4, conv4), 1) + + convt3 = self.convt3(y) + y = torch.cat((convt3, conv3), 1) + + convt2 = self.convt2(y) + y = torch.cat((convt2, conv2), 1) + + convt1 = self.convt1(y) + y = torch.cat((convt1, conv1), 1) + + y = self.convtf(y) + + if self.training: + return 100 * y + else: + min_distance = 0.9 + return F.relu( + 100 * y - min_distance + ) + min_distance # the minimum range of Velodyne is around 3 feet ~= 0.9m diff --git a/modelscope/models/cv/self_supervised_depth_completion/self_supervised_depth_completion.py b/modelscope/models/cv/self_supervised_depth_completion/self_supervised_depth_completion.py new file mode 100644 index 000000000..4ebc22181 --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/self_supervised_depth_completion.py @@ -0,0 +1,396 @@ +from modelscope.models.cv.self_supervised_depth_completion import helper +from modelscope.models.cv.self_supervised_depth_completion import criteria +from modelscope.models.cv.self_supervised_depth_completion.metrics import AverageMeter, Result +# import mmcv +from argparse import ArgumentParser +# import torchvision +from os import makedirs +from tqdm import tqdm +import cv2 +import numpy as np +import torch.utils.data +import torch.optim +import torch.nn.parallel +import torch +import time +# import argparse +import os +import sys +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +from modelscope.utils.logger import get_logger +from modelscope.utils.constant import Tasks +from modelscope.models.builder import MODELS +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.metainfo import Models +from modelscope.models.cv.self_supervised_depth_completion.inverse_warp import Intrinsics, homography_from +from modelscope.models.cv.self_supervised_depth_completion.model import DepthCompletionNet +from .dataloaders.kitti_loader import load_calib, oheight, owidth, input_options, KittiDepth + + + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + +# from modelscope.utils.config import Config + + +m_logger = get_logger() + + +@MODELS.register_module( + Tasks.self_supervised_depth_completion, module_name=Models.self_supervised_depth_completion) +class SelfSupervisedDepthCompletion(TorchModel): + """SelfSupervisedDepthCompletion Class""" + def __init__(self, model_dir: str, *args, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, *args, **kwargs) + # self.skip_test = False + # self.skip_train = True + # self.skip_video = False + # self.configs = os.path.join(model_dir, 'bouncingballs\\bouncingballs.py') + # self.iteration = 20000 + # define loss functions + self.depth_criterion = criteria.MaskedMSELoss() + self.photometric_criterion = criteria.PhotometricLoss() + self.smoothness_criterion = criteria.SmoothnessLoss() + + def add_args(self, parser): + """add args.""" + parser.add_argument('-w', + '--workers', + default=4, + type=int, + metavar='N', + help='number of data loading workers (default: 4)') + parser.add_argument('--epochs', + default=11, + type=int, + metavar='N', + help='number of total epochs to run (default: 11)') + parser.add_argument('--start-epoch', + default=0, + type=int, + metavar='N', + help='manual epoch number (useful on restarts)') + parser.add_argument('-c', + '--criterion', + metavar='LOSS', + default='l2', + choices=criteria.loss_names, + help='loss function: | '.join(criteria.loss_names) + + ' (default: l2)') + parser.add_argument('-b', + '--batch-size', + default=1, + type=int, + help='mini-batch size (default: 1)') + parser.add_argument('--lr', + '--learning-rate', + default=1e-5, + type=float, + metavar='LR', + help='initial learning rate (default 1e-5)') + parser.add_argument('--weight-decay', + '--wd', + default=0, + type=float, + metavar='W', + help='weight decay (default: 0)') + parser.add_argument('--print-freq', + '-p', + default=10, + type=int, + metavar='N', + help='print frequency (default: 10)') + parser.add_argument('--resume', + default='', + type=str, + metavar='PATH', + help='path to latest checkpoint (default: none)') + parser.add_argument('--data-folder', + default='../data', + type=str, + metavar='PATH', + help='data folder (default: none)') + parser.add_argument('-i', + '--input', + type=str, + default='gd', + choices=input_options, + help='input: | '.join(input_options)) + parser.add_argument('-l', + '--layers', + type=int, + default=34, + help='use 16 for sparse_conv; use 18 or 34 for resnet') + parser.add_argument('--pretrained', + action="store_true", + help='use ImageNet pre-trained weights') + parser.add_argument('--val', + type=str, + default="select", + choices=["select", "full"], + help='full or select validation set') + parser.add_argument('--jitter', + type=float, + default=0.1, + help='color jitter for images') + parser.add_argument( + '--rank-metric', + type=str, + default='rmse', + choices=[m for m in dir(Result()) if not m.startswith('_')], + help='metrics for which best result is sbatch_datacted') + parser.add_argument( + '-m', + '--train-mode', + type=str, + default="dense", + choices=["dense", "sparse", "photo", "sparse+photo", "dense+photo"], + help='dense | sparse | photo | sparse+photo | dense+photo') + parser.add_argument('-e', '--evaluate', default='', type=str, metavar='PATH') + parser.add_argument('--cpu', action="store_true", help='run on cpu') + + def iterate(self, mode, args, loader, model, optimizer, logger, epoch): + """iterate data""" + block_average_meter = AverageMeter() + average_meter = AverageMeter() + meters = [block_average_meter, average_meter] + + # switch to appropriate mode + assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \ + "unsupported mode: {}".format(mode) + if mode == 'train': + model.train() + lr = helper.adjust_learning_rate(args.lr, optimizer, epoch) + else: + model.eval() + lr = 0 + + for i, batch_data in enumerate(loader): + start = time.time() + batch_data = { + key: val.to(self.device) + for key, val in batch_data.items() if val is not None + } + gt = batch_data[ + 'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None + data_time = time.time() - start + + start = time.time() + pred = model(batch_data) + depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None + if mode == 'train': + # Loss 1: the direct depth supervision from ground truth label + # mask=1 indicates that a pixel does not ground truth labels + if 'sparse' in args.train_mode: + depth_loss = self.depth_criterion(pred, batch_data['d']) + mask = (batch_data['d'] < 1e-3).float() + elif 'dense' in args.train_mode: + depth_loss = self.depth_criterion(pred, gt) + mask = (gt < 1e-3).float() + + # Loss 2: the self-supervised photometric loss + if args.use_pose: + # create multi-scale pyramids + pred_array = helper.multiscale(pred) + rgb_curr_array = helper.multiscale(batch_data['rgb']) + rgb_near_array = helper.multiscale(batch_data['rgb_near']) + if mask is not None: + mask_array = helper.multiscale(mask) + num_scales = len(pred_array) + + # compute photometric loss at multiple scales + for scale in range(len(pred_array)): + pred_ = pred_array[scale] + rgb_curr_ = rgb_curr_array[scale] + rgb_near_ = rgb_near_array[scale] + mask_ = None + if mask is not None: + mask_ = mask_array[scale] + + # compute the corresponding intrinsic parameters + height_, width_ = pred_.size(2), pred_.size(3) + intrinsics_ = self.kitti_intrinsics.scale(height_, width_) + + # inverse warp from a nearby frame to the current frame + warped_ = homography_from(rgb_near_, pred_, + batch_data['r_mat'], + batch_data['t_vec'], intrinsics_) + photometric_loss += self.photometric_criterion( + rgb_curr_, warped_, mask_) * (2**(scale - num_scales)) + + # Loss 3: the depth smoothness loss + smooth_loss = self.smoothness_criterion(pred) if args.w2 > 0 else 0 + + # backprop + loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss + optimizer.zero_grad() + loss.backward() + optimizer.step() + + gpu_time = time.time() - start + + # measure accuracy and record loss + with torch.no_grad(): + mini_batch_size = next(iter(batch_data.values())).size(0) + result = Result() + if mode != 'test_prediction' and mode != 'test_completion': + result.evaluate(pred.data, gt.data, photometric_loss) + [ + m.update(result, gpu_time, data_time, mini_batch_size) + for m in meters + ] + logger.conditional_print(mode, i, epoch, lr, len(loader), + block_average_meter, average_meter) + logger.conditional_save_img_comparison(mode, i, batch_data, pred, + epoch) + logger.conditional_save_pred(mode, i, pred, epoch) + + avg = logger.conditional_save_info(mode, average_meter, epoch) + is_best = logger.rank_conditional_save_best(mode, avg, epoch) + if is_best and not (mode == "train"): + logger.save_img_comparison_as_best(mode, epoch) + logger.conditional_summarize(mode, avg, is_best) + + return avg, is_best + + def run(self, model_dir, source_dir): + """main function""" + parser = ArgumentParser(description="Testing script parameters") + self.add_args(parser) + + args = parser.parse_args() + + self.depth_criterion = criteria.MaskedMSELoss() if ( + args.criterion == 'l2') else criteria.MaskedL1Loss() + + + args.use_pose = ("photo" in args.train_mode) + # args.pretrained = not args.no_pretrained + args.use_rgb = ('rgb' in args.input) or args.use_pose + args.use_d = 'd' in args.input + args.use_g = 'g' in args.input + + args.evaluate = os.path.join(self.model_dir, 'model_best.pth') + args.data_folder = source_dir + args.result = os.path.join(args.data_folder, 'results') + + if args.use_pose: + args.w1, args.w2 = 0.1, 0.1 + else: + args.w1, args.w2 = 0, 0 + + cuda = torch.cuda.is_available() and not args.cpu + if cuda: + import torch.backends.cudnn as cudnn + cudnn.benchmark = True + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + print("=> using '{}' for computation.".format(self.device)) + + if args.use_pose: + # hard-coded KITTI camera intrinsics + K = load_calib() + fu, fv = float(K[0, 0]), float(K[1, 1]) + cu, cv = float(K[0, 2]), float(K[1, 2]) + kitti_intrinsics = Intrinsics(owidth, oheight, fu, fv, cu, cv) + if cuda: + kitti_intrinsics = kitti_intrinsics.cuda() + + if args.evaluate: + args_new = args + if os.path.isfile(args.evaluate): + print("=> loading checkpoint '{}' ... ".format(args.evaluate), + end='') + checkpoint = torch.load(args.evaluate, map_location=self.device) + args = checkpoint['args'] + args.data_folder = args_new.data_folder + args.result = args_new.result + args.val = args_new.val + is_eval = True + print("Completed.") + else: + print("No model found at '{}'".format(args.evaluate)) + return + elif args.resume: # optionally resume from a checkpoint + args_new = args + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}' ... ".format(args.resume), + end='') + checkpoint = torch.load(args.resume, map_location=self.device) + args.start_epoch = checkpoint['epoch'] + 1 + args.data_folder = args_new.data_folder + args.result = args_new.result + + args.val = args_new.val + print("Completed. Resuming from epoch {}.".format( + checkpoint['epoch'])) + else: + print("No checkpoint found at '{}'".format(args.resume)) + return + + print("=> creating model and optimizer ... ", end='') + model = DepthCompletionNet(args).to(self.device) + model_named_params = [ + p for _, p in model.named_parameters() if p.requires_grad + ] + optimizer = torch.optim.Adam(model_named_params, + lr=args.lr, + weight_decay=args.weight_decay) + print("completed.") + if checkpoint is not None: + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> checkpoint state loaded.") + + model = torch.nn.DataParallel(model) + + # Data loading code + print("=> creating data loaders ... ") + if not is_eval: + train_dataset = KittiDepth('train', args) + train_loader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + pin_memory=True, + sampler=None) + print("\t==> train_loader size:{}".format(len(train_loader))) + val_dataset = KittiDepth('val', args) + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=1, + shuffle=False, + num_workers=2, + pin_memory=True) # set batch size to be 1 for validation + print("\t==> val_loader size:{}".format(len(val_loader))) + + # create backups and results folder + logger = helper.logger(args) + if checkpoint is not None: + logger.best_result = checkpoint['best_result'] + print("=> logger created.") + + if is_eval: + print("=> starting model evaluation ...") + result, is_best = self.iterate("val", args, val_loader, model, None, logger, + checkpoint['epoch']) + return + + # main loop + print("=> starting main loop ...") + for epoch in range(args.start_epoch, args.epochs): + print("=> starting training epoch {} ..".format(epoch)) + self.iterate("train", args, train_loader, model, optimizer, logger, + epoch) # train for one epoch + result, is_best = self.iterate("val", args, val_loader, model, None, logger, + epoch) # evaluate on validation set + helper.save_checkpoint({ # save checkpoint + 'epoch': epoch, + 'model': model.module.state_dict(), + 'best_result': logger.best_result, + 'optimizer': optimizer.state_dict(), + 'args': args, + }, is_best, epoch, logger.output_directory) diff --git a/modelscope/models/cv/self_supervised_depth_completion/vis_utils.py b/modelscope/models/cv/self_supervised_depth_completion/vis_utils.py new file mode 100644 index 000000000..7668ac1e8 --- /dev/null +++ b/modelscope/models/cv/self_supervised_depth_completion/vis_utils.py @@ -0,0 +1,113 @@ +import os +if not ("DISPLAY" in os.environ): + import matplotlib as mpl + mpl.use('Agg') +import matplotlib.pyplot as plt +from PIL import Image +import numpy as np +import cv2 + +cmap = plt.cm.jet + + +def depth_colorize(depth): + depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth)) + depth = 255 * cmap(depth)[:, :, :3] # H, W, C + return depth.astype('uint8') + + +def merge_into_row(ele, pred): + def preprocess_depth(x): + y = np.squeeze(x.data.cpu().numpy()) + return depth_colorize(y) + + # if is gray, transforms to rgb + img_list = [] + if 'rgb' in ele: + rgb = np.squeeze(ele['rgb'][0, ...].data.cpu().numpy()) + rgb = np.transpose(rgb, (1, 2, 0)) + img_list.append(rgb) + elif 'g' in ele: + g = np.squeeze(ele['g'][0, ...].data.cpu().numpy()) + g = np.array(Image.fromarray(g).convert('RGB')) + img_list.append(g) + if 'd' in ele: + img_list.append(preprocess_depth(ele['d'][0, ...])) + img_list.append(preprocess_depth(pred[0, ...])) + if 'gt' in ele: + img_list.append(preprocess_depth(ele['gt'][0, ...])) + + img_merge = np.hstack(img_list) + return img_merge.astype('uint8') + + +def add_row(img_merge, row): + return np.vstack([img_merge, row]) + + +def save_image(img_merge, filename): + image_to_write = cv2.cvtColor(img_merge, cv2.COLOR_RGB2BGR) + cv2.imwrite(filename, image_to_write) + + +def save_depth_as_uint16png(img, filename): + img = (img * 256).astype('uint16') + cv2.imwrite(filename, img) + + +if ("DISPLAY" in os.environ): + f, axarr = plt.subplots(4, 1) + plt.tight_layout() + plt.ion() + + +def display_warping(rgb_tgt, pred_tgt, warped): + def preprocess(rgb_tgt, pred_tgt, warped): + rgb_tgt = 255 * np.transpose(np.squeeze(rgb_tgt.data.cpu().numpy()), + (1, 2, 0)) # H, W, C + # depth = np.squeeze(depth.cpu().numpy()) + # depth = depth_colorize(depth) + + # convert to log-scale + pred_tgt = np.squeeze(pred_tgt.data.cpu().numpy()) + # pred_tgt[pred_tgt<=0] = 0.9 # remove negative predictions + # pred_tgt = np.log10(pred_tgt) + + pred_tgt = depth_colorize(pred_tgt) + + warped = 255 * np.transpose(np.squeeze(warped.data.cpu().numpy()), + (1, 2, 0)) # H, W, C + recon_err = np.absolute( + warped.astype('float') - rgb_tgt.astype('float')) * (warped > 0) + recon_err = recon_err[:, :, 0] + recon_err[:, :, 1] + recon_err[:, :, 2] + recon_err = depth_colorize(recon_err) + return rgb_tgt.astype('uint8'), warped.astype( + 'uint8'), recon_err, pred_tgt + + rgb_tgt, warped, recon_err, pred_tgt = preprocess(rgb_tgt, pred_tgt, + warped) + + # 1st column + column = 0 + axarr[0].imshow(rgb_tgt) + axarr[0].axis('off') + axarr[0].axis('equal') + # axarr[0, column].set_title('rgb_tgt') + + axarr[1].imshow(warped) + axarr[1].axis('off') + axarr[1].axis('equal') + # axarr[1, column].set_title('warped') + + axarr[2].imshow(recon_err, 'hot') + axarr[2].axis('off') + axarr[2].axis('equal') + # axarr[2, column].set_title('recon_err error') + + axarr[3].imshow(pred_tgt, 'hot') + axarr[3].axis('off') + axarr[3].axis('equal') + # axarr[3, column].set_title('pred_tgt') + + # plt.show() + plt.pause(0.001) diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index 0b01e69ec..41540af3a 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -772,6 +772,7 @@ class OutputKeys(object): Tasks.surface_recon_common: [OutputKeys.OUTPUT], Tasks.video_colorization: [OutputKeys.OUTPUT_VIDEO], Tasks.image_control_3d_portrait: [OutputKeys.OUTPUT], + Tasks.self_supervised_depth_completion: [OutputKeys.OUTPUT_IMG], # image quality assessment degradation result for single image # { diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 30c5e484d..a19efd1f2 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -119,6 +119,8 @@ from .human3d_animation_pipeline import Human3DAnimationPipeline from .rife_video_frame_interpolation_pipeline import RIFEVideoFrameInterpolationPipeline from .anydoor_pipeline import AnydoorPipeline + from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline + else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], @@ -295,6 +297,8 @@ 'human3d_animation_pipeline': ['Human3DAnimationPipeline'], 'rife_video_frame_interpolation_pipeline': ['RIFEVideoFrameInterpolationPipeline'], 'anydoor_pipeline': ['AnydoorPipeline'], + 'self_supervised_depth_completion_pipeline': ['SelfSupervisedDepthCompletionPipeline'], + } import sys diff --git a/modelscope/pipelines/cv/self_supervised_depth_completion_pipeline.py b/modelscope/pipelines/cv/self_supervised_depth_completion_pipeline.py new file mode 100644 index 000000000..07b8aadbc --- /dev/null +++ b/modelscope/pipelines/cv/self_supervised_depth_completion_pipeline.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.self_supervised_depth_completion, module_name=Pipelines.self_supervised_depth_completion) +class SelfSupervisedDepthCompletionPipeline(Pipeline): + """SelfSupervisedDepthCompletionPipeline Class""" + def __init__(self, model: str, **kwargs): + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """preprocess, not used at present""" + return inputs + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """forward""" + model_dir = inputs['model_dir'] + source_dir = inputs['source_dir'] + self.model.run(model_dir, source_dir) + return {OutputKeys.OUTPUT: 'Done'} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """postprocess, not used at present""" + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 999be1543..8232280ed 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -168,6 +168,8 @@ class CVTasks(object): human3d_render = 'human3d-render' human3d_animation = 'human3d-animation' image_control_3d_portrait = 'image-control-3d-portrait' + self_supervised_depth_completion = 'self-supervised-depth-completion' + # 3d generation image_to_3d = 'image-to-3d' diff --git a/modelscope/utils/pipeline_schema.json b/modelscope/utils/pipeline_schema.json index cf5c7fb7d..56095319e 100644 --- a/modelscope/utils/pipeline_schema.json +++ b/modelscope/utils/pipeline_schema.json @@ -3777,5 +3777,17 @@ } } } - } + }, + "self-supervised-depth-completion": { + "input": {}, + "parameters": {}, + "output": { + "type": "object", + "properties": { + "output": { + "type": "object" + } + } + } + }, } diff --git a/tests/pipelines/test_self_supervised_depth_completion.py b/tests/pipelines/test_self_supervised_depth_completion.py new file mode 100644 index 000000000..918ab7c9c --- /dev/null +++ b/tests/pipelines/test_self_supervised_depth_completion.py @@ -0,0 +1,46 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import DownloadMode, Tasks +from modelscope.utils.test_utils import test_level +from modelscope import get_logger +logger = get_logger() + + +class SelfSupervisedDepthCompletionTest(unittest.TestCase): + """class SelfSupervisedDepthCompletionTest""" + def setUp(self) -> None: + self.model_id = 'Damo_XR_Lab/Self_Supervised_Depth_Completion' + data_dir = MsDataset.load( + 'KITTI_Depth_Dataset', + namespace='Damo_XR_Lab', + split='test', + download_mode=DownloadMode.FORCE_REDOWNLOAD + ).config_kwargs['split_config']['test'] + self.source_dir = os.path.join(data_dir, 'selected_data') + logger.info(data_dir) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest only') + def test_run(self): + """test running evaluation""" + snapshot_path = snapshot_download(self.model_id) + logger.info('snapshot_path: %s', snapshot_path) + self_supervised_depth_completion = pipeline( + task=Tasks.self_supervised_depth_completion, + model=self.model_id + # ,config_file = os.path.join(modelPath, "configuration.json") + ) + + self_supervised_depth_completion(dict(model_dir=snapshot_path, source_dir=self.source_dir)) + logger.info('self-supervised-depth-completion_damo.test_run_modelhub done') + + +if __name__ == '__main__': + unittest.main()