Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast and generic implementation using OpenMP and CUDA #45

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 0 additions & 27 deletions include/autocast.h

This file was deleted.

2 changes: 1 addition & 1 deletion include/involution2d_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace involution {
namespace cuda {

#define CUDA_MAX_THREADS 512u
#define CUDA_MAX_THREADS 1024u

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@d-li14
In your CuPy implementation, the maximum number of CUDA threads was set to 1024. However, when I experimented, my CuPy reimplementation did not work with 1024, so I set it to 512.

My CUDA implementation does work with 1024. However, when I experimented, I set it to 512 and forgot to change it back to 1024.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shikishima-TasakiLab
Thanks, but I have tried to change the maximum CUDA threads, and it seems the result is still similar.

#define CUDA_KERNEL_LOOP(i, n) \
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
Expand Down
23 changes: 21 additions & 2 deletions include/involution2d_wrapper.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/autocast_mode.h>
#include <torch/csrc/autograd/custom_function.h>

#include "autocast.h"
#include "involution2d_cpu.h"

#ifdef USE_CUDA
Expand Down Expand Up @@ -34,7 +34,8 @@ at::Tensor involution2d_autocast(
const std::vector<int64_t>& dilation
) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return involution2d(autocast::_cast(at::kFloat, input), autocast::_cast(at::kFloat, weight), stride, padding, dilation)
auto exec_type = at::autocast::promote_type(at::kFloat, input, weight);
return involution2d(at::autocast::cached_cast(exec_type, input), at::autocast::cached_cast(exec_type, weight), stride, padding, dilation)
.to(input.scalar_type());
}

Expand Down Expand Up @@ -208,6 +209,24 @@ at::Tensor involution2d_autograd(
return Involution2dFunctionCUDA::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0];
}

at::Tensor involution2d_autocast(
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = at::autocast::promote_type(at::kFloat, input, weight);
return involution2d_autograd(
at::autocast::cached_cast(exec_type, input),
at::autocast::cached_cast(exec_type, weight),
kernel_size, stride, padding, dilation, groups
);
}

Comment on lines +212 to +229

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csvance
Fixed CUDA implementation input to be full precision using Autocast.

} // namespace cuda

#endif
Expand Down
8 changes: 4 additions & 4 deletions src/pytorch_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ TORCH_LIBRARY_IMPL(involution, CUDA, m) {
}
#endif

// TORCH_LIBRARY_IMPL(involution, Autocast, m) {
// m.impl("involution2d", involution2d_autocast);
// }

TORCH_LIBRARY_IMPL(involution, AutogradCPU, m) {
m.impl("involution2d", involution::cpu::involution2d_autograd);
}
Expand All @@ -36,4 +32,8 @@ TORCH_LIBRARY_IMPL(involution, AutogradCPU, m) {
TORCH_LIBRARY_IMPL(involution, AutogradCUDA, m) {
m.impl("involution2d", involution::cuda::involution2d_autograd);
}

TORCH_LIBRARY_IMPL(involution, Autocast, m) {
m.impl("involution2d", involution::cuda::involution2d_autocast);
}
#endif