Skip to content

Commit

Permalink
add Malleable 2.5D Conv
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesCXK committed Aug 17, 2020
1 parent 087b257 commit 32b3f86
Show file tree
Hide file tree
Showing 17 changed files with 1,760 additions and 26 deletions.
4 changes: 4 additions & 0 deletions .gitignore
@@ -1 +1,5 @@
.DS_Store
__pycache__
log


74 changes: 53 additions & 21 deletions README.md
Expand Up @@ -2,34 +2,54 @@

![license](https://img.shields.io/badge/license-MIT-green) ![PyTorch-1.0.0](https://img.shields.io/badge/PyTorch-1.0.0-blue)

**Official PyTorch implementation of "Bi-directional Cross-Modality Feature Propagation with Separation-and-Aggregation Gate for RGB-D Semantic Segmentation"** ([ECCV, 2020](http://eccv2020.eu/)).
Implement some state-of-the-art methods of RGBD Semantic Segmentation task in PyTorch.

<img src='pic/arch.png'>
Currently, we provide code of:

- **SA-Gate, ECCV 2020** [[arXiv](https://arxiv.org/abs/2007.09183)]
<img src='pic/sagate.png' width="600">
- **Malleable 2.5D Convolution, ECCV 2020** [[arXiv](https://arxiv.org/abs/2007.09365)]
<img src='pic/malleable.png' width="600">



## News

- 2020/08/16

Official code release for the paper **Malleable 2.5D Convolution: Learning Receptive Fields along the Depth-axis for RGB-D Scene Parsing**, *ECCV 2020*. [[arXiv](https://arxiv.org/abs/2007.09365)], [[code](./model/malleable2_5d.nyu.res101)]

Thanks [aurora95](https://github.com/aurora95) for his open source code!

- 2020/07/20

Official code release for the paper **Bi-directional Cross-Modality Feature Propagation with Separation-and-Aggregation Gate for RGB-D Semantic Segmentation**, *ECCV 2020*. [[arXiv](https://arxiv.org/abs/2007.09183)], [[code](./model/SA-Gate.nyu)]


## Main Results

#### Results on NYU Depth V2 Test Set with Multi-scale Inference

| Method | mIoU (%) |
| :--------: | :------: |
| 3DGNN | 43.1 |
| ACNet | 48.3 |
| RDFNet-101 | 49.1 |
| PADNet | 50.2 |
| PAP | 50.4 |
| **Ours** | **52.4** |
| Method | mIoU (%) |
| :----------------: | :------: |
| 3DGNN | 43.1 |
| ACNet | 48.3 |
| RDFNet-101 | 49.1 |
| PADNet | 50.2 |
| PAP | 50.4 |
| **Malleable 2.5D** | **50.9** |
| **SA-Gate** | **52.4** |

#### Results on CityScapes Test Set with Multi-scale Inference (out method uses output stride=16 and does not use coarse-labeled data)

| Method | mIoU (%) |
| :------: | :------: |
| PADNet | 80.3 |
| DANet | 81.5 |
| GALD | 81.8 |
| ACFNet | 81.8 |
| **Ours** | **82.8** |
| Method | mIoU (%) |
| :---------: | :------: |
| PADNet | 80.3 |
| DANet | 81.5 |
| GALD | 81.8 |
| ACFNet | 81.8 |
| **SA-Gate** | **82.8** |

For more details, please refer to our paper.

Expand All @@ -55,7 +75,7 @@ Your directory tree should look like this:
| |-- train.txt
```



## Installation

Expand Down Expand Up @@ -114,6 +134,8 @@ If you want to generate HHA maps from Depth maps, please refer to [https://githu

## Training and Inference

*We just take SA-Gate as an example. You could run other models in a similar way.*

### Training

Training on NYU Depth V2:
Expand Down Expand Up @@ -154,16 +176,26 @@ $ python eval.py -e 300-400 -d 0-7 --save_path results

Please consider citing this project in your publications if it helps your research.

```
@inproceedings{chen2020SAGate,
```tex
@inproceedings{chen2020-SAGate,
title={Bi-directional Cross-Modality Feature Propagation with Separation-and-Aggregation Gate for RGB-D Semantic Segmentation},
author={Chen, Xiaokang and Lin, Kwan-Yee and Wang, Jingbo and Wu, Wayne and Qian, Chen and Li, Hongsheng and Zeng, Gang},
booktitle={European Conference on Computer Vision (ECCV)},
year={2020}
}
```

```tex
@inproceedings{xing2020-melleable,
title={Malleable 2.5D Convolution: Learning Receptive Fields along the Depth-axis for RGB-D Scene Parsing
},
author={Xing, Yajie and Wang, Jingbo and Zeng, Gang},
booktitle={European Conference on Computer Vision (ECCV)},
year={2020}
}
```



## Acknowledgement

Expand Down
180 changes: 180 additions & 0 deletions furnace/seg_opr/conv_2_5d.py
@@ -0,0 +1,180 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

def _ntuple(n):
def parse(x):
if isinstance(x, list) or isinstance(x, tuple):
return x
return tuple([x]*n)
return parse
_pair = _ntuple(2)

class Conv2_5D_Depth(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, pixel_size=1):
super(Conv2_5D_Depth, self).__init__()
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.kernel_size_prod = self.kernel_size[0]*self.kernel_size[1]
self.stride = stride
self.padding = padding
self.dilation = dilation
self.pixel_size = pixel_size
assert self.kernel_size_prod%2==1

self.weight_0 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
self.weight_1 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
self.weight_2 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)

def forward(self, x, depth, camera_params):
N, C, H, W = x.size(0), x.size(1), x.size(2), x.size(3)
out_H = (H+2*self.padding[0]-self.dilation[0]*(self.kernel_size[0]-1)-1)//self.stride[0]+1
out_W = (W+2*self.padding[1]-self.dilation[1]*(self.kernel_size[1]-1)-1)//self.stride[1]+1

intrinsic = camera_params['intrinsic']
x_col = F.unfold(x, self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride) # N*(C*kh*kw)*(out_H*out_W)
x_col = x_col.view(N, C, self.kernel_size_prod, out_H*out_W)
depth_col = F.unfold(depth, self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride) # N*(kh*kw)*(out_H*out_W)
valid_mask = 1-depth_col.eq(0.).to(torch.float32)

valid_mask = valid_mask*valid_mask[:, self.kernel_size_prod//2, :].view(N,1,out_H*out_W)
depth_col *= valid_mask
valid_mask = valid_mask.view(N,1,self.kernel_size_prod,out_H*out_W)

center_depth = depth_col[:,self.kernel_size_prod//2,:].view(N,1,out_H*out_W)
# grid_range = self.pixel_size * center_depth / (intrinsic['fx'].view(N,1,1) * camera_params['scale'].view(N,1,1))
grid_range = self.pixel_size * self.dilation[0] * center_depth / intrinsic['fx'].view(N,1,1)

mask_0 = torch.abs(depth_col - (center_depth + grid_range)).le(grid_range/2).view(N,1,self.kernel_size_prod,out_H*out_W).to(torch.float32)
mask_1 = torch.abs(depth_col - (center_depth )).le(grid_range/2).view(N,1,self.kernel_size_prod,out_H*out_W).to(torch.float32)
mask_1 = (mask_1 + 1- valid_mask).clamp(min=0., max=1.)
mask_2 = torch.abs(depth_col - (center_depth - grid_range)).le(grid_range/2).view(N,1,self.kernel_size_prod,out_H*out_W).to(torch.float32)
output = torch.matmul(self.weight_0.view(-1,C*self.kernel_size_prod), (x_col*mask_0).view(N, C*self.kernel_size_prod, out_H*out_W))
output += torch.matmul(self.weight_1.view(-1,C*self.kernel_size_prod), (x_col*mask_1).view(N, C*self.kernel_size_prod, out_H*out_W))
output += torch.matmul(self.weight_2.view(-1,C*self.kernel_size_prod), (x_col*mask_2).view(N, C*self.kernel_size_prod, out_H*out_W))
output = output.view(N,-1,out_H,out_W)
if self.bias:
output += self.bias.view(1,-1,1,1)
return output

def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.bias is None:
s += ', bias=False'
return s.format(**self.__dict__)


class Malleable_Conv2_5D_Depth(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, pixel_size=1, anchor_init=[-2.,-1.,0.,1.,2.], scale_const=100, fix_center=False, adjust_to_scale=False):
super(Malleable_Conv2_5D_Depth, self).__init__()
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.kernel_size_prod = self.kernel_size[0]*self.kernel_size[1]
self.stride = stride
self.padding = padding
self.dilation = dilation
self.pixel_size = pixel_size
self.fix_center = fix_center
self.adjust_to_scale = adjust_to_scale
assert self.kernel_size_prod%2==1

self.weight_0 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
self.weight_1 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
self.weight_2 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
self.depth_anchor = Parameter(torch.tensor(anchor_init, requires_grad=True).view(1,5,1,1))
# self.depth_bias = Parameter(torch.tensor([0.,0.,0.,0.,0.], requires_grad=True).view(1,5,1,1))
self.temperature = Parameter(torch.tensor([1.], requires_grad=True))
self.kernel_weight = Parameter(torch.tensor([0.,0.,0.], requires_grad=True))
self.scale_const = scale_const
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)

def forward(self, x, depth, camera_params):
N, C, H, W = x.size(0), x.size(1), x.size(2), x.size(3)
out_H = (H+2*self.padding[0]-self.dilation[0]*(self.kernel_size[0]-1)-1)//self.stride[0]+1
out_W = (W+2*self.padding[1]-self.dilation[1]*(self.kernel_size[1]-1)-1)//self.stride[1]+1

intrinsic = camera_params['intrinsic']
x_col = F.unfold(x, self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride) # N*(C*kh*kw)*(out_H*out_W)
x_col = x_col.view(N, C, self.kernel_size_prod, out_H*out_W)
depth_col = F.unfold(depth, self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride) # N*(kh*kw)*(out_H*out_W)
valid_mask = 1-depth_col.eq(0.).to(torch.float32)

valid_mask = valid_mask*valid_mask[:, self.kernel_size_prod//2, :].view(N,1,out_H*out_W)
depth_col *= valid_mask
valid_mask = valid_mask.view(N,1,self.kernel_size_prod,out_H*out_W)

center_depth = depth_col[:,self.kernel_size_prod//2,:].view(N,1,out_H*out_W)
if self.adjust_to_scale:
grid_range = self.pixel_size * self.dilation[0] * center_depth / (intrinsic['fx'].view(N,1,1) * camera_params['scale'].view(N,1,1))
else:
grid_range = self.pixel_size * self.dilation[0] * center_depth / intrinsic['fx'].view(N,1,1)
depth_diff = (depth_col - center_depth).view(N, 1, self.kernel_size_prod, out_H*out_W) # N*1*(kh*kw)*(out_H*out_W)
relative_diff = depth_diff*self.scale_const/(1e-5 + grid_range.view(N,1,1,out_H*out_W)*self.scale_const)
depth_logit = -( ((relative_diff - self.depth_anchor).pow(2)) / (1e-5 + torch.clamp(self.temperature, min=0.)) ) # N*5*(kh*kw)*(out_H*out_W)
if self.fix_center:
depth_logit[:,2,:,:] = -( ((relative_diff - 0.).pow(2)) / (1e-5 + torch.clamp(self.temperature, min=0.)) ).view(N,self.kernel_size_prod,out_H*out_W)

depth_out_range_0 = (depth_diff<self.depth_anchor[0,0,0,0]).to(torch.float32).view(N,self.kernel_size_prod,out_H*out_W)
depth_out_range_4 = (depth_diff>self.depth_anchor[0,4,0,0]).to(torch.float32).view(N,self.kernel_size_prod,out_H*out_W)
depth_logit[:,0,:,:] = depth_logit[:,0,:,:]*(1 - 2*depth_out_range_0)
depth_logit[:,4,:,:] = depth_logit[:,4,:,:]*(1 - 2*depth_out_range_4)

depth_class = F.softmax(depth_logit, dim=1) # N*5*(kh*kw)*(out_H*out_W)

mask_0 = depth_class[:,1,:,:].view(N,1,self.kernel_size_prod,out_H*out_W).to(torch.float32)
mask_1 = depth_class[:,2,:,:].view(N,1,self.kernel_size_prod,out_H*out_W).to(torch.float32)
mask_2 = depth_class[:,3,:,:].view(N,1,self.kernel_size_prod,out_H*out_W).to(torch.float32)

invalid_mask_bool = valid_mask.eq(0.)

mask_0 = mask_0*valid_mask
mask_1 = mask_1*valid_mask
mask_2 = mask_2*valid_mask
mask_0[invalid_mask_bool] = 1./5.
mask_1[invalid_mask_bool] = 1./5.
mask_2[invalid_mask_bool] = 1./5.

weight = F.softmax(self.kernel_weight, dim=0) * 3 #???
output = torch.matmul(self.weight_0.view(-1,C*self.kernel_size_prod), (x_col*mask_0).view(N, C*self.kernel_size_prod, out_H*out_W)) * weight[0]
output += torch.matmul(self.weight_1.view(-1,C*self.kernel_size_prod), (x_col*mask_1).view(N, C*self.kernel_size_prod, out_H*out_W)) * weight[1]
output += torch.matmul(self.weight_2.view(-1,C*self.kernel_size_prod), (x_col*mask_2).view(N, C*self.kernel_size_prod, out_H*out_W)) * weight[2]
output = output.view(N,-1,out_H,out_W)
if self.bias:
output += self.bias.view(1,-1,1,1)
return output

def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.bias is None:
s += ', bias=False'
return s.format(**self.__dict__)
44 changes: 44 additions & 0 deletions furnace/seg_opr/geo_utils.py
@@ -0,0 +1,44 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

class Plane2Space(nn.Module):
def __init__(self):
super(Plane2Space, self).__init__()

def forward(self, depth, coordinate, camera_params):
valid_mask = 1-depth.eq(0.).to(torch.float32)

depth = torch.clamp(depth, min=1e-5)
N, H, W = depth.size(0), depth.size(2), depth.size(3)
intrinsic = camera_params['intrinsic']

K_inverse = depth.new_zeros(N, 3, 3)
K_inverse[:,0,0] = 1./intrinsic['fx']
K_inverse[:,1,1] = 1./intrinsic['fy']
K_inverse[:,2,2] = 1.
if 'cx' in intrinsic:
K_inverse[:,0,2] = -intrinsic['cx']/intrinsic['fx']
K_inverse[:,1,2] = -intrinsic['cy']/intrinsic['fy']
elif 'u0' in intrinsic:
K_inverse[:,0,2] = -intrinsic['u0']/intrinsic['fx']
K_inverse[:,1,2] = -intrinsic['v0']/intrinsic['fy']
coord_3d = torch.matmul(K_inverse, (coordinate.float()*depth.float()).view(N,3,H*W)).view(N,3,H,W).contiguous()
coord_3d *= valid_mask

return coord_3d

class Disp2Depth(nn.Module):
def __init__(self, min_disp=0.01, max_disp=256):
self.min_disp = min_disp
self.max_disp = max_disp
super(Disp2Depth, self).__init__()

def forward(self, disp, camera_params):
N = disp.size(0)
intrinsic, extrinsic = camera_params['intrinsic'], camera_params['extrinsic']
valid_mask = 1 - disp.eq(0.).to(torch.float32)
depth = (extrinsic['baseline'] * intrinsic['fx']).view(N, 1, 1, 1).cuda() / torch.clamp(disp, self.min_disp, self.max_disp)
depth *= valid_mask
return depth

0 comments on commit 32b3f86

Please sign in to comment.