Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
TODO.md
Optimization.md
**/__pycache__/
33 changes: 33 additions & 0 deletions Converse2D/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
### Kernel Registry

---

**Kernel Tree**

```
├---v1 Translation from python to CPP
├---v2 Add FB/F2B cache & broadcast replace repeat
```

**Tested Device**

- NVIDIA RTX 2080ti
- NVIDIA RTX 4090
- NVIDIA RTX 5060ti 16g

**Installation**

```python
cd ./Converse2D
pip install . --no-build-isolation
```

**Usage**

```python
import torch
import torch_converse2d

out = torch.ops.converse2d.forward(x, x0, weight, bias, scale, eps)
print(torch.ops.converse2d)
```
48 changes: 48 additions & 0 deletions Converse2D/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension
import os, pathlib

PKG_DIR = pathlib.Path(__file__).resolve().parent / "torch_converse2d"


CPP = str(PKG_DIR / f"converse2d.cpp")
CU = str(PKG_DIR / f"converse2d.cu")
has_cu = os.path.exists(CU)

extra_cflags = ["-O3"]
extra_cuda = ["-O3"]

if has_cu and "TORCH_CUDA_ARCH_LIST" not in os.environ:
try:
import torch
if torch.cuda.is_available():
maj, min = torch.cuda.get_device_capability(0)
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{maj}.{min}+PTX"
except Exception:
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "8.0;8.6;8.9+PTX")

if has_cu:
ext = CUDAExtension(
name="converse2d_ext",
sources=[CPP, CU],
extra_compile_args={"cxx": extra_cflags, "nvcc": extra_cuda},
)
else:
ext = CppExtension(
name="converse2d_ext",
sources=[CPP],
extra_compile_args={"cxx": extra_cflags},
)

print(f"[setup.py] building sources={[p for p in ([CPP] + ([CU] if has_cu else []))]}")
print(f"[setup.py] TORCH_CUDA_ARCH_LIST={os.environ.get('TORCH_CUDA_ARCH_LIST','<unset>')}")

setup(
name="torch_converse2d",
version="0.1",
description="Converse2D CUDA extension for PyTorch",
packages=["torch_converse2d"],
ext_modules=[ext],
cmdclass={"build_ext": BuildExtension},
zip_safe=False,
)
7 changes: 7 additions & 0 deletions Converse2D/torch_converse2d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os
try:
import converse2d_ext
except Exception as e:
print("[torch_converse2d] extension import failed:", e)

__all__ = ["converse2d_ext"]
Binary file not shown.
229 changes: 229 additions & 0 deletions Converse2D/torch_converse2d/converse2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@

#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/roll.h>
#include <ATen/ops/fft_fftn.h>
#include <ATen/ops/fft_ifftn.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/real.h>
#include <ATen/ops/sigmoid.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/conj_physical.h>

#include <unordered_map>
#include <tuple>
#include <list>
#include <mutex>

using at::Tensor;

struct FBKey
{
int64_t device_id;
at::ScalarType dtype;
int64_t channels;
int64_t H, W;
void *ptr;

bool operator==(const FBKey &other) const
{
return device_id == other.device_id && dtype == other.dtype &&
channels == other.channels && H == other.H && W == other.W &&
ptr == other.ptr;
}
};

namespace std
{
template <>
struct hash<FBKey>
{
size_t operator()(const FBKey &k) const
{
return ((hash<int64_t>()(k.device_id) ^ hash<int64_t>()(k.channels)) << 1) ^
((hash<int64_t>()(k.H) ^ hash<int64_t>()(k.W)) << 1) ^
((hash<void *>()(k.ptr)) ^ hash<int>()(static_cast<int>(k.dtype)));
}
};
}

constexpr size_t FB_CACHE_MAX_SIZE = 64;

static std::unordered_map<FBKey, std::tuple<at::Tensor, at::Tensor, at::Tensor>> fb_cache;
static std::list<FBKey> fb_cache_lru;
static std::mutex fb_cache_mutex;

static inline std::tuple<Tensor, Tensor, at::Tensor> p2o_cached(const Tensor &psf, int64_t H, int64_t W)
{
const bool training_with_grad = at::GradMode::is_enabled() && psf.requires_grad();
auto C = psf.size(1);
FBKey key{
psf.device().index(),
psf.scalar_type(),
C, H, W,
psf.data_ptr()};

if (!training_with_grad)
{
std::lock_guard<std::mutex> lock(fb_cache_mutex);
auto it = fb_cache.find(key);
if (it != fb_cache.end())
{
fb_cache_lru.remove(key);
fb_cache_lru.push_front(key);
return it->second;
}
}

Tensor otf = at::zeros({1, C, H, W}, psf.options());
int64_t kh = psf.size(2), kw = psf.size(3);
otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf);
otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1});
Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt);
Tensor FBC = at::conj_physical(FB);
Tensor F2B = at::abs(FB).pow(2);

if (!training_with_grad)
{
std::lock_guard<std::mutex> lock(fb_cache_mutex);
fb_cache[key] = std::make_tuple(FB, FBC, F2B);
fb_cache_lru.push_front(key);

if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE)
{
fb_cache.erase(fb_cache_lru.back());
fb_cache_lru.pop_back();
}
}

