Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast and generic implementation using OpenMP and CUDA #45

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ By [Duo Li](https://duoli.org/), [Jie Hu](https://github.com/hujie-frank), [Chan

<p align="center"><img src="fig/involution.png" width="500" /></p>

**TL; DR.** `involution` is a general-purpose neural primitive that is versatile for a spectrum of deep learning models on different vision tasks. `involution` bridges `convolution` and `self-attention` in design, while being more efficient and effective than `convolution`, simpler than `self-attention` in form.
**TL; DR.** `involution` is a general-purpose neural primitive that is versatile for a spectrum of deep learning models on different vision tasks. `involution` bridges `convolution` and `self-attention` in design, while being more efficient and effective than `convolution`, simpler than `self-attention` in form.

<p align="center"><img src="fig/complexity.png" width="400" /><img src="fig/parameter.png" width="400" /></p>

Expand Down
17 changes: 17 additions & 0 deletions docker/run-docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
RUN_DIR=$(dirname $(readlink -f $0))

DOCKER_VOLUME="${DOCKER_VOLUME} -v $(dirname ${RUN_DIR}):/workspace/involution:rw"

docker run \
-it \
--rm \
--gpus '"device=0"' \
${DOCKER_VOLUME} \
--name Involution-PyTorch \
pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel bash
# pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel bash
# pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel bash
# pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel bash
# nvcr.io/nvidia/pytorch:21.05-py3
# nvcr.io/nvidia/pytorch:20.08-py3
53 changes: 53 additions & 0 deletions include/involution2d_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/Parallel.h>

namespace involution {
namespace cpu {

at::Tensor involution2d_forward(
const at::Tensor& input,
const at::Tensor& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
);

at::Tensor involution2d_backward_grad_input(
const at::Tensor& grad,
const at::Tensor& weight,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
);

at::Tensor involution2d_backward_grad_weight(
const at::Tensor& grad,
const at::Tensor& input,
const std::vector<int64_t>& weight_shape,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
);

std::vector<at::Tensor> involution2d_backward(
const at::Tensor& grad,
const at::Tensor& weight,
const at::Tensor& input,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
);

} // namespace cpu
} // namespace involution
59 changes: 59 additions & 0 deletions include/involution2d_cuda.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

namespace involution {
namespace cuda {

#define CUDA_MAX_THREADS 1024u

#define CUDA_KERNEL_LOOP(i, n) \
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)

at::Tensor involution2d_forward(
const at::Tensor& input,
const at::Tensor& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
);

at::Tensor involution2d_backward_grad_input(
const at::Tensor& grad,
const at::Tensor& weight,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
);

at::Tensor involution2d_backward_grad_weight(
const at::Tensor& grad,
const at::Tensor& input,
const std::vector<int64_t>& weight_shape,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
);

std::vector<at::Tensor> involution2d_backward(
const at::Tensor& grad,
const at::Tensor& weight,
const at::Tensor& input,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
);

} // namespace cuda
} // namespace involution
234 changes: 234 additions & 0 deletions include/involution2d_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
#pragma once

#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/autocast_mode.h>
#include <torch/csrc/autograd/custom_function.h>

#include "involution2d_cpu.h"

#ifdef USE_CUDA
# include "involution2d_cuda.cuh"
#endif

namespace involution {

at::Tensor involution2d(
const at::Tensor& input,
const at::Tensor& weight,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation
) {
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("involution::involution2d", "")
.typed<decltype(involution2d)>();

return op.call(input, weight, stride, padding, dilation);
}

at::Tensor involution2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation
) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = at::autocast::promote_type(at::kFloat, input, weight);
return involution2d(at::autocast::cached_cast(exec_type, input), at::autocast::cached_cast(exec_type, weight), stride, padding, dilation)
.to(input.scalar_type());
}

at::Tensor _involution2d_backward_grad_input(
const at::Tensor& grad,
const at::Tensor& weight,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation
) {
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("involution2d::_involution2d_backward_grad_input", "")
.typed<decltype(_involution2d_backward_grad_input)>();

return op.call(grad, weight, input_shape, stride, padding, dilation);
}

at::Tensor _involution2d_backward_grad_weight(
const at::Tensor& grad,
const at::Tensor& input,
const std::vector<int64_t>& weight_shape,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation
) {
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("involution2d::_involution2d_backward_grad_weight", "")
.typed<decltype(_involution2d_backward_grad_weight)>();

return op.call(grad, input, weight_shape, stride, padding, dilation);
}

namespace cpu {

class Involution2dFunctionCPU : public torch::autograd::Function<Involution2dFunctionCPU>
{
public:

static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
ctx->saved_data["kernel_size"] = kernel_size;
ctx->saved_data["stride"] = stride;
ctx->saved_data["padding"] = padding;
ctx->saved_data["dilation"] = dilation;
ctx->saved_data["groups"] = groups;
ctx->save_for_backward({input, weight});

auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups);

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list grad_output
) {
torch::autograd::variable_list saved = ctx->get_saved_variables();
torch::autograd::Variable input = saved[0];
torch::autograd::Variable weight = saved[1];

auto kernel_size = ctx->saved_data["kernel_size"].toIntVector();
auto stride = ctx->saved_data["stride"].toIntVector();
auto padding = ctx->saved_data["padding"].toIntVector();
auto dilation = ctx->saved_data["dilation"].toIntVector();
auto groups = ctx->saved_data["groups"].toInt();

auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups);

return {
grads[0],
grads[1],
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()
};
}
};

at::Tensor involution2d_autograd(
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
return Involution2dFunctionCPU::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0];
}

} // namespace cpu

#ifdef USE_CUDA

namespace cuda {

class Involution2dFunctionCUDA : public torch::autograd::Function<Involution2dFunctionCUDA>
{
public:

static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
ctx->saved_data["kernel_size"] = kernel_size;
ctx->saved_data["stride"] = stride;
ctx->saved_data["padding"] = padding;
ctx->saved_data["dilation"] = dilation;
ctx->saved_data["groups"] = groups;
ctx->save_for_backward({input, weight});

auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups);

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list grad_output
) {
torch::autograd::variable_list saved = ctx->get_saved_variables();
torch::autograd::Variable input = saved[0];
torch::autograd::Variable weight = saved[1];

auto kernel_size = ctx->saved_data["kernel_size"].toIntVector();
auto stride = ctx->saved_data["stride"].toIntVector();
auto padding = ctx->saved_data["padding"].toIntVector();
auto dilation = ctx->saved_data["dilation"].toIntVector();
auto groups = ctx->saved_data["groups"].toInt();

auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups);

return {
grads[0],
grads[1],
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()
};
}
};

at::Tensor involution2d_autograd(
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
return Involution2dFunctionCUDA::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0];
}

at::Tensor involution2d_autocast(
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = at::autocast::promote_type(at::kFloat, input, weight);
return involution2d_autograd(
at::autocast::cached_cast(exec_type, input),
at::autocast::cached_cast(exec_type, weight),
kernel_size, stride, padding, dilation, groups
);
}

Comment on lines +212 to +229

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csvance
Fixed CUDA implementation input to be full precision using Autocast.

} // namespace cuda

#endif

} // namespace involution
9 changes: 9 additions & 0 deletions involution/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from glob import glob
import os

from torch import ops

_LIB_PATH = glob(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'involution.*.so'))[0]
ops.load_library(_LIB_PATH)

from .involution2d import Involution2d