Skip to content

Commit

Permalink
feat(visualization): update scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
npurson committed Sep 27, 2023
1 parent 94f5aaa commit 6411bc1
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 95 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Haoyang Zhang<sup>2</sup>,

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/symphonize-3d-semantic-scene-completion-with/3d-semantic-scene-completion-from-a-single-1)](https://paperswithcode.com/sota/3d-semantic-scene-completion-from-a-single-1?p=symphonize-3d-semantic-scene-completion-with)

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/symphonize-3d-semantic-scene-completion-with/3d-semantic-scene-completion-from-a-single-2)](https://paperswithcode.com/sota/3d-semantic-scene-completion-from-a-single-2?p=symphonize-3d-semantic-scene-completion-with)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/symphonize-3d-semantic-scene-completion-with/3d-semantic-scene-completion-on-kitti-360)](https://paperswithcode.com/sota/3d-semantic-scene-completion-on-kitti-360?p=symphonize-3d-semantic-scene-completion-with)

**TL;DR:** Our paper delve into enhancing SSC through the utilization of instance-centric representations. We propose a novel paradigm that integrates ***instance queries*** to facilitate ***instance semantics*** and capture ***global context***. Our approach achieves SOTA results of ***13.02 mIoU & 41.07 IoU*** on the SemanticKITTI *test* benchmark.

Expand Down Expand Up @@ -47,10 +47,16 @@ This project is built upon ***[TmPL](https://github.com/npurson/tmpl)***, a temp

### Prepare Dataset

#### SemanticKITTI

1. Download the RGB images, calibration files, and preprocess the labels, referring to the documentation of [VoxFormer](https://github.com/NVlabs/VoxFormer/blob/main/docs/prepare_dataset.md) or [MonoScene](https://github.com/astra-vision/MonoScene#semantickitti).

2. Generate depth predications with pre-trained MobileStereoNet referring to VoxFormer https://github.com/NVlabs/VoxFormer/tree/main/preprocess#3-image-to-depth.

#### SSCBench-KITTI-360

1. Refer to https://github.com/ai4ce/SSCBench/tree/main/dataset/KITTI-360.

## Usage

0. **Setup**
Expand Down Expand Up @@ -89,7 +95,7 @@ This project is built upon ***[TmPL](https://github.com/npurson/tmpl)***, a temp
2. Visualization
```shell
python tools/generate_outputs.py [+output_file=...]
python tools/generate_outputs.py [+path=...]
```
## Results
Expand All @@ -98,14 +104,14 @@ This project is built upon ***[TmPL](https://github.com/npurson/tmpl)***, a temp
| Method | Split | IoU | mIoU | Download |
| :------------------------------------------: | :---: | :---: | :---: | :----------------------: |
| [Symphonies](symphonies/configs/config.yaml) | test | 41.07 | 13.02 | [model](<https://github.com/hustvl/Symphonies/releases/download/v1.0/e28_miou0.1344_remapped.ckpt>) |
| [Symphonies](symphonies/configs/config.yaml) | val | 41.44 | 13.44 | [log](<https://github.com/hustvl/Symphonies/releases/download/v1.0/log>) |
| [Symphonies](symphonies/configs/config.yaml) | val | 42.17 | 14.66 | [log](https://github.com/hustvl/Symphonies/releases/download/v1.0/semantic_kitti.log) / [model](https://github.com/hustvl/Symphonies/releases/download/v1.0/semantic_kitti_e26_miou0.1466.ckpt) |
| [Symphonies](symphonies/configs/config.yaml) | test | 41.07 | 13.02 | |
2. **KITTI-360**
| Method | Split | IoU | mIoU | Download |
| :------------------------------------------: | :---: | :---: | :---: | :----------------------: |
| [Symphonies](symphonies/configs/config.yaml) | test | 43.11 | 16.22 | available soon |
| [Symphonies](symphonies/configs/config.yaml) | test | 44.12 | 18.58 | [log](https://github.com/hustvl/Symphonies/releases/download/v1.0/kitti_360.log) / [model](https://github.com/hustvl/Symphonies/releases/download/v1.0/kitti_360_e27_miou0.1858.ckpt) |
## Citation
Expand Down
Binary file modified assets/arch.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 5 additions & 25 deletions ssc_pl/models/decoders/symphonies_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,10 @@
from torch.cuda.amp import autocast

from ..layers import (ASPP, DeformableSqueezeAttention, DeformableTransformerLayer,
LearnableSqueezePositionalEncoding, TransformerLayer, Upsample,
nchw_to_nlc, nlc_to_nchw)
from ..utils import (cumprod, flatten_fov_from_voxels, generate_grid, index_fov_back_to_voxels,
interpolate_flatten)


def flatten_multi_scale_feats(feats):
feat_flatten = torch.cat([nchw_to_nlc(feat) for feat in feats], dim=1)
shapes = torch.stack([torch.tensor(feat.shape[2:]) for feat in feats]).to(feat_flatten.device)
return feat_flatten, shapes


def get_level_start_index(shapes):
return torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))


def pix2vox(pix_coords, depth, K, E, voxel_origin, voxel_size, offset=0.5, downsample_z=1):
p_x = torch.cat([pix_coords * depth, depth], dim=1) # bs, 3, h, w
p_c = K.inverse() @ p_x.flatten(2)
p_w = E.inverse() @ F.pad(p_c, (0, 0, 0, 1), value=1)
p_v = ((p_w[:, :-1].transpose(1, 2) - voxel_origin.unsqueeze(1)) / voxel_size - offset)
if downsample_z != 1:
p_v[..., -1] /= downsample_z
return p_v.long()
LearnableSqueezePositionalEncoding, TransformerLayer, Upsample)
from ..utils import (cumprod, flatten_fov_from_voxels, flatten_multi_scale_feats, generate_grid,
get_level_start_index, index_fov_back_to_voxels, interpolate_flatten,
nchw_to_nlc, nlc_to_nchw, pix2vox)


class SymphoniesLayer(nn.Module):
Expand Down Expand Up @@ -281,5 +261,5 @@ def generate_vol_ref_pts_from_pts(self, pred_pts, vol_pts):
pred_pts = pred_pts[..., 1] * self.image_shape[1] + pred_pts[..., 0]
assert pred_pts.size(0) == 1
ref_pts = vol_pts[:, pred_pts.squeeze()]
ref_pts = ref_pts / torch.tensor(self.scene_shape).to(pred_pts)
ref_pts = ref_pts / (torch.tensor(self.scene_shape) - 1).to(pred_pts)
return ref_pts.clamp(0, 1)
26 changes: 0 additions & 26 deletions ssc_pl/models/layers/transformer.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,6 @@
import torch.nn as nn
from mmcv.ops import MultiScaleDeformableAttention

from ..utils import cumprod


def nlc_to_nchw(x, shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before conversion.
shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after conversion.
"""
B, L, C = x.shape
assert L == cumprod(shape), 'The seq_len does not match H, W'
return x.transpose(1, 2).reshape(B, C, *shape).contiguous()


def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
Returns:
Tensor: The output tensor of shape [N, L, C] after conversion.
tuple: The [H, W] shape.
"""
return x.flatten(2).transpose(1, 2).contiguous()


class TransformerLayer(nn.Module):

Expand Down
4 changes: 2 additions & 2 deletions ssc_pl/models/projections/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

from ..layers import TransformerLayer, nchw_to_nlc
from ..utils import generate_grid
from ..layers import TransformerLayer
from ..utils import generate_grid, nchw_to_nlc


class ProjectionLayer(nn.Module):
Expand Down
56 changes: 54 additions & 2 deletions ssc_pl/models/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,60 @@ def interpolate_flatten(x, src_shape, dst_shape, mode='nearest'):
bs, n, c = *x.shape, 1
assert cumprod(src_shape) == n
x = F.interpolate(
x.reshape(bs, c, *src_shape).float(), dst_shape,
mode=mode).flatten(2).transpose(1, 2).to(x.dtype)
x.reshape(bs, c, *src_shape).float(), dst_shape, mode=mode,
align_corners=False).flatten(2).transpose(1, 2).to(x.dtype)
if c == 1:
x = x.squeeze(2)
return x


def flatten_multi_scale_feats(feats):
feat_flatten = torch.cat([nchw_to_nlc(feat) for feat in feats], dim=1)
shapes = torch.stack([torch.tensor(feat.shape[2:]) for feat in feats]).to(feat_flatten.device)
return feat_flatten, shapes


def get_level_start_index(shapes):
return torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))


def nlc_to_nchw(x, shape):
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, L, C] before conversion.
shape (Sequence[int]): The height and width of output feature map.
Returns:
Tensor: The output tensor of shape [N, C, H, W] after conversion.
"""
B, L, C = x.shape
assert L == cumprod(shape), 'The seq_len does not match H, W'
return x.transpose(1, 2).reshape(B, C, *shape).contiguous()


def nchw_to_nlc(x):
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
Args:
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
Returns:
Tensor: The output tensor of shape [N, L, C] after conversion.
tuple: The [H, W] shape.
"""
return x.flatten(2).transpose(1, 2).contiguous()


def pix2vox(pix_coords, depth, K, E, voxel_origin, voxel_size, offset=0.5, downsample_z=1):
p_x = torch.cat([pix_coords * depth, depth], dim=1) # bs, 3, h, w
p_c = K.inverse() @ p_x.flatten(2)
p_w = E.inverse() @ F.pad(p_c, (0, 0, 0, 1), value=1)
p_v = (p_w[:, :-1].transpose(1, 2) - voxel_origin.unsqueeze(1)) / voxel_size - offset
if downsample_z != 1:
p_v[..., -1] /= downsample_z
return p_v.long()


def vox2pix(voxel_pts, K, E, voxel_origin, scene_shape, image_shape, voxel_size):
p_v = voxel_pts.squeeze(2) * torch.tensor(scene_shape).to(voxel_pts) * voxel_size + voxel_origin
p_c = E @ F.pad(p_v.transpose(1, 2), (0, 0, 0, 1), value=1)
p_x = (K @ p_c[:, :-1]) / p_c[:, 2]
p_x = p_x[:, :-1].transpose(1, 2) / (torch.tensor(image_shape[::-1]).to(p_x) - 1)
return p_x.clamp(0, 1)
80 changes: 80 additions & 0 deletions tools/render_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os

import hydra
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from rich.progress import track
from ssc_pl import LitModule, build_data_loaders, generate_grid, pre_build_callbacks


@hydra.main(version_base=None, config_path='../configs', config_name='config')
def main(cfg: DictConfig):
if os.environ.get('LOCAL_RANK', 0) == 0:
print(OmegaConf.to_yaml(cfg))
cfg, _ = pre_build_callbacks(cfg)

dls, meta_info = build_data_loaders(cfg.data)
data_loader = dls[1]

if cfg.get('ckpt_path'):
model = LitModule.load_from_checkpoint(cfg.ckpt_path, **cfg, meta_info=meta_info)
else:
import warnings
warnings.warn('\033[31;1m{}\033[0m'.format('No ckpt_path is provided'))
model = LitModule(**cfg, meta_info=meta_info)

model.cuda()
model.eval()

with torch.no_grad():
for batch_inputs, targets in data_loader:
for key in batch_inputs:
if isinstance(batch_inputs[key], torch.Tensor):
batch_inputs[key] = batch_inputs[key].cuda()

# outputs = model(batch_inputs)

# vol = outputs['ssc_logits'] # (B, C, X, Y, Z)
vol = targets['target'].cuda()
K = batch_inputs['cam_K'] # (B, 3, 3)
E = batch_inputs['cam_pose'] # (B, 4, 4)
vox_origin = batch_inputs['voxel_origin'] # (B, 3)
vox_size = 0.2
image_shape = batch_inputs['img'].shape[-2:]

pix_coords = generate_grid(image_shape).to(vol) # (2, H, W)
pix_coords = torch.flip(pix_coords, dims=[0])
depth = torch.arange(2, 50, step=1).to(pix_coords) # (D,)
p_x = F.pad(pix_coords, (0, 0, 0, 0, 0, 1), value=1)
p_x = p_x.unsqueeze(-1).repeat(1, 1, 1, depth.size(0)) # (3, H, W, D)
d_ = depth.reshape(1, 1, 1, -1)
p_x = p_x * d_

p_c = K.inverse() @ p_x.flatten(1)
p_w = E.inverse() @ F.pad(p_c, (0, 0, 0, 1), value=1)
p_v = (p_w[:, :-1].transpose(1, 2) - vox_origin.unsqueeze(1)) / vox_size - 0.5
p_v = p_v.reshape(1, *image_shape, depth.size(0), -1) # (1, H, W, D, 3)
p_v = p_v / (torch.tensor(vol.shape[-3:]) - 1).to(p_v)

# vol = 1 - vol.softmax(dim=1)[:, 0].unsqueeze(1) # prob of non-empty
vol = ((vol.int() != 0) & (vol.int() != 255)).to(vol).unsqueeze(1)
sigmas = F.grid_sample(vol, torch.flip(p_v, dims=[-1]) * 2 - 1, padding_mode='border')
T = torch.exp(-torch.cumsum(sigmas * 1, dim=-1))
alpha = 1 - torch.exp(-sigmas * 1)
depth_map = torch.sum(T * alpha * d_.unsqueeze(0), dim=-1)
draw_depth(depth_map, 'rendered_depth.png')
draw_depth(batch_inputs['depth'], 'depth.png')
import pdb; pdb.set_trace()


def draw_depth(depth_map, path):
depth_map = depth_map.squeeze().cpu().numpy()
plt.imshow(depth_map, cmap='jet')
plt.colorbar()
plt.imsave(path, depth_map, cmap='jet')


if __name__ == '__main__':
main()

0 comments on commit 6411bc1

Please sign in to comment.