return std::make_tuple(FB, FBC, F2B);
}

static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s)
{
TORCH_CHECK(s >= 1, "scale must be >= 1");
if (s == 1)
return x;
auto sizes = x.sizes().vec();
sizes[sizes.size() - 2] *= s;
sizes[sizes.size() - 1] *= s;
Tensor z = at::zeros(sizes, x.options());
z.index_put_(
{at::indexing::Slice(), at::indexing::Slice(),
at::indexing::Slice(0, z.size(-2), s),
at::indexing::Slice(0, z.size(-1), s)},
x);
return z;
}

static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s)
{
TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims");
TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale");

const auto &sizes = a.sizes();
const int64_t L = a.dim();
const int64_t W = sizes[L - 2];
const int64_t H = sizes[L - 1];
const int64_t W_s = W / s;
const int64_t H_s = H / s;

std::vector<int64_t> view_shape;
view_shape.reserve(L + 2);
for (int64_t i = 0; i < L - 2; ++i)
view_shape.push_back(sizes[i]);
view_shape.push_back(s);
view_shape.push_back(W_s);
view_shape.push_back(s);
view_shape.push_back(H_s);
Tensor v = a.view(view_shape);

std::vector<int64_t> perm;
perm.reserve(view_shape.size());
for (int64_t i = 0; i < L - 2; ++i)
perm.push_back(i);
perm.push_back(L - 2 + 1);
perm.push_back(L - 2 + 3);
perm.push_back(L - 2 + 0);
perm.push_back(L - 2 + 2);
Tensor p = v.permute(perm).contiguous();

std::vector<int64_t> merge_shape;
merge_shape.reserve(L + 1);
for (int64_t i = 0; i < L - 2; ++i)
merge_shape.push_back(p.size(i));
merge_shape.push_back(W_s);
merge_shape.push_back(H_s);
merge_shape.push_back(s * s);
Tensor r = p.view(merge_shape);

return r.mean(-1, /*keepdim=*/false);
}

Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps)
{
TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)");
TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)");
TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)");
TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)");
TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device");
TORCH_CHECK(scale >= 1, "scale must be >= 1");

x = x.contiguous();
x0 = x0.contiguous();
weight = weight.contiguous();
bias = bias.contiguous();

const int64_t B = x.size(0);
const int64_t C = x.size(1);
const int64_t H = x.size(2);
const int64_t W = x.size(3);
const int64_t Hs = H * scale;
const int64_t Ws = W * scale;

Tensor lambda_ = at::sigmoid(bias - 9.0) + eps;
Tensor STy = sfold_upsample_zero_insertion(x, scale);

auto [FB, FBC, F2B] = p2o_cached(weight, Hs, Ws);

Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt);
Tensor FBFy = FBC * F_STy;
Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt);

Tensor x1 = FB * FR;
Tensor FBR = splits_mean_then_mean(x1, scale);
Tensor invW = splits_mean_then_mean(F2B, scale);

Tensor invW_plus = invW + lambda_;
Tensor invWBR = FBR / invW_plus;

Tensor invWBR_rep = invWBR.repeat({1, 1, scale, scale});
Tensor FCBinvWBR = FBC * invWBR_rep;

Tensor FX = (FR - FCBinvWBR) / lambda_;
Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt);
Tensor out = at::real(out_c);
return out;
}

void clear_fb_cache()
{
std::lock_guard<std::mutex> lock(fb_cache_mutex);
fb_cache.clear();
fb_cache_lru.clear();
}

TORCH_LIBRARY(converse2d, m)
{
m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor");
m.def("clear_cache() -> ()");
}
TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m)
{
m.impl("forward", TORCH_FN(converse2d_forward));
m.impl("clear_cache", TORCH_FN(clear_fb_cache));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ___________
* [Visual results of Converse-USRNet](#visual-results-of-converse-usrnet)



Motivation
----------
Convolution and transposed convolution (often referred to as deconvolution) are fundamental operations in deep neural networks. Convolution is commonly used for feature extraction and spatial downsampling. In contrast, transposed convolution is used to upsample spatial dimensions. Due to this functional relationship, transposed convolution is sometimes described in the literature as a reverse convolution operator. However, it is not the mathematical inverse of convolution. Instead, it performs upsampling by inserting zeros between input elements, followed by a standard convolution. While this interpretation is widely accepted, implementing a reverse convolution operator has received little attention. Notably, popular deep learning frameworks such as PyTorch do not provide native support for such an operator.
Expand All @@ -43,7 +44,6 @@ $$
\mathbf{X}^\ast = \arg\min_{\mathbf{X}} \left\| \mathbf{Y} - \left( \mathbf{X} \otimes \mathbf{K} \right) \downarrow_{s} \right\|_F^2 + \lambda \left\| \mathbf{X} - \mathbf{X}_0 \right\|_F^2,
$$


$$
\mathbf{X}^\ast = \arg\min_{\mathbf{X}} \left\| \mathbf{Y} - \left( \mathbf{X} \otimes \mathbf{K} \right) \downarrow_{s} \right\|_F^2
$$
Expand Down Expand Up @@ -136,4 +136,3 @@ Citation
}
```


Binary file added figs/pytorch_hotmap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading