diff --git a/README.md b/README.md index e799f74..68d94bd 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ By [Duo Li](https://duoli.org/), [Jie Hu](https://github.com/hujie-frank), [Chan

-**TL; DR.** `involution` is a general-purpose neural primitive that is versatile for a spectrum of deep learning models on different vision tasks. `involution` bridges `convolution` and `self-attention` in design, while being more efficient and effective than `convolution`, simpler than `self-attention` in form. +**TL; DR.** `involution` is a general-purpose neural primitive that is versatile for a spectrum of deep learning models on different vision tasks. `involution` bridges `convolution` and `self-attention` in design, while being more efficient and effective than `convolution`, simpler than `self-attention` in form.

diff --git a/docker/run-docker.sh b/docker/run-docker.sh new file mode 100755 index 0000000..e30c835 --- /dev/null +++ b/docker/run-docker.sh @@ -0,0 +1,17 @@ +#!/bin/bash +RUN_DIR=$(dirname $(readlink -f $0)) + +DOCKER_VOLUME="${DOCKER_VOLUME} -v $(dirname ${RUN_DIR}):/workspace/involution:rw" + +docker run \ + -it \ + --rm \ + --gpus '"device=0"' \ + ${DOCKER_VOLUME} \ + --name Involution-PyTorch \ + pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel bash + # pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel bash + # pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel bash + # pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel bash + # nvcr.io/nvidia/pytorch:21.05-py3 + # nvcr.io/nvidia/pytorch:20.08-py3 diff --git a/include/involution2d_cpu.h b/include/involution2d_cpu.h new file mode 100644 index 0000000..a1e8851 --- /dev/null +++ b/include/involution2d_cpu.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +namespace involution { +namespace cpu { + +at::Tensor involution2d_forward( + const at::Tensor& input, + const at::Tensor& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +); + +at::Tensor involution2d_backward_grad_input( + const at::Tensor& grad, + const at::Tensor& weight, + const std::vector& input_shape, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +); + +at::Tensor involution2d_backward_grad_weight( + const at::Tensor& grad, + const at::Tensor& input, + const std::vector& weight_shape, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +); + +std::vector involution2d_backward( + const at::Tensor& grad, + const at::Tensor& weight, + const at::Tensor& input, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +); + +} // namespace cpu +} // namespace involution diff --git a/include/involution2d_cuda.cuh b/include/involution2d_cuda.cuh new file mode 100644 index 0000000..dfc9425 --- /dev/null +++ b/include/involution2d_cuda.cuh @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include + +namespace involution { +namespace cuda { + +#define CUDA_MAX_THREADS 1024u + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + +at::Tensor involution2d_forward( + const at::Tensor& input, + const at::Tensor& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +); + +at::Tensor involution2d_backward_grad_input( + const at::Tensor& grad, + const at::Tensor& weight, + const std::vector& input_shape, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +); + +at::Tensor involution2d_backward_grad_weight( + const at::Tensor& grad, + const at::Tensor& input, + const std::vector& weight_shape, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +); + +std::vector involution2d_backward( + const at::Tensor& grad, + const at::Tensor& weight, + const at::Tensor& input, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +); + +} // namespace cuda +} // namespace involution diff --git a/include/involution2d_wrapper.h b/include/involution2d_wrapper.h new file mode 100644 index 0000000..12f4295 --- /dev/null +++ b/include/involution2d_wrapper.h @@ -0,0 +1,234 @@ +#pragma once + +#include +#include +#include + +#include "involution2d_cpu.h" + +#ifdef USE_CUDA +# include "involution2d_cuda.cuh" +#endif + +namespace involution { + +at::Tensor involution2d( + const at::Tensor& input, + const at::Tensor& weight, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation +) { + static auto op = at::Dispatcher::singleton() + .findSchemaOrThrow("involution::involution2d", "") + .typed(); + + return op.call(input, weight, stride, padding, dilation); +} + +at::Tensor involution2d_autocast( + const at::Tensor& input, + const at::Tensor& weight, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation +) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto exec_type = at::autocast::promote_type(at::kFloat, input, weight); + return involution2d(at::autocast::cached_cast(exec_type, input), at::autocast::cached_cast(exec_type, weight), stride, padding, dilation) + .to(input.scalar_type()); +} + +at::Tensor _involution2d_backward_grad_input( + const at::Tensor& grad, + const at::Tensor& weight, + const std::vector& input_shape, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation +) { + static auto op = at::Dispatcher::singleton() + .findSchemaOrThrow("involution2d::_involution2d_backward_grad_input", "") + .typed(); + + return op.call(grad, weight, input_shape, stride, padding, dilation); +} + +at::Tensor _involution2d_backward_grad_weight( + const at::Tensor& grad, + const at::Tensor& input, + const std::vector& weight_shape, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation +) { + static auto op = at::Dispatcher::singleton() + .findSchemaOrThrow("involution2d::_involution2d_backward_grad_weight", "") + .typed(); + + return op.call(grad, input, weight_shape, stride, padding, dilation); +} + +namespace cpu { + +class Involution2dFunctionCPU : public torch::autograd::Function +{ + public: + + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups + ) { + ctx->saved_data["kernel_size"] = kernel_size; + ctx->saved_data["stride"] = stride; + ctx->saved_data["padding"] = padding; + ctx->saved_data["dilation"] = dilation; + ctx->saved_data["groups"] = groups; + ctx->save_for_backward({input, weight}); + + auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list grad_output + ) { + torch::autograd::variable_list saved = ctx->get_saved_variables(); + torch::autograd::Variable input = saved[0]; + torch::autograd::Variable weight = saved[1]; + + auto kernel_size = ctx->saved_data["kernel_size"].toIntVector(); + auto stride = ctx->saved_data["stride"].toIntVector(); + auto padding = ctx->saved_data["padding"].toIntVector(); + auto dilation = ctx->saved_data["dilation"].toIntVector(); + auto groups = ctx->saved_data["groups"].toInt(); + + auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups); + + return { + grads[0], + grads[1], + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable() + }; + } +}; + +at::Tensor involution2d_autograd( + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + return Involution2dFunctionCPU::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0]; +} + +} // namespace cpu + +#ifdef USE_CUDA + +namespace cuda { + +class Involution2dFunctionCUDA : public torch::autograd::Function +{ + public: + + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups + ) { + ctx->saved_data["kernel_size"] = kernel_size; + ctx->saved_data["stride"] = stride; + ctx->saved_data["padding"] = padding; + ctx->saved_data["dilation"] = dilation; + ctx->saved_data["groups"] = groups; + ctx->save_for_backward({input, weight}); + + auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::variable_list grad_output + ) { + torch::autograd::variable_list saved = ctx->get_saved_variables(); + torch::autograd::Variable input = saved[0]; + torch::autograd::Variable weight = saved[1]; + + auto kernel_size = ctx->saved_data["kernel_size"].toIntVector(); + auto stride = ctx->saved_data["stride"].toIntVector(); + auto padding = ctx->saved_data["padding"].toIntVector(); + auto dilation = ctx->saved_data["dilation"].toIntVector(); + auto groups = ctx->saved_data["groups"].toInt(); + + auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups); + + return { + grads[0], + grads[1], + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable(), + torch::autograd::Variable() + }; + } +}; + +at::Tensor involution2d_autograd( + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + return Involution2dFunctionCUDA::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0]; +} + +at::Tensor involution2d_autocast( + const torch::autograd::Variable& input, + const torch::autograd::Variable& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto exec_type = at::autocast::promote_type(at::kFloat, input, weight); + return involution2d_autograd( + at::autocast::cached_cast(exec_type, input), + at::autocast::cached_cast(exec_type, weight), + kernel_size, stride, padding, dilation, groups + ); +} + +} // namespace cuda + +#endif + +} // namespace involution diff --git a/involution/__init__.py b/involution/__init__.py new file mode 100644 index 0000000..7839be6 --- /dev/null +++ b/involution/__init__.py @@ -0,0 +1,9 @@ +from glob import glob +import os + +from torch import ops + +_LIB_PATH = glob(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'involution.*.so'))[0] +ops.load_library(_LIB_PATH) + +from .involution2d import Involution2d diff --git a/involution/involution2d.py b/involution/involution2d.py new file mode 100644 index 0000000..bdf41ec --- /dev/null +++ b/involution/involution2d.py @@ -0,0 +1,123 @@ +from typing import Optional, Tuple, Union +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair +from torch import ops + +def _involution2d( + input: torch.Tensor, + weight: torch.Tensor, + kernel_size: Union[int, Tuple[int, int]] = 7, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: torch.Tensor = None, + ) -> torch.Tensor: + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = _pair(padding) + dilation_ = _pair(dilation) + + output: torch.Tensor = ops.involution.involution2d(input, weight, kernel_size_, stride_, padding_, dilation_, groups) + + if bias is not None: + output += bias.view(1, -1, 1, 1) + + return output + +class Involution2d(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = 7, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 3, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = False, + sigma_mapping: Optional[nn.Module] = None, + reduce_ratio: int = 1, + ) -> None: + """2D Involution: https://arxiv.org/pdf/2103.06255.pdf + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + kernel_size (Union[int, Tuple[int, int]], optional): Kernel size to be used. Defaults to 7. + stride (Union[int, Tuple[int, int]], optional): Stride factor to be utilized. Defaults to 1. + padding (Union[int, Tuple[int, int]], optional): Padding to be used in unfold operation. Defaults to 3. + dilation (Union[int, Tuple[int, int]], optional): Dilation in unfold to be employed. Defaults to 1. + groups (int, optional): Number of groups to be employed. Defaults to 1. + bias (bool, optional): If true bias is utilized in each convolution layer. Defaults to False. + sigma_mapping (Optional[nn.Module], optional): Non-linear mapping as introduced in the paper. If none BN + ReLU is utilized + reduce_ratio (int, optional): Reduce ration of involution channels. Defaults to 1. + """ + super(Involution2d, self).__init__() + + assert isinstance(in_channels, int) and in_channels > 0, \ + '"in_channels" must be a positive integer.' + assert isinstance(out_channels, int) and out_channels > 0, \ + '"out_channels" must be a positive integer.' + assert isinstance(kernel_size, (int, tuple)), \ + '"kernel_size" must be an int or a tuple of ints.' + assert isinstance(stride, (int, tuple)), \ + '"stride" must be an int or a tuple of ints.' + assert isinstance(padding, (int, tuple)), \ + '"padding" must be an int or a tuple of ints.' + assert isinstance(dilation, (int, tuple)), \ + '"dilation" must be an int or a tuple of ints.' + assert isinstance(groups, int) and groups > 0, \ + '"groups" must be a positive integer.' + assert in_channels % groups == 0, '"in_channels" must be divisible by "groups".' + assert out_channels % groups == 0, '"out_channels" must be divisible by "groups".' + assert isinstance(bias, bool), '"bias" must be a bool.' + assert isinstance(sigma_mapping, nn.Module) or sigma_mapping is None, \ + '"sigma_mapping" muse be an int or a tuple of ints.' + assert isinstance(reduce_ratio, int) and reduce_ratio > 0, \ + '"reduce_ratio" must be a positive integer.' + + self.in_channels: int = in_channels + self.out_channels: int = out_channels + self.kernel_size: Tuple[int, int] = _pair(kernel_size) + self.stride: Tuple[int, int] = _pair(stride) + self.padding: Tuple[int, int] = _pair(padding) + self.dilation: Tuple[int, int] = _pair(dilation) + self.groups: int = groups + self.bias: bool = bias + self.reduce_ratio: int = reduce_ratio + + self.sigma_mapping = sigma_mapping if isinstance(sigma_mapping, nn.Module) else nn.Sequential( + nn.BatchNorm2d(num_features=self.out_channels // + self.reduce_ratio, momentum=0.3), + nn.ReLU() + ) + self.initial_mapping = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, bias=bias) \ + if self.in_channels != self.out_channels else nn.Identity() + self.o_mapping = nn.AvgPool2d( + kernel_size=self.stride) if self.stride[0] > 1 or self.stride[1] > 1 else nn.Identity() + self.reduce_mapping = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.out_channels // self.reduce_ratio, kernel_size=1, bias=bias) + self.span_mapping = nn.Conv2d(in_channels=self.out_channels // self.reduce_ratio, + out_channels=self.kernel_size[0] * self.kernel_size[1] * self.groups, kernel_size=1, bias=bias) + + def __repr__(self) -> str: + """Method returns information about the module + Returns: + str: Info string + """ + return (f'{self.__class__.__name__}({self.in_channels}, {self.out_channels}, kernel_size=({self.kernel_size[0]}, {self.kernel_size[1]}), ' + f'stride=({self.stride[0]}, {self.stride[1]}), padding=({self.padding[0]}, {self.padding[1]}), dilation=({self.dilation[0], self.dilation[1]}), ' + f'groups={self.groups}, bias={self.bias}, reduce_ratio={self.reduce_ratio}, sigma_mapping={str(self.sigma_mapping)}' + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward pass + Args: + input (torch.Tensor): Input tensor of the shape [batch size, in channels, height, width] + Returns: + torch.Tensor: Output tensor of the shape [batch size, out channels, height, width] (w/ same padding) + """ + weight: torch.Tensor = self.span_mapping(self.sigma_mapping(self.reduce_mapping(self.o_mapping(input)))) + input_init: torch.Tensor = self.initial_mapping(input) + + return _involution2d(input_init, weight, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5207016 --- /dev/null +++ b/setup.py @@ -0,0 +1,70 @@ +import os +from os.path import abspath, dirname, join +from setuptools import setup, find_packages +from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension + +INCLUDE_DIR = join(dirname(abspath(__file__)), 'include') +EXTRA_COMPILE_ARGS = ['-O3'] + +EXTENSION = [] + +CC = ['52', '53', '60', '61', '62', '70', '72', '75', '80'] + +if os.getenv('USE_OPENMP', '1') == '1': + EXTRA_COMPILE_ARGS.append('-fopenmp') + +if os.getenv('USE_CUDA', '1') == '1': + EXTRA_COMPILE_ARGS.append('-DUSE_CUDA') + + GENERATE_CODES = [] + + for cc in CC: + GENERATE_CODES.append('--generate-code') + GENERATE_CODES.append(f'arch=compute_{cc},code=compute_{cc}') + + EXTENSION.append( + CUDAExtension( + name='involution', + sources=[ + 'src/involution2d_cpu.cpp', + 'src/involution2d_cuda.cu', + 'src/pytorch_wrapper.cpp', + ], + include_dirs=[ + INCLUDE_DIR + ], + extra_compile_args={ + 'cxx': EXTRA_COMPILE_ARGS, + 'nvcc': ['-O3'] + GENERATE_CODES, + } + ) + ) +else: + EXTENSION.append( + CppExtension( + name='involution', + sources=[ + 'src/involution2d_cpu.cpp', + 'src/pytorch_wrapper.cpp', + ], + include_dirs=[ + INCLUDE_DIR + ], + extra_compile_args=EXTRA_COMPILE_ARGS + ) + ) + +setup( + name='involution-pytorch', + version="0.1.0", + url="https://github.com/shikishima-TasakiLab/Involution-PyTorch", + license="MIT License", + author="Junya Shikishima", + author_email="160442065@ccalumni.meijo-u.ac.jp", + description="PyTorch Involution", + packages=find_packages(), + ext_modules=EXTENSION, + cmdclass={ + 'build_ext': BuildExtension, + } +) diff --git a/src/involution2d_cpu.cpp b/src/involution2d_cpu.cpp new file mode 100644 index 0000000..d9af583 --- /dev/null +++ b/src/involution2d_cpu.cpp @@ -0,0 +1,359 @@ +#include "involution2d_cpu.h" + +namespace involution { +namespace cpu { + +template +static void involution2d_forward_frame( + const at::Tensor& in_data, + const at::Tensor& weight_data, + at::Tensor& out_data, + const at::IntArrayRef& kernel_size, + const at::IntArrayRef& padding, + const at::IntArrayRef& stride, + const at::IntArrayRef& dilation +) { + auto num_elements = out_data.numel(); + const auto groups = weight_data.size(1); + const auto channels = in_data.size(1); + const auto in_height = in_data.size(2); + const auto in_width = in_data.size(3); + const auto out_height = out_data.size(2); + const auto out_width = out_data.size(3); + + auto in_data_a = in_data.accessor(); + auto weight_data_a = weight_data.accessor(); + auto* out_data_p = out_data.data_ptr(); + + #pragma omp parallel for + for (int64_t idx = 0l; idx < num_elements; idx++) { + const int64_t w = idx % out_width; + const int64_t h = (idx / out_width) % out_height; + int64_t divisor = out_width * out_height; + const int64_t c = (idx / divisor) % channels; + divisor *= channels; + const int64_t n = idx / divisor; + const int64_t g = c / (channels / groups); + + scalar_t value = 0; + + for (int64_t kh = 0l; kh < kernel_size[0]; kh++) { + const int64_t h_in = h * stride[0] + kh * dilation[0] - padding[0]; + + if ((0l <= h_in) && (h_in < in_height)) { + for (int64_t kw = 0l; kw < kernel_size[1]; kw++) { + const int64_t w_in = w * stride[1] + kw * dilation[1] - padding[1]; + + if ((0l <= w_in) && (w_in < in_width)) { + value += weight_data_a[n][g][kh][kw][h][w] * in_data_a[n][c][h_in][w_in]; + } + } + } + } + out_data_p[idx] = value; + } +} + +at::Tensor involution2d_forward( + const at::Tensor& input, + const at::Tensor& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + AT_ASSERTM(input.device().is_cpu(), "\"input\" must be a CPU tensor."); + AT_ASSERTM(weight.device().is_cpu(), "\"weight\" must be a CPU tensor."); + + at::TensorArg input_t{input, "input", 1}, weight_t{weight, "weight", 2}; + + at::CheckedFrom c = __func__; + at::checkAllSameType(c, {input_t, weight_t}); + + const auto batch_size = input.size(0); + const auto channels = input.size(1); + const auto in_height = input.size(2); + const auto in_width = input.size(3); + + const auto weight_height = weight.size(2); + const auto weight_width = weight.size(3); + + const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width}); + + const auto out_height = (in_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1; + const auto out_width = (in_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1; + + at::Tensor output = at::zeros({batch_size, channels, out_height, out_width}, input.options()); + + if (output.numel() == 0) { + return output; + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, + at::kBFloat16, + input.scalar_type(), + "involution2d_forward_kernel", [&] { + involution2d_forward_frame( + input, + weight_, + output, + kernel_size, + padding, + stride, + dilation + ); + } + ); + return output; +} + +template +static void involution2d_backward_grad_input_frame( + const at::Tensor& out_diff, + const at::Tensor& weight_data, + at::Tensor& in_diff, + const at::IntArrayRef& kernel_size, + const at::IntArrayRef& padding, + const at::IntArrayRef& stride, + const at::IntArrayRef& dilation +) { + auto num_elements = in_diff.numel(); + const auto groups = weight_data.size(1); + const auto channels = in_diff.size(1); + const auto in_height = in_diff.size(2); + const auto in_width = in_diff.size(3); + const auto out_height = out_diff.size(2); + const auto out_width = out_diff.size(3); + + auto out_diff_a = out_diff.accessor(); + auto weight_data_a = weight_data.accessor(); + auto* in_diff_p = in_diff.data_ptr(); + + #pragma omp parallel for + for (int64_t idx = 0l; idx < num_elements; idx++) { + const int64_t w = idx % in_width; + const int64_t h = (idx / in_width) % in_height; + int64_t divisor = in_width * in_height; + const int64_t c = (idx / divisor) % channels; + divisor *= channels; + const int64_t n = idx / divisor; + const int64_t g = c / (channels / groups); + + scalar_t value = 0; + + for (int64_t kh = 0l; kh < kernel_size[0]; kh++) { + const int64_t h_out_s = h + padding[0] - kh * dilation[0]; + + for (int64_t kw = 0l; kw < kernel_size[1]; kw++) { + const int64_t w_out_s = w + padding[1] - kw * dilation[1]; + + if (((h_out_s % stride[0]) == 0) && ((w_out_s % stride[1]) == 0)) { + const int64_t h_out = h_out_s / stride[0]; + const int64_t w_out = h_out_s / stride[1]; + + if ((0l <= h_out) && (h_out < out_height) && (0l <= w_out) && (w_out < out_width)) { + value += weight_data_a[n][g][kh][kw][h_out][w_out] * out_diff_a[n][c][h_out][w_out]; + } + } + } + } + in_diff_p[idx] = value; + } +} + +at::Tensor involution2d_backward_grad_input( + const at::Tensor& grad, + const at::Tensor& weight, + const std::vector& input_shape, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + AT_ASSERTM(grad.device().is_cpu(), "\"grad\" must be a CPU tensor."); + AT_ASSERTM(weight.device().is_cpu(), "\"weight\" must be a CPU tensor."); + + at::TensorArg grad_t{grad, "grad", 1}, weight_t{weight, "weight", 2}; + + at::CheckedFrom c = __func__; + at::checkAllSameType(c, {grad_t, weight_t}); + + const auto batch_size = input_shape[0]; + + const auto weight_height = weight.size(2); + const auto weight_width = weight.size(3); + + const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width}); + + at::Tensor grad_input = at::zeros(input_shape, grad.options()); + + if (grad_input.numel() == 0) { + return grad_input; + } + + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, grad.scalar_type(), "involution2d_backward_grad_input_frame", [&] { + involution2d_backward_grad_input_frame( + grad, + weight_, + grad_input, + kernel_size, + padding, + stride, + dilation + ); + }); + + return grad_input; +} + +template +static void involution2d_backward_grad_weight_frame( + const at::Tensor& out_diff, + const at::Tensor& in_data, + at::Tensor& weight_diff, + const at::IntArrayRef& kernel_size, + const at::IntArrayRef& padding, + const at::IntArrayRef& stride, + const at::IntArrayRef& dilation +) { + auto num_elements = weight_diff.numel(); + const auto groups = weight_diff.size(1); + const auto batch_size = in_data.size(0); + const auto channels = in_data.size(1); + const auto in_height = in_data.size(2); + const auto in_width = in_data.size(3); + const auto out_height = out_diff.size(2); + const auto out_width = out_diff.size(3); + const auto channels_per_group = channels / groups; + + auto out_diff_a = out_diff.accessor(); + auto in_data_a = in_data.accessor(); + auto* weight_diff_p = weight_diff.data_ptr(); + + #pragma omp parallel for + for (int64_t idx = 0l; idx < num_elements; idx++) { + const int64_t w = idx % out_width; + const int64_t h = (idx / out_width) % out_height; + int64_t divisor = out_width * out_height; + const int64_t kw = (idx / divisor) % kernel_size[1]; + divisor *= kernel_size[1]; + const int64_t kh = (idx / divisor) % kernel_size[0]; + + const int64_t h_in = h * stride[0] + kh * dilation[0] - padding[0]; + const int64_t w_in = w * stride[1] + kw * dilation[1] - padding[1]; + + if ((0l <= h_in) && (h_in < in_height) && (0l <= w_in) && (w_in < in_width)) { + divisor *= kernel_size[0]; + const int64_t g = (idx / divisor) % groups; + divisor *= groups; + const int64_t n = (idx / divisor) % batch_size; + + scalar_t value = 0; + + for (int64_t c = g * channels_per_group; c < (g + 1) * channels_per_group; c++) { + value += out_diff_a[n][c][h][w] * in_data_a[n][c][h_in][w_in]; + } + weight_diff_p[idx] = value; + } + else { + weight_diff_p[idx] = 0; + } + } +} + +at::Tensor involution2d_backward_grad_weight( + const at::Tensor& grad, + const at::Tensor& input, + const std::vector& weight_shape, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + AT_ASSERTM(grad.device().is_cpu(), "\"grad\" must be a CPU tensor."); + AT_ASSERTM(input.device().is_cpu(), "\"input\" must be a CPU tensor."); + + at::TensorArg grad_t{grad, "grad", 1}, input_t{input, "input", 2}; + + at::CheckedFrom c = __func__; + at::checkAllSameType(c, {grad_t, input_t}); + + const auto batch_size = input.size(0); + + at::Tensor grad_weight = at::zeros({batch_size, groups, kernel_size[0], kernel_size[1], weight_shape[2], weight_shape[3]}, grad.options()); + + if (grad_weight.numel() == 0) { + return grad_weight.view(weight_shape); + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, + at::kBFloat16, + grad.scalar_type(), + "involution2d_backward_grad_weight_kernel", [&] { + involution2d_backward_grad_weight_frame( + grad, + input, + grad_weight, + kernel_size, + padding, + stride, + dilation + ); + } + ); + return grad_weight.view(weight_shape); +} + +std::vector involution2d_backward( + const at::Tensor& grad, + const at::Tensor& weight, + const at::Tensor& input, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + auto input_sizes = input.sizes(); + std::vector input_size; + std::copy(input_sizes.begin(), input_sizes.end(), std::back_inserter(input_size)); + + auto grad_input = involution2d_backward_grad_input( + grad, + weight, + input_size, + kernel_size, + stride, + padding, + dilation, + groups + ); + + auto weight_sizes = weight.sizes(); + std::vector weight_size; + std::copy(weight_sizes.begin(), weight_sizes.end(), std::back_inserter(weight_size)); + + auto grad_weight = involution2d_backward_grad_weight( + grad, + input, + weight_size, + kernel_size, + stride, + padding, + dilation, + groups + ); + + // std::vector output{grad_input, grad_weight}; + + // return output; + return {grad_input, grad_weight}; +} + +} // namespace cpu +} // namespace involution diff --git a/src/involution2d_cuda.cu b/src/involution2d_cuda.cu new file mode 100644 index 0000000..925e6a5 --- /dev/null +++ b/src/involution2d_cuda.cu @@ -0,0 +1,407 @@ +#include + +namespace involution { +namespace cuda { + +static u_int32_t ceildiv(u_int32_t num_elements, u_int32_t threads) { + return (num_elements + threads - 1) / threads; +} + +template +C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK) +__global__ static void involution2d_forward_kernel( + const at::GenericPackedTensorAccessor in_data, + const at::GenericPackedTensorAccessor weight_data, + scalar_t* const out_data, + const int64_t num_elements, + const int64_t channels, + const int64_t groups, + const int64_t in_height, const int64_t in_width, + const int64_t out_height, const int64_t out_width, + const int64_t kernel_height, const int64_t kernel_width, + const int64_t pad_h, const int64_t pad_w, + const int64_t stride_h, const int64_t stride_w, + const int64_t dilation_h, const int64_t dilation_w +) { + CUDA_KERNEL_LOOP(idx, num_elements) { + const int64_t w = idx % out_width; + const int64_t h = (idx / out_width) % out_height; + int64_t divisor = out_width * out_height; + const int64_t c = (idx / divisor) % channels; + divisor *= channels; + const int64_t n = idx / divisor; + const int64_t g = c / (channels / groups); + + scalar_t value = 0; + + for (int64_t kh = 0l; kh < kernel_height; kh++) { + const int64_t h_in = h * stride_h + kh * dilation_h - pad_h; + + if ((0l <= h_in) && (h_in < in_height)) { + for (int64_t kw = 0l; kw < kernel_width; kw++) { + const int64_t w_in = w * stride_w + kw * dilation_w - pad_w; + + if ((0l <= w_in) && (w_in < in_width)) { + value += weight_data[n][g][kh][kw][h][w] * in_data[n][c][h_in][w_in]; + } + } + } + } + + out_data[idx] = value; + } +} + +at::Tensor involution2d_forward( + const at::Tensor& input, + const at::Tensor& weight, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + AT_ASSERTM(input.device().is_cuda(), "\"input\" must be a CUDA tensor."); + AT_ASSERTM(weight.device().is_cuda(), "\"weight\" must be a CUDA tensor."); + + at::TensorArg input_t{input, "input", 1}, weight_t{weight, "weight", 2}; + + at::CheckedFrom c = __func__; + at::checkAllSameGPU(c, {input_t, weight_t}); + at::checkAllSameType(c, {input_t, weight_t}); + + at::cuda::CUDAGuard device_guard(input.device()); + + const auto batch_size = input.size(0); + const auto channels = input.size(1); + const auto in_height = input.size(2); + const auto in_width = input.size(3); + + const auto weight_height = weight.size(2); + const auto weight_width = weight.size(3); + + const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width}); + + const auto out_height = (in_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1; + const auto out_width = (in_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1; + + at::Tensor output = at::zeros({batch_size, channels, out_height, out_width}, input.options()); + const auto num_elements = output.numel(); + + if (num_elements == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return output; + } + + const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK); + const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u); + const dim3 threads_per_block(threads, 1u, 1u); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, + at::kBFloat16, + input.scalar_type(), + "involution2d_forward_kernel", [&] { + involution2d_forward_kernel<<>>( + input.generic_packed_accessor(), + weight_.generic_packed_accessor(), + output.data_ptr(), + num_elements, + channels, + groups, + in_height, in_width, + out_height, out_width, + kernel_size[0], kernel_size[1], + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1] + ); + } + ); + AT_CUDA_CHECK(cudaGetLastError()); + return output; +} + +template +C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK) +__global__ static void involution2d_backward_grad_input_kernel( + const at::GenericPackedTensorAccessor out_diff, + const at::GenericPackedTensorAccessor weight_data, + scalar_t* const in_diff, + const int64_t num_elements, + const int64_t channels, + const int64_t groups, + const int64_t in_height, const int64_t in_width, + const int64_t out_height, const int64_t out_width, + const int64_t kernel_height, const int64_t kernel_width, + const int64_t pad_h, const int64_t pad_w, + const int64_t stride_h, const int64_t stride_w, + const int64_t dilation_h, const int64_t dilation_w +) { + CUDA_KERNEL_LOOP(idx, num_elements) { + const int64_t w = idx % in_width; + const int64_t h = (idx / in_width) % in_height; + int64_t divisor = in_width * in_height; + const int64_t c = (idx / divisor) % channels; + divisor *= channels; + const int64_t n = idx / divisor; + const int64_t g = c / (channels / groups); + + scalar_t value = 0; + + for (int64_t kh = 0l; kh < kernel_height; kh++) { + const int64_t h_out_s = h + pad_h - kh * dilation_h; + + for (int64_t kw = 0l; kw < kernel_width; kw++) { + const int64_t w_out_s = w + pad_w - kw * dilation_w; + + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { + const int64_t h_out = h_out_s / stride_h; + const int64_t w_out = h_out_s / stride_w; + + if ((0l <= h_out) && (h_out < out_height) && (0l <= w_out) && (w_out < out_width)) { + value += weight_data[n][g][kh][kw][h_out][w_out] * out_diff[n][c][h_out][w_out]; + } + } + } + } + in_diff[idx] = value; + } +} + +at::Tensor involution2d_backward_grad_input( + const at::Tensor& grad, + const at::Tensor& weight, + const std::vector& input_shape, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + AT_ASSERTM(grad.device().is_cuda(), "\"grad\" must be a CUDA tensor."); + AT_ASSERTM(weight.device().is_cuda(), "\"weight\" must be a CUDA tensor."); + + at::TensorArg grad_t{grad, "grad", 1}, weight_t{weight, "weight", 2}; + + at::CheckedFrom c = __func__; + at::checkAllSameGPU(c, {grad_t, weight_t}); + at::checkAllSameType(c, {grad_t, weight_t}); + + at::cuda::CUDAGuard device_guard(grad.device()); + + const auto batch_size = input_shape[0]; + const auto channels = input_shape[1]; + const auto in_height = input_shape[2]; + const auto in_width = input_shape[3]; + + const auto weight_height = weight.size(2); + const auto weight_width = weight.size(3); + + const at::Tensor weight_ = weight.view({batch_size, groups, kernel_size[0], kernel_size[1], weight_height, weight_width}); + + const auto out_height = grad.size(2); + const auto out_width = grad.size(3); + + at::Tensor grad_input = at::zeros(input_shape, grad.options()); + const auto num_elements = grad_input.numel(); + + if (num_elements == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; + } + + const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK); + const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u); + const dim3 threads_per_block(threads, 1u, 1u); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, + at::kBFloat16, + grad.scalar_type(), + "involution2d_backward_grad_input_kernel", [&] { + involution2d_backward_grad_input_kernel<<>>( + grad.generic_packed_accessor(), + weight_.generic_packed_accessor(), + grad_input.data_ptr(), + num_elements, + channels, + groups, + in_height, in_width, + out_height, out_width, + kernel_size[0], kernel_size[1], + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1] + ); + } + ); + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; +} + +template +C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS_PER_BLOCK) +__global__ static void involution2d_backward_grad_weight_kernel( + const at::GenericPackedTensorAccessor out_diff, + const at::GenericPackedTensorAccessor in_data, + scalar_t* const weight_diff, + const int64_t num_elements, + const int64_t batch_size, + const int64_t channels_per_group, + const int64_t groups, + const int64_t in_height, const int64_t in_width, + const int64_t out_height, const int64_t out_width, + const int64_t kernel_height, const int64_t kernel_width, + const int64_t pad_h, const int64_t pad_w, + const int64_t stride_h, const int64_t stride_w, + const int64_t dilation_h, const int64_t dilation_w +) { + CUDA_KERNEL_LOOP(idx, num_elements) { + const int64_t w = idx % out_width; + const int64_t h = (idx / out_width) % out_height; + int64_t divisor = out_width * out_height; + const int64_t kw = (idx / divisor) % kernel_width; + divisor *= kernel_width; + const int64_t kh = (idx / divisor) % kernel_height; + + const int64_t h_in = -pad_h + h * stride_h + kh * dilation_h; + const int64_t w_in = -pad_w + w * stride_w + kw * dilation_w; + + if ((0l <= h_in) && (h_in < in_height) && (0l <= w_in) && (w_in < in_width)) { + divisor *= kernel_height; + const int64_t g = (idx / divisor) % groups; + divisor *= groups; + const int64_t n = (idx / divisor) % batch_size; + + scalar_t value = 0; + + for (int64_t c = g * channels_per_group; c < (g + 1) * channels_per_group; c++) { + value += out_diff[n][c][h][w] * in_data[n][c][h_in][w_in]; + } + weight_diff[idx] = value; + } + else { + weight_diff[idx] = 0; + } + } +} + +at::Tensor involution2d_backward_grad_weight( + const at::Tensor& grad, + const at::Tensor& input, + const std::vector& weight_shape, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + AT_ASSERTM(grad.device().is_cuda(), "\"grad\" must be a CUDA tensor."); + AT_ASSERTM(input.device().is_cuda(), "\"input\" must be a CUDA tensor."); + + at::TensorArg grad_t{grad, "grad", 1}, input_t{input, "input", 2}; + + at::CheckedFrom c = __func__; + at::checkAllSameGPU(c, {grad_t, input_t}); + at::checkAllSameType(c, {grad_t, input_t}); + + at::cuda::CUDAGuard device_guard(grad.device()); + + const auto batch_size = input.size(0); + const auto channels = input.size(1); + const auto in_height = input.size(2); + const auto in_width = input.size(3); + + const auto out_height = grad.size(2); + const auto out_width = grad.size(3); + + at::Tensor grad_weight = at::zeros({batch_size, groups, kernel_size[0], kernel_size[1], weight_shape[2], weight_shape[3]}, grad.options()); + const auto num_elements = grad_weight.numel(); + + if (num_elements == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_weight.view(weight_shape); + } + + const auto threads = std::min(static_cast(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock), CUDA_MAX_THREADS_PER_BLOCK); + const dim3 num_blocks(ceildiv(num_elements, threads), 1u, 1u); + const dim3 threads_per_block(threads, 1u, 1u); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, + at::kBFloat16, + grad.scalar_type(), + "involution2d_backward_grad_weight_kernel", [&] { + involution2d_backward_grad_weight_kernel<<>>( + grad.generic_packed_accessor(), + input.generic_packed_accessor(), + grad_weight.data_ptr(), + num_elements, + batch_size, + channels / groups, + groups, + in_height, in_width, + out_height, out_width, + kernel_size[0], kernel_size[1], + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1] + ); + } + ); + AT_CUDA_CHECK(cudaGetLastError()); + return grad_weight.view(weight_shape); +} + +std::vector involution2d_backward( + const at::Tensor& grad, + const at::Tensor& weight, + const at::Tensor& input, + const std::vector& kernel_size, + const std::vector& stride, + const std::vector& padding, + const std::vector& dilation, + const int64_t groups +) { + auto input_sizes = input.sizes(); + std::vector input_size; + std::copy(input_sizes.begin(), input_sizes.end(), std::back_inserter(input_size)); + + auto grad_input = involution2d_backward_grad_input( + grad, + weight, + input_size, + kernel_size, + stride, + padding, + dilation, + groups + ); + + auto weight_sizes = weight.sizes(); + std::vector weight_size; + std::copy(weight_sizes.begin(), weight_sizes.end(), std::back_inserter(weight_size)); + + auto grad_weight = involution2d_backward_grad_weight( + grad, + input, + weight_size, + kernel_size, + stride, + padding, + dilation, + groups + ); + + return {grad_input, grad_weight}; +} + +} // namespace cuda +} // namespace involution diff --git a/src/pytorch_wrapper.cpp b/src/pytorch_wrapper.cpp new file mode 100644 index 0000000..af492c6 --- /dev/null +++ b/src/pytorch_wrapper.cpp @@ -0,0 +1,39 @@ +#include +#include "involution2d_wrapper.h" + +TORCH_LIBRARY(involution, m) { + m.def("involution2d(Tensor input, Tensor weight, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor"); + m.def("_involution2d_backward_grad_input(Tensor grad, Tensor weight, int[] input_shape, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor"); + m.def("_involution2d_backward_grad_weight(Tensor grad, Tensor input, int[] weight_shape, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor"); + m.def("_involution2d_backward(Tensor grad, Tensor weight, Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(involution, CPU, m) { + m.impl("involution2d", involution::cpu::involution2d_forward); + m.impl("_involution2d_backward_grad_input", involution::cpu::involution2d_backward_grad_input); + m.impl("_involution2d_backward_grad_weight", involution::cpu::involution2d_backward_grad_weight); + m.impl("_involution2d_backward", involution::cpu::involution2d_backward); +} + +#ifdef USE_CUDA +TORCH_LIBRARY_IMPL(involution, CUDA, m) { + m.impl("involution2d", involution::cuda::involution2d_forward); + m.impl("_involution2d_backward_grad_input", involution::cuda::involution2d_backward_grad_input); + m.impl("_involution2d_backward_grad_weight", involution::cuda::involution2d_backward_grad_weight); + m.impl("_involution2d_backward", involution::cuda::involution2d_backward); +} +#endif + +TORCH_LIBRARY_IMPL(involution, AutogradCPU, m) { + m.impl("involution2d", involution::cpu::involution2d_autograd); +} + +#ifdef USE_CUDA +TORCH_LIBRARY_IMPL(involution, AutogradCUDA, m) { + m.impl("involution2d", involution::cuda::involution2d_autograd); +} + +TORCH_LIBRARY_IMPL(involution, Autocast, m) { + m.impl("involution2d", involution::cuda::involution2d_autocast); +} +#endif