Skip to content

Commit

Permalink
Fix pytorch1.5 compile error (#2524)
Browse files Browse the repository at this point in the history
* replace AT_CHECK with TORCH_CHECK to support pytorch1.5

* fix compile warning, delete grid sample op

* fix ci pytorch version

* delete python3.5

* update version of pytorch and torchvision in ci

* change is_cuda to device().is_cuda()

* specify TORCH_CUDA_ARCH_LIST=6.0,7.0

* fix typo

* delete affine grid

* fix typos

* fix setup.py

* remove redundant comment
  • Loading branch information
yhcao6 committed Apr 24, 2020
1 parent 2e802ce commit e5921ca
Show file tree
Hide file tree
Showing 41 changed files with 215 additions and 2,510 deletions.
16 changes: 12 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@ dist: bionic # ubuntu 18.04
language: python

python:
- "3.5"
- "3.6"
- "3.7"

env: CUDA=10.1.105-1 CUDA_SHORT=10.1 UBUNTU_VERSION=ubuntu1804 FORCE_CUDA=1
env:
global:
- CUDA=10.1.105-1
- CUDA_SHORT=10.1
- UBUNTU_VERSION=ubuntu1804
- FORCE_CUDA=1
matrix:
- TORCH=1.3.1 TORCHVISION=0.4.2 CUDA_ARCH=6.0
- TORCH=1.5.0 TORCHVISION=0.6.0 CUDA_ARCH=7.0

cache: pip

# Ref to CUDA installation in Travis: https://github.com/jeremad/cuda-travis
Expand All @@ -25,7 +33,7 @@ before_install:

install:
- pip install Pillow==6.2.2 # remove this line when torchvision>=0.5
- pip install torch==1.2 torchvision==0.4.0 # TODO: fix CI for pytorch>1.2
- pip install torch==${TORCH} torchvision==${TORCHVISION}
- pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
- pip install -r requirements.txt

Expand All @@ -36,7 +44,7 @@ before_script:

script:
- python setup.py check -m -s
- python setup.py build_ext --inplace
- TORCH_CUDA_ARCH_LIST="${CUDA_ARCH}" python setup.py build_ext --inplace
- coverage run --branch --source mmdet -m py.test -v --xdoctest-modules tests mmdet

after_success:
Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/mask_heads/fcn_mask_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import pycocotools.mask as mask_util
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair

from mmdet.core import auto_fp16, force_fp32, mask_target
from mmdet.ops import Conv2d, ConvModule, build_upsample_layer
from mmdet.ops.carafe import CARAFEPack
from mmdet.ops.grid_sampler import grid_sample
from ..builder import HEADS, build_loss

BYTES_PER_FLOAT = 4
Expand Down Expand Up @@ -302,7 +302,7 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)

img_masks = grid_sample(
img_masks = F.grid_sample(
masks.to(dtype=torch.float32), grid, align_corners=False)

if skip_empty:
Expand Down
3 changes: 0 additions & 3 deletions mmdet/ops/affine_grid/__init__.py

This file was deleted.

68 changes: 0 additions & 68 deletions mmdet/ops/affine_grid/affine_grid.py

This file was deleted.

23 changes: 0 additions & 23 deletions mmdet/ops/affine_grid/src/affine_grid_ext.cpp

This file was deleted.

108 changes: 0 additions & 108 deletions mmdet/ops/affine_grid/src/cpu/affine_grid_cpu.cpp

This file was deleted.

4 changes: 2 additions & 2 deletions mmdet/ops/carafe/src/carafe_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ int carafe_forward(at::Tensor features, at::Tensor rfeatures,
at::Tensor masks, at::Tensor rmasks, int kernel_size,
int group_size, int scale_factor, at::Tensor routput,
at::Tensor output) {
if (features.type().is_cuda()) {
if (features.device().is_cuda()) {
#ifdef WITH_CUDA
return carafe_forward_cuda(features, rfeatures, masks, rmasks, kernel_size,
group_size, scale_factor, routput, output);
Expand All @@ -39,7 +39,7 @@ int carafe_backward(at::Tensor top_grad, at::Tensor rfeatures,
at::Tensor rbottom_grad_hs, at::Tensor rbottom_grad,
at::Tensor rmask_grad, at::Tensor bottom_grad,
at::Tensor mask_grad) {
if (top_grad.type().is_cuda()) {
if (top_grad.device().is_cuda()) {
#ifdef WITH_CUDA
return carafe_backward_cuda(top_grad, rfeatures, masks, kernel_size,
group_size, scale_factor, rtop_grad, rbottom_grad_hs, rbottom_grad,
Expand Down
4 changes: 2 additions & 2 deletions mmdet/ops/carafe/src/carafe_naive_ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ int carafe_naive_backward_cuda(at::Tensor top_grad, at::Tensor features,
int carafe_naive_forward(at::Tensor features, at::Tensor masks,
int kernel_size, int group_size, int scale_factor,
at::Tensor output) {
if (features.type().is_cuda()) {
if (features.device().is_cuda()) {
#ifdef WITH_CUDA
return carafe_naive_forward_cuda(features, masks, kernel_size,
group_size, scale_factor, output);
Expand All @@ -33,7 +33,7 @@ int carafe_naive_backward(at::Tensor top_grad, at::Tensor features,
at::Tensor masks, int kernel_size,
int group_size, int scale_factor,
at::Tensor bottom_grad, at::Tensor mask_grad) {
if (top_grad.type().is_cuda()) {
if (top_grad.device().is_cuda()) {
#ifdef WITH_CUDA
return carafe_naive_backward_cuda(top_grad, features, masks, kernel_size,
group_size, scale_factor, bottom_grad, mask_grad);
Expand Down
4 changes: 2 additions & 2 deletions mmdet/ops/carafe/src/cuda/carafe_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ int CARAFEBackwardLaucher(const at::Tensor top_grad, const at::Tensor rfeatures,
at::Tensor rmask_grad, at::Tensor bottom_grad,
at::Tensor mask_grad);

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
Expand Down
Loading

0 comments on commit e5921ca

Please sign in to comment.