From 2760962e31cc6f4c0fcbe6e1eb18e4e17f3e6652 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Fri, 15 Aug 2025 16:33:32 +0800 Subject: [PATCH 01/22] add naive cuda implement --- models/CUDA/converse2d_cuda.cu | 92 ++++++++++++++++++++++ models/CUDA/converse2d_op.cpp | 13 +++ models/benchmark.py | 139 +++++++++++++++++++++++++++++++++ models/util_converse.py | 86 ++++++++++++++------ 4 files changed, 306 insertions(+), 24 deletions(-) create mode 100644 models/CUDA/converse2d_cuda.cu create mode 100644 models/CUDA/converse2d_op.cpp create mode 100644 models/benchmark.py diff --git a/models/CUDA/converse2d_cuda.cu b/models/CUDA/converse2d_cuda.cu new file mode 100644 index 0000000..7a7caae --- /dev/null +++ b/models/CUDA/converse2d_cuda.cu @@ -0,0 +1,92 @@ +#include +#include +#include + + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s %d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + + +torch::Tensor p2o_cuda(torch::Tensor psf, const std::vector& shape) { + auto otf = torch::zeros(torch::IntArrayRef({psf.size(0), psf.size(1), shape[0], shape[1]})).to(psf.device()); + otf.slice(2, 0, psf.size(2)).slice(3, 0, psf.size(3)).copy_(psf); + + otf = torch::roll(otf, {-psf.size(2) / 2, -psf.size(3) / 2}, {2, 3}); + return torch::fft::fftn(otf, c10::nullopt, c10::IntArrayRef({-2, -1})); +} + +torch::Tensor splits_cuda(torch::Tensor a, int scale) { + auto sizes = a.sizes(); + long W = sizes[2]; + long H = sizes[3]; + long W_s = W / scale; + long H_s = H / scale; + + auto b = a.view({sizes[0], sizes[1], scale, W_s, scale, H_s}); + b = b.permute({0, 1, 3, 5, 2, 4}).contiguous(); + return b.view({sizes[0], sizes[1], W_s, H_s, scale * scale}); +} + +torch::Tensor converse2d_cuda_forward( + torch::Tensor x, + torch::Tensor weight, + torch::Tensor bias, + int scale, + float eps +) { + TORCH_CHECK(x.is_cuda(), "Input tensor must be a CUDA tensor"); + TORCH_CHECK(weight.is_cuda(), "Weight tensor must be a CUDA tensor"); + TORCH_CHECK(bias.is_cuda(), "Bias tensor must be a CUDA tensor"); + + x = x.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); + + const int N = x.size(0); + const int C = x.size(1); + const int H = x.size(2); + const int W = x.size(3); + const int H_up = H * scale; + const int W_up = W * scale; + + auto biaseps = (torch::sigmoid(bias - 9.0f) + eps).contiguous(); + + auto STy = torch::zeros({N, C, H_up, W_up}, x.options()); + STy.slice(2, 0, H_up, scale).slice(3, 0, W_up, scale).copy_(x); + + if (scale != 1) { + x = torch::nn::functional::interpolate(x, + torch::nn::functional::InterpolateFuncOptions().scale_factor(std::vector({(double)scale, (double)scale})).mode(torch::kNearest)); + } + + auto FB = p2o_cuda(weight, {H_up, W_up}).contiguous(); + auto FBC = torch::conj(FB).contiguous(); + auto F2B = torch::pow(torch::abs(FB), 2).contiguous(); + + auto STy_fft = torch::fft::fftn(STy, c10::nullopt, c10::IntArrayRef({-2, -1})).contiguous(); + auto FBFy = (FBC * STy_fft).contiguous(); + + auto x_fft = torch::fft::fftn(biaseps * x, c10::nullopt, c10::IntArrayRef({-2, -1})).contiguous(); + + auto FR = FBFy + x_fft; + auto x1 = FB.mul(FR); + + auto FBR = torch::mean(splits_cuda(x1, scale), -1, false); + auto invW = torch::mean(splits_cuda(F2B.to(torch::kComplexFloat), scale), -1, false); + + auto invWBR = FBR.div(invW + biaseps.to(torch::kComplexFloat)); + auto FCBinvWBR = FBC * invWBR.repeat({1, 1, scale, scale}); + + auto FX = (FR - FCBinvWBR) / biaseps.to(torch::kComplexFloat); + + auto out_complex = torch::fft::ifftn(FX, c10::nullopt, c10::IntArrayRef({-2, -1})); + auto out = torch::real(out_complex); + + return out; +} \ No newline at end of file diff --git a/models/CUDA/converse2d_op.cpp b/models/CUDA/converse2d_op.cpp new file mode 100644 index 0000000..3cc0e80 --- /dev/null +++ b/models/CUDA/converse2d_op.cpp @@ -0,0 +1,13 @@ +#include + +torch::Tensor converse2d_cuda_forward( + torch::Tensor x, + torch::Tensor weight, + torch::Tensor bias, + int scale, + float eps +); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &converse2d_cuda_forward, "Converse2D forward (CUDA)"); +} \ No newline at end of file diff --git a/models/benchmark.py b/models/benchmark.py new file mode 100644 index 0000000..2f91d97 --- /dev/null +++ b/models/benchmark.py @@ -0,0 +1,139 @@ +import torch +import time +from util_converse import Converse2D + + +def benchmark(model, input_tensor, model_name, num_runs=100): + """ + Measures the average forward pass time of a model. + """ + print(f"Warming up {model_name} backend...") + # Warm-up runs to stabilize performance measurement + for _ in range(10): + _ = model(input_tensor) + torch.cuda.synchronize() + + print(f"Running benchmark for {model_name} backend ({num_runs} iterations)...") + start_time = time.time() + for _ in range(num_runs): + _ = model(input_tensor) + # Wait for all kernels to complete + torch.cuda.synchronize() + end_time = time.time() + + avg_time = (end_time - start_time) / num_runs + return avg_time + +def run_comparison(): + if not torch.cuda.is_available(): + print("CUDA is not available. Performance comparison cannot be run.") + return + + if Converse2D is None: + print("Converse2D class not loaded. Aborting benchmark.") + return + + params = { + 'in_channels': 64, + 'out_channels': 64, # Must be the same as in_channels for Converse2D + 'kernel_size': 3, + 'scale': 2, + 'padding': 2, + 'batch_size': 4, + 'height': 256, + 'width': 256, + 'device': torch.device("cuda") + } + + print("\n--- Benchmark Configuration ---") + for key, value in params.items(): + if key != 'device': + print(f"{key.replace('_', ' ').capitalize()}: {value}") + print(f"Device: {params['device']}") + print("---------------------------------\n") + + + # Create a dummy input tensor on the GPU + input_tensor = torch.randn( + params['batch_size'], + params['in_channels'], + params['height'], + params['width'] + ).to(params['device']) + + try: + # Initialize PyTorch backend model + print("Initializing PyTorch backend model...") + converse_torch = Converse2D( + in_channels=params['in_channels'], + out_channels=params['out_channels'], + kernel_size=params['kernel_size'], + scale=params['scale'], + padding=params['padding'], + backend='torch' + ).to(params['device']) + print("PyTorch backend model initialized.") + + # Initialize CUDA backend model (this will trigger the JIT compilation) + print("\nInitializing CUDA backend model (compilation may take a moment)...") + converse_cuda = Converse2D( + in_channels=params['in_channels'], + out_channels=params['out_channels'], + kernel_size=params['kernel_size'], + scale=params['scale'], + padding=params['padding'], + backend='cuda' + ).to(params['device']) + print("CUDA backend model initialized and compiled successfully.") + + except Exception as e: + print(f"\nAn error occurred during model initialization: {e}") + print("Please ensure that a compatible CUDA toolkit is installed and configured correctly for PyTorch.") + return + + # Run benchmarks + torch_time = benchmark(converse_torch, input_tensor, "PyTorch") + cuda_time = benchmark(converse_cuda, input_tensor, "CUDA") + + # --- Step 4: Report the results --- + print("\n--- Performance Comparison Results ---") + print(f"Input Tensor Shape: ({params['batch_size']}, {params['in_channels']}, {params['height']}, {params['width']})") + print(f"PyTorch Backend Average Time: {torch_time * 1000:.4f} ms") + print(f"CUDA Backend Average Time: {cuda_time * 1000:.4f} ms") + print("--------------------------------------") + + if cuda_time > 0: + speedup = torch_time / cuda_time + print(f"The CUDA implementation is approximately {speedup:.2f}x faster than the PyTorch implementation.") + else: + print("Could not calculate speedup due to zero execution time.") + +if __name__ == "__main__": + run_comparison() + +""" +--- Device Details --- +GPU Architecture: RTX 2080ti +CUDA version: 12.8 +Torch verison: 2.8.0 + +--- Benchmark Configuration --- +In channels: 64 +Out channels: 64 +Kernel size: 3 +Scale: 2 +Padding: 2 +Batch size: 4 +Height: 256 +Width: 256 +Device: cuda +--------------------------------- + +--- Performance Comparison Results --- +Input Tensor Shape: (4, 64, 256, 256) +PyTorch Backend Average Time: 131.7963 ms +CUDA Backend Average Time: 67.5533 ms +-------------------------------------- +The CUDA implementation is approximately 1.95x faster than the PyTorch implementation. + +""" \ No newline at end of file diff --git a/models/util_converse.py b/models/util_converse.py index 8757172..1990fda 100644 --- a/models/util_converse.py +++ b/models/util_converse.py @@ -1,13 +1,41 @@ +import os +os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5' import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.cpp_extension import load from collections import OrderedDict + + """ # -------------------------------------------- # LayerNorm for Vision Normalization # -------------------------------------------- """ + + +def load_converse2D(): + current_file_dir = os.path.dirname(os.path.abspath(__file__)) + converse2D_cuda = load( + name="converse2D_cuda", + sources=[ + os.path.join(current_file_dir, "CUDA/converse2d_op.cpp"), + os.path.join(current_file_dir, "CUDA/converse2d_cuda.cu"), + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "-gencode arch=compute_75,code=sm_75", + ], + ) + return converse2D_cuda + + class LayerNorm(nn.Module): ''' LayerNorm that supports two data formats: channels_last (default) or channels_first. @@ -65,7 +93,7 @@ def sequential(*args): # -------------------------------------------- """ class Converse2D(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, padding_mode='circular', eps=1e-5): + def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, padding_mode='circular', eps=1e-5, backend:str='torch'): super(Converse2D, self).__init__() """ Converse2D Operator for Image Restoration Tasks. @@ -95,6 +123,8 @@ def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, p self.padding = padding self.padding_mode = padding_mode self.eps = eps + assert backend in ['torch', 'cuda'], "Not Implementd Yet" + self.backend = backend # ensure depthwise @@ -109,27 +139,35 @@ def forward(self, x): if self.padding > 0: x = nn.functional.pad(x, pad=[self.padding, self.padding, self.padding, self.padding], mode=self.padding_mode, value=0) - self.biaseps = torch.sigmoid(self.bias-9.0) + self.eps - _, _, h, w = x.shape - STy = self.upsample(x, scale=self.scale) - if self.scale != 1: - x = nn.functional.interpolate(x, scale_factor=self.scale, mode='nearest') - # x = nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear',align_corners=False) - # x = torch.zeros_like(x) - - FB = self.p2o(self.weight, (h*self.scale, w*self.scale)) - FBC = torch.conj(FB) - F2B = torch.pow(torch.abs(FB), 2) - FBFy = FBC*torch.fft.fftn(STy, dim=(-2, -1)) - - FR = FBFy + torch.fft.fftn(self.biaseps*x, dim=(-2,-1)) - x1 = FB.mul(FR) - FBR = torch.mean(self.splits(x1, self.scale), dim=-1, keepdim=False) - invW = torch.mean(self.splits(F2B, self.scale), dim=-1, keepdim=False) - invWBR = FBR.div(invW + self.biaseps) - FCBinvWBR = FBC*invWBR.repeat(1, 1, self.scale, self.scale) - FX = (FR-FCBinvWBR)/self.biaseps - out = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) + if self.backend == 'torch': + + self.biaseps = torch.sigmoid(self.bias-9.0) + self.eps + _, _, h, w = x.shape + STy = self.upsample(x, scale=self.scale) + if self.scale != 1: + x = nn.functional.interpolate(x, scale_factor=self.scale, mode='nearest') + # x = nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear',align_corners=False) + # x = torch.zeros_like(x) + + FB = self.p2o(self.weight, (h*self.scale, w*self.scale)) + FBC = torch.conj(FB) + F2B = torch.pow(torch.abs(FB), 2) + FBFy = FBC*torch.fft.fftn(STy, dim=(-2, -1)) + + FR = FBFy + torch.fft.fftn(self.biaseps*x, dim=(-2,-1)) + x1 = FB.mul(FR) + FBR = torch.mean(self.splits(x1, self.scale), dim=-1, keepdim=False) + invW = torch.mean(self.splits(F2B, self.scale), dim=-1, keepdim=False) + invWBR = FBR.div(invW + self.biaseps) + FCBinvWBR = FBC*invWBR.repeat(1, 1, self.scale, self.scale) + FX = (FR-FCBinvWBR)/self.biaseps + out = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) + + elif self.backend == 'cuda': + out = load_converse2D().forward(x, self.weight, self.bias, self.scale, self.eps) + else: + raise NotImplementedError + if self.padding > 0: out = out[..., self.padding*self.scale:-self.padding*self.scale, self.padding*self.scale:-self.padding*self.scale] @@ -196,7 +234,7 @@ def upsample(self, x, scale=3): # -------------------------------------------- """ class ConverseBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3, scale=1, padding=2, padding_mode='replicate', eps=1e-5): + def __init__(self, in_channels, out_channels, kernel_size=3, scale=1, padding=2, padding_mode='replicate', eps=1e-5, backend='torch'): super(ConverseBlock, self).__init__() """ ConverseBlock: A Convolutional Block for Image Restoration using Converse2D Operations. @@ -223,7 +261,7 @@ def __init__(self, in_channels, out_channels, kernel_size=3, scale=1, padding=2, self.conv1 = nn.Sequential(LayerNorm(in_channels, eps=1e-5, data_format="channels_first"), nn.Conv2d(in_channels, 2*out_channels, 1, 1, 0), nn.GELU(), - Converse2D(2*out_channels, 2*out_channels, kernel_size, scale=scale, padding=padding, padding_mode=padding_mode, eps=eps), + Converse2D(2*out_channels, 2*out_channels, kernel_size, scale=scale, padding=padding, padding_mode=padding_mode, eps=eps, backend=backend), nn.GELU(), nn.Conv2d(2*out_channels, out_channels, 1, 1, 0)) From bfa771cd13659cfc73dc2d9f609eb0fcb85c28fc Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Sat, 16 Aug 2025 13:35:48 +0800 Subject: [PATCH 02/22] add naive backward --- models/CUDA/converse2d_cuda.cu | 129 ++++++++++++++++++++++----------- models/CUDA/converse2d_op.cpp | 15 ++-- models/benchmark.py | 128 +++++++++++++++++++++++--------- models/util_converse.py | 43 ++++++++--- 4 files changed, 222 insertions(+), 93 deletions(-) diff --git a/models/CUDA/converse2d_cuda.cu b/models/CUDA/converse2d_cuda.cu index 7a7caae..8ceed7f 100644 --- a/models/CUDA/converse2d_cuda.cu +++ b/models/CUDA/converse2d_cuda.cu @@ -2,7 +2,6 @@ #include #include - #define CUDA_CHECK(call) \ do { \ cudaError_t err = call; \ @@ -12,7 +11,6 @@ } \ } while (0) - torch::Tensor p2o_cuda(torch::Tensor psf, const std::vector& shape) { auto otf = torch::zeros(torch::IntArrayRef({psf.size(0), psf.size(1), shape[0], shape[1]})).to(psf.device()); otf.slice(2, 0, psf.size(2)).slice(3, 0, psf.size(3)).copy_(psf); @@ -21,7 +19,7 @@ torch::Tensor p2o_cuda(torch::Tensor psf, const std::vector& shape) { return torch::fft::fftn(otf, c10::nullopt, c10::IntArrayRef({-2, -1})); } -torch::Tensor splits_cuda(torch::Tensor a, int scale) { +inline torch::Tensor splits_cuda(torch::Tensor a, int scale) { auto sizes = a.sizes(); long W = sizes[2]; long H = sizes[3]; @@ -33,60 +31,103 @@ torch::Tensor splits_cuda(torch::Tensor a, int scale) { return b.view({sizes[0], sizes[1], W_s, H_s, scale * scale}); } -torch::Tensor converse2d_cuda_forward( - torch::Tensor x, - torch::Tensor weight, - torch::Tensor bias, - int scale, - float eps -) { - TORCH_CHECK(x.is_cuda(), "Input tensor must be a CUDA tensor"); - TORCH_CHECK(weight.is_cuda(), "Weight tensor must be a CUDA tensor"); - TORCH_CHECK(bias.is_cuda(), "Bias tensor must be a CUDA tensor"); +inline torch::Tensor unsplit_cuda(torch::Tensor b, int scale) { + auto sizes = b.sizes(); + long N = sizes[0]; + long C = sizes[1]; + long W_s = sizes[2]; + long H_s = sizes[3]; + + auto a = b.view({N, C, W_s, H_s, scale, scale}); + a = a.permute({0, 1, 4, 2, 5, 3}).contiguous(); + return a.view({N, C, W_s * scale, H_s * scale}); +} - x = x.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); +inline torch::Tensor interpolate_backward_nearest_cuda(torch::Tensor grad_out, int scale) { + if (scale == 1) return grad_out; + auto options = torch::nn::functional::AvgPool2dFuncOptions({scale, scale}).stride({scale, scale}); + auto grad_in = torch::nn::functional::avg_pool2d(grad_out, options); + return grad_in * (scale * scale); +} - const int N = x.size(0); - const int C = x.size(1); - const int H = x.size(2); - const int W = x.size(3); - const int H_up = H * scale; - const int W_up = W * scale; +std::vector converse2d_cuda_forward( + torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int scale, float eps +) { + TORCH_CHECK(x.is_cuda(), "Input tensor must be a CUDA tensor"); + const int H_up = x.size(2) * scale; + const int W_up = x.size(3) * scale; auto biaseps = (torch::sigmoid(bias - 9.0f) + eps).contiguous(); - - auto STy = torch::zeros({N, C, H_up, W_up}, x.options()); + auto STy = torch::zeros({x.size(0), x.size(1), H_up, W_up}, x.options()); STy.slice(2, 0, H_up, scale).slice(3, 0, W_up, scale).copy_(x); - + auto x_interp = x; if (scale != 1) { - x = torch::nn::functional::interpolate(x, + x_interp = torch::nn::functional::interpolate(x, torch::nn::functional::InterpolateFuncOptions().scale_factor(std::vector({(double)scale, (double)scale})).mode(torch::kNearest)); } - auto FB = p2o_cuda(weight, {H_up, W_up}).contiguous(); auto FBC = torch::conj(FB).contiguous(); - auto F2B = torch::pow(torch::abs(FB), 2).contiguous(); - auto STy_fft = torch::fft::fftn(STy, c10::nullopt, c10::IntArrayRef({-2, -1})).contiguous(); - auto FBFy = (FBC * STy_fft).contiguous(); - - auto x_fft = torch::fft::fftn(biaseps * x, c10::nullopt, c10::IntArrayRef({-2, -1})).contiguous(); - - auto FR = FBFy + x_fft; - auto x1 = FB.mul(FR); - - auto FBR = torch::mean(splits_cuda(x1, scale), -1, false); - auto invW = torch::mean(splits_cuda(F2B.to(torch::kComplexFloat), scale), -1, false); - + auto x_fft = torch::fft::fftn(biaseps * x_interp, c10::nullopt, c10::IntArrayRef({-2, -1})).contiguous(); + auto FR = (FBC * STy_fft) + x_fft; + auto invW = torch::mean(splits_cuda(torch::pow(torch::abs(FB), 2).to(torch::kComplexFloat), scale), -1, false); + auto FBR = torch::mean(splits_cuda(FB.mul(FR), scale), -1, false); auto invWBR = FBR.div(invW + biaseps.to(torch::kComplexFloat)); auto FCBinvWBR = FBC * invWBR.repeat({1, 1, scale, scale}); - auto FX = (FR - FCBinvWBR) / biaseps.to(torch::kComplexFloat); - - auto out_complex = torch::fft::ifftn(FX, c10::nullopt, c10::IntArrayRef({-2, -1})); - auto out = torch::real(out_complex); + auto out = torch::real(torch::fft::ifftn(FX, c10::nullopt, c10::IntArrayRef({-2, -1}))); + + return {out, x_interp, biaseps, FB, FBC, FR, invW, FBR, invWBR, STy_fft}; +} - return out; +std::vector converse2d_cuda_backward( + torch::Tensor grad_out, torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int scale, + const std::vector& saved_tensors +) { + // --- Unpack saved tensors --- + auto x_interp = saved_tensors[0]; + auto biaseps = saved_tensors[1]; + auto FB = saved_tensors[2]; + auto FBC = saved_tensors[3]; + auto FR = saved_tensors[4]; + auto invW = saved_tensors[5]; + auto FBR = saved_tensors[6]; + auto invWBR = saved_tensors[7]; + auto STy_fft = saved_tensors[8]; + + + auto grad_out_c = grad_out.to(torch::kComplexFloat); + auto grad_FX = torch::fft::fftn(grad_out_c, c10::nullopt, c10::IntArrayRef({-2, -1})); + auto FCBinvWBR = FBC * invWBR.repeat({1, 1, scale, scale}); + auto grad_FR = grad_FX / biaseps.to(torch::kComplexFloat); + auto grad_FCBinvWBR = -grad_FR; + auto grad_biaseps = -torch::sum(torch::real(grad_FX * (FR - FCBinvWBR) / torch::pow(biaseps.to(torch::kComplexFloat), 2)), {0, 2, 3}, true); + auto grad_FBC = grad_FCBinvWBR * invWBR.repeat({1, 1, scale, scale}); + auto grad_invWBR = torch::sum(splits_cuda(grad_FCBinvWBR * FBC, scale), -1, false); + auto denom = invW + biaseps.to(torch::kComplexFloat); + auto grad_FBR = grad_invWBR / denom; + auto grad_denom = -grad_invWBR * FBR / torch::pow(denom, 2); + grad_biaseps += torch::sum(torch::real(grad_denom), {0, 2, 3}, true); + auto F2B = torch::pow(torch::abs(FB), 2); + auto grad_F2B = unsplit_cuda( (grad_denom).unsqueeze(-1).expand({-1, -1, -1, -1, scale * scale}) / (float)(scale * scale), scale).to(F2B.dtype()); + auto grad_x1 = unsplit_cuda( (grad_FBR).unsqueeze(-1).expand({-1, -1, -1, -1, scale * scale}) / (float)(scale * scale), scale); + auto grad_FB = grad_x1 * torch::conj(FR); + grad_FR += grad_x1 * torch::conj(FB); + auto grad_x_fft = grad_FR; + auto grad_biaseps_x_interp = torch::real(torch::fft::ifftn(grad_x_fft, c10::nullopt, c10::IntArrayRef({-2, -1}))); + grad_biaseps += torch::sum(grad_biaseps_x_interp * x_interp, {0, 2, 3}, true); + auto grad_x_interp = grad_biaseps_x_interp * biaseps; + grad_FBC += grad_FR * torch::conj(STy_fft); + auto grad_STy_fft = grad_FR * torch::conj(FBC); + auto grad_STy = torch::real(torch::fft::ifftn(grad_STy_fft, c10::nullopt, c10::IntArrayRef({-2, -1}))); + const int H_up = x.size(2) * scale; const int W_up = x.size(3) * scale; + auto grad_x = grad_STy.slice(2, 0, H_up, scale).slice(3, 0, W_up, scale) + interpolate_backward_nearest_cuda(grad_x_interp, scale); + grad_FB += 2 * F2B.to(torch::kComplexFloat) * FBC; + grad_FB += torch::conj(grad_FBC); + auto grad_otf = torch::real(torch::fft::ifftn(grad_FB, c10::nullopt, c10::IntArrayRef({-2, -1}))); + auto grad_weight = torch::roll(grad_otf, {weight.size(2) / 2, weight.size(3) / 2}, {2, 3}).slice(2, 0, weight.size(2)).slice(3, 0, weight.size(3)).clone(); + auto sig = torch::sigmoid(bias - 9.0f); + auto grad_bias = grad_biaseps * sig * (1.0f - sig); + + return {grad_x, grad_weight, grad_bias}; } \ No newline at end of file diff --git a/models/CUDA/converse2d_op.cpp b/models/CUDA/converse2d_op.cpp index 3cc0e80..00a0d58 100644 --- a/models/CUDA/converse2d_op.cpp +++ b/models/CUDA/converse2d_op.cpp @@ -1,13 +1,16 @@ #include +#include -torch::Tensor converse2d_cuda_forward( - torch::Tensor x, - torch::Tensor weight, - torch::Tensor bias, - int scale, - float eps +std::vector converse2d_cuda_forward( + torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int scale, float eps +); + +std::vector converse2d_cuda_backward( + torch::Tensor grad_out, torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int scale, + const std::vector& saved_tensors ); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &converse2d_cuda_forward, "Converse2D forward (CUDA)"); + m.def("backward", &converse2d_cuda_backward, "Converse2D backward (CUDA)"); } \ No newline at end of file diff --git a/models/benchmark.py b/models/benchmark.py index 2f91d97..e8a3807 100644 --- a/models/benchmark.py +++ b/models/benchmark.py @@ -1,29 +1,62 @@ import torch -import time from util_converse import Converse2D -def benchmark(model, input_tensor, model_name, num_runs=100): +def benchmark(model, input_tensor, model_name, pass_type='forward', num_runs=100): """ - Measures the average forward pass time of a model. + Measures the average forward or backward pass time of a model. + Uses torch.cuda.Event for precise GPU timing. """ - print(f"Warming up {model_name} backend...") - # Warm-up runs to stabilize performance measurement + if pass_type not in ['forward', 'backward']: + raise ValueError("pass_type must be 'forward' or 'backward'") + + print(f"Warming up {model_name} backend for {pass_type} pass...") + # Warm-up runs to stabilize performance and CUDA kernels for _ in range(10): - _ = model(input_tensor) - torch.cuda.synchronize() + output = model(input_tensor) + if pass_type == 'backward': + grad_output = torch.ones_like(output) + output.backward(gradient=grad_output) - print(f"Running benchmark for {model_name} backend ({num_runs} iterations)...") - start_time = time.time() - for _ in range(num_runs): - _ = model(input_tensor) - # Wait for all kernels to complete torch.cuda.synchronize() - end_time = time.time() - avg_time = (end_time - start_time) / num_runs + print(f"Running benchmark for {model_name} backend {pass_type} pass ({num_runs} iterations)...") + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + total_time = 0.0 + + if pass_type == 'forward': + start_event.record() + for _ in range(num_runs): + _ = model(input_tensor) + end_event.record() + else: + for _ in range(num_runs): + output = model(input_tensor) + grad_output = torch.ones_like(output) + torch.cuda.synchronize() + + start_event.record() + output.backward(gradient=grad_output) + end_event.record() + + torch.cuda.synchronize() + total_time += start_event.elapsed_time(end_event) + + del output, grad_output + + if pass_type == 'forward': + torch.cuda.synchronize() + total_time = start_event.elapsed_time(end_event) + + avg_time = (total_time / 1000.0) / num_runs return avg_time +# ======================================================================= +# The 'run_comparison' function and the rest of the file remain the same. +# No changes are needed below this line. +# ======================================================================= def run_comparison(): if not torch.cuda.is_available(): print("CUDA is not available. Performance comparison cannot be run.") @@ -35,7 +68,7 @@ def run_comparison(): params = { 'in_channels': 64, - 'out_channels': 64, # Must be the same as in_channels for Converse2D + 'out_channels': 64, 'kernel_size': 3, 'scale': 2, 'padding': 2, @@ -52,8 +85,6 @@ def run_comparison(): print(f"Device: {params['device']}") print("---------------------------------\n") - - # Create a dummy input tensor on the GPU input_tensor = torch.randn( params['batch_size'], params['in_channels'], @@ -62,7 +93,6 @@ def run_comparison(): ).to(params['device']) try: - # Initialize PyTorch backend model print("Initializing PyTorch backend model...") converse_torch = Converse2D( in_channels=params['in_channels'], @@ -74,7 +104,6 @@ def run_comparison(): ).to(params['device']) print("PyTorch backend model initialized.") - # Initialize CUDA backend model (this will trigger the JIT compilation) print("\nInitializing CUDA backend model (compilation may take a moment)...") converse_cuda = Converse2D( in_channels=params['in_channels'], @@ -91,22 +120,45 @@ def run_comparison(): print("Please ensure that a compatible CUDA toolkit is installed and configured correctly for PyTorch.") return - # Run benchmarks - torch_time = benchmark(converse_torch, input_tensor, "PyTorch") - cuda_time = benchmark(converse_cuda, input_tensor, "CUDA") - - # --- Step 4: Report the results --- - print("\n--- Performance Comparison Results ---") + # --- Run Forward Pass benchmarks --- + print("\n--- Benchmarking Forward Pass ---") + cuda_forward_time = benchmark(converse_cuda, input_tensor, "CUDA", pass_type='forward') + torch_forward_time = benchmark(converse_torch, input_tensor, "PyTorch", pass_type='forward') + + # --- Run Backward Pass benchmarks --- + print("\n--- Benchmarking Backward Pass ---") + # Enable gradient computation on the input tensor for the backward pass + input_tensor.requires_grad_(True) + cuda_backward_time = benchmark(converse_cuda, input_tensor, "CUDA", pass_type='backward') + torch_backward_time = benchmark(converse_torch, input_tensor, "PyTorch", pass_type='backward') + + # --- Report the results --- + print("\n\n--- Performance Comparison Results ---") print(f"Input Tensor Shape: ({params['batch_size']}, {params['in_channels']}, {params['height']}, {params['width']})") - print(f"PyTorch Backend Average Time: {torch_time * 1000:.4f} ms") - print(f"CUDA Backend Average Time: {cuda_time * 1000:.4f} ms") print("--------------------------------------") - if cuda_time > 0: - speedup = torch_time / cuda_time - print(f"The CUDA implementation is approximately {speedup:.2f}x faster than the PyTorch implementation.") + # Forward pass results + print("\n--- Forward Pass ---") + print(f"PyTorch Backend Average Time: {torch_forward_time * 1000:.4f} ms") + print(f"CUDA Backend Average Time: {cuda_forward_time * 1000:.4f} ms") + if cuda_forward_time > 0: + forward_speedup = torch_forward_time / cuda_forward_time + print(f"CUDA implementation is {forward_speedup:.2f}x faster.") + else: + print("Could not calculate forward pass speedup.") + + # Backward pass results + print("\n--- Backward Pass ---") + print(f"PyTorch Backend Average Time: {torch_backward_time * 1000:.4f} ms") + print(f"CUDA Backend Average Time: {cuda_backward_time * 1000:.4f} ms") + if cuda_backward_time > 0: + backward_speedup = torch_backward_time / cuda_backward_time + print(f"CUDA implementation is {backward_speedup:.2f}x faster.") else: - print("Could not calculate speedup due to zero execution time.") + print("Could not calculate backward pass speedup.") + + print("\n--------------------------------------") + if __name__ == "__main__": run_comparison() @@ -131,8 +183,18 @@ def run_comparison(): --- Performance Comparison Results --- Input Tensor Shape: (4, 64, 256, 256) -PyTorch Backend Average Time: 131.7963 ms -CUDA Backend Average Time: 67.5533 ms +-------------------------------------- + +--- Forward Pass --- +PyTorch Backend Average Time: 202.0048 ms +CUDA Backend Average Time: 76.2283 ms +CUDA implementation is 2.65x faster. + +--- Backward Pass --- +PyTorch Backend Average Time: 105.8770 ms +CUDA Backend Average Time: 123.8077 ms +CUDA implementation is 0.86x faster. + -------------------------------------- The CUDA implementation is approximately 1.95x faster than the PyTorch implementation. diff --git a/models/util_converse.py b/models/util_converse.py index 1990fda..cd134c5 100644 --- a/models/util_converse.py +++ b/models/util_converse.py @@ -13,11 +13,11 @@ # LayerNorm for Vision Normalization # -------------------------------------------- """ - +converse2D_cuda = None def load_converse2D(): current_file_dir = os.path.dirname(os.path.abspath(__file__)) - converse2D_cuda = load( + converse2D_cuda_lib = load( name="converse2D_cuda", sources=[ os.path.join(current_file_dir, "CUDA/converse2d_op.cpp"), @@ -33,7 +33,35 @@ def load_converse2D(): "-gencode arch=compute_75,code=sm_75", ], ) - return converse2D_cuda + return converse2D_cuda_lib + + +class Converse2DFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, scale, eps): + ctx.scale = scale + + global converse2D_cuda + if converse2D_cuda is None: + converse2D_cuda = load_converse2D() + + out = converse2D_cuda.forward(x, weight, bias, scale, eps) + tensors_to_save = out[1:] + ctx.save_for_backward(x, weight, bias, *tensors_to_save) + + return out[0] + + @staticmethod + def backward(ctx, grad_out): + x, weight, bias, *saved_tensors = ctx.saved_tensors + scale = ctx.scale + + grad_x, grad_weight, grad_bias = converse2D_cuda.backward(grad_out, x, weight, bias, scale, saved_tensors) + + return grad_x, grad_weight, grad_bias, None, None + +def Converse2D_CUDA(x, weight, bias, scale, eps): + return Converse2DFunction.apply(x, weight, bias, scale, eps) class LayerNorm(nn.Module): @@ -123,10 +151,9 @@ def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, p self.padding = padding self.padding_mode = padding_mode self.eps = eps - assert backend in ['torch', 'cuda'], "Not Implementd Yet" + assert backend in ['torch', 'cuda'], "Not Implementd Yet" #, 'cuda_fuse' self.backend = backend - # ensure depthwise assert self.out_channels == self.in_channels self.weight = nn.Parameter(torch.randn(1, self.in_channels, self.kernel_size, self.kernel_size)) @@ -135,12 +162,10 @@ def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, p def forward(self, x): - if self.padding > 0: x = nn.functional.pad(x, pad=[self.padding, self.padding, self.padding, self.padding], mode=self.padding_mode, value=0) if self.backend == 'torch': - self.biaseps = torch.sigmoid(self.bias-9.0) + self.eps _, _, h, w = x.shape STy = self.upsample(x, scale=self.scale) @@ -162,13 +187,11 @@ def forward(self, x): FCBinvWBR = FBC*invWBR.repeat(1, 1, self.scale, self.scale) FX = (FR-FCBinvWBR)/self.biaseps out = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) - elif self.backend == 'cuda': - out = load_converse2D().forward(x, self.weight, self.bias, self.scale, self.eps) + out = Converse2D_CUDA(x, self.weight, self.bias, self.scale, self.eps) else: raise NotImplementedError - if self.padding > 0: out = out[..., self.padding*self.scale:-self.padding*self.scale, self.padding*self.scale:-self.padding*self.scale] From 6cdc522ed0c27d25724e822d7a3ddfd4b7e41d79 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Fri, 29 Aug 2025 14:30:02 +0800 Subject: [PATCH 03/22] Update CUDA kenrel --- Converse2D/setup.py | 18 ++ Converse2D/torch_converse2d/__init__.py | 0 .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 156 bytes .../torch_converse2d/converse2d_ext.cpp | 146 +++++++++++++ README.md | 22 +- models/CUDA/converse2d_cuda.cu | 133 ------------ models/CUDA/converse2d_op.cpp | 16 -- models/benchmark.py | 201 ------------------ models/util_converse.py | 126 +++++------ test/Results.md | 36 ++++ test/test_error.py | 66 ++++++ test/test_speed.py | 177 +++++++++++++++ 12 files changed, 522 insertions(+), 419 deletions(-) create mode 100644 Converse2D/setup.py create mode 100644 Converse2D/torch_converse2d/__init__.py create mode 100644 Converse2D/torch_converse2d/__pycache__/__init__.cpython-311.pyc create mode 100644 Converse2D/torch_converse2d/converse2d_ext.cpp delete mode 100644 models/CUDA/converse2d_cuda.cu delete mode 100644 models/CUDA/converse2d_op.cpp delete mode 100644 models/benchmark.py create mode 100644 test/Results.md create mode 100644 test/test_error.py create mode 100644 test/test_speed.py diff --git a/Converse2D/setup.py b/Converse2D/setup.py new file mode 100644 index 0000000..30cb75f --- /dev/null +++ b/Converse2D/setup.py @@ -0,0 +1,18 @@ +from setuptools import setup +from torch.utils.cpp_extension import CppExtension, BuildExtension + +setup( + name="torch_converse2d", + version="0.1", + description="Converse2D CUDA extension for PyTorch", + packages=["torch_converse2d"], + ext_modules=[ + CppExtension( + name="converse2d_ext", + sources=["torch_converse2d/converse2d_ext.cpp"], + extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]}, + ) + ], + cmdclass={"build_ext": BuildExtension}, + zip_safe=False, +) diff --git a/Converse2D/torch_converse2d/__init__.py b/Converse2D/torch_converse2d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Converse2D/torch_converse2d/__pycache__/__init__.cpython-311.pyc b/Converse2D/torch_converse2d/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8edbc55e081516b4347e4e0cd6fc099e2cb6bfc GIT binary patch literal 156 zcmZ3^%ge<81e2XMW`O9&AOZ#$p^VRLK*n^26oz01O-8?!3`I;p{%4TnFMa)tO5Kv& z0)6NFyt34y;#8v){gV8m +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using at::Tensor; + +namespace { + +static inline Tensor sfold_upsample_zero_insertion(const Tensor& x, int64_t s) { + TORCH_CHECK(s >= 1, "scale must be >= 1"); + if (s == 1) return x; + auto sizes = x.sizes().vec(); + sizes[sizes.size()-2] *= s; + sizes[sizes.size()-1] *= s; + Tensor z = at::zeros(sizes, x.options()); + z.index_put_( + {at::indexing::Slice(), at::indexing::Slice(), + at::indexing::Slice(0, z.size(-2), s), + at::indexing::Slice(0, z.size(-1), s)}, x); + return z; +} + +static inline Tensor p2o(const Tensor& psf, int64_t H, int64_t W) { + TORCH_CHECK(psf.dim() == 4 && psf.size(0) == 1, "psf must be (1,C,kh,kw)"); + auto C = psf.size(1); + auto kh = psf.size(2); + auto kw = psf.size(3); + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + const int64_t sh = -static_cast(kh / 2); + const int64_t sw = -static_cast(kw / 2); + otf = at::roll(otf, {sh, sw}, {-2, -1}); + return at::fft_fftn(otf, c10::optional({}), c10::optional({-2,-1}), c10::nullopt); +} + +static inline Tensor splits_mean_then_mean(const Tensor& a, int64_t s) { + TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); + TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); + + const auto& sizes = a.sizes(); + const int64_t L = a.dim(); + const int64_t W = sizes[L-2]; + const int64_t H = sizes[L-1]; + const int64_t W_s = W / s; + const int64_t H_s = H / s; + + std::vector view_shape; + view_shape.reserve(L + 2); + for (int64_t i = 0; i < L-2; ++i) view_shape.push_back(sizes[i]); + view_shape.push_back(s); + view_shape.push_back(W_s); + view_shape.push_back(s); + view_shape.push_back(H_s); + Tensor v = a.view(view_shape); + + std::vector perm; + perm.reserve(view_shape.size()); + for (int64_t i = 0; i < L-2; ++i) perm.push_back(i); + perm.push_back(L-2 + 1); // W_s + perm.push_back(L-2 + 3); // H_s + perm.push_back(L-2 + 0); // s + perm.push_back(L-2 + 2); // s + Tensor p = v.permute(perm).contiguous(); + + std::vector merge_shape; + merge_shape.reserve(L+1); + for (int64_t i = 0; i < L-2; ++i) merge_shape.push_back(p.size(i)); + merge_shape.push_back(W_s); + merge_shape.push_back(H_s); + merge_shape.push_back(s * s); + Tensor r = p.view(merge_shape); + + return r.mean(-1, /*keepdim=*/false); +} + +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) { + TORCH_CHECK(x.dim()==4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim()==4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim()==4 && weight.size(0)==1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim()==4 && bias.size(0)==1 && bias.size(2)==1 && bias.size(3)==1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device()==x0.device() && x.device()==weight.device() && x.device()==bias.device(), "tensors on same device"); + TORCH_CHECK(scale >= 1, "scale must be >= 1"); + + x = x.contiguous(); + x0 = x0.contiguous(); + weight= weight.contiguous(); + bias = bias.contiguous(); + + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; + + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + + Tensor STy = sfold_upsample_zero_insertion(x, scale); + + Tensor FB = p2o(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + Tensor F2B = at::abs(FB).pow(2.0); + + Tensor F_STy = at::fft_fftn(STy, c10::optional({}), c10::optional({-2,-1}), c10::nullopt); + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::optional({}), c10::optional({-2,-1}), c10::nullopt); + + Tensor x1 = FB * FR; + + Tensor FBR = splits_mean_then_mean(x1, scale); + Tensor invW= splits_mean_then_mean(F2B, scale); + + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; + + Tensor invWBR_rep = invWBR.repeat({1,1,scale,scale}); + + Tensor FCBinvWBR = FBC * invWBR_rep; + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = at::fft_ifftn(FX, c10::optional({}), c10::optional({-2,-1}), c10::nullopt); + Tensor out = at::real(out_c); + (void)B; (void)C; (void)H; (void)W; + return out; +} + +} + + +TORCH_LIBRARY(converse2d, m) { + m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); +} +TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) { + m.impl("forward", TORCH_FN(converse2d_forward)); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/README.md b/README.md index cc47a99..4d1a03a 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,26 @@ ___________ * [Visual results of ConverseNet](#visual-results-of-conversenet) * [Visual results of Converse-USRNet](#visual-results-of-converse-usrnet) +Kernel Registry +---------- +**Installation** + +```python +cd ./Converse2D +pip install -e. +``` + +**Usage** + +```python +import torch +import torch_converse2d + +out = torch.ops.converse2d.forward(x, x0, weight, bias, scale, eps) +print(torch.ops.converse2d) +``` + + Motivation ---------- @@ -43,7 +63,6 @@ $$ \mathbf{X}^\ast = \arg\min_{\mathbf{X}} \left\| \mathbf{Y} - \left( \mathbf{X} \otimes \mathbf{K} \right) \downarrow_{s} \right\|_F^2 + \lambda \left\| \mathbf{X} - \mathbf{X}_0 \right\|_F^2, $$ - $$ \mathbf{X}^\ast = \arg\min_{\mathbf{X}} \left\| \mathbf{Y} - \left( \mathbf{X} \otimes \mathbf{K} \right) \downarrow_{s} \right\|_F^2 $$ @@ -136,4 +155,3 @@ Citation } ``` - diff --git a/models/CUDA/converse2d_cuda.cu b/models/CUDA/converse2d_cuda.cu deleted file mode 100644 index 8ceed7f..0000000 --- a/models/CUDA/converse2d_cuda.cu +++ /dev/null @@ -1,133 +0,0 @@ -#include -#include -#include - -#define CUDA_CHECK(call) \ - do { \ - cudaError_t err = call; \ - if (err != cudaSuccess) { \ - fprintf(stderr, "CUDA error at %s %d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -torch::Tensor p2o_cuda(torch::Tensor psf, const std::vector& shape) { - auto otf = torch::zeros(torch::IntArrayRef({psf.size(0), psf.size(1), shape[0], shape[1]})).to(psf.device()); - otf.slice(2, 0, psf.size(2)).slice(3, 0, psf.size(3)).copy_(psf); - - otf = torch::roll(otf, {-psf.size(2) / 2, -psf.size(3) / 2}, {2, 3}); - return torch::fft::fftn(otf, c10::nullopt, c10::IntArrayRef({-2, -1})); -} - -inline torch::Tensor splits_cuda(torch::Tensor a, int scale) { - auto sizes = a.sizes(); - long W = sizes[2]; - long H = sizes[3]; - long W_s = W / scale; - long H_s = H / scale; - - auto b = a.view({sizes[0], sizes[1], scale, W_s, scale, H_s}); - b = b.permute({0, 1, 3, 5, 2, 4}).contiguous(); - return b.view({sizes[0], sizes[1], W_s, H_s, scale * scale}); -} - -inline torch::Tensor unsplit_cuda(torch::Tensor b, int scale) { - auto sizes = b.sizes(); - long N = sizes[0]; - long C = sizes[1]; - long W_s = sizes[2]; - long H_s = sizes[3]; - - auto a = b.view({N, C, W_s, H_s, scale, scale}); - a = a.permute({0, 1, 4, 2, 5, 3}).contiguous(); - return a.view({N, C, W_s * scale, H_s * scale}); -} - -inline torch::Tensor interpolate_backward_nearest_cuda(torch::Tensor grad_out, int scale) { - if (scale == 1) return grad_out; - auto options = torch::nn::functional::AvgPool2dFuncOptions({scale, scale}).stride({scale, scale}); - auto grad_in = torch::nn::functional::avg_pool2d(grad_out, options); - return grad_in * (scale * scale); -} - - -std::vector converse2d_cuda_forward( - torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int scale, float eps -) { - TORCH_CHECK(x.is_cuda(), "Input tensor must be a CUDA tensor"); - const int H_up = x.size(2) * scale; - const int W_up = x.size(3) * scale; - auto biaseps = (torch::sigmoid(bias - 9.0f) + eps).contiguous(); - auto STy = torch::zeros({x.size(0), x.size(1), H_up, W_up}, x.options()); - STy.slice(2, 0, H_up, scale).slice(3, 0, W_up, scale).copy_(x); - auto x_interp = x; - if (scale != 1) { - x_interp = torch::nn::functional::interpolate(x, - torch::nn::functional::InterpolateFuncOptions().scale_factor(std::vector({(double)scale, (double)scale})).mode(torch::kNearest)); - } - auto FB = p2o_cuda(weight, {H_up, W_up}).contiguous(); - auto FBC = torch::conj(FB).contiguous(); - auto STy_fft = torch::fft::fftn(STy, c10::nullopt, c10::IntArrayRef({-2, -1})).contiguous(); - auto x_fft = torch::fft::fftn(biaseps * x_interp, c10::nullopt, c10::IntArrayRef({-2, -1})).contiguous(); - auto FR = (FBC * STy_fft) + x_fft; - auto invW = torch::mean(splits_cuda(torch::pow(torch::abs(FB), 2).to(torch::kComplexFloat), scale), -1, false); - auto FBR = torch::mean(splits_cuda(FB.mul(FR), scale), -1, false); - auto invWBR = FBR.div(invW + biaseps.to(torch::kComplexFloat)); - auto FCBinvWBR = FBC * invWBR.repeat({1, 1, scale, scale}); - auto FX = (FR - FCBinvWBR) / biaseps.to(torch::kComplexFloat); - auto out = torch::real(torch::fft::ifftn(FX, c10::nullopt, c10::IntArrayRef({-2, -1}))); - - return {out, x_interp, biaseps, FB, FBC, FR, invW, FBR, invWBR, STy_fft}; -} - -std::vector converse2d_cuda_backward( - torch::Tensor grad_out, torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int scale, - const std::vector& saved_tensors -) { - // --- Unpack saved tensors --- - auto x_interp = saved_tensors[0]; - auto biaseps = saved_tensors[1]; - auto FB = saved_tensors[2]; - auto FBC = saved_tensors[3]; - auto FR = saved_tensors[4]; - auto invW = saved_tensors[5]; - auto FBR = saved_tensors[6]; - auto invWBR = saved_tensors[7]; - auto STy_fft = saved_tensors[8]; - - - auto grad_out_c = grad_out.to(torch::kComplexFloat); - auto grad_FX = torch::fft::fftn(grad_out_c, c10::nullopt, c10::IntArrayRef({-2, -1})); - auto FCBinvWBR = FBC * invWBR.repeat({1, 1, scale, scale}); - auto grad_FR = grad_FX / biaseps.to(torch::kComplexFloat); - auto grad_FCBinvWBR = -grad_FR; - auto grad_biaseps = -torch::sum(torch::real(grad_FX * (FR - FCBinvWBR) / torch::pow(biaseps.to(torch::kComplexFloat), 2)), {0, 2, 3}, true); - auto grad_FBC = grad_FCBinvWBR * invWBR.repeat({1, 1, scale, scale}); - auto grad_invWBR = torch::sum(splits_cuda(grad_FCBinvWBR * FBC, scale), -1, false); - auto denom = invW + biaseps.to(torch::kComplexFloat); - auto grad_FBR = grad_invWBR / denom; - auto grad_denom = -grad_invWBR * FBR / torch::pow(denom, 2); - grad_biaseps += torch::sum(torch::real(grad_denom), {0, 2, 3}, true); - auto F2B = torch::pow(torch::abs(FB), 2); - auto grad_F2B = unsplit_cuda( (grad_denom).unsqueeze(-1).expand({-1, -1, -1, -1, scale * scale}) / (float)(scale * scale), scale).to(F2B.dtype()); - auto grad_x1 = unsplit_cuda( (grad_FBR).unsqueeze(-1).expand({-1, -1, -1, -1, scale * scale}) / (float)(scale * scale), scale); - auto grad_FB = grad_x1 * torch::conj(FR); - grad_FR += grad_x1 * torch::conj(FB); - auto grad_x_fft = grad_FR; - auto grad_biaseps_x_interp = torch::real(torch::fft::ifftn(grad_x_fft, c10::nullopt, c10::IntArrayRef({-2, -1}))); - grad_biaseps += torch::sum(grad_biaseps_x_interp * x_interp, {0, 2, 3}, true); - auto grad_x_interp = grad_biaseps_x_interp * biaseps; - grad_FBC += grad_FR * torch::conj(STy_fft); - auto grad_STy_fft = grad_FR * torch::conj(FBC); - auto grad_STy = torch::real(torch::fft::ifftn(grad_STy_fft, c10::nullopt, c10::IntArrayRef({-2, -1}))); - const int H_up = x.size(2) * scale; const int W_up = x.size(3) * scale; - auto grad_x = grad_STy.slice(2, 0, H_up, scale).slice(3, 0, W_up, scale) + interpolate_backward_nearest_cuda(grad_x_interp, scale); - grad_FB += 2 * F2B.to(torch::kComplexFloat) * FBC; - grad_FB += torch::conj(grad_FBC); - auto grad_otf = torch::real(torch::fft::ifftn(grad_FB, c10::nullopt, c10::IntArrayRef({-2, -1}))); - auto grad_weight = torch::roll(grad_otf, {weight.size(2) / 2, weight.size(3) / 2}, {2, 3}).slice(2, 0, weight.size(2)).slice(3, 0, weight.size(3)).clone(); - auto sig = torch::sigmoid(bias - 9.0f); - auto grad_bias = grad_biaseps * sig * (1.0f - sig); - - return {grad_x, grad_weight, grad_bias}; -} \ No newline at end of file diff --git a/models/CUDA/converse2d_op.cpp b/models/CUDA/converse2d_op.cpp deleted file mode 100644 index 00a0d58..0000000 --- a/models/CUDA/converse2d_op.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include -#include - -std::vector converse2d_cuda_forward( - torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int scale, float eps -); - -std::vector converse2d_cuda_backward( - torch::Tensor grad_out, torch::Tensor x, torch::Tensor weight, torch::Tensor bias, int scale, - const std::vector& saved_tensors -); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &converse2d_cuda_forward, "Converse2D forward (CUDA)"); - m.def("backward", &converse2d_cuda_backward, "Converse2D backward (CUDA)"); -} \ No newline at end of file diff --git a/models/benchmark.py b/models/benchmark.py deleted file mode 100644 index e8a3807..0000000 --- a/models/benchmark.py +++ /dev/null @@ -1,201 +0,0 @@ -import torch -from util_converse import Converse2D - - -def benchmark(model, input_tensor, model_name, pass_type='forward', num_runs=100): - """ - Measures the average forward or backward pass time of a model. - Uses torch.cuda.Event for precise GPU timing. - """ - if pass_type not in ['forward', 'backward']: - raise ValueError("pass_type must be 'forward' or 'backward'") - - print(f"Warming up {model_name} backend for {pass_type} pass...") - # Warm-up runs to stabilize performance and CUDA kernels - for _ in range(10): - output = model(input_tensor) - if pass_type == 'backward': - grad_output = torch.ones_like(output) - output.backward(gradient=grad_output) - - torch.cuda.synchronize() - - print(f"Running benchmark for {model_name} backend {pass_type} pass ({num_runs} iterations)...") - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - total_time = 0.0 - - if pass_type == 'forward': - start_event.record() - for _ in range(num_runs): - _ = model(input_tensor) - end_event.record() - else: - for _ in range(num_runs): - output = model(input_tensor) - grad_output = torch.ones_like(output) - torch.cuda.synchronize() - - start_event.record() - output.backward(gradient=grad_output) - end_event.record() - - torch.cuda.synchronize() - total_time += start_event.elapsed_time(end_event) - - del output, grad_output - - if pass_type == 'forward': - torch.cuda.synchronize() - total_time = start_event.elapsed_time(end_event) - - avg_time = (total_time / 1000.0) / num_runs - return avg_time - -# ======================================================================= -# The 'run_comparison' function and the rest of the file remain the same. -# No changes are needed below this line. -# ======================================================================= -def run_comparison(): - if not torch.cuda.is_available(): - print("CUDA is not available. Performance comparison cannot be run.") - return - - if Converse2D is None: - print("Converse2D class not loaded. Aborting benchmark.") - return - - params = { - 'in_channels': 64, - 'out_channels': 64, - 'kernel_size': 3, - 'scale': 2, - 'padding': 2, - 'batch_size': 4, - 'height': 256, - 'width': 256, - 'device': torch.device("cuda") - } - - print("\n--- Benchmark Configuration ---") - for key, value in params.items(): - if key != 'device': - print(f"{key.replace('_', ' ').capitalize()}: {value}") - print(f"Device: {params['device']}") - print("---------------------------------\n") - - input_tensor = torch.randn( - params['batch_size'], - params['in_channels'], - params['height'], - params['width'] - ).to(params['device']) - - try: - print("Initializing PyTorch backend model...") - converse_torch = Converse2D( - in_channels=params['in_channels'], - out_channels=params['out_channels'], - kernel_size=params['kernel_size'], - scale=params['scale'], - padding=params['padding'], - backend='torch' - ).to(params['device']) - print("PyTorch backend model initialized.") - - print("\nInitializing CUDA backend model (compilation may take a moment)...") - converse_cuda = Converse2D( - in_channels=params['in_channels'], - out_channels=params['out_channels'], - kernel_size=params['kernel_size'], - scale=params['scale'], - padding=params['padding'], - backend='cuda' - ).to(params['device']) - print("CUDA backend model initialized and compiled successfully.") - - except Exception as e: - print(f"\nAn error occurred during model initialization: {e}") - print("Please ensure that a compatible CUDA toolkit is installed and configured correctly for PyTorch.") - return - - # --- Run Forward Pass benchmarks --- - print("\n--- Benchmarking Forward Pass ---") - cuda_forward_time = benchmark(converse_cuda, input_tensor, "CUDA", pass_type='forward') - torch_forward_time = benchmark(converse_torch, input_tensor, "PyTorch", pass_type='forward') - - # --- Run Backward Pass benchmarks --- - print("\n--- Benchmarking Backward Pass ---") - # Enable gradient computation on the input tensor for the backward pass - input_tensor.requires_grad_(True) - cuda_backward_time = benchmark(converse_cuda, input_tensor, "CUDA", pass_type='backward') - torch_backward_time = benchmark(converse_torch, input_tensor, "PyTorch", pass_type='backward') - - # --- Report the results --- - print("\n\n--- Performance Comparison Results ---") - print(f"Input Tensor Shape: ({params['batch_size']}, {params['in_channels']}, {params['height']}, {params['width']})") - print("--------------------------------------") - - # Forward pass results - print("\n--- Forward Pass ---") - print(f"PyTorch Backend Average Time: {torch_forward_time * 1000:.4f} ms") - print(f"CUDA Backend Average Time: {cuda_forward_time * 1000:.4f} ms") - if cuda_forward_time > 0: - forward_speedup = torch_forward_time / cuda_forward_time - print(f"CUDA implementation is {forward_speedup:.2f}x faster.") - else: - print("Could not calculate forward pass speedup.") - - # Backward pass results - print("\n--- Backward Pass ---") - print(f"PyTorch Backend Average Time: {torch_backward_time * 1000:.4f} ms") - print(f"CUDA Backend Average Time: {cuda_backward_time * 1000:.4f} ms") - if cuda_backward_time > 0: - backward_speedup = torch_backward_time / cuda_backward_time - print(f"CUDA implementation is {backward_speedup:.2f}x faster.") - else: - print("Could not calculate backward pass speedup.") - - print("\n--------------------------------------") - - -if __name__ == "__main__": - run_comparison() - -""" ---- Device Details --- -GPU Architecture: RTX 2080ti -CUDA version: 12.8 -Torch verison: 2.8.0 - ---- Benchmark Configuration --- -In channels: 64 -Out channels: 64 -Kernel size: 3 -Scale: 2 -Padding: 2 -Batch size: 4 -Height: 256 -Width: 256 -Device: cuda ---------------------------------- - ---- Performance Comparison Results --- -Input Tensor Shape: (4, 64, 256, 256) --------------------------------------- - ---- Forward Pass --- -PyTorch Backend Average Time: 202.0048 ms -CUDA Backend Average Time: 76.2283 ms -CUDA implementation is 2.65x faster. - ---- Backward Pass --- -PyTorch Backend Average Time: 105.8770 ms -CUDA Backend Average Time: 123.8077 ms -CUDA implementation is 0.86x faster. - --------------------------------------- -The CUDA implementation is approximately 1.95x faster than the PyTorch implementation. - -""" \ No newline at end of file diff --git a/models/util_converse.py b/models/util_converse.py index cd134c5..e910b3a 100644 --- a/models/util_converse.py +++ b/models/util_converse.py @@ -1,69 +1,42 @@ import os -os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5' import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.cpp_extension import load from collections import OrderedDict +_HAS_CONVERSE2D_EXT = False -""" -# -------------------------------------------- -# LayerNorm for Vision Normalization -# -------------------------------------------- -""" -converse2D_cuda = None - -def load_converse2D(): - current_file_dir = os.path.dirname(os.path.abspath(__file__)) - converse2D_cuda_lib = load( - name="converse2D_cuda", - sources=[ - os.path.join(current_file_dir, "CUDA/converse2d_op.cpp"), - os.path.join(current_file_dir, "CUDA/converse2d_cuda.cu"), - ], - verbose=True, - extra_cuda_cflags=[ - "-res-usage", - "--maxrregcount 60", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "-gencode arch=compute_75,code=sm_75", - ], - ) - return converse2D_cuda_lib - - -class Converse2DFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x, weight, bias, scale, eps): - ctx.scale = scale - - global converse2D_cuda - if converse2D_cuda is None: - converse2D_cuda = load_converse2D() - - out = converse2D_cuda.forward(x, weight, bias, scale, eps) - tensors_to_save = out[1:] - ctx.save_for_backward(x, weight, bias, *tensors_to_save) - - return out[0] +def _try_import_converse2d_ext(): + global _HAS_CONVERSE2D_EXT + if _HAS_CONVERSE2D_EXT: + return - @staticmethod - def backward(ctx, grad_out): - x, weight, bias, *saved_tensors = ctx.saved_tensors - scale = ctx.scale + candidates = [ + "converse2d_ext", + "torch_converse2d.converse2d_ext", + "torch_converse2d", + ] + for mod in candidates: + try: + __import__(mod) + except Exception: + continue + if hasattr(torch.ops, "converse2d") and hasattr(torch.ops.converse2d, "forward"): + print(mod) + _HAS_CONVERSE2D_EXT = True + break - grad_x, grad_weight, grad_bias = converse2D_cuda.backward(grad_out, x, weight, bias, scale, saved_tensors) +_try_import_converse2d_ext() - return grad_x, grad_weight, grad_bias, None, None - -def Converse2D_CUDA(x, weight, bias, scale, eps): - return Converse2DFunction.apply(x, weight, bias, scale, eps) +converse2d_CUDA = torch.ops.converse2d.forward +""" +# -------------------------------------------- +# LayerNorm for Vision Normalization +# -------------------------------------------- +""" class LayerNorm(nn.Module): ''' LayerNorm that supports two data formats: channels_last (default) or channels_first. @@ -121,7 +94,7 @@ def sequential(*args): # -------------------------------------------- """ class Converse2D(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, padding_mode='circular', eps=1e-5, backend:str='torch'): + def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, padding_mode='circular', eps=1e-5, backend: str = "auto"): super(Converse2D, self).__init__() """ Converse2D Operator for Image Restoration Tasks. @@ -138,6 +111,8 @@ def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, p Default is `circular`. eps (float, optional): Small value added to denominators for numerical stability. Default is a small value like 1e-5. + backend (str, optional): Backend for computing the convolution. One of {'auto', 'cuda', 'pytorch'}. + Default is 'auto'. Returns: Tensor: Output tensor of shape (N, out_channels, H * scale, W * scale), where spatial dimensions @@ -151,8 +126,9 @@ def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, p self.padding = padding self.padding_mode = padding_mode self.eps = eps - assert backend in ['torch', 'cuda'], "Not Implementd Yet" #, 'cuda_fuse' - self.backend = backend + self.backend = backend.lower() + if self.backend not in ("auto", "cuda", "pytorch"): + raise ValueError(f"backend must be 'auto' | 'cuda' | 'pytorch', got: {self.backend}") # ensure depthwise assert self.out_channels == self.in_channels @@ -162,17 +138,37 @@ def __init__(self, in_channels, out_channels, kernel_size, scale=1, padding=2, p def forward(self, x): + if self.padding > 0: x = nn.functional.pad(x, pad=[self.padding, self.padding, self.padding, self.padding], mode=self.padding_mode, value=0) - if self.backend == 'torch': - self.biaseps = torch.sigmoid(self.bias-9.0) + self.eps - _, _, h, w = x.shape + self.biaseps = torch.sigmoid(self.bias-9.0) + self.eps + _, _, h, w = x.shape + + backend = (os.environ.get("CONVERSE2D_BACKEND", "") or self.backend).lower() + + def _can_use_cuda_backend(): + return (_HAS_CONVERSE2D_EXT and x.is_cuda) + + use_cuda_backend = False + if backend == "cuda": + if not _can_use_cuda_backend(): + raise RuntimeError("Converse2D backend='cuda' but CUDA extension is unavailable.") + use_cuda_backend = True + elif backend == "python": + use_cuda_backend = False + else: # "auto" + use_cuda_backend = _can_use_cuda_backend() + + if use_cuda_backend: + x0 = x if self.scale == 1 else F.interpolate(x, scale_factor=self.scale, mode='nearest') + out = converse2d_CUDA( + x, x0, self.weight, self.bias, int(self.scale), float(self.eps) + ) + else: STy = self.upsample(x, scale=self.scale) if self.scale != 1: x = nn.functional.interpolate(x, scale_factor=self.scale, mode='nearest') - # x = nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear',align_corners=False) - # x = torch.zeros_like(x) FB = self.p2o(self.weight, (h*self.scale, w*self.scale)) FBC = torch.conj(FB) @@ -187,10 +183,6 @@ def forward(self, x): FCBinvWBR = FBC*invWBR.repeat(1, 1, self.scale, self.scale) FX = (FR-FCBinvWBR)/self.biaseps out = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) - elif self.backend == 'cuda': - out = Converse2D_CUDA(x, self.weight, self.bias, self.scale, self.eps) - else: - raise NotImplementedError if self.padding > 0: out = out[..., self.padding*self.scale:-self.padding*self.scale, self.padding*self.scale:-self.padding*self.scale] @@ -257,7 +249,7 @@ def upsample(self, x, scale=3): # -------------------------------------------- """ class ConverseBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3, scale=1, padding=2, padding_mode='replicate', eps=1e-5, backend='torch'): + def __init__(self, in_channels, out_channels, kernel_size=3, scale=1, padding=2, padding_mode='replicate', eps=1e-5): super(ConverseBlock, self).__init__() """ ConverseBlock: A Convolutional Block for Image Restoration using Converse2D Operations. @@ -284,7 +276,7 @@ def __init__(self, in_channels, out_channels, kernel_size=3, scale=1, padding=2, self.conv1 = nn.Sequential(LayerNorm(in_channels, eps=1e-5, data_format="channels_first"), nn.Conv2d(in_channels, 2*out_channels, 1, 1, 0), nn.GELU(), - Converse2D(2*out_channels, 2*out_channels, kernel_size, scale=scale, padding=padding, padding_mode=padding_mode, eps=eps, backend=backend), + Converse2D(2*out_channels, 2*out_channels, kernel_size, scale=scale, padding=padding, padding_mode=padding_mode, eps=eps), nn.GELU(), nn.Conv2d(2*out_channels, out_channels, 1, 1, 0)) diff --git a/test/Results.md b/test/Results.md new file mode 100644 index 0000000..acf966a --- /dev/null +++ b/test/Results.md @@ -0,0 +1,36 @@ +# 环境配置 +- **设备**: `cuda` +- **PyTorch版本**: `2.4.0+cu121` +- **CUDA版本**: `12.1` +- **cuDNN版本**: `90100` + +--- + +# 测试配置 +- **数据类型**: `['torch.float32']` +- **AMP**: `False` +- **模式**: `TRAIN=True`, `INFER=True` +- **预热迭代次数**: `10` +- **总迭代次数**: `50` + +--- + +# 逐例对比 (CUDA vs PyTorch) + +| dtype | B | C | H | W | s | k | fwd_py(ms) | fwd_cu(ms) | fwd_Gpix/s(py) | fwd_Gpix/s(cu) | fwd_speedup | bwd_py(ms) | bwd_cu(ms) | bwd_Gpix/s(py) | bwd_Gpix/s(cu) | bwd_speedup | +| :---: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :---: | :-: | :-: | :-: | :-: | :---: | +| float32 | 1 | 3 | 128 | 128 | 2 | 5 | 4.62 | 0.60 | 0.014 | 0.110 | 7.73× | 17.57 | 2.29 | 0.004 | 0.029 | 7.68× | +| float32 | 2 | 3 | 256 | 256 | 2 | 5 | 10.21 | 0.61 | 0.051 | 0.863 | 16.81× | 19.87 | 2.73 | 0.026 | 0.192 | 7.27× | +| float32 | 4 | 3 | 256 | 256 | 2 | 5 | 15.71 | 0.71 | 0.067 | 1.476 | 22.12× | 24.21 | 3.58 | 0.043 | 0.293 | 6.77× | +| float32 | 2 | 8 | 256 | 256 | 2 | 5 | 22.28 | 1.16 | 0.024 | 0.454 | 19.27× | 31.86 | 4.14 | 0.016 | 0.127 | 7.70× | +| float32 | 1 | 3 | 512 | 512 | 2 | 5 | 20.03 | 1.30 | 0.052 | 0.806 | 15.41× | 32.28 | 3.90 | 0.032 | 0.269 | 8.27× | + +--- + +### 总体表现 +- **整体前向传播(Forward)几何平均加速比**: **15.35×** (CUDA vs PyTorch) +- **整体训练(Train)几何平均加速比**: **7.52×** (CUDA vs PyTorch) + +### 精度测试 +- forward: max|Δ|=0.000e+00 max rel=0.000e+00 +- backward: max|Δ|=0.000e+00 max rel=0.000e+00 \ No newline at end of file diff --git a/test/test_error.py b/test/test_error.py new file mode 100644 index 0000000..6a01fbe --- /dev/null +++ b/test/test_error.py @@ -0,0 +1,66 @@ +import os +import torch +from models.util_converse import Converse2D + +torch.manual_seed(0) + +device = "cuda" if torch.cuda.is_available() else "cpu" +dtype = torch.float32 +B, C, H, W, scale = 2, 3, 32, 40, 2 + +x = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=True) + +m = Converse2D( + in_channels=C, out_channels=C, kernel_size=5, scale=scale, + padding=2, padding_mode="circular", eps=1e-5, backend="pytorch" +).to(device=device, dtype=dtype) +m.eval() + +x_py = x.detach().clone().requires_grad_(True) +m.backend = "python" +y_py = m(x_py) +loss_py = y_py.square().mean() +g_py = torch.autograd.grad(loss_py, x_py)[0].detach() + + +have_cuda = False +try: + if device == "cuda": + m.backend = "cuda" + x_cuda = x.detach().clone().requires_grad_(True) + y_cuda = m(x_cuda) + loss_cuda = y_cuda.square().mean() + g_cuda = torch.autograd.grad(loss_cuda, x_cuda)[0].detach() + have_cuda = True + print("[INFO] CUDA backend: OK") + else: + print("[WARN] CUDA not available on this device.") +except Exception as e: + print("[WARN] CUDA backend unavailable ->", repr(e)) + +print("[INFO] Python backend: OK") + +if have_cuda: + with torch.no_grad(): + out_abs = (y_cuda - y_py).abs() + grad_abs = (g_cuda - g_py).abs() + out_mae = out_abs.max().item() + grad_mae = grad_abs.max().item() + out_rel = (out_abs / (y_py.abs() + 1e-8)).max().item() + grad_rel = (grad_abs / (g_py.abs() + 1e-8)).max().item() + + print(f"forward: max|Δ|={out_mae:.3e} max rel={out_rel:.3e}") + print(f"backward: max|Δ|={grad_mae:.3e} max rel={grad_rel:.3e}") + +try: + torch.manual_seed(0) + B2, C2, H2, W2, s2 = 1, 2, 8, 9, 2 + x64 = torch.randn(B2, C2, H2, W2, device=device, dtype=torch.float64, requires_grad=True) + m64 = Converse2D(C2, C2, kernel_size=5, scale=s2, padding=2, + padding_mode="circular", eps=1e-5, backend="auto").to(device=device, dtype=torch.float64) + m64.eval() + def f(inp): return m64(inp) + torch.autograd.gradcheck(f, (x64,), eps=1e-6, atol=1e-4, rtol=1e-4) + print("[INFO] Gradcheck (float64) passed.") +except Exception as e: + print("[WARN] Gradcheck skipped/failed ->", repr(e)) diff --git a/test/test_speed.py b/test/test_speed.py new file mode 100644 index 0000000..5d10efb --- /dev/null +++ b/test/test_speed.py @@ -0,0 +1,177 @@ +import os +import time +import math +import torch +from util_converse import Converse2D + +# -------------------------- +# Config +# -------------------------- +device = "cuda" if torch.cuda.is_available() else "cpu" +DTYPES = [torch.float32] +USE_AUTOCast = False +WARMUP = 10 +ITERS = 50 +TRAIN = True +INFER = True +CASES = [ + # (B, C, H, W, scale, ksize, padding, padding_mode) + (1, 3, 128, 128, 2, 5, 2, "circular"), + (2, 3, 256, 256, 2, 5, 2, "circular"), + (4, 3, 256, 256, 2, 5, 2, "circular"), + (2, 8, 256, 256, 2, 5, 2, "circular"), + (1, 3, 512, 512, 2, 5, 2, "circular"), # 320 +] + +def synchronize(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + +def timed_run(fn, warmup=WARMUP, iters=ITERS): + for _ in range(warmup): + fn() + synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + synchronize() + t1 = time.perf_counter() + return (t1 - t0) / iters + +def make_model(C, scale, ksize=5, padding=2, padding_mode="circular", dtype=torch.float32): + m = Converse2D( + in_channels=C, out_channels=C, kernel_size=ksize, + scale=scale, padding=padding, padding_mode=padding_mode, + eps=1e-5, backend="pytorch" + ).to(device=device, dtype=dtype) + m.eval() + return m + +def clone_as_cuda_backend(m): + m2 = Converse2D( + in_channels=m.in_channels, out_channels=m.out_channels, + kernel_size=m.kernel_size, scale=m.scale, padding=m.padding, + padding_mode=m.padding_mode, eps=m.eps, backend="cuda" + ).to(device=device, dtype=next(m.parameters()).dtype) + m2.load_state_dict(m.state_dict()) + m2.eval() + return m2 + +def tp_gpix_per_s(B,H,W,s,t): + if t is None or t <= 0: return None + return (B * (H*s) * (W*s) / t) / 1e9 + +def speedup_and_pct(t_py, t_cu): + if t_py and t_cu and t_py > 0 and t_cu > 0: + sp = t_py / t_cu + pct = (t_py - t_cu) / t_py * 100.0 + return sp, pct + return None, None + +def fmt_ms(t): return "-" if t is None else f"{t*1e3:7.2f}" +def fmt_tp(x): return "-" if x is None else f"{x:6.3f}" +def fmt_sp(x): return "-" if x is None else f"{x:5.2f}×" +def fmt_pct(p): return "-" if p is None else f"{p:6.1f}%" + +def geom_mean(vals): + vals = [v for v in vals if v and v > 0] + if not vals: return None + return math.exp(sum(math.log(v) for v in vals) / len(vals)) + +def run_case(B,C,H,W,scale,ksize,padding,padding_mode,dtype): + x = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=TRAIN) + + torch.manual_seed(0); m_py = make_model(C, scale, ksize, padding, padding_mode, dtype) + torch.manual_seed(0); m_cu = clone_as_cuda_backend(m_py) + + fwd_py = fwd_cu = None + if INFER: + def fwd_run(m): + def _call(): + with torch.no_grad(): + if USE_AUTOCast and dtype is torch.bfloat16 and device == "cuda": + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + _ = m(x) + else: + _ = m(x) + return _call + fwd_py = timed_run(fwd_run(m_py)) + fwd_cu = timed_run(fwd_run(m_cu)) + + bwd_py = bwd_cu = None + if TRAIN: + def train_run(m): + def _call(): + x_local = x.detach().clone().requires_grad_(True) + if USE_AUTOCast and dtype is torch.bfloat16 and device == "cuda": + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + y = m(x_local); loss = y.square().mean() + else: + y = m(x_local); loss = y.square().mean() + loss.backward() + return _call + bwd_py = timed_run(train_run(m_py)) + bwd_cu = timed_run(train_run(m_cu)) + + fwd_tp_py = tp_gpix_per_s(B,H,W,scale,fwd_py) + fwd_tp_cu = tp_gpix_per_s(B,H,W,scale,fwd_cu) + bwd_tp_py = tp_gpix_per_s(B,H,W,scale,bwd_py) + bwd_tp_cu = tp_gpix_per_s(B,H,W,scale,bwd_cu) + + fwd_sp, fwd_pct = speedup_and_pct(fwd_py, fwd_cu) + bwd_sp, bwd_pct = speedup_and_pct(bwd_py, bwd_cu) + + return { + "shape": (B,C,H,W,scale,ksize,padding,padding_mode,str(dtype).split('.')[-1]), + "fwd_py": fwd_py, "fwd_cu": fwd_cu, "fwd_tp_py": fwd_tp_py, "fwd_tp_cu": fwd_tp_cu, + "bwd_py": bwd_py, "bwd_cu": bwd_cu, "bwd_tp_py": bwd_tp_py, "bwd_tp_cu": bwd_tp_cu, + "fwd_sp": fwd_sp, "fwd_pct": fwd_pct, "bwd_sp": bwd_sp, "bwd_pct": bwd_pct + } + +def main(): + print(f"[Env] device={device}, torch={torch.__version__}, cuda={torch.version.cuda}, cudnn={torch.backends.cudnn.version()}") + print(f"[Cfg] dtypes={[d.__str__() for d in DTYPES]}, AMP={USE_AUTOCast}, TRAIN={TRAIN}, INFER={INFER}, warmup={WARMUP}, iters={ITERS}\n") + + rows = [] + for dtype in DTYPES: + for (B,C,H,W,s,ks,pd,pm) in CASES: + rows.append(run_case(B,C,H,W,s,ks,pd,pm,dtype)) + + print("=== Per‑case Comparison (CUDA vs PyTorch) ===") + for r in rows: + B,C,H,W,s,ks,pd,pm,dtype = r["shape"] + tag = f"[{dtype}] B{B} C{C} {H}x{W} s{s} k{ks}" + # forward + if INFER: + print(f"{tag} | Forward : Py {fmt_ms(r['fwd_py'])} ms ({fmt_tp(r['fwd_tp_py'])} Gpix/s) " + f"vs CUDA {fmt_ms(r['fwd_cu'])} ms ({fmt_tp(r['fwd_tp_cu'])} Gpix/s) " + f"-> CUDA is {fmt_sp(r['fwd_sp'])} faster ({fmt_pct(r['fwd_pct'])} time saved)") + # backward + if TRAIN: + print(f"{tag} | Train : Py {fmt_ms(r['bwd_py'])} ms ({fmt_tp(r['bwd_tp_py'])} Gpix/s) " + f"vs CUDA {fmt_ms(r['bwd_cu'])} ms ({fmt_tp(r['bwd_tp_cu'])} Gpix/s) " + f"-> CUDA is {fmt_sp(r['bwd_sp'])} faster ({fmt_pct(r['bwd_pct'])} time saved)") + print("") + + hdr = ("dtype B C H W s k | fwd_py(ms) fwd_cu(ms) fwd_Gpix/s(py) fwd_Gpix/s(cu) fwd_speedup " + "| bwd_py(ms) bwd_cu(ms) bwd_Gpix/s(py) bwd_Gpix/s(cu) bwd_speedup") + print(hdr); print("-"*len(hdr)) + for r in rows: + B,C,H,W,s,ks,pd,pm,dtype = r["shape"] + line = (f"{dtype:6s} {B:3d} {C:3d} {H:5d} {W:5d} {s:2d} {ks:3d} | " + f"{fmt_ms(r['fwd_py'])} {fmt_ms(r['fwd_cu'])} {fmt_tp(r['fwd_tp_py'])} {fmt_tp(r['fwd_tp_cu'])} {fmt_sp(r['fwd_sp'])} | " + f"{fmt_ms(r['bwd_py'])} {fmt_ms(r['bwd_cu'])} {fmt_tp(r['bwd_tp_py'])} {fmt_tp(r['bwd_tp_cu'])} {fmt_sp(r['bwd_sp'])}") + print(line) + + fwd_sps = [r["fwd_sp"] for r in rows if r["fwd_sp"]] + bwd_sps = [r["bwd_sp"] for r in rows if r["bwd_sp"]] + gm_fwd = geom_mean(fwd_sps) + gm_bwd = geom_mean(bwd_sps) + if gm_fwd: + print(f"\nOverall Forward Geomean Speedup : {gm_fwd:.2f}× (CUDA vs PyTorch)") + if gm_bwd: + print(f"Overall Train Geomean Speedup : {gm_bwd:.2f}× (CUDA vs PyTorch)") + +if __name__ == "__main__": + torch.set_grad_enabled(True) + main() From 177131d52a6fb30699f003e38734d1f112eb08b2 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Fri, 29 Aug 2025 19:51:41 +0800 Subject: [PATCH 04/22] Update installation method --- .gitignore | 1 + README.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3aa864b --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +TODO.md diff --git a/README.md b/README.md index 4d1a03a..ca9da18 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Kernel Registry ```python cd ./Converse2D -pip install -e. +pip install --no-build-isolation -e. ``` **Usage** From 34311b20fb155efcc07dad4bb8f1a6633cf074e9 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Sat, 30 Aug 2025 18:33:54 +0800 Subject: [PATCH 05/22] add cache --- .gitignore | 1 + .../{converse2d_ext.cpp => converse2d_v1.cpp} | 0 Converse2D/torch_converse2d/converse2d_v2.cpp | 225 ++++++++++++++ Converse2D/torch_converse2d/converse2d_v3.cpp | 284 ++++++++++++++++++ test/Results.md | 22 +- test/test_cache.py | 136 +++++++++ test/test_error.py | 1 - test/test_speed.py | 1 - 8 files changed, 659 insertions(+), 11 deletions(-) rename Converse2D/torch_converse2d/{converse2d_ext.cpp => converse2d_v1.cpp} (100%) create mode 100644 Converse2D/torch_converse2d/converse2d_v2.cpp create mode 100644 Converse2D/torch_converse2d/converse2d_v3.cpp create mode 100644 test/test_cache.py diff --git a/.gitignore b/.gitignore index 3aa864b..d41279a 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ TODO.md +Optimization.md \ No newline at end of file diff --git a/Converse2D/torch_converse2d/converse2d_ext.cpp b/Converse2D/torch_converse2d/converse2d_v1.cpp similarity index 100% rename from Converse2D/torch_converse2d/converse2d_ext.cpp rename to Converse2D/torch_converse2d/converse2d_v1.cpp diff --git a/Converse2D/torch_converse2d/converse2d_v2.cpp b/Converse2D/torch_converse2d/converse2d_v2.cpp new file mode 100644 index 0000000..23b8a36 --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v2.cpp @@ -0,0 +1,225 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using at::Tensor; + +struct FBKey +{ + int64_t device_id; + at::ScalarType dtype; + int64_t channels; + int64_t H, W; + void *ptr; + + bool operator==(const FBKey &other) const + { + return device_id == other.device_id && dtype == other.dtype && + channels == other.channels && H == other.H && W == other.W && + ptr == other.ptr; + } +}; + +namespace std +{ + template <> + struct hash + { + size_t operator()(const FBKey &k) const + { + return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ + ((hash()(k.H) ^ hash()(k.W)) << 1) ^ + ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + } + }; +} + +constexpr size_t FB_CACHE_MAX_SIZE = 64; + +static std::unordered_map> fb_cache; +static std::list fb_cache_lru; +static std::mutex fb_cache_mutex; + +static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) +{ + auto C = psf.size(1); + FBKey key{ + psf.device().index(), + psf.scalar_type(), + C, H, W, + psf.data_ptr()}; + + { + std::lock_guard lock(fb_cache_mutex); + auto it = fb_cache.find(key); + if (it != fb_cache.end()) + { + fb_cache_lru.remove(key); + fb_cache_lru.push_front(key); + return it->second; + } + } + + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + int64_t kh = psf.size(2), kw = psf.size(3); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); + Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor F2B = at::abs(FB).pow(2); + + { + std::lock_guard lock(fb_cache_mutex); + fb_cache[key] = {FB, F2B}; + fb_cache_lru.push_front(key); + + if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) + { + fb_cache.erase(fb_cache_lru.back()); + fb_cache_lru.pop_back(); + } + } + + return {FB, F2B}; +} + +static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) +{ + TORCH_CHECK(s >= 1, "scale must be >= 1"); + if (s == 1) + return x; + auto sizes = x.sizes().vec(); + sizes[sizes.size() - 2] *= s; + sizes[sizes.size() - 1] *= s; + Tensor z = at::zeros(sizes, x.options()); + z.index_put_( + {at::indexing::Slice(), at::indexing::Slice(), + at::indexing::Slice(0, z.size(-2), s), + at::indexing::Slice(0, z.size(-1), s)}, + x); + return z; +} + +static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) +{ + TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); + TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); + + const auto &sizes = a.sizes(); + const int64_t L = a.dim(); + const int64_t W = sizes[L - 2]; + const int64_t H = sizes[L - 1]; + const int64_t W_s = W / s; + const int64_t H_s = H / s; + + std::vector view_shape; + view_shape.reserve(L + 2); + for (int64_t i = 0; i < L - 2; ++i) + view_shape.push_back(sizes[i]); + view_shape.push_back(s); + view_shape.push_back(W_s); + view_shape.push_back(s); + view_shape.push_back(H_s); + Tensor v = a.view(view_shape); + + std::vector perm; + perm.reserve(view_shape.size()); + for (int64_t i = 0; i < L - 2; ++i) + perm.push_back(i); + perm.push_back(L - 2 + 1); + perm.push_back(L - 2 + 3); + perm.push_back(L - 2 + 0); + perm.push_back(L - 2 + 2); + Tensor p = v.permute(perm).contiguous(); + + std::vector merge_shape; + merge_shape.reserve(L + 1); + for (int64_t i = 0; i < L - 2; ++i) + merge_shape.push_back(p.size(i)); + merge_shape.push_back(W_s); + merge_shape.push_back(H_s); + merge_shape.push_back(s * s); + Tensor r = p.view(merge_shape); + + return r.mean(-1, /*keepdim=*/false); +} + +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +{ + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + TORCH_CHECK(scale >= 1, "scale must be >= 1"); + + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); + + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; + + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + Tensor STy = sfold_upsample_zero_insertion(x, scale); + + auto [FB, F2B] = p2o_cached(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + + Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); + + Tensor x1 = FB * FR; + Tensor FBR = splits_mean_then_mean(x1, scale); + Tensor invW = splits_mean_then_mean(F2B, scale); + + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; + + Tensor invWBR_rep = invWBR.repeat({1, 1, scale, scale}); + Tensor FCBinvWBR = FBC * invWBR_rep; + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor out = at::real(out_c); + return out; +} + +void clear_fb_cache() +{ + std::lock_guard lock(fb_cache_mutex); + fb_cache.clear(); + fb_cache_lru.clear(); +} + +TORCH_LIBRARY(converse2d, m) +{ + m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); + m.def("clear_cache() -> ()"); +} +TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) +{ + m.impl("forward", TORCH_FN(converse2d_forward)); + m.impl("clear_cache", TORCH_FN(clear_fb_cache)); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v3.cpp b/Converse2D/torch_converse2d/converse2d_v3.cpp new file mode 100644 index 0000000..a7ac3cc --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v3.cpp @@ -0,0 +1,284 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using at::Tensor; + +// ---------- FB Cache ---------- +struct FBKey +{ + int64_t device_id; + at::ScalarType dtype; + int64_t channels; + int64_t H, W; + void *ptr; + + bool operator==(const FBKey &other) const + { + return device_id == other.device_id && dtype == other.dtype && + channels == other.channels && H == other.H && W == other.W && + ptr == other.ptr; + } +}; + +namespace std +{ + template <> + struct hash + { + size_t operator()(const FBKey &k) const + { + return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ + ((hash()(k.H) ^ hash()(k.W)) << 1) ^ + ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + } + }; +} // namespace std + +constexpr size_t FB_CACHE_MAX_SIZE = 64; +static std::unordered_map> fb_cache; +static std::list fb_cache_lru; +static std::mutex fb_cache_mutex; + +__global__ void block_mean_kernel( + const float *__restrict__ input, + float *__restrict__ output, + int B, int C, int H, int W, int s) +{ + int b = blockIdx.z; + int c = blockIdx.y; + int h = blockIdx.x / W; + int w = blockIdx.x % W; + + if (h >= H || w >= W) + return; + + int Hs = H * s; + int Ws = W * s; + + float sum = 0.0f; + for (int i = 0; i < s; ++i) + { + for (int j = 0; j < s; ++j) + { + int hs = h * s + i; + int ws = w * s + j; + int idx = ((b * C + c) * Hs + hs) * Ws + ws; + sum += input[idx]; + } + } + + int out_idx = ((b * C + c) * H + h) * W + w; + output[out_idx] = sum / (s * s); +} + +Tensor block_mean_cuda(const Tensor &input, int64_t s) +{ + TORCH_CHECK(input.dim() == 4, "input must be (B,C,Hs,Ws)"); + TORCH_CHECK(input.device().is_cuda(), "input must be CUDA tensor"); + + const int64_t B = input.size(0); + const int64_t C = input.size(1); + const int64_t Hs = input.size(2); + const int64_t Ws = input.size(3); + TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "Hs and Ws must be divisible by s"); + + const int64_t H = Hs / s; + const int64_t W = Ws / s; + + Tensor output = at::empty({B, C, H, W}, input.options()); + + const int threads = 1; + const dim3 blocks(H * W, C, B); + + block_mean_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + B, C, H, W, s); + + return output; +} + +static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) +{ + auto C = psf.size(1); + FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; + + { + std::lock_guard lock(fb_cache_mutex); + auto it = fb_cache.find(key); + if (it != fb_cache.end()) + { + fb_cache_lru.remove(key); + fb_cache_lru.push_front(key); + return it->second; + } + } + + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + int64_t kh = psf.size(2), kw = psf.size(3); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); + Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor F2B = at::abs(FB).pow(2); + + { + std::lock_guard lock(fb_cache_mutex); + fb_cache[key] = {FB, F2B}; + fb_cache_lru.push_front(key); + + if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) + { + fb_cache.erase(fb_cache_lru.back()); + fb_cache_lru.pop_back(); + } + } + + return {FB, F2B}; +} + +// ---------- Utility ---------- +static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) +{ + if (s == 1) + return x; + auto sizes = x.sizes().vec(); + sizes[sizes.size() - 2] *= s; + sizes[sizes.size() - 1] *= s; + Tensor z = at::zeros(sizes, x.options()); + z.index_put_({at::indexing::Slice(), at::indexing::Slice(), + at::indexing::Slice(0, z.size(-2), s), + at::indexing::Slice(0, z.size(-1), s)}, + x); + return z; +} + +static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) +{ + const auto &sizes = a.sizes(); + const int64_t L = a.dim(); + const int64_t W = sizes[L - 2]; + const int64_t H = sizes[L - 1]; + const int64_t W_s = W / s; + const int64_t H_s = H / s; + + std::vector view_shape; + view_shape.reserve(L + 2); + for (int64_t i = 0; i < L - 2; ++i) + view_shape.push_back(sizes[i]); + view_shape.push_back(s); + view_shape.push_back(W_s); + view_shape.push_back(s); + view_shape.push_back(H_s); + Tensor v = a.view(view_shape); + + std::vector perm; + perm.reserve(view_shape.size()); + for (int64_t i = 0; i < L - 2; ++i) + perm.push_back(i); + perm.push_back(L - 2 + 1); + perm.push_back(L - 2 + 3); + perm.push_back(L - 2 + 0); + perm.push_back(L - 2 + 2); + Tensor p = v.permute(perm).contiguous(); + + std::vector merge_shape; + merge_shape.reserve(L + 1); + for (int64_t i = 0; i < L - 2; ++i) + merge_shape.push_back(p.size(i)); + merge_shape.push_back(W_s); + merge_shape.push_back(H_s); + merge_shape.push_back(s * s); + Tensor r = p.view(merge_shape); + + return r.mean(-1, /*keepdim=*/false); +} + +// ---------- Forward ---------- +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +{ + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); + + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; + + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + Tensor STy = sfold_upsample_zero_insertion(x, scale); + + auto [FB, F2B] = p2o_cached(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + + Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); + + Tensor x1 = FB * FR; + // Tensor FBR = splits_mean_then_mean(x1, scale); + // Tensor invW = splits_mean_then_mean(F2B, scale); + Tensor FBR = block_mean_cuda(x1, scale); + Tensor invW = block_mean_cuda(F2B, scale); + + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; + + // ---- 关键替换:避免 repeat 物化,使用广播展开(零 stride) ---- + // 形状: (B,C,H,1,W,1) -> expand 为 (B,C,H,scale,W,scale) -> reshape 为 (B,C,Hs,Ws) + Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) + .expand({B, C, H, scale, W, scale}) + .reshape({B, C, Hs, Ws}); + Tensor FCBinvWBR = FBC * invWBR_exp; + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor out = at::real(out_c); + return out; +} + +// ---------- Clear Cache ---------- +void clear_fb_cache() +{ + std::lock_guard lock(fb_cache_mutex); + fb_cache.clear(); + fb_cache_lru.clear(); +} + +// ---------- Registration ---------- +TORCH_LIBRARY(converse2d, m) +{ + m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); + m.def("clear_cache() -> ()"); +} +TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) +{ + m.impl("forward", TORCH_FN(converse2d_forward)); + m.impl("clear_cache", TORCH_FN(clear_fb_cache)); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} +s \ No newline at end of file diff --git a/test/Results.md b/test/Results.md index acf966a..2da77aa 100644 --- a/test/Results.md +++ b/test/Results.md @@ -1,5 +1,6 @@ # 环境配置 -- **设备**: `cuda` + +- **设备**: 4090 - **PyTorch版本**: `2.4.0+cu121` - **CUDA版本**: `12.1` - **cuDNN版本**: `90100` @@ -7,6 +8,7 @@ --- # 测试配置 + - **数据类型**: `['torch.float32']` - **AMP**: `False` - **模式**: `TRAIN=True`, `INFER=True` @@ -17,20 +19,22 @@ # 逐例对比 (CUDA vs PyTorch) -| dtype | B | C | H | W | s | k | fwd_py(ms) | fwd_cu(ms) | fwd_Gpix/s(py) | fwd_Gpix/s(cu) | fwd_speedup | bwd_py(ms) | bwd_cu(ms) | bwd_Gpix/s(py) | bwd_Gpix/s(cu) | bwd_speedup | -| :---: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :---: | :-: | :-: | :-: | :-: | :---: | -| float32 | 1 | 3 | 128 | 128 | 2 | 5 | 4.62 | 0.60 | 0.014 | 0.110 | 7.73× | 17.57 | 2.29 | 0.004 | 0.029 | 7.68× | -| float32 | 2 | 3 | 256 | 256 | 2 | 5 | 10.21 | 0.61 | 0.051 | 0.863 | 16.81× | 19.87 | 2.73 | 0.026 | 0.192 | 7.27× | -| float32 | 4 | 3 | 256 | 256 | 2 | 5 | 15.71 | 0.71 | 0.067 | 1.476 | 22.12× | 24.21 | 3.58 | 0.043 | 0.293 | 6.77× | -| float32 | 2 | 8 | 256 | 256 | 2 | 5 | 22.28 | 1.16 | 0.024 | 0.454 | 19.27× | 31.86 | 4.14 | 0.016 | 0.127 | 7.70× | -| float32 | 1 | 3 | 512 | 512 | 2 | 5 | 20.03 | 1.30 | 0.052 | 0.806 | 15.41× | 32.28 | 3.90 | 0.032 | 0.269 | 8.27× | +| dtype | B | C | H | W | s | k | fwd_py(ms) | fwd_cu(ms) | fwd_Gpix/s(py) | fwd_Gpix/s(cu) | fwd_speedup | bwd_py(ms) | bwd_cu(ms) | bwd_Gpix/s(py) | bwd_Gpix/s(cu) | bwd_speedup | +| :-----: | :-: | :-: | :-: | :-: | :-: | :-: | :--------: | :--------: | :------------: | :------------: | :---------: | :--------: | :--------: | :------------: | :------------: | :---------: | +| float32 | 1 | 3 | 128 | 128 | 2 | 5 | 4.62 | 0.60 | 0.014 | 0.110 | 7.73× | 17.57 | 2.29 | 0.004 | 0.029 | 7.68× | +| float32 | 2 | 3 | 256 | 256 | 2 | 5 | 10.21 | 0.61 | 0.051 | 0.863 | 16.81× | 19.87 | 2.73 | 0.026 | 0.192 | 7.27× | +| float32 | 4 | 3 | 256 | 256 | 2 | 5 | 15.71 | 0.71 | 0.067 | 1.476 | 22.12× | 24.21 | 3.58 | 0.043 | 0.293 | 6.77× | +| float32 | 2 | 8 | 256 | 256 | 2 | 5 | 22.28 | 1.16 | 0.024 | 0.454 | 19.27× | 31.86 | 4.14 | 0.016 | 0.127 | 7.70× | +| float32 | 1 | 3 | 512 | 512 | 2 | 5 | 20.03 | 1.30 | 0.052 | 0.806 | 15.41× | 32.28 | 3.90 | 0.032 | 0.269 | 8.27× | --- ### 总体表现 + - **整体前向传播(Forward)几何平均加速比**: **15.35×** (CUDA vs PyTorch) - **整体训练(Train)几何平均加速比**: **7.52×** (CUDA vs PyTorch) ### 精度测试 + - forward: max|Δ|=0.000e+00 max rel=0.000e+00 -- backward: max|Δ|=0.000e+00 max rel=0.000e+00 \ No newline at end of file +- backward: max|Δ|=0.000e+00 max rel=0.000e+00 diff --git a/test/test_cache.py b/test/test_cache.py new file mode 100644 index 0000000..9ab894b --- /dev/null +++ b/test/test_cache.py @@ -0,0 +1,136 @@ + +import os, sys, math, subprocess, json +import torch + +# Paths to the two C++ sources (no-cache vs cache) +ROOT = os.path.dirname(os.path.abspath(__file__)) +SRC_NOCACHE = os.path.join(ROOT, "models/backend/converse2d_v1.cpp") +SRC_CACHE = os.path.join(ROOT, "models/backend/converse2d_v3.cpp") + +# Benchmark config +CASES = [ + # (B, C, H, W, scale, ksize) + (1, 3, 128, 128, 2, 5), + (2, 3, 256, 256, 2, 5), + (4, 3, 256, 256, 2, 5), + (2, 8, 256, 256, 2, 5), + (1, 3, 512, 512, 2, 5), +] +WARMUP = 10 +ITERS = 50 +DTYPE = "float32" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +def run_single_bench(src_path, tag): + # compile & run timings in a clean subprocess to avoid op name collisions + child = f''' +import os, time, torch, json +from torch.utils.cpp_extension import load + +torch.manual_seed(0) +device = "{DEVICE}" +dtype = torch.{DTYPE} + +ext = load( + name="converse2d_ext", + sources=[r\"\"\"{src_path}\"\"\"], + verbose=False, + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3","-gencode","arch=compute_89,code=sm_89"], +) + +op = torch.ops.converse2d.forward +clear = getattr(torch.ops.converse2d, "clear_cache", None) + +def synchronize(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + +def timed(fn, warmup={WARMUP}, iters={ITERS}): + for _ in range(warmup): + fn() + synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + synchronize() + t1 = time.perf_counter() + return (t1 - t0) / iters + +def bench_case(B,C,H,W,scale,ksize): + x = torch.randn(B,C,H,W, device=device, dtype=dtype, requires_grad=True) + x0 = x if scale == 1 else torch.nn.functional.interpolate(x, scale_factor=scale, mode="nearest") + weight = torch.randn(1,C,ksize,ksize, device=device, dtype=dtype, requires_grad=False) + weight = torch.nn.functional.softmax(weight.view(1,C,-1), dim=-1).view_as(weight) + bias = torch.zeros(1,C,1,1, device=device, dtype=dtype, requires_grad=False) + + if clear is not None: + clear() + + def fwd(): + with torch.no_grad(): + _ = op(x, x0, weight, bias, int(scale), float(1e-5)) + + def train(): + x_local = x.detach().clone().requires_grad_(True) + y = op(x_local, x0, weight, bias, int(scale), float(1e-5)) + loss = y.square().mean() + loss.backward() + + t_f = timed(fwd) + t_b = timed(train) + t_f_hot = timed(fwd) # hot cache + t_b_hot = timed(train) # hot cache + + def tp(B,H,W,s,t): return (B*H*s*W*s / t) / 1e9 + + return dict( + shape=(B,C,H,W,scale,ksize), + fwd_ms=t_f*1e3, bwd_ms=t_b*1e3, fwd_hot_ms=t_f_hot*1e3, bwd_hot_ms=t_b_hot*1e3, + fwd_tp=tp(B,H,W,scale,t_f), bwd_tp=tp(B,H,W,scale,t_b), + fwd_tp_hot=tp(B,H,W,scale,t_f_hot), bwd_tp_hot=tp(B,H,W,scale,t_b_hot), + ) + +rows = [] +for (B,C,H,W,s,k) in {CASES}: + rows.append(bench_case(B,C,H,W,s,k)) + +print(json.dumps(dict(tag="{tag}", rows=rows), indent=2)) +''' + out = subprocess.check_output([sys.executable, "-c", child], text=True) + return json.loads(out) + +def main(): + if DEVICE != "cuda": + print("[WARN] CUDA device not available; this script is intended for RTX 4090 tests.") + res_nc = run_single_bench(SRC_NOCACHE, "nocache") + res_cc = run_single_bench(SRC_CACHE, "cache") + + def fmt_ms(x): return f"{x:7.2f}" + print("=== Converse2D CUDA Backend: Cache vs No-Cache ===") + print(f"[Env] torch={torch.__version__}, cuda={torch.version.cuda}, device={DEVICE}") + print("case | no‑cache fwd cache fwd | no‑cache bwd cache bwd || fwd speedup bwd speedup") + print("-"*110) + for r_nc, r_c in zip(res_nc["rows"], res_cc["rows"]): + B,C,H,W,s,k = r_nc["shape"] + tag = f"B{B} C{C} {H}x{W} s{s} k{k}" + tf0, tb0 = r_nc["fwd_hot_ms"], r_nc["bwd_hot_ms"] + tf1, tb1 = r_c["fwd_hot_ms"], r_c["bwd_hot_ms"] + sp_f = tf0 / tf1 if tf1 > 0 else float('nan') + sp_b = tb0 / tb1 if tb1 > 0 else float('nan') + print(f"{tag:24s} | {fmt_ms(tf0)} {fmt_ms(tf1)} | {fmt_ms(tb0)} {fmt_ms(tb1)} || {sp_f:5.2f}× {sp_b:5.2f}×") + + # Geometric mean speedups + sps_f, sps_b = [], [] + for r_nc, r_c in zip(res_nc["rows"], res_cc["rows"]): + tf0, tb0 = r_nc["fwd_hot_ms"], r_nc["bwd_hot_ms"] + tf1, tb1 = r_c["fwd_hot_ms"], r_c["bwd_hot_ms"] + sps_f.append(tf0/tf1); sps_b.append(tb0/tb1) + def gmean(a): + a = [x for x in a if x>0 and math.isfinite(x)] + return math.exp(sum(math.log(x) for x in a)/len(a)) if a else float('nan') + print("-"*110) + print(f"Geomean speedup: Forward {gmean(sps_f):.2f}×, Backward {gmean(sps_b):.2f}×") + +if __name__ == "__main__": + main() diff --git a/test/test_error.py b/test/test_error.py index 6a01fbe..de14de1 100644 --- a/test/test_error.py +++ b/test/test_error.py @@ -1,4 +1,3 @@ -import os import torch from models.util_converse import Converse2D diff --git a/test/test_speed.py b/test/test_speed.py index 5d10efb..841c095 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -1,4 +1,3 @@ -import os import time import math import torch From e1c4d7dbc0c5bcd8647bec6e7661872f6cbb4c6e Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Sat, 30 Aug 2025 18:38:56 +0800 Subject: [PATCH 06/22] fix bug in backend choice --- models/util_converse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/util_converse.py b/models/util_converse.py index e910b3a..3fe1231 100644 --- a/models/util_converse.py +++ b/models/util_converse.py @@ -155,7 +155,7 @@ def _can_use_cuda_backend(): if not _can_use_cuda_backend(): raise RuntimeError("Converse2D backend='cuda' but CUDA extension is unavailable.") use_cuda_backend = True - elif backend == "python": + elif backend == "pytorch": use_cuda_backend = False else: # "auto" use_cuda_backend = _can_use_cuda_backend() From 19f401b1c3cf4b5a000fcb35be89de9025bbeb12 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Sat, 30 Aug 2025 20:17:23 +0800 Subject: [PATCH 07/22] add Inner cuda kernel --- Converse2D/torch_converse2d/converse2d_v3.cpp | 118 ++---------------- Converse2D/torch_converse2d/converse2d_v3.cu | 80 ++++++++++++ 2 files changed, 87 insertions(+), 111 deletions(-) create mode 100644 Converse2D/torch_converse2d/converse2d_v3.cu diff --git a/Converse2D/torch_converse2d/converse2d_v3.cpp b/Converse2D/torch_converse2d/converse2d_v3.cpp index a7ac3cc..af94ca8 100644 --- a/Converse2D/torch_converse2d/converse2d_v3.cpp +++ b/Converse2D/torch_converse2d/converse2d_v3.cpp @@ -10,14 +10,15 @@ #include #include #include -#include -#include + #include #include #include using at::Tensor; +Tensor block_mean_cuda(const Tensor &input, int64_t s); + // ---------- FB Cache ---------- struct FBKey { @@ -47,72 +48,13 @@ namespace std ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); } }; -} // namespace std +} constexpr size_t FB_CACHE_MAX_SIZE = 64; static std::unordered_map> fb_cache; static std::list fb_cache_lru; static std::mutex fb_cache_mutex; -__global__ void block_mean_kernel( - const float *__restrict__ input, - float *__restrict__ output, - int B, int C, int H, int W, int s) -{ - int b = blockIdx.z; - int c = blockIdx.y; - int h = blockIdx.x / W; - int w = blockIdx.x % W; - - if (h >= H || w >= W) - return; - - int Hs = H * s; - int Ws = W * s; - - float sum = 0.0f; - for (int i = 0; i < s; ++i) - { - for (int j = 0; j < s; ++j) - { - int hs = h * s + i; - int ws = w * s + j; - int idx = ((b * C + c) * Hs + hs) * Ws + ws; - sum += input[idx]; - } - } - - int out_idx = ((b * C + c) * H + h) * W + w; - output[out_idx] = sum / (s * s); -} - -Tensor block_mean_cuda(const Tensor &input, int64_t s) -{ - TORCH_CHECK(input.dim() == 4, "input must be (B,C,Hs,Ws)"); - TORCH_CHECK(input.device().is_cuda(), "input must be CUDA tensor"); - - const int64_t B = input.size(0); - const int64_t C = input.size(1); - const int64_t Hs = input.size(2); - const int64_t Ws = input.size(3); - TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "Hs and Ws must be divisible by s"); - - const int64_t H = Hs / s; - const int64_t W = Ws / s; - - Tensor output = at::empty({B, C, H, W}, input.options()); - - const int threads = 1; - const dim3 blocks(H * W, C, B); - - block_mean_kernel<<>>( - input.data_ptr(), - output.data_ptr(), - B, C, H, W, s); - - return output; -} - static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) { auto C = psf.size(1); @@ -151,7 +93,6 @@ static inline std::pair p2o_cached(const Tensor &psf, int64_t H, return {FB, F2B}; } -// ---------- Utility ---------- static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) { if (s == 1) @@ -167,47 +108,6 @@ static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) return z; } -static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) -{ - const auto &sizes = a.sizes(); - const int64_t L = a.dim(); - const int64_t W = sizes[L - 2]; - const int64_t H = sizes[L - 1]; - const int64_t W_s = W / s; - const int64_t H_s = H / s; - - std::vector view_shape; - view_shape.reserve(L + 2); - for (int64_t i = 0; i < L - 2; ++i) - view_shape.push_back(sizes[i]); - view_shape.push_back(s); - view_shape.push_back(W_s); - view_shape.push_back(s); - view_shape.push_back(H_s); - Tensor v = a.view(view_shape); - - std::vector perm; - perm.reserve(view_shape.size()); - for (int64_t i = 0; i < L - 2; ++i) - perm.push_back(i); - perm.push_back(L - 2 + 1); - perm.push_back(L - 2 + 3); - perm.push_back(L - 2 + 0); - perm.push_back(L - 2 + 2); - Tensor p = v.permute(perm).contiguous(); - - std::vector merge_shape; - merge_shape.reserve(L + 1); - for (int64_t i = 0; i < L - 2; ++i) - merge_shape.push_back(p.size(i)); - merge_shape.push_back(W_s); - merge_shape.push_back(H_s); - merge_shape.push_back(s * s); - Tensor r = p.view(merge_shape); - - return r.mean(-1, /*keepdim=*/false); -} - // ---------- Forward ---------- Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) { @@ -240,16 +140,13 @@ Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64 Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); Tensor x1 = FB * FR; - // Tensor FBR = splits_mean_then_mean(x1, scale); - // Tensor invW = splits_mean_then_mean(F2B, scale); - Tensor FBR = block_mean_cuda(x1, scale); - Tensor invW = block_mean_cuda(F2B, scale); + + Tensor FBR = block_mean_cuda(x1, scale); // (B,C,H,W) + Tensor invW = block_mean_cuda(F2B, scale); // (B,C,H,W) Tensor invW_plus = invW + lambda_; Tensor invWBR = FBR / invW_plus; - // ---- 关键替换:避免 repeat 物化,使用广播展开(零 stride) ---- - // 形状: (B,C,H,1,W,1) -> expand 为 (B,C,H,scale,W,scale) -> reshape 为 (B,C,Hs,Ws) Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) .expand({B, C, H, scale, W, scale}) .reshape({B, C, Hs, Ws}); @@ -281,4 +178,3 @@ TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) m.impl("clear_cache", TORCH_FN(clear_fb_cache)); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} -s \ No newline at end of file diff --git a/Converse2D/torch_converse2d/converse2d_v3.cu b/Converse2D/torch_converse2d/converse2d_v3.cu new file mode 100644 index 0000000..e50633b --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v3.cu @@ -0,0 +1,80 @@ +#include +#include +#include + +using at::Tensor; + +template +__global__ void block_mean_kernel( + const scalar_t *__restrict__ input, + scalar_t *__restrict__ output, + int64_t B, int64_t C, int64_t H, int64_t W, int64_t s, + int64_t Hs, int64_t Ws, int64_t n_out // B*C*H*W +) +{ + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_out) + return; + + int64_t w = tid % W; + int64_t t1 = tid / W; + int64_t h = t1 % H; + int64_t t2 = t1 / H; + int64_t c = t2 % C; + int64_t b = t2 / C; + + const int64_t hs0 = h * s; + const int64_t ws0 = w * s; + + const int64_t in_bc_off = ((b * C + c) * Hs); + scalar_t sum = scalar_t(0); + + for (int64_t di = 0; di < s; ++di) + { + const int64_t hs = hs0 + di; + const int64_t row_off = (in_bc_off + hs) * Ws; + for (int64_t dj = 0; dj < s; ++dj) + { + const int64_t ws = ws0 + dj; + sum += input[row_off + ws]; + } + } + + using value_t = typename c10::scalar_value_type::type; + const value_t denom = static_cast(s * s); + output[tid] = sum / denom; +} + +Tensor block_mean_cuda(const Tensor &input, int64_t s) +{ + TORCH_CHECK(input.is_cuda(), "block_mean_cuda: input must be CUDA"); + TORCH_CHECK(input.dim() == 4, "block_mean_cuda: input must be (B,C,Hs,Ws)"); + TORCH_CHECK(s > 0, "block_mean_cuda: s must be > 0"); + const int64_t B = input.size(0); + const int64_t C = input.size(1); + const int64_t Hs = input.size(2); + const int64_t Ws = input.size(3); + TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "Hs and Ws must be divisible by s"); + + const int64_t H = Hs / s; + const int64_t W = Ws / s; + + Tensor output = at::empty({B, C, H, W}, input.options()); + const int64_t n_out = B * C * H * W; + + const int threads = 256; + const int blocks = static_cast((n_out + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::kComplexFloat, at::kComplexDouble, + input.scalar_type(), "block_mean_kernel", [&] + { block_mean_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + B, C, H, W, s, Hs, Ws, n_out); }); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return output; +} + +TORCH_LIBRARY_FRAGMENT(converse2d_v3, m) {} From 12a86e618f176dee3726dc815f8947dcc2ef83a2 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Sat, 30 Aug 2025 21:27:38 +0800 Subject: [PATCH 08/22] Modify STy --- Converse2D/torch_converse2d/converse2d_v3.cpp | 3 +- Converse2D/torch_converse2d/converse2d_v3.cu | 174 +++++++---- Converse2D/torch_converse2d/converse2d_v4.cpp | 171 +++++++++++ Converse2D/torch_converse2d/converse2d_v4.cu | 277 ++++++++++++++++++ 4 files changed, 566 insertions(+), 59 deletions(-) create mode 100644 Converse2D/torch_converse2d/converse2d_v4.cpp create mode 100644 Converse2D/torch_converse2d/converse2d_v4.cu diff --git a/Converse2D/torch_converse2d/converse2d_v3.cpp b/Converse2D/torch_converse2d/converse2d_v3.cpp index af94ca8..7a1c460 100644 --- a/Converse2D/torch_converse2d/converse2d_v3.cpp +++ b/Converse2D/torch_converse2d/converse2d_v3.cpp @@ -1,3 +1,4 @@ +// backend/converse2d_v3.cpp #include #include #include @@ -48,7 +49,7 @@ namespace std ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); } }; -} +} // namespace std constexpr size_t FB_CACHE_MAX_SIZE = 64; static std::unordered_map> fb_cache; diff --git a/Converse2D/torch_converse2d/converse2d_v3.cu b/Converse2D/torch_converse2d/converse2d_v3.cu index e50633b..40f838e 100644 --- a/Converse2D/torch_converse2d/converse2d_v3.cu +++ b/Converse2D/torch_converse2d/converse2d_v3.cu @@ -1,80 +1,138 @@ #include #include +#include #include +#include -using at::Tensor; +// ====================== +// block mean (forward): +// in : (B,C,Hs,Ws) +// out: (B,C,Ho,Wo), Ho=Hs/s, Wo=Ws/s +// ====================== +template +struct AccT +{ + using type = T; +}; +template <> +struct AccT +{ + using type = float; +}; +template <> +struct AccT +{ + using type = float; +}; template __global__ void block_mean_kernel( - const scalar_t *__restrict__ input, - scalar_t *__restrict__ output, - int64_t B, int64_t C, int64_t H, int64_t W, int64_t s, - int64_t Hs, int64_t Ws, int64_t n_out // B*C*H*W -) + const scalar_t *__restrict__ in, // (B,C,Hs,Ws) + scalar_t *__restrict__ out, // (B,C,Ho,Wo) + int B, int C, int Ho, int Wo, int s, int Hs, int Ws, + long long total_out) { - int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= n_out) + using acc_t = typename AccT::type; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_out) return; - int64_t w = tid % W; - int64_t t1 = tid / W; - int64_t h = t1 % H; - int64_t t2 = t1 / H; - int64_t c = t2 % C; - int64_t b = t2 / C; + int wo = static_cast(idx % Wo); + int ho = static_cast((idx / Wo) % Ho); + int c = static_cast((idx / (1LL * Wo * Ho)) % C); + int b = static_cast(idx / (1LL * Wo * Ho * C)); - const int64_t hs0 = h * s; - const int64_t ws0 = w * s; + const int hi0 = ho * s; + const int wi0 = wo * s; - const int64_t in_bc_off = ((b * C + c) * Hs); - scalar_t sum = scalar_t(0); + const long long base_in = ((long long)b * C + c) * Hs * Ws; - for (int64_t di = 0; di < s; ++di) + acc_t acc = acc_t(0); + for (int di = 0; di < s; ++di) { - const int64_t hs = hs0 + di; - const int64_t row_off = (in_bc_off + hs) * Ws; - for (int64_t dj = 0; dj < s; ++dj) + const int hi = hi0 + di; + const long long row_off = base_in + (long long)hi * Ws + wi0; +#pragma unroll + for (int dj = 0; dj < s; ++dj) { - const int64_t ws = ws0 + dj; - sum += input[row_off + ws]; + acc += static_cast(in[row_off + dj]); } } + const float inv_area = 1.0f / (s * s); + acc = acc * static_cast(inv_area); - using value_t = typename c10::scalar_value_type::type; - const value_t denom = static_cast(s * s); - output[tid] = sum / denom; + const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; + out[out_off] = static_cast(acc); } -Tensor block_mean_cuda(const Tensor &input, int64_t s) +struct BlockMeanFunctionV3 : public torch::autograd::Function { - TORCH_CHECK(input.is_cuda(), "block_mean_cuda: input must be CUDA"); - TORCH_CHECK(input.dim() == 4, "block_mean_cuda: input must be (B,C,Hs,Ws)"); - TORCH_CHECK(s > 0, "block_mean_cuda: s must be > 0"); - const int64_t B = input.size(0); - const int64_t C = input.size(1); - const int64_t Hs = input.size(2); - const int64_t Ws = input.size(3); - TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "Hs and Ws must be divisible by s"); - - const int64_t H = Hs / s; - const int64_t W = Ws / s; - - Tensor output = at::empty({B, C, H, W}, input.options()); - const int64_t n_out = B * C * H * W; - - const int threads = 256; - const int blocks = static_cast((n_out + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::kComplexFloat, at::kComplexDouble, - input.scalar_type(), "block_mean_kernel", [&] - { block_mean_kernel<<>>( - input.data_ptr(), - output.data_ptr(), - B, C, H, W, s, Hs, Ws, n_out); }); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); - return output; -} + static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &input, int64_t s) + { + TORCH_CHECK(input.is_cuda() && input.dim() == 4, "block_mean_cuda: input must be (B,C,Hs,Ws) CUDA"); + TORCH_CHECK(s >= 1, "block_mean_cuda: s must be >= 1"); -TORCH_LIBRARY_FRAGMENT(converse2d_v3, m) {} + auto x = input.contiguous(); + const int B = (int)x.size(0); + const int C = (int)x.size(1); + const int Hs = (int)x.size(2); + const int Ws = (int)x.size(3); + + TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "block_mean_cuda: H,W must be divisible by s"); + const int Ho = Hs / (int)s; + const int Wo = Ws / (int)s; + + auto out = at::empty({B, C, Ho, Wo}, x.options()); + + // launch forward kernel + { + const long long total_out = 1LL * B * C * Ho * Wo; + const int threads = 256; + const int blocks = (int)((total_out + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::kBFloat16, + x.scalar_type(), "block_mean_v3_fwd", [&] + { block_mean_kernel<<>>( + x.data_ptr(), out.data_ptr(), + B, C, Ho, Wo, (int)s, Hs, Ws, total_out); }); + } + + // save for backward + ctx->saved_data["B"] = (int64_t)B; + ctx->saved_data["C"] = (int64_t)C; + ctx->saved_data["Hs"] = (int64_t)Hs; + ctx->saved_data["Ws"] = (int64_t)Ws; + ctx->saved_data["s"] = (int64_t)s; + return out; + } + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs) + { + auto go = grad_outputs[0]; // (B,C,Ho,Wo) + const int B = (int)ctx->saved_data["B"].toInt(); + const int C = (int)ctx->saved_data["C"].toInt(); + const int Hs = (int)ctx->saved_data["Hs"].toInt(); + const int Ws = (int)ctx->saved_data["Ws"].toInt(); + const int s = (int)ctx->saved_data["s"].toInt(); + + const int Ho = Hs / s; + const int Wo = Ws / s; + + // gi = expand( go / (s*s), dims=[B,C,Ho,1,Wo,1] ) -> reshape(B,C,Hs,Ws) + auto go_scaled = go / static_cast(s * s); + auto gi = go_scaled.view({B, C, Ho, 1, Wo, 1}) + .expand({B, C, Ho, s, Wo, s}) + .reshape({B, C, Hs, Ws}) + .contiguous(); + + return {gi, torch::Tensor()}; // no grad for s + } +}; + +at::Tensor block_mean_cuda(const at::Tensor &input, int64_t s) +{ + return BlockMeanFunctionV3::apply(input, s); +} diff --git a/Converse2D/torch_converse2d/converse2d_v4.cpp b/Converse2D/torch_converse2d/converse2d_v4.cpp new file mode 100644 index 0000000..382cd7b --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v4.cpp @@ -0,0 +1,171 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using at::Tensor; + +Tensor block_mean_cuda(const Tensor &input, int64_t s); +Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t scale); + +// ---------- FB Cache ---------- +struct FBKey +{ + int64_t device_id; + at::ScalarType dtype; + int64_t channels; + int64_t H, W; + void *ptr; + + bool operator==(const FBKey &other) const + { + return device_id == other.device_id && dtype == other.dtype && + channels == other.channels && H == other.H && W == other.W && + ptr == other.ptr; + } +}; + +namespace std +{ + template <> + struct hash + { + size_t operator()(const FBKey &k) const + { + return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ + ((hash()(k.H) ^ hash()(k.W)) << 1) ^ + ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + } + }; +} // namespace std + +constexpr size_t FB_CACHE_MAX_SIZE = 64; +static std::unordered_map> fb_cache; +static std::list fb_cache_lru; +static std::mutex fb_cache_mutex; + +static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) +{ + auto C = psf.size(1); + FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; + + { + std::lock_guard lock(fb_cache_mutex); + auto it = fb_cache.find(key); + if (it != fb_cache.end()) + { + fb_cache_lru.remove(key); + fb_cache_lru.push_front(key); + return it->second; + } + } + + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + int64_t kh = psf.size(2), kw = psf.size(3); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); + Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor F2B = at::abs(FB).pow(2); + + { + std::lock_guard lock(fb_cache_mutex); + fb_cache[key] = {FB, F2B}; + fb_cache_lru.push_front(key); + + if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) + { + fb_cache.erase(fb_cache_lru.back()); + fb_cache_lru.pop_back(); + } + } + + return {FB, F2B}; +} + +static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) +{ + if (s == 1) + return x; + return sfold_upsample_cuda_launcher(x, s); +} + +// ---------- Forward ---------- +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +{ + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); + + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; + + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + Tensor STy = sfold_upsample_zero_insertion(x, scale); + + auto [FB, F2B] = p2o_cached(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + + Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); + + Tensor x1 = FB * FR; + + Tensor FBR = block_mean_cuda(x1, scale); // (B,C,H,W) + Tensor invW = block_mean_cuda(F2B, scale); // (B,C,H,W) + + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; + + Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) + .expand({B, C, H, scale, W, scale}) + .reshape({B, C, Hs, Ws}); + Tensor FCBinvWBR = FBC * invWBR_exp; + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor out = at::real(out_c); + return out; +} + +void clear_fb_cache() +{ + std::lock_guard lock(fb_cache_mutex); + fb_cache.clear(); + fb_cache_lru.clear(); +} + +TORCH_LIBRARY(converse2d, m) +{ + m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); + m.def("clear_cache() -> ()"); +} +TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) +{ + m.impl("forward", TORCH_FN(converse2d_forward)); + m.impl("clear_cache", TORCH_FN(clear_fb_cache)); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v4.cu b/Converse2D/torch_converse2d/converse2d_v4.cu new file mode 100644 index 0000000..e02250c --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v4.cu @@ -0,0 +1,277 @@ +#include +#include +#include +#include +#include +#include +#include + +// ====================================================================== +// S-FOLD UPSAMPLE (zero-insertion upsample) +// forward: out[b,c,h*s, w*s] = x[b,c,h,w]; others = 0 +// backward: grad_x[b,c,h,w] = grad_out[b,c,h*s, w*s] +// dtypes: float/double/half/bfloat16 +// ====================================================================== + +using namespace at; +using namespace at::indexing; + +template +__global__ void sfold_upsample_kernel( + const scalar_t *__restrict__ x, + scalar_t *__restrict__ out, + int B, int C, int H, int W, int s, + int Hs, int Ws, long long total_in) +{ + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int w = static_cast(idx % W); + int h = static_cast((idx / W) % H); + int c = static_cast((idx / (1LL * W * H)) % C); + int b = static_cast(idx / (1LL * W * H * C)); + + long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; + long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); + + out[out_off] = x[in_off]; +} + +template +__global__ void sfold_downsample_grad_kernel( // backward of zero-insertion upsample + const scalar_t *__restrict__ grad_out, // (B,C,Hs,Ws) + scalar_t *__restrict__ grad_in, // (B,C,H,W) + int B, int C, int H, int W, int s, int Hs, int Ws, long long total_in) +{ + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int w = static_cast(idx % W); + int h = static_cast((idx / W) % H); + int c = static_cast((idx / (1LL * W * H)) % C); + int b = static_cast(idx / (1LL * W * H * C)); + + long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; + long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); + + grad_in[in_off] = grad_out[out_off]; +} + +struct SFoldFunction : public torch::autograd::Function +{ + static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &x, int64_t scale) + { + TORCH_CHECK(x.is_cuda() && x.dim() == 4, "sfold: x must be (B,C,H,W) CUDA"); + TORCH_CHECK(scale >= 1, "sfold: scale must be >= 1"); + if (scale == 1) + { + ctx->saved_data["s"] = (int64_t)1; + return x; + } + + auto x_ = x.contiguous(); + const int B = (int)x_.size(0), C = (int)x_.size(1), H = (int)x_.size(2), W = (int)x_.size(3); + const int s = (int)scale, Hs = H * s, Ws = W * s; + + auto out = at::zeros({B, C, Hs, Ws}, x_.options()); + + const long long total = 1LL * B * C * H * W; + const int threads = 256, blocks = (int)((total + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x_.scalar_type(), "sfold_fwd", [&] + { sfold_upsample_kernel<<>>( + x_.data_ptr(), out.data_ptr(), + B, C, H, W, s, Hs, Ws, total); }); + + // save for backward + ctx->saved_data["B"] = (int64_t)B; + ctx->saved_data["C"] = (int64_t)C; + ctx->saved_data["H"] = (int64_t)H; + ctx->saved_data["W"] = (int64_t)W; + ctx->saved_data["s"] = (int64_t)s; + return out; + } + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs) + { + auto go = grad_outputs[0]; // (B,C,Hs,Ws) + const int B = (int)ctx->saved_data["B"].toInt(); + const int C = (int)ctx->saved_data["C"].toInt(); + const int H = (int)ctx->saved_data["H"].toInt(); + const int W = (int)ctx->saved_data["W"].toInt(); + const int s = (int)ctx->saved_data["s"].toInt(); + const int Hs = H * s, Ws = W * s; + + at::Tensor gx; + if (s == 1) + { + gx = go; // identity + } + else + { + gx = go.index({Slice(), Slice(), Slice(0, Hs, s), Slice(0, Ws, s)}).contiguous(); + } + return {gx, torch::Tensor()}; // no grad for scale + } +}; + +// exposed symbol for v4.cpp +at::Tensor sfold_upsample_cuda_launcher(const at::Tensor &x, int64_t scale) +{ + return SFoldFunction::apply(x, scale); +} + +// ====================================================================== +// BLOCK MEAN over non-overlapping s×s tiles +// forward: out[b,c,ho,wo] = mean_{i,j in s×s} in[b,c, ho*s+i, wo*s+j] +// backward: grad_in[b,c,hi,wi] = grad_out[b,c,hi/s, wi/s] / (s*s) +// dtypes: float/double/half/bfloat16 + complex64/complex128 +// ====================================================================== + +template +struct AccT +{ + using type = T; +}; +template <> +struct AccT +{ + using type = float; +}; +template <> +struct AccT +{ + using type = float; +}; + +template +__global__ void block_mean_kernel( + const scalar_t *__restrict__ in, // (B,C,Hs,Ws) + scalar_t *__restrict__ out, // (B,C,Ho,Wo) + int B, int C, int Ho, int Wo, int s, int Hs, int Ws, + long long total_out) +{ + using acc_t = typename AccT::type; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_out) + return; + + int wo = static_cast(idx % Wo); + int ho = static_cast((idx / Wo) % Ho); + int c = static_cast((idx / (1LL * Wo * Ho)) % C); + int b = static_cast(idx / (1LL * Wo * Ho * C)); + + const int hi0 = ho * s; + const int wi0 = wo * s; + + const long long base_in = ((long long)b * C + c) * Hs * Ws; + + acc_t acc = acc_t(0); + for (int di = 0; di < s; ++di) + { + const int hi = hi0 + di; + const long long row_off = base_in + (long long)hi * Ws + wi0; +#pragma unroll + for (int dj = 0; dj < s; ++dj) + { + acc += static_cast(in[row_off + dj]); + } + } + const float inv_area = 1.0f / (s * s); + acc = acc * static_cast(inv_area); + + const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; + out[out_off] = static_cast(acc); +} + +template +__global__ void block_mean_grad_kernel( + const scalar_t *__restrict__ grad_out, // (B,C,Ho,Wo) + scalar_t *__restrict__ grad_in, // (B,C,Hs,Ws) + int B, int C, int Ho, int Wo, int s, int Hs, int Ws, + long long total_in) +{ + using acc_t = typename AccT::type; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int wi = static_cast(idx % Ws); + int hi = static_cast((idx / Ws) % Hs); + int c = static_cast((idx / (1LL * Ws * Hs)) % C); + int b = static_cast(idx / (1LL * Ws * Hs * C)); + + const int ho = hi / s; + const int wo = wi / s; + + const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; + acc_t g = static_cast(grad_out[out_off]) * static_cast(1.0f / (s * s)); + + const long long in_off = ((long long)b * C + c) * Hs * Ws + (long long)hi * Ws + wi; + grad_in[in_off] = static_cast(g); +} + +struct BlockMeanFunction : public torch::autograd::Function +{ + static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &input, int64_t s) + { + TORCH_CHECK(input.is_cuda() && input.dim() == 4, "block_mean: input must be (B,C,Hs,Ws) CUDA"); + TORCH_CHECK(s >= 1, "block_mean: s must be >= 1"); + + auto x = input.contiguous(); + const int B = (int)x.size(0), C = (int)x.size(1), Hs = (int)x.size(2), Ws = (int)x.size(3); + TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "block_mean: H,W must be divisible by s"); + const int Ho = Hs / (int)s, Wo = Ws / (int)s; + + auto out = at::empty({B, C, Ho, Wo}, x.options()); + + const long long total_out = 1LL * B * C * Ho * Wo; + const int threads = 256, blocks = (int)((total_out + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "block_mean_fwd", [&] + { block_mean_kernel<<>>( + x.data_ptr(), out.data_ptr(), + B, C, Ho, Wo, (int)s, Hs, Ws, total_out); }); + + // save for backward + ctx->saved_data["B"] = (int64_t)B; + ctx->saved_data["C"] = (int64_t)C; + ctx->saved_data["Hs"] = (int64_t)Hs; + ctx->saved_data["Ws"] = (int64_t)Ws; + ctx->saved_data["s"] = (int64_t)s; + return out; + } + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs) + { + auto go = grad_outputs[0]; // (B,C,Ho,Wo) + const int B = (int)ctx->saved_data["B"].toInt(); + const int C = (int)ctx->saved_data["C"].toInt(); + const int Hs = (int)ctx->saved_data["Hs"].toInt(); + const int Ws = (int)ctx->saved_data["Ws"].toInt(); + const int s = (int)ctx->saved_data["s"].toInt(); + const int Ho = Hs / s, Wo = Ws / s; + + auto go_scaled = go / static_cast(s * s); + auto gi = go_scaled.view({B, C, Ho, 1, Wo, 1}) + .expand({B, C, Ho, s, Wo, s}) + .reshape({B, C, Hs, Ws}) + .contiguous(); + + return {gi, torch::Tensor()}; // no grad for s + } +}; + +// exposed symbol for v4.cpp +at::Tensor block_mean_cuda(const at::Tensor &input, int64_t s) +{ + return BlockMeanFunction::apply(input, s); +} From 025c867ecbf6630636dca5d77ee1f19a62b3aaf6 Mon Sep 17 00:00:00 2001 From: Boyce <1473416941@qq.com> Date: Sat, 30 Aug 2025 23:10:44 +0800 Subject: [PATCH 09/22] add Speedtest results --- test/Results.md | 40 ----- test/results.csv | 181 +++++++++++++++++++++ test/test_speed.py | 384 ++++++++++++++++++++++++++------------------- 3 files changed, 404 insertions(+), 201 deletions(-) delete mode 100644 test/Results.md create mode 100644 test/results.csv diff --git a/test/Results.md b/test/Results.md deleted file mode 100644 index 2da77aa..0000000 --- a/test/Results.md +++ /dev/null @@ -1,40 +0,0 @@ -# 环境配置 - -- **设备**: 4090 -- **PyTorch版本**: `2.4.0+cu121` -- **CUDA版本**: `12.1` -- **cuDNN版本**: `90100` - ---- - -# 测试配置 - -- **数据类型**: `['torch.float32']` -- **AMP**: `False` -- **模式**: `TRAIN=True`, `INFER=True` -- **预热迭代次数**: `10` -- **总迭代次数**: `50` - ---- - -# 逐例对比 (CUDA vs PyTorch) - -| dtype | B | C | H | W | s | k | fwd_py(ms) | fwd_cu(ms) | fwd_Gpix/s(py) | fwd_Gpix/s(cu) | fwd_speedup | bwd_py(ms) | bwd_cu(ms) | bwd_Gpix/s(py) | bwd_Gpix/s(cu) | bwd_speedup | -| :-----: | :-: | :-: | :-: | :-: | :-: | :-: | :--------: | :--------: | :------------: | :------------: | :---------: | :--------: | :--------: | :------------: | :------------: | :---------: | -| float32 | 1 | 3 | 128 | 128 | 2 | 5 | 4.62 | 0.60 | 0.014 | 0.110 | 7.73× | 17.57 | 2.29 | 0.004 | 0.029 | 7.68× | -| float32 | 2 | 3 | 256 | 256 | 2 | 5 | 10.21 | 0.61 | 0.051 | 0.863 | 16.81× | 19.87 | 2.73 | 0.026 | 0.192 | 7.27× | -| float32 | 4 | 3 | 256 | 256 | 2 | 5 | 15.71 | 0.71 | 0.067 | 1.476 | 22.12× | 24.21 | 3.58 | 0.043 | 0.293 | 6.77× | -| float32 | 2 | 8 | 256 | 256 | 2 | 5 | 22.28 | 1.16 | 0.024 | 0.454 | 19.27× | 31.86 | 4.14 | 0.016 | 0.127 | 7.70× | -| float32 | 1 | 3 | 512 | 512 | 2 | 5 | 20.03 | 1.30 | 0.052 | 0.806 | 15.41× | 32.28 | 3.90 | 0.032 | 0.269 | 8.27× | - ---- - -### 总体表现 - -- **整体前向传播(Forward)几何平均加速比**: **15.35×** (CUDA vs PyTorch) -- **整体训练(Train)几何平均加速比**: **7.52×** (CUDA vs PyTorch) - -### 精度测试 - -- forward: max|Δ|=0.000e+00 max rel=0.000e+00 -- backward: max|Δ|=0.000e+00 max rel=0.000e+00 diff --git a/test/results.csv b/test/results.csv new file mode 100644 index 0000000..ecca0b3 --- /dev/null +++ b/test/results.csv @@ -0,0 +1,181 @@ +variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters +pytorch,1,3,128,128,1,3,1.52592733502388,0.8647029753774405,2.298735734075308,0.010737077463615656,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,3,0.44544885866343975,0.4432220011949539,0.472044013440609,0.036780877717724675,3.425594892312428,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,3,0.3007301315665245,0.29844650998711586,0.3108557313680649,0.054480739640735645,5.074075308234717,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,3,0.30307079665362835,0.2994614187628031,0.3240731079131365,0.0540599760217902,5.03488739882722,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,3,0.32072700560092926,0.32319221645593643,0.3775870893150568,0.051083942773394356,4.757713907392458,10,float32,cuda,50 +pytorch,1,3,128,128,1,5,4.0221707709133625,0.9404211305081844,7.168814446777105,0.004073422271993561,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,5,0.4805893823504448,0.4761132877320051,0.5044737830758095,0.0340914730988643,8.369246010475537,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,5,0.3053080663084984,0.3032265231013298,0.3142551053315401,0.053663829449709974,13.174138566156198,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,5,0.2769072540104389,0.2749105915427208,0.28512105345726013,0.05916782519313254,14.525335514546446,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,5,0.27410948649048805,0.27215038426220417,0.2808789722621441,0.05977173650488943,14.673592010296721,10,float32,cuda,50 +pytorch,1,3,128,128,1,7,3.5734746791422367,0.8284670766443014,10.36820076406002,0.004584893268065006,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,7,0.5126208532601595,0.5032145418226719,0.5624458193778992,0.03196124366732498,6.970989682561095,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,7,0.30147168785333633,0.299196457490325,0.31120297499001026,0.05434672859884173,11.853433748912119,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,7,0.2754225581884384,0.27269101701676846,0.28396081179380417,0.05948677591176249,12.974517057158911,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,7,0.27901765890419483,0.2780151553452015,0.2874089404940605,0.058720297720029645,12.807342349500704,10,float32,cuda,50 +pytorch,1,3,128,128,2,3,5.1218782644718885,1.200301107019186,8.238791720941663,0.012795306060784977,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,3,0.48394261859357357,0.46004820615053177,0.5499029066413641,0.13542101373600798,10.583647869983038,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,3,0.3488078713417053,0.3472878597676754,0.3549169283360243,0.18788566825603104,14.683952643529377,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,3,0.34828370437026024,0.34455815330147743,0.36369492299854755,0.18816843618479692,14.706051992104753,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,3,0.3102908004075289,0.30851690098643303,0.3184992354363203,0.2112083243007092,16.506703575307196,10,float32,cuda,50 +pytorch,1,3,128,128,2,5,2.7414161060005426,1.1525587178766727,3.5016948357224464,0.02390589296406032,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,5,0.5049472488462925,0.4589471500366926,0.6366008426994085,0.1297878147662695,5.429113857465611,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,5,0.3760635666549206,0.37418887950479984,0.38898889906704426,0.1742684104789562,7.28976787191962,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,5,0.3826252557337284,0.3617340698838234,0.443447008728981,0.17127985938703158,7.1647547173632296,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,5,0.33498174510896206,0.326463021337986,0.3694041632115841,0.1956405116305147,8.183777611848136,10,float32,cuda,50 +pytorch,1,3,128,128,2,7,6.999819874763489,1.7441920936107635,18.960934737697244,0.009362526632474858,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,7,0.4698681924492121,0.46762311831116676,0.4844237584620714,0.13947741314088583,14.897411630007488,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,7,0.37914127111434937,0.36833412013947964,0.40491526015102863,0.1728537750780349,18.462299960618996,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,7,0.44248790480196476,0.44069206342101097,0.4555768799036741,0.14810800315396327,15.819234376352625,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,7,0.3105806838721037,0.3085271455347538,0.3185570705682039,0.2110111909824616,22.53784680841902,10,float32,cuda,50 +pytorch,1,3,128,128,3,3,7.3913106974214315,1.91159313544631,23.464363627135754,0.019949912273535222,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,3,0.523065309971571,0.5227448418736458,0.5595123395323753,0.281907435245542,14.130760645975844,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,3,0.3566680010408163,0.3553489223122597,0.3626151941716671,0.4134264906571348,20.72322349033937,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,3,0.33742536790668964,0.3273128531873226,0.36468892358243465,0.43700330213695415,21.90502374873426,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,3,0.31496345065534115,0.31304731965065,0.3236026968806982,0.46816860716121145,23.467201296031043,10,float32,cuda,50 +pytorch,1,3,128,128,3,5,3.313328195363283,1.4438305515795946,5.463926354423165,0.044503891949596766,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,5,0.4630151018500328,0.4609823226928711,0.4760188050568104,0.31846909401188367,7.155982995207888,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,5,0.36386603489518166,0.35751843824982643,0.3770098090171814,0.40524804697002687,9.105901286768216,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,5,0.32810534350574017,0.3216217737644911,0.356891006231308,0.44941663681688954,10.098367066994497,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,5,0.3466991614550352,0.3456580452620983,0.357994856312871,0.4253139793622609,9.556781682014543,10,float32,cuda,50 +pytorch,1,3,128,128,3,7,6.756937270984054,1.7321815248578787,10.7049988117069,0.021822904976965263,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,7,0.4557925648987293,0.45467307791113853,0.465529877692461,0.32351558879149916,14.824588620670786,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,7,0.38067104294896126,0.3792489878833294,0.39135636761784554,0.3873580686823354,17.75006897987246,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,7,0.32470209524035454,0.3225074615329504,0.3330751322209835,0.454127035708989,20.80965096939815,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,7,0.3528321161866188,0.3392628859728575,0.38030720315873623,0.41792113936138414,19.150573207486012,10,float32,cuda,50 +pytorch,1,3,128,256,1,3,4.7790092043578625,1.1113823857158422,11.791642662137747,0.006856651368262623,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,3,0.43059052899479866,0.4229459445923567,0.4630208481103182,0.07610014107020878,11.098732746199326,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,3,0.3584872093051672,0.3435080870985985,0.4171540029346943,0.09140632956894645,13.331045237627055,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,3,0.31005123630166054,0.30543701723217964,0.3326671663671732,0.10568575823422545,15.413611186856176,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,3,0.274530379101634,0.2726605162024498,0.28227311559021473,0.11936019651897593,17.407943048039222,10,float32,cuda,50 +pytorch,1,3,128,256,1,5,3.1171874701976776,0.8945895824581385,2.6657020207494497,0.010512040200752509,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,5,0.48769986256957054,0.4741228185594082,0.5690208170562983,0.06718886453515381,6.391610310845659,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,5,0.34617194905877113,0.3371278289705515,0.35091196186840534,0.0946581607466896,9.004737323960525,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,5,0.29009729623794556,0.290280906483531,0.31339898705482483,0.11295520649431635,10.745317211232734,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,5,0.27919040992856026,0.2755909226834774,0.2894400618970394,0.11736792824791058,11.165095072553921,10,float32,cuda,50 +pytorch,1,3,128,256,1,7,0.8610220160335302,0.8528372272849083,0.8728299289941788,0.03805709887762491,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,7,0.4062088765203953,0.4038410261273384,0.4159193020313978,0.08066785807511706,2.119653374906738,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,7,0.3033390734344721,0.29832683503627777,0.3132038749754429,0.10802432943766017,2.8384804050623895,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,7,0.2987572457641363,0.2943666186183691,0.31618126668035984,0.10968102184831952,2.8820121628557662,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,7,0.2784122433513403,0.2752200234681368,0.2866980619728565,0.11769597344412998,3.09261548870525,10,float32,cuda,50 +pytorch,1,3,128,256,2,3,4.812479577958584,1.4287945814430714,9.668499417603016,0.027235855836213175,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,3,0.4676768183708191,0.4681474529206753,0.48343208618462086,0.2802619134653656,10.290181999448148,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,3,0.3709117043763399,0.3670940641313791,0.3977825865149498,0.35337790221634474,12.974730970138582,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,3,0.40193247608840466,0.41314586997032166,0.4363299813121557,0.32610452699813897,11.973353397051905,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,3,0.34459170885384083,0.34502334892749786,0.3664530348032713,0.3803689892480681,13.965743963966949,10,float32,cuda,50 +pytorch,1,3,128,256,2,5,3.534023268148303,1.3921631034463644,7.826935639604926,0.03708860696570254,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,5,0.4526952747255564,0.45057223178446293,0.4625048488378525,0.28953692984637747,7.806627251169746,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,5,0.35566513426601887,0.34616305492818356,0.3815658390522003,0.36852642379605594,9.936378148061708,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,5,0.41979328729212284,0.41792611591517925,0.45677535235881805,0.3122298616194654,8.418484466353727,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,5,0.35099745728075504,0.343183521181345,0.3787568770349026,0.37342720661123885,10.068515297880111,10,float32,cuda,50 +pytorch,1,3,128,256,2,7,5.221625966951251,1.6814591363072395,8.201262401416898,0.025101759649117296,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,7,0.5324110481888056,0.5313355941325426,0.5536912009119987,0.24618572519464088,9.80750866217851,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,7,0.35434636287391186,0.3511281684041023,0.36454498767852783,0.3698979691422422,14.735937811246218,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,7,0.3654781263321638,0.3502380568534136,0.43992577120661736,0.3586315857405801,14.287109380126324,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,7,0.3121230937540531,0.3096370492130518,0.31899111345410347,0.41993688587260436,16.729380399727134,10,float32,cuda,50 +pytorch,1,3,128,256,3,3,10.960625801235437,2.763780066743493,45.61858847737312,0.026906492872583856,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,3,0.46120816841721535,0.45501673594117165,0.4788396880030632,0.6394336011265492,23.765029658625433,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,3,0.36482485942542553,0.361383892595768,0.3778459504246712,0.8083659662460131,30.04352778617576,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,3,0.4006952326744795,0.3770939074456692,0.4836510866880417,0.7360007705396968,27.354020980179047,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,3,0.32492942176759243,0.320731895044446,0.33880281262099743,0.9076186403671916,33.732327905581556,10,float32,cuda,50 +pytorch,1,3,128,256,3,5,11.318621216341853,2.605273388326168,45.752703258767724,0.026055470393708847,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,5,0.4699833784252405,0.4665427841246128,0.48548299819231033,0.6274945318027054,24.083024498157414,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,5,0.4197291377931833,0.42938650585711,0.4759266972541809,0.7026245581866524,26.96648909306595,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,5,0.3752768971025944,0.3709190059453249,0.4119148012250662,0.7858517331520571,30.160719468023988,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,5,0.36198515444993973,0.35343365743756294,0.3866641316562891,0.8147074441440512,31.26819173991056,10,float32,cuda,50 +pytorch,1,3,128,256,3,7,7.481619408354163,2.37233005464077,25.27893357910216,0.039418203988122395,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,7,0.4637504182755947,0.46098814345896244,0.47804401256144047,0.6359282673999478,16.132857488676233,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,7,0.3909336030483246,0.3640139475464821,0.45265606604516506,0.7543787428361459,19.137826347021225,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,7,0.3465086594223976,0.34476793371140957,0.35570100881159306,0.8510956132859561,21.59143561036953,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,7,0.32119077630341053,0.3183919470757246,0.33296276815235615,0.9181832784681636,23.293381878708452,10,float32,cuda,50 +pytorch,1,3,256,128,1,3,3.82584142498672,1.0650705080479383,7.791366055607796,0.008564913272670139,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,3,0.48191010020673275,0.4935292527079582,0.54588015191257,0.06799608471775748,7.938911061099336,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,3,0.31497727148234844,0.31047710217535496,0.33008200116455555,0.10403290321802264,12.146404745274209,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,3,0.2761990111321211,0.2733948640525341,0.28623687103390694,0.11863909239097628,13.85175641760937,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,3,0.2851138450205326,0.28066104277968407,0.2989410888403654,0.11492952928203187,13.418644838910577,10,float32,cuda,50 +pytorch,1,3,256,128,1,5,3.586227549239993,0.8654326666146517,11.385623132809997,0.009137178148928202,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,5,0.423511927947402,0.4099758807569742,0.48163579776883125,0.07737208290404897,8.46783127601871,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,5,0.2990085817873478,0.296260928735137,0.30720722861588,0.10958882786616574,11.993727831499116,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,5,0.275130495429039,0.27071102522313595,0.28432300314307213,0.1190998473248177,13.034642138261058,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,5,0.27801617980003357,0.2752654254436493,0.28759418055415154,0.11786364384824212,12.899348346630148,10,float32,cuda,50 +pytorch,1,3,256,128,1,7,3.53361826390028,1.0452242568135262,5.550267640501261,0.009273214465399514,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,7,0.4154033772647381,0.412175664678216,0.4310780204832554,0.07888236300764795,8.50647456736562,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,7,0.300332996994257,0.29647164046764374,0.3192121163010597,0.10910556058755874,11.765667772988163,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,7,0.35520353354513645,0.35250792279839516,0.39303875528275967,0.09225133453194125,9.948150652198553,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,7,0.28812913224101067,0.28413604013621807,0.30301674269139767,0.11372678543518687,12.264008975477436,10,float32,cuda,50 +pytorch,1,3,256,128,2,3,1.6048630606383085,1.2890337966382504,1.585709908977151,0.08167176578160394,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,3,0.49104482866823673,0.45836716890335083,0.6075259298086166,0.2669247130765648,3.2682618102116248,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,3,0.3567333798855543,0.35033351741731167,0.3854172769933939,0.367422863658147,4.498774578238723,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,3,0.35774463787674904,0.3515880089253187,0.385533319786191,0.36638424765197236,4.486057625247256,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,3,0.3179950825870037,0.307686161249876,0.3502919338643551,0.4121824744385429,5.046817226172725,10,float32,cuda,50 +pytorch,1,3,256,128,2,5,5.680167442187667,1.439184183254838,13.859670702368021,0.02307537609305383,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,5,0.4627335909754038,0.46338303945958614,0.48728715628385544,0.2832558572713755,12.2752433645769,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,5,0.3555159270763397,0.35281339660286903,0.3653420601040125,0.36868109138709554,15.977251677300995,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,5,0.3210102953016758,0.31748763285577297,0.3332026768475771,0.4083108919507472,17.694658163064886,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,5,0.32647970132529736,0.31356699764728546,0.3721813205629587,0.40147059516390177,17.398225430646516,10,float32,cuda,50 +pytorch,1,3,256,128,2,7,4.92123176343739,1.4630758669227362,18.036476150155067,0.026633982364701436,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,7,0.44690595008432865,0.4451470449566841,0.45659723691642284,0.29328765923852085,11.011783939123616,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,7,0.35398226231336594,0.351473456248641,0.3692640457302332,0.37027844034729446,13.902481246590899,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,7,0.3210613317787647,0.31872186809778214,0.33104592002928257,0.4082459861292746,15.32801143062749,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,7,0.32810162752866745,0.32733287662267685,0.3445217851549387,0.3994859793511623,14.99910805229823,10,float32,cuda,50 +pytorch,1,3,256,128,3,3,9.371620612218976,2.651255577802658,27.0526010543108,0.03146862343269484,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,3,0.5212203320115805,0.5533958319574594,0.5914739333093166,0.5658106215116867,17.980151649208416,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,3,0.3939199075102806,0.38670445792376995,0.4216096829622984,0.7486598021002614,23.790675296029292,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,3,0.39537341333925724,0.39467494934797287,0.4045611247420311,0.745907514390568,23.7032139644086,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,3,0.3381696157157421,0.32397685572505,0.3813320305198431,0.8720830798941337,27.712781328339542,10,float32,cuda,50 +pytorch,1,3,256,128,3,5,6.334149120375514,2.5149499997496605,9.2535394243896,0.04655905543040269,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,5,0.4680374823510647,0.46257232315838337,0.49366913735866547,0.6301033808629732,13.533422768957653,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,5,0.3780175279825926,0.38000987842679024,0.39740419015288353,0.7801543001825579,16.756231263083645,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,5,0.37817317992448807,0.36235409788787365,0.45027188025414944,0.7798331972110945,16.749334581686327,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,5,0.34465146251022816,0.34011295065283775,0.35576531663537025,0.8556818469651727,18.378419387056965,10,float32,cuda,50 +pytorch,1,3,256,128,3,7,9.138631783425808,3.381723305210471,33.623141143471,0.03227091396054104,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,7,0.4677951242774725,0.4621578846126795,0.49515804275870323,0.6304298285611738,19.53554303829219,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,7,0.3610655106604099,0.35483832471072674,0.38772691041231155,0.8167825264190665,25.310176446126675,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,7,0.3526437934488058,0.32708211801946163,0.39914101362228394,0.8362886444584856,25.91462533354493,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,7,0.32145872712135315,0.3138268366456032,0.3484130371361971,0.9174179299498951,28.428631772612924,10,float32,cuda,50 +pytorch,1,3,256,256,1,3,4.861515955999494,1.0742205195128918,14.73111561499536,0.013480568734763365,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,3,0.4849292803555727,0.48372894525527954,0.49910699017345905,0.13514547925822495,10.025206051560351,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,3,0.3039980586618185,0.2988413907587528,0.32487385906279087,0.21558032406024435,15.99193092679473,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,3,0.2766306512057781,0.273675424978137,0.2857776824384928,0.23690794824919645,17.574032142892012,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,3,0.2745348773896694,0.26938505470752716,0.2950329799205065,0.23871648157468708,17.708190675894173,10,float32,cuda,50 +pytorch,1,3,256,256,1,5,1.1926674656569958,1.1049916502088308,1.324015948921442,0.05494909678273036,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,5,0.4482492245733738,0.419301213696599,0.5197371356189251,0.14620438007979739,2.660723990748964,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,5,0.31799268908798695,0.2957459073513746,0.39233872666954994,0.2060927884473046,3.7506128492375206,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,5,0.27819630689918995,0.2757101319730282,0.28871200047433376,0.23557465852250975,4.287143416642058,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,5,0.27723253704607487,0.2712407149374485,0.302234198898077,0.23639360912787882,4.302047221314356,10,float32,cuda,50 +pytorch,1,3,256,256,1,7,1.1875793617218733,1.0118531063199043,1.2298297137022018,0.05518452249370455,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,7,0.4127761535346508,0.4072303418070078,0.43675173074007034,0.15876886161859766,2.87705418918339,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,7,0.29986392706632614,0.29469607397913933,0.319720059633255,0.21855246358293792,3.9603942139368957,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,7,0.3060135804116726,0.2915910445153713,0.36104372702538967,0.21416043010848088,3.8808060744371273,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,7,0.2705792896449566,0.2661552280187607,0.29452070593833923,0.242206268210674,4.389025351053911,10,float32,cuda,50 +pytorch,1,3,256,256,2,3,4.6642600279301405,2.254175953567028,5.693626776337624,0.05620269848384325,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,3,0.47030373476445675,0.4551168531179428,0.5582175217568874,0.5573929795205436,9.917548348337382,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,3,0.3628383856266737,0.35751843824982643,0.3879097755998373,0.722481441833228,12.854924431091538,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,3,0.3875340800732374,0.3856392577290535,0.41215093806385994,0.6764411531250598,12.035741545746568,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,3,0.33060619607567787,0.3230019938200712,0.3539799712598324,0.7929191984653353,14.108205119248465,10,float32,cuda,50 +pytorch,1,3,256,256,2,5,8.932606596499681,2.1693871822208166,25.42668771930039,0.02934686501280862,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,5,0.5238902755081654,0.5380609072744846,0.5635851062834263,0.5003795875877337,17.05052949844353,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,5,0.37138083949685097,0.36642886698246,0.3967938479036093,0.7058630174759535,24.052416405223358,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,5,0.3278264496475458,0.3224520478397608,0.3462827764451504,0.7996426166401076,27.24797406098057,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,5,0.32294890843331814,0.3143919166177511,0.3508441150188446,0.811719727655085,27.65950391296668,10,float32,cuda,50 +pytorch,1,3,256,256,2,7,5.310848616063595,2.0857034251093864,5.400367686524987,0.049360096465016795,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,7,0.5602501425892115,0.5671817343682051,0.6339772138744593,0.4679052802887194,9.479423943596625,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,7,0.42169813998043537,0.4308209754526615,0.4505240358412266,0.6216389761931654,12.593957887293481,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,7,0.38437320850789547,0.3447979688644409,0.5109140183776617,0.6820038290848132,13.816906325703249,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,7,0.3148854896426201,0.31107640825212,0.32421983778476715,0.832505811231635,16.865968076493942,10,float32,cuda,50 +pytorch,1,3,256,256,3,3,10.819459995254874,5.247258115559816,43.626357009634376,0.054515105214001526,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,3,0.5695007462054491,0.5678911693394184,0.5740981083363295,1.0356860880867385,18.998148935439186,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,3,0.47695184126496315,0.47551305033266544,0.48432392068207264,1.2366531565025085,22.68459634532429,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,3,0.4476183373481035,0.44674682430922985,0.4503197968006134,1.3176940057781994,24.171172386176853,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,3,0.43853784911334515,0.4363264888525009,0.4481939598917961,1.3449785490409363,24.671667490343484,10,float32,cuda,50 +pytorch,1,3,256,256,3,5,16.53141546063125,6.35837041772902,54.64553306810558,0.03567897748409002,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,5,0.5723217409104109,0.5713619757443666,0.5795460194349289,1.0305811536387692,28.88482872297353,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,5,0.4708682466298342,0.46868249773979187,0.4833988845348358,1.2526306545866557,35.10836752945706,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,5,0.4524185135960579,0.4476869944483042,0.4720529541373253,1.3037132263040514,36.540094986898204,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,5,0.47651665285229683,0.4608971066772938,0.5228085909038782,1.2377825548582126,34.69220931037514,10,float32,cuda,50 +pytorch,1,3,256,256,3,7,18.527234653010964,7.752042729407549,61.81915830820799,0.03183551193940022,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,7,0.5787044391036034,0.5703659262508154,0.5973172839730978,1.0192145768116458,32.015020796642105,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,7,0.4701301362365484,0.46895304694771767,0.47616218216717243,1.2545973009125009,39.40873648586716,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,7,0.4477470647543669,0.44500199146568775,0.46136612072587013,1.3173151683832396,41.3787964487829,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,7,0.43689489364624023,0.43615163303911686,0.44089406728744507,1.3500363784924174,42.40661752392263,10,float32,cuda,50 \ No newline at end of file diff --git a/test/test_speed.py b/test/test_speed.py index 841c095..5ff90ce 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -1,176 +1,238 @@ -import time -import math +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os, sys, time, json, argparse, pathlib, statistics, subprocess, csv, re import torch -from util_converse import Converse2D - -# -------------------------- -# Config -# -------------------------- -device = "cuda" if torch.cuda.is_available() else "cpu" -DTYPES = [torch.float32] -USE_AUTOCast = False -WARMUP = 10 -ITERS = 50 -TRAIN = True -INFER = True -CASES = [ - # (B, C, H, W, scale, ksize, padding, padding_mode) - (1, 3, 128, 128, 2, 5, 2, "circular"), - (2, 3, 256, 256, 2, 5, 2, "circular"), - (4, 3, 256, 256, 2, 5, 2, "circular"), - (2, 8, 256, 256, 2, 5, 2, "circular"), - (1, 3, 512, 512, 2, 5, 2, "circular"), # 320 -] +import torch.nn.functional as F + +ROOT = pathlib.Path(__file__).resolve().parent +MODELS = ROOT / "models" +BACKEND = MODELS / "backend" +# ----------------------------- +# Common utils +# ----------------------------- def synchronize(): if torch.cuda.is_available(): torch.cuda.synchronize() -def timed_run(fn, warmup=WARMUP, iters=ITERS): - for _ in range(warmup): - fn() +def timed_run(fn, warmup, iters): + for _ in range(warmup): fn() synchronize() - t0 = time.perf_counter() + times=[] for _ in range(iters): - fn() - synchronize() - t1 = time.perf_counter() - return (t1 - t0) / iters - -def make_model(C, scale, ksize=5, padding=2, padding_mode="circular", dtype=torch.float32): - m = Converse2D( - in_channels=C, out_channels=C, kernel_size=ksize, - scale=scale, padding=padding, padding_mode=padding_mode, - eps=1e-5, backend="pytorch" - ).to(device=device, dtype=dtype) - m.eval() - return m - -def clone_as_cuda_backend(m): - m2 = Converse2D( - in_channels=m.in_channels, out_channels=m.out_channels, - kernel_size=m.kernel_size, scale=m.scale, padding=m.padding, - padding_mode=m.padding_mode, eps=m.eps, backend="cuda" - ).to(device=device, dtype=next(m.parameters()).dtype) - m2.load_state_dict(m.state_dict()) - m2.eval() - return m2 - -def tp_gpix_per_s(B,H,W,s,t): - if t is None or t <= 0: return None - return (B * (H*s) * (W*s) / t) / 1e9 - -def speedup_and_pct(t_py, t_cu): - if t_py and t_cu and t_py > 0 and t_cu > 0: - sp = t_py / t_cu - pct = (t_py - t_cu) / t_py * 100.0 - return sp, pct - return None, None - -def fmt_ms(t): return "-" if t is None else f"{t*1e3:7.2f}" -def fmt_tp(x): return "-" if x is None else f"{x:6.3f}" -def fmt_sp(x): return "-" if x is None else f"{x:5.2f}×" -def fmt_pct(p): return "-" if p is None else f"{p:6.1f}%" - -def geom_mean(vals): - vals = [v for v in vals if v and v > 0] - if not vals: return None - return math.exp(sum(math.log(v) for v in vals) / len(vals)) - -def run_case(B,C,H,W,scale,ksize,padding,padding_mode,dtype): - x = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=TRAIN) - - torch.manual_seed(0); m_py = make_model(C, scale, ksize, padding, padding_mode, dtype) - torch.manual_seed(0); m_cu = clone_as_cuda_backend(m_py) - - fwd_py = fwd_cu = None - if INFER: - def fwd_run(m): - def _call(): - with torch.no_grad(): - if USE_AUTOCast and dtype is torch.bfloat16 and device == "cuda": - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - _ = m(x) - else: - _ = m(x) - return _call - fwd_py = timed_run(fwd_run(m_py)) - fwd_cu = timed_run(fwd_run(m_cu)) - - bwd_py = bwd_cu = None - if TRAIN: - def train_run(m): - def _call(): - x_local = x.detach().clone().requires_grad_(True) - if USE_AUTOCast and dtype is torch.bfloat16 and device == "cuda": - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - y = m(x_local); loss = y.square().mean() - else: - y = m(x_local); loss = y.square().mean() - loss.backward() - return _call - bwd_py = timed_run(train_run(m_py)) - bwd_cu = timed_run(train_run(m_cu)) - - fwd_tp_py = tp_gpix_per_s(B,H,W,scale,fwd_py) - fwd_tp_cu = tp_gpix_per_s(B,H,W,scale,fwd_cu) - bwd_tp_py = tp_gpix_per_s(B,H,W,scale,bwd_py) - bwd_tp_cu = tp_gpix_per_s(B,H,W,scale,bwd_cu) - - fwd_sp, fwd_pct = speedup_and_pct(fwd_py, fwd_cu) - bwd_sp, bwd_pct = speedup_and_pct(bwd_py, bwd_cu) - + t0=time.perf_counter(); fn(); synchronize() + times.append((time.perf_counter()-t0)*1000.0) return { - "shape": (B,C,H,W,scale,ksize,padding,padding_mode,str(dtype).split('.')[-1]), - "fwd_py": fwd_py, "fwd_cu": fwd_cu, "fwd_tp_py": fwd_tp_py, "fwd_tp_cu": fwd_tp_cu, - "bwd_py": bwd_py, "bwd_cu": bwd_cu, "bwd_tp_py": bwd_tp_py, "bwd_tp_cu": bwd_tp_cu, - "fwd_sp": fwd_sp, "fwd_pct": fwd_pct, "bwd_sp": bwd_sp, "bwd_pct": bwd_pct + "mean_ms": statistics.mean(times), + "p50_ms": statistics.median(times), + "p90_ms": statistics.quantiles(times, n=10)[8] if len(times)>=10 else statistics.median_high(times), } -def main(): +def tp_gpix_per_s(B,H,W,s,mean_ms): + if mean_ms<=0: return None + return (B*(H*s)*(W*s)/(mean_ms/1e3))/1e9 + +def to_dtype(name): + name=name.lower() + if name in ("fp16","half","float16"): return torch.float16 + if name in ("bf16","bfloat16"): return torch.bfloat16 + if name in ("fp32","float32","float"):return torch.float32 + raise ValueError(name) + +# ----------------------------- +# Subprocess runner +# ----------------------------- +def _parse_last_json_from_text(txt: str): + # 提取最后一个 {...} JSON 对象;把所有输出(包含编译日志)都容忍 + m = re.findall(r"\{.*\}", txt, flags=re.S) + if not m: + tail = txt[-2000:] if len(txt) > 2000 else txt + raise RuntimeError("Child produced no JSON. Tail of output:\n" + tail) + return json.loads(m[-1]) + + +def run_variant_subprocess(variant, case_args, cache_root): + """ + 在子进程中跑一个 variant(pytorch / cuda_v1..v4),返回 json 结果。 + 合并 stdout+stderr,解析最后一个 JSON。 + """ + cmd = [ + sys.executable, __file__, "--worker", + "--variant", variant, + "--B", str(case_args["B"]), "--C", str(case_args["C"]), + "--H", str(case_args["H"]), "--W", str(case_args["W"]), + "--scale", str(case_args["scale"]), "--ksize", str(case_args["ksize"]), + "--warmup", str(case_args["warmup"]), "--iters", str(case_args["iters"]), + "--dtype", case_args["dtype"], "--device", case_args["device"], + ] + env = os.environ.copy() + env["TORCH_EXTENSIONS_DIR"] = str(pathlib.Path(cache_root) / variant) + env.setdefault("PYTHONWARNINGS", "ignore") + env.setdefault("TORCH_SHOW_CPP_STACKTRACES", "0") + + proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + out = proc.stdout + if proc.returncode != 0: + raise RuntimeError(f"Subprocess failed (variant={variant}). Output:\n{out}") + return _parse_last_json_from_text(out) + +# ----------------------------- +# Worker: run in subprocess +# ----------------------------- +def worker_main(args): + device = "cuda" if (args.device=="cuda" and torch.cuda.is_available()) else "cpu" + dtype = to_dtype(args.dtype) + B,C,H,W,s,k = args.B,args.C,args.H,args.W,args.scale,args.ksize + + if args.variant == "pytorch": + from models.util_converse import Converse2D + torch.manual_seed(0) + x = torch.randn(B,C,H,W, device=device, dtype=dtype) + m = Converse2D(C, C, kernel_size=k, scale=s, padding=k//2, + padding_mode="circular", eps=1e-5, backend="pytorch").to(device=device, dtype=dtype) + m.eval() + def call(): + with torch.no_grad(): + _ = m(x) + stat = timed_run(call, args.warmup, args.iters) + stat["tp"] = tp_gpix_per_s(B,H,W,s,stat["mean_ms"]) + print(json.dumps({"variant":"pytorch", **stat})) + return + + from torch.utils.cpp_extension import load + vnum = int(args.variant.split("_v")[1]) + cpp = BACKEND / f"converse2d_v{vnum}.cpp" + cu = BACKEND / f"converse2d_v{vnum}.cu" + sources = [str(cpp)] + if cu.exists(): sources.append(str(cu)) + + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9" # RTX 4090 (sm_89) + load( + name=f"converse2d_v{vnum}_ext", + sources=sources, + verbose=False, + extra_cflags=["-O3"], + extra_cuda_cflags=(["-O3","-gencode=arch=compute_89,code=sm_89"] if cu.exists() else []) + ) + + torch.manual_seed(0) + x = torch.randn(B,C,H,W, device=device, dtype=dtype) + x0 = x if s==1 else F.interpolate(x, scale_factor=s, mode="nearest") + weight = torch.randn(1,C,k,k, device=device, dtype=dtype) + weight = torch.softmax(weight.view(1,C,-1), dim=-1).view(1,C,k,k).contiguous() + bias = torch.zeros(1,C,1,1, device=device, dtype=dtype) + + converse2d_forward = torch.ops.converse2d.forward + def call(): + with torch.no_grad(): + _ = converse2d_forward(x, x0, weight, bias, int(s), float(1e-5)) + stat = timed_run(call, args.warmup, args.iters) + stat["tp"] = tp_gpix_per_s(B,H,W,s,stat["mean_ms"]) + + try: torch.ops.converse2d.clear_cache() + except Exception: pass + + print(json.dumps({"variant": args.variant, **stat})) + +# ----------------------------- +# Orchestrator (parent) +# ----------------------------- +def parent_main(args): + device = "cuda" if (args.device=="cuda" and torch.cuda.is_available()) else "cpu" + Bs = [int(x) for x in args.B_list.split(",")] + Cs = [int(x) for x in args.C_list.split(",")] + Hs = [int(x) for x in args.H_list.split(",")] + Ws = [int(x) for x in args.W_list.split(",")] + Ss = [int(x) for x in args.scale_list.split(",")] + Ks = [int(x) for x in args.ksize_list.split(",")] + print(f"[Env] device={device}, torch={torch.__version__}, cuda={torch.version.cuda}, cudnn={torch.backends.cudnn.version()}") - print(f"[Cfg] dtypes={[d.__str__() for d in DTYPES]}, AMP={USE_AUTOCast}, TRAIN={TRAIN}, INFER={INFER}, warmup={WARMUP}, iters={ITERS}\n") - - rows = [] - for dtype in DTYPES: - for (B,C,H,W,s,ks,pd,pm) in CASES: - rows.append(run_case(B,C,H,W,s,ks,pd,pm,dtype)) - - print("=== Per‑case Comparison (CUDA vs PyTorch) ===") - for r in rows: - B,C,H,W,s,ks,pd,pm,dtype = r["shape"] - tag = f"[{dtype}] B{B} C{C} {H}x{W} s{s} k{ks}" - # forward - if INFER: - print(f"{tag} | Forward : Py {fmt_ms(r['fwd_py'])} ms ({fmt_tp(r['fwd_tp_py'])} Gpix/s) " - f"vs CUDA {fmt_ms(r['fwd_cu'])} ms ({fmt_tp(r['fwd_tp_cu'])} Gpix/s) " - f"-> CUDA is {fmt_sp(r['fwd_sp'])} faster ({fmt_pct(r['fwd_pct'])} time saved)") - # backward - if TRAIN: - print(f"{tag} | Train : Py {fmt_ms(r['bwd_py'])} ms ({fmt_tp(r['bwd_tp_py'])} Gpix/s) " - f"vs CUDA {fmt_ms(r['bwd_cu'])} ms ({fmt_tp(r['bwd_tp_cu'])} Gpix/s) " - f"-> CUDA is {fmt_sp(r['bwd_sp'])} faster ({fmt_pct(r['bwd_pct'])} time saved)") - print("") - - hdr = ("dtype B C H W s k | fwd_py(ms) fwd_cu(ms) fwd_Gpix/s(py) fwd_Gpix/s(cu) fwd_speedup " - "| bwd_py(ms) bwd_cu(ms) bwd_Gpix/s(py) bwd_Gpix/s(cu) bwd_speedup") - print(hdr); print("-"*len(hdr)) - for r in rows: - B,C,H,W,s,ks,pd,pm,dtype = r["shape"] - line = (f"{dtype:6s} {B:3d} {C:3d} {H:5d} {W:5d} {s:2d} {ks:3d} | " - f"{fmt_ms(r['fwd_py'])} {fmt_ms(r['fwd_cu'])} {fmt_tp(r['fwd_tp_py'])} {fmt_tp(r['fwd_tp_cu'])} {fmt_sp(r['fwd_sp'])} | " - f"{fmt_ms(r['bwd_py'])} {fmt_ms(r['bwd_cu'])} {fmt_tp(r['bwd_tp_py'])} {fmt_tp(r['bwd_tp_cu'])} {fmt_sp(r['bwd_sp'])}") - print(line) - - fwd_sps = [r["fwd_sp"] for r in rows if r["fwd_sp"]] - bwd_sps = [r["bwd_sp"] for r in rows if r["bwd_sp"]] - gm_fwd = geom_mean(fwd_sps) - gm_bwd = geom_mean(bwd_sps) - if gm_fwd: - print(f"\nOverall Forward Geomean Speedup : {gm_fwd:.2f}× (CUDA vs PyTorch)") - if gm_bwd: - print(f"Overall Train Geomean Speedup : {gm_bwd:.2f}× (CUDA vs PyTorch)") + print(f"[Cfg] dtype={args.dtype}, warmup={args.warmup}, iters={args.iters}") + print(f"[Grid] B={Bs} C={Cs} H={Hs} W={Ws} scale={Ss} ksize={Ks}\n") + + variants = ["pytorch", "cuda_v1", "cuda_v2", "cuda_v3", "cuda_v4"] + results = [] + cache_root = ROOT / ".torch_ext_cache_grid" + cache_root.mkdir(exist_ok=True) + + for B in Bs: + for C in Cs: + for H in Hs: + for W in Ws: + for s in Ss: + for k in Ks: + case = dict(B=B,C=C,H=H,W=W,scale=s,ksize=k, + warmup=args.warmup,iters=args.iters, + dtype=args.dtype,device=device) + base = run_variant_subprocess("pytorch", case, cache_root) + base_mean = base["mean_ms"] + results.append({**case,"variant":"pytorch",**base}) + + print(f"[Case] B{B} C{C} {H}x{W} s{s} k{k}") + print(f" PyTorch : {base_mean:.3f} ms") + + for v in variants[1:]: + r = run_variant_subprocess(v, case, cache_root) + sp = base_mean / r["mean_ms"] if r["mean_ms"]>0 else None + results.append({**case, "variant":v, **r, "speedup_vs_pytorch": sp}) + print(f" {v:8s}: {r['mean_ms']:.3f} ms ({sp:.2f}x vs PyTorch)") + print("") + + # 将原来的header定义 + header = ["variant","B","C","H","W","scale","ksize","mean_ms","p50_ms","p90_ms","tp","speedup_vs_pytorch"] + + # 修改为包含所有需要的字段 + header = ["variant","B","C","H","W","scale","ksize","mean_ms","p50_ms","p90_ms","tp","speedup_vs_pytorch","warmup","dtype","device","iters"] + print("\n=== Summary (normalized to PyTorch) ===") + print(" | ".join(h.rjust(10) for h in header)) + print("-"*120) + for r in results: + line=[] + for h in header: + v = r.get(h,"") + if isinstance(v,float): + line.append(f"{v:10.3f}") + else: + line.append(str(v).rjust(10)) + print(" | ".join(line)) + + if args.csv: + with open(args.csv,"w",newline="") as f: + w=csv.DictWriter(f, fieldnames=header); w.writeheader(); w.writerows(results) + print(f"\n[Saved] {args.csv}") + +# ----------------------------- +# CLI +# ----------------------------- +def main(): + p = argparse.ArgumentParser() + p.add_argument("--worker", action="store_true", help="internal") + p.add_argument("--variant", default="") + p.add_argument("--B", type=int, default=2) + p.add_argument("--C", type=int, default=16) + p.add_argument("--H", type=int, default=128) + p.add_argument("--W", type=int, default=128) + p.add_argument("--scale", type=int, default=2) + p.add_argument("--ksize", type=int, default=5) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--iters", type=int, default=50) + p.add_argument("--dtype", default="float32", choices=["float16","bfloat16","float32"]) + p.add_argument("--device", default="cuda") + # grid + p.add_argument("--B_list", default="1") + p.add_argument("--C_list", default="3") + p.add_argument("--H_list", default="128,256") + p.add_argument("--W_list", default="128,256") + p.add_argument("--scale_list", default="1,2,3") + p.add_argument("--ksize_list", default="3,5,7") + p.add_argument("--csv", default="") + args = p.parse_args() + + if args.worker: + worker_main(args) + else: + parent_main(args) if __name__ == "__main__": - torch.set_grad_enabled(True) main() From 54b9b99f30aeafe3eb795d72460ae8574c14e124 Mon Sep 17 00:00:00 2001 From: Yiozolm <1473416941@qq.com> Date: Sun, 31 Aug 2025 15:31:03 +0800 Subject: [PATCH 10/22] Update project structure --- .gitignore | 3 +- Converse2D/README.md | 37 ++++++ Converse2D/setup.py | 78 +++++++++++-- Converse2D/torch_converse2d/__init__.py | 7 ++ README.md | 19 --- models/util_converse.py | 4 +- test/README.md | 8 ++ test/test_cache.py | 146 +++--------------------- test/test_error.py | 51 ++++----- test/test_speed.py | 89 +++------------ 10 files changed, 185 insertions(+), 257 deletions(-) create mode 100644 Converse2D/README.md create mode 100644 test/README.md diff --git a/.gitignore b/.gitignore index d41279a..e198538 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ TODO.md -Optimization.md \ No newline at end of file +Optimization.md +**/__pycache__/ \ No newline at end of file diff --git a/Converse2D/README.md b/Converse2D/README.md new file mode 100644 index 0000000..3401b8f --- /dev/null +++ b/Converse2D/README.md @@ -0,0 +1,37 @@ +### Kernel Registry +---------- +**Kernel Details** +We offer four versions Converse2d Kernel. +- v1: Translation from python to CPP **faster** +- v2: Add FB/F2B cache & broadcast replace repeat **much faster** +- v3: `​splits→permute→view→mean` to `block mean CUDA kernel` **fastest** +- v4: STy s-fold upsampler CUDA kernel **fastest** + +**Tested Device** +- NVIDIA RTX 2080ti +- NVIDIA RTX 4090 +- NVIDIA RTX 5060ti 16g + +Under different circumstances, **v3** and **v4** each have their own performance advantages, but they are both faster than **v1** and **v2**. + +We highly recommend you to run `test/test_speed.py` first to choose the most suitable backend for GPU. + + +**Installation** + +```python +cd ./Converse2D +# Remember to choose the wanted kernel version +CONVERSE2D_VARIANT={v1,v2,v3,v4} pip install -e . +``` + +**Usage** + +```python +import torch +import torch_converse2d + +out = torch.ops.converse2d.forward(x, x0, weight, bias, scale, eps) +print(torch.ops.converse2d) +``` + diff --git a/Converse2D/setup.py b/Converse2D/setup.py index 30cb75f..45b6837 100644 --- a/Converse2D/setup.py +++ b/Converse2D/setup.py @@ -1,18 +1,80 @@ +# setup.py — selectable variants: v1 | v2 | v3 | v4 from setuptools import setup -from torch.utils.cpp_extension import CppExtension, BuildExtension +from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension +import os, sys, pathlib + +PKG_DIR = pathlib.Path(__file__).resolve().parent / "torch_converse2d" + +# --------------------------- +# parse custom args +# --------------------------- +variant = os.environ.get("CONVERSE2D_VARIANT", "").lower() +to_remove = [] +for i, a in enumerate(list(sys.argv)): + aa = a.lower() + if aa.startswith("--variant="): + variant = a.split("=", 1)[1].lower(); to_remove.append(i) + elif aa in ("--v1","--v2","--v3","--v4"): + variant = aa[2:]; to_remove.append(i) +# scrub custom flags so setuptools doesn't see them +for idx in reversed(to_remove): + sys.argv.pop(idx) + +if variant not in {"", "v1","v2","v3","v4"}: + raise SystemExit(f"[setup.py] invalid --variant={variant!r}; pick from v1|v2|v3|v4") + +if not variant: + variant = "v1" # default + +# --------------------------- +# pick sources per variant +# --------------------------- +CPP = str(PKG_DIR / f"converse2d_{variant}.cpp") +CU = str(PKG_DIR / f"converse2d_{variant}.cu") +has_cu = os.path.exists(CU) # v3,v4 have .cu; v1,v2 usually not + +# --------------------------- +# CUDA arch (auto if not set) +# --------------------------- +extra_cflags = ["-O3"] +extra_cuda = ["-O3"] + +# Respect TORCH_CUDA_ARCH_LIST if user already set it; otherwise auto-detect. +if has_cu and "TORCH_CUDA_ARCH_LIST" not in os.environ: + try: + import torch + if torch.cuda.is_available(): + maj, min = torch.cuda.get_device_capability(0) + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{maj}.{min}+PTX" + except Exception: + # Fallback: a safe default that covers Ampere/Lovelace widely. + os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "8.0;8.6;8.9+PTX") + +# --------------------------- +# Extension definition +# --------------------------- +if has_cu: + ext = CUDAExtension( + name="converse2d_ext", + sources=[CPP, CU], + extra_compile_args={"cxx": extra_cflags, "nvcc": extra_cuda}, + ) +else: + ext = CppExtension( + name="converse2d_ext", + sources=[CPP], + extra_compile_args={"cxx": extra_cflags}, + ) + +print(f"[setup.py] building variant={variant} sources={[p for p in ([CPP] + ([CU] if has_cu else []))]}") +print(f"[setup.py] TORCH_CUDA_ARCH_LIST={os.environ.get('TORCH_CUDA_ARCH_LIST','')}") setup( name="torch_converse2d", version="0.1", description="Converse2D CUDA extension for PyTorch", packages=["torch_converse2d"], - ext_modules=[ - CppExtension( - name="converse2d_ext", - sources=["torch_converse2d/converse2d_ext.cpp"], - extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]}, - ) - ], + ext_modules=[ext], cmdclass={"build_ext": BuildExtension}, zip_safe=False, ) diff --git a/Converse2D/torch_converse2d/__init__.py b/Converse2D/torch_converse2d/__init__.py index e69de29..d5c9823 100644 --- a/Converse2D/torch_converse2d/__init__.py +++ b/Converse2D/torch_converse2d/__init__.py @@ -0,0 +1,7 @@ +import os +try: + import converse2d_ext +except Exception as e: + print("[torch_converse2d] extension import failed:", e) + +__all__ = ["converse2d_ext"] diff --git a/README.md b/README.md index ca9da18..946b4a2 100644 --- a/README.md +++ b/README.md @@ -17,25 +17,6 @@ ___________ * [Visual results of ConverseNet](#visual-results-of-conversenet) * [Visual results of Converse-USRNet](#visual-results-of-converse-usrnet) -Kernel Registry ----------- -**Installation** - -```python -cd ./Converse2D -pip install --no-build-isolation -e. -``` - -**Usage** - -```python -import torch -import torch_converse2d - -out = torch.ops.converse2d.forward(x, x0, weight, bias, scale, eps) -print(torch.ops.converse2d) -``` - Motivation diff --git a/models/util_converse.py b/models/util_converse.py index 3fe1231..5a18c26 100644 --- a/models/util_converse.py +++ b/models/util_converse.py @@ -29,7 +29,9 @@ def _try_import_converse2d_ext(): _try_import_converse2d_ext() -converse2d_CUDA = torch.ops.converse2d.forward +converse2d_CUDA = torch.ops.converse2d.forward if ( + hasattr(torch.ops, "converse2d") and hasattr(torch.ops.converse2d, "forward") +) else None """ diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000..c6a4d6b --- /dev/null +++ b/test/README.md @@ -0,0 +1,8 @@ +### Test +--- +`test_cache.py`: test clean cache function +`test_error.py`: verify CUDA backend precision +`test_speed.py`: comparison for pytorch and CUDA backend speed + +`test_cache.py` and `test_error.py` needs kernel registry to run +`test_speed.py` can be run any time. \ No newline at end of file diff --git a/test/test_cache.py b/test/test_cache.py index 9ab894b..65587f5 100644 --- a/test/test_cache.py +++ b/test/test_cache.py @@ -1,136 +1,28 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- -import os, sys, math, subprocess, json -import torch +import sys, pathlib, torch +PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PROJECT_ROOT)) -# Paths to the two C++ sources (no-cache vs cache) -ROOT = os.path.dirname(os.path.abspath(__file__)) -SRC_NOCACHE = os.path.join(ROOT, "models/backend/converse2d_v1.cpp") -SRC_CACHE = os.path.join(ROOT, "models/backend/converse2d_v3.cpp") - -# Benchmark config -CASES = [ - # (B, C, H, W, scale, ksize) - (1, 3, 128, 128, 2, 5), - (2, 3, 256, 256, 2, 5), - (4, 3, 256, 256, 2, 5), - (2, 8, 256, 256, 2, 5), - (1, 3, 512, 512, 2, 5), -] -WARMUP = 10 -ITERS = 50 -DTYPE = "float32" -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - -def run_single_bench(src_path, tag): - # compile & run timings in a clean subprocess to avoid op name collisions - child = f''' -import os, time, torch, json -from torch.utils.cpp_extension import load - -torch.manual_seed(0) -device = "{DEVICE}" -dtype = torch.{DTYPE} - -ext = load( - name="converse2d_ext", - sources=[r\"\"\"{src_path}\"\"\"], - verbose=False, - extra_cflags=["-O3"], - extra_cuda_cflags=["-O3","-gencode","arch=compute_89,code=sm_89"], -) - -op = torch.ops.converse2d.forward -clear = getattr(torch.ops.converse2d, "clear_cache", None) - -def synchronize(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - -def timed(fn, warmup={WARMUP}, iters={ITERS}): - for _ in range(warmup): - fn() - synchronize() - t0 = time.perf_counter() - for _ in range(iters): - fn() - synchronize() - t1 = time.perf_counter() - return (t1 - t0) / iters - -def bench_case(B,C,H,W,scale,ksize): - x = torch.randn(B,C,H,W, device=device, dtype=dtype, requires_grad=True) - x0 = x if scale == 1 else torch.nn.functional.interpolate(x, scale_factor=scale, mode="nearest") - weight = torch.randn(1,C,ksize,ksize, device=device, dtype=dtype, requires_grad=False) - weight = torch.nn.functional.softmax(weight.view(1,C,-1), dim=-1).view_as(weight) - bias = torch.zeros(1,C,1,1, device=device, dtype=dtype, requires_grad=False) - - if clear is not None: - clear() - - def fwd(): - with torch.no_grad(): - _ = op(x, x0, weight, bias, int(scale), float(1e-5)) - - def train(): - x_local = x.detach().clone().requires_grad_(True) - y = op(x_local, x0, weight, bias, int(scale), float(1e-5)) - loss = y.square().mean() - loss.backward() - - t_f = timed(fwd) - t_b = timed(train) - t_f_hot = timed(fwd) # hot cache - t_b_hot = timed(train) # hot cache - - def tp(B,H,W,s,t): return (B*H*s*W*s / t) / 1e9 - - return dict( - shape=(B,C,H,W,scale,ksize), - fwd_ms=t_f*1e3, bwd_ms=t_b*1e3, fwd_hot_ms=t_f_hot*1e3, bwd_hot_ms=t_b_hot*1e3, - fwd_tp=tp(B,H,W,scale,t_f), bwd_tp=tp(B,H,W,scale,t_b), - fwd_tp_hot=tp(B,H,W,scale,t_f_hot), bwd_tp_hot=tp(B,H,W,scale,t_b_hot), - ) - -rows = [] -for (B,C,H,W,s,k) in {CASES}: - rows.append(bench_case(B,C,H,W,s,k)) - -print(json.dumps(dict(tag="{tag}", rows=rows), indent=2)) -''' - out = subprocess.check_output([sys.executable, "-c", child], text=True) - return json.loads(out) +from models.util_converse import Converse2D def main(): - if DEVICE != "cuda": - print("[WARN] CUDA device not available; this script is intended for RTX 4090 tests.") - res_nc = run_single_bench(SRC_NOCACHE, "nocache") - res_cc = run_single_bench(SRC_CACHE, "cache") + device = "cuda" if torch.cuda.is_available() else "cpu" + + backend = "cuda" if device == "cuda" else "pytorch" + m = Converse2D(3,3,5, scale=2, padding=2, padding_mode="circular", eps=1e-5, backend=backend).to(device) + m.eval() - def fmt_ms(x): return f"{x:7.2f}" - print("=== Converse2D CUDA Backend: Cache vs No-Cache ===") - print(f"[Env] torch={torch.__version__}, cuda={torch.version.cuda}, device={DEVICE}") - print("case | no‑cache fwd cache fwd | no‑cache bwd cache bwd || fwd speedup bwd speedup") - print("-"*110) - for r_nc, r_c in zip(res_nc["rows"], res_cc["rows"]): - B,C,H,W,s,k = r_nc["shape"] - tag = f"B{B} C{C} {H}x{W} s{s} k{k}" - tf0, tb0 = r_nc["fwd_hot_ms"], r_nc["bwd_hot_ms"] - tf1, tb1 = r_c["fwd_hot_ms"], r_c["bwd_hot_ms"] - sp_f = tf0 / tf1 if tf1 > 0 else float('nan') - sp_b = tb0 / tb1 if tb1 > 0 else float('nan') - print(f"{tag:24s} | {fmt_ms(tf0)} {fmt_ms(tf1)} | {fmt_ms(tb0)} {fmt_ms(tb1)} || {sp_f:5.2f}× {sp_b:5.2f}×") + x = torch.randn(1,3,64,64, device=device) + with torch.no_grad(): + _ = m(x) - # Geometric mean speedups - sps_f, sps_b = [], [] - for r_nc, r_c in zip(res_nc["rows"], res_cc["rows"]): - tf0, tb0 = r_nc["fwd_hot_ms"], r_nc["bwd_hot_ms"] - tf1, tb1 = r_c["fwd_hot_ms"], r_c["bwd_hot_ms"] - sps_f.append(tf0/tf1); sps_b.append(tb0/tb1) - def gmean(a): - a = [x for x in a if x>0 and math.isfinite(x)] - return math.exp(sum(math.log(x) for x in a)/len(a)) if a else float('nan') - print("-"*110) - print(f"Geomean speedup: Forward {gmean(sps_f):.2f}×, Backward {gmean(sps_b):.2f}×") + if hasattr(torch.ops, "converse2d") and hasattr(torch.ops.converse2d, "clear_cache"): + torch.ops.converse2d.clear_cache() + print("[INFO] cleared converse2d FB cache") + else: + print("[WARN] converse2d extension not available; nothing to clear") if __name__ == "__main__": main() diff --git a/test/test_error.py b/test/test_error.py index de14de1..2c75c17 100644 --- a/test/test_error.py +++ b/test/test_error.py @@ -1,35 +1,32 @@ -import torch +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import sys, pathlib, torch +PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PROJECT_ROOT)) + from models.util_converse import Converse2D torch.manual_seed(0) - device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float32 B, C, H, W, scale = 2, 3, 32, 40, 2 x = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=True) -m = Converse2D( - in_channels=C, out_channels=C, kernel_size=5, scale=scale, - padding=2, padding_mode="circular", eps=1e-5, backend="pytorch" -).to(device=device, dtype=dtype) +m = Converse2D(C, C, 5, scale=scale, padding=2, padding_mode="circular", eps=1e-5, backend="pytorch").to(device=device, dtype=dtype) m.eval() - x_py = x.detach().clone().requires_grad_(True) -m.backend = "python" y_py = m(x_py) -loss_py = y_py.square().mean() -g_py = torch.autograd.grad(loss_py, x_py)[0].detach() - +g_py = torch.autograd.grad(y_py.square().mean(), x_py)[0].detach() have_cuda = False try: if device == "cuda": - m.backend = "cuda" - x_cuda = x.detach().clone().requires_grad_(True) - y_cuda = m(x_cuda) - loss_cuda = y_cuda.square().mean() - g_cuda = torch.autograd.grad(loss_cuda, x_cuda)[0].detach() + m.backend = "cuda" + x_cu = x.detach().clone().requires_grad_(True) + y_cu = m(x_cu) + g_cu = torch.autograd.grad(y_cu.square().mean(), x_cu)[0].detach() have_cuda = True print("[INFO] CUDA backend: OK") else: @@ -41,25 +38,19 @@ if have_cuda: with torch.no_grad(): - out_abs = (y_cuda - y_py).abs() - grad_abs = (g_cuda - g_py).abs() - out_mae = out_abs.max().item() - grad_mae = grad_abs.max().item() - out_rel = (out_abs / (y_py.abs() + 1e-8)).max().item() - grad_rel = (grad_abs / (g_py.abs() + 1e-8)).max().item() - + out_mae = (y_cu - y_py).abs().max().item() + grad_mae = (g_cu - g_py).abs().max().item() + out_rel = ((y_cu - y_py).abs() / (y_py.abs() + 1e-8)).max().item() + grad_rel = ((g_cu - g_py).abs() / (g_py.abs() + 1e-8)).max().item() print(f"forward: max|Δ|={out_mae:.3e} max rel={out_rel:.3e}") print(f"backward: max|Δ|={grad_mae:.3e} max rel={grad_rel:.3e}") +# gradcheck (float64) try: - torch.manual_seed(0) - B2, C2, H2, W2, s2 = 1, 2, 8, 9, 2 - x64 = torch.randn(B2, C2, H2, W2, device=device, dtype=torch.float64, requires_grad=True) - m64 = Converse2D(C2, C2, kernel_size=5, scale=s2, padding=2, - padding_mode="circular", eps=1e-5, backend="auto").to(device=device, dtype=torch.float64) + x64 = torch.randn(1,2,8,9, device=device, dtype=torch.float64, requires_grad=True) + m64 = Converse2D(2,2,5, scale=2, padding=2, padding_mode="circular", eps=1e-5, backend="auto").to(device=device, dtype=torch.float64) m64.eval() - def f(inp): return m64(inp) - torch.autograd.gradcheck(f, (x64,), eps=1e-6, atol=1e-4, rtol=1e-4) + torch.autograd.gradcheck(lambda t: m64(t), (x64,), eps=1e-6, atol=1e-4, rtol=1e-4) print("[INFO] Gradcheck (float64) passed.") except Exception as e: print("[WARN] Gradcheck skipped/failed ->", repr(e)) diff --git a/test/test_speed.py b/test/test_speed.py index 5ff90ce..a2fe2e5 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -5,13 +5,11 @@ import torch import torch.nn.functional as F -ROOT = pathlib.Path(__file__).resolve().parent -MODELS = ROOT / "models" -BACKEND = MODELS / "backend" +# 路径:项目根 + 包源码目录 +PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PROJECT_ROOT)) +PKG = PROJECT_ROOT / "Converse2D/torch_converse2d" # <== 统一放置 CUDA 源 -# ----------------------------- -# Common utils -# ----------------------------- def synchronize(): if torch.cuda.is_available(): torch.cuda.synchronize() @@ -40,23 +38,16 @@ def to_dtype(name): if name in ("fp32","float32","float"):return torch.float32 raise ValueError(name) -# ----------------------------- -# Subprocess runner -# ----------------------------- +# -------- parent <-> child plumbing -------- +import re def _parse_last_json_from_text(txt: str): - # 提取最后一个 {...} JSON 对象;把所有输出(包含编译日志)都容忍 m = re.findall(r"\{.*\}", txt, flags=re.S) if not m: tail = txt[-2000:] if len(txt) > 2000 else txt - raise RuntimeError("Child produced no JSON. Tail of output:\n" + tail) + raise RuntimeError("Child produced no JSON. Tail:\n" + tail) return json.loads(m[-1]) - def run_variant_subprocess(variant, case_args, cache_root): - """ - 在子进程中跑一个 variant(pytorch / cuda_v1..v4),返回 json 结果。 - 合并 stdout+stderr,解析最后一个 JSON。 - """ cmd = [ sys.executable, __file__, "--worker", "--variant", variant, @@ -68,18 +59,13 @@ def run_variant_subprocess(variant, case_args, cache_root): ] env = os.environ.copy() env["TORCH_EXTENSIONS_DIR"] = str(pathlib.Path(cache_root) / variant) - env.setdefault("PYTHONWARNINGS", "ignore") - env.setdefault("TORCH_SHOW_CPP_STACKTRACES", "0") - proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) out = proc.stdout if proc.returncode != 0: raise RuntimeError(f"Subprocess failed (variant={variant}). Output:\n{out}") return _parse_last_json_from_text(out) -# ----------------------------- -# Worker: run in subprocess -# ----------------------------- +# -------- worker -------- def worker_main(args): device = "cuda" if (args.device=="cuda" and torch.cuda.is_available()) else "cpu" dtype = to_dtype(args.dtype) @@ -102,18 +88,18 @@ def call(): from torch.utils.cpp_extension import load vnum = int(args.variant.split("_v")[1]) - cpp = BACKEND / f"converse2d_v{vnum}.cpp" - cu = BACKEND / f"converse2d_v{vnum}.cu" + cpp = PKG / f"converse2d_v{vnum}.cpp" + cu = PKG / f"converse2d_v{vnum}.cu" sources = [str(cpp)] if cu.exists(): sources.append(str(cu)) - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9" # RTX 4090 (sm_89) + os.environ["TORCH_CUDA_ARCH_LIST"] = "7.5" # RTX 4090 load( name=f"converse2d_v{vnum}_ext", sources=sources, verbose=False, extra_cflags=["-O3"], - extra_cuda_cflags=(["-O3","-gencode=arch=compute_89,code=sm_89"] if cu.exists() else []) + extra_cuda_cflags=(["-O3","-gencode=arch=compute_75,code=sm_75"] if cu.exists() else []), ) torch.manual_seed(0) @@ -129,15 +115,11 @@ def call(): _ = converse2d_forward(x, x0, weight, bias, int(s), float(1e-5)) stat = timed_run(call, args.warmup, args.iters) stat["tp"] = tp_gpix_per_s(B,H,W,s,stat["mean_ms"]) - try: torch.ops.converse2d.clear_cache() except Exception: pass - print(json.dumps({"variant": args.variant, **stat})) -# ----------------------------- -# Orchestrator (parent) -# ----------------------------- +# -------- parent orchestrator -------- def parent_main(args): device = "cuda" if (args.device=="cuda" and torch.cuda.is_available()) else "cpu" Bs = [int(x) for x in args.B_list.split(",")] @@ -152,8 +134,7 @@ def parent_main(args): print(f"[Grid] B={Bs} C={Cs} H={Hs} W={Ws} scale={Ss} ksize={Ks}\n") variants = ["pytorch", "cuda_v1", "cuda_v2", "cuda_v3", "cuda_v4"] - results = [] - cache_root = ROOT / ".torch_ext_cache_grid" + cache_root = PROJECT_ROOT / ".torch_ext_cache_grid" cache_root.mkdir(exist_ok=True) for B in Bs: @@ -167,44 +148,14 @@ def parent_main(args): dtype=args.dtype,device=device) base = run_variant_subprocess("pytorch", case, cache_root) base_mean = base["mean_ms"] - results.append({**case,"variant":"pytorch",**base}) - print(f"[Case] B{B} C{C} {H}x{W} s{s} k{k}") print(f" PyTorch : {base_mean:.3f} ms") - for v in variants[1:]: r = run_variant_subprocess(v, case, cache_root) sp = base_mean / r["mean_ms"] if r["mean_ms"]>0 else None - results.append({**case, "variant":v, **r, "speedup_vs_pytorch": sp}) print(f" {v:8s}: {r['mean_ms']:.3f} ms ({sp:.2f}x vs PyTorch)") print("") - # 将原来的header定义 - header = ["variant","B","C","H","W","scale","ksize","mean_ms","p50_ms","p90_ms","tp","speedup_vs_pytorch"] - - # 修改为包含所有需要的字段 - header = ["variant","B","C","H","W","scale","ksize","mean_ms","p50_ms","p90_ms","tp","speedup_vs_pytorch","warmup","dtype","device","iters"] - print("\n=== Summary (normalized to PyTorch) ===") - print(" | ".join(h.rjust(10) for h in header)) - print("-"*120) - for r in results: - line=[] - for h in header: - v = r.get(h,"") - if isinstance(v,float): - line.append(f"{v:10.3f}") - else: - line.append(str(v).rjust(10)) - print(" | ".join(line)) - - if args.csv: - with open(args.csv,"w",newline="") as f: - w=csv.DictWriter(f, fieldnames=header); w.writeheader(); w.writerows(results) - print(f"\n[Saved] {args.csv}") - -# ----------------------------- -# CLI -# ----------------------------- def main(): p = argparse.ArgumentParser() p.add_argument("--worker", action="store_true", help="internal") @@ -220,19 +171,15 @@ def main(): p.add_argument("--dtype", default="float32", choices=["float16","bfloat16","float32"]) p.add_argument("--device", default="cuda") # grid - p.add_argument("--B_list", default="1") - p.add_argument("--C_list", default="3") + p.add_argument("--B_list", default="1,2") + p.add_argument("--C_list", default="3,8") p.add_argument("--H_list", default="128,256") p.add_argument("--W_list", default="128,256") p.add_argument("--scale_list", default="1,2,3") p.add_argument("--ksize_list", default="3,5,7") - p.add_argument("--csv", default="") args = p.parse_args() - - if args.worker: - worker_main(args) - else: - parent_main(args) + if args.worker: worker_main(args) + else: parent_main(args) if __name__ == "__main__": main() From aae6bb2ed0a663075d3cd6f89d6368d89e44f23f Mon Sep 17 00:00:00 2001 From: Yiozolm <1473416941@qq.com> Date: Sun, 31 Aug 2025 17:10:07 +0800 Subject: [PATCH 11/22] Add TODO list --- Converse2D/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Converse2D/README.md b/Converse2D/README.md index 3401b8f..485ba2b 100644 --- a/Converse2D/README.md +++ b/Converse2D/README.md @@ -35,3 +35,11 @@ out = torch.ops.converse2d.forward(x, x0, weight, bias, scale, eps) print(torch.ops.converse2d) ``` +**TODO** +- [ ] Temporary Tensor Reuse and In-Place Writing +- [ ] Larger batched FFT +- [ ] Eliminate redundant calculations of `conj/abs/pow(2)` +- [ ] The minimal necessary policy for `contiguous()` +- [ ] R2C/C2R (Real FFT) replaces C2C **(Optional)** +- [ ] Mixed precision **(Optional)** +- [ ] Adaptive padding **(Optional)** \ No newline at end of file From 0ff28496b69d6a6d90fbb6d30d0feaa95b77aa4e Mon Sep 17 00:00:00 2001 From: Yiozolm <1473416941@qq.com> Date: Sun, 31 Aug 2025 17:24:00 +0800 Subject: [PATCH 12/22] Update test results saving --- test/results_2080ti.csv | 361 ++++++++++++++++++++++++++++++++++++++++ test/results_4090.csv | 181 ++++++++++++++++++++ test/test_speed.py | 52 +++++- 3 files changed, 586 insertions(+), 8 deletions(-) create mode 100644 test/results_2080ti.csv create mode 100644 test/results_4090.csv diff --git a/test/results_2080ti.csv b/test/results_2080ti.csv new file mode 100644 index 0000000..f50f58f --- /dev/null +++ b/test/results_2080ti.csv @@ -0,0 +1,361 @@ +variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters +pytorch,1,3,128,128,1,3,1.3588724099099636,1.3492174912244081,1.3725919649004936,0.012057055453120557,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,3,0.718311658129096,0.712866079993546,0.7390974787995219,0.022809040915016645,1.8917588132277605,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,3,0.4002129239961505,0.3978795139119029,0.4132815869525075,0.0409382081827962,3.3953736334687536,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,3,0.3714838158339262,0.36968302447348833,0.3826369298622012,0.044104209393941815,3.657958575825157,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,3,0.36731038708239794,0.36497542168945074,0.3836334450170398,0.044605327200628854,3.6995207805139763,10,float32,cuda,50 +pytorch,1,3,128,128,1,5,2.02513235155493,2.008417039178312,2.028485178016126,0.008090335422976228,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,5,0.5582067556679249,0.5469518946483731,0.6023913389071822,0.029351131697422133,3.6279251925780596,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,5,0.5319257266819477,0.5326560931280255,0.5428183358162642,0.030801292695881242,3.8071712834559133,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,5,0.3715480910614133,0.37001201417297125,0.3865184495225549,0.04409657967342882,5.450525518162857,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,5,0.3673481987789273,0.3656859043985605,0.37831307854503393,0.04460073590794984,5.512841381246758,10,float32,cuda,50 +pytorch,1,3,128,128,1,7,1.6548363445326686,1.6421884065493941,1.6600201604887843,0.009900676918373397,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,7,0.53962841629982,0.5401384551078081,0.5467256065458059,0.03036163312588968,3.0666219467827918,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,7,0.40642036590725183,0.4043428925797343,0.41671814396977425,0.04031294042911952,4.071735777410118,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,7,0.3748922608792782,0.35778700839728117,0.4316947190091014,0.0437032227914567,4.414165127472596,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,7,0.37697022780776024,0.3745625726878643,0.387030653655529,0.04346231821881484,4.389832995980173,10,float32,cuda,50 +pytorch,1,3,128,128,2,3,1.8107919162139297,1.7898115329444408,1.8179329112172127,0.03619190002627418,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,3,0.6132322456687689,0.6145505467429757,0.6199152441695333,0.10686978785423916,2.9528648061211222,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,3,0.48999007791280746,0.4895030288025737,0.5026075756177306,0.13374964709318457,3.6955685387085646,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,3,0.4332400672137737,0.42917008977383375,0.45607127249240875,0.1512694807326363,4.179650159920297,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,3,0.42093018535524607,0.4206504672765732,0.4302357789129019,0.15569327712786998,4.301881830322297,10,float32,cuda,50 +pytorch,1,3,128,128,2,5,2.1734800469130278,2.159254509024322,2.198834274895489,0.030152565740403335,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,5,0.6376416468992829,0.6374488584697247,0.6438577082008123,0.10277873209613546,3.408623099639435,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,5,0.478832246735692,0.4765290068462491,0.4935482516884804,0.13686630432009075,4.539126305152032,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,5,0.4307904839515686,0.43061794713139534,0.44139688834547997,0.15212963712394317,5.0453297551423635,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,5,0.4390825191512704,0.4367530345916748,0.46043451875448227,0.14925668215322843,4.950049141364774,10,float32,cuda,50 +pytorch,1,3,128,128,2,7,2.1709761628881097,2.145085483789444,2.219454082660377,0.0301873420447029,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,7,0.8158506592735648,0.8078685496002436,0.8693302515894175,0.08032842684512065,2.660997007492954,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,7,0.48865099903196096,0.48417190555483103,0.5218630190938711,0.134116169065099,4.44279489285586,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,7,0.4235974932089448,0.42357854545116425,0.4345803987234831,0.15471290800975904,5.1250920925947225,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,7,0.42114653158932924,0.41657453402876854,0.45004035346210003,0.1556132962859251,5.154918775408755,10,float32,cuda,50 +pytorch,1,3,128,128,3,3,2.2250490309670568,2.2027044324204326,2.2900789277628064,0.06627089917920248,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,3,0.6517368508502841,0.6480665178969502,0.67496825940907,0.22625082471187952,3.414029800622999,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,3,0.48238967079669237,0.47929398715496063,0.4991346038877964,0.30567818700692434,4.612555296410616,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,3,0.4496808862313628,0.44654798693954945,0.46137943863868713,0.3279125364562043,4.948062279485589,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,3,0.426276377402246,0.423771096393466,0.4365326603874564,0.34591642374978826,5.219733367649035,10,float32,cuda,50 +pytorch,1,3,128,128,3,5,2.170854858122766,2.143112593330443,2.2381292656064034,0.06792531497361907,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,5,0.627955780364573,0.6281071109697223,0.6354789482429624,0.23481908218822556,3.4570186723377723,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,5,0.5390440486371517,0.5394599866122007,0.5951238563284278,0.27355092848684337,4.0272309166779054,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,5,0.44883018359541893,0.4477289039641619,0.45829941518604755,0.32853405450315853,4.836695341504932,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,5,0.4391154181212187,0.4359594313427806,0.4516175948083401,0.3358023743071906,4.943699921562529,10,float32,cuda,50 +pytorch,1,3,128,128,3,7,2.302859895862639,2.2876025177538395,2.3640060564503074,0.06403168523839518,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,7,1.1249064421281219,1.1208199430257082,1.1454877443611622,0.13108290118868873,2.047156820887291,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,7,0.4788805032148957,0.47789840027689934,0.4864738555625081,0.3079181528796333,4.8088403691583155,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,7,0.4445829102769494,0.4424919607117772,0.45168143697082996,0.33167266800278816,5.179821002179527,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,7,0.42853476013988256,0.4264645976945758,0.43528247624635696,0.34409344052246155,5.3737995375472885,10,float32,cuda,50 +pytorch,1,3,128,256,1,3,1.7951560160145164,1.7785559175536036,1.8055073218420148,0.018253566658094317,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,3,0.5566090485081077,0.5568780470639467,0.5647305399179459,0.058870764116804856,3.225164989369316,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,3,0.4132459592074156,0.4112694878131151,0.42841359972953796,0.07929418127365924,4.344037675425872,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,3,0.374758830294013,0.37343648727983236,0.3899722592905164,0.08743756611229739,4.790163355473561,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,3,0.3926960425451398,0.37166697438806295,0.43946001678705215,0.0834436725861157,4.57136263553806,10,float32,cuda,50 +pytorch,1,3,128,256,1,5,1.4699688693508506,1.4445290435105562,1.4817556831985712,0.022291628539365324,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,5,0.5615295935422182,0.5593735259026289,0.5792464362457395,0.05835489416202311,2.617794122083669,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,5,0.4279289720579982,0.42431592009961605,0.447837240062654,0.0765734552685507,3.4350767658508112,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,5,0.37391155026853085,0.37109002005308867,0.3892844310030341,0.08763569880755787,3.931327792081223,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,5,0.37634771782904863,0.37289841566234827,0.39316134061664343,0.08706841691248005,3.9058795887758406,10,float32,cuda,50 +pytorch,1,3,128,256,1,7,1.4711981965228915,1.4571724459528923,1.4904611511155963,0.022273001746090804,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,7,0.5647571152076125,0.5636163987219334,0.5756448954343796,0.058021402683796255,2.6050104671670384,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,7,0.41376128792762756,0.4112150054425001,0.43163460213690996,0.0791954224720307,3.5556690281286634,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,7,0.40381398517638445,0.4036114551126957,0.4396184580400586,0.08114627329136968,3.6432571692143787,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,7,0.364238740876317,0.36335503682494164,0.3761310363188386,0.08996297296977228,4.039104113371784,10,float32,cuda,50 +pytorch,1,3,128,256,2,3,2.44969944935292,2.412254922091961,2.47289901599288,0.05350533920992726,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,3,0.6509987637400627,0.6330730393528938,0.6970161804929376,0.20133986007435145,3.7629863308481806,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,3,0.48265908844769,0.4774669650942087,0.49854174721986055,0.27156227477565753,5.075423850883138,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,3,0.4738498851656914,0.4671409260481596,0.5251517286524177,0.2766108088306658,5.169779556866057,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,3,0.433603972196579,0.42632746044546366,0.46148020774126053,0.3022850536539299,5.649624095792007,10,float32,cuda,50 +pytorch,1,3,128,256,2,5,2.4055594438686967,2.3875508923083544,2.4414390325546265,0.05448711747035686,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,5,0.6060707569122314,0.6057386053726077,0.6157765863463283,0.21626517779504295,3.969106604193179,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,5,0.4737306758761406,0.4720538854598999,0.48464827705174685,0.2766804150007577,5.077905160816824,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,5,0.578413438051939,0.5772160366177559,0.5907158832997084,0.22660607685990572,4.158892732455306,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,5,0.44161519035696983,0.43730391189455986,0.46488025691360235,0.29680138469433276,5.447184554327074,10,float32,cuda,50 +pytorch,1,3,128,256,2,7,2.1301160426810384,2.1134009584784508,2.1731291199103,0.06153279791979229,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,7,0.6413558637723327,0.6351619958877563,0.6695011164993048,0.20436704083293092,3.3212700826528705,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,7,0.4867607494816184,0.4816199652850628,0.5202332977205515,0.26927397112356877,4.376104780324894,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,7,0.43859776109457016,0.43492461554706097,0.4642078885808587,0.2988432947603176,4.856650515873802,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,7,0.44836338609457016,0.447872094810009,0.47701222356408834,0.2923343075394517,4.75086973812743,10,float32,cuda,50 +pytorch,1,3,128,256,3,3,3.6672434210777283,3.676302614621818,3.76545328181237,0.0804178959882983,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,3,1.0867606522515416,1.0846543591469526,1.094102906063199,0.2713679404834945,3.3744720270097783,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,3,0.8823559479787946,0.8789199637249112,0.9115118300542235,0.334232461032934,4.1561950474502405,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,3,0.8446352742612362,0.8448135340586305,0.8729633176699281,0.3491589908531183,4.3418070886102855,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,3,0.7847000611945987,0.7841064361855388,0.7894640555605292,0.375827675546548,4.673433331322609,10,float32,cuda,50 +pytorch,1,3,128,256,3,5,3.0747908260673285,3.057163907214999,3.187051648274064,0.09591286584433902,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,5,1.089889518916607,1.0872494895011187,1.096566766500473,0.2705888944534067,2.821194967654879,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,5,0.871727392077446,0.8700459729880095,0.8790818741545081,0.3383075978571515,3.527238967149673,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,5,0.8495373092591763,0.8468780433759093,0.8814086206257343,0.3471442593347345,3.619371147747054,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,5,0.7938195066526532,0.7923004450276494,0.8032999699935317,0.3715101449743573,3.873413036967289,10,float32,cuda,50 +pytorch,1,3,128,256,3,7,3.572395290248096,3.608741913922131,3.6893750075250864,0.08255301444525165,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,7,1.090473192743957,1.088742632418871,1.1037262855097651,0.2704440622312898,3.2760046867900345,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,7,0.8857840858399868,0.8821134688332677,0.9158464381471276,0.3329389235079062,4.033031691758611,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,7,0.7986934995278716,0.7949564605951309,0.8138878736644983,0.36924302022531813,4.472798754916412,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,7,0.7899459125474095,0.7872970309108496,0.8032441604882479,0.37333188933020844,4.522328976585083,10,float32,cuda,50 +pytorch,1,3,256,128,1,3,1.4993645157665014,1.4857796486467123,1.5059428755193949,0.021854592165833953,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,3,0.7230725651606917,0.7242871215566993,0.7289966102689505,0.04531771993412281,2.073601721333862,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,3,0.405933721922338,0.4037924809381366,0.4139351425692439,0.08072253727732694,3.6936190190509848,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,3,0.38566675037145615,0.38364401552826166,0.39766388945281506,0.08496454508572336,3.88772045897756,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,3,0.36679978482425213,0.3645513206720352,0.3863898105919361,0.0893348397565184,4.087691917499091,10,float32,cuda,50 +pytorch,1,3,256,128,1,5,1.4416432147845626,1.4279410243034363,1.4532520435750484,0.022729618302193316,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,5,0.5475908936932683,0.5470785545185208,0.5588179919868708,0.0598402938715685,2.6327012216388237,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,5,0.437559150159359,0.43854652903974056,0.4863776499405503,0.07488816080766657,3.2947390410176918,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,5,0.37053246051073074,0.3683719551190734,0.38405111990869045,0.08843489705283467,3.89073392597627,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,5,0.3653422696515918,0.3640999784693122,0.37856195122003555,0.08969123674424304,3.946007167906916,10,float32,cuda,50 +pytorch,1,3,256,128,1,7,1.5084912767633796,1.5014259843155742,1.5185259049758315,0.021722366250805942,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,7,0.5360735580325127,0.5358878988772631,0.5449545569717884,0.06112593973160048,2.8139632223231015,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,7,0.40473042987287045,0.4033439327031374,0.4160322714596987,0.08096253106121211,3.7271506302039374,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,7,0.359728392213583,0.35842100623995066,0.3687401069328189,0.09109094725151559,4.193417337677747,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,7,0.3663349011912942,0.3646225668489933,0.3803261090070009,0.08944820680050104,4.117792959005207,10,float32,cuda,50 +pytorch,1,3,256,128,2,3,2.133134347386658,2.1139864111319184,2.160443994216621,0.06144573132985211,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,3,0.6306317308917642,0.6240959046408534,0.6758735282346606,0.20784238023458415,3.382535706489482,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,3,0.5141867883503437,0.4855890292674303,0.6295430706813931,0.2549112559280567,4.148559231228704,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,3,0.423141997307539,0.4228444304317236,0.4300167551264167,0.30975890087491614,5.041178519172842,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,3,0.4352907044813037,0.43021049350500107,0.4722143057733774,0.3011137124009725,4.900482195981006,10,float32,cuda,50 +pytorch,1,3,256,128,2,5,2.415369115769863,2.3783150827512145,2.4357105838134885,0.05426582593287103,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,5,0.6090639485046268,0.6084414198994637,0.6245741387829185,0.21520236146271315,3.9657069207594295,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,5,0.47608069609850645,0.4748969804495573,0.48580863513052464,0.2753146705466918,5.0734447659060224,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,5,0.44825855642557144,0.4457270260900259,0.45704932417720556,0.29240267279039234,5.388339120685384,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,5,0.43435935862362385,0.4245585296303034,0.47102225944399834,0.30175935523833164,5.560762230203958,10,float32,cuda,50 +pytorch,1,3,256,128,2,7,2.6145117403939366,2.644861117005348,2.7359860949218273,0.0501324962420138,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,7,0.6282500876113772,0.6161184282973409,0.6790379295125604,0.20863029322978538,4.161577995690182,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,7,0.630432553589344,0.6254641339182854,0.6581001449376345,0.2079080454423657,4.1471712168229775,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,7,0.4406739352270961,0.43863849714398384,0.47307766508311033,0.2974353360210733,5.932984756737607,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,7,0.4157876316457987,0.4142334219068289,0.42685249354690313,0.3152378522689142,6.288094068707623,10,float32,cuda,50 +pytorch,1,3,256,128,3,3,3.8604717003181577,3.8671459769830108,3.9452574448660016,0.07639273718175296,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,3,1.1453364789485931,1.1441544629633427,1.1537153273820877,0.25748939758797007,3.370600492758277,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,3,0.9098627092316747,0.9093486005440354,0.9173051686957479,0.32412802174191285,4.242916718257525,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,3,0.8363525243476033,0.8355780737474561,0.8449909742921591,0.3526168588180517,4.615842707391262,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,3,0.8323059324175119,0.8312843274325132,0.8499871473759413,0.35433124829880763,4.638284493665748,10,float32,cuda,50 +pytorch,1,3,256,128,3,5,3.06809326633811,3.044669982045889,3.175138426013291,0.09612224088350126,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,5,1.1442034784704447,1.1442825198173523,1.1537955841049552,0.25774436588344785,2.6814227749416517,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,5,0.9060597093775868,0.9040400618687272,0.9134287713095546,0.3254884826548444,3.386193243760636,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,5,0.8386831870302558,0.8376993937417865,0.8442188147455454,0.35163695249963434,3.658226746147264,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,5,0.8441900322213769,0.8293610299006104,0.8354945341125131,0.34934314401222816,3.6343632940853605,10,float32,cuda,50 +pytorch,1,3,256,128,3,7,4.190151221118867,4.013016470707953,5.129964789375663,0.07038218537641504,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,7,1.148314750753343,1.1480144457891583,1.1549783172085881,0.25682157248831405,3.648957063705705,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,7,0.908088181167841,0.9069349616765976,0.9139806730672717,0.32476141206983916,4.614255870757122,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,7,0.8402302768081427,0.8394863689318299,0.8471378590911627,0.3509894943566047,4.9869081569357,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,7,0.8234968921169639,0.8228260558098555,0.8296727202832699,0.35812157012744705,5.0882416937205965,10,float32,cuda,50 +pytorch,1,3,256,256,1,3,1.7476121010258794,1.7350299749523401,1.7694034380838275,0.03750031254734915,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,3,0.549295274540782,0.5471804179251194,0.5695500643923879,0.11930923683038953,3.1815531318504533,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,3,0.4180182237178087,0.41706988122314215,0.4277727100998163,0.15677785388667967,4.1807079257999975,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,3,0.39928949903696775,0.39886811282485723,0.4483650205656886,0.16413153904138217,4.376804562205827,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,3,0.37084896117448807,0.3665839321911335,0.39022082928568125,0.17671884476215285,4.7124632505134905,10,float32,cuda,50 +pytorch,1,3,256,256,1,5,1.6822041990235448,1.661254558712244,1.7334380885586143,0.03895840947135974,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,5,0.5354204308241606,0.5324805388227105,0.5574033362790942,0.12240100718443245,3.1418378944452416,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,5,0.3861092450097203,0.3840719582512975,0.3941373433917761,0.16973434551754937,4.356808910341412,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,5,0.36452691070735455,0.3633134765550494,0.3785670269280672,0.17978370889773043,4.614759979611034,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,5,0.3729759994894266,0.36907452158629894,0.3943886375054717,0.17571103794805398,4.510221036544827,10,float32,cuda,50 +pytorch,1,3,256,256,1,7,1.6930993692949414,1.6854360001161695,1.7194566549733281,0.03870771036155498,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,7,0.5525457859039307,0.5440320819616318,0.5869314773008227,0.11860736552860893,3.064179317782934,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,7,0.40626043919473886,0.40068302769213915,0.4402776714414358,0.16131523938166584,4.167522126079726,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,7,0.3619052888825536,0.3611855208873749,0.3710400080308318,0.18108605210593623,4.678294076670403,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,7,0.38414323702454567,0.3804925363510847,0.39720607455819845,0.17060302950436282,4.407468897302902,10,float32,cuda,50 +pytorch,1,3,256,256,2,3,3.5527321184054017,3.5565514117479324,3.642054391093552,0.07378659332121555,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,3,0.9840514045208693,0.9831475326791406,0.990622048266232,0.26639258761856743,3.6103115163330446,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,3,0.7950076553970575,0.7889129919931293,0.8149547036737204,0.3297377053169069,4.468802399935419,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,3,0.7262257812544703,0.7252609357237816,0.7328283973038197,0.36096763123332914,4.892049015759908,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,3,0.7457686774432659,0.744950957596302,0.7515277713537216,0.35150846090601934,4.763852687652834,10,float32,cuda,50 +pytorch,1,3,256,256,2,5,3.0734648229554296,2.9988345922902226,3.703110688365996,0.08529266319954933,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,5,0.9879587311297655,0.985164544545114,0.9992359671741724,0.265339018463078,3.1109242988731065,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,5,0.7879760023206472,0.7871531415730715,0.7931290892884135,0.33268018217301876,3.900454853832921,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,5,0.7290018862113357,0.7279976271092892,0.736223254352808,0.35959303392529934,4.215990220448401,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,5,0.7129816291853786,0.7123425602912903,0.7170673925429583,0.36767286739142807,4.310720917826502,10,float32,cuda,50 +pytorch,1,3,256,256,2,7,3.3563410444185138,3.334320499561727,3.463956923224032,0.07810410102273038,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,7,0.9848026884719729,0.9839965496212244,0.9889016160741448,0.2661893626699421,3.408135541979721,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,7,0.790087953209877,0.7883070502430201,0.7980464026331902,0.3317909087652735,4.248060017600273,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,7,0.7256730319932103,0.7235906086862087,0.7303511258214712,0.36124258232384293,4.625142311268799,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,7,0.7089367602020502,0.7080744253471494,0.7145792013034225,0.3697706406496508,4.734330666478546,10,float32,cuda,50 +pytorch,1,3,256,256,3,3,6.674171555787325,6.3839604845270514,7.9156504943966866,0.08837411431064435,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,3,2.0255626551806927,2.0256630377843976,2.0410686964169145,0.29119020262909817,3.294971665634282,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,3,1.6004600655287504,1.5984426718205214,1.6165184089913964,0.3685340313724963,4.170158131113601,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,3,1.4722054777666926,1.471800496801734,1.4807991916313767,0.4006397265242836,4.533451108952478,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,3,1.4688942860811949,1.4653319958597422,1.4887887286022305,0.40154285137398704,4.543670445878774,10,float32,cuda,50 +pytorch,1,3,256,256,3,5,5.245809820480645,5.211159586906433,5.4148891242221,0.11243716798447673,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,5,2.024001074023545,2.023787470534444,2.0335125038400292,0.29141486512528336,2.591801895664222,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,5,1.5998238697648048,1.5992920380085707,1.60962687805295,0.3686805848738286,3.278992093830834,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,5,1.5252059418708086,1.5229755081236362,1.5630704816430807,0.3867176122304673,3.439410820840473,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,5,1.4621745282784104,1.4624440809711814,1.4697926817461848,0.4033882334788508,3.587676928456107,10,float32,cuda,50 +pytorch,1,3,256,256,3,7,6.0111372359097,5.690280115231872,7.076818193309009,0.09812186560580803,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,7,2.0201045367866755,2.0184863824397326,2.029761392623186,0.2919769691415162,2.9756565199697294,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,7,1.6010280745103955,1.6005345387384295,1.6115479171276093,0.3684032837340294,3.754548300315061,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,7,1.4723994303494692,1.4696434373036027,1.4831635169684887,0.4005869520473851,4.08254520614897,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,7,1.4713413268327713,1.4690780080854893,1.4835603768005967,0.4008750309961475,4.085481136351517,10,float32,cuda,50 +pytorch,1,8,128,128,1,3,1.8926894385367632,1.8772950861603022,1.9192735198885202,0.008656465063104309,,10,float32,cuda,50 +cuda_v1,1,8,128,128,1,3,0.5577139323577285,0.5577560514211655,0.5638764938339591,0.029377067793047325,3.3936563688402814,10,float32,cuda,50 +cuda_v2,1,8,128,128,1,3,0.40895847138017416,0.4075943725183606,0.4168321844190359,0.04006274755650967,4.628072459653953,10,float32,cuda,50 +cuda_v3,1,8,128,128,1,3,0.46957184094935656,0.46929146628826857,0.48177684657275677,0.034891359683910474,4.030670652461229,10,float32,cuda,50 +cuda_v4,1,8,128,128,1,3,0.37016375456005335,0.37107395473867655,0.381774315610528,0.044261491834803476,5.113113899512556,10,float32,cuda,50 +pytorch,1,8,128,128,1,5,1.8416153965517879,1.8245259998366237,1.8423532135784626,0.008896537263251137,,10,float32,cuda,50 +cuda_v1,1,8,128,128,1,5,0.5451813340187073,0.5464649293571711,0.5536802345886827,0.030052386201905076,3.3779868855315485,10,float32,cuda,50 +cuda_v2,1,8,128,128,1,5,0.4040445853024721,0.400731572881341,0.4195025423541665,0.04054998036351548,4.557950937946947,10,float32,cuda,50 +cuda_v3,1,8,128,128,1,5,0.4850057791918516,0.4806459182873368,0.5031358683481812,0.03378104076883392,3.797099901820568,10,float32,cuda,50 +cuda_v4,1,8,128,128,1,5,0.41654150001704693,0.41544996201992035,0.44771814718842506,0.03933341575648401,4.421205081549905,10,float32,cuda,50 +pytorch,1,8,128,128,1,7,1.6017067013308406,1.5827155439183116,1.6301571391522884,0.01022908875038527,,10,float32,cuda,50 +cuda_v1,1,8,128,128,1,7,0.5676993587985635,0.5643828772008419,0.5944325588643551,0.028860346142849047,2.821399525129944,10,float32,cuda,50 +cuda_v2,1,8,128,128,1,7,0.5288944765925407,0.5266304360702634,0.5552147515118122,0.030977823980230752,3.0284050452748295,10,float32,cuda,50 +cuda_v3,1,8,128,128,1,7,0.36155916284769773,0.3598505863919854,0.3737625200301409,0.04531485212809157,4.429998921104758,10,float32,cuda,50 +cuda_v4,1,8,128,128,1,7,0.3603372583165765,0.3556694136932492,0.39303568191826344,0.045468514903351284,4.445021058365414,10,float32,cuda,50 +pytorch,1,8,128,128,2,3,2.2949183266609907,2.2692334605380893,2.3775779409334064,0.028557007558239384,,10,float32,cuda,50 +cuda_v1,1,8,128,128,2,3,0.6507242191582918,0.6433745147660375,0.6627354305237532,0.10071240330469096,3.526714173370487,10,float32,cuda,50 +cuda_v2,1,8,128,128,2,3,0.5056051397696137,0.5040100077167153,0.5108454264700413,0.1296189354994738,4.538953713379384,10,float32,cuda,50 +cuda_v3,1,8,128,128,2,3,0.45750402845442295,0.4539460642263293,0.463389465585351,0.14324682609112538,5.016170752449698,10,float32,cuda,50 +cuda_v4,1,8,128,128,2,3,0.4524831008166075,0.44547696597874165,0.46550282277166843,0.1448363483226789,5.071832124822551,10,float32,cuda,50 +pytorch,1,8,128,128,2,5,2.7334573213011026,2.684682374820113,3.138685319572687,0.023975497802469953,,10,float32,cuda,50 +cuda_v1,1,8,128,128,2,5,0.6499058287590742,0.6500064628198743,0.6569568300619721,0.1008392248537515,4.205928305829088,10,float32,cuda,50 +cuda_v2,1,8,128,128,2,5,0.5131116230040789,0.5048854509368539,0.5448845447972417,0.1277226963137396,5.32721770225691,10,float32,cuda,50 +cuda_v3,1,8,128,128,2,5,0.5604248121380806,0.5589330103248358,0.5785022862255573,0.11693986165596977,4.877473770072154,10,float32,cuda,50 +cuda_v4,1,8,128,128,2,5,0.44553374871611595,0.4445739323273301,0.4527335288003087,0.14709547860033845,6.135241896215589,10,float32,cuda,50 +pytorch,1,8,128,128,2,7,2.5094290915876627,2.4871930945664644,2.599560678936541,0.026115900313619444,,10,float32,cuda,50 +cuda_v1,1,8,128,128,2,7,0.654459148645401,0.6531750550493598,0.6606933427974582,0.10013764821783966,3.834355584732335,10,float32,cuda,50 +cuda_v2,1,8,128,128,2,7,0.5120710004121065,0.5095340311527252,0.5175400758162141,0.1279822523580864,4.900549122227415,10,float32,cuda,50 +cuda_v3,1,8,128,128,2,7,0.4739638837054372,0.47122593969106674,0.48143870662897825,0.1382721389816483,5.294557618966686,10,float32,cuda,50 +cuda_v4,1,8,128,128,2,7,0.45019406359642744,0.44682237785309553,0.45996704138815403,0.14557277694081097,5.574105245948376,10,float32,cuda,50 +pytorch,1,8,128,128,3,3,5.082319467328489,4.843820002861321,6.053852685727179,0.029013524424805582,,10,float32,cuda,50 +cuda_v1,1,8,128,128,3,3,1.3734258990734816,1.3736775144934654,1.3794763945043087,0.10736363723698117,3.7004686388665315,10,float32,cuda,50 +cuda_v2,1,8,128,128,3,3,1.0804085666313767,1.0780345182865858,1.0903888382017612,0.13648170197294476,4.704071796815472,10,float32,cuda,50 +cuda_v3,1,8,128,128,3,3,0.9943246794864535,0.9936904534697533,0.9973179781809449,0.1482976366192155,5.111327891361797,10,float32,cuda,50 +cuda_v4,1,8,128,128,3,3,0.979733420535922,0.9789994219318032,0.9875095216557384,0.15050624681083186,5.187451362584344,10,float32,cuda,50 +pytorch,1,8,128,128,3,5,3.830469772219658,3.782373503781855,3.919286490418017,0.038495539390342996,,10,float32,cuda,50 +cuda_v1,1,8,128,128,3,5,1.363285849802196,1.3614975614473224,1.373448595404625,0.10816220238873228,2.809733389938313,10,float32,cuda,50 +cuda_v2,1,8,128,128,3,5,1.0831660823896527,1.0821976466104388,1.0879256762564182,0.13613424792132192,3.5363642145893044,10,float32,cuda,50 +cuda_v3,1,8,128,128,3,5,0.9961470495909452,0.9955629939213395,0.9998968336731195,0.14802633814008773,3.8452854664304734,10,float32,cuda,50 +cuda_v4,1,8,128,128,3,5,0.9828188642859459,0.9822685969993472,0.9898525662720203,0.15003375022429208,3.897432081752557,10,float32,cuda,50 +pytorch,1,8,128,128,3,7,4.917632234282792,4.972781520336866,5.124774295836687,0.02998516216239696,,10,float32,cuda,50 +cuda_v1,1,8,128,128,3,7,1.3663292303681374,1.3654798967763782,1.372660114429891,0.1079212804078488,3.5991561367371228,10,float32,cuda,50 +cuda_v2,1,8,128,128,3,7,1.0768034821376204,1.0758364805951715,1.0850996943190694,0.13693863592201352,4.566879951502883,10,float32,cuda,50 +cuda_v3,1,8,128,128,3,7,0.9991599293425679,0.9962470503523946,1.0128907160833478,0.14757997760881364,4.921766866209816,10,float32,cuda,50 +cuda_v4,1,8,128,128,3,7,0.9830530360341072,0.9821520652621984,0.9894790593534708,0.149998010885431,5.002407860029409,10,float32,cuda,50 +pytorch,1,8,128,256,1,3,2.067403239198029,1.9350905204191804,2.2734523052349687,0.015849834893705162,,10,float32,cuda,50 +cuda_v1,1,8,128,256,1,3,0.5463245930150151,0.5399889778345823,0.5833456525579095,0.05997899493991736,3.7842031378975625,10,float32,cuda,50 +cuda_v2,1,8,128,256,1,3,0.4024812765419483,0.4006690578535199,0.41477736085653305,0.08141496737820245,5.136644509182666,10,float32,cuda,50 +cuda_v3,1,8,128,256,1,3,0.4817423736676574,0.48082845751196146,0.49287278670817614,0.06801975867417855,4.291512128065947,10,float32,cuda,50 +cuda_v4,1,8,128,256,1,3,0.3658107668161392,0.3623685333877802,0.37896146532148123,0.08957636836443796,5.6515647617259175,10,float32,cuda,50 +pytorch,1,8,128,256,1,5,1.814923589117825,1.8042140873149037,1.8416738137602806,0.01805475458938051,,10,float32,cuda,50 +cuda_v1,1,8,128,256,1,5,0.5642586294561625,0.5593795794993639,0.5676337750628591,0.058072660814389485,3.216474670253718,10,float32,cuda,50 +cuda_v2,1,8,128,256,1,5,0.4024905152618885,0.3992595011368394,0.4130690125748515,0.08141309858861853,4.5092331876116605,10,float32,cuda,50 +cuda_v3,1,8,128,256,1,5,0.36135319620370865,0.36130042281001806,0.37272502668201923,0.09068136201437506,5.022575165198437,10,float32,cuda,50 +cuda_v4,1,8,128,256,1,5,0.3723372332751751,0.3720894455909729,0.38131470791995525,0.08800624023486492,4.874408001459551,10,float32,cuda,50 +pytorch,1,8,128,256,1,7,2.195071200840175,2.182921045459807,2.221221197396517,0.014927989573849759,,10,float32,cuda,50 +cuda_v1,1,8,128,256,1,7,0.5654219072312117,0.562358065508306,0.5834365030750632,0.0579531843052564,3.8821827961868642,10,float32,cuda,50 +cuda_v2,1,8,128,256,1,7,0.40377453435212374,0.4015684826299548,0.4227354656904936,0.08115420169471034,5.436378508521534,10,float32,cuda,50 +cuda_v3,1,8,128,256,1,7,0.3754196735098958,0.37615804467350245,0.38602480199187994,0.08728365163616354,5.8469796756197825,10,float32,cuda,50 +cuda_v4,1,8,128,256,1,7,0.3583986544981599,0.35712961107492447,0.3721989458426833,0.09142891467012537,6.124663620497619,10,float32,cuda,50 +pytorch,1,8,128,256,2,3,4.291606065817177,4.2037880048155785,4.73125025164336,0.030541479807290328,,10,float32,cuda,50 +cuda_v1,1,8,128,256,2,3,1.2242945516481996,1.2185699306428432,1.2504691258072853,0.10705920386850132,3.5053705499543613,10,float32,cuda,50 +cuda_v2,1,8,128,256,2,3,0.9949567029252648,0.9701449889689684,0.9783105226233602,0.13173638572878216,4.313359619769844,10,float32,cuda,50 +cuda_v3,1,8,128,256,2,3,0.8927233191207051,0.8920715190470219,0.8991760900244117,0.14682264615771481,4.807319327162002,10,float32,cuda,50 +cuda_v4,1,8,128,256,2,3,0.8806978771463037,0.8800718933343887,0.8872309466823936,0.1488274281127011,4.872960611331467,10,float32,cuda,50 +pytorch,1,8,128,256,2,5,3.808657177723944,3.604250494390726,4.542697803117335,0.03441422892210233,,10,float32,cuda,50 +cuda_v1,1,8,128,256,2,5,1.21144263073802,1.2110269162803888,1.2173108290880919,0.10819497075164834,3.143902221274549,10,float32,cuda,50 +cuda_v2,1,8,128,256,2,5,0.9682021802291274,0.9672249434515834,0.9752721525728703,0.13537668337927258,3.9337415836252467,10,float32,cuda,50 +cuda_v3,1,8,128,256,2,5,0.925555769354105,0.9218446211889386,0.9484240086749196,0.1416143730501167,4.114994799699427,10,float32,cuda,50 +cuda_v4,1,8,128,256,2,5,0.8804069506004453,0.8796894690021873,0.8853658568114042,0.14887660747181486,4.326018979207747,10,float32,cuda,50 +pytorch,1,8,128,256,2,7,4.556696848012507,4.437481984496117,5.234006349928677,0.02876469608838903,,10,float32,cuda,50 +cuda_v1,1,8,128,256,2,7,1.2110048020258546,1.2096769642084837,1.2163812527433038,0.10823408774328021,3.7627405278573147,10,float32,cuda,50 +cuda_v2,1,8,128,256,2,7,0.9702552948147058,0.9701745584607124,0.9744886308908463,0.13509021872952667,4.696389571244463,10,float32,cuda,50 +cuda_v3,1,8,128,256,2,7,0.8937292965129018,0.8919144747778773,0.9041925193741918,0.1466573832942578,5.0985201736046335,10,float32,cuda,50 +cuda_v4,1,8,128,256,2,7,0.8822291437536478,0.8815015899017453,0.8899182314053178,0.1485691114695258,5.164981093942314,10,float32,cuda,50 +pytorch,1,8,128,256,3,3,8.531096149235964,8.202063501812518,10.174798616208136,0.03456906297163372,,10,float32,cuda,50 +cuda_v1,1,8,128,256,3,3,2.46146900113672,2.46087193954736,2.4681103182956576,0.11981138087207599,3.465855611139631,10,float32,cuda,50 +cuda_v2,1,8,128,256,3,3,1.9530037604272366,1.9528679549694061,1.961797708645463,0.15100431754185942,4.368192382471252,10,float32,cuda,50 +cuda_v3,1,8,128,256,3,3,1.7915777256712317,1.7895660130307078,1.803082088008523,0.1646102180074316,4.761778418538725,10,float32,cuda,50 +cuda_v4,1,8,128,256,3,3,1.7834124807268381,1.7811920261010528,1.78979211486876,0.16536387582070033,4.7835799297306005,10,float32,cuda,50 +pytorch,1,8,128,256,3,5,6.88839724753052,6.5992234740406275,8.008980075828731,0.04281286188971252,,10,float32,cuda,50 +cuda_v1,1,8,128,256,3,5,2.5040291901677847,2.5033794809132814,2.51198906917125,0.11777498487557134,2.750925298546123,10,float32,cuda,50 +cuda_v2,1,8,128,256,3,5,1.9446935411542654,1.9462434574961662,1.9529090961441398,0.15164960121426438,3.5421505248800993,10,float32,cuda,50 +cuda_v3,1,8,128,256,3,5,1.7982480069622397,1.7984320875257254,1.8055621068924665,0.16399962566798088,3.830615810978716,10,float32,cuda,50 +cuda_v4,1,8,128,256,3,5,1.7795015452429652,1.778624951839447,1.786777633242309,0.16572730762070453,3.870970084822269,10,float32,cuda,50 +pytorch,1,8,128,256,3,7,7.802273272536695,7.923941942863166,8.00987989641726,0.037798214661112155,,10,float32,cuda,50 +cuda_v1,1,8,128,256,3,7,2.446714648976922,2.456792979501188,2.4712163489311934,0.12053387595619929,3.1888774916188796,10,float32,cuda,50 +cuda_v2,1,8,128,256,3,7,1.95212263148278,1.9523354712873697,1.9586187787353992,0.15107247631056495,3.9968151317472804,10,float32,cuda,50 +cuda_v3,1,8,128,256,3,7,1.7935446230694652,1.792844501323998,1.800841884687543,0.1644296975980942,4.3501974649417505,10,float32,cuda,50 +cuda_v4,1,8,128,256,3,7,1.783656389452517,1.7828240524977446,1.7908786423504353,0.16534126289342171,4.37431408800187,10,float32,cuda,50 +pytorch,1,8,256,128,1,3,1.963044130243361,1.8950920784845948,2.2486007306724787,0.016692441853529656,,10,float32,cuda,50 +cuda_v1,1,8,256,128,1,3,0.5548413703218102,0.5544835003092885,0.5644256481900811,0.05905832144599172,3.5380276872735488,10,float32,cuda,50 +cuda_v2,1,8,256,128,1,3,0.4113389831036329,0.4055905155837536,0.42195883579552174,0.07966179075165461,4.772326987906203,10,float32,cuda,50 +cuda_v3,1,8,256,128,1,3,0.3597167832776904,0.3596280002966523,0.37039099261164665,0.09109388697803433,5.457193607583082,10,float32,cuda,50 +cuda_v4,1,8,256,128,1,3,0.3603894030675292,0.3588370746001601,0.37184106186032295,0.09092387212578497,5.447008467880863,10,float32,cuda,50 +pytorch,1,8,256,128,1,5,1.8885056534782052,1.820328994654119,2.098836237564683,0.01735128509657817,,10,float32,cuda,50 +cuda_v1,1,8,256,128,1,5,0.5418416438624263,0.5404235562309623,0.5499669583514333,0.06047523362438307,3.4853460874958095,10,float32,cuda,50 +cuda_v2,1,8,256,128,1,5,0.39879064075648785,0.39732293225824833,0.4064814653247595,0.08216842786942187,4.73558168239805,10,float32,cuda,50 +cuda_v3,1,8,256,128,1,5,0.36798678804188967,0.36877451930195093,0.3769081551581621,0.08904667521995345,5.131993090097647,10,float32,cuda,50 +cuda_v4,1,8,256,128,1,5,0.4367315862327814,0.4385111387819052,0.46610296703875065,0.07503006659686481,4.324179228180708,10,float32,cuda,50 +pytorch,1,8,256,128,1,7,1.8230937235057354,1.8126420909538865,1.8479618011042476,0.017973842802216696,,10,float32,cuda,50 +cuda_v1,1,8,256,128,1,7,0.7252306723967195,0.7257384713739157,0.7311144843697548,0.04518286560013981,2.5138122157476226,10,float32,cuda,50 +cuda_v2,1,8,256,128,1,7,0.40039450861513615,0.3968594828620553,0.41684405878186226,0.08183928424327364,4.553243574222228,10,float32,cuda,50 +cuda_v3,1,8,256,128,1,7,0.3655512351542711,0.3629459533840418,0.38629008922725916,0.08963996520534567,4.987245420566962,10,float32,cuda,50 +cuda_v4,1,8,256,128,1,7,0.3956288564950228,0.3936109133064747,0.42458868119865656,0.0828251010057762,4.608090875010961,10,float32,cuda,50 +pytorch,1,8,256,128,2,3,3.778682304546237,3.7858879659324884,3.9505227701738477,0.03468722412633199,,10,float32,cuda,50 +cuda_v1,1,8,256,128,2,3,1.2544773099943995,1.2530890526250005,1.2617591070011258,0.10448335649895905,3.012156755998326,10,float32,cuda,50 +cuda_v2,1,8,256,128,2,3,1.0024462034925818,0.998719478957355,1.0154257994145155,0.1307521536251396,3.7694614348192297,10,float32,cuda,50 +cuda_v3,1,8,256,128,2,3,0.9228101978078485,0.9217309998348355,0.931913498789072,0.14203570822186815,4.09475568597157,10,float32,cuda,50 +cuda_v4,1,8,256,128,2,3,0.9084250312298536,0.9071275126188993,0.9134697495028377,0.14428488372073006,4.15959729712703,10,float32,cuda,50 +pytorch,1,8,256,128,2,5,3.9364816807210445,3.7192939780652523,5.02905345056206,0.03329673821217722,,10,float32,cuda,50 +cuda_v1,1,8,256,128,2,5,1.2536142067983747,1.2501939199864864,1.2665793765336275,0.10455529244100295,3.1401061501803555,10,float32,cuda,50 +cuda_v2,1,8,256,128,2,5,0.9965211106464267,0.9953664848580956,1.001305691897869,0.13152957684456457,3.9502240731935063,10,float32,cuda,50 +cuda_v3,1,8,256,128,2,5,0.9432033356279135,0.9403220610693097,0.9684782708063722,0.13896473331780806,4.173523917936987,10,float32,cuda,50 +cuda_v4,1,8,256,128,2,5,0.9114051656797528,0.9110620012506843,0.9172238875180483,0.14381309755057473,4.319134704250991,10,float32,cuda,50 +pytorch,1,8,256,128,2,7,4.343443512916565,4.256172454915941,4.450874077156186,0.030176978153443713,,10,float32,cuda,50 +cuda_v1,1,8,256,128,2,7,1.2559509836137295,1.2548479717224836,1.2616410618647933,0.10436076065872286,3.4582906256574106,10,float32,cuda,50 +cuda_v2,1,8,256,128,2,7,1.0031307814642787,1.0019369656220078,1.010197401046753,0.13066292294278226,4.329887581135137,10,float32,cuda,50 +cuda_v3,1,8,256,128,2,7,0.9242268512025476,0.9237965568900108,0.9295360650867224,0.14181799612233414,4.69954265802291,10,float32,cuda,50 +cuda_v4,1,8,256,128,2,7,0.908985547721386,0.9086495265364647,0.9132696315646172,0.14419591194663856,4.778341662091946,10,float32,cuda,50 +pytorch,1,8,256,128,3,3,8.587907194159925,8.186906110495329,10.415960941463709,0.03434038041311746,,10,float32,cuda,50 +cuda_v1,1,8,256,128,3,3,2.5986746698617935,2.597782062366605,2.6080208364874125,0.11348554069512842,3.304725787247797,10,float32,cuda,50 +cuda_v2,1,8,256,128,3,3,2.026054351590574,2.0258149597793818,2.035799017176032,0.14555976732237044,4.2387348530003175,10,float32,cuda,50 +cuda_v3,1,8,256,128,3,3,1.861856454052031,1.8611680716276169,1.8698341678828,0.1583967439370374,4.612550648289629,10,float32,cuda,50 +cuda_v4,1,8,256,128,3,3,1.8496425542980433,1.8488640198484063,1.8567960010841489,0.15944269843635916,4.643009090704618,10,float32,cuda,50 +pytorch,1,8,256,128,3,5,6.725382432341576,6.6917366348207,6.833369680680335,0.04385059183873362,,10,float32,cuda,50 +cuda_v1,1,8,256,128,3,5,2.5601254729554057,2.559990040026605,2.569140726700425,0.11519435399373373,2.6269737571017573,10,float32,cuda,50 +cuda_v2,1,8,256,128,3,5,2.0295281894505024,2.029025577940047,2.0362793002277613,0.14531062023821795,3.3137664543415295,10,float32,cuda,50 +cuda_v3,1,8,256,128,3,5,1.8621120229363441,1.8612504936754704,1.8688065931200981,0.15837500449353012,3.6116959396118355,10,float32,cuda,50 +cuda_v4,1,8,256,128,3,5,1.8565295031294227,1.855071634054184,1.8653924344107509,0.15885123263750311,3.6225561840008824,10,float32,cuda,50 +pytorch,1,8,256,128,3,7,9.338932940736413,9.335665381513536,9.872243599966168,0.03157876835303038,,10,float32,cuda,50 +cuda_v1,1,8,256,128,3,7,2.5553510058671236,2.5530324783176184,2.5751939741894603,0.11540958534576178,3.654657586881053,10,float32,cuda,50 +cuda_v2,1,8,256,128,3,7,2.02504463493824,2.0252445247024298,2.030028752051294,0.14563234553543272,4.611717084952664,10,float32,cuda,50 +cuda_v3,1,8,256,128,3,7,1.867524804547429,1.8653111765161157,1.8787236418575048,0.1579159748143041,5.000700883863004,10,float32,cuda,50 +cuda_v4,1,8,256,128,3,7,1.857128101401031,1.8561474280431867,1.8686209805309772,0.15880003096044706,5.028696153857697,10,float32,cuda,50 +pytorch,1,8,256,256,1,3,2.6666148006916046,2.5547059485688806,2.9847467551007867,0.024576478006123267,,10,float32,cuda,50 +cuda_v1,1,8,256,256,1,3,0.6383521435782313,0.6306543946266174,0.6635474506765604,0.10266433763759804,4.1773413428896955,10,float32,cuda,50 +cuda_v2,1,8,256,256,1,3,0.48515914008021355,0.48316200263798237,0.49340776167809963,0.13508144974691116,5.496371356109505,10,float32,cuda,50 +cuda_v3,1,8,256,256,1,3,0.4773047938942909,0.470960047096014,0.496916426345706,0.13730429871717215,5.586817553066902,10,float32,cuda,50 +cuda_v4,1,8,256,256,1,3,0.46722331549972296,0.4641955019906163,0.476591382175684,0.14026697261438115,5.707366717860607,10,float32,cuda,50 +pytorch,1,8,256,256,1,5,2.2214805241674185,2.1969579393044114,2.284339559264481,0.029501046390925274,,10,float32,cuda,50 +cuda_v1,1,8,256,256,1,5,0.641844249330461,0.6362345302477479,0.6586905103176832,0.10210576797776687,3.461089705929052,10,float32,cuda,50 +cuda_v2,1,8,256,256,1,5,0.4858332918956876,0.4853704012930393,0.4932035459205508,0.13489400807483387,4.5725160486622,10,float32,cuda,50 +cuda_v3,1,8,256,256,1,5,0.46678606420755386,0.46398292761296034,0.47364868223667145,0.14039836452970836,4.7590977848465705,10,float32,cuda,50 +cuda_v4,1,8,256,256,1,5,0.4626097623258829,0.46146148815751076,0.46768535394221544,0.14166583876332794,4.802061489144513,10,float32,cuda,50 +pytorch,1,8,256,256,1,7,2.6318022841587663,2.5641301181167364,2.948883641511202,0.024901566654330964,,10,float32,cuda,50 +cuda_v1,1,8,256,256,1,7,0.6335034826770425,0.6320229731500149,0.6402655737474561,0.10345010215738622,4.154361193149823,10,float32,cuda,50 +cuda_v2,1,8,256,256,1,7,0.48780177254229784,0.4840875044465065,0.5006202263757586,0.1343496553086373,5.395229030108864,10,float32,cuda,50 +cuda_v3,1,8,256,256,1,7,0.4626284958794713,0.4613135242834687,0.46950376126915216,0.14166010218504593,5.688802803112306,10,float32,cuda,50 +cuda_v4,1,8,256,256,1,7,0.46635448932647705,0.4646843299269676,0.4732209723442793,0.14052829231825134,5.643351451295543,10,float32,cuda,50 +pytorch,1,8,256,256,2,3,7.843744731508195,7.943923585116863,8.010599622502923,0.03342077145205552,,10,float32,cuda,50 +cuda_v1,1,8,256,256,2,3,2.281022099778056,2.2740440908819437,2.3007393116131425,0.1149239194243259,3.4386973858216425,10,float32,cuda,50 +cuda_v2,1,8,256,256,2,3,1.8037663120776415,1.8034030217677355,1.8102034227922559,0.14533146463859462,4.348537102055917,10,float32,cuda,50 +cuda_v3,1,8,256,256,2,3,1.6670777183026075,1.6650534234941006,1.679345965385437,0.15724761786565708,4.705086418823097,10,float32,cuda,50 +cuda_v4,1,8,256,256,2,3,1.6558474162593484,1.6553474124521017,1.6613460145890713,0.15831410396025372,4.736997294851993,10,float32,cuda,50 +pytorch,1,8,256,256,2,5,7.005033129826188,6.712923059239984,8.13796438742429,0.03742223557570876,,10,float32,cuda,50 +cuda_v1,1,8,256,256,2,5,2.276116036809981,2.274753525853157,2.2865735460072756,0.11517163262352816,3.0776256643065847,10,float32,cuda,50 +cuda_v2,1,8,256,256,2,5,1.8342435453087091,1.8338535446673632,1.840946963056922,0.14291668119561532,3.819031091995593,10,float32,cuda,50 +cuda_v3,1,8,256,256,2,5,1.671182638965547,1.6702709253877401,1.679073041304946,0.15686137103618172,4.191662219613689,10,float32,cuda,50 +cuda_v4,1,8,256,256,2,5,1.663348381407559,1.6634294297546148,1.6699650092050433,0.15760017740731405,4.211404663103942,10,float32,cuda,50 +pytorch,1,8,256,256,2,7,6.85543866828084,6.597880506888032,7.352691376581788,0.03823883673745107,,10,float32,cuda,50 +cuda_v1,1,8,256,256,2,7,2.254212941043079,2.2627164144068956,2.27433773688972,0.11629069961717975,3.041167293232139,10,float32,cuda,50 +cuda_v2,1,8,256,256,2,7,1.8175022071227431,1.8079809378832579,1.837374921888113,0.14423311233002337,3.771901151709504,10,float32,cuda,50 +cuda_v3,1,8,256,256,2,7,1.6725403722375631,1.6703285509720445,1.689420617185533,0.15673403425789817,4.09881805071735,10,float32,cuda,50 +cuda_v4,1,8,256,256,2,7,1.6611922485753894,1.660865033045411,1.6678510000929236,0.15780473345262133,4.126818358416943,10,float32,cuda,50 +pytorch,1,8,256,256,3,3,14.831915474496782,14.468007488176227,15.825923206284642,0.039767216919095315,,10,float32,cuda,50 +cuda_v1,1,8,256,256,3,3,4.795580897480249,4.861502558924258,4.891411936841905,0.1229932332723054,3.092829793005878,10,float32,cuda,50 +cuda_v2,1,8,256,256,3,3,3.8363351486623287,3.8457076298072934,3.8683589547872543,0.15374673409482031,3.86616781374496,10,float32,cuda,50 +cuda_v3,1,8,256,256,3,3,3.5528057161718607,3.5565325524657965,3.5763337276875973,0.16601639580661728,4.174704911947207,10,float32,cuda,50 +cuda_v4,1,8,256,256,3,3,3.4883907111361623,3.528101951815188,3.555237129330635,0.16908197757695997,4.251793076718191,10,float32,cuda,50 +pytorch,1,8,256,256,3,5,12.49826724641025,12.370950891636312,13.13346887473017,0.04719246183261197,,10,float32,cuda,50 +cuda_v1,1,8,256,256,3,5,4.857695340178907,4.866057075560093,4.883602284826338,0.12142054177861986,2.57288001226317,10,float32,cuda,50 +cuda_v2,1,8,256,256,3,5,3.828923678956926,3.873889450915158,3.933584224432707,0.15404433450621288,3.264172465776342,10,float32,cuda,50 +cuda_v3,1,8,256,256,3,5,3.510385132394731,3.535768366418779,3.565721120685339,0.16802258947514145,3.560369240136373,10,float32,cuda,50 +cuda_v4,1,8,256,256,3,5,3.5362447844818234,3.54151357896626,3.5553407855331898,0.16679388332740339,3.5343331720860096,10,float32,cuda,50 +pytorch,1,8,256,256,3,7,13.576734396629035,13.386719045229256,14.27298269700259,0.043443731222026945,,10,float32,cuda,50 +cuda_v1,1,8,256,256,3,7,4.754021079279482,4.852510988712311,4.898783378303051,0.12406844441030404,2.855842279665853,10,float32,cuda,50 +cuda_v2,1,8,256,256,3,7,3.7703912518918514,3.8328804075717926,3.857595380395651,0.15643575443372004,3.600882107345412,10,float32,cuda,50 +cuda_v3,1,8,256,256,3,7,3.478135773912072,3.5281829768791795,3.553933184593916,0.16958049896269256,3.9034515266660943,10,float32,cuda,50 +cuda_v4,1,8,256,256,3,7,3.560623242519796,3.560943529009819,3.572101565077901,0.1656518985093719,3.8130218986665376,10,float32,cuda,50 diff --git a/test/results_4090.csv b/test/results_4090.csv new file mode 100644 index 0000000..ecca0b3 --- /dev/null +++ b/test/results_4090.csv @@ -0,0 +1,181 @@ +variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters +pytorch,1,3,128,128,1,3,1.52592733502388,0.8647029753774405,2.298735734075308,0.010737077463615656,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,3,0.44544885866343975,0.4432220011949539,0.472044013440609,0.036780877717724675,3.425594892312428,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,3,0.3007301315665245,0.29844650998711586,0.3108557313680649,0.054480739640735645,5.074075308234717,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,3,0.30307079665362835,0.2994614187628031,0.3240731079131365,0.0540599760217902,5.03488739882722,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,3,0.32072700560092926,0.32319221645593643,0.3775870893150568,0.051083942773394356,4.757713907392458,10,float32,cuda,50 +pytorch,1,3,128,128,1,5,4.0221707709133625,0.9404211305081844,7.168814446777105,0.004073422271993561,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,5,0.4805893823504448,0.4761132877320051,0.5044737830758095,0.0340914730988643,8.369246010475537,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,5,0.3053080663084984,0.3032265231013298,0.3142551053315401,0.053663829449709974,13.174138566156198,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,5,0.2769072540104389,0.2749105915427208,0.28512105345726013,0.05916782519313254,14.525335514546446,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,5,0.27410948649048805,0.27215038426220417,0.2808789722621441,0.05977173650488943,14.673592010296721,10,float32,cuda,50 +pytorch,1,3,128,128,1,7,3.5734746791422367,0.8284670766443014,10.36820076406002,0.004584893268065006,,10,float32,cuda,50 +cuda_v1,1,3,128,128,1,7,0.5126208532601595,0.5032145418226719,0.5624458193778992,0.03196124366732498,6.970989682561095,10,float32,cuda,50 +cuda_v2,1,3,128,128,1,7,0.30147168785333633,0.299196457490325,0.31120297499001026,0.05434672859884173,11.853433748912119,10,float32,cuda,50 +cuda_v3,1,3,128,128,1,7,0.2754225581884384,0.27269101701676846,0.28396081179380417,0.05948677591176249,12.974517057158911,10,float32,cuda,50 +cuda_v4,1,3,128,128,1,7,0.27901765890419483,0.2780151553452015,0.2874089404940605,0.058720297720029645,12.807342349500704,10,float32,cuda,50 +pytorch,1,3,128,128,2,3,5.1218782644718885,1.200301107019186,8.238791720941663,0.012795306060784977,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,3,0.48394261859357357,0.46004820615053177,0.5499029066413641,0.13542101373600798,10.583647869983038,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,3,0.3488078713417053,0.3472878597676754,0.3549169283360243,0.18788566825603104,14.683952643529377,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,3,0.34828370437026024,0.34455815330147743,0.36369492299854755,0.18816843618479692,14.706051992104753,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,3,0.3102908004075289,0.30851690098643303,0.3184992354363203,0.2112083243007092,16.506703575307196,10,float32,cuda,50 +pytorch,1,3,128,128,2,5,2.7414161060005426,1.1525587178766727,3.5016948357224464,0.02390589296406032,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,5,0.5049472488462925,0.4589471500366926,0.6366008426994085,0.1297878147662695,5.429113857465611,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,5,0.3760635666549206,0.37418887950479984,0.38898889906704426,0.1742684104789562,7.28976787191962,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,5,0.3826252557337284,0.3617340698838234,0.443447008728981,0.17127985938703158,7.1647547173632296,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,5,0.33498174510896206,0.326463021337986,0.3694041632115841,0.1956405116305147,8.183777611848136,10,float32,cuda,50 +pytorch,1,3,128,128,2,7,6.999819874763489,1.7441920936107635,18.960934737697244,0.009362526632474858,,10,float32,cuda,50 +cuda_v1,1,3,128,128,2,7,0.4698681924492121,0.46762311831116676,0.4844237584620714,0.13947741314088583,14.897411630007488,10,float32,cuda,50 +cuda_v2,1,3,128,128,2,7,0.37914127111434937,0.36833412013947964,0.40491526015102863,0.1728537750780349,18.462299960618996,10,float32,cuda,50 +cuda_v3,1,3,128,128,2,7,0.44248790480196476,0.44069206342101097,0.4555768799036741,0.14810800315396327,15.819234376352625,10,float32,cuda,50 +cuda_v4,1,3,128,128,2,7,0.3105806838721037,0.3085271455347538,0.3185570705682039,0.2110111909824616,22.53784680841902,10,float32,cuda,50 +pytorch,1,3,128,128,3,3,7.3913106974214315,1.91159313544631,23.464363627135754,0.019949912273535222,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,3,0.523065309971571,0.5227448418736458,0.5595123395323753,0.281907435245542,14.130760645975844,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,3,0.3566680010408163,0.3553489223122597,0.3626151941716671,0.4134264906571348,20.72322349033937,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,3,0.33742536790668964,0.3273128531873226,0.36468892358243465,0.43700330213695415,21.90502374873426,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,3,0.31496345065534115,0.31304731965065,0.3236026968806982,0.46816860716121145,23.467201296031043,10,float32,cuda,50 +pytorch,1,3,128,128,3,5,3.313328195363283,1.4438305515795946,5.463926354423165,0.044503891949596766,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,5,0.4630151018500328,0.4609823226928711,0.4760188050568104,0.31846909401188367,7.155982995207888,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,5,0.36386603489518166,0.35751843824982643,0.3770098090171814,0.40524804697002687,9.105901286768216,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,5,0.32810534350574017,0.3216217737644911,0.356891006231308,0.44941663681688954,10.098367066994497,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,5,0.3466991614550352,0.3456580452620983,0.357994856312871,0.4253139793622609,9.556781682014543,10,float32,cuda,50 +pytorch,1,3,128,128,3,7,6.756937270984054,1.7321815248578787,10.7049988117069,0.021822904976965263,,10,float32,cuda,50 +cuda_v1,1,3,128,128,3,7,0.4557925648987293,0.45467307791113853,0.465529877692461,0.32351558879149916,14.824588620670786,10,float32,cuda,50 +cuda_v2,1,3,128,128,3,7,0.38067104294896126,0.3792489878833294,0.39135636761784554,0.3873580686823354,17.75006897987246,10,float32,cuda,50 +cuda_v3,1,3,128,128,3,7,0.32470209524035454,0.3225074615329504,0.3330751322209835,0.454127035708989,20.80965096939815,10,float32,cuda,50 +cuda_v4,1,3,128,128,3,7,0.3528321161866188,0.3392628859728575,0.38030720315873623,0.41792113936138414,19.150573207486012,10,float32,cuda,50 +pytorch,1,3,128,256,1,3,4.7790092043578625,1.1113823857158422,11.791642662137747,0.006856651368262623,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,3,0.43059052899479866,0.4229459445923567,0.4630208481103182,0.07610014107020878,11.098732746199326,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,3,0.3584872093051672,0.3435080870985985,0.4171540029346943,0.09140632956894645,13.331045237627055,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,3,0.31005123630166054,0.30543701723217964,0.3326671663671732,0.10568575823422545,15.413611186856176,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,3,0.274530379101634,0.2726605162024498,0.28227311559021473,0.11936019651897593,17.407943048039222,10,float32,cuda,50 +pytorch,1,3,128,256,1,5,3.1171874701976776,0.8945895824581385,2.6657020207494497,0.010512040200752509,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,5,0.48769986256957054,0.4741228185594082,0.5690208170562983,0.06718886453515381,6.391610310845659,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,5,0.34617194905877113,0.3371278289705515,0.35091196186840534,0.0946581607466896,9.004737323960525,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,5,0.29009729623794556,0.290280906483531,0.31339898705482483,0.11295520649431635,10.745317211232734,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,5,0.27919040992856026,0.2755909226834774,0.2894400618970394,0.11736792824791058,11.165095072553921,10,float32,cuda,50 +pytorch,1,3,128,256,1,7,0.8610220160335302,0.8528372272849083,0.8728299289941788,0.03805709887762491,,10,float32,cuda,50 +cuda_v1,1,3,128,256,1,7,0.4062088765203953,0.4038410261273384,0.4159193020313978,0.08066785807511706,2.119653374906738,10,float32,cuda,50 +cuda_v2,1,3,128,256,1,7,0.3033390734344721,0.29832683503627777,0.3132038749754429,0.10802432943766017,2.8384804050623895,10,float32,cuda,50 +cuda_v3,1,3,128,256,1,7,0.2987572457641363,0.2943666186183691,0.31618126668035984,0.10968102184831952,2.8820121628557662,10,float32,cuda,50 +cuda_v4,1,3,128,256,1,7,0.2784122433513403,0.2752200234681368,0.2866980619728565,0.11769597344412998,3.09261548870525,10,float32,cuda,50 +pytorch,1,3,128,256,2,3,4.812479577958584,1.4287945814430714,9.668499417603016,0.027235855836213175,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,3,0.4676768183708191,0.4681474529206753,0.48343208618462086,0.2802619134653656,10.290181999448148,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,3,0.3709117043763399,0.3670940641313791,0.3977825865149498,0.35337790221634474,12.974730970138582,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,3,0.40193247608840466,0.41314586997032166,0.4363299813121557,0.32610452699813897,11.973353397051905,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,3,0.34459170885384083,0.34502334892749786,0.3664530348032713,0.3803689892480681,13.965743963966949,10,float32,cuda,50 +pytorch,1,3,128,256,2,5,3.534023268148303,1.3921631034463644,7.826935639604926,0.03708860696570254,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,5,0.4526952747255564,0.45057223178446293,0.4625048488378525,0.28953692984637747,7.806627251169746,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,5,0.35566513426601887,0.34616305492818356,0.3815658390522003,0.36852642379605594,9.936378148061708,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,5,0.41979328729212284,0.41792611591517925,0.45677535235881805,0.3122298616194654,8.418484466353727,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,5,0.35099745728075504,0.343183521181345,0.3787568770349026,0.37342720661123885,10.068515297880111,10,float32,cuda,50 +pytorch,1,3,128,256,2,7,5.221625966951251,1.6814591363072395,8.201262401416898,0.025101759649117296,,10,float32,cuda,50 +cuda_v1,1,3,128,256,2,7,0.5324110481888056,0.5313355941325426,0.5536912009119987,0.24618572519464088,9.80750866217851,10,float32,cuda,50 +cuda_v2,1,3,128,256,2,7,0.35434636287391186,0.3511281684041023,0.36454498767852783,0.3698979691422422,14.735937811246218,10,float32,cuda,50 +cuda_v3,1,3,128,256,2,7,0.3654781263321638,0.3502380568534136,0.43992577120661736,0.3586315857405801,14.287109380126324,10,float32,cuda,50 +cuda_v4,1,3,128,256,2,7,0.3121230937540531,0.3096370492130518,0.31899111345410347,0.41993688587260436,16.729380399727134,10,float32,cuda,50 +pytorch,1,3,128,256,3,3,10.960625801235437,2.763780066743493,45.61858847737312,0.026906492872583856,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,3,0.46120816841721535,0.45501673594117165,0.4788396880030632,0.6394336011265492,23.765029658625433,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,3,0.36482485942542553,0.361383892595768,0.3778459504246712,0.8083659662460131,30.04352778617576,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,3,0.4006952326744795,0.3770939074456692,0.4836510866880417,0.7360007705396968,27.354020980179047,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,3,0.32492942176759243,0.320731895044446,0.33880281262099743,0.9076186403671916,33.732327905581556,10,float32,cuda,50 +pytorch,1,3,128,256,3,5,11.318621216341853,2.605273388326168,45.752703258767724,0.026055470393708847,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,5,0.4699833784252405,0.4665427841246128,0.48548299819231033,0.6274945318027054,24.083024498157414,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,5,0.4197291377931833,0.42938650585711,0.4759266972541809,0.7026245581866524,26.96648909306595,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,5,0.3752768971025944,0.3709190059453249,0.4119148012250662,0.7858517331520571,30.160719468023988,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,5,0.36198515444993973,0.35343365743756294,0.3866641316562891,0.8147074441440512,31.26819173991056,10,float32,cuda,50 +pytorch,1,3,128,256,3,7,7.481619408354163,2.37233005464077,25.27893357910216,0.039418203988122395,,10,float32,cuda,50 +cuda_v1,1,3,128,256,3,7,0.4637504182755947,0.46098814345896244,0.47804401256144047,0.6359282673999478,16.132857488676233,10,float32,cuda,50 +cuda_v2,1,3,128,256,3,7,0.3909336030483246,0.3640139475464821,0.45265606604516506,0.7543787428361459,19.137826347021225,10,float32,cuda,50 +cuda_v3,1,3,128,256,3,7,0.3465086594223976,0.34476793371140957,0.35570100881159306,0.8510956132859561,21.59143561036953,10,float32,cuda,50 +cuda_v4,1,3,128,256,3,7,0.32119077630341053,0.3183919470757246,0.33296276815235615,0.9181832784681636,23.293381878708452,10,float32,cuda,50 +pytorch,1,3,256,128,1,3,3.82584142498672,1.0650705080479383,7.791366055607796,0.008564913272670139,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,3,0.48191010020673275,0.4935292527079582,0.54588015191257,0.06799608471775748,7.938911061099336,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,3,0.31497727148234844,0.31047710217535496,0.33008200116455555,0.10403290321802264,12.146404745274209,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,3,0.2761990111321211,0.2733948640525341,0.28623687103390694,0.11863909239097628,13.85175641760937,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,3,0.2851138450205326,0.28066104277968407,0.2989410888403654,0.11492952928203187,13.418644838910577,10,float32,cuda,50 +pytorch,1,3,256,128,1,5,3.586227549239993,0.8654326666146517,11.385623132809997,0.009137178148928202,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,5,0.423511927947402,0.4099758807569742,0.48163579776883125,0.07737208290404897,8.46783127601871,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,5,0.2990085817873478,0.296260928735137,0.30720722861588,0.10958882786616574,11.993727831499116,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,5,0.275130495429039,0.27071102522313595,0.28432300314307213,0.1190998473248177,13.034642138261058,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,5,0.27801617980003357,0.2752654254436493,0.28759418055415154,0.11786364384824212,12.899348346630148,10,float32,cuda,50 +pytorch,1,3,256,128,1,7,3.53361826390028,1.0452242568135262,5.550267640501261,0.009273214465399514,,10,float32,cuda,50 +cuda_v1,1,3,256,128,1,7,0.4154033772647381,0.412175664678216,0.4310780204832554,0.07888236300764795,8.50647456736562,10,float32,cuda,50 +cuda_v2,1,3,256,128,1,7,0.300332996994257,0.29647164046764374,0.3192121163010597,0.10910556058755874,11.765667772988163,10,float32,cuda,50 +cuda_v3,1,3,256,128,1,7,0.35520353354513645,0.35250792279839516,0.39303875528275967,0.09225133453194125,9.948150652198553,10,float32,cuda,50 +cuda_v4,1,3,256,128,1,7,0.28812913224101067,0.28413604013621807,0.30301674269139767,0.11372678543518687,12.264008975477436,10,float32,cuda,50 +pytorch,1,3,256,128,2,3,1.6048630606383085,1.2890337966382504,1.585709908977151,0.08167176578160394,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,3,0.49104482866823673,0.45836716890335083,0.6075259298086166,0.2669247130765648,3.2682618102116248,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,3,0.3567333798855543,0.35033351741731167,0.3854172769933939,0.367422863658147,4.498774578238723,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,3,0.35774463787674904,0.3515880089253187,0.385533319786191,0.36638424765197236,4.486057625247256,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,3,0.3179950825870037,0.307686161249876,0.3502919338643551,0.4121824744385429,5.046817226172725,10,float32,cuda,50 +pytorch,1,3,256,128,2,5,5.680167442187667,1.439184183254838,13.859670702368021,0.02307537609305383,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,5,0.4627335909754038,0.46338303945958614,0.48728715628385544,0.2832558572713755,12.2752433645769,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,5,0.3555159270763397,0.35281339660286903,0.3653420601040125,0.36868109138709554,15.977251677300995,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,5,0.3210102953016758,0.31748763285577297,0.3332026768475771,0.4083108919507472,17.694658163064886,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,5,0.32647970132529736,0.31356699764728546,0.3721813205629587,0.40147059516390177,17.398225430646516,10,float32,cuda,50 +pytorch,1,3,256,128,2,7,4.92123176343739,1.4630758669227362,18.036476150155067,0.026633982364701436,,10,float32,cuda,50 +cuda_v1,1,3,256,128,2,7,0.44690595008432865,0.4451470449566841,0.45659723691642284,0.29328765923852085,11.011783939123616,10,float32,cuda,50 +cuda_v2,1,3,256,128,2,7,0.35398226231336594,0.351473456248641,0.3692640457302332,0.37027844034729446,13.902481246590899,10,float32,cuda,50 +cuda_v3,1,3,256,128,2,7,0.3210613317787647,0.31872186809778214,0.33104592002928257,0.4082459861292746,15.32801143062749,10,float32,cuda,50 +cuda_v4,1,3,256,128,2,7,0.32810162752866745,0.32733287662267685,0.3445217851549387,0.3994859793511623,14.99910805229823,10,float32,cuda,50 +pytorch,1,3,256,128,3,3,9.371620612218976,2.651255577802658,27.0526010543108,0.03146862343269484,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,3,0.5212203320115805,0.5533958319574594,0.5914739333093166,0.5658106215116867,17.980151649208416,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,3,0.3939199075102806,0.38670445792376995,0.4216096829622984,0.7486598021002614,23.790675296029292,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,3,0.39537341333925724,0.39467494934797287,0.4045611247420311,0.745907514390568,23.7032139644086,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,3,0.3381696157157421,0.32397685572505,0.3813320305198431,0.8720830798941337,27.712781328339542,10,float32,cuda,50 +pytorch,1,3,256,128,3,5,6.334149120375514,2.5149499997496605,9.2535394243896,0.04655905543040269,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,5,0.4680374823510647,0.46257232315838337,0.49366913735866547,0.6301033808629732,13.533422768957653,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,5,0.3780175279825926,0.38000987842679024,0.39740419015288353,0.7801543001825579,16.756231263083645,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,5,0.37817317992448807,0.36235409788787365,0.45027188025414944,0.7798331972110945,16.749334581686327,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,5,0.34465146251022816,0.34011295065283775,0.35576531663537025,0.8556818469651727,18.378419387056965,10,float32,cuda,50 +pytorch,1,3,256,128,3,7,9.138631783425808,3.381723305210471,33.623141143471,0.03227091396054104,,10,float32,cuda,50 +cuda_v1,1,3,256,128,3,7,0.4677951242774725,0.4621578846126795,0.49515804275870323,0.6304298285611738,19.53554303829219,10,float32,cuda,50 +cuda_v2,1,3,256,128,3,7,0.3610655106604099,0.35483832471072674,0.38772691041231155,0.8167825264190665,25.310176446126675,10,float32,cuda,50 +cuda_v3,1,3,256,128,3,7,0.3526437934488058,0.32708211801946163,0.39914101362228394,0.8362886444584856,25.91462533354493,10,float32,cuda,50 +cuda_v4,1,3,256,128,3,7,0.32145872712135315,0.3138268366456032,0.3484130371361971,0.9174179299498951,28.428631772612924,10,float32,cuda,50 +pytorch,1,3,256,256,1,3,4.861515955999494,1.0742205195128918,14.73111561499536,0.013480568734763365,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,3,0.4849292803555727,0.48372894525527954,0.49910699017345905,0.13514547925822495,10.025206051560351,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,3,0.3039980586618185,0.2988413907587528,0.32487385906279087,0.21558032406024435,15.99193092679473,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,3,0.2766306512057781,0.273675424978137,0.2857776824384928,0.23690794824919645,17.574032142892012,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,3,0.2745348773896694,0.26938505470752716,0.2950329799205065,0.23871648157468708,17.708190675894173,10,float32,cuda,50 +pytorch,1,3,256,256,1,5,1.1926674656569958,1.1049916502088308,1.324015948921442,0.05494909678273036,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,5,0.4482492245733738,0.419301213696599,0.5197371356189251,0.14620438007979739,2.660723990748964,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,5,0.31799268908798695,0.2957459073513746,0.39233872666954994,0.2060927884473046,3.7506128492375206,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,5,0.27819630689918995,0.2757101319730282,0.28871200047433376,0.23557465852250975,4.287143416642058,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,5,0.27723253704607487,0.2712407149374485,0.302234198898077,0.23639360912787882,4.302047221314356,10,float32,cuda,50 +pytorch,1,3,256,256,1,7,1.1875793617218733,1.0118531063199043,1.2298297137022018,0.05518452249370455,,10,float32,cuda,50 +cuda_v1,1,3,256,256,1,7,0.4127761535346508,0.4072303418070078,0.43675173074007034,0.15876886161859766,2.87705418918339,10,float32,cuda,50 +cuda_v2,1,3,256,256,1,7,0.29986392706632614,0.29469607397913933,0.319720059633255,0.21855246358293792,3.9603942139368957,10,float32,cuda,50 +cuda_v3,1,3,256,256,1,7,0.3060135804116726,0.2915910445153713,0.36104372702538967,0.21416043010848088,3.8808060744371273,10,float32,cuda,50 +cuda_v4,1,3,256,256,1,7,0.2705792896449566,0.2661552280187607,0.29452070593833923,0.242206268210674,4.389025351053911,10,float32,cuda,50 +pytorch,1,3,256,256,2,3,4.6642600279301405,2.254175953567028,5.693626776337624,0.05620269848384325,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,3,0.47030373476445675,0.4551168531179428,0.5582175217568874,0.5573929795205436,9.917548348337382,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,3,0.3628383856266737,0.35751843824982643,0.3879097755998373,0.722481441833228,12.854924431091538,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,3,0.3875340800732374,0.3856392577290535,0.41215093806385994,0.6764411531250598,12.035741545746568,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,3,0.33060619607567787,0.3230019938200712,0.3539799712598324,0.7929191984653353,14.108205119248465,10,float32,cuda,50 +pytorch,1,3,256,256,2,5,8.932606596499681,2.1693871822208166,25.42668771930039,0.02934686501280862,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,5,0.5238902755081654,0.5380609072744846,0.5635851062834263,0.5003795875877337,17.05052949844353,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,5,0.37138083949685097,0.36642886698246,0.3967938479036093,0.7058630174759535,24.052416405223358,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,5,0.3278264496475458,0.3224520478397608,0.3462827764451504,0.7996426166401076,27.24797406098057,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,5,0.32294890843331814,0.3143919166177511,0.3508441150188446,0.811719727655085,27.65950391296668,10,float32,cuda,50 +pytorch,1,3,256,256,2,7,5.310848616063595,2.0857034251093864,5.400367686524987,0.049360096465016795,,10,float32,cuda,50 +cuda_v1,1,3,256,256,2,7,0.5602501425892115,0.5671817343682051,0.6339772138744593,0.4679052802887194,9.479423943596625,10,float32,cuda,50 +cuda_v2,1,3,256,256,2,7,0.42169813998043537,0.4308209754526615,0.4505240358412266,0.6216389761931654,12.593957887293481,10,float32,cuda,50 +cuda_v3,1,3,256,256,2,7,0.38437320850789547,0.3447979688644409,0.5109140183776617,0.6820038290848132,13.816906325703249,10,float32,cuda,50 +cuda_v4,1,3,256,256,2,7,0.3148854896426201,0.31107640825212,0.32421983778476715,0.832505811231635,16.865968076493942,10,float32,cuda,50 +pytorch,1,3,256,256,3,3,10.819459995254874,5.247258115559816,43.626357009634376,0.054515105214001526,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,3,0.5695007462054491,0.5678911693394184,0.5740981083363295,1.0356860880867385,18.998148935439186,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,3,0.47695184126496315,0.47551305033266544,0.48432392068207264,1.2366531565025085,22.68459634532429,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,3,0.4476183373481035,0.44674682430922985,0.4503197968006134,1.3176940057781994,24.171172386176853,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,3,0.43853784911334515,0.4363264888525009,0.4481939598917961,1.3449785490409363,24.671667490343484,10,float32,cuda,50 +pytorch,1,3,256,256,3,5,16.53141546063125,6.35837041772902,54.64553306810558,0.03567897748409002,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,5,0.5723217409104109,0.5713619757443666,0.5795460194349289,1.0305811536387692,28.88482872297353,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,5,0.4708682466298342,0.46868249773979187,0.4833988845348358,1.2526306545866557,35.10836752945706,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,5,0.4524185135960579,0.4476869944483042,0.4720529541373253,1.3037132263040514,36.540094986898204,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,5,0.47651665285229683,0.4608971066772938,0.5228085909038782,1.2377825548582126,34.69220931037514,10,float32,cuda,50 +pytorch,1,3,256,256,3,7,18.527234653010964,7.752042729407549,61.81915830820799,0.03183551193940022,,10,float32,cuda,50 +cuda_v1,1,3,256,256,3,7,0.5787044391036034,0.5703659262508154,0.5973172839730978,1.0192145768116458,32.015020796642105,10,float32,cuda,50 +cuda_v2,1,3,256,256,3,7,0.4701301362365484,0.46895304694771767,0.47616218216717243,1.2545973009125009,39.40873648586716,10,float32,cuda,50 +cuda_v3,1,3,256,256,3,7,0.4477470647543669,0.44500199146568775,0.46136612072587013,1.3173151683832396,41.3787964487829,10,float32,cuda,50 +cuda_v4,1,3,256,256,3,7,0.43689489364624023,0.43615163303911686,0.44089406728744507,1.3500363784924174,42.40661752392263,10,float32,cuda,50 \ No newline at end of file diff --git a/test/test_speed.py b/test/test_speed.py index a2fe2e5..a812435 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -5,10 +5,9 @@ import torch import torch.nn.functional as F -# 路径:项目根 + 包源码目录 PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] sys.path.insert(0, str(PROJECT_ROOT)) -PKG = PROJECT_ROOT / "Converse2D/torch_converse2d" # <== 统一放置 CUDA 源 +PKG = PROJECT_ROOT / "Converse2D/torch_converse2d" def synchronize(): if torch.cuda.is_available(): @@ -93,13 +92,25 @@ def call(): sources = [str(cpp)] if cu.exists(): sources.append(str(cu)) - os.environ["TORCH_CUDA_ARCH_LIST"] = "7.5" # RTX 4090 + maj, min = torch.cuda.get_device_capability(0) if device == "cuda" else (0, 0) + arch_num = f"{maj}{min}" # e.g. "75", "86", "89" + arch_str = f"{maj}.{min}" # e.g. "7.5" + os.environ.setdefault("TORCH_CUDA_ARCH_LIST", f"{arch_str}+PTX") + + ext_name = f"converse2d_v{vnum}_sm{arch_num}_ext" + + print(f"[build] compiling {ext_name} for sm_{arch_num} (variant={args.variant}) ...", flush=True) + + extra_cuda = [] + if cu.exists() and device == "cuda": + extra_cuda = ["-O3", f"-gencode=arch=compute_{arch_num},code=sm_{arch_num}"] + load( - name=f"converse2d_v{vnum}_ext", + name=ext_name, sources=sources, verbose=False, extra_cflags=["-O3"], - extra_cuda_cflags=(["-O3","-gencode=arch=compute_75,code=sm_75"] if cu.exists() else []), + extra_cuda_cflags=extra_cuda, ) torch.manual_seed(0) @@ -134,6 +145,7 @@ def parent_main(args): print(f"[Grid] B={Bs} C={Cs} H={Hs} W={Ws} scale={Ss} ksize={Ks}\n") variants = ["pytorch", "cuda_v1", "cuda_v2", "cuda_v3", "cuda_v4"] + results = list() cache_root = PROJECT_ROOT / ".torch_ext_cache_grid" cache_root.mkdir(exist_ok=True) @@ -148,14 +160,35 @@ def parent_main(args): dtype=args.dtype,device=device) base = run_variant_subprocess("pytorch", case, cache_root) base_mean = base["mean_ms"] + results.append({**case,"variant":"pytorch",**base}) print(f"[Case] B{B} C{C} {H}x{W} s{s} k{k}") print(f" PyTorch : {base_mean:.3f} ms") for v in variants[1:]: r = run_variant_subprocess(v, case, cache_root) sp = base_mean / r["mean_ms"] if r["mean_ms"]>0 else None + results.append({**case, "variant":v, **r, "speedup_vs_pytorch": sp}) print(f" {v:8s}: {r['mean_ms']:.3f} ms ({sp:.2f}x vs PyTorch)") print("") + header = ["variant","B","C","H","W","scale","ksize","mean_ms","p50_ms","p90_ms","tp","speedup_vs_pytorch","warmup","dtype","device","iters"] + print("\n=== Summary (normalized to PyTorch) ===") + print(" | ".join(h.rjust(10) for h in header)) + print("-"*120) + for r in results: + line=[] + for h in header: + v = r.get(h,"") + if isinstance(v,float): + line.append(f"{v:10.3f}") + else: + line.append(str(v).rjust(10)) + print(" | ".join(line)) + + if args.csv: + with open(args.csv,"w",newline="") as f: + w=csv.DictWriter(f, fieldnames=header); w.writeheader(); w.writerows(results) + print(f"\n[Saved] {args.csv}") + def main(): p = argparse.ArgumentParser() p.add_argument("--worker", action="store_true", help="internal") @@ -171,15 +204,18 @@ def main(): p.add_argument("--dtype", default="float32", choices=["float16","bfloat16","float32"]) p.add_argument("--device", default="cuda") # grid - p.add_argument("--B_list", default="1,2") + p.add_argument("--B_list", default="1") p.add_argument("--C_list", default="3,8") p.add_argument("--H_list", default="128,256") p.add_argument("--W_list", default="128,256") p.add_argument("--scale_list", default="1,2,3") p.add_argument("--ksize_list", default="3,5,7") + p.add_argument("--csv", default="") args = p.parse_args() - if args.worker: worker_main(args) - else: parent_main(args) + if args.worker: + worker_main(args) + else: + parent_main(args) if __name__ == "__main__": main() From 100223d5c7a4f923558059d6a20c8a10770d6f33 Mon Sep 17 00:00:00 2001 From: Boyce Yi <59084547+Yiozolm@users.noreply.github.com> Date: Sun, 31 Aug 2025 19:27:28 +0800 Subject: [PATCH 13/22] Delete test/results.csv --- test/results.csv | 181 ----------------------------------------------- 1 file changed, 181 deletions(-) delete mode 100644 test/results.csv diff --git a/test/results.csv b/test/results.csv deleted file mode 100644 index ecca0b3..0000000 --- a/test/results.csv +++ /dev/null @@ -1,181 +0,0 @@ -variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters -pytorch,1,3,128,128,1,3,1.52592733502388,0.8647029753774405,2.298735734075308,0.010737077463615656,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,3,0.44544885866343975,0.4432220011949539,0.472044013440609,0.036780877717724675,3.425594892312428,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,3,0.3007301315665245,0.29844650998711586,0.3108557313680649,0.054480739640735645,5.074075308234717,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,3,0.30307079665362835,0.2994614187628031,0.3240731079131365,0.0540599760217902,5.03488739882722,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,3,0.32072700560092926,0.32319221645593643,0.3775870893150568,0.051083942773394356,4.757713907392458,10,float32,cuda,50 -pytorch,1,3,128,128,1,5,4.0221707709133625,0.9404211305081844,7.168814446777105,0.004073422271993561,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,5,0.4805893823504448,0.4761132877320051,0.5044737830758095,0.0340914730988643,8.369246010475537,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,5,0.3053080663084984,0.3032265231013298,0.3142551053315401,0.053663829449709974,13.174138566156198,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,5,0.2769072540104389,0.2749105915427208,0.28512105345726013,0.05916782519313254,14.525335514546446,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,5,0.27410948649048805,0.27215038426220417,0.2808789722621441,0.05977173650488943,14.673592010296721,10,float32,cuda,50 -pytorch,1,3,128,128,1,7,3.5734746791422367,0.8284670766443014,10.36820076406002,0.004584893268065006,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,7,0.5126208532601595,0.5032145418226719,0.5624458193778992,0.03196124366732498,6.970989682561095,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,7,0.30147168785333633,0.299196457490325,0.31120297499001026,0.05434672859884173,11.853433748912119,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,7,0.2754225581884384,0.27269101701676846,0.28396081179380417,0.05948677591176249,12.974517057158911,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,7,0.27901765890419483,0.2780151553452015,0.2874089404940605,0.058720297720029645,12.807342349500704,10,float32,cuda,50 -pytorch,1,3,128,128,2,3,5.1218782644718885,1.200301107019186,8.238791720941663,0.012795306060784977,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,3,0.48394261859357357,0.46004820615053177,0.5499029066413641,0.13542101373600798,10.583647869983038,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,3,0.3488078713417053,0.3472878597676754,0.3549169283360243,0.18788566825603104,14.683952643529377,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,3,0.34828370437026024,0.34455815330147743,0.36369492299854755,0.18816843618479692,14.706051992104753,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,3,0.3102908004075289,0.30851690098643303,0.3184992354363203,0.2112083243007092,16.506703575307196,10,float32,cuda,50 -pytorch,1,3,128,128,2,5,2.7414161060005426,1.1525587178766727,3.5016948357224464,0.02390589296406032,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,5,0.5049472488462925,0.4589471500366926,0.6366008426994085,0.1297878147662695,5.429113857465611,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,5,0.3760635666549206,0.37418887950479984,0.38898889906704426,0.1742684104789562,7.28976787191962,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,5,0.3826252557337284,0.3617340698838234,0.443447008728981,0.17127985938703158,7.1647547173632296,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,5,0.33498174510896206,0.326463021337986,0.3694041632115841,0.1956405116305147,8.183777611848136,10,float32,cuda,50 -pytorch,1,3,128,128,2,7,6.999819874763489,1.7441920936107635,18.960934737697244,0.009362526632474858,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,7,0.4698681924492121,0.46762311831116676,0.4844237584620714,0.13947741314088583,14.897411630007488,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,7,0.37914127111434937,0.36833412013947964,0.40491526015102863,0.1728537750780349,18.462299960618996,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,7,0.44248790480196476,0.44069206342101097,0.4555768799036741,0.14810800315396327,15.819234376352625,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,7,0.3105806838721037,0.3085271455347538,0.3185570705682039,0.2110111909824616,22.53784680841902,10,float32,cuda,50 -pytorch,1,3,128,128,3,3,7.3913106974214315,1.91159313544631,23.464363627135754,0.019949912273535222,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,3,0.523065309971571,0.5227448418736458,0.5595123395323753,0.281907435245542,14.130760645975844,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,3,0.3566680010408163,0.3553489223122597,0.3626151941716671,0.4134264906571348,20.72322349033937,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,3,0.33742536790668964,0.3273128531873226,0.36468892358243465,0.43700330213695415,21.90502374873426,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,3,0.31496345065534115,0.31304731965065,0.3236026968806982,0.46816860716121145,23.467201296031043,10,float32,cuda,50 -pytorch,1,3,128,128,3,5,3.313328195363283,1.4438305515795946,5.463926354423165,0.044503891949596766,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,5,0.4630151018500328,0.4609823226928711,0.4760188050568104,0.31846909401188367,7.155982995207888,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,5,0.36386603489518166,0.35751843824982643,0.3770098090171814,0.40524804697002687,9.105901286768216,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,5,0.32810534350574017,0.3216217737644911,0.356891006231308,0.44941663681688954,10.098367066994497,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,5,0.3466991614550352,0.3456580452620983,0.357994856312871,0.4253139793622609,9.556781682014543,10,float32,cuda,50 -pytorch,1,3,128,128,3,7,6.756937270984054,1.7321815248578787,10.7049988117069,0.021822904976965263,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,7,0.4557925648987293,0.45467307791113853,0.465529877692461,0.32351558879149916,14.824588620670786,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,7,0.38067104294896126,0.3792489878833294,0.39135636761784554,0.3873580686823354,17.75006897987246,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,7,0.32470209524035454,0.3225074615329504,0.3330751322209835,0.454127035708989,20.80965096939815,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,7,0.3528321161866188,0.3392628859728575,0.38030720315873623,0.41792113936138414,19.150573207486012,10,float32,cuda,50 -pytorch,1,3,128,256,1,3,4.7790092043578625,1.1113823857158422,11.791642662137747,0.006856651368262623,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,3,0.43059052899479866,0.4229459445923567,0.4630208481103182,0.07610014107020878,11.098732746199326,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,3,0.3584872093051672,0.3435080870985985,0.4171540029346943,0.09140632956894645,13.331045237627055,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,3,0.31005123630166054,0.30543701723217964,0.3326671663671732,0.10568575823422545,15.413611186856176,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,3,0.274530379101634,0.2726605162024498,0.28227311559021473,0.11936019651897593,17.407943048039222,10,float32,cuda,50 -pytorch,1,3,128,256,1,5,3.1171874701976776,0.8945895824581385,2.6657020207494497,0.010512040200752509,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,5,0.48769986256957054,0.4741228185594082,0.5690208170562983,0.06718886453515381,6.391610310845659,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,5,0.34617194905877113,0.3371278289705515,0.35091196186840534,0.0946581607466896,9.004737323960525,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,5,0.29009729623794556,0.290280906483531,0.31339898705482483,0.11295520649431635,10.745317211232734,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,5,0.27919040992856026,0.2755909226834774,0.2894400618970394,0.11736792824791058,11.165095072553921,10,float32,cuda,50 -pytorch,1,3,128,256,1,7,0.8610220160335302,0.8528372272849083,0.8728299289941788,0.03805709887762491,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,7,0.4062088765203953,0.4038410261273384,0.4159193020313978,0.08066785807511706,2.119653374906738,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,7,0.3033390734344721,0.29832683503627777,0.3132038749754429,0.10802432943766017,2.8384804050623895,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,7,0.2987572457641363,0.2943666186183691,0.31618126668035984,0.10968102184831952,2.8820121628557662,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,7,0.2784122433513403,0.2752200234681368,0.2866980619728565,0.11769597344412998,3.09261548870525,10,float32,cuda,50 -pytorch,1,3,128,256,2,3,4.812479577958584,1.4287945814430714,9.668499417603016,0.027235855836213175,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,3,0.4676768183708191,0.4681474529206753,0.48343208618462086,0.2802619134653656,10.290181999448148,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,3,0.3709117043763399,0.3670940641313791,0.3977825865149498,0.35337790221634474,12.974730970138582,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,3,0.40193247608840466,0.41314586997032166,0.4363299813121557,0.32610452699813897,11.973353397051905,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,3,0.34459170885384083,0.34502334892749786,0.3664530348032713,0.3803689892480681,13.965743963966949,10,float32,cuda,50 -pytorch,1,3,128,256,2,5,3.534023268148303,1.3921631034463644,7.826935639604926,0.03708860696570254,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,5,0.4526952747255564,0.45057223178446293,0.4625048488378525,0.28953692984637747,7.806627251169746,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,5,0.35566513426601887,0.34616305492818356,0.3815658390522003,0.36852642379605594,9.936378148061708,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,5,0.41979328729212284,0.41792611591517925,0.45677535235881805,0.3122298616194654,8.418484466353727,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,5,0.35099745728075504,0.343183521181345,0.3787568770349026,0.37342720661123885,10.068515297880111,10,float32,cuda,50 -pytorch,1,3,128,256,2,7,5.221625966951251,1.6814591363072395,8.201262401416898,0.025101759649117296,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,7,0.5324110481888056,0.5313355941325426,0.5536912009119987,0.24618572519464088,9.80750866217851,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,7,0.35434636287391186,0.3511281684041023,0.36454498767852783,0.3698979691422422,14.735937811246218,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,7,0.3654781263321638,0.3502380568534136,0.43992577120661736,0.3586315857405801,14.287109380126324,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,7,0.3121230937540531,0.3096370492130518,0.31899111345410347,0.41993688587260436,16.729380399727134,10,float32,cuda,50 -pytorch,1,3,128,256,3,3,10.960625801235437,2.763780066743493,45.61858847737312,0.026906492872583856,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,3,0.46120816841721535,0.45501673594117165,0.4788396880030632,0.6394336011265492,23.765029658625433,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,3,0.36482485942542553,0.361383892595768,0.3778459504246712,0.8083659662460131,30.04352778617576,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,3,0.4006952326744795,0.3770939074456692,0.4836510866880417,0.7360007705396968,27.354020980179047,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,3,0.32492942176759243,0.320731895044446,0.33880281262099743,0.9076186403671916,33.732327905581556,10,float32,cuda,50 -pytorch,1,3,128,256,3,5,11.318621216341853,2.605273388326168,45.752703258767724,0.026055470393708847,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,5,0.4699833784252405,0.4665427841246128,0.48548299819231033,0.6274945318027054,24.083024498157414,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,5,0.4197291377931833,0.42938650585711,0.4759266972541809,0.7026245581866524,26.96648909306595,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,5,0.3752768971025944,0.3709190059453249,0.4119148012250662,0.7858517331520571,30.160719468023988,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,5,0.36198515444993973,0.35343365743756294,0.3866641316562891,0.8147074441440512,31.26819173991056,10,float32,cuda,50 -pytorch,1,3,128,256,3,7,7.481619408354163,2.37233005464077,25.27893357910216,0.039418203988122395,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,7,0.4637504182755947,0.46098814345896244,0.47804401256144047,0.6359282673999478,16.132857488676233,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,7,0.3909336030483246,0.3640139475464821,0.45265606604516506,0.7543787428361459,19.137826347021225,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,7,0.3465086594223976,0.34476793371140957,0.35570100881159306,0.8510956132859561,21.59143561036953,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,7,0.32119077630341053,0.3183919470757246,0.33296276815235615,0.9181832784681636,23.293381878708452,10,float32,cuda,50 -pytorch,1,3,256,128,1,3,3.82584142498672,1.0650705080479383,7.791366055607796,0.008564913272670139,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,3,0.48191010020673275,0.4935292527079582,0.54588015191257,0.06799608471775748,7.938911061099336,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,3,0.31497727148234844,0.31047710217535496,0.33008200116455555,0.10403290321802264,12.146404745274209,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,3,0.2761990111321211,0.2733948640525341,0.28623687103390694,0.11863909239097628,13.85175641760937,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,3,0.2851138450205326,0.28066104277968407,0.2989410888403654,0.11492952928203187,13.418644838910577,10,float32,cuda,50 -pytorch,1,3,256,128,1,5,3.586227549239993,0.8654326666146517,11.385623132809997,0.009137178148928202,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,5,0.423511927947402,0.4099758807569742,0.48163579776883125,0.07737208290404897,8.46783127601871,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,5,0.2990085817873478,0.296260928735137,0.30720722861588,0.10958882786616574,11.993727831499116,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,5,0.275130495429039,0.27071102522313595,0.28432300314307213,0.1190998473248177,13.034642138261058,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,5,0.27801617980003357,0.2752654254436493,0.28759418055415154,0.11786364384824212,12.899348346630148,10,float32,cuda,50 -pytorch,1,3,256,128,1,7,3.53361826390028,1.0452242568135262,5.550267640501261,0.009273214465399514,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,7,0.4154033772647381,0.412175664678216,0.4310780204832554,0.07888236300764795,8.50647456736562,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,7,0.300332996994257,0.29647164046764374,0.3192121163010597,0.10910556058755874,11.765667772988163,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,7,0.35520353354513645,0.35250792279839516,0.39303875528275967,0.09225133453194125,9.948150652198553,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,7,0.28812913224101067,0.28413604013621807,0.30301674269139767,0.11372678543518687,12.264008975477436,10,float32,cuda,50 -pytorch,1,3,256,128,2,3,1.6048630606383085,1.2890337966382504,1.585709908977151,0.08167176578160394,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,3,0.49104482866823673,0.45836716890335083,0.6075259298086166,0.2669247130765648,3.2682618102116248,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,3,0.3567333798855543,0.35033351741731167,0.3854172769933939,0.367422863658147,4.498774578238723,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,3,0.35774463787674904,0.3515880089253187,0.385533319786191,0.36638424765197236,4.486057625247256,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,3,0.3179950825870037,0.307686161249876,0.3502919338643551,0.4121824744385429,5.046817226172725,10,float32,cuda,50 -pytorch,1,3,256,128,2,5,5.680167442187667,1.439184183254838,13.859670702368021,0.02307537609305383,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,5,0.4627335909754038,0.46338303945958614,0.48728715628385544,0.2832558572713755,12.2752433645769,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,5,0.3555159270763397,0.35281339660286903,0.3653420601040125,0.36868109138709554,15.977251677300995,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,5,0.3210102953016758,0.31748763285577297,0.3332026768475771,0.4083108919507472,17.694658163064886,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,5,0.32647970132529736,0.31356699764728546,0.3721813205629587,0.40147059516390177,17.398225430646516,10,float32,cuda,50 -pytorch,1,3,256,128,2,7,4.92123176343739,1.4630758669227362,18.036476150155067,0.026633982364701436,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,7,0.44690595008432865,0.4451470449566841,0.45659723691642284,0.29328765923852085,11.011783939123616,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,7,0.35398226231336594,0.351473456248641,0.3692640457302332,0.37027844034729446,13.902481246590899,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,7,0.3210613317787647,0.31872186809778214,0.33104592002928257,0.4082459861292746,15.32801143062749,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,7,0.32810162752866745,0.32733287662267685,0.3445217851549387,0.3994859793511623,14.99910805229823,10,float32,cuda,50 -pytorch,1,3,256,128,3,3,9.371620612218976,2.651255577802658,27.0526010543108,0.03146862343269484,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,3,0.5212203320115805,0.5533958319574594,0.5914739333093166,0.5658106215116867,17.980151649208416,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,3,0.3939199075102806,0.38670445792376995,0.4216096829622984,0.7486598021002614,23.790675296029292,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,3,0.39537341333925724,0.39467494934797287,0.4045611247420311,0.745907514390568,23.7032139644086,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,3,0.3381696157157421,0.32397685572505,0.3813320305198431,0.8720830798941337,27.712781328339542,10,float32,cuda,50 -pytorch,1,3,256,128,3,5,6.334149120375514,2.5149499997496605,9.2535394243896,0.04655905543040269,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,5,0.4680374823510647,0.46257232315838337,0.49366913735866547,0.6301033808629732,13.533422768957653,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,5,0.3780175279825926,0.38000987842679024,0.39740419015288353,0.7801543001825579,16.756231263083645,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,5,0.37817317992448807,0.36235409788787365,0.45027188025414944,0.7798331972110945,16.749334581686327,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,5,0.34465146251022816,0.34011295065283775,0.35576531663537025,0.8556818469651727,18.378419387056965,10,float32,cuda,50 -pytorch,1,3,256,128,3,7,9.138631783425808,3.381723305210471,33.623141143471,0.03227091396054104,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,7,0.4677951242774725,0.4621578846126795,0.49515804275870323,0.6304298285611738,19.53554303829219,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,7,0.3610655106604099,0.35483832471072674,0.38772691041231155,0.8167825264190665,25.310176446126675,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,7,0.3526437934488058,0.32708211801946163,0.39914101362228394,0.8362886444584856,25.91462533354493,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,7,0.32145872712135315,0.3138268366456032,0.3484130371361971,0.9174179299498951,28.428631772612924,10,float32,cuda,50 -pytorch,1,3,256,256,1,3,4.861515955999494,1.0742205195128918,14.73111561499536,0.013480568734763365,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,3,0.4849292803555727,0.48372894525527954,0.49910699017345905,0.13514547925822495,10.025206051560351,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,3,0.3039980586618185,0.2988413907587528,0.32487385906279087,0.21558032406024435,15.99193092679473,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,3,0.2766306512057781,0.273675424978137,0.2857776824384928,0.23690794824919645,17.574032142892012,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,3,0.2745348773896694,0.26938505470752716,0.2950329799205065,0.23871648157468708,17.708190675894173,10,float32,cuda,50 -pytorch,1,3,256,256,1,5,1.1926674656569958,1.1049916502088308,1.324015948921442,0.05494909678273036,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,5,0.4482492245733738,0.419301213696599,0.5197371356189251,0.14620438007979739,2.660723990748964,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,5,0.31799268908798695,0.2957459073513746,0.39233872666954994,0.2060927884473046,3.7506128492375206,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,5,0.27819630689918995,0.2757101319730282,0.28871200047433376,0.23557465852250975,4.287143416642058,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,5,0.27723253704607487,0.2712407149374485,0.302234198898077,0.23639360912787882,4.302047221314356,10,float32,cuda,50 -pytorch,1,3,256,256,1,7,1.1875793617218733,1.0118531063199043,1.2298297137022018,0.05518452249370455,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,7,0.4127761535346508,0.4072303418070078,0.43675173074007034,0.15876886161859766,2.87705418918339,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,7,0.29986392706632614,0.29469607397913933,0.319720059633255,0.21855246358293792,3.9603942139368957,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,7,0.3060135804116726,0.2915910445153713,0.36104372702538967,0.21416043010848088,3.8808060744371273,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,7,0.2705792896449566,0.2661552280187607,0.29452070593833923,0.242206268210674,4.389025351053911,10,float32,cuda,50 -pytorch,1,3,256,256,2,3,4.6642600279301405,2.254175953567028,5.693626776337624,0.05620269848384325,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,3,0.47030373476445675,0.4551168531179428,0.5582175217568874,0.5573929795205436,9.917548348337382,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,3,0.3628383856266737,0.35751843824982643,0.3879097755998373,0.722481441833228,12.854924431091538,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,3,0.3875340800732374,0.3856392577290535,0.41215093806385994,0.6764411531250598,12.035741545746568,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,3,0.33060619607567787,0.3230019938200712,0.3539799712598324,0.7929191984653353,14.108205119248465,10,float32,cuda,50 -pytorch,1,3,256,256,2,5,8.932606596499681,2.1693871822208166,25.42668771930039,0.02934686501280862,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,5,0.5238902755081654,0.5380609072744846,0.5635851062834263,0.5003795875877337,17.05052949844353,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,5,0.37138083949685097,0.36642886698246,0.3967938479036093,0.7058630174759535,24.052416405223358,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,5,0.3278264496475458,0.3224520478397608,0.3462827764451504,0.7996426166401076,27.24797406098057,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,5,0.32294890843331814,0.3143919166177511,0.3508441150188446,0.811719727655085,27.65950391296668,10,float32,cuda,50 -pytorch,1,3,256,256,2,7,5.310848616063595,2.0857034251093864,5.400367686524987,0.049360096465016795,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,7,0.5602501425892115,0.5671817343682051,0.6339772138744593,0.4679052802887194,9.479423943596625,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,7,0.42169813998043537,0.4308209754526615,0.4505240358412266,0.6216389761931654,12.593957887293481,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,7,0.38437320850789547,0.3447979688644409,0.5109140183776617,0.6820038290848132,13.816906325703249,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,7,0.3148854896426201,0.31107640825212,0.32421983778476715,0.832505811231635,16.865968076493942,10,float32,cuda,50 -pytorch,1,3,256,256,3,3,10.819459995254874,5.247258115559816,43.626357009634376,0.054515105214001526,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,3,0.5695007462054491,0.5678911693394184,0.5740981083363295,1.0356860880867385,18.998148935439186,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,3,0.47695184126496315,0.47551305033266544,0.48432392068207264,1.2366531565025085,22.68459634532429,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,3,0.4476183373481035,0.44674682430922985,0.4503197968006134,1.3176940057781994,24.171172386176853,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,3,0.43853784911334515,0.4363264888525009,0.4481939598917961,1.3449785490409363,24.671667490343484,10,float32,cuda,50 -pytorch,1,3,256,256,3,5,16.53141546063125,6.35837041772902,54.64553306810558,0.03567897748409002,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,5,0.5723217409104109,0.5713619757443666,0.5795460194349289,1.0305811536387692,28.88482872297353,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,5,0.4708682466298342,0.46868249773979187,0.4833988845348358,1.2526306545866557,35.10836752945706,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,5,0.4524185135960579,0.4476869944483042,0.4720529541373253,1.3037132263040514,36.540094986898204,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,5,0.47651665285229683,0.4608971066772938,0.5228085909038782,1.2377825548582126,34.69220931037514,10,float32,cuda,50 -pytorch,1,3,256,256,3,7,18.527234653010964,7.752042729407549,61.81915830820799,0.03183551193940022,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,7,0.5787044391036034,0.5703659262508154,0.5973172839730978,1.0192145768116458,32.015020796642105,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,7,0.4701301362365484,0.46895304694771767,0.47616218216717243,1.2545973009125009,39.40873648586716,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,7,0.4477470647543669,0.44500199146568775,0.46136612072587013,1.3173151683832396,41.3787964487829,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,7,0.43689489364624023,0.43615163303911686,0.44089406728744507,1.3500363784924174,42.40661752392263,10,float32,cuda,50 \ No newline at end of file From 207e2144a344d81760bb33d5ce5dbe38489e76d1 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Sun, 31 Aug 2025 22:38:56 +0800 Subject: [PATCH 14/22] Update kernel & Add fwd test --- Converse2D/README.md | 14 +- Converse2D/setup.py | 2 +- Converse2D/torch_converse2d/converse2d_v5.cpp | 208 +++++++++++ Converse2D/torch_converse2d/converse2d_v5.cu | 277 +++++++++++++++ Converse2D/torch_converse2d/converse2d_v6.cpp | 172 ++++++++++ Converse2D/torch_converse2d/converse2d_v6.cu | 323 ++++++++++++++++++ Converse2D/torch_converse2d/converse2d_v7.cpp | 194 +++++++++++ Converse2D/torch_converse2d/converse2d_v7.cu | 251 ++++++++++++++ test/test_error.py | 198 ++++++++--- test/test_speed.py | 231 ++++++++----- test/test_v5.csv | 37 ++ test/test_v6.csv | 37 ++ test/test_v7.csv | 73 ++++ 13 files changed, 1876 insertions(+), 141 deletions(-) create mode 100644 Converse2D/torch_converse2d/converse2d_v5.cpp create mode 100644 Converse2D/torch_converse2d/converse2d_v5.cu create mode 100644 Converse2D/torch_converse2d/converse2d_v6.cpp create mode 100644 Converse2D/torch_converse2d/converse2d_v6.cu create mode 100644 Converse2D/torch_converse2d/converse2d_v7.cpp create mode 100644 Converse2D/torch_converse2d/converse2d_v7.cu create mode 100644 test/test_v5.csv create mode 100644 test/test_v6.csv create mode 100644 test/test_v7.csv diff --git a/Converse2D/README.md b/Converse2D/README.md index 485ba2b..5953226 100644 --- a/Converse2D/README.md +++ b/Converse2D/README.md @@ -6,13 +6,17 @@ We offer four versions Converse2d Kernel. - v2: Add FB/F2B cache & broadcast replace repeat **much faster** - v3: `​splits→permute→view→mean` to `block mean CUDA kernel` **fastest** - v4: STy s-fold upsampler CUDA kernel **fastest** +- v5: Larger batched FFT CUDA kernel +- v6: Eliminate redundant calculations of `conj/abs/pow(2)` +- v7: R2C/C2R (Real FFT) replaces C2C **Tested Device** - NVIDIA RTX 2080ti - NVIDIA RTX 4090 - NVIDIA RTX 5060ti 16g -Under different circumstances, **v3** and **v4** each have their own performance advantages, but they are both faster than **v1** and **v2**. +~~Under different circumstances, **v3** and **v4** each have their own performance advantages, but they are both faster than **v1** and **v2**.~~ +**v7** fastest We highly recommend you to run `test/test_speed.py` first to choose the most suitable backend for GPU. @@ -22,7 +26,7 @@ We highly recommend you to run `test/test_speed.py` first to choose the most sui ```python cd ./Converse2D # Remember to choose the wanted kernel version -CONVERSE2D_VARIANT={v1,v2,v3,v4} pip install -e . +pip install . --no-build-isolation --config-settings=--variant=v7 ``` **Usage** @@ -37,9 +41,9 @@ print(torch.ops.converse2d) **TODO** - [ ] Temporary Tensor Reuse and In-Place Writing -- [ ] Larger batched FFT -- [ ] Eliminate redundant calculations of `conj/abs/pow(2)` +- [x] Larger batched FFT(v5) **Note: not very useful** +- [x] Eliminate redundant calculations of `conj/abs/pow(2)` (v6) *Note: not very useful** - [ ] The minimal necessary policy for `contiguous()` -- [ ] R2C/C2R (Real FFT) replaces C2C **(Optional)** +- [x] R2C/C2R (Real FFT) replaces C2C (v7) **(Optional)** - [ ] Mixed precision **(Optional)** - [ ] Adaptive padding **(Optional)** \ No newline at end of file diff --git a/Converse2D/setup.py b/Converse2D/setup.py index 45b6837..1fd7114 100644 --- a/Converse2D/setup.py +++ b/Converse2D/setup.py @@ -20,7 +20,7 @@ for idx in reversed(to_remove): sys.argv.pop(idx) -if variant not in {"", "v1","v2","v3","v4"}: +if variant not in {"", "v1","v2","v3","v4", "v5", "v6","v7"}: raise SystemExit(f"[setup.py] invalid --variant={variant!r}; pick from v1|v2|v3|v4") if not variant: diff --git a/Converse2D/torch_converse2d/converse2d_v5.cpp b/Converse2D/torch_converse2d/converse2d_v5.cpp new file mode 100644 index 0000000..fefd0e9 --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v5.cpp @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using at::Tensor; + +Tensor block_mean_cuda(const Tensor &input, int64_t s); +Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t scale); + +// ---------- FB Cache ---------- +struct FBKey +{ + int64_t device_id; + at::ScalarType dtype; + int64_t channels; + int64_t H, W; + void *ptr; + + bool operator==(const FBKey &other) const + { + return device_id == other.device_id && dtype == other.dtype && + channels == other.channels && H == other.H && W == other.W && + ptr == other.ptr; + } +}; + +namespace std +{ + template <> + struct hash + { + size_t operator()(const FBKey &k) const + { + return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ + ((hash()(k.H) ^ hash()(k.W)) << 1) ^ + ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + } + }; +} // namespace std + +constexpr size_t FB_CACHE_MAX_SIZE = 64; +static std::unordered_map> fb_cache; +static std::list fb_cache_lru; +static std::mutex fb_cache_mutex; + +inline Tensor fft2_auto_batched(const Tensor& x) +{ + TORCH_CHECK(x.dim() == 4, "Expected input of shape (B,C,H,W)"); + const int64_t B = x.size(0), C = x.size(1), H = x.size(2), W = x.size(3); + + if (B * C >= 8) + { + auto x_reshaped = x.view({B * C, H, W}).contiguous(); + auto fx = at::fft_fftn(x_reshaped, c10::nullopt, {-2, -1}, c10::nullopt); + return fx.view({B, C, H, W}); + } + else + { + return at::fft_fftn(x, c10::nullopt, {-2, -1}, c10::nullopt); + } +} + +inline Tensor ifft2_auto_batched(const Tensor& x) +{ + TORCH_CHECK(x.dim() == 4, "Expected input of shape (B,C,H,W)"); + const int64_t B = x.size(0), C = x.size(1), H = x.size(2), W = x.size(3); + + if (B * C >= 8) + { + auto x_reshaped = x.view({B * C, H, W}).contiguous(); + auto fx = at::fft_ifftn(x_reshaped, c10::nullopt, {-2, -1}, c10::nullopt); + return fx.view({B, C, H, W}); + } + else + { + return at::fft_ifftn(x, c10::nullopt, {-2, -1}, c10::nullopt); + } +} + +static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) +{ + auto C = psf.size(1); + FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; + + { + std::lock_guard lock(fb_cache_mutex); + auto it = fb_cache.find(key); + if (it != fb_cache.end()) + { + fb_cache_lru.remove(key); + fb_cache_lru.push_front(key); + return it->second; + } + } + + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + int64_t kh = psf.size(2), kw = psf.size(3); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); + Tensor FB = fft2_auto_batched(otf); + Tensor F2B = at::abs(FB).pow(2); + + { + std::lock_guard lock(fb_cache_mutex); + fb_cache[key] = {FB, F2B}; + fb_cache_lru.push_front(key); + + if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) + { + fb_cache.erase(fb_cache_lru.back()); + fb_cache_lru.pop_back(); + } + } + + return {FB, F2B}; +} + +static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) +{ + if (s == 1) + return x; + return sfold_upsample_cuda_launcher(x, s); +} + + + + +// ---------- Forward ---------- +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +{ + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); + + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; + + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + Tensor STy = sfold_upsample_zero_insertion(x, scale); + + auto [FB, F2B] = p2o_cached(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + + Tensor F_STy = fft2_auto_batched(STy); + Tensor FB_Fy = FBC * F_STy; + Tensor FR = FB_Fy + fft2_auto_batched(lambda_ * x0); + + Tensor x1 = FB * FR; + + Tensor FBR = block_mean_cuda(x1, scale); // (B,C,H,W) + Tensor invW = block_mean_cuda(F2B, scale); // (B,C,H,W) + + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; + + Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) + .expand({B, C, H, scale, W, scale}) + .reshape({B, C, Hs, Ws}); + Tensor FCBinvWBR = FBC * invWBR_exp; + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = ifft2_auto_batched(FX); + Tensor out = at::real(out_c); + return out; +} + +void clear_fb_cache() +{ + std::lock_guard lock(fb_cache_mutex); + fb_cache.clear(); + fb_cache_lru.clear(); +} + +TORCH_LIBRARY(converse2d, m) +{ + m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); + m.def("clear_cache() -> ()"); +} +TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) +{ + m.impl("forward", TORCH_FN(converse2d_forward)); + m.impl("clear_cache", TORCH_FN(clear_fb_cache)); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v5.cu b/Converse2D/torch_converse2d/converse2d_v5.cu new file mode 100644 index 0000000..e02250c --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v5.cu @@ -0,0 +1,277 @@ +#include +#include +#include +#include +#include +#include +#include + +// ====================================================================== +// S-FOLD UPSAMPLE (zero-insertion upsample) +// forward: out[b,c,h*s, w*s] = x[b,c,h,w]; others = 0 +// backward: grad_x[b,c,h,w] = grad_out[b,c,h*s, w*s] +// dtypes: float/double/half/bfloat16 +// ====================================================================== + +using namespace at; +using namespace at::indexing; + +template +__global__ void sfold_upsample_kernel( + const scalar_t *__restrict__ x, + scalar_t *__restrict__ out, + int B, int C, int H, int W, int s, + int Hs, int Ws, long long total_in) +{ + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int w = static_cast(idx % W); + int h = static_cast((idx / W) % H); + int c = static_cast((idx / (1LL * W * H)) % C); + int b = static_cast(idx / (1LL * W * H * C)); + + long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; + long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); + + out[out_off] = x[in_off]; +} + +template +__global__ void sfold_downsample_grad_kernel( // backward of zero-insertion upsample + const scalar_t *__restrict__ grad_out, // (B,C,Hs,Ws) + scalar_t *__restrict__ grad_in, // (B,C,H,W) + int B, int C, int H, int W, int s, int Hs, int Ws, long long total_in) +{ + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int w = static_cast(idx % W); + int h = static_cast((idx / W) % H); + int c = static_cast((idx / (1LL * W * H)) % C); + int b = static_cast(idx / (1LL * W * H * C)); + + long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; + long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); + + grad_in[in_off] = grad_out[out_off]; +} + +struct SFoldFunction : public torch::autograd::Function +{ + static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &x, int64_t scale) + { + TORCH_CHECK(x.is_cuda() && x.dim() == 4, "sfold: x must be (B,C,H,W) CUDA"); + TORCH_CHECK(scale >= 1, "sfold: scale must be >= 1"); + if (scale == 1) + { + ctx->saved_data["s"] = (int64_t)1; + return x; + } + + auto x_ = x.contiguous(); + const int B = (int)x_.size(0), C = (int)x_.size(1), H = (int)x_.size(2), W = (int)x_.size(3); + const int s = (int)scale, Hs = H * s, Ws = W * s; + + auto out = at::zeros({B, C, Hs, Ws}, x_.options()); + + const long long total = 1LL * B * C * H * W; + const int threads = 256, blocks = (int)((total + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x_.scalar_type(), "sfold_fwd", [&] + { sfold_upsample_kernel<<>>( + x_.data_ptr(), out.data_ptr(), + B, C, H, W, s, Hs, Ws, total); }); + + // save for backward + ctx->saved_data["B"] = (int64_t)B; + ctx->saved_data["C"] = (int64_t)C; + ctx->saved_data["H"] = (int64_t)H; + ctx->saved_data["W"] = (int64_t)W; + ctx->saved_data["s"] = (int64_t)s; + return out; + } + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs) + { + auto go = grad_outputs[0]; // (B,C,Hs,Ws) + const int B = (int)ctx->saved_data["B"].toInt(); + const int C = (int)ctx->saved_data["C"].toInt(); + const int H = (int)ctx->saved_data["H"].toInt(); + const int W = (int)ctx->saved_data["W"].toInt(); + const int s = (int)ctx->saved_data["s"].toInt(); + const int Hs = H * s, Ws = W * s; + + at::Tensor gx; + if (s == 1) + { + gx = go; // identity + } + else + { + gx = go.index({Slice(), Slice(), Slice(0, Hs, s), Slice(0, Ws, s)}).contiguous(); + } + return {gx, torch::Tensor()}; // no grad for scale + } +}; + +// exposed symbol for v4.cpp +at::Tensor sfold_upsample_cuda_launcher(const at::Tensor &x, int64_t scale) +{ + return SFoldFunction::apply(x, scale); +} + +// ====================================================================== +// BLOCK MEAN over non-overlapping s×s tiles +// forward: out[b,c,ho,wo] = mean_{i,j in s×s} in[b,c, ho*s+i, wo*s+j] +// backward: grad_in[b,c,hi,wi] = grad_out[b,c,hi/s, wi/s] / (s*s) +// dtypes: float/double/half/bfloat16 + complex64/complex128 +// ====================================================================== + +template +struct AccT +{ + using type = T; +}; +template <> +struct AccT +{ + using type = float; +}; +template <> +struct AccT +{ + using type = float; +}; + +template +__global__ void block_mean_kernel( + const scalar_t *__restrict__ in, // (B,C,Hs,Ws) + scalar_t *__restrict__ out, // (B,C,Ho,Wo) + int B, int C, int Ho, int Wo, int s, int Hs, int Ws, + long long total_out) +{ + using acc_t = typename AccT::type; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_out) + return; + + int wo = static_cast(idx % Wo); + int ho = static_cast((idx / Wo) % Ho); + int c = static_cast((idx / (1LL * Wo * Ho)) % C); + int b = static_cast(idx / (1LL * Wo * Ho * C)); + + const int hi0 = ho * s; + const int wi0 = wo * s; + + const long long base_in = ((long long)b * C + c) * Hs * Ws; + + acc_t acc = acc_t(0); + for (int di = 0; di < s; ++di) + { + const int hi = hi0 + di; + const long long row_off = base_in + (long long)hi * Ws + wi0; +#pragma unroll + for (int dj = 0; dj < s; ++dj) + { + acc += static_cast(in[row_off + dj]); + } + } + const float inv_area = 1.0f / (s * s); + acc = acc * static_cast(inv_area); + + const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; + out[out_off] = static_cast(acc); +} + +template +__global__ void block_mean_grad_kernel( + const scalar_t *__restrict__ grad_out, // (B,C,Ho,Wo) + scalar_t *__restrict__ grad_in, // (B,C,Hs,Ws) + int B, int C, int Ho, int Wo, int s, int Hs, int Ws, + long long total_in) +{ + using acc_t = typename AccT::type; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int wi = static_cast(idx % Ws); + int hi = static_cast((idx / Ws) % Hs); + int c = static_cast((idx / (1LL * Ws * Hs)) % C); + int b = static_cast(idx / (1LL * Ws * Hs * C)); + + const int ho = hi / s; + const int wo = wi / s; + + const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; + acc_t g = static_cast(grad_out[out_off]) * static_cast(1.0f / (s * s)); + + const long long in_off = ((long long)b * C + c) * Hs * Ws + (long long)hi * Ws + wi; + grad_in[in_off] = static_cast(g); +} + +struct BlockMeanFunction : public torch::autograd::Function +{ + static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &input, int64_t s) + { + TORCH_CHECK(input.is_cuda() && input.dim() == 4, "block_mean: input must be (B,C,Hs,Ws) CUDA"); + TORCH_CHECK(s >= 1, "block_mean: s must be >= 1"); + + auto x = input.contiguous(); + const int B = (int)x.size(0), C = (int)x.size(1), Hs = (int)x.size(2), Ws = (int)x.size(3); + TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "block_mean: H,W must be divisible by s"); + const int Ho = Hs / (int)s, Wo = Ws / (int)s; + + auto out = at::empty({B, C, Ho, Wo}, x.options()); + + const long long total_out = 1LL * B * C * Ho * Wo; + const int threads = 256, blocks = (int)((total_out + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "block_mean_fwd", [&] + { block_mean_kernel<<>>( + x.data_ptr(), out.data_ptr(), + B, C, Ho, Wo, (int)s, Hs, Ws, total_out); }); + + // save for backward + ctx->saved_data["B"] = (int64_t)B; + ctx->saved_data["C"] = (int64_t)C; + ctx->saved_data["Hs"] = (int64_t)Hs; + ctx->saved_data["Ws"] = (int64_t)Ws; + ctx->saved_data["s"] = (int64_t)s; + return out; + } + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs) + { + auto go = grad_outputs[0]; // (B,C,Ho,Wo) + const int B = (int)ctx->saved_data["B"].toInt(); + const int C = (int)ctx->saved_data["C"].toInt(); + const int Hs = (int)ctx->saved_data["Hs"].toInt(); + const int Ws = (int)ctx->saved_data["Ws"].toInt(); + const int s = (int)ctx->saved_data["s"].toInt(); + const int Ho = Hs / s, Wo = Ws / s; + + auto go_scaled = go / static_cast(s * s); + auto gi = go_scaled.view({B, C, Ho, 1, Wo, 1}) + .expand({B, C, Ho, s, Wo, s}) + .reshape({B, C, Hs, Ws}) + .contiguous(); + + return {gi, torch::Tensor()}; // no grad for s + } +}; + +// exposed symbol for v4.cpp +at::Tensor block_mean_cuda(const at::Tensor &input, int64_t s) +{ + return BlockMeanFunction::apply(input, s); +} diff --git a/Converse2D/torch_converse2d/converse2d_v6.cpp b/Converse2D/torch_converse2d/converse2d_v6.cpp new file mode 100644 index 0000000..ab8061e --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v6.cpp @@ -0,0 +1,172 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using at::Tensor; + +Tensor block_mean_cuda(const Tensor &input, int64_t s); +Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t scale); +std::tuple fb_postprocess_cuda(const at::Tensor& FB); + +// ---------- FB Cache ---------- +struct FBKey +{ + int64_t device_id; + at::ScalarType dtype; + int64_t channels; + int64_t H, W; + void *ptr; + + bool operator==(const FBKey &other) const + { + return device_id == other.device_id && dtype == other.dtype && + channels == other.channels && H == other.H && W == other.W && + ptr == other.ptr; + } +}; + +namespace std +{ + template <> + struct hash + { + size_t operator()(const FBKey &k) const + { + return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ + ((hash()(k.H) ^ hash()(k.W)) << 1) ^ + ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + } + }; +} // namespace std + +constexpr size_t FB_CACHE_MAX_SIZE = 64; +static std::unordered_map> fb_cache; +static std::list fb_cache_lru; +static std::mutex fb_cache_mutex; + +static inline std::tuple p2o_cached(const Tensor &psf, int64_t H, int64_t W) +{ + auto C = psf.size(1); + FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; + + { + std::lock_guard lock(fb_cache_mutex); + auto it = fb_cache.find(key); + if (it != fb_cache.end()) + { + fb_cache_lru.remove(key); + fb_cache_lru.push_front(key); + return it->second; + } + } + + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + int64_t kh = psf.size(2), kw = psf.size(3); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); + Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); + auto [FBC, F2B] = fb_postprocess_cuda(FB); + + { + std::lock_guard lock(fb_cache_mutex); + fb_cache[key] = std::make_tuple(FB, FBC, F2B); + fb_cache_lru.push_front(key); + + if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) + { + fb_cache.erase(fb_cache_lru.back()); + fb_cache_lru.pop_back(); + } + } + + return std::make_tuple(FB, FBC, F2B); +} + +static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) +{ + if (s == 1) + return x; + return sfold_upsample_cuda_launcher(x, s); +} + +// ---------- Forward ---------- +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +{ + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); + + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; + + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + Tensor STy = sfold_upsample_zero_insertion(x, scale); + + auto [FB, FBC, F2B] = p2o_cached(weight, Hs, Ws); + + Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); + + Tensor x1 = FB * FR; + + Tensor FBR = block_mean_cuda(x1, scale); // (B,C,H,W) + Tensor invW = block_mean_cuda(F2B, scale); // (B,C,H,W) + + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; + + Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) + .expand({B, C, H, scale, W, scale}) + .reshape({B, C, Hs, Ws}); + Tensor FCBinvWBR = FBC * invWBR_exp; + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor out = at::real(out_c); + return out; +} + +void clear_fb_cache() +{ + std::lock_guard lock(fb_cache_mutex); + fb_cache.clear(); + fb_cache_lru.clear(); +} + +TORCH_LIBRARY(converse2d, m) +{ + m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); + m.def("clear_cache() -> ()"); +} +TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) +{ + m.impl("forward", TORCH_FN(converse2d_forward)); + m.impl("clear_cache", TORCH_FN(clear_fb_cache)); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v6.cu b/Converse2D/torch_converse2d/converse2d_v6.cu new file mode 100644 index 0000000..4830d49 --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v6.cu @@ -0,0 +1,323 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// ====================================================================== +// S-FOLD UPSAMPLE (zero-insertion upsample) +// forward: out[b,c,h*s, w*s] = x[b,c,h,w]; others = 0 +// backward: grad_x[b,c,h,w] = grad_out[b,c,h*s, w*s] +// dtypes: float/double/half/bfloat16 +// ====================================================================== + +using namespace at; +using namespace at::indexing; + +template +__global__ void sfold_upsample_kernel( + const scalar_t *__restrict__ x, + scalar_t *__restrict__ out, + int B, int C, int H, int W, int s, + int Hs, int Ws, long long total_in) +{ + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int w = static_cast(idx % W); + int h = static_cast((idx / W) % H); + int c = static_cast((idx / (1LL * W * H)) % C); + int b = static_cast(idx / (1LL * W * H * C)); + + long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; + long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); + + out[out_off] = x[in_off]; +} + +template +__global__ void sfold_downsample_grad_kernel( // backward of zero-insertion upsample + const scalar_t *__restrict__ grad_out, // (B,C,Hs,Ws) + scalar_t *__restrict__ grad_in, // (B,C,H,W) + int B, int C, int H, int W, int s, int Hs, int Ws, long long total_in) +{ + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int w = static_cast(idx % W); + int h = static_cast((idx / W) % H); + int c = static_cast((idx / (1LL * W * H)) % C); + int b = static_cast(idx / (1LL * W * H * C)); + + long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; + long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); + + grad_in[in_off] = grad_out[out_off]; +} + +struct SFoldFunction : public torch::autograd::Function +{ + static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &x, int64_t scale) + { + TORCH_CHECK(x.is_cuda() && x.dim() == 4, "sfold: x must be (B,C,H,W) CUDA"); + TORCH_CHECK(scale >= 1, "sfold: scale must be >= 1"); + if (scale == 1) + { + ctx->saved_data["s"] = (int64_t)1; + return x; + } + + auto x_ = x.contiguous(); + const int B = (int)x_.size(0), C = (int)x_.size(1), H = (int)x_.size(2), W = (int)x_.size(3); + const int s = (int)scale, Hs = H * s, Ws = W * s; + + auto out = at::zeros({B, C, Hs, Ws}, x_.options()); + + const long long total = 1LL * B * C * H * W; + const int threads = 256, blocks = (int)((total + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x_.scalar_type(), "sfold_fwd", [&] + { sfold_upsample_kernel<<>>( + x_.data_ptr(), out.data_ptr(), + B, C, H, W, s, Hs, Ws, total); }); + + // save for backward + ctx->saved_data["B"] = (int64_t)B; + ctx->saved_data["C"] = (int64_t)C; + ctx->saved_data["H"] = (int64_t)H; + ctx->saved_data["W"] = (int64_t)W; + ctx->saved_data["s"] = (int64_t)s; + return out; + } + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs) + { + auto go = grad_outputs[0]; // (B,C,Hs,Ws) + const int B = (int)ctx->saved_data["B"].toInt(); + const int C = (int)ctx->saved_data["C"].toInt(); + const int H = (int)ctx->saved_data["H"].toInt(); + const int W = (int)ctx->saved_data["W"].toInt(); + const int s = (int)ctx->saved_data["s"].toInt(); + const int Hs = H * s, Ws = W * s; + + at::Tensor gx; + if (s == 1) + { + gx = go; // identity + } + else + { + gx = go.index({Slice(), Slice(), Slice(0, Hs, s), Slice(0, Ws, s)}).contiguous(); + } + return {gx, torch::Tensor()}; // no grad for scale + } +}; + +// exposed symbol for v4.cpp +at::Tensor sfold_upsample_cuda_launcher(const at::Tensor &x, int64_t scale) +{ + return SFoldFunction::apply(x, scale); +} + +// ====================================================================== +// BLOCK MEAN over non-overlapping s×s tiles +// forward: out[b,c,ho,wo] = mean_{i,j in s×s} in[b,c, ho*s+i, wo*s+j] +// backward: grad_in[b,c,hi,wi] = grad_out[b,c,hi/s, wi/s] / (s*s) +// dtypes: float/double/half/bfloat16 + complex64/complex128 +// ====================================================================== + +template +struct AccT +{ + using type = T; +}; +template <> +struct AccT +{ + using type = float; +}; +template <> +struct AccT +{ + using type = float; +}; + +template +__global__ void block_mean_kernel( + const scalar_t *__restrict__ in, // (B,C,Hs,Ws) + scalar_t *__restrict__ out, // (B,C,Ho,Wo) + int B, int C, int Ho, int Wo, int s, int Hs, int Ws, + long long total_out) +{ + using acc_t = typename AccT::type; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_out) + return; + + int wo = static_cast(idx % Wo); + int ho = static_cast((idx / Wo) % Ho); + int c = static_cast((idx / (1LL * Wo * Ho)) % C); + int b = static_cast(idx / (1LL * Wo * Ho * C)); + + const int hi0 = ho * s; + const int wi0 = wo * s; + + const long long base_in = ((long long)b * C + c) * Hs * Ws; + + acc_t acc = acc_t(0); + for (int di = 0; di < s; ++di) + { + const int hi = hi0 + di; + const long long row_off = base_in + (long long)hi * Ws + wi0; +#pragma unroll + for (int dj = 0; dj < s; ++dj) + { + acc += static_cast(in[row_off + dj]); + } + } + const float inv_area = 1.0f / (s * s); + acc = acc * static_cast(inv_area); + + const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; + out[out_off] = static_cast(acc); +} + +template +__global__ void block_mean_grad_kernel( + const scalar_t *__restrict__ grad_out, // (B,C,Ho,Wo) + scalar_t *__restrict__ grad_in, // (B,C,Hs,Ws) + int B, int C, int Ho, int Wo, int s, int Hs, int Ws, + long long total_in) +{ + using acc_t = typename AccT::type; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + int wi = static_cast(idx % Ws); + int hi = static_cast((idx / Ws) % Hs); + int c = static_cast((idx / (1LL * Ws * Hs)) % C); + int b = static_cast(idx / (1LL * Ws * Hs * C)); + + const int ho = hi / s; + const int wo = wi / s; + + const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; + acc_t g = static_cast(grad_out[out_off]) * static_cast(1.0f / (s * s)); + + const long long in_off = ((long long)b * C + c) * Hs * Ws + (long long)hi * Ws + wi; + grad_in[in_off] = static_cast(g); +} + +struct BlockMeanFunction : public torch::autograd::Function +{ + static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &input, int64_t s) + { + TORCH_CHECK(input.is_cuda() && input.dim() == 4, "block_mean: input must be (B,C,Hs,Ws) CUDA"); + TORCH_CHECK(s >= 1, "block_mean: s must be >= 1"); + + auto x = input.contiguous(); + const int B = (int)x.size(0), C = (int)x.size(1), Hs = (int)x.size(2), Ws = (int)x.size(3); + TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "block_mean: H,W must be divisible by s"); + const int Ho = Hs / (int)s, Wo = Ws / (int)s; + + auto out = at::empty({B, C, Ho, Wo}, x.options()); + + const long long total_out = 1LL * B * C * Ho * Wo; + const int threads = 256, blocks = (int)((total_out + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "block_mean_fwd", [&] + { block_mean_kernel<<>>( + x.data_ptr(), out.data_ptr(), + B, C, Ho, Wo, (int)s, Hs, Ws, total_out); }); + + // save for backward + ctx->saved_data["B"] = (int64_t)B; + ctx->saved_data["C"] = (int64_t)C; + ctx->saved_data["Hs"] = (int64_t)Hs; + ctx->saved_data["Ws"] = (int64_t)Ws; + ctx->saved_data["s"] = (int64_t)s; + return out; + } + + static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs) + { + auto go = grad_outputs[0]; // (B,C,Ho,Wo) + const int B = (int)ctx->saved_data["B"].toInt(); + const int C = (int)ctx->saved_data["C"].toInt(); + const int Hs = (int)ctx->saved_data["Hs"].toInt(); + const int Ws = (int)ctx->saved_data["Ws"].toInt(); + const int s = (int)ctx->saved_data["s"].toInt(); + const int Ho = Hs / s, Wo = Ws / s; + + auto go_scaled = go / static_cast(s * s); + auto gi = go_scaled.view({B, C, Ho, 1, Wo, 1}) + .expand({B, C, Ho, s, Wo, s}) + .reshape({B, C, Hs, Ws}) + .contiguous(); + + return {gi, torch::Tensor()}; // no grad for s + } +}; + +at::Tensor block_mean_cuda(const at::Tensor &input, int64_t s) +{ + return BlockMeanFunction::apply(input, s); +} + +template +__global__ void fb_postprocess_kernel( + const thrust::complex* __restrict__ FB, + thrust::complex* __restrict__ FBC, + real_t* __restrict__ F2B, + int64_t N +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + + thrust::complex val = FB[idx]; + real_t re = val.real(); + real_t im = val.imag(); + F2B[idx] = re * re + im * im; + FBC[idx] = thrust::complex(re, -im); +} + +std::tuple fb_postprocess_cuda(const at::Tensor& FB) { + TORCH_CHECK(FB.is_cuda(), "FB must be CUDA tensor"); + TORCH_CHECK(FB.is_complex(),"FB must be complex"); + + auto FBc = FB.contiguous(); + const auto N = FBc.numel(); + + at::Tensor FBC = at::empty_like(FBc); + at::Tensor F2B = at::empty(FBc.sizes(), + FBc.scalar_type() == at::kComplexFloat ? FBc.options().dtype(at::kFloat) + : FBc.options().dtype(at::kDouble)); + + constexpr int threads = 256; + const int blocks = (static_cast(N) + threads - 1) / threads; + + AT_DISPATCH_COMPLEX_TYPES(FB.scalar_type(), "fb_postprocess_cuda", [&] { + using real_t = typename scalar_t::value_type; + fb_postprocess_kernel<<>>( + reinterpret_cast*>(FBc.data_ptr()), + reinterpret_cast*>(FBC.data_ptr()), + F2B.data_ptr(), + N + ); + }); + + return {FBC, F2B}; +} \ No newline at end of file diff --git a/Converse2D/torch_converse2d/converse2d_v7.cpp b/Converse2D/torch_converse2d/converse2d_v7.cpp new file mode 100644 index 0000000..867b542 --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v7.cpp @@ -0,0 +1,194 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using at::Tensor; +using at::indexing::Slice; + +#ifdef CONVERSE2D_USE_CUDA_KERNELS +Tensor block_mean_cuda(const Tensor &input, int64_t s); +Tensor block_mean_cuda_backward(const Tensor &grad_out, int64_t s); + +Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t s); +Tensor sfold_upsample_cuda_backward(const Tensor &grad_out, int64_t s); +#endif + +static inline Tensor sfold_upsample_zero_insertion_autograd(const Tensor &x, int64_t s) { + TORCH_CHECK(x.dim() == 4, "sfold_upsample expects (B,C,H,W)"); + if (s == 1) return x; + const auto B = x.size(0), C = x.size(1), H = x.size(2), W = x.size(3); + auto y = at::zeros({B, C, H * s, W * s}, x.options()); + y.index_put_({Slice(), Slice(), Slice(0, c10::nullopt, s), Slice(0, c10::nullopt, s)}, x); + return y; +} + +static inline Tensor block_mean_autograd(const Tensor &input, int64_t s) { + if (s == 1) return input; + return at::avg_pool2d(input, /*kernel_size=*/{s, s}, /*stride=*/{s, s}, + /*padding=*/{0, 0}, /*ceil_mode=*/false, /*count_include_pad=*/true); +} + +static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) { +#ifdef CONVERSE2D_USE_CUDA_KERNELS + if (s == 1) return x; + return sfold_upsample_cuda_launcher(x, s); +#else + return sfold_upsample_zero_insertion_autograd(x, s); +#endif +} + +static inline Tensor block_mean(const Tensor &input, int64_t s) { +#ifdef CONVERSE2D_USE_CUDA_KERNELS + if (s == 1) return input; + return block_mean_cuda(input, s); +#else + return block_mean_autograd(input, s); +#endif +} + +struct FBKey { + int64_t device_id; + at::ScalarType dtype; + int64_t channels; + int64_t H, W; + void *ptr; + bool operator==(const FBKey &other) const { + return device_id == other.device_id && dtype == other.dtype && + channels == other.channels && H == other.H && W == other.W && + ptr == other.ptr; + } +}; +namespace std { +template <> struct hash { + size_t operator()(const FBKey &k) const { + return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ + ((hash()(k.H) ^ hash()(k.W)) << 1) ^ + ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + } +}; +} + +constexpr size_t FB_CACHE_MAX_SIZE = 64; +static std::unordered_map> fb_cache; +static std::list fb_cache_lru; +static std::mutex fb_cache_mutex; + +static inline std::pair p2o_cached_rfft(const Tensor &psf, int64_t H, int64_t W) { + auto C = psf.size(1); + FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; + + { + std::lock_guard lock(fb_cache_mutex); + auto it = fb_cache.find(key); + if (it != fb_cache.end()) { + fb_cache_lru.remove(key); + fb_cache_lru.push_front(key); + return it->second; + } + } + + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + int64_t kh = psf.size(2), kw = psf.size(3); + otf.index_put_({0, Slice(), Slice(0, kh), Slice(0, kw)}, psf); + otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); + + Tensor FB = at::fft_rfft2(otf, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor F2B = at::abs(FB).pow(2); + + { + std::lock_guard lock(fb_cache_mutex); + fb_cache[key] = {FB, F2B}; + fb_cache_lru.push_front(key); + if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) { + fb_cache.erase(fb_cache_lru.back()); + fb_cache_lru.pop_back(); + } + } + return {FB, F2B}; +} + +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) { + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); + + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; + + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + Tensor STy = sfold_upsample_zero_insertion(x, scale); + + auto [FB, F2B] = p2o_cached_rfft(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + + Tensor F_STy = at::fft_rfft2(STy, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor F_lambda_x0 = at::fft_rfft2(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); + + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + F_lambda_x0; + + Tensor x1 = FB * FR; + Tensor x1_real = at::fft_irfft2(x1, {Hs, Ws}, {-2, -1}, c10::nullopt); + Tensor F2B_real = at::fft_irfft2(F2B, {Hs, Ws}, {-2, -1}, c10::nullopt); + + Tensor FBR = block_mean(x1_real, scale); + Tensor invW = block_mean(F2B_real, scale); + + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; + + Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) + .expand({B, C, H, scale, W, scale}) + .reshape({B, C, Hs, Ws}); + Tensor FCBinvWBR = FBC * at::fft_rfft2(invWBR_exp, c10::nullopt, {-2, -1}, c10::nullopt); + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out = at::fft_irfft2(FX, {Hs, Ws}, {-2, -1}, c10::nullopt); + return out; +} + +void clear_fb_cache() { + std::lock_guard lock(fb_cache_mutex); + fb_cache.clear(); + fb_cache_lru.clear(); +} + +TORCH_LIBRARY(converse2d, m) { + m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); + m.def("clear_cache() -> ()"); +} +TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) { + m.impl("forward", TORCH_FN(converse2d_forward)); + m.impl("clear_cache", TORCH_FN(clear_fb_cache)); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v7.cu b/Converse2D/torch_converse2d/converse2d_v7.cu new file mode 100644 index 0000000..e7ebeec --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_v7.cu @@ -0,0 +1,251 @@ +// converse2d_v7.cu (PyTorch 2.4 / CUDA 12.1 兼容版) +#include +#include +#include // getCurrentCUDAStream +#include // CUDAGuard (moved from at::cuda) +#include + +using at::Tensor; + +// ---- 简易累加类型映射:half/bfloat16/float 用 float;double 用 double ---- +template struct acc_type_map { using type = float; }; +template <> struct acc_type_map { using type = double; }; +template <> struct acc_type_map { using type = float; }; +template <> struct acc_type_map { using type = float; }; +template <> struct acc_type_map { using type = float; }; +template using acc_t = typename acc_type_map::type; + +// ======================= block_mean forward ======================= +template +__global__ void block_mean_forward_kernel( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + int B, int C, int Hin, int Win, int s) +{ + const int Hout = Hin / s; + const int Wout = Win / s; + const int Nout = B * C * Hout * Wout; + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= Nout) return; + + int wout = idx % Wout; + int t = idx / Wout; + int hout = t % Hout; + t /= Hout; + int c = t % C; + int b = t / C; + + const int h0 = hout * s; + const int w0 = wout * s; + + const int64_t base_in = (((int64_t)b * C + c) * Hin + h0) * Win + w0; + const int64_t base_out = (((int64_t)b * C + c) * Hout + hout) * Wout + wout; + + acc_t sum = static_cast>(0); + for (int dh = 0; dh < s; ++dh) { + int64_t row = base_in + (int64_t)dh * Win; + for (int dw = 0; dw < s; ++dw) { + sum += static_cast>(x[row + dw]); + } + } + const acc_t denom = static_cast>(s) * static_cast>(s); + y[base_out] = static_cast(sum / denom); +} + +Tensor block_mean_cuda(const Tensor &input, int64_t s) +{ + TORCH_CHECK(input.is_cuda(), "block_mean_cuda: input must be CUDA tensor"); + TORCH_CHECK(input.dim() == 4, "block_mean_cuda: expect (B,C,H,W)"); + TORCH_CHECK(s >= 1, "block_mean_cuda: s must be >= 1"); + if (s == 1) return input; + + auto x = input.contiguous(); + const int B = x.size(0); + const int C = x.size(1); + const int Hin = x.size(2); + const int Win = x.size(3); + TORCH_CHECK(Hin % s == 0 && Win % s == 0, "block_mean_cuda: H/W must be divisible by s"); + + const int Hout = Hin / s; + const int Wout = Win / s; + + auto y = at::empty({B, C, Hout, Wout}, x.options()); + + const int threads = 256; + const int blocks = (B * C * Hout * Wout + threads - 1) / threads; + + c10::cuda::CUDAGuard guard(x.device()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "block_mean_forward", [&] { + block_mean_forward_kernel<<>>( + x.data_ptr(), y.data_ptr(), B, C, Hin, Win, (int)s); + }); + return y; +} + +// ======================= block_mean backward ======================= +// grad_x[b,c,i*s+p,j*s+q] = grad_y[b,c,i,j] / (s*s) +template +__global__ void block_mean_backward_kernel( + const scalar_t* __restrict__ gy, + scalar_t* __restrict__ gx, + int B, int C, int Hin, int Win, int s) +{ + const int Hout = Hin / s; + const int Wout = Win / s; + const int Nin = B * C * Hin * Win; + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= Nin) return; + + int win = idx % Win; + int t = idx / Win; + int hin = t % Hin; + t /= Hin; + int c = t % C; + int b = t / C; + + const int hout = hin / s; + const int wout = win / s; + + const int64_t index_out = (((int64_t)b * C + c) * Hout + hout) * Wout + wout; + const acc_t denom = static_cast>(s) * static_cast>(s); + gx[idx] = static_cast( static_cast>(gy[index_out]) / denom ); +} + +Tensor block_mean_cuda_backward(const Tensor &grad_out, int64_t s) +{ + TORCH_CHECK(grad_out.is_cuda(), "block_mean_cuda_backward: grad_out must be CUDA tensor"); + TORCH_CHECK(grad_out.dim() == 4, "block_mean_cuda_backward: expect (B,C,Hout,Wout)"); + TORCH_CHECK(s >= 1, "block_mean_cuda_backward: s must be >= 1"); + + auto gy = grad_out.contiguous(); + const int B = gy.size(0); + const int C = gy.size(1); + const int Hout = gy.size(2); + const int Wout = gy.size(3); + + const int Hin = Hout * s; + const int Win = Wout * s; + + auto gx = at::empty({B, C, Hin, Win}, gy.options()); + + const int threads = 256; + const int blocks = (B * C * Hin * Win + threads - 1) / threads; + + c10::cuda::CUDAGuard guard(gy.device()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, gy.scalar_type(), "block_mean_backward", [&] { + block_mean_backward_kernel<<>>( + gy.data_ptr(), gx.data_ptr(), B, C, Hin, Win, (int)s); + }); + return gx; +} + +// ======================= s-fold zero insertion forward ======================= +template +__global__ void sfold_upsample_forward_kernel( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + int B, int C, int H, int W, int s, int Hs, int Ws) +{ + const int Nin = B * C * H * W; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= Nin) return; + + int w = idx % W; + int t = idx / W; + int h = t % H; + t /= H; + int c = t % C; + int b = t / C; + + const int hs = h * s; + const int ws = w * s; + + const int64_t out_index = (((int64_t)b * C + c) * Hs + hs) * Ws + ws; + const scalar_t v = x[idx]; + y[out_index] = v; // 其它位置保持 0 +} + +Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t s) +{ + TORCH_CHECK(x.is_cuda(), "sfold_upsample_cuda: x must be CUDA tensor"); + TORCH_CHECK(x.dim() == 4, "sfold_upsample_cuda: expect (B,C,H,W)"); + TORCH_CHECK(s >= 1, "sfold_upsample_cuda: s must be >= 1"); + if (s == 1) return x; + + auto xx = x.contiguous(); + const int B = xx.size(0); + const int C = xx.size(1); + const int H = xx.size(2); + const int W = xx.size(3); + const int Hs = H * s; + const int Ws = W * s; + + auto y = at::zeros({B, C, Hs, Ws}, xx.options()); + + const int threads = 256; + const int blocks = (B * C * H * W + threads - 1) / threads; + + c10::cuda::CUDAGuard guard(xx.device()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, xx.scalar_type(), "sfold_upsample_forward", [&] { + sfold_upsample_forward_kernel<<>>( + xx.data_ptr(), y.data_ptr(), B, C, H, W, (int)s, Hs, Ws); + }); + return y; +} + +// ======================= s-fold zero insertion backward ======================= +// grad_x[b,c,h,w] = grad_y[b,c,h*s, w*s] +template +__global__ void sfold_upsample_backward_kernel( + const scalar_t* __restrict__ gy, + scalar_t* __restrict__ gx, + int B, int C, int H, int W, int s, int Hs, int Ws) +{ + const int Nin = B * C * H * W; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= Nin) return; + + int w = idx % W; + int t = idx / W; + int h = t % H; + t /= H; + int c = t % C; + int b = t / C; + + const int hs = h * s; + const int ws = w * s; + + const int64_t in_index = (((int64_t)b * C + c) * Hs + hs) * Ws + ws; + gx[idx] = gy[in_index]; +} + +Tensor sfold_upsample_cuda_backward(const Tensor &grad_out, int64_t s) +{ + TORCH_CHECK(grad_out.is_cuda(), "sfold_upsample_cuda_backward: grad_out must be CUDA tensor"); + TORCH_CHECK(grad_out.dim() == 4, "sfold_upsample_cuda_backward: expect (B,C,Hs,Ws)"); + TORCH_CHECK(s >= 1, "sfold_upsample_cuda_backward: s must be >= 1"); + if (s == 1) return grad_out; + + auto gy = grad_out.contiguous(); + const int B = gy.size(0); + const int C = gy.size(1); + const int Hs = gy.size(2); + const int Ws = gy.size(3); + TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "sfold_upsample_cuda_backward: Hs/Ws must be divisible by s"); + const int H = Hs / s; + const int W = Ws / s; + + auto gx = at::empty({B, C, H, W}, gy.options()); + + const int threads = 256; + const int blocks = (B * C * H * W + threads - 1) / threads; + + c10::cuda::CUDAGuard guard(gy.device()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, gy.scalar_type(), "sfold_upsample_backward", [&] { + sfold_upsample_backward_kernel<<>>( + gy.data_ptr(), gx.data_ptr(), B, C, H, W, (int)s, Hs, Ws); + }); + return gx; +} diff --git a/test/test_error.py b/test/test_error.py index 2c75c17..d4d7f83 100644 --- a/test/test_error.py +++ b/test/test_error.py @@ -1,56 +1,156 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import sys, pathlib, torch +""" +Converse2D accuracy test: +- Forward/Backward numerical consistency: CUDA backend vs PyTorch backend +- Autograd gradcheck in float64 +- Multiple shapes / batch sizes / scales +""" + +import sys, pathlib, argparse, math, itertools, time +import torch + PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] sys.path.insert(0, str(PROJECT_ROOT)) -from models.util_converse import Converse2D - -torch.manual_seed(0) -device = "cuda" if torch.cuda.is_available() else "cpu" -dtype = torch.float32 -B, C, H, W, scale = 2, 3, 32, 40, 2 - -x = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=True) - -m = Converse2D(C, C, 5, scale=scale, padding=2, padding_mode="circular", eps=1e-5, backend="pytorch").to(device=device, dtype=dtype) -m.eval() -x_py = x.detach().clone().requires_grad_(True) -y_py = m(x_py) -g_py = torch.autograd.grad(y_py.square().mean(), x_py)[0].detach() - -have_cuda = False -try: - if device == "cuda": - m.backend = "cuda" - x_cu = x.detach().clone().requires_grad_(True) - y_cu = m(x_cu) - g_cu = torch.autograd.grad(y_cu.square().mean(), x_cu)[0].detach() - have_cuda = True - print("[INFO] CUDA backend: OK") - else: - print("[WARN] CUDA not available on this device.") -except Exception as e: - print("[WARN] CUDA backend unavailable ->", repr(e)) - -print("[INFO] Python backend: OK") - -if have_cuda: +from models.util_converse import Converse2D # noqa: E402 + + +def max_rel_err(a, b, eps=1e-8): + with torch.no_grad(): + return ((a - b).abs() / (b.abs() + eps)).max().item() + + +def run_one_case(device, dtype, B, C, H, W, k, scale, eps, atol_fwd, rtol_fwd, atol_bwd, rtol_bwd, seed=0): + torch.manual_seed(seed) + + # Build inputs + x = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=True) + x0 = torch.randn(B, C, H*scale, W*scale, device=device, dtype=dtype) + + # Two identical modules (same init), different backends + m_py = Converse2D(C, C, k, scale=scale, padding=k//2, padding_mode="circular", + eps=eps, backend="pytorch").to(device=device, dtype=dtype).eval() + m_cu = Converse2D(C, C, k, scale=scale, padding=k//2, padding_mode="circular", + eps=eps, backend="cuda").to(device=device, dtype=dtype).eval() + + # Copy weights/bias so two modules are identical + with torch.no_grad(): + for (pn, p_py), (_, p_cu) in zip(m_py.named_parameters(), m_cu.named_parameters()): + p_cu.copy_(p_py) + + # Forward + x_py = x.detach().clone().requires_grad_(True) + x_cu = x.detach().clone().requires_grad_(True) + + y_py = m_py(x_py) + y_cu = m_cu(x_cu) + + # Backward (use a simple scalar loss so both sides comparable) + loss_py = (y_py.square()).mean() + loss_cu = (y_cu.square()).mean() + g_py = torch.autograd.grad(loss_py, x_py, retain_graph=False, create_graph=False)[0].detach() + g_cu = torch.autograd.grad(loss_cu, x_cu, retain_graph=False, create_graph=False)[0].detach() + + # Errors with torch.no_grad(): - out_mae = (y_cu - y_py).abs().max().item() - grad_mae = (g_cu - g_py).abs().max().item() - out_rel = ((y_cu - y_py).abs() / (y_py.abs() + 1e-8)).max().item() - grad_rel = ((g_cu - g_py).abs() / (g_py.abs() + 1e-8)).max().item() - print(f"forward: max|Δ|={out_mae:.3e} max rel={out_rel:.3e}") - print(f"backward: max|Δ|={grad_mae:.3e} max rel={grad_rel:.3e}") - -# gradcheck (float64) -try: - x64 = torch.randn(1,2,8,9, device=device, dtype=torch.float64, requires_grad=True) - m64 = Converse2D(2,2,5, scale=2, padding=2, padding_mode="circular", eps=1e-5, backend="auto").to(device=device, dtype=torch.float64) - m64.eval() - torch.autograd.gradcheck(lambda t: m64(t), (x64,), eps=1e-6, atol=1e-4, rtol=1e-4) - print("[INFO] Gradcheck (float64) passed.") -except Exception as e: - print("[WARN] Gradcheck skipped/failed ->", repr(e)) + f_mae = (y_cu - y_py).abs().max().item() + f_mre = max_rel_err(y_cu, y_py) + b_mae = (g_cu - g_py).abs().max().item() + b_mre = max_rel_err(g_cu, g_py) + + print(f"[Case] B{B} C{C} {H}x{W} s{scale} k{k} dtype={str(dtype).split('.')[-1]}") + print(f" forward: max|Δ|={f_mae:.3e} max rel={f_mre:.3e}") + print(f" backward: max|Δ|={b_mae:.3e} max rel={b_mre:.3e}") + + # Assertions + assert f_mae <= atol_fwd + rtol_fwd * y_py.abs().max().item() + 1e-12, \ + f"FWD max abs error too large: {f_mae:.3e}" + assert f_mre <= rtol_fwd or f_mae <= atol_fwd, \ + f"FWD max rel error too large: {f_mre:.3e}" + + assert b_mae <= atol_bwd + rtol_bwd * g_py.abs().max().item() + 1e-12, \ + f"BWD max abs error too large: {b_mae:.3e}" + assert b_mre <= rtol_bwd or b_mae <= atol_bwd, \ + f"BWD max rel error too large: {b_mre:.3e}" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--dtype", default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--eps", type=float, default=1e-5) + + # tolerances (forward/backward) + parser.add_argument("--atol-fwd", type=float, default=5e-4) + parser.add_argument("--rtol-fwd", type=float, default=5e-3) + parser.add_argument("--atol-bwd", type=float, default=8e-4) + parser.add_argument("--rtol-bwd", type=float, default=8e-3) + + # shapes + parser.add_argument("--B", type=int, nargs="*", default=[1, 2]) + parser.add_argument("--C", type=int, nargs="*", default=[2, 3]) + parser.add_argument("--H", type=int, nargs="*", default=[16, 32]) + parser.add_argument("--W", type=int, nargs="*", default=[18, 40]) + parser.add_argument("--k", type=int, nargs="*", default=[3, 5]) + parser.add_argument("--s", type=int, nargs="*", default=[1, 2, 3]) + parser.add_argument("--seed", type=int, default=0) + + parser.add_argument("--gradcheck", action="store_true", help="run float64 gradcheck") + args = parser.parse_args() + + device = args.device + if device == "cpu": + print("[WARN] CUDA not available, only PyTorch backend will run; comparisons will be skipped.") + + # dtype + map_dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + dtype = map_dtype[args.dtype] + + # Run a grid + any_cuda = torch.cuda.is_available() and device.startswith("cuda") + if any_cuda: + print("[INFO] CUDA backend: will be tested") + print("[INFO] Python backend: will be tested") + + for (B, C, H, W, k, s) in itertools.product(args.B, args.C, args.H, args.W, args.k, args.s): + if H * s > 2048 or W * s > 2048: + # keep test practical + continue + if not any_cuda: + # Just dry run to ensure Python path healthy + torch.manual_seed(args.seed) + m_py = Converse2D(C, C, k, scale=s, padding=k//2, padding_mode="circular", + eps=args.eps, backend="pytorch").to(device=device, dtype=dtype).eval() + x = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=True) + x0 = torch.randn(B, C, H*s, W*s, device=device, dtype=dtype) + y = m_py(x) + _ = torch.autograd.grad((y.square()).mean(), x)[0] + print(f"[CPU dry-run] B{B} C{C} {H}x{W} s{s} k{k} OK") + else: + run_one_case(device, dtype, B, C, H, W, k, s, args.eps, + args.atol_fwd, args.rtol_fwd, args.atol_bwd, args.rtol_bwd, seed=args.seed) + + # ====== gradcheck (float64, stricter) ====== + if args.gradcheck and any_cuda: + print("[INFO] Running gradcheck (float64)…") + torch.manual_seed(0) + x64 = torch.randn(1, 2, 8, 9, device=device, dtype=torch.float64, requires_grad=True) + x0 = torch.randn(1, 2, 16, 18, device=device, dtype=torch.float64) + m64 = Converse2D(2, 2, 5, scale=2, padding=2, padding_mode="circular", + eps=args.eps, backend="cuda").to(device=device, dtype=torch.float64).eval() + + # Wrap a function of a single tensor for gradcheck; x0 is closed-over constant + def f(t): + return m64(t) + + ok = torch.autograd.gradcheck(f, (x64,), eps=1e-6, atol=1e-4, rtol=1e-4) + print("[INFO] Gradcheck (float64) passed." if ok else "[WARN] Gradcheck failed.") + +if __name__ == "__main__": + main() diff --git a/test/test_speed.py b/test/test_speed.py index a812435..26707ee 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -7,38 +7,44 @@ PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] sys.path.insert(0, str(PROJECT_ROOT)) -PKG = PROJECT_ROOT / "Converse2D/torch_converse2d" +PKG = PROJECT_ROOT / "Converse2D/torch_converse2d" +# ------------ utils ------------ def synchronize(): if torch.cuda.is_available(): torch.cuda.synchronize() def timed_run(fn, warmup, iters): - for _ in range(warmup): fn() + # warmup + for _ in range(warmup): + fn() synchronize() - times=[] + times = [] for _ in range(iters): - t0=time.perf_counter(); fn(); synchronize() - times.append((time.perf_counter()-t0)*1000.0) + t0 = time.perf_counter() + fn() + synchronize() + times.append((time.perf_counter() - t0) * 1000.0) return { "mean_ms": statistics.mean(times), "p50_ms": statistics.median(times), - "p90_ms": statistics.quantiles(times, n=10)[8] if len(times)>=10 else statistics.median_high(times), + "p90_ms": statistics.quantiles(times, n=10)[8] if len(times) >= 10 else statistics.median_high(times), } -def tp_gpix_per_s(B,H,W,s,mean_ms): - if mean_ms<=0: return None - return (B*(H*s)*(W*s)/(mean_ms/1e3))/1e9 +def tp_gpix_per_s(B, H, W, s, mean_ms): + if not mean_ms or mean_ms <= 0: + return None + # pixels processed per example = (H*s)*(W*s) + return (B * (H * s) * (W * s) / (mean_ms / 1e3)) / 1e9 def to_dtype(name): - name=name.lower() - if name in ("fp16","half","float16"): return torch.float16 - if name in ("bf16","bfloat16"): return torch.bfloat16 - if name in ("fp32","float32","float"):return torch.float32 + name = name.lower() + if name in ("fp16", "half", "float16"): return torch.float16 + if name in ("bf16", "bfloat16"): return torch.bfloat16 + if name in ("fp32", "float32", "float"): return torch.float32 raise ValueError(name) -# -------- parent <-> child plumbing -------- -import re +# ------------ parent <-> child plumbing ------------ def _parse_last_json_from_text(txt: str): m = re.findall(r"\{.*\}", txt, flags=re.S) if not m: @@ -55,6 +61,7 @@ def run_variant_subprocess(variant, case_args, cache_root): "--scale", str(case_args["scale"]), "--ksize", str(case_args["ksize"]), "--warmup", str(case_args["warmup"]), "--iters", str(case_args["iters"]), "--dtype", case_args["dtype"], "--device", case_args["device"], + "--eps", str(case_args["eps"]), ] env = os.environ.copy() env["TORCH_EXTENSIONS_DIR"] = str(pathlib.Path(cache_root) / variant) @@ -64,25 +71,42 @@ def run_variant_subprocess(variant, case_args, cache_root): raise RuntimeError(f"Subprocess failed (variant={variant}). Output:\n{out}") return _parse_last_json_from_text(out) -# -------- worker -------- +# ------------ worker ------------ def worker_main(args): - device = "cuda" if (args.device=="cuda" and torch.cuda.is_available()) else "cpu" + device = "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu" dtype = to_dtype(args.dtype) - B,C,H,W,s,k = args.B,args.C,args.H,args.W,args.scale,args.ksize + B, C, H, W, s, k = args.B, args.C, args.H, args.W, args.scale, args.ksize + eps = float(args.eps) if args.variant == "pytorch": - from models.util_converse import Converse2D + from models.util_converse import Converse2D # 只在 baseline 导入,避免注册 v1 扩展冲突 torch.manual_seed(0) - x = torch.randn(B,C,H,W, device=device, dtype=dtype) m = Converse2D(C, C, kernel_size=k, scale=s, padding=k//2, - padding_mode="circular", eps=1e-5, backend="pytorch").to(device=device, dtype=dtype) - m.eval() - def call(): + padding_mode="circular", eps=eps, backend="pytorch").to(device=device, dtype=dtype).eval() + + x_fwd = torch.randn(B, C, H, W, device=device, dtype=dtype) + x_bwd = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=True) + + def call_fwd(): with torch.no_grad(): - _ = m(x) - stat = timed_run(call, args.warmup, args.iters) - stat["tp"] = tp_gpix_per_s(B,H,W,s,stat["mean_ms"]) - print(json.dumps({"variant":"pytorch", **stat})) + _ = m(x_fwd) + fstat = timed_run(call_fwd, args.warmup, args.iters) + + def call_bwd(): + y = m(x_bwd) + (y.square().mean()).backward() + assert x_bwd.grad is not None, "x.grad is None" + x_bwd.grad.zero_() + bstat = timed_run(call_bwd, args.warmup, args.iters) + + out = { + "variant": "pytorch", + "fwd_mean_ms": fstat["mean_ms"], "fwd_p50_ms": fstat["p50_ms"], "fwd_p90_ms": fstat["p90_ms"], "fwd_tp": tp_gpix_per_s(B,H,W,s,fstat["mean_ms"]), + "bwd_mean_ms": bstat["mean_ms"], "bwd_p50_ms": bstat["p50_ms"], "bwd_p90_ms": bstat["p90_ms"], "bwd_tp": tp_gpix_per_s(B,H,W,s,bstat["mean_ms"]), + "grad_ok": True, + "warmup": args.warmup, "iters": args.iters, "dtype": args.dtype, "device": device + } + print(json.dumps(out)) return from torch.utils.cpp_extension import load @@ -92,45 +116,62 @@ def call(): sources = [str(cpp)] if cu.exists(): sources.append(str(cu)) - maj, min = torch.cuda.get_device_capability(0) if device == "cuda" else (0, 0) - arch_num = f"{maj}{min}" # e.g. "75", "86", "89" - arch_str = f"{maj}.{min}" # e.g. "7.5" + maj, minr = torch.cuda.get_device_capability(0) if device == "cuda" else (0, 0) + arch_num = f"{maj}{minr}" + arch_str = f"{maj}.{minr}" os.environ.setdefault("TORCH_CUDA_ARCH_LIST", f"{arch_str}+PTX") + extra_cuda = ["-O3", f"-gencode=arch=compute_{arch_num},code=sm_{arch_num}"] if (cu.exists() and device=="cuda") else [] ext_name = f"converse2d_v{vnum}_sm{arch_num}_ext" + print(f"[build] compiling {ext_name} (variant={args.variant}) ...", flush=True) + load(name=ext_name, sources=sources, verbose=False, extra_cflags=["-O3"], extra_cuda_cflags=extra_cuda) - print(f"[build] compiling {ext_name} for sm_{arch_num} (variant={args.variant}) ...", flush=True) - - extra_cuda = [] - if cu.exists() and device == "cuda": - extra_cuda = ["-O3", f"-gencode=arch=compute_{arch_num},code=sm_{arch_num}"] - - load( - name=ext_name, - sources=sources, - verbose=False, - extra_cflags=["-O3"], - extra_cuda_cflags=extra_cuda, - ) + converse2d_forward = torch.ops.converse2d.forward torch.manual_seed(0) - x = torch.randn(B,C,H,W, device=device, dtype=dtype) - x0 = x if s==1 else F.interpolate(x, scale_factor=s, mode="nearest") - weight = torch.randn(1,C,k,k, device=device, dtype=dtype) - weight = torch.softmax(weight.view(1,C,-1), dim=-1).view(1,C,k,k).contiguous() - bias = torch.zeros(1,C,1,1, device=device, dtype=dtype) + x_fwd = torch.randn(B, C, H, W, device=device, dtype=dtype) + x_bwd = torch.randn(B, C, H, W, device=device, dtype=dtype, requires_grad=True) + x0_fwd = x_fwd if s == 1 else F.interpolate(x_fwd, scale_factor=s, mode="nearest") + x0_bwd = x_bwd if s == 1 else F.interpolate(x_bwd, scale_factor=s, mode="nearest") + weight = torch.randn(1, C, k, k, device=device, dtype=dtype) + weight = torch.softmax(weight.view(1, C, -1), dim=-1).view(1, C, k, k).contiguous() + bias = torch.zeros(1, C, 1, 1, device=device, dtype=dtype) - converse2d_forward = torch.ops.converse2d.forward - def call(): + def call_fwd(): with torch.no_grad(): - _ = converse2d_forward(x, x0, weight, bias, int(s), float(1e-5)) - stat = timed_run(call, args.warmup, args.iters) - stat["tp"] = tp_gpix_per_s(B,H,W,s,stat["mean_ms"]) - try: torch.ops.converse2d.clear_cache() - except Exception: pass - print(json.dumps({"variant": args.variant, **stat})) - -# -------- parent orchestrator -------- + _ = converse2d_forward(x_fwd, x0_fwd, weight, bias, int(s), float(eps)) + fstat = timed_run(call_fwd, args.warmup, args.iters) + + def call_bwd(): + y = converse2d_forward(x_bwd, x0_bwd, weight, bias, int(s), float(eps)) + (y.square().mean()).backward() + assert x_bwd.grad is not None, "x.grad is None" + x_bwd.grad.zero_() + + grad_ok = True + try: + bstat = timed_run(call_bwd, args.warmup, args.iters) + except Exception: + grad_ok = False + bstat = {"mean_ms": None, "p50_ms": None, "p90_ms": None} + + try: + torch.ops.converse2d.clear_cache() + except Exception: + pass + + out = { + "variant": args.variant, + "fwd_mean_ms": fstat["mean_ms"], "fwd_p50_ms": fstat["p50_ms"], "fwd_p90_ms": fstat["p90_ms"], + "fwd_tp": tp_gpix_per_s(B, H, W, s, fstat["mean_ms"]), + "bwd_mean_ms": bstat["mean_ms"], "bwd_p50_ms": bstat["p50_ms"], "bwd_p90_ms": bstat["p90_ms"], + "bwd_tp": tp_gpix_per_s(B, H, W, s, bstat["mean_ms"]) if bstat["mean_ms"] else None, + "grad_ok": grad_ok, + "warmup": args.warmup, "iters": args.iters, "dtype": args.dtype, "device": device + } + print(json.dumps(out)) + +# ------------ parent orchestrator ------------ def parent_main(args): device = "cuda" if (args.device=="cuda" and torch.cuda.is_available()) else "cpu" Bs = [int(x) for x in args.B_list.split(",")] @@ -144,8 +185,9 @@ def parent_main(args): print(f"[Cfg] dtype={args.dtype}, warmup={args.warmup}, iters={args.iters}") print(f"[Grid] B={Bs} C={Cs} H={Hs} W={Ws} scale={Ss} ksize={Ks}\n") - variants = ["pytorch", "cuda_v1", "cuda_v2", "cuda_v3", "cuda_v4"] - results = list() + variants = ["pytorch", "cuda_v1", "cuda_v2", "cuda_v3","cuda_v4", "cuda_v5", "cuda_v6", "cuda_v7"] + + results = [] cache_root = PROJECT_ROOT / ".torch_ext_cache_grid" cache_root.mkdir(exist_ok=True) @@ -155,38 +197,54 @@ def parent_main(args): for W in Ws: for s in Ss: for k in Ks: - case = dict(B=B,C=C,H=H,W=W,scale=s,ksize=k, - warmup=args.warmup,iters=args.iters, - dtype=args.dtype,device=device) + case = dict(B=B, C=C, H=H, W=W, scale=s, ksize=k, + warmup=args.warmup, iters=args.iters, + dtype=args.dtype, device=device, eps=args.eps) + base = run_variant_subprocess("pytorch", case, cache_root) - base_mean = base["mean_ms"] - results.append({**case,"variant":"pytorch",**base}) + results.append({**case, **base}) print(f"[Case] B{B} C{C} {H}x{W} s{s} k{k}") - print(f" PyTorch : {base_mean:.3f} ms") + print(f" PyTorch : fwd {base['fwd_mean_ms']:.3f} ms | bwd {base['bwd_mean_ms']:.3f} ms | grad_ok={base['grad_ok']}") + + base_fwd = base["fwd_mean_ms"] + base_bwd = base["bwd_mean_ms"] + for v in variants[1:]: r = run_variant_subprocess(v, case, cache_root) - sp = base_mean / r["mean_ms"] if r["mean_ms"]>0 else None - results.append({**case, "variant":v, **r, "speedup_vs_pytorch": sp}) - print(f" {v:8s}: {r['mean_ms']:.3f} ms ({sp:.2f}x vs PyTorch)") + r["fwd_speedup_vs_pytorch"] = (base_fwd / r["fwd_mean_ms"]) if (r["fwd_mean_ms"] and base_fwd) else None + r["bwd_speedup_vs_pytorch"] = (base_bwd / r["bwd_mean_ms"]) if (r["bwd_mean_ms"] and base_bwd) else None + results.append({**case, **r}) + fsp = f"{r['fwd_speedup_vs_pytorch']:.2f}x" if r["fwd_speedup_vs_pytorch"] else "n/a" + if r["bwd_mean_ms"]: + bsp = f"{r['bwd_speedup_vs_pytorch']:.2f}x" if r["bwd_speedup_vs_pytorch"] else "n/a" + bwd_repr = f"{r['bwd_mean_ms']:.3f} ms ({bsp})" + else: + bwd_repr = "n/a" + print(f" {v:8s}: fwd {r['fwd_mean_ms']:.3f} ms ({fsp}) | bwd {bwd_repr} | grad_ok={r['grad_ok']}") print("") - header = ["variant","B","C","H","W","scale","ksize","mean_ms","p50_ms","p90_ms","tp","speedup_vs_pytorch","warmup","dtype","device","iters"] - print("\n=== Summary (normalized to PyTorch) ===") - print(" | ".join(h.rjust(10) for h in header)) - print("-"*120) + header = [ + "variant","B","C","H","W","scale","ksize", + "fwd_mean_ms","fwd_p50_ms","fwd_p90_ms","fwd_tp","fwd_speedup_vs_pytorch", + "bwd_mean_ms","bwd_p50_ms","bwd_p90_ms","bwd_tp","bwd_speedup_vs_pytorch", + "grad_ok","eps","warmup","dtype","device","iters" + ] + print("\n=== Summary (speed vs PyTorch) ===") + print(" | ".join(h.rjust(14) for h in header)) + print("-"*160) for r in results: - line=[] + row=[] for h in header: v = r.get(h,"") - if isinstance(v,float): - line.append(f"{v:10.3f}") + if isinstance(v, float): + row.append(f"{v:14.3f}") else: - line.append(str(v).rjust(10)) - print(" | ".join(line)) + row.append(str(v).rjust(14)) + print(" | ".join(row)) if args.csv: - with open(args.csv,"w",newline="") as f: - w=csv.DictWriter(f, fieldnames=header); w.writeheader(); w.writerows(results) + with open(args.csv, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=header); w.writeheader(); w.writerows(results) print(f"\n[Saved] {args.csv}") def main(): @@ -203,18 +261,19 @@ def main(): p.add_argument("--iters", type=int, default=50) p.add_argument("--dtype", default="float32", choices=["float16","bfloat16","float32"]) p.add_argument("--device", default="cuda") + p.add_argument("--eps", default=1e-5, type=float) # grid - p.add_argument("--B_list", default="1") - p.add_argument("--C_list", default="3,8") - p.add_argument("--H_list", default="128,256") - p.add_argument("--W_list", default="128,256") + p.add_argument("--B_list", default="8") + p.add_argument("--C_list", default="8") + p.add_argument("--H_list", default="256") + p.add_argument("--W_list", default="256") p.add_argument("--scale_list", default="1,2,3") p.add_argument("--ksize_list", default="3,5,7") p.add_argument("--csv", default="") args = p.parse_args() - if args.worker: + if args.worker: worker_main(args) - else: + else: parent_main(args) if __name__ == "__main__": diff --git a/test/test_v5.csv b/test/test_v5.csv new file mode 100644 index 0000000..c8bf52e --- /dev/null +++ b/test/test_v5.csv @@ -0,0 +1,37 @@ +variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters +pytorch,8,8,256,256,1,3,9.018746092915535,8.91665113158524,9.489735681563616,0.058133136757430405,,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,3,3.061873996630311,3.0620904872193933,3.0763483606278896,0.17123108285219948,2.9454987706355484,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,3,3.043683832511306,3.0511280056089163,3.0741054099053144,0.1722544222234201,2.963102145032679,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,3,3.0455162096768618,3.054927452467382,3.0771585879847407,0.17215078295565156,2.9613193534348157,10,float32,cuda,50 +pytorch,8,8,256,256,1,5,8.164601628668606,8.02653655409813,9.788747737184167,0.06421476807380928,,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,5,3.0641747498884797,3.0675254529342055,3.079495718702674,0.17110251300748477,2.664535248508838,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,5,3.0497499415650964,3.0503239249810576,3.0591003131121397,0.17191179934278203,2.6771380556133817,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,5,3.067394499666989,3.0633179703727365,3.082340001128614,0.17092291195570677,2.6617383677107704,10,float32,cuda,50 +pytorch,8,8,256,256,1,7,9.292588303796947,9.34083794709295,10.928183235228062,0.05642001806813891,,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,7,3.0567877739667892,3.0584499472752213,3.077001217752695,0.17151599612675505,3.0399847784453704,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,7,3.011856614612043,3.047129837796092,3.0612722039222717,0.1740746878375329,3.085335556385351,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,7,3.0088233342394233,3.04994557518512,3.0606051674112678,0.17425017748093563,3.088445970905084,10,float32,cuda,50 +pytorch,8,8,256,256,2,3,48.46728646196425,48.91097836662084,52.44047886226326,0.04326943291215169,,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,3,10.859722616150975,10.825826553627849,11.112747946754098,0.1931128514167608,4.463031715918962,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,3,10.890666269697249,10.818996001034975,11.145172524265945,0.19256415981042627,4.45035090248079,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,3,10.889667607843876,10.824134456925094,11.142246099188924,0.19258181934675508,4.450759032080377,10,float32,cuda,50 +pytorch,8,8,256,256,2,5,39.19886094983667,39.11385603714734,40.184705122374,0.053500330090809387,,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,5,10.921839205548167,10.843515512533486,11.04965121485293,0.19201454631694917,3.589034796440155,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,5,10.950461719185114,10.84580144379288,11.239956971257925,0.191512655245012,3.57965371278917,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,5,10.826412132009864,10.821027914062142,10.844661109149456,0.19370701710121166,3.6206695691862247,10,float32,cuda,50 +pytorch,8,8,256,256,2,7,44.36711141373962,43.92930108588189,47.58514082059264,0.04726816628748459,,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,7,10.912643321789801,10.832935920916498,11.155167641118169,0.19217635344248038,4.065661279806486,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,7,10.899619800038636,10.823742602951825,11.146645108237863,0.19240597731606807,4.070519176603055,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,7,10.856493567116559,10.844919481314719,10.878617619164288,0.19317028901045027,4.086688868689981,10,float32,cuda,50 +pytorch,8,8,256,256,3,3,101.29073102492839,100.59209598693997,105.4747574031353,0.046584637629268566,,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,3,24.37616023235023,24.372862884774804,24.413138814270496,0.1935740475539636,4.15531938005982,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,3,24.461542363278568,24.370684404857457,24.882999388501048,0.19289838432606382,4.14081538770773,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,3,24.447299065068364,24.368930491618812,24.460403178818524,0.1930107693058896,4.143227877866399,10,float32,cuda,50 +pytorch,8,8,256,256,3,5,86.84147734194994,86.47282503079623,90.13091719243675,0.054335694698282394,,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,5,24.3759097578004,24.37639352865517,24.403270613402128,0.19357603662320866,3.562594307445706,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,5,24.360263156704605,24.359582574106753,24.388629896566272,0.19370037054387548,3.5648825623646356,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,5,24.40255253110081,24.365780991502106,24.539125943556428,0.1933646897793254,3.558704657280067,10,float32,cuda,50 +pytorch,8,8,256,256,3,7,92.940255952999,93.59969745855778,96.69805229641497,0.05077016360258625,,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,7,24.371705148369074,24.363814387470484,24.390139686875045,0.19360943238375597,3.813449054434275,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,7,24.42510688211769,24.368971004150808,24.507023952901363,0.19318613518349081,3.8051115355013323,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,7,24.401678266003728,24.37236753758043,24.425271223299205,0.19337161766344219,3.808764911161164,10,float32,cuda,50 diff --git a/test/test_v6.csv b/test/test_v6.csv new file mode 100644 index 0000000..122bef0 --- /dev/null +++ b/test/test_v6.csv @@ -0,0 +1,37 @@ +variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters +pytorch,8,8,256,256,1,3,9.205426471307874,9.001746540889144,9.537389921024442,0.0569542325533899,,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,3,3.0636826157569885,3.0644324142485857,3.076496347784996,0.17112999803031376,3.0046932485639855,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,3,2.9942896962165833,2.9994455398991704,3.0572090996429324,0.17509595035592612,3.0743272713189156,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,3,3.0439832201227546,3.0431936029344797,3.0574772041291,0.17223748032975592,3.0241383758142555,10,float32,cuda,50 +pytorch,8,8,256,256,1,5,7.813363987952471,7.445647963322699,8.981779753230512,0.0671014432206674,,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,5,3.069853764027357,3.0687025282531977,3.0871906550601125,0.17078598535983158,2.5451909402036397,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,5,3.0559902312234044,3.0536623671650887,3.0731556937098503,0.17156075783334943,2.5567372264879746,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,5,3.0378469452261925,3.0349043663591146,3.0528080882504582,0.17258539006512144,2.572007125056362,10,float32,cuda,50 +pytorch,8,8,256,256,1,7,8.450384107418358,8.219452458433807,9.284918219782412,0.062043096897777934,,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,7,3.0614592181518674,3.0635460279881954,3.0748489312827587,0.17125428190955963,2.7602471583860133,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,7,3.06079070083797,3.0586729990318418,3.079405124299228,0.17129168611772855,2.760850033001227,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,7,3.066543652676046,3.0657140305265784,3.081072401255369,0.17097033643805315,2.755670573951901,10,float32,cuda,50 +pytorch,8,8,256,256,2,3,46.00589778739959,45.81692651845515,47.76265760883689,0.04558441636529442,,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,3,10.883753076195717,10.826129000633955,11.165698524564505,0.1926864736197262,4.227025132352634,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,3,10.945874289609492,10.827564052306116,11.150588723830879,0.1915929184378399,4.20303546068232,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,3,10.81716647837311,10.752257890999317,11.076831934042275,0.19387258245427405,4.253045183263078,10,float32,cuda,50 +pytorch,8,8,256,256,2,5,38.73277612961829,38.58403104823083,40.112837520428,0.05414411796825335,,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,5,10.870220344513655,10.826208395883441,11.142217181622982,0.19292635600146416,3.5632006437815438,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,5,10.877257627435029,10.877057909965515,10.957135050557554,0.19280153801914962,3.560895352145105,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,5,10.755426031537354,10.749659966677427,10.778588545508683,0.19498548861297296,3.601231231198558,10,float32,cuda,50 +pytorch,8,8,256,256,2,7,44.22167818527669,44.16610090993345,45.8656846312806,0.04742361859750118,,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,7,10.892827231436968,10.832248604856431,11.152946949005127,0.19252595817801713,4.059706194334179,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,7,10.838612425141037,10.817955480888486,10.870237019844353,0.19348897420997238,4.080012869793271,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,7,10.811590934172273,10.752530070021749,11.083307187072933,0.1939725626661953,4.09021007680793,10,float32,cuda,50 +pytorch,8,8,256,256,3,3,100.32575480174273,100.2145285019651,101.78962631616741,0.04703270869304274,,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,3,24.35199290048331,24.352667038328946,24.379127472639084,0.19376615373053724,4.119817019154584,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,3,24.405628656968474,24.35805497225374,24.45684978738427,0.19334031777348676,4.110762980616645,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,3,24.225411019288003,24.226056993938982,24.253131332807243,0.19477861474643748,4.141343761798901,10,float32,cuda,50 +pytorch,8,8,256,256,3,5,83.76444775145501,83.35559000261128,86.63466500584036,0.05633167921074292,,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,5,24.367065099067986,24.362630443647504,24.40121565014124,0.19364630007002692,3.437609224208906,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,5,24.346372256986797,24.343705968931317,24.382443260401487,0.1938108869031148,3.4405309697594357,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,5,24.30050970055163,24.209839990362525,24.665594031102955,0.19417666782079415,3.4470243128091065,10,float32,cuda,50 +pytorch,8,8,256,256,3,7,93.56231243349612,94.02274643070996,95.42696874123067,0.0504326141292624,,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,7,24.370713336393237,24.374024011194706,24.39629177097231,0.19361731168343105,3.8391290046392603,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,7,24.427954778075218,24.36559647321701,24.452975136227906,0.19316361287171985,3.830132865542675,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,7,24.22358512878418,24.221764993853867,24.25856285262853,0.19479329648826565,3.862446947306687,10,float32,cuda,50 diff --git a/test/test_v7.csv b/test/test_v7.csv new file mode 100644 index 0000000..e01b4d1 --- /dev/null +++ b/test/test_v7.csv @@ -0,0 +1,73 @@ +variant,B,C,H,W,scale,ksize,fwd_mean_ms,fwd_p50_ms,fwd_p90_ms,fwd_tp,fwd_speedup_vs_pytorch,bwd_mean_ms,bwd_p50_ms,bwd_p90_ms,bwd_tp,bwd_speedup_vs_pytorch,grad_ok,eps,warmup,dtype,device,iters +pytorch,8,8,256,256,1,3,9.288488607853651,8.991222945041955,9.583412948995829,0.05644492038852277,,20.064111375249922,19.56672000233084,21.39674376230687,0.02613063644805796,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,1,3,3.3021508576348424,3.324655001051724,3.3440517028793693,0.158771668104685,2.8128601654821153,6.695929053239524,6.69370440300554,6.72081895172596,0.07829951539679872,2.9964641524304687,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,1,3,3.1805392680689692,3.1820274889469147,3.1898362562060356,0.16484248607259483,2.920413120223181,6.543120900169015,6.540277507156134,6.558541930280626,0.08012812356660827,3.066443625507768,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,3,3.024693327024579,3.04701691493392,3.0670031206682324,0.17333591981562882,3.070886071280102,6.397886727936566,6.395927630364895,6.411892082542181,0.08194705881719984,3.1360529231690477,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,3,3.0596842663362622,3.0616489239037037,3.07619022205472,0.17135362814013316,3.0357670266990997,6.399778164923191,6.40075805131346,6.410616356879473,0.08192283958740817,3.1351260712785547,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,3,3.0559819284826517,3.0527031049132347,3.0764779541641474,0.17156122394359777,3.0394448740950337,6.409975425340235,6.408171029761434,6.429016985930502,0.08179251326414738,3.1301385799283215,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,3,3.0050428677350283,3.0304675456136465,3.052515466697514,0.17446939131193434,3.090967089882017,6.387096201069653,6.387264467775822,6.399212614633143,0.08208550231515176,3.141351052750664,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,1,3,1.8821742571890354,1.8811896443367004,1.8879902781918645,0.27855444202228474,4.934978029996914,5.040463833138347,5.040465504862368,5.048878467641771,0.10401582420909114,3.9806081423180046,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,1,5,7.326102494262159,7.420093985274434,7.613263255916536,0.07156438234526817,,16.14620674867183,16.1027709254995,16.263780114240944,0.03247127998302929,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,1,5,3.2715308107435703,3.3216774463653564,3.3417833968997,0.16025769901914425,2.2393499918153132,6.707657393999398,6.701117963530123,6.731993076391518,0.07816260867304027,2.407130507756115,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,1,5,3.159591546282172,3.1798079144209623,3.199100005440414,0.16593537244297263,2.318686572915615,6.542924279347062,6.542460061609745,6.554703111760318,0.08013053148955596,2.467735535261782,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,5,3.05445936974138,3.055477049201727,3.073208383284509,0.1716467422005326,2.398494007429687,6.400615535676479,6.398052908480167,6.4109522849321365,0.08191212190103653,2.5226021870356474,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,5,3.0754889361560345,3.0726579716429114,3.091029403731227,0.17047305676712743,2.3820935943339285,6.441063601523638,6.430621026083827,6.473009032197297,0.08139773683898717,2.5067609555745474,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,5,3.048403072170913,3.052991349250078,3.072553128004074,0.17198775476453956,2.403259123159488,6.407246896997094,6.406409083865583,6.420211470685899,0.08182734463466983,2.5199913485836056,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,5,2.9911579517647624,3.0331225134432316,3.045725799165666,0.1752792759374923,2.449252969051604,6.385393836535513,6.382934865541756,6.403305334970355,0.08210738654836988,2.5286156440793923,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,1,5,1.8822504533454776,1.8811115296557546,1.8892435124143958,0.2785431657451006,3.8922038675782376,5.046942573972046,5.043365992605686,5.0634910352528095,0.10388229949432032,3.1992055609946117,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,1,7,9.156141332350671,9.012500522658229,10.730735887773335,0.057260802446066954,,18.346093399450183,18.29959498718381,18.474590103141963,0.02857763713422022,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,1,7,3.3304048283025622,3.330954583361745,3.3437674399465322,0.15742470571279427,2.7492577642632607,6.747524798847735,6.6985663725063205,6.975994328968227,0.07770078890106961,2.718936787395449,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,1,7,3.1805677665397525,3.180499654263258,3.187753399834037,0.16484100905367305,2.8787757420782603,6.5445497911423445,6.542496965266764,6.557909771800041,0.0801106289556529,2.803262865274632,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,7,2.9923936538398266,2.959055360406637,3.0513782519847155,0.17520689476374068,3.0598050896818174,6.439674673601985,6.414378061890602,6.499357661232352,0.08141529294161429,2.8489161843308506,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,7,3.0570596596226096,3.056291490793228,3.0761950882151723,0.17150074201192486,2.9950810097965133,6.428387816995382,6.409163004718721,6.460378412157297,0.08155824056132495,2.853918264070309,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,7,3.0619724839925766,3.061673021875322,3.0807194765657187,0.17122557525937293,2.9902755103833485,6.415656288154423,6.413351511582732,6.4297555247321725,0.081720088553999,2.8595817131481276,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,7,3.042473620735109,3.0388199957087636,3.0632448382675648,0.17232294026375944,3.0094398419593875,6.388430343940854,6.389048998244107,6.399145186878741,0.08206835979627831,2.871768558430108,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,1,7,1.8896355107426643,1.883591990917921,1.9149431493133307,0.2774545657188377,4.845453676276503,5.0843368750065565,5.057928501628339,5.145174008794129,0.10311826554555825,3.6083551996004446,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,2,3,45.595936393365264,45.675908448174596,46.11496950965375,0.045994274180651766,,80.62302918639034,80.68823453504592,81.43389630131423,0.02601182343510869,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,2,3,11.85663956683129,11.822210508398712,11.90752275288105,0.17687574866210323,3.8456036498671153,25.430614417418838,25.363861583173275,25.43278068769723,0.08246564418685631,3.170313853336047,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,2,3,11.422737971879542,11.390020838007331,11.460354528389871,0.1835945116803661,3.9916818984741824,24.925604569725692,24.912420543842018,24.930561799556017,0.08413645470999621,3.2345465868584786,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,3,10.840446311049163,10.825388482771814,10.85302964784205,0.19345624154445284,4.206094019108,24.0234538866207,24.002613499760628,24.026684556156397,0.08729602370656452,3.3560132346869342,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,3,10.878428206779063,10.826261015608907,11.14368592388928,0.1927807915019494,4.191408494560955,23.699625409208238,23.694894975051284,23.71811883058399,0.0884888247721069,3.4018693457941795,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,3,10.918543636798859,10.834143031388521,11.191598395816982,0.1920725025022523,4.176008990768045,23.730977713130414,23.69663491845131,23.714118194766343,0.08837191730367012,3.3973749485164024,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,3,10.793533944524825,10.768141131848097,10.864380979910493,0.19429706811306324,4.224375133085532,23.638727851212025,23.631693911738694,23.66107157431543,0.08871678768840655,3.4106331649423582,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,2,3,7.474436634220183,7.435761974193156,7.621926814317703,0.2805765976259157,6.100250577363058,20.239599947817624,20.2378734247759,20.263943052850664,0.10361627726866854,3.9834299785694967,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,2,5,38.853937541134655,38.41745341196656,40.494335955008864,0.05397527593644133,,66.22083507943898,66.26046146266162,67.06938743591309,0.031669066049744635,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,2,5,11.842856714501977,11.828812537714839,11.868507391773164,0.17708159868488205,3.2807909846242334,25.389751312322915,25.350986630655825,25.415705051273108,0.08259836712075817,2.608171866862623,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,2,5,11.411372288130224,11.404957505874336,11.43825901672244,0.1837773711213853,3.4048435683365956,24.930950277484953,24.917138973250985,24.93941232096404,0.08411841412615266,2.656169714446977,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,5,10.863104425370693,10.828325641341507,11.124962358735502,0.19305273316733643,3.5766882117409824,24.023548257537186,24.010553024709225,24.034774862229824,0.08729568078445848,2.756496849238861,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,5,10.90485557448119,10.82881959155202,11.15186910610646,0.19231359697304146,3.562994234610317,23.72426231391728,23.71151139959693,23.792543495073915,0.08839693189405325,2.7912705652639866,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,5,10.839254357852042,10.825947392731905,10.844685533083975,0.19347751522048257,3.584558149333266,23.702450119890273,23.698252509348094,23.726728814654052,0.08847827922397537,2.7938392336861733,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,5,10.759746134281158,10.75667655095458,10.77785084489733,0.19490720076734483,3.6110459351214446,23.636622526682913,23.63265841268003,23.660433711484075,0.08872468973232393,2.8016200286095696,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,2,5,7.5289920112118125,7.6012545032426715,7.6225581811740994,0.2785435283869371,5.160576274124775,20.23725756444037,20.235952455550432,20.263770665042102,0.10362827044732498,3.2722237619685215,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,2,7,43.63379757385701,43.68262959178537,44.327281741425395,0.048062559680950134,,75.68363416008651,75.79991698730737,76.05113848112524,0.02770945163077252,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,2,7,11.998974299058318,11.853986419737339,12.264510500244796,0.1747776057962375,3.6364606245786693,25.40259242989123,25.356607511639595,25.428785989060998,0.08255661329795148,2.9793665496530024,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,2,7,11.41712686046958,11.413594475015998,11.448350944556296,0.18368474184701714,3.821784421519722,24.922909317538142,24.915332440286875,24.959108070470393,0.0841455535258977,3.036709446550378,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,7,10.88796194177121,10.827695950865746,11.144098523072898,0.19261198847089686,4.007526643389317,24.021303891204298,24.00768455117941,24.02521837502718,0.08730383702309759,3.1506880102290795,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,7,10.852394071407616,10.832656407728791,10.89512084145099,0.1932432591556259,4.020660997633443,23.70872948784381,23.69790489319712,23.748202505521476,0.08845484533767504,3.192226483451668,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,7,10.89978810865432,10.831900988705456,11.149773467332125,0.19240300628733165,4.003178514930233,23.71836454141885,23.704793537035584,23.77322791144252,0.08841891254086219,3.190929713046693,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,7,10.774082760326564,10.765377548523247,10.784391174092889,0.19464784582148828,4.049885131245684,23.65874081850052,23.644065018743277,23.685648757964373,0.08864174201359364,3.198971354422374,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,2,7,7.470769472420216,7.435820531100035,7.61996959336102,0.2807143237041433,5.840602863592509,20.250088809989393,20.24794602766633,20.27125945314765,0.10356260753609496,3.7374470240718987,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,3,3,101.28841781057417,101.60925146192312,107.4452179018408,0.04658570152437898,,175.37028575781733,175.02894543576986,177.75019966065884,0.026906450996586024,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,3,3,26.66485572233796,26.59923385363072,27.004664088599384,0.17695921737341677,3.798573630598038,58.96571567747742,58.67917649447918,60.70718690752983,0.08002263596373708,2.974105948565666,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,3,3,25.672803041525185,25.64728946890682,25.70093993563205,0.18379730457822557,3.9453587380677684,56.8397456035018,56.80113600101322,56.884816801175475,0.08301571285901913,3.085346070708189,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,3,24.374681571498513,24.3739370489493,24.404809717088938,0.19358579049161745,4.155476555189605,53.81168487481773,53.78810840193182,53.942976240068674,0.08768712615070265,3.2589629216364755,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,3,24.441681993193924,24.361361982300878,24.429271719418466,0.1930551261289607,4.144085412729743,53.07756224647164,53.07193798944354,53.1614552019164,0.08889993813371996,3.3040380593114964,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,3,24.375403551384807,24.365137447603047,24.41326177213341,0.19358005663590047,4.155353473309771,53.05147184524685,53.03926358465105,53.125715954229236,0.08894365859941288,3.305662965758595,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,3,24.276557215489447,24.22345709055662,24.269587476737797,0.1943682523891544,4.17227273667734,52.906081513501704,52.899461006745696,52.96852015890181,0.08918808320355022,3.314747203741834,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,3,3,16.926095513626933,16.923529445193708,16.972190444357693,0.2787761652532999,5.9841572871326685,46.062268051318824,46.03102651890367,46.21587647125125,0.10243941949933792,3.807243828341975,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,3,5,83.52379720192403,83.7456090375781,85.80605688039213,0.05649398324878007,,145.5188752664253,145.75104543473572,148.14186077564955,0.03242597904471772,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,3,5,26.591138602234423,26.581451063975692,26.635813480243087,0.17744979147314519,3.1410387667606536,57.973297983407974,57.822484406642616,58.44167503528297,0.08139250593179065,2.510101724902264,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,3,5,25.637969239614904,25.634926394559443,25.683102500624955,0.1840470263420472,3.2578164214685903,56.7909628059715,56.76805193070322,56.88152206130326,0.08308702242152946,2.562359714935563,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,5,24.41519634798169,24.391590617597103,24.450342124328017,0.19326455264776388,3.4209758550161573,53.789559081196785,53.765413467772305,53.912423201836646,0.08772319536728602,2.705336830271478,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,5,24.383281343616545,24.362614494748414,24.392798193730414,0.19351751446018198,3.4254535320690236,53.027258049696684,53.017987054772675,53.09310944285244,0.08898427287297746,2.744227791865955,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,5,24.393998621962965,24.367429548874497,24.428526940755546,0.19343249432471676,3.423948590647372,53.11096484772861,53.071493515744805,53.206104156561196,0.08884402709550474,2.73990268645343,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,5,24.267674176953733,24.23390606418252,24.288278119638562,0.194439399737825,3.441771823409596,52.94822241179645,52.909665973857045,53.120820759795606,0.08911709940518672,2.7483240916885783,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,3,5,17.1291301259771,16.946400050073862,18.221904919482768,0.2754717820050909,4.876126025527497,46.0351287573576,46.020424575544894,46.14793793298304,0.1024998110653888,3.1610398231626027,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,3,7,90.59936547186226,89.40491802059114,93.76176362857223,0.05208195416628463,,158.36064806673676,158.1671329913661,160.4503425071016,0.02979649336880384,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,3,7,26.627318738028407,26.58525703009218,26.68023353908211,0.17720867979324692,3.4024967501692434,58.71118636801839,58.641764568164945,60.120047559030354,0.08036955632990492,2.6972823726314648,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,3,7,25.703592961654067,25.652661453932524,26.049211691133678,0.18357713674658,3.5247743615852856,56.822463613934815,56.79537542164326,56.92225047387183,0.0830409612659392,2.786937383473484,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,7,24.43074162583798,24.387317011132836,24.824217706918716,0.1931415784369645,3.7084165048859696,53.77089409157634,53.76591603271663,53.81281706504524,0.08775364590300178,2.9450997745552696,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,7,24.416185086593032,24.371870909817517,24.417825369164348,0.1932567263585738,3.7106274035254803,53.0829459335655,53.052632487379014,53.19422874599695,0.0888909218773469,2.983267889180994,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,7,24.362534414976835,24.361073039472103,24.400543165393174,0.19368231234182484,3.718798870784413,53.057269509881735,53.02855698391795,53.15046065952629,0.0889339395635725,2.984711605583899,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,7,24.26755757071078,24.222337524406612,24.37746999785304,0.1944403340241791,3.733353272486279,52.89180411491543,52.88355250377208,52.94667130801827,0.08921215827216153,2.9940489025988666,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,3,7,16.87394628766924,16.87477866653353,16.914236755110323,0.27963772786500724,5.369186551107395,46.165161593817174,46.10549903009087,46.407426544465125,0.10221110112245234,3.4303063738857515,True,1e-05,10,float32,cuda,50 From ab26c2815277085a7740f805831a1c9702050a94 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Mon, 1 Sep 2025 10:23:14 +0800 Subject: [PATCH 15/22] Add Optimization Details --- Converse2D/README.md | 45 ++++++++++++++++++++++++++-------------- figs/pytorch_hotmap.png | Bin 0 -> 28444 bytes 2 files changed, 29 insertions(+), 16 deletions(-) create mode 100644 figs/pytorch_hotmap.png diff --git a/Converse2D/README.md b/Converse2D/README.md index 5953226..fb42493 100644 --- a/Converse2D/README.md +++ b/Converse2D/README.md @@ -1,26 +1,38 @@ +### Pytorch Analysis + +![hotmap](../figs/pytorch_hotmap.png) + +**Bottleneck** + +- p2o +- upsample + ### Kernel Registry ----------- -**Kernel Details** -We offer four versions Converse2d Kernel. -- v1: Translation from python to CPP **faster** -- v2: Add FB/F2B cache & broadcast replace repeat **much faster** -- v3: `​splits→permute→view→mean` to `block mean CUDA kernel` **fastest** -- v4: STy s-fold upsampler CUDA kernel **fastest** -- v5: Larger batched FFT CUDA kernel -- v6: Eliminate redundant calculations of `conj/abs/pow(2)` -- v7: R2C/C2R (Real FFT) replaces C2C + +--- + +**Kernel Tree** + +``` +├---v1 Translation from python to CPP +├---v2 Add FB/F2B cache & broadcast replace repeat +├---v3 splits→permute→view→mean to block mean CUDA kernel +├---v4 STy s-fold upsampler CUDA kernel + ├---v5 Larger batched FFT CUDA kernel + ├---v6 Eliminate redundant calculations of conj/abs/pow(2) + ├---v7 R2C/C2R (Real FFT) replaces C2C +``` **Tested Device** + - NVIDIA RTX 2080ti - NVIDIA RTX 4090 - NVIDIA RTX 5060ti 16g -~~Under different circumstances, **v3** and **v4** each have their own performance advantages, but they are both faster than **v1** and **v2**.~~ **v7** fastest We highly recommend you to run `test/test_speed.py` first to choose the most suitable backend for GPU. - **Installation** ```python @@ -40,10 +52,11 @@ print(torch.ops.converse2d) ``` **TODO** + - [ ] Temporary Tensor Reuse and In-Place Writing -- [x] Larger batched FFT(v5) **Note: not very useful** -- [x] Eliminate redundant calculations of `conj/abs/pow(2)` (v6) *Note: not very useful** +- [X] Larger batched FFT(v5) **Note: not very useful** +- [X] Eliminate redundant calculations of `conj/abs/pow(2)` (v6) *Note: not very useful** - [ ] The minimal necessary policy for `contiguous()` -- [x] R2C/C2R (Real FFT) replaces C2C (v7) **(Optional)** +- [X] R2C/C2R (Real FFT) replaces C2C (v7) **(Optional)** - [ ] Mixed precision **(Optional)** -- [ ] Adaptive padding **(Optional)** \ No newline at end of file +- [ ] Adaptive padding **(Optional)** diff --git a/figs/pytorch_hotmap.png b/figs/pytorch_hotmap.png new file mode 100644 index 0000000000000000000000000000000000000000..4f442b48bf3195d613ced50b44e5ae8764385be9 GIT binary patch literal 28444 zcmdqIWmFwq(>8br7A!afw*;2}3GN;U?m>eDcXzko!QBb&?he5r1a~{Q%K;AXHTV6z z@3-EWS+i#T&6*!@R(J2-QdL*&s$ExyDac8pArm5lKp?cwQesLV(5q+=2!`%80&wSo ze+U=&fpJih6aiI?6Ym2Ta3;dC!XQv}Gz#<^JaCODKY|qtgAnZ z39Gp39n)_zL!!go1q zxf4drg8Y0_eXEazLhslJEj|*4icxxZRL|!sg9y7!M>lqf!ox_|5>G8HRqrPF+cvE#LCgX2Hc&Z7tqE*!g6kK9B{Y_Y0Qd>kCo)&>(1JG?|%%1=ZiHp;9! zt&gmYJ|>1TyPUiQPZ=`_&3AeePmfPSRPSk6ld6r3w%n;r0^}lh8jF)JR-TfV?rN&v zU#4ZS-Hi_rb962w&r>*DuBG0)IGZ!BEhJxbj;DfwcR8!gQV*EMf+Q}MR4jjcnp~ts za^F|qm|^pN4|wDya+rTgj|_X0Worv%c&fTbvR zERi{c=N+(MdM@YRSs6$_@)k1y!{t7Dj44O`o)5n4bbDD8k(6~_2Czn3du+pxf3;Fk zBDWih=YQUfZ*vkkTp0}882me^r-`8fdf+2}k|ebu_8p7#=jMVmi+F9ZN_JyI!{2=a zco~F;Ulx*i)Tf0qEaK!+px+Inyvw{6`=S=~pMmu@2$~owxy*83Nd1;QN~s3AD;@SZ z*{Ek7H2QL9<1_jqb~K?mw|8S1&387ckG$+VKP3NOs_Br}&C)V{xi6Do#D$}w5Xzaj zR%kHn@?dBU%ph^p#`!-&=*)cN)d9!OWd8qJNAiF$?=7nwcJ}&bntvi3Xs0<2|D#mH z|7qTFEw$sZlK*Yos!tB+m66=khQmHDf~&x{0JV%$nVpP|`IH{4rO>JK5o(*7OWTkt zsdHTHvL>Heez}v60BjsE8tZ@BXK_>;a?E$uvW6T_EX*frf`<$Z^R}sPMm0KWt3T~@ zz98)%^8^6aVB}q^e6aReCnfFq`e<#RCp!No;yo4796L9Wqd8-zY=8Xe&EWEQQY2H# z+EM)P7f9ZX4fMSM+Rx5()+!qUV2Xcsn&p!urE3n8{i9-I#{BUopm&*F-ncc((W&Tq z`Luafq9Xudy*FUXFGuJvCLF%&;}zxh@lydOdNG3`8J?x(250om#uvc&|G}yjv8qql z`WE$=4fcOc5AS%*BGdlM07v@sn8S8;^@5N+(edcQd~jUa+q~k{X>HSYz;qKwU8n(Q z=iw`RWB)toE-w+b2l>%yGanm4+K|uGZ~_Be(rb!943SVxF*i3i_kP1fD6KsTb4@MM zKX)T-F(2~0TEqO^qTHg+mXE^x3E(HW4iiE_AVg<`mnN?%1{BVv{^?^>qMz!YTdYJR zmr|kdUc|K|4>#JITbb>mXA32tZJnI(>S{lIo-cGb=TO{r#Q5a5fx%$a;G`^_DN6UGN!F{myd15i6@aHSiLtScB^n zzx$YrO7ga=m33==!F`_C=&w`8-1b)TF-ZUF&T8>(bX`1g-Ut!}T0oTByL^JrEzG^W zqJjt@VqWatXLvZ*jXRSLkqYhZ)xPhxL58D)*l?8(8O>>mN=gHbhYP%L4_kPDJs!7* zapya(851j#Lw%w;mKc#G4efW3%cA)s7HkB1@Eyuw+o}vF_`2N}t4&Cm)8%`CMIGtv zxp96c_G!H)GRao8U@2SU3dp+coHs-J%9Uye{qq=;pc=Y0_;X%75Q9Mo5S(mTem<&2 z$9(kOw+(kH1{)?&4AzsYU2f)H@o4ZNUk|l&BwFic2D`M%P|H6XjnlvZ zs)0ovq51x8F#{c4?;$(qS)P*i=!#mSSB178KK0<{QYAW4B8gyPU!zGD_7^_zZ)M6u zzSMqPT>SB0=M(3ER{S64yv7Ej-+#N0cN*$Y58YRuJwG;g_dc@lJ_t)lzMFRMrd>TZ z{OszwHz?J@yanDDKzDvb-N)};%WU^KY2F>iP5$s5ckSsH+92d%*SgW`P*V&Y9e`>>2*uQDUxxP2Ar_FnBlVv@M@Vt)MUFir~IE>GaR=>5Mcq zG5Fwc^SbQb6CE%B$NdtLoN4~TNCsW1`ydMSimuBE&mfopp0ifELD{`xCs1%x;b0Pk-c(sZv=WWcu?tVI%-%<}y88uK}FM-k0!c`b21 zt>mEa_3InA@kh#sp1Y5LWF)7ZQuH5fmyCMvxjK9e=Ze!>-?E-Q`sucdLY$l#xQT9S zkng~bkEu5e%N*sOr=18m%A$LHg-Ir>9YFb3o-wx6L3+Ar4{Ro0E6o|<1_h~WW99T} zTM77cmzCi8l($z0B>u^_PwR8vJeQsLc-^ws1)NYE)zwYl0B{X;i+Ph$^zi@=A`~7@ z+G~B6`Mb#>xBUl90JH1z$!FJb`X@_y@KiRSv!iW4sFbc)=ZB6awz>y4z~q79F>sDT z4(E(s$KUG%h=Dv1K068m{lLM4fhj^BlEAbIDn*+~clV#GoWR;o+@##ep=N*bAUdTL z730Z|yHs30%(5Tu`?EZxZ9=mHfe!)bSOv#(KJW{fqN6A&M( z^A&P8E%=J=@Pm~*4z+w(U}-5k0_p+zkRKHy_Ef1$sOI6)@^t|(<#0Lr$#r(i&>V@~ z$pUX6nar%@AwzWd%#-3T&W61OL1GFU|ZPUL&?ccRF2D__hb8@k?r~TL0KLiNroQ z-lu;NFVzug{bn!fXR5);y@A8zZ-=B<+-;RHNxjijUc1`r4Mj& z5K1N0vR%o^`3J)Z$5|%gNd82M3u1Ud`$dH><;p8MMRS$EF)-GHRxxv4$GB5b@#FuJ zdw1FY*2Qyuev09PwTK>ZsX%xNiJPT#Myu3RtNd%M@%$O{q=#?wZCr%nKt!^T1&HxO#@{CF;#UztjrnQUAX|A51eG+UrBt9 zEIzzvyK+D1Nyq#mBK0S4GZq6S`OMOXm^;461?%^)u@6#|S2!tYl$QtdIZYn0lWAPv zk=6=Axd8|Y6*7{Y+s(nZGBjgKt;np|Qke{l_kD6KWOYH+Npog82SDi$jyMPxa4=b- z)=8q+T`AXSF)%21QFS;B@mO_IoO2qEoY9eQ8^QlUk@FN9-CDtkRiedZSJ_D2P9*}L zL{4Aak&f+x>7;V%u0}6?s2$_9PyZne{CZa-tmO8 zf4p-|LQO4aeZH%#Oayv$$oM3h!mGsvjQ@6;1;J3VQmAs3!hJ5}V-PPzj+1YazDohH z<*j)Nk^p-RQj~neovZ&;kM&v1crpk755e77k5ynpjnM0G>=$E9zeVZve3-3;>^|%- z##uC^OPHB8>vP}!(z2}?G$8yc_`$B2RP>-luGwk+eH)q{t@A#;_#ud;TWW8cE_UxP zZ=hgUP%C*$Qe)bBgozH|Y9N+>-wBYBhTX#*^VW#~O*>>2s>t}>F+;7vi)@MiEZ5Gh z9@wqCm&$?7BMWx7XskLVFs4vkP#y@L#>BjQ8lo9$4&5IL9kszGZ_T!2V+Mr(dB$PA zy+=`WD`&mX*YMKvz_z4J`A zh-%CX`39nN{V=1Q+3)@Y&+0FongF0MzM4c52{oRwwg5u6Knz~7*7AQsJ*m7G+j+T{ z9Qi}?pJxUf^rZiWfeS#u@K3Dw{r|2Js_wfs0sG3P+z$&u8CI8HZ<&HU(%4i;TzUoH z7PUmydG*M}Cow3{+V|(+wZniOSF|VBWS1JHZhmEpJF|j>sGP;#-{;lI z)A8N7zCMY*1mh4eCl)U(EZBo3xP!kO%X_{jr z>E1zKf7c?qp%j+Vp*yds5@DC*MWM(m;(7gKqHa}`s|w-f!WYfZo-%GH6MSBTctUrT zeBBwSnqk$l7A~@!8R0W;zrZgW7&A0b5-Q0f5*p!b$)-`(=-Oe_Rhb) z6!h#8uxyMMgdgQf;C3?ONj^2EX2@E8na&BAXS1OGPnW05EvQ1mgxEo-Kv9N9H)C zpN&v=x6H02ehztqJllRWjN1;SxK>W8JUep!yV7vv(b2GU0NvL~h9pa!H9C~u)LMdP zjUhWBK~2uwZ`f%XCCo3l0~$9Lmlzs+wpZ0ZdA*5gsV><(|GLP{_~3@;{WaY;@br@u zL^oiLhOxhqH;J5|9Z9dPnW&vta^o6FrrNq2vnh=KY%9|kV)@7NNb9|5&n+SF(|Z~+ z#}V+V#+wZ{F5T*gw&yk42QXP zv>#mpTVQAX_BRRYp8EEI ziTQXw7|^t0q0&Zlv%{pzv;*`a#v2x{a$z<3t4T%(+F;+j&Bpr#gVl58Jw8jW7e%AT zWo9M-2a&uP)*rK*vDQ;rpq0(d{?k^H#hXu|1P9`9Xg|WY3OeQAP^Nk@jV^<%3 zVy?w4ZJ%{Z$mAQZDu+Wh=-IWU=VKgdC$yk@PVo2H#7VugmeA~ugYQg+DsHrn=>2-- z4#6JbbV$}5izi$gAF5uRa)H_=WT)NLM_|F73J!U26k5lSFz@El_9l=Y7yFgy~QDRa)po7OT+~C`e;Nr zNSrn{!n|@R(m~sKc5mAuqEEe=n`_8<_&%YvX+FD>9&PzbS8q8sc+?;~goW12Fh= zt4m0d3NlCmpCQ%<+%V2`g<|AkXr4jvv&p7XpO>e#c_TLHl=ArMbGNC=+UJQCmN&o~ zY0>C73%8J+zLuHy_@lUsIq0*NLqet@4_RI-9#Os_zcwa10B6WdebkRqS>fj*b8M33 zG9uJ=c0dTN^k!n*MGgA-CPsEsc2@aG^uPQ_> z->;Z}l}TQ~F=i^HpgV~093 zU;sGba4$&lIv{UQNg^#mT=s_gNC3H_6z~m}fi0te9e^k^EfjE18mz7$WMj7M``+r) ze6X$lok?hp&g`p!S%T6p1tK{|%$;wV1x4?MtSzSblG9L}6l0RY=6=B&ZW>ZDreebw z$8K~q*NOS1m#KJm7)9teUV~-I9&0PlQy#d54hBnp3|g3hF}VMbP=Te-uS6>;HayS( zhotYg%}S#e{3?i;VQv2R>F(~6vrp!Dw8Yn}#gh+HNyS5AXFv3%1b&jw^S(`ePdK!i zNI>GIQiLltCf`3Dl62A9+}yc&8-6!6v>VF%#;kuG2Fg4~u;fhmI8fY_gr@2(004UA zbanm7(FK>|qgkW%>wQ5CI+{(ilbVV!Hr`*W6Dfx(g8K_-y3-t?#jDrly!Ke!hSY*& zAT?A{{W@7ZJr*M0pf zWl0tgXd7oJB8u<+L285PT(pESR>$ZP%LoYsT3BL(ay>~NKO+-t({uk(lA%l_pP|7C z<8ElIP?GD*lf6Z(-rK|#8=ur=6OZ+p*yEJ@e5hu7c`ht>A?vJ|heF1U{9AI%_S>P2 zdutiY9{;cOZ6nvuOI6kR5&(5vnb8IH{XuuJUefi^Rj+>} zCcidR?WY0_Z5Jyf(EA?54vAM!T#PLDk-ok{laO|*-UB?!FEV$uwG(ZB3$xqkg$Mmd zmF|W42I-2QqJ@(uY#VafiV16ir1JE!BCu#i$OL@p{2y%2_?`j@$7>8nFxn`a7WAa} zp2w|o?U^ABZOFjHwl}sdB57nMQc^z^1SX81v?yfaQ z9wI(XEg7Q;ouvb=S!{F7W6Ec^4#8u*1nNn9VkU{%j13x=k%qXcI_9s#r7uRsEas-p z5`d(>)+2l*q&+e2T^K`y)k`1G66+4r!?&2DYHH|FVSejAk;Xv7fLri>aiRmT35fl| zu}7Ua8uiG4N~7U>E-)*nJ)XLY)I)yDUo74RqBWS7Y)X;Cbz z%V?a{h!q7Hy@r?Oc z{_IyMa#+u7sTK?<E8dJmBlE4Sy1_+=ahNN?un&Haqauf==D`TuqOMrsr zEZZ2&lK$eR#j<~=%wGeQz2yV3X7rr(Zkv1nc1XbLHYTj!02Y_zY`G=aMP)K7Hd8iN zZ4G_<)Kn+3p!4PHZN`Gh#$U%`(mfpQmL%I)NTY~UC(PMTX*4_|*kA3TWYmHIfG?Wo zxkYQM+E@2hH}T4H0DtR+Z`1?alf>wsOp??l8JV zUHijTCg)%>5(>tX{zqLe-5+7mVP_k>qUQttOfjza8yDioCVCo!uR&d<6Eb#BSIiM9 zRiG^zrh!~;OekwK_rks_IAF!r&Z^0Zjd3e3i{Z3zmIz0>iN@-Rp4i#sHW=wyu^$C<6K(BO%3TJVCS+PUE%k* z6{(7-Dx@NV*ezk~3jQPI39Qa7<98r9pB3aM1$IqZuCza^5m>Exr_?^$SyjNW$T&Z= z?TW;0Ktev8_*X7}ZkT>P(wWBUpn@htNfg7LXpfbVqxpjcKR%~Z_V%-V;!T?&A8Lj# zd|bo>lZA4+&J zq3Cdbt}W$InWBdhr-$drRqWNQMrZk0c?x(t&~>#vT5ZIh!l}H%g~6T2O9*&=Q03L9 zIE@Q~CGqx`L?Oe$>GVyX#&fMQrkf__%oS12xupH-H{NTPGXrt1h#u78x4YEq9ZVs_ zny!YUxzt`*aOZI^o?7XAO%#zn`IQJ=9iH4#o(Sao=MQ(2-L00XLJ$niGwdLr-Q>CCS8ZS^;A@#}ao@iMC zk_vD9;vB(6SPG9H$#FOjVHW09v~ug zAFnaupnGB@U~zWnERZ6V*lU46#G=CrDZ8k8x5)9=^x|%&&Ex#5>N&+-4bZha*BHrn zO2elV?a+$s3~pV%r;o%-?_g8;LRlobg?F;m3C#NlaoV(fx%$6xnc%N~(|rB3U(j=0 zgIb3jLf)(mIh4s$%EO3f2{KQX{;;OKDySl?f~-=9K;Qe5qQeO>)niNX?+&O$JL6o- z6(K$W8F-DNnJ>KEEx_ZXN23YC)4Ovctm4*`lkb5j^@c`>y9sd^){{ahsq1+A8vzK=Kcp($6}?q2_;ieDa+xSPEl$by>8nRWm=O9Ls$;toUO#NhMyfLIPa7xW zG!Q5~a1o;$Q(x0i)KcPbYB zgj9_Yp+)|mjxk;emwh}bohCmpM5M6Io%;>dyeXLTm%p3D5fXoHK`v;C&yYQ!+d8nq z`Ho*xOO)`QG%opj*s)TwPH05w4E|youllrNyiV6Rb^hm1rJJk_ofxs-yrkNnIchEBg)?c=qio@k;UL~ z^trkg$;^JN1UPt`$6U0MeNCbE8#ug3nuWrh|Eawb%mT0W+{L4M&zv{#k!>Y;1prTs zfYS(%72{>Hx!nzPHBrry-KCEvt<3L*t}fuZ*J`|YeGe1|gM6S7O+;a)kFzpdNtp&O zuh{DiB%roUc`<)?PVRrUa`m86;GeE@mSDUva?|)=iuuIvxUg>3bW=o1_x0nI7x9WG zYItzRMC08<+M{Qh54z4kZqDgts~~%Y(Ym1MswmJ~UUacqakc$4KJ;;aHS;5?Gs$5E zxWzK*e)kpUK{Cxq?;=69XXf9lmXZDk z*xk+H?r;Q~-g=sRR!!COCeJ~^p`qWRC-My0PhM#DK41MpZ@M+kwwzs?TF9T`?o7ep z04o&TZkO7ukQ$Z95J3$zkiv5Z!ok|yI+##SezHS$o_mYgBfH}mDe{-J<`c>{6v zL(~DE{iP!_KYT4mrDTssmhk1VnrX+%j{|Vz(J>#X?G>SApwaWP0qLHOYf);{8y^giM=si*z#C3{TUZ}fTt$aye z^d@{JZHYh(aT>HZ{GMA-_dYQL0*C(QFzQ8)7h$Czo|BqGWuK5spLN38THX9#VS#S? zPX_(VFYg|p7Gt_b+Ecp+0lh6-ae2={>I^I(A%#YAPxm7#I42Y~NHQ=i#CBWKbfNJ5 zGav8uX{a%FqVm+FulL=$dN+@*^CmW^>t^D)%&wIW19yCl+WZ2e{a9&E`^pZ`?foCz4= z?bdQu%jlN3o&BDA-t7Lia_82);yVG! zL}X%(yr{^GAv#0<>50>h4uH1rE#jzn2V;`eE0xBPPKLpX32*$PU0s1)(IVt{O9v6c zc;Lin#03fqc(?Oqe8KhY8kZM@3uJmer+qwJ%wVj$ycE7~iN0HzU0wswZIZT>?yyg1 zoArrc`i^C*K14zU0PMTA+mxwTNljsaZE1MtaowSpq<=c8{~|B;CW@EZYCOQ^ycH%X z1JbuQWLR8-bHUsE_C=x0q9>;#Bc+h3p~5cQn*0q*zzG80FsMm(uL z=WKDclf`8tNR?-0#?u}o*hK!?F@Y$RkXP= zlD*Byc^|L*<;!uxM5T9cv;@}l+p(s_Ilq4%d|RqJnSfPBlyC-#Q&W|k1WNwMC2U;T!UNwRF+VN4^i z;i8WB9bY8H2C~!^9)bd0yr*PeaIN47?uzn@S^F(seDU;&uUVvlLbCoy6WDUB(>ucE zX2-Upg+Hps!#ls8Tw;!lw9mO4X)0S8gRUu;o-JSOgM53AUg3a=PA2+EE6%LIA!SAc zt)cUnGRN7~shhERT8zwV=h6E3bs?oFB3gInR=F)soAN~)E$ww9+y!`MSwmITj;2y5&eC5`a+)^vG?X{g=llpU@{yRvjtr5H>wJE?$nCdK*PG zr2&>(>mB$=iGZFXk($^m&BInm|6M9}JRA3dqQXoo_I z4K*I?p_6l($+p&g`6tH1*Z8Vvf$0PT})Kote<{2>}%kO z@!j$DStfyInMaK@c&b6lmP^=}aUX)*BpUralT*Wyh~~4)nDwrD9~clO#W4&Bux-P+ z0=So*Cp^oQ5P)uyYY7zg?+9gqI33WXO6W*UZUCV zo*KCp1YQV9dLvoP#rQ#pk9n%V)PvbqQUK%iN4Rnh51wRAS4TYqulDW zyEcLI0VD9^TSZ-6-8j683X_{|;cuQv@(!wt?1Tv&N&<<*?Y?;2ceI|2SQSauUDZu+ z;=T7h4$pM)K~jo~zpqTY{8~xhCi`Gp#Z^>?m+9JxsnZ60HIRJbaK@yN5SI=XIha8J ztver~XV+jIgbr@p-O}OkFOwx?o_K2ge6;o#)ka^c?poDPq)Lq3&GANUWVO)>A)zKS$DqpkX{*S zWMiJV86rk|nETiOPIf?KekKhXD#o565mg}FLj-Fjt-F_@5KE^kP{8BJ^9Xrzlge@9 z7MT(uC>9ekbK2zTTYlub%@8YSM2rvc&{wYEBvmICsputvk`Vh|{E&0iWfY#D*!8Jp zJf`c>y`?M)k0gV=uX-?sthT`bb8q6Ng9lfdA?ifzZ>0KFG>+%G3UO*Xl!#oES6E_$ zd~4$nwroy?QfAQmA6qwkFW}f1?B~DNW9`o(wsT7PNs~Y0B8FBHi-)K4((ZF2tNtwH zwdppO_)o$6Zj>dD2=66HJ)FsZ5PMCLe^;%-`T3tTLGys&6ObZfCmv82|63L}3pU;U zuiE)b4`N9L9ZJ|DQtaP~t$zZkZ>UqvjJmnVYcp!Vn<)KTJDoZu(S4Q|@+`M1C zJ2}pw?FiZun8;Ic@c9GBT`wfz1}~D45%X0oN{xwFQM*9cP8FB*!dBb`zk2u`K z&$Y`_Q(LY(UsTr2BpE;z=*$I|s+^Qyrhw9MJFCcL5a`$*uvJ=1a$_@Vn>3^61{g{8`!@KO{jPk*pNf=#0H(~yDMlsqD2Y2dWe^>=# z%yAA}2!p79y3(obZNPeSh;WG`(419vty4d3%WB9UB!~i{C(vG zm9~{T8D_3XQryIHC2{IM5g-g*IuVmA#!)rKOkP!MjWVZPpb!q{U9>WYLW?^kZ?7Si z_i!VhEZ$vmJG$;V<#!;zDC&{={kT?C88&hkV5@)!N~I5_XP_|0jJR+XhKzwxXH*L; z9bjA#U6zX@M9|e)5t^E-!kH*V_X9=Q9z7>NvHn7Byv&H5p~7t@HCa!-Muj?ZtcYU5 z+d5Bo(On~prP9X-@fUeqQj7jy)!{W@Wk_U%UW)Ms<*7R7t^!PGDe#5SE2ghyy+4>z zasPCN8?Gr79GIE@IYH*v($}Kv63d6}@^iel_A|Q!%|73c?UW-+k<^G!H#U9VA4@mw z=7Lb(-tHtFStpO0WTybA`n9B&@$PZc zywk=jkdWrrAA9nGJ4=Y2*J=UBv(|40`A(!4;T0o*Dm0@v0q{!GS(iv(L4VFQxZF<8 zdHf|49s8Ye2(J3DNMeJAaqg3r^aZ}8GvV~foWfcsv==e&>(;o_@!@wMYXm1Uxv{;b zY3N5@)Q(*1Z)_&)M0qA58a=Hr+Df&}i@MV!fLK4h?KZU0WlQjB3z|BCUVc&sMD)^V z4e=qpV&D|3?3bwE%Gzp(_t)S`-N=F4(}V>B!2Y_px1)A2wg zILb)d1!3vzOPtdU*c~LxxjUCU^1PI_9%PP3sLn!~U8TY4^5sXGL` z02_|?Z6*NyqS4I?VvDL#lP`=cWg)nGo_<99{Wv8kkN|12&on!1N`F7JKm)504r~Oj zBDaT0jMDf~6hpi0@j!Yn^`Nf_WgA^(C#rE{Gd1`aGZy#lc%go*#Fz1>$!18<);UyC ztLh>ifkEEvZ2BTYYQ7c|(D}FdOO!K6S%K*q12Wl3dq{)M{kl1nxh#bnr`9mqV22J( zfWJ)0`+Li}sIF>qILfnScZkR7H{d>LRz1#A;>51tBw48EtJ71dkY9S)-A&Y)-hRoo zPdJ{h{9lGiu7gMx$;3kr{+e!^6mI?LMK0qj=%!*}8c0nM9eWa5|1o{jo_FUo)#38d z+uet9=DjrWy%~YG`>AGIJq66mWM_lnj-IWO$RNd<6xv;6Zz=manWc z1bZrBox-r@`ZNZvjaxF;A5SBOKaUD6F?+$?3R`9s6MvimbU6d)67}jv;(j>6HCLU*TUX4K z*5A_${@C$-ej(+|8d+^HIv6}gC_tR^^fd1ZZOSfZI-oiZh-fC*qiSDsi+D{0qsQ5M z_(ij@Z*{v}gWVnN)%!*CVxiCEvk|%h^(J$a%HQ1K#t2>S;Wr;i_FaMLdh50_W1b2sR*SoQPJ3$@8<&kG;rqGz44YQ5@ zZSGr4%2{!8bz|2Ix*6pB-CqWyEQWBc<%3YRV5>z&TPha-{uH}*bQiU3^VC94tkSR{AbSF7oI8L77W}kE6tmJ6n7+hm<(~BEh`cfMy431F^QF!&d<8Np zRo@yUV6S_|#}TTSjEZ`TwOWX#oV&Lq)Z^$}MYXgXX)p z_BY3tf!HOBzlH=;v9KqQiO=-6pLmQOewSoeSi=8pJZbt7N?Yb*YtrQIO0`i*_A+Hl z+81kDB7WD01z-&h)GPv{c^$^QjP$XjhRwTqW)P7uGT4UQ_)H(v)wAZkwPb5z>0n~G z)G`BAp`BSoe{!gCaCJVc_21ce+VD&Rfb}A2Af+@pxG4@@#aJn0ey6KV4i0BN0bSN8 z;dW{_7Gbrzle=H=J^4C}=VT@BXQV7A&l`*jHnDpae1kY#^F2X&x3oHYjx9{am?xVI zv9NEfx50KMCne#ydpxG`>wOWta}(d5U1Z?$LVyitzTKHZL8tAh*`1W<%1z1^RS$4M zXA{BE#|BCD#mm!7_vP;!kD}^*YAb1!xi$VINr(4W}TFKgDUjgj?|B{@T*>j6g}uW2IF+qI)IKaW%Kdg<{mz=1UX2-2UOlU&Cx zS`ojV1!3O!G{OO!Nosp*9B=pogv{b@uHNd-^2rI!!`ttsZ0x|P-u^2bshb&VFn=XU zF@o<+d#*wWW1kkR7_@@Tx-Z~Q@#$->)~1CkW{cz$E$yjJpg9o7lquSmGAA8;V=3($ zI5PYQ!0wA8e{x)a<+Fs8@=I;?lUw<=8{`9Vf~??hf!Wk6_&kZ~TDMz!d7;cKko@fp zEqzB2HV|oxIdmo=+nf7ETp$PX@KV&55to1E%59Pk*es#v+urzb!fep}_z7-8Hq&<0 z$;d>jJs12#VW9?#_YZ&UFKh3KHQkpzt0H4Qe4oK(P6KDP+562J%i9qABrDF=s`Btb zf1ERyVto^$?PT!l+xBtW7-*zFt0)YJ)Prw5{^*_QFS}XtV|aMyd=KQ^xYYVv#@-lT z1Lp?P5D`Sz$8Gp)-&%|u;_>|yT9h|Jc4A{8y3hm)VL3CPfEvmn-Be#B;+e=h0Eb^K zn?q+C=$=G#BMwyI%ZXviv2E+GsI~u2h-gv_)Nw4fE4HFI_k6xXWm9D~-LHj>J|YTd zr#k6YM2q0r56KO?J$^jrX zuWAI0HS>`{_>cn6M$8vMGDq1_y53_?lXurr;uV7gT;KU&1v6xC6~thG0+X$D`>WVO zWaDq>2dizrHO7YWKAB&X03T^M0q(!M!V(~s-h}@cB%D9X#GLMO^m#upsi)(N4-{Bu z@mzMn=Yq4y74JW%Mt$0bp$Oc|-JAkh_r*$(hBX7btKz8=6HtMD!iuaIXp*0;0^`3x zAmXHEmsw*OBU=2RknX}4L&QAbha_whrX-=;KP+s_+|6o)j~D+?a4E(54blFSeTszw z&GwT2C2#i@xqkEoK6-nJ#aSbO`AzPvcHItyH1XuAMW7@<+Ex)wg2t)lmC*A3*`A|t3$w-3q^d;fWB!Q`VEy5=uv$*6A(Q>;fHce)Cu2Q26{;TPk==2%~lR=N>xpZLF4zx2j-)k?*es8h_&l{UJ%E!O;9z-g;_ z9JsV2ki5BtWFv{_YIWvs(Dq6-Q`OP|nyd%s68oCvaTZT}0$+Hmq?1BtODD%}PADHs zcq`YBOA$U&<~?z`eD?ZpTH&Sm8v7$=dI>UhR$?*Abkzid@eK}J;(Q5lqDybRmA7Cc z6D}Iu-tKG4AHmng_*>dxgoEKSXQ6#)s|9XN@$(5*N;Dgr_JK6vXTPhKl7G%4m#+{j z5Gr2U2p#fd90yy#0}xbH=H2P_So=&Hd?l9PG1fWdMK=d;po(;_#~;)a3%k%yT|Xop zWN3{ygN`aD?Tlx=TH=<$ucYw?{Z_?v**@psPmqo!YX}RSLLst&2`tE6>gVLsDh%T> zD`B_~n)uVGCl?-%h{Ghj9-kvVCFs-!MR+4##8mt0vXkzAA1nM*WB(t#?N6l3RL^aHfNe;yw3cG?(nb#8HCVb@ zQ@$zx+TEy$OM(P2Q$X1O3ml+TQ+n%^0i3(jEDOUUR)8kGzD0_DHsWMc3QV(PPm1_! zy~rdrK*|OQ0YqtkaPP`mNl~}|AbEY9<6bs-^vl3Y{|D0b-Sz0I1Sd_C)lYyY9k7In za1o3l!Z-#n3tzH^e`BofVP(R;qeK{xbEVE}`A(rxO&??d7-3hlJGn1kp;SOB4YAyW z1-5(~pR{iI_Fy1UR)G2JdVXl2NZT4Nj$cXsFO%4xV}-2msenxM=)`Zo(@FjcK)yb@inY3fm9-`!v2l?AoA`wx^+n_QJ){EWEp@2cjMXg zmc8M4`_celD`MRKeI{?UhX=S?|6ER23a$#c;Z0kwU;wt^k?85`HYk6-bbhp(1rMKJ zk_}RST8wOz&)hk6nD~hf>Cxk~0W6acC~R?=okzdqdjdCCC}Iz*Q$PXDT2EeRq)P!c z^Fw$R=Sy;eiYnf-8tLEA=PeE}uFV9WUzu=nX17!weo%4k*4Md+$BD9JEwpR>5 z`f`sTi3TXYupr?2gZp{Ni&qR|HFT>*0Uyj&{On{aEySC8IXOWh*oQ&NQ8=ym>UQ5qY8VTY%99t11lhLLz2XqXfv0 zP_|2BVbaH!|GWUT8sI46jZy_-SbSIQ-`ahzd=A8DQP0(QtgXfxZqfmvQ_Az-+sbCd z>P~Tkp&X6F#&#*8hci3lMadM}p}E$*>JO|Bg^u z2c`f-BC1k4((vt+Ca0&@x?el6|25}>>ow#vp>$9H?Qyx*l5)2MMZIBlMs%o#y zA_8z9T5*M<<&&o~i$=a9#<%F`gAvYb9&ICb&B4lXL2I^xY%L(*z7z^fiWis^H;`Z6 zJwHP{6M+)hypDgOUyf&5sN$mH?-dLD3#$`t!4Tlgsfn;NJWDdvngXwrS*;4?>(L1e z@UZ~|{caKl_72;;X$!wuC!vLgXC_2XrUTa2ks>Ui>)zt!bP-vV0 zOJ9=!c=BE}{510tMFJk~y$;x8JRFG_B4x2a)&ZQ%xPIH`01i`-e4vydJ^>3l;P}L! z(v7;Wg7n4LIs4RO=(Q7~Rky-9%x3kJUrIuj%F^-tz=DhZKvtkn=08sXz?HAw{L}P! z`ern@i~wQCFEvlU)}-ekzV|_7IlNU+cA$Lr`LF{P76DrE5eFpu3wzp zbjv9~D+fUPLg6&&xJLjS&7|AB?aL<~z%sf8y@92>w~isod0+xgGkA^Zxu;MMp2iWN zYvuqf=BdJm0ATT%lmKZIXbG*Z0$L$pt9c=Dl9A=7|F8DWI;!e#Tl5?0?k+(<8WAL< zTfv}1Bm@+c?(WVF2m&IKN=Ufr?rxN>4U*E*&08D&opH`P=e%*>d*j|W#vSjEJ%r77 zukT)S&9%PM8J`(~o(EMz^#R2)H^Eaw7Mw9TIz`4a$GqoG+;KB)S&r}|F|1zwyx=+< zV(RbANf-=~kK+;&vR{}A2xtT%dnx@gG8L?j(tm?vrMFk%nH{liQ=fZrBIB6gm1FiD zN!**}1_amw0gK~u>Qy20%6amivPEdZ!;m0Ls@|;VZzP2Bb^fxJ)|@gF^#0yjsw((V z_*57iN#~RkK{g2}Yj^m&H=4&(uT7|y4t#-ti$M37XS@V!EZe`{XG?8yhuAp2Fr_Z# z>chS@$*Qcv{5L~?pDAIs@nspYRl=EJ81!~6E$v)c6L#1yERYxq>-%NwMK~?TSi7&Q zOCn}~68Ysa$RW}>uw4&1MiC`49PNiJDL;3-jg7dk-|avx?FVce9a|Yj`?jPlQcTS{ z8f?onR>&PvZ`4b32}%pN{!b05ZUMeEhjQvd(vXP>`Nx}a_}iYh(<#0YYit#682Kty zaT9z|9K2d%BZ;EvpB>O3$588aDK=jogeJ9O$5R~>7YuIy6ht5L|f1cht6$}tI0NYgXbf~>Gz93Vri=2 zmkLiX7#tOV7rZYTl-z*In$>0&Qu6IA^!q{5oK}#ku!{2lKnW9p>vo&fTfecXM$+P;4_Ox;x@G7;B6|;pIVSJ30&g-G`?0 ziCWFv%uYEOW!TS1X2^OuHmSm5ca)oIgBCqzD7i!L3gz9MkVBEw=wTF!%Q>kub1_hN zQ$n$fu&w`zlWlUF1j-(?GVqa=PA{VSUAzUKn7F}dT3(A4D_xjm{Fk3_b#4w08KCCV zGcyx%@>L(lL%->_q}h&a@6|fDrYrsk%9uZGkeErNQ5@MhT^H*1yP-TS_(9J?<%-En zSz|UfrNA_&nubePJ%j06t@uoOBT{S%5y=xa6LLa6*DK0kkPVvtdcKqUjoqLB97+Ej zk>oEM6;Z@!)LJ$55aV?2^@p02R_?oIWAVf`saH8ehl-Av84Vx&S9w+>a|O1A7>p!7 zU!G$kULwihIK%^k@PBok8e=SBO-P9>in3IwS8VEML8P*hf@ORy>9wW5%8e@{9b~hw zlFKb+q)8Rzw91h96s!=xAh>gK(zlgI4=GGqll6Awy zH{Nb=#;@7`jHwATrtyC(GD2AzNaGT zbnkX2XSduA`c=okl;aYLH}#=w@zF!7_G08@)T*x=jjtLz=kbD;e_9=ZnWcTrN;HtC z!qa;k8yy7n6ov_AO?UN*0V4V9MBJ=CqUA@78n4&InOj>5Ly}w-MZ={hC!H|qd!B`x zXJ*u1_iN~@+T=WvbsCDKZF)urJt#XWkbD9~8iuNFrzx%9obDeWBf-+;p=%2GIOON_ zJOv`-{dfE!#JEi3Ew^;yFEq|dP{#WN(9z6RD!3Yr0)Up z*Klk~;%}{>R#||7O$3>(>TW*85tV3bA;bsu77N2g3!Gr}XOF;G$**2G#sO|?KEnb8 z;`^~)dv8D4Pe4Mv0t384pvYpLgfp=0G+c;Fh#v2qeI%L(_tDkKKmT++A|M>r8RQ0I zqcSUU((YkNY=>I;-$^$msgF4GPbFXcPNjA#h6P!jX2rA2Oh%(^L)Ow# zmB1#CM9fD@GZC9mHrnk37Tz117#}qEuHeKE80pCfTq&Q@}SNm-50G`TDYb&r!}Ad@@mJ2AZulY_=EB- zj@~7xYcFSaTTC)J%)UMQoZIkIH42dL@nLnP-Wc_hZNADLsu3s~4|o2Ey}&2#b>HDA z*^MA&xTMr&Vn(u(Dvw0SO~3q|OWvG&OGLT!_m^Y-=rZ_f482CXC*Mu<{9t9BuW=uH&!C@aMM zA#`iw*kBnAXAjFfhPqSa01ix7_wy)nt^_`eLq71RW71gU_-%CM+thco~R&yAQ%Nn@J7Ty8^NDptLmAr??iWmxQOK z>k+rBaht#K78&+dUrFxElPp^rgjWqc-K)Ug2N)k$1LI8?Ede+f zXe5iZ$k#<7TmbI`j0Ht3LABy*#fUPS&y1Tw9(oj7)KQ>=1Ip}TOIvixJ#p?>=kU;- z_~}B|R?NAbeL408BQ8U$DuNf~#GAMM_m2}*eSN&LXSn$I_)aPdLlALu59NtVPMW<{ z(-REkp-EJ(Fe~ikcqd%Bhb}pW`Z9L)TQ~`_lr$Pi8;m#7@?H|sC=Po(=da~quVeNT zXcvMz1gRAhhbfS@?3q3#fpW{g>`^RFqD$rRR33RDNP;hy)~(UYJ*uXu8MM1=S6Ne2 z5*}eE0wZRw3j%H+_=@Xxh)jGdu3XDmaX3}HOfHqaH?4w>G9>IQs3e?gEfobfuUogc zkYur%Uo%_a2BQ(nw#0(+u{IMw7tMs+>9UVRAT0nA2S$MW`13S0n=eY@65r-qkhZrZv5`G z(J;>Wwv2&P%F)=ajY#*e87Mb&5BMJ>J=~vqj{pG#BtDLdjWzOY%alhBaVV7RZl#JuV~hz&wdO2eEj_)f z-EM>@2-kkJ5&du2dJutaPX0fb!vB1}-q{5}$~`gSHtyYnM`5CfbqEU;oekh+^Msw7Py_<+0pxKnzu0A74JO+ zG*BRVT3*CnpnQIjrQIyAf}fnK%UTV<8dJKg1(sLa;^3{nX^;Rm2TC5#+_|BT+F5oD z;$rQT1jksbLiA0$+uLJ9Bf0M18%-+abu7&Cy+6hRNRNH(G#A=urJ=MXO_zxIdiS?y zZROi9#>6I%^C8%En(=l+%UR-z&E(2==~Ogx7e98Ap7Nt{kLADR3Oa#3B#4{q3~RLX58T_fuC{Li`wyza$|8c zO}PdA8$2@ae1%S*+Wrw>3j+M$r5?{H#34oD+cdxqYK}DUsTI2LR}K!x&5C9fS}gV% z-+=>uXDS8IaiBCa*Vl_%8g*EF1$a=rsnik}7ho)Qw9Eg+zBKRlMl9jAdV-a+=sJIu z{fAPzKGE>qk?!FprINfrp{bp~>aey~au^#N5KU8fsCDl)Ot8I0to2@UB*2~ms%fsb zX?8o`JZrv^ydFSWlS8WN2oAa-_qreEPp{%xh5ONM3)jHEQPL;5n#s>-$jARq!JkJZ9!gnwYpDLgAc&h zxUglo;_gKPAjw$pOO+)a_)^j8lT7JavG7dqh!!C3D(_fXmSU=$!y8By0VysOY{i z;MMs}`4X_;x|Y z_;}43<7mD+5+HCj&dBtVt0FL2fGCH!Du+7}2T*2_Ar=#dX!o6Cio>Z#P6E^c8ZP*$ zk~4zdc3~TSeBi!l$apm^RPRM(F(XV4sE02AbZ8vV3@?RGt?E)M0Md*_bc((sYK=S} zd?Cb>FGS8Ht7oWQh69K08)Xi+k+|+XOhfq=oXGP!fZE&LjXNHg(v5b|3yO`sdwB)q zfh|7(q=w+|(b1jq;?Y-zjhR|CMtkuhkbdAV{+SWh{q&Z_F7JiiQNF3#MJHr0DvSe4 zSTn#Wjvt;;ZcN!@BH(p3&b#~iJyDjJE^DD002=6=cyGZS!DuYyrK9h#STX2nApd@{ zavv}Z!3vPqz5|%WrwTPJKvTvP2_tczM|600=k;M6hT= z-qfOp1Neb}%Lt-mr8l)an$lq!3t zn}TSWXZ*kouoc7N*Ql@`0(4$2H6I>W34~nVX3H7@^4%Rf2mDr#eG&lY`|i%w)~y=h zUCm?E;EqWvf+`8XV6vZR32Re6Vt>jYsjnwOTp5)T+H@Mp7yVYRRbu&_rFG+gL36`s z#Vlo{(I=`rqkE9XK2Ww>X(NscmG$&P5FY`?4roOO@zX>3XNq_RX7c#>E87Wg!2%-a zc;3^D`8r281U#J8ztUZtGb@vQfk{{qF4v1uArPO|J(j5v4#Lm%Qs%c9f<6HCu)so? zLtktGczA2J9~t7c&k4X~fKUj508*IWA%cDeln#K`(sYl?m01;cI;>5PjeY`dVO&l( z09Bz1I`aaB&Ct$WRv5mlON`R(nn)CXNgK2uF17I<`P>+R_XLE!Usk0 zt$?ETiu{+NQYos|aHOIs}=zvu+9hdC}G&G=ylz)^1 zEIC_o2GSh5F5POQj5&cEZi=d*2S7vH6z`p%GqGys9YmR%9Bjhr;9i~>$Y_o}MX3RL zH~>0X_`Kd9>;SO4(>XO;MmD`-@@t+Snb6AE6^GBQ6kSn(10ekR06@rMIWzm~+Md|? zgZb$S1!Dwo4eVaPxA1<>)HC+f58(1lx_5^@eNn$t+)3-+v1x^BPj-+uaAXJ$7>Waz zr@bG4A$upcpcf+OqzDphY6t0C?3Ih=0VSobCcsp=DC%%$ z1YqyM@NEQs_L2K**5}Vl1HYy(Wzuc4Rf=V2WkvA3!DJ}+FWfU@0oDizgwA2oQ{MmpY%AZfwv7KGL>SSuLym}~asC`Y8x*p;Gqg(C|A%lZD*jn3 zlPrpv`Q>=^5`PfRJ^fZ9>G!v(T=|KAm!P(O`zvpppX1ub2Y$Q~KxX69c+IXHbBvFL z>@gP~549sV6^eg{goh`oK?+IBDHF`2*&18WPa?HpFYpr`rtD&5_Lz%+#1q#RcQfEy z)*45?QPt214C}@c6471$WuC!S-E+lrBQ5vEGu@5Nf=w@sn|8l$d=ML-5Z&$-Wj(J` zo@iFN0~Z#Rwt`IbV!@zPgJ%b`BLC9n9ONrky{_*S~!TDX^dx36dLzw`0dGsRW}*+QlQ#km!TJO*(tZ5;XEyj{@XJ zh_?y@cO1s_#q-L)MdJfEIyrx3>J)c$Ix4?StU@CaV~M`Ip89L{7%f|TNLnw{XS@lQ zVX`L&>$!!7pJM6_NqY&U8YyNoOYr80eMQ$!kOrRQ4Uco24#&<#al51QIJ)eE7|HbIlr$&_f@=~&q0g1K z(u1lM-Q@zCEIlWdABTl?JhPFgMQu~LLbv_#Q2n$bQ0Ny^cQai;&LiCHnl5ErO6Q^rex6|d^;y+CG0rVnm^IOS@`{b=P6`#7>`*%$5NPUSOp zX7gxQI|<_x5J<)GChJcTA}Sb14aThZRb+E8o!^)*GONeFiWyB{U_C<{b5v#YVpQkH zBX6mW#Ny(+ARf{(GeeurC)fxx!^+~)yLm9yQ1}wt6={0lx!n9}lg*QQNaS1WZvUDa zkfP{77OH0SrBLqds0wf{U~&+E2PEeCWb&0WU0uu+Pve94#9 zHal}34roN8TjZUEj0&|Iw46W>@`^D!7AI3XXck52V#y{mVEj&p}%)pR(Cb) z@I=w4y^nZl^F9)!*0>H3pa_1HCKJRbWM-Lb(QuxkyOLJ9R4{3Qt`Xb(H7pc@(;RYj= zJ*0zl_ZQ|%(XF@{vQi`;eJjtwA-NhloQJ(L;nT6|m8`GQi<^C}7?G&wda zJ8bR}uENTRcwXvyvFzdydFgYddz?EDrxf>K#9SBbST9SxyMms1q({H8&MPo5S^%ek zz8J`teNuSX`%>c}<>%g$rd6jm7GMz0bB7wkWbaZp{KDC2DGk+jiKn2d<&nB*?5kW$ zAAD|>*<}e?Xlh+xngG|tE3$;QUmC+4(Z?rlm`fxDm@A04mffPiF@hS-4Tim}b3%XH ze(>E~qSRaALw(u>Gf}ds;h3`8&c01kJkw@E?C_Kw37gW~JPXZ5DOz$x>U`>LaL%%@ z(>*L?fb$suDPYR)j#@``G(WA_+pD~)l8y|oQel}n zR9imxiKuZ+1XJ}Qu|0HeNW9m5vI0-oS%H(o7Q-8k3BSLWr&#Hu@Lk-eY3ls;6Fb+& znY5)K$b6SpJf;;YZ&&22i{J85r%S>>X1#sT63Otu867vuimDo4_i!++U&!l_3^HdT zjYUW0a8(`nJjy0>7Kp6ZYnwd6{ipdp-3ej&hG>s9@{DEk(nqVa!XbJL`9zgE%U+$1 zlR$Qh%RQWBt`k}<>&|1sH%d|M#L^bzq|mGhEQ_U3i>e!1u_mho8v{@@OE?H|dY&*A z(w-0i>m9Y`*}$eJd2#*?Q7Y1?f_43-w$>@3BkTKdZJ&fHj`V6JR(Q`(j+pDk1go7n z#L#8VoLl3WD1qd_!MAyS>;&1F%;@tV6nN1tsGK+KvO7cvULUp3myNYyTaek*Jqn3t zvn}p9-Pgiq^mbQA(T?;ht3Q)h;T$B0&kuFJ@UOlpqSaIsYygS^Ym!iXVjPWE%ON|* zJHN#PWqFaFS)zO9E}g4->$Vm`zM|u0T*`W~?vTsLs_|^GD4@0rrq(YE0}chxuSV?~Jcjzp(C^$I{$7K=!H5Z=Uny$Hd+&x7u96#24hnqiete|1@wF zFJV~#(tuv3MB`xgyM^91^&G|S!73TZRhyC7Rg{grqkOQTcTNxBXVHLP+P_^z_!mJ% z;Ino24W8?L0DG<-Y=G#Oq&vbbKgak;Ru1cXRvM30>QH-1maxPnWTsAu>cH;cgXy|&0WcOW?E)mJA%k#sd}a|^vAtgHEMQvdejiv2lg>af^QAV*kbOv9j) zS7$d{>(F<>UyGSyiSD_SS~zwcSr@bktQ3zIKRkLM<~sFZQi^UKmI~LQ-^BXM%GNrVKgxcQGLTXXPEi z9L~lz@U!*o_*9g)9S198E$IeuzanVOoKtmc|D@8CjquLUn%D!yDs&x9q%e&)^^Nw< z#NZ=`+7~lYIGb?7%le}Rp~TGi)PR8%a4NBeE1fjQ5lB}&EDKDVF>fBwd)4SWY@ zL0%IbRPEh1gG51Nl*^h^9&k)Ua>U=^^;)fh{ZeQ2$a!&qz{CP*Mjlz9f-DWl(8JT zq$q4zoVQtvvy7L6vvB92UC85bORa2j88;C4&&FJ{)jihObXt_*yciI2;)UG+yAScB)DzpO$UX^B z-bRKn*xw6^P{x21TXQK8@6U3*_T0l5CYL>ukjh79UwP7QWE6beEa+eKQ~By@n$C8xomKc>(o)QGj{osyUMb-8do!SY zx7b$fZHW|cvrnl%Iwm_S||43q2<C+|IO<+S}Fcb z))XQ-`8zHPQYZd5`SsUPS8%ugCFK79_M-AD^se&8S4yV-Q=p?A Date: Mon, 1 Sep 2025 10:32:38 +0800 Subject: [PATCH 16/22] Add profiler scripts --- test/profiler.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 test/profiler.py diff --git a/test/profiler.py b/test/profiler.py new file mode 100644 index 0000000..e25c5e2 --- /dev/null +++ b/test/profiler.py @@ -0,0 +1,51 @@ +import os +import sys +import argparse +import torch +import pathlib +from torch.profiler import profile, record_function, ProfilerActivity + +PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PROJECT_ROOT)) +from models.util_converse import Converse2D + +def main(): + parser = argparse.ArgumentParser(description="Converse2D Profiler") + parser.add_argument("--backend", type=str, default="pytorch", choices=["pytorch", "cuda"], + help="Which backend to profile: pytorch or cuda") + parser.add_argument("--H", type=int, default=256, help="Input height") + parser.add_argument("--W", type=int, default=256, help="Input width") + parser.add_argument("--B", type=int, default=2, help="Batch size") + parser.add_argument("--C", type=int, default=8, help="Channels") + parser.add_argument("--scale", type=int, default=2, help="Upsampling scale") + args = parser.parse_args() + + os.environ["CONVERSE2D_BACKEND"] = args.backend.lower() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print(f"[INFO] Running Converse2D with backend: {args.backend.upper()} on {device}") + + x = torch.randn(args.B, args.C, args.H, args.W, device=device) + model = Converse2D(args.C, args.C, kernel_size=5, scale=args.scale, padding=4, backend=args.backend).to(device) + + # warmup + for _ in range(10): + _ = model(x) + + trace_file = f"profile_converse2d_{args.backend.lower()}.json" + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=False, + ) as prof: + with record_function(f"Converse2D::forward({args.backend.upper()})"): + y = model(x) + + print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=25)) + prof.export_chrome_trace(trace_file) + print(f"[INFO] Saved trace to: {trace_file}") + +if __name__ == "__main__": + main() From 45f5d59ba06a14bc10f5e2614b526540e5622720 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Mon, 1 Sep 2025 10:38:32 +0800 Subject: [PATCH 17/22] Format main branch --- Converse2D/setup.py | 42 +- .../{converse2d_v7.cpp => converse2d.cpp} | 0 .../{converse2d_v7.cu => converse2d.cu} | 0 Converse2D/torch_converse2d/converse2d_v1.cpp | 146 ------- Converse2D/torch_converse2d/converse2d_v2.cpp | 225 ----------- Converse2D/torch_converse2d/converse2d_v3.cpp | 181 --------- Converse2D/torch_converse2d/converse2d_v3.cu | 138 ------- Converse2D/torch_converse2d/converse2d_v4.cpp | 171 --------- Converse2D/torch_converse2d/converse2d_v4.cu | 277 -------------- Converse2D/torch_converse2d/converse2d_v5.cpp | 208 ---------- Converse2D/torch_converse2d/converse2d_v5.cu | 277 -------------- Converse2D/torch_converse2d/converse2d_v6.cpp | 172 --------- Converse2D/torch_converse2d/converse2d_v6.cu | 323 ---------------- test/profiler.py | 51 --- test/results_2080ti.csv | 361 ------------------ test/results_4090.csv | 254 ++++-------- test/test_speed.py | 2 +- test/test_v5.csv | 37 -- test/test_v6.csv | 37 -- test/test_v7.csv | 73 ---- 20 files changed, 79 insertions(+), 2896 deletions(-) rename Converse2D/torch_converse2d/{converse2d_v7.cpp => converse2d.cpp} (100%) rename Converse2D/torch_converse2d/{converse2d_v7.cu => converse2d.cu} (100%) delete mode 100644 Converse2D/torch_converse2d/converse2d_v1.cpp delete mode 100644 Converse2D/torch_converse2d/converse2d_v2.cpp delete mode 100644 Converse2D/torch_converse2d/converse2d_v3.cpp delete mode 100644 Converse2D/torch_converse2d/converse2d_v3.cu delete mode 100644 Converse2D/torch_converse2d/converse2d_v4.cpp delete mode 100644 Converse2D/torch_converse2d/converse2d_v4.cu delete mode 100644 Converse2D/torch_converse2d/converse2d_v5.cpp delete mode 100644 Converse2D/torch_converse2d/converse2d_v5.cu delete mode 100644 Converse2D/torch_converse2d/converse2d_v6.cpp delete mode 100644 Converse2D/torch_converse2d/converse2d_v6.cu delete mode 100644 test/profiler.py delete mode 100644 test/results_2080ti.csv delete mode 100644 test/test_v5.csv delete mode 100644 test/test_v6.csv delete mode 100644 test/test_v7.csv diff --git a/Converse2D/setup.py b/Converse2D/setup.py index 1fd7114..6b310cd 100644 --- a/Converse2D/setup.py +++ b/Converse2D/setup.py @@ -1,45 +1,17 @@ -# setup.py — selectable variants: v1 | v2 | v3 | v4 from setuptools import setup from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension -import os, sys, pathlib +import os, pathlib PKG_DIR = pathlib.Path(__file__).resolve().parent / "torch_converse2d" -# --------------------------- -# parse custom args -# --------------------------- -variant = os.environ.get("CONVERSE2D_VARIANT", "").lower() -to_remove = [] -for i, a in enumerate(list(sys.argv)): - aa = a.lower() - if aa.startswith("--variant="): - variant = a.split("=", 1)[1].lower(); to_remove.append(i) - elif aa in ("--v1","--v2","--v3","--v4"): - variant = aa[2:]; to_remove.append(i) -# scrub custom flags so setuptools doesn't see them -for idx in reversed(to_remove): - sys.argv.pop(idx) -if variant not in {"", "v1","v2","v3","v4", "v5", "v6","v7"}: - raise SystemExit(f"[setup.py] invalid --variant={variant!r}; pick from v1|v2|v3|v4") +CPP = str(PKG_DIR / f"converse2d.cpp") +CU = str(PKG_DIR / f"converse2d.cu") +has_cu = os.path.exists(CU) -if not variant: - variant = "v1" # default - -# --------------------------- -# pick sources per variant -# --------------------------- -CPP = str(PKG_DIR / f"converse2d_{variant}.cpp") -CU = str(PKG_DIR / f"converse2d_{variant}.cu") -has_cu = os.path.exists(CU) # v3,v4 have .cu; v1,v2 usually not - -# --------------------------- -# CUDA arch (auto if not set) -# --------------------------- extra_cflags = ["-O3"] extra_cuda = ["-O3"] -# Respect TORCH_CUDA_ARCH_LIST if user already set it; otherwise auto-detect. if has_cu and "TORCH_CUDA_ARCH_LIST" not in os.environ: try: import torch @@ -47,12 +19,8 @@ maj, min = torch.cuda.get_device_capability(0) os.environ["TORCH_CUDA_ARCH_LIST"] = f"{maj}.{min}+PTX" except Exception: - # Fallback: a safe default that covers Ampere/Lovelace widely. os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "8.0;8.6;8.9+PTX") -# --------------------------- -# Extension definition -# --------------------------- if has_cu: ext = CUDAExtension( name="converse2d_ext", @@ -66,7 +34,7 @@ extra_compile_args={"cxx": extra_cflags}, ) -print(f"[setup.py] building variant={variant} sources={[p for p in ([CPP] + ([CU] if has_cu else []))]}") +print(f"[setup.py] building sources={[p for p in ([CPP] + ([CU] if has_cu else []))]}") print(f"[setup.py] TORCH_CUDA_ARCH_LIST={os.environ.get('TORCH_CUDA_ARCH_LIST','')}") setup( diff --git a/Converse2D/torch_converse2d/converse2d_v7.cpp b/Converse2D/torch_converse2d/converse2d.cpp similarity index 100% rename from Converse2D/torch_converse2d/converse2d_v7.cpp rename to Converse2D/torch_converse2d/converse2d.cpp diff --git a/Converse2D/torch_converse2d/converse2d_v7.cu b/Converse2D/torch_converse2d/converse2d.cu similarity index 100% rename from Converse2D/torch_converse2d/converse2d_v7.cu rename to Converse2D/torch_converse2d/converse2d.cu diff --git a/Converse2D/torch_converse2d/converse2d_v1.cpp b/Converse2D/torch_converse2d/converse2d_v1.cpp deleted file mode 100644 index 17ecb38..0000000 --- a/Converse2D/torch_converse2d/converse2d_v1.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using at::Tensor; - -namespace { - -static inline Tensor sfold_upsample_zero_insertion(const Tensor& x, int64_t s) { - TORCH_CHECK(s >= 1, "scale must be >= 1"); - if (s == 1) return x; - auto sizes = x.sizes().vec(); - sizes[sizes.size()-2] *= s; - sizes[sizes.size()-1] *= s; - Tensor z = at::zeros(sizes, x.options()); - z.index_put_( - {at::indexing::Slice(), at::indexing::Slice(), - at::indexing::Slice(0, z.size(-2), s), - at::indexing::Slice(0, z.size(-1), s)}, x); - return z; -} - -static inline Tensor p2o(const Tensor& psf, int64_t H, int64_t W) { - TORCH_CHECK(psf.dim() == 4 && psf.size(0) == 1, "psf must be (1,C,kh,kw)"); - auto C = psf.size(1); - auto kh = psf.size(2); - auto kw = psf.size(3); - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); - const int64_t sh = -static_cast(kh / 2); - const int64_t sw = -static_cast(kw / 2); - otf = at::roll(otf, {sh, sw}, {-2, -1}); - return at::fft_fftn(otf, c10::optional({}), c10::optional({-2,-1}), c10::nullopt); -} - -static inline Tensor splits_mean_then_mean(const Tensor& a, int64_t s) { - TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); - TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); - - const auto& sizes = a.sizes(); - const int64_t L = a.dim(); - const int64_t W = sizes[L-2]; - const int64_t H = sizes[L-1]; - const int64_t W_s = W / s; - const int64_t H_s = H / s; - - std::vector view_shape; - view_shape.reserve(L + 2); - for (int64_t i = 0; i < L-2; ++i) view_shape.push_back(sizes[i]); - view_shape.push_back(s); - view_shape.push_back(W_s); - view_shape.push_back(s); - view_shape.push_back(H_s); - Tensor v = a.view(view_shape); - - std::vector perm; - perm.reserve(view_shape.size()); - for (int64_t i = 0; i < L-2; ++i) perm.push_back(i); - perm.push_back(L-2 + 1); // W_s - perm.push_back(L-2 + 3); // H_s - perm.push_back(L-2 + 0); // s - perm.push_back(L-2 + 2); // s - Tensor p = v.permute(perm).contiguous(); - - std::vector merge_shape; - merge_shape.reserve(L+1); - for (int64_t i = 0; i < L-2; ++i) merge_shape.push_back(p.size(i)); - merge_shape.push_back(W_s); - merge_shape.push_back(H_s); - merge_shape.push_back(s * s); - Tensor r = p.view(merge_shape); - - return r.mean(-1, /*keepdim=*/false); -} - -Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) { - TORCH_CHECK(x.dim()==4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim()==4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim()==4 && weight.size(0)==1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim()==4 && bias.size(0)==1 && bias.size(2)==1 && bias.size(3)==1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device()==x0.device() && x.device()==weight.device() && x.device()==bias.device(), "tensors on same device"); - TORCH_CHECK(scale >= 1, "scale must be >= 1"); - - x = x.contiguous(); - x0 = x0.contiguous(); - weight= weight.contiguous(); - bias = bias.contiguous(); - - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; - - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - - Tensor STy = sfold_upsample_zero_insertion(x, scale); - - Tensor FB = p2o(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); - Tensor F2B = at::abs(FB).pow(2.0); - - Tensor F_STy = at::fft_fftn(STy, c10::optional({}), c10::optional({-2,-1}), c10::nullopt); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::optional({}), c10::optional({-2,-1}), c10::nullopt); - - Tensor x1 = FB * FR; - - Tensor FBR = splits_mean_then_mean(x1, scale); - Tensor invW= splits_mean_then_mean(F2B, scale); - - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; - - Tensor invWBR_rep = invWBR.repeat({1,1,scale,scale}); - - Tensor FCBinvWBR = FBC * invWBR_rep; - - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out_c = at::fft_ifftn(FX, c10::optional({}), c10::optional({-2,-1}), c10::nullopt); - Tensor out = at::real(out_c); - (void)B; (void)C; (void)H; (void)W; - return out; -} - -} - - -TORCH_LIBRARY(converse2d, m) { - m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); -} -TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) { - m.impl("forward", TORCH_FN(converse2d_forward)); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v2.cpp b/Converse2D/torch_converse2d/converse2d_v2.cpp deleted file mode 100644 index 23b8a36..0000000 --- a/Converse2D/torch_converse2d/converse2d_v2.cpp +++ /dev/null @@ -1,225 +0,0 @@ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -using at::Tensor; - -struct FBKey -{ - int64_t device_id; - at::ScalarType dtype; - int64_t channels; - int64_t H, W; - void *ptr; - - bool operator==(const FBKey &other) const - { - return device_id == other.device_id && dtype == other.dtype && - channels == other.channels && H == other.H && W == other.W && - ptr == other.ptr; - } -}; - -namespace std -{ - template <> - struct hash - { - size_t operator()(const FBKey &k) const - { - return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ - ((hash()(k.H) ^ hash()(k.W)) << 1) ^ - ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); - } - }; -} - -constexpr size_t FB_CACHE_MAX_SIZE = 64; - -static std::unordered_map> fb_cache; -static std::list fb_cache_lru; -static std::mutex fb_cache_mutex; - -static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) -{ - auto C = psf.size(1); - FBKey key{ - psf.device().index(), - psf.scalar_type(), - C, H, W, - psf.data_ptr()}; - - { - std::lock_guard lock(fb_cache_mutex); - auto it = fb_cache.find(key); - if (it != fb_cache.end()) - { - fb_cache_lru.remove(key); - fb_cache_lru.push_front(key); - return it->second; - } - } - - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - int64_t kh = psf.size(2), kw = psf.size(3); - otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); - otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); - Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor F2B = at::abs(FB).pow(2); - - { - std::lock_guard lock(fb_cache_mutex); - fb_cache[key] = {FB, F2B}; - fb_cache_lru.push_front(key); - - if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) - { - fb_cache.erase(fb_cache_lru.back()); - fb_cache_lru.pop_back(); - } - } - - return {FB, F2B}; -} - -static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) -{ - TORCH_CHECK(s >= 1, "scale must be >= 1"); - if (s == 1) - return x; - auto sizes = x.sizes().vec(); - sizes[sizes.size() - 2] *= s; - sizes[sizes.size() - 1] *= s; - Tensor z = at::zeros(sizes, x.options()); - z.index_put_( - {at::indexing::Slice(), at::indexing::Slice(), - at::indexing::Slice(0, z.size(-2), s), - at::indexing::Slice(0, z.size(-1), s)}, - x); - return z; -} - -static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) -{ - TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); - TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); - - const auto &sizes = a.sizes(); - const int64_t L = a.dim(); - const int64_t W = sizes[L - 2]; - const int64_t H = sizes[L - 1]; - const int64_t W_s = W / s; - const int64_t H_s = H / s; - - std::vector view_shape; - view_shape.reserve(L + 2); - for (int64_t i = 0; i < L - 2; ++i) - view_shape.push_back(sizes[i]); - view_shape.push_back(s); - view_shape.push_back(W_s); - view_shape.push_back(s); - view_shape.push_back(H_s); - Tensor v = a.view(view_shape); - - std::vector perm; - perm.reserve(view_shape.size()); - for (int64_t i = 0; i < L - 2; ++i) - perm.push_back(i); - perm.push_back(L - 2 + 1); - perm.push_back(L - 2 + 3); - perm.push_back(L - 2 + 0); - perm.push_back(L - 2 + 2); - Tensor p = v.permute(perm).contiguous(); - - std::vector merge_shape; - merge_shape.reserve(L + 1); - for (int64_t i = 0; i < L - 2; ++i) - merge_shape.push_back(p.size(i)); - merge_shape.push_back(W_s); - merge_shape.push_back(H_s); - merge_shape.push_back(s * s); - Tensor r = p.view(merge_shape); - - return r.mean(-1, /*keepdim=*/false); -} - -Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) -{ - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); - TORCH_CHECK(scale >= 1, "scale must be >= 1"); - - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); - - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; - - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - Tensor STy = sfold_upsample_zero_insertion(x, scale); - - auto [FB, F2B] = p2o_cached(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); - - Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); - - Tensor x1 = FB * FR; - Tensor FBR = splits_mean_then_mean(x1, scale); - Tensor invW = splits_mean_then_mean(F2B, scale); - - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; - - Tensor invWBR_rep = invWBR.repeat({1, 1, scale, scale}); - Tensor FCBinvWBR = FBC * invWBR_rep; - - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor out = at::real(out_c); - return out; -} - -void clear_fb_cache() -{ - std::lock_guard lock(fb_cache_mutex); - fb_cache.clear(); - fb_cache_lru.clear(); -} - -TORCH_LIBRARY(converse2d, m) -{ - m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); - m.def("clear_cache() -> ()"); -} -TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) -{ - m.impl("forward", TORCH_FN(converse2d_forward)); - m.impl("clear_cache", TORCH_FN(clear_fb_cache)); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v3.cpp b/Converse2D/torch_converse2d/converse2d_v3.cpp deleted file mode 100644 index 7a1c460..0000000 --- a/Converse2D/torch_converse2d/converse2d_v3.cpp +++ /dev/null @@ -1,181 +0,0 @@ -// backend/converse2d_v3.cpp -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -using at::Tensor; - -Tensor block_mean_cuda(const Tensor &input, int64_t s); - -// ---------- FB Cache ---------- -struct FBKey -{ - int64_t device_id; - at::ScalarType dtype; - int64_t channels; - int64_t H, W; - void *ptr; - - bool operator==(const FBKey &other) const - { - return device_id == other.device_id && dtype == other.dtype && - channels == other.channels && H == other.H && W == other.W && - ptr == other.ptr; - } -}; - -namespace std -{ - template <> - struct hash - { - size_t operator()(const FBKey &k) const - { - return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ - ((hash()(k.H) ^ hash()(k.W)) << 1) ^ - ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); - } - }; -} // namespace std - -constexpr size_t FB_CACHE_MAX_SIZE = 64; -static std::unordered_map> fb_cache; -static std::list fb_cache_lru; -static std::mutex fb_cache_mutex; - -static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) -{ - auto C = psf.size(1); - FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; - - { - std::lock_guard lock(fb_cache_mutex); - auto it = fb_cache.find(key); - if (it != fb_cache.end()) - { - fb_cache_lru.remove(key); - fb_cache_lru.push_front(key); - return it->second; - } - } - - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - int64_t kh = psf.size(2), kw = psf.size(3); - otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); - otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); - Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor F2B = at::abs(FB).pow(2); - - { - std::lock_guard lock(fb_cache_mutex); - fb_cache[key] = {FB, F2B}; - fb_cache_lru.push_front(key); - - if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) - { - fb_cache.erase(fb_cache_lru.back()); - fb_cache_lru.pop_back(); - } - } - - return {FB, F2B}; -} - -static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) -{ - if (s == 1) - return x; - auto sizes = x.sizes().vec(); - sizes[sizes.size() - 2] *= s; - sizes[sizes.size() - 1] *= s; - Tensor z = at::zeros(sizes, x.options()); - z.index_put_({at::indexing::Slice(), at::indexing::Slice(), - at::indexing::Slice(0, z.size(-2), s), - at::indexing::Slice(0, z.size(-1), s)}, - x); - return z; -} - -// ---------- Forward ---------- -Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) -{ - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); - - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); - - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; - - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - Tensor STy = sfold_upsample_zero_insertion(x, scale); - - auto [FB, F2B] = p2o_cached(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); - - Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); - - Tensor x1 = FB * FR; - - Tensor FBR = block_mean_cuda(x1, scale); // (B,C,H,W) - Tensor invW = block_mean_cuda(F2B, scale); // (B,C,H,W) - - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; - - Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) - .expand({B, C, H, scale, W, scale}) - .reshape({B, C, Hs, Ws}); - Tensor FCBinvWBR = FBC * invWBR_exp; - - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor out = at::real(out_c); - return out; -} - -// ---------- Clear Cache ---------- -void clear_fb_cache() -{ - std::lock_guard lock(fb_cache_mutex); - fb_cache.clear(); - fb_cache_lru.clear(); -} - -// ---------- Registration ---------- -TORCH_LIBRARY(converse2d, m) -{ - m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); - m.def("clear_cache() -> ()"); -} -TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) -{ - m.impl("forward", TORCH_FN(converse2d_forward)); - m.impl("clear_cache", TORCH_FN(clear_fb_cache)); -} -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v3.cu b/Converse2D/torch_converse2d/converse2d_v3.cu deleted file mode 100644 index 40f838e..0000000 --- a/Converse2D/torch_converse2d/converse2d_v3.cu +++ /dev/null @@ -1,138 +0,0 @@ -#include -#include -#include -#include -#include - -// ====================== -// block mean (forward): -// in : (B,C,Hs,Ws) -// out: (B,C,Ho,Wo), Ho=Hs/s, Wo=Ws/s -// ====================== - -template -struct AccT -{ - using type = T; -}; -template <> -struct AccT -{ - using type = float; -}; -template <> -struct AccT -{ - using type = float; -}; -template -__global__ void block_mean_kernel( - const scalar_t *__restrict__ in, // (B,C,Hs,Ws) - scalar_t *__restrict__ out, // (B,C,Ho,Wo) - int B, int C, int Ho, int Wo, int s, int Hs, int Ws, - long long total_out) -{ - using acc_t = typename AccT::type; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_out) - return; - - int wo = static_cast(idx % Wo); - int ho = static_cast((idx / Wo) % Ho); - int c = static_cast((idx / (1LL * Wo * Ho)) % C); - int b = static_cast(idx / (1LL * Wo * Ho * C)); - - const int hi0 = ho * s; - const int wi0 = wo * s; - - const long long base_in = ((long long)b * C + c) * Hs * Ws; - - acc_t acc = acc_t(0); - for (int di = 0; di < s; ++di) - { - const int hi = hi0 + di; - const long long row_off = base_in + (long long)hi * Ws + wi0; -#pragma unroll - for (int dj = 0; dj < s; ++dj) - { - acc += static_cast(in[row_off + dj]); - } - } - const float inv_area = 1.0f / (s * s); - acc = acc * static_cast(inv_area); - - const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; - out[out_off] = static_cast(acc); -} - -struct BlockMeanFunctionV3 : public torch::autograd::Function -{ - static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &input, int64_t s) - { - TORCH_CHECK(input.is_cuda() && input.dim() == 4, "block_mean_cuda: input must be (B,C,Hs,Ws) CUDA"); - TORCH_CHECK(s >= 1, "block_mean_cuda: s must be >= 1"); - - auto x = input.contiguous(); - const int B = (int)x.size(0); - const int C = (int)x.size(1); - const int Hs = (int)x.size(2); - const int Ws = (int)x.size(3); - - TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "block_mean_cuda: H,W must be divisible by s"); - const int Ho = Hs / (int)s; - const int Wo = Ws / (int)s; - - auto out = at::empty({B, C, Ho, Wo}, x.options()); - - // launch forward kernel - { - const long long total_out = 1LL * B * C * Ho * Wo; - const int threads = 256; - const int blocks = (int)((total_out + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::kBFloat16, - x.scalar_type(), "block_mean_v3_fwd", [&] - { block_mean_kernel<<>>( - x.data_ptr(), out.data_ptr(), - B, C, Ho, Wo, (int)s, Hs, Ws, total_out); }); - } - - // save for backward - ctx->saved_data["B"] = (int64_t)B; - ctx->saved_data["C"] = (int64_t)C; - ctx->saved_data["Hs"] = (int64_t)Hs; - ctx->saved_data["Ws"] = (int64_t)Ws; - ctx->saved_data["s"] = (int64_t)s; - return out; - } - - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, - torch::autograd::tensor_list grad_outputs) - { - auto go = grad_outputs[0]; // (B,C,Ho,Wo) - const int B = (int)ctx->saved_data["B"].toInt(); - const int C = (int)ctx->saved_data["C"].toInt(); - const int Hs = (int)ctx->saved_data["Hs"].toInt(); - const int Ws = (int)ctx->saved_data["Ws"].toInt(); - const int s = (int)ctx->saved_data["s"].toInt(); - - const int Ho = Hs / s; - const int Wo = Ws / s; - - // gi = expand( go / (s*s), dims=[B,C,Ho,1,Wo,1] ) -> reshape(B,C,Hs,Ws) - auto go_scaled = go / static_cast(s * s); - auto gi = go_scaled.view({B, C, Ho, 1, Wo, 1}) - .expand({B, C, Ho, s, Wo, s}) - .reshape({B, C, Hs, Ws}) - .contiguous(); - - return {gi, torch::Tensor()}; // no grad for s - } -}; - -at::Tensor block_mean_cuda(const at::Tensor &input, int64_t s) -{ - return BlockMeanFunctionV3::apply(input, s); -} diff --git a/Converse2D/torch_converse2d/converse2d_v4.cpp b/Converse2D/torch_converse2d/converse2d_v4.cpp deleted file mode 100644 index 382cd7b..0000000 --- a/Converse2D/torch_converse2d/converse2d_v4.cpp +++ /dev/null @@ -1,171 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -using at::Tensor; - -Tensor block_mean_cuda(const Tensor &input, int64_t s); -Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t scale); - -// ---------- FB Cache ---------- -struct FBKey -{ - int64_t device_id; - at::ScalarType dtype; - int64_t channels; - int64_t H, W; - void *ptr; - - bool operator==(const FBKey &other) const - { - return device_id == other.device_id && dtype == other.dtype && - channels == other.channels && H == other.H && W == other.W && - ptr == other.ptr; - } -}; - -namespace std -{ - template <> - struct hash - { - size_t operator()(const FBKey &k) const - { - return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ - ((hash()(k.H) ^ hash()(k.W)) << 1) ^ - ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); - } - }; -} // namespace std - -constexpr size_t FB_CACHE_MAX_SIZE = 64; -static std::unordered_map> fb_cache; -static std::list fb_cache_lru; -static std::mutex fb_cache_mutex; - -static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) -{ - auto C = psf.size(1); - FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; - - { - std::lock_guard lock(fb_cache_mutex); - auto it = fb_cache.find(key); - if (it != fb_cache.end()) - { - fb_cache_lru.remove(key); - fb_cache_lru.push_front(key); - return it->second; - } - } - - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - int64_t kh = psf.size(2), kw = psf.size(3); - otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); - otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); - Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor F2B = at::abs(FB).pow(2); - - { - std::lock_guard lock(fb_cache_mutex); - fb_cache[key] = {FB, F2B}; - fb_cache_lru.push_front(key); - - if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) - { - fb_cache.erase(fb_cache_lru.back()); - fb_cache_lru.pop_back(); - } - } - - return {FB, F2B}; -} - -static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) -{ - if (s == 1) - return x; - return sfold_upsample_cuda_launcher(x, s); -} - -// ---------- Forward ---------- -Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) -{ - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); - - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); - - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; - - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - Tensor STy = sfold_upsample_zero_insertion(x, scale); - - auto [FB, F2B] = p2o_cached(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); - - Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); - - Tensor x1 = FB * FR; - - Tensor FBR = block_mean_cuda(x1, scale); // (B,C,H,W) - Tensor invW = block_mean_cuda(F2B, scale); // (B,C,H,W) - - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; - - Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) - .expand({B, C, H, scale, W, scale}) - .reshape({B, C, Hs, Ws}); - Tensor FCBinvWBR = FBC * invWBR_exp; - - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor out = at::real(out_c); - return out; -} - -void clear_fb_cache() -{ - std::lock_guard lock(fb_cache_mutex); - fb_cache.clear(); - fb_cache_lru.clear(); -} - -TORCH_LIBRARY(converse2d, m) -{ - m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); - m.def("clear_cache() -> ()"); -} -TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) -{ - m.impl("forward", TORCH_FN(converse2d_forward)); - m.impl("clear_cache", TORCH_FN(clear_fb_cache)); -} -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v4.cu b/Converse2D/torch_converse2d/converse2d_v4.cu deleted file mode 100644 index e02250c..0000000 --- a/Converse2D/torch_converse2d/converse2d_v4.cu +++ /dev/null @@ -1,277 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -// ====================================================================== -// S-FOLD UPSAMPLE (zero-insertion upsample) -// forward: out[b,c,h*s, w*s] = x[b,c,h,w]; others = 0 -// backward: grad_x[b,c,h,w] = grad_out[b,c,h*s, w*s] -// dtypes: float/double/half/bfloat16 -// ====================================================================== - -using namespace at; -using namespace at::indexing; - -template -__global__ void sfold_upsample_kernel( - const scalar_t *__restrict__ x, - scalar_t *__restrict__ out, - int B, int C, int H, int W, int s, - int Hs, int Ws, long long total_in) -{ - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int w = static_cast(idx % W); - int h = static_cast((idx / W) % H); - int c = static_cast((idx / (1LL * W * H)) % C); - int b = static_cast(idx / (1LL * W * H * C)); - - long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; - long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); - - out[out_off] = x[in_off]; -} - -template -__global__ void sfold_downsample_grad_kernel( // backward of zero-insertion upsample - const scalar_t *__restrict__ grad_out, // (B,C,Hs,Ws) - scalar_t *__restrict__ grad_in, // (B,C,H,W) - int B, int C, int H, int W, int s, int Hs, int Ws, long long total_in) -{ - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int w = static_cast(idx % W); - int h = static_cast((idx / W) % H); - int c = static_cast((idx / (1LL * W * H)) % C); - int b = static_cast(idx / (1LL * W * H * C)); - - long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; - long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); - - grad_in[in_off] = grad_out[out_off]; -} - -struct SFoldFunction : public torch::autograd::Function -{ - static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &x, int64_t scale) - { - TORCH_CHECK(x.is_cuda() && x.dim() == 4, "sfold: x must be (B,C,H,W) CUDA"); - TORCH_CHECK(scale >= 1, "sfold: scale must be >= 1"); - if (scale == 1) - { - ctx->saved_data["s"] = (int64_t)1; - return x; - } - - auto x_ = x.contiguous(); - const int B = (int)x_.size(0), C = (int)x_.size(1), H = (int)x_.size(2), W = (int)x_.size(3); - const int s = (int)scale, Hs = H * s, Ws = W * s; - - auto out = at::zeros({B, C, Hs, Ws}, x_.options()); - - const long long total = 1LL * B * C * H * W; - const int threads = 256, blocks = (int)((total + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x_.scalar_type(), "sfold_fwd", [&] - { sfold_upsample_kernel<<>>( - x_.data_ptr(), out.data_ptr(), - B, C, H, W, s, Hs, Ws, total); }); - - // save for backward - ctx->saved_data["B"] = (int64_t)B; - ctx->saved_data["C"] = (int64_t)C; - ctx->saved_data["H"] = (int64_t)H; - ctx->saved_data["W"] = (int64_t)W; - ctx->saved_data["s"] = (int64_t)s; - return out; - } - - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, - torch::autograd::tensor_list grad_outputs) - { - auto go = grad_outputs[0]; // (B,C,Hs,Ws) - const int B = (int)ctx->saved_data["B"].toInt(); - const int C = (int)ctx->saved_data["C"].toInt(); - const int H = (int)ctx->saved_data["H"].toInt(); - const int W = (int)ctx->saved_data["W"].toInt(); - const int s = (int)ctx->saved_data["s"].toInt(); - const int Hs = H * s, Ws = W * s; - - at::Tensor gx; - if (s == 1) - { - gx = go; // identity - } - else - { - gx = go.index({Slice(), Slice(), Slice(0, Hs, s), Slice(0, Ws, s)}).contiguous(); - } - return {gx, torch::Tensor()}; // no grad for scale - } -}; - -// exposed symbol for v4.cpp -at::Tensor sfold_upsample_cuda_launcher(const at::Tensor &x, int64_t scale) -{ - return SFoldFunction::apply(x, scale); -} - -// ====================================================================== -// BLOCK MEAN over non-overlapping s×s tiles -// forward: out[b,c,ho,wo] = mean_{i,j in s×s} in[b,c, ho*s+i, wo*s+j] -// backward: grad_in[b,c,hi,wi] = grad_out[b,c,hi/s, wi/s] / (s*s) -// dtypes: float/double/half/bfloat16 + complex64/complex128 -// ====================================================================== - -template -struct AccT -{ - using type = T; -}; -template <> -struct AccT -{ - using type = float; -}; -template <> -struct AccT -{ - using type = float; -}; - -template -__global__ void block_mean_kernel( - const scalar_t *__restrict__ in, // (B,C,Hs,Ws) - scalar_t *__restrict__ out, // (B,C,Ho,Wo) - int B, int C, int Ho, int Wo, int s, int Hs, int Ws, - long long total_out) -{ - using acc_t = typename AccT::type; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_out) - return; - - int wo = static_cast(idx % Wo); - int ho = static_cast((idx / Wo) % Ho); - int c = static_cast((idx / (1LL * Wo * Ho)) % C); - int b = static_cast(idx / (1LL * Wo * Ho * C)); - - const int hi0 = ho * s; - const int wi0 = wo * s; - - const long long base_in = ((long long)b * C + c) * Hs * Ws; - - acc_t acc = acc_t(0); - for (int di = 0; di < s; ++di) - { - const int hi = hi0 + di; - const long long row_off = base_in + (long long)hi * Ws + wi0; -#pragma unroll - for (int dj = 0; dj < s; ++dj) - { - acc += static_cast(in[row_off + dj]); - } - } - const float inv_area = 1.0f / (s * s); - acc = acc * static_cast(inv_area); - - const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; - out[out_off] = static_cast(acc); -} - -template -__global__ void block_mean_grad_kernel( - const scalar_t *__restrict__ grad_out, // (B,C,Ho,Wo) - scalar_t *__restrict__ grad_in, // (B,C,Hs,Ws) - int B, int C, int Ho, int Wo, int s, int Hs, int Ws, - long long total_in) -{ - using acc_t = typename AccT::type; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int wi = static_cast(idx % Ws); - int hi = static_cast((idx / Ws) % Hs); - int c = static_cast((idx / (1LL * Ws * Hs)) % C); - int b = static_cast(idx / (1LL * Ws * Hs * C)); - - const int ho = hi / s; - const int wo = wi / s; - - const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; - acc_t g = static_cast(grad_out[out_off]) * static_cast(1.0f / (s * s)); - - const long long in_off = ((long long)b * C + c) * Hs * Ws + (long long)hi * Ws + wi; - grad_in[in_off] = static_cast(g); -} - -struct BlockMeanFunction : public torch::autograd::Function -{ - static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &input, int64_t s) - { - TORCH_CHECK(input.is_cuda() && input.dim() == 4, "block_mean: input must be (B,C,Hs,Ws) CUDA"); - TORCH_CHECK(s >= 1, "block_mean: s must be >= 1"); - - auto x = input.contiguous(); - const int B = (int)x.size(0), C = (int)x.size(1), Hs = (int)x.size(2), Ws = (int)x.size(3); - TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "block_mean: H,W must be divisible by s"); - const int Ho = Hs / (int)s, Wo = Ws / (int)s; - - auto out = at::empty({B, C, Ho, Wo}, x.options()); - - const long long total_out = 1LL * B * C * Ho * Wo; - const int threads = 256, blocks = (int)((total_out + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "block_mean_fwd", [&] - { block_mean_kernel<<>>( - x.data_ptr(), out.data_ptr(), - B, C, Ho, Wo, (int)s, Hs, Ws, total_out); }); - - // save for backward - ctx->saved_data["B"] = (int64_t)B; - ctx->saved_data["C"] = (int64_t)C; - ctx->saved_data["Hs"] = (int64_t)Hs; - ctx->saved_data["Ws"] = (int64_t)Ws; - ctx->saved_data["s"] = (int64_t)s; - return out; - } - - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, - torch::autograd::tensor_list grad_outputs) - { - auto go = grad_outputs[0]; // (B,C,Ho,Wo) - const int B = (int)ctx->saved_data["B"].toInt(); - const int C = (int)ctx->saved_data["C"].toInt(); - const int Hs = (int)ctx->saved_data["Hs"].toInt(); - const int Ws = (int)ctx->saved_data["Ws"].toInt(); - const int s = (int)ctx->saved_data["s"].toInt(); - const int Ho = Hs / s, Wo = Ws / s; - - auto go_scaled = go / static_cast(s * s); - auto gi = go_scaled.view({B, C, Ho, 1, Wo, 1}) - .expand({B, C, Ho, s, Wo, s}) - .reshape({B, C, Hs, Ws}) - .contiguous(); - - return {gi, torch::Tensor()}; // no grad for s - } -}; - -// exposed symbol for v4.cpp -at::Tensor block_mean_cuda(const at::Tensor &input, int64_t s) -{ - return BlockMeanFunction::apply(input, s); -} diff --git a/Converse2D/torch_converse2d/converse2d_v5.cpp b/Converse2D/torch_converse2d/converse2d_v5.cpp deleted file mode 100644 index fefd0e9..0000000 --- a/Converse2D/torch_converse2d/converse2d_v5.cpp +++ /dev/null @@ -1,208 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -using at::Tensor; - -Tensor block_mean_cuda(const Tensor &input, int64_t s); -Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t scale); - -// ---------- FB Cache ---------- -struct FBKey -{ - int64_t device_id; - at::ScalarType dtype; - int64_t channels; - int64_t H, W; - void *ptr; - - bool operator==(const FBKey &other) const - { - return device_id == other.device_id && dtype == other.dtype && - channels == other.channels && H == other.H && W == other.W && - ptr == other.ptr; - } -}; - -namespace std -{ - template <> - struct hash - { - size_t operator()(const FBKey &k) const - { - return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ - ((hash()(k.H) ^ hash()(k.W)) << 1) ^ - ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); - } - }; -} // namespace std - -constexpr size_t FB_CACHE_MAX_SIZE = 64; -static std::unordered_map> fb_cache; -static std::list fb_cache_lru; -static std::mutex fb_cache_mutex; - -inline Tensor fft2_auto_batched(const Tensor& x) -{ - TORCH_CHECK(x.dim() == 4, "Expected input of shape (B,C,H,W)"); - const int64_t B = x.size(0), C = x.size(1), H = x.size(2), W = x.size(3); - - if (B * C >= 8) - { - auto x_reshaped = x.view({B * C, H, W}).contiguous(); - auto fx = at::fft_fftn(x_reshaped, c10::nullopt, {-2, -1}, c10::nullopt); - return fx.view({B, C, H, W}); - } - else - { - return at::fft_fftn(x, c10::nullopt, {-2, -1}, c10::nullopt); - } -} - -inline Tensor ifft2_auto_batched(const Tensor& x) -{ - TORCH_CHECK(x.dim() == 4, "Expected input of shape (B,C,H,W)"); - const int64_t B = x.size(0), C = x.size(1), H = x.size(2), W = x.size(3); - - if (B * C >= 8) - { - auto x_reshaped = x.view({B * C, H, W}).contiguous(); - auto fx = at::fft_ifftn(x_reshaped, c10::nullopt, {-2, -1}, c10::nullopt); - return fx.view({B, C, H, W}); - } - else - { - return at::fft_ifftn(x, c10::nullopt, {-2, -1}, c10::nullopt); - } -} - -static inline std::pair p2o_cached(const Tensor &psf, int64_t H, int64_t W) -{ - auto C = psf.size(1); - FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; - - { - std::lock_guard lock(fb_cache_mutex); - auto it = fb_cache.find(key); - if (it != fb_cache.end()) - { - fb_cache_lru.remove(key); - fb_cache_lru.push_front(key); - return it->second; - } - } - - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - int64_t kh = psf.size(2), kw = psf.size(3); - otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); - otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); - Tensor FB = fft2_auto_batched(otf); - Tensor F2B = at::abs(FB).pow(2); - - { - std::lock_guard lock(fb_cache_mutex); - fb_cache[key] = {FB, F2B}; - fb_cache_lru.push_front(key); - - if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) - { - fb_cache.erase(fb_cache_lru.back()); - fb_cache_lru.pop_back(); - } - } - - return {FB, F2B}; -} - -static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) -{ - if (s == 1) - return x; - return sfold_upsample_cuda_launcher(x, s); -} - - - - -// ---------- Forward ---------- -Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) -{ - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); - - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); - - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; - - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - Tensor STy = sfold_upsample_zero_insertion(x, scale); - - auto [FB, F2B] = p2o_cached(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); - - Tensor F_STy = fft2_auto_batched(STy); - Tensor FB_Fy = FBC * F_STy; - Tensor FR = FB_Fy + fft2_auto_batched(lambda_ * x0); - - Tensor x1 = FB * FR; - - Tensor FBR = block_mean_cuda(x1, scale); // (B,C,H,W) - Tensor invW = block_mean_cuda(F2B, scale); // (B,C,H,W) - - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; - - Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) - .expand({B, C, H, scale, W, scale}) - .reshape({B, C, Hs, Ws}); - Tensor FCBinvWBR = FBC * invWBR_exp; - - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out_c = ifft2_auto_batched(FX); - Tensor out = at::real(out_c); - return out; -} - -void clear_fb_cache() -{ - std::lock_guard lock(fb_cache_mutex); - fb_cache.clear(); - fb_cache_lru.clear(); -} - -TORCH_LIBRARY(converse2d, m) -{ - m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); - m.def("clear_cache() -> ()"); -} -TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) -{ - m.impl("forward", TORCH_FN(converse2d_forward)); - m.impl("clear_cache", TORCH_FN(clear_fb_cache)); -} -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v5.cu b/Converse2D/torch_converse2d/converse2d_v5.cu deleted file mode 100644 index e02250c..0000000 --- a/Converse2D/torch_converse2d/converse2d_v5.cu +++ /dev/null @@ -1,277 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -// ====================================================================== -// S-FOLD UPSAMPLE (zero-insertion upsample) -// forward: out[b,c,h*s, w*s] = x[b,c,h,w]; others = 0 -// backward: grad_x[b,c,h,w] = grad_out[b,c,h*s, w*s] -// dtypes: float/double/half/bfloat16 -// ====================================================================== - -using namespace at; -using namespace at::indexing; - -template -__global__ void sfold_upsample_kernel( - const scalar_t *__restrict__ x, - scalar_t *__restrict__ out, - int B, int C, int H, int W, int s, - int Hs, int Ws, long long total_in) -{ - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int w = static_cast(idx % W); - int h = static_cast((idx / W) % H); - int c = static_cast((idx / (1LL * W * H)) % C); - int b = static_cast(idx / (1LL * W * H * C)); - - long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; - long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); - - out[out_off] = x[in_off]; -} - -template -__global__ void sfold_downsample_grad_kernel( // backward of zero-insertion upsample - const scalar_t *__restrict__ grad_out, // (B,C,Hs,Ws) - scalar_t *__restrict__ grad_in, // (B,C,H,W) - int B, int C, int H, int W, int s, int Hs, int Ws, long long total_in) -{ - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int w = static_cast(idx % W); - int h = static_cast((idx / W) % H); - int c = static_cast((idx / (1LL * W * H)) % C); - int b = static_cast(idx / (1LL * W * H * C)); - - long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; - long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); - - grad_in[in_off] = grad_out[out_off]; -} - -struct SFoldFunction : public torch::autograd::Function -{ - static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &x, int64_t scale) - { - TORCH_CHECK(x.is_cuda() && x.dim() == 4, "sfold: x must be (B,C,H,W) CUDA"); - TORCH_CHECK(scale >= 1, "sfold: scale must be >= 1"); - if (scale == 1) - { - ctx->saved_data["s"] = (int64_t)1; - return x; - } - - auto x_ = x.contiguous(); - const int B = (int)x_.size(0), C = (int)x_.size(1), H = (int)x_.size(2), W = (int)x_.size(3); - const int s = (int)scale, Hs = H * s, Ws = W * s; - - auto out = at::zeros({B, C, Hs, Ws}, x_.options()); - - const long long total = 1LL * B * C * H * W; - const int threads = 256, blocks = (int)((total + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x_.scalar_type(), "sfold_fwd", [&] - { sfold_upsample_kernel<<>>( - x_.data_ptr(), out.data_ptr(), - B, C, H, W, s, Hs, Ws, total); }); - - // save for backward - ctx->saved_data["B"] = (int64_t)B; - ctx->saved_data["C"] = (int64_t)C; - ctx->saved_data["H"] = (int64_t)H; - ctx->saved_data["W"] = (int64_t)W; - ctx->saved_data["s"] = (int64_t)s; - return out; - } - - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, - torch::autograd::tensor_list grad_outputs) - { - auto go = grad_outputs[0]; // (B,C,Hs,Ws) - const int B = (int)ctx->saved_data["B"].toInt(); - const int C = (int)ctx->saved_data["C"].toInt(); - const int H = (int)ctx->saved_data["H"].toInt(); - const int W = (int)ctx->saved_data["W"].toInt(); - const int s = (int)ctx->saved_data["s"].toInt(); - const int Hs = H * s, Ws = W * s; - - at::Tensor gx; - if (s == 1) - { - gx = go; // identity - } - else - { - gx = go.index({Slice(), Slice(), Slice(0, Hs, s), Slice(0, Ws, s)}).contiguous(); - } - return {gx, torch::Tensor()}; // no grad for scale - } -}; - -// exposed symbol for v4.cpp -at::Tensor sfold_upsample_cuda_launcher(const at::Tensor &x, int64_t scale) -{ - return SFoldFunction::apply(x, scale); -} - -// ====================================================================== -// BLOCK MEAN over non-overlapping s×s tiles -// forward: out[b,c,ho,wo] = mean_{i,j in s×s} in[b,c, ho*s+i, wo*s+j] -// backward: grad_in[b,c,hi,wi] = grad_out[b,c,hi/s, wi/s] / (s*s) -// dtypes: float/double/half/bfloat16 + complex64/complex128 -// ====================================================================== - -template -struct AccT -{ - using type = T; -}; -template <> -struct AccT -{ - using type = float; -}; -template <> -struct AccT -{ - using type = float; -}; - -template -__global__ void block_mean_kernel( - const scalar_t *__restrict__ in, // (B,C,Hs,Ws) - scalar_t *__restrict__ out, // (B,C,Ho,Wo) - int B, int C, int Ho, int Wo, int s, int Hs, int Ws, - long long total_out) -{ - using acc_t = typename AccT::type; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_out) - return; - - int wo = static_cast(idx % Wo); - int ho = static_cast((idx / Wo) % Ho); - int c = static_cast((idx / (1LL * Wo * Ho)) % C); - int b = static_cast(idx / (1LL * Wo * Ho * C)); - - const int hi0 = ho * s; - const int wi0 = wo * s; - - const long long base_in = ((long long)b * C + c) * Hs * Ws; - - acc_t acc = acc_t(0); - for (int di = 0; di < s; ++di) - { - const int hi = hi0 + di; - const long long row_off = base_in + (long long)hi * Ws + wi0; -#pragma unroll - for (int dj = 0; dj < s; ++dj) - { - acc += static_cast(in[row_off + dj]); - } - } - const float inv_area = 1.0f / (s * s); - acc = acc * static_cast(inv_area); - - const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; - out[out_off] = static_cast(acc); -} - -template -__global__ void block_mean_grad_kernel( - const scalar_t *__restrict__ grad_out, // (B,C,Ho,Wo) - scalar_t *__restrict__ grad_in, // (B,C,Hs,Ws) - int B, int C, int Ho, int Wo, int s, int Hs, int Ws, - long long total_in) -{ - using acc_t = typename AccT::type; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int wi = static_cast(idx % Ws); - int hi = static_cast((idx / Ws) % Hs); - int c = static_cast((idx / (1LL * Ws * Hs)) % C); - int b = static_cast(idx / (1LL * Ws * Hs * C)); - - const int ho = hi / s; - const int wo = wi / s; - - const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; - acc_t g = static_cast(grad_out[out_off]) * static_cast(1.0f / (s * s)); - - const long long in_off = ((long long)b * C + c) * Hs * Ws + (long long)hi * Ws + wi; - grad_in[in_off] = static_cast(g); -} - -struct BlockMeanFunction : public torch::autograd::Function -{ - static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &input, int64_t s) - { - TORCH_CHECK(input.is_cuda() && input.dim() == 4, "block_mean: input must be (B,C,Hs,Ws) CUDA"); - TORCH_CHECK(s >= 1, "block_mean: s must be >= 1"); - - auto x = input.contiguous(); - const int B = (int)x.size(0), C = (int)x.size(1), Hs = (int)x.size(2), Ws = (int)x.size(3); - TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "block_mean: H,W must be divisible by s"); - const int Ho = Hs / (int)s, Wo = Ws / (int)s; - - auto out = at::empty({B, C, Ho, Wo}, x.options()); - - const long long total_out = 1LL * B * C * Ho * Wo; - const int threads = 256, blocks = (int)((total_out + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "block_mean_fwd", [&] - { block_mean_kernel<<>>( - x.data_ptr(), out.data_ptr(), - B, C, Ho, Wo, (int)s, Hs, Ws, total_out); }); - - // save for backward - ctx->saved_data["B"] = (int64_t)B; - ctx->saved_data["C"] = (int64_t)C; - ctx->saved_data["Hs"] = (int64_t)Hs; - ctx->saved_data["Ws"] = (int64_t)Ws; - ctx->saved_data["s"] = (int64_t)s; - return out; - } - - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, - torch::autograd::tensor_list grad_outputs) - { - auto go = grad_outputs[0]; // (B,C,Ho,Wo) - const int B = (int)ctx->saved_data["B"].toInt(); - const int C = (int)ctx->saved_data["C"].toInt(); - const int Hs = (int)ctx->saved_data["Hs"].toInt(); - const int Ws = (int)ctx->saved_data["Ws"].toInt(); - const int s = (int)ctx->saved_data["s"].toInt(); - const int Ho = Hs / s, Wo = Ws / s; - - auto go_scaled = go / static_cast(s * s); - auto gi = go_scaled.view({B, C, Ho, 1, Wo, 1}) - .expand({B, C, Ho, s, Wo, s}) - .reshape({B, C, Hs, Ws}) - .contiguous(); - - return {gi, torch::Tensor()}; // no grad for s - } -}; - -// exposed symbol for v4.cpp -at::Tensor block_mean_cuda(const at::Tensor &input, int64_t s) -{ - return BlockMeanFunction::apply(input, s); -} diff --git a/Converse2D/torch_converse2d/converse2d_v6.cpp b/Converse2D/torch_converse2d/converse2d_v6.cpp deleted file mode 100644 index ab8061e..0000000 --- a/Converse2D/torch_converse2d/converse2d_v6.cpp +++ /dev/null @@ -1,172 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -using at::Tensor; - -Tensor block_mean_cuda(const Tensor &input, int64_t s); -Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t scale); -std::tuple fb_postprocess_cuda(const at::Tensor& FB); - -// ---------- FB Cache ---------- -struct FBKey -{ - int64_t device_id; - at::ScalarType dtype; - int64_t channels; - int64_t H, W; - void *ptr; - - bool operator==(const FBKey &other) const - { - return device_id == other.device_id && dtype == other.dtype && - channels == other.channels && H == other.H && W == other.W && - ptr == other.ptr; - } -}; - -namespace std -{ - template <> - struct hash - { - size_t operator()(const FBKey &k) const - { - return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ - ((hash()(k.H) ^ hash()(k.W)) << 1) ^ - ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); - } - }; -} // namespace std - -constexpr size_t FB_CACHE_MAX_SIZE = 64; -static std::unordered_map> fb_cache; -static std::list fb_cache_lru; -static std::mutex fb_cache_mutex; - -static inline std::tuple p2o_cached(const Tensor &psf, int64_t H, int64_t W) -{ - auto C = psf.size(1); - FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; - - { - std::lock_guard lock(fb_cache_mutex); - auto it = fb_cache.find(key); - if (it != fb_cache.end()) - { - fb_cache_lru.remove(key); - fb_cache_lru.push_front(key); - return it->second; - } - } - - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - int64_t kh = psf.size(2), kw = psf.size(3); - otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); - otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); - Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); - auto [FBC, F2B] = fb_postprocess_cuda(FB); - - { - std::lock_guard lock(fb_cache_mutex); - fb_cache[key] = std::make_tuple(FB, FBC, F2B); - fb_cache_lru.push_front(key); - - if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) - { - fb_cache.erase(fb_cache_lru.back()); - fb_cache_lru.pop_back(); - } - } - - return std::make_tuple(FB, FBC, F2B); -} - -static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) -{ - if (s == 1) - return x; - return sfold_upsample_cuda_launcher(x, s); -} - -// ---------- Forward ---------- -Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) -{ - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); - - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); - - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; - - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - Tensor STy = sfold_upsample_zero_insertion(x, scale); - - auto [FB, FBC, F2B] = p2o_cached(weight, Hs, Ws); - - Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); - - Tensor x1 = FB * FR; - - Tensor FBR = block_mean_cuda(x1, scale); // (B,C,H,W) - Tensor invW = block_mean_cuda(F2B, scale); // (B,C,H,W) - - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; - - Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) - .expand({B, C, H, scale, W, scale}) - .reshape({B, C, Hs, Ws}); - Tensor FCBinvWBR = FBC * invWBR_exp; - - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor out = at::real(out_c); - return out; -} - -void clear_fb_cache() -{ - std::lock_guard lock(fb_cache_mutex); - fb_cache.clear(); - fb_cache_lru.clear(); -} - -TORCH_LIBRARY(converse2d, m) -{ - m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); - m.def("clear_cache() -> ()"); -} -TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) -{ - m.impl("forward", TORCH_FN(converse2d_forward)); - m.impl("clear_cache", TORCH_FN(clear_fb_cache)); -} -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_v6.cu b/Converse2D/torch_converse2d/converse2d_v6.cu deleted file mode 100644 index 4830d49..0000000 --- a/Converse2D/torch_converse2d/converse2d_v6.cu +++ /dev/null @@ -1,323 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// ====================================================================== -// S-FOLD UPSAMPLE (zero-insertion upsample) -// forward: out[b,c,h*s, w*s] = x[b,c,h,w]; others = 0 -// backward: grad_x[b,c,h,w] = grad_out[b,c,h*s, w*s] -// dtypes: float/double/half/bfloat16 -// ====================================================================== - -using namespace at; -using namespace at::indexing; - -template -__global__ void sfold_upsample_kernel( - const scalar_t *__restrict__ x, - scalar_t *__restrict__ out, - int B, int C, int H, int W, int s, - int Hs, int Ws, long long total_in) -{ - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int w = static_cast(idx % W); - int h = static_cast((idx / W) % H); - int c = static_cast((idx / (1LL * W * H)) % C); - int b = static_cast(idx / (1LL * W * H * C)); - - long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; - long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); - - out[out_off] = x[in_off]; -} - -template -__global__ void sfold_downsample_grad_kernel( // backward of zero-insertion upsample - const scalar_t *__restrict__ grad_out, // (B,C,Hs,Ws) - scalar_t *__restrict__ grad_in, // (B,C,H,W) - int B, int C, int H, int W, int s, int Hs, int Ws, long long total_in) -{ - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int w = static_cast(idx % W); - int h = static_cast((idx / W) % H); - int c = static_cast((idx / (1LL * W * H)) % C); - int b = static_cast(idx / (1LL * W * H * C)); - - long long in_off = ((long long)b * C + c) * H * W + (long long)h * W + w; - long long out_off = ((long long)b * C + c) * Hs * Ws + (long long)(h * s) * Ws + (w * s); - - grad_in[in_off] = grad_out[out_off]; -} - -struct SFoldFunction : public torch::autograd::Function -{ - static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &x, int64_t scale) - { - TORCH_CHECK(x.is_cuda() && x.dim() == 4, "sfold: x must be (B,C,H,W) CUDA"); - TORCH_CHECK(scale >= 1, "sfold: scale must be >= 1"); - if (scale == 1) - { - ctx->saved_data["s"] = (int64_t)1; - return x; - } - - auto x_ = x.contiguous(); - const int B = (int)x_.size(0), C = (int)x_.size(1), H = (int)x_.size(2), W = (int)x_.size(3); - const int s = (int)scale, Hs = H * s, Ws = W * s; - - auto out = at::zeros({B, C, Hs, Ws}, x_.options()); - - const long long total = 1LL * B * C * H * W; - const int threads = 256, blocks = (int)((total + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x_.scalar_type(), "sfold_fwd", [&] - { sfold_upsample_kernel<<>>( - x_.data_ptr(), out.data_ptr(), - B, C, H, W, s, Hs, Ws, total); }); - - // save for backward - ctx->saved_data["B"] = (int64_t)B; - ctx->saved_data["C"] = (int64_t)C; - ctx->saved_data["H"] = (int64_t)H; - ctx->saved_data["W"] = (int64_t)W; - ctx->saved_data["s"] = (int64_t)s; - return out; - } - - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, - torch::autograd::tensor_list grad_outputs) - { - auto go = grad_outputs[0]; // (B,C,Hs,Ws) - const int B = (int)ctx->saved_data["B"].toInt(); - const int C = (int)ctx->saved_data["C"].toInt(); - const int H = (int)ctx->saved_data["H"].toInt(); - const int W = (int)ctx->saved_data["W"].toInt(); - const int s = (int)ctx->saved_data["s"].toInt(); - const int Hs = H * s, Ws = W * s; - - at::Tensor gx; - if (s == 1) - { - gx = go; // identity - } - else - { - gx = go.index({Slice(), Slice(), Slice(0, Hs, s), Slice(0, Ws, s)}).contiguous(); - } - return {gx, torch::Tensor()}; // no grad for scale - } -}; - -// exposed symbol for v4.cpp -at::Tensor sfold_upsample_cuda_launcher(const at::Tensor &x, int64_t scale) -{ - return SFoldFunction::apply(x, scale); -} - -// ====================================================================== -// BLOCK MEAN over non-overlapping s×s tiles -// forward: out[b,c,ho,wo] = mean_{i,j in s×s} in[b,c, ho*s+i, wo*s+j] -// backward: grad_in[b,c,hi,wi] = grad_out[b,c,hi/s, wi/s] / (s*s) -// dtypes: float/double/half/bfloat16 + complex64/complex128 -// ====================================================================== - -template -struct AccT -{ - using type = T; -}; -template <> -struct AccT -{ - using type = float; -}; -template <> -struct AccT -{ - using type = float; -}; - -template -__global__ void block_mean_kernel( - const scalar_t *__restrict__ in, // (B,C,Hs,Ws) - scalar_t *__restrict__ out, // (B,C,Ho,Wo) - int B, int C, int Ho, int Wo, int s, int Hs, int Ws, - long long total_out) -{ - using acc_t = typename AccT::type; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_out) - return; - - int wo = static_cast(idx % Wo); - int ho = static_cast((idx / Wo) % Ho); - int c = static_cast((idx / (1LL * Wo * Ho)) % C); - int b = static_cast(idx / (1LL * Wo * Ho * C)); - - const int hi0 = ho * s; - const int wi0 = wo * s; - - const long long base_in = ((long long)b * C + c) * Hs * Ws; - - acc_t acc = acc_t(0); - for (int di = 0; di < s; ++di) - { - const int hi = hi0 + di; - const long long row_off = base_in + (long long)hi * Ws + wi0; -#pragma unroll - for (int dj = 0; dj < s; ++dj) - { - acc += static_cast(in[row_off + dj]); - } - } - const float inv_area = 1.0f / (s * s); - acc = acc * static_cast(inv_area); - - const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; - out[out_off] = static_cast(acc); -} - -template -__global__ void block_mean_grad_kernel( - const scalar_t *__restrict__ grad_out, // (B,C,Ho,Wo) - scalar_t *__restrict__ grad_in, // (B,C,Hs,Ws) - int B, int C, int Ho, int Wo, int s, int Hs, int Ws, - long long total_in) -{ - using acc_t = typename AccT::type; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - int wi = static_cast(idx % Ws); - int hi = static_cast((idx / Ws) % Hs); - int c = static_cast((idx / (1LL * Ws * Hs)) % C); - int b = static_cast(idx / (1LL * Ws * Hs * C)); - - const int ho = hi / s; - const int wo = wi / s; - - const long long out_off = ((long long)b * C + c) * Ho * Wo + (long long)ho * Wo + wo; - acc_t g = static_cast(grad_out[out_off]) * static_cast(1.0f / (s * s)); - - const long long in_off = ((long long)b * C + c) * Hs * Ws + (long long)hi * Ws + wi; - grad_in[in_off] = static_cast(g); -} - -struct BlockMeanFunction : public torch::autograd::Function -{ - static at::Tensor forward(torch::autograd::AutogradContext *ctx, const at::Tensor &input, int64_t s) - { - TORCH_CHECK(input.is_cuda() && input.dim() == 4, "block_mean: input must be (B,C,Hs,Ws) CUDA"); - TORCH_CHECK(s >= 1, "block_mean: s must be >= 1"); - - auto x = input.contiguous(); - const int B = (int)x.size(0), C = (int)x.size(1), Hs = (int)x.size(2), Ws = (int)x.size(3); - TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "block_mean: H,W must be divisible by s"); - const int Ho = Hs / (int)s, Wo = Ws / (int)s; - - auto out = at::empty({B, C, Ho, Wo}, x.options()); - - const long long total_out = 1LL * B * C * Ho * Wo; - const int threads = 256, blocks = (int)((total_out + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "block_mean_fwd", [&] - { block_mean_kernel<<>>( - x.data_ptr(), out.data_ptr(), - B, C, Ho, Wo, (int)s, Hs, Ws, total_out); }); - - // save for backward - ctx->saved_data["B"] = (int64_t)B; - ctx->saved_data["C"] = (int64_t)C; - ctx->saved_data["Hs"] = (int64_t)Hs; - ctx->saved_data["Ws"] = (int64_t)Ws; - ctx->saved_data["s"] = (int64_t)s; - return out; - } - - static torch::autograd::tensor_list backward(torch::autograd::AutogradContext *ctx, - torch::autograd::tensor_list grad_outputs) - { - auto go = grad_outputs[0]; // (B,C,Ho,Wo) - const int B = (int)ctx->saved_data["B"].toInt(); - const int C = (int)ctx->saved_data["C"].toInt(); - const int Hs = (int)ctx->saved_data["Hs"].toInt(); - const int Ws = (int)ctx->saved_data["Ws"].toInt(); - const int s = (int)ctx->saved_data["s"].toInt(); - const int Ho = Hs / s, Wo = Ws / s; - - auto go_scaled = go / static_cast(s * s); - auto gi = go_scaled.view({B, C, Ho, 1, Wo, 1}) - .expand({B, C, Ho, s, Wo, s}) - .reshape({B, C, Hs, Ws}) - .contiguous(); - - return {gi, torch::Tensor()}; // no grad for s - } -}; - -at::Tensor block_mean_cuda(const at::Tensor &input, int64_t s) -{ - return BlockMeanFunction::apply(input, s); -} - -template -__global__ void fb_postprocess_kernel( - const thrust::complex* __restrict__ FB, - thrust::complex* __restrict__ FBC, - real_t* __restrict__ F2B, - int64_t N -) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) return; - - thrust::complex val = FB[idx]; - real_t re = val.real(); - real_t im = val.imag(); - F2B[idx] = re * re + im * im; - FBC[idx] = thrust::complex(re, -im); -} - -std::tuple fb_postprocess_cuda(const at::Tensor& FB) { - TORCH_CHECK(FB.is_cuda(), "FB must be CUDA tensor"); - TORCH_CHECK(FB.is_complex(),"FB must be complex"); - - auto FBc = FB.contiguous(); - const auto N = FBc.numel(); - - at::Tensor FBC = at::empty_like(FBc); - at::Tensor F2B = at::empty(FBc.sizes(), - FBc.scalar_type() == at::kComplexFloat ? FBc.options().dtype(at::kFloat) - : FBc.options().dtype(at::kDouble)); - - constexpr int threads = 256; - const int blocks = (static_cast(N) + threads - 1) / threads; - - AT_DISPATCH_COMPLEX_TYPES(FB.scalar_type(), "fb_postprocess_cuda", [&] { - using real_t = typename scalar_t::value_type; - fb_postprocess_kernel<<>>( - reinterpret_cast*>(FBc.data_ptr()), - reinterpret_cast*>(FBC.data_ptr()), - F2B.data_ptr(), - N - ); - }); - - return {FBC, F2B}; -} \ No newline at end of file diff --git a/test/profiler.py b/test/profiler.py deleted file mode 100644 index e25c5e2..0000000 --- a/test/profiler.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -import sys -import argparse -import torch -import pathlib -from torch.profiler import profile, record_function, ProfilerActivity - -PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] -sys.path.insert(0, str(PROJECT_ROOT)) -from models.util_converse import Converse2D - -def main(): - parser = argparse.ArgumentParser(description="Converse2D Profiler") - parser.add_argument("--backend", type=str, default="pytorch", choices=["pytorch", "cuda"], - help="Which backend to profile: pytorch or cuda") - parser.add_argument("--H", type=int, default=256, help="Input height") - parser.add_argument("--W", type=int, default=256, help="Input width") - parser.add_argument("--B", type=int, default=2, help="Batch size") - parser.add_argument("--C", type=int, default=8, help="Channels") - parser.add_argument("--scale", type=int, default=2, help="Upsampling scale") - args = parser.parse_args() - - os.environ["CONVERSE2D_BACKEND"] = args.backend.lower() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - print(f"[INFO] Running Converse2D with backend: {args.backend.upper()} on {device}") - - x = torch.randn(args.B, args.C, args.H, args.W, device=device) - model = Converse2D(args.C, args.C, kernel_size=5, scale=args.scale, padding=4, backend=args.backend).to(device) - - # warmup - for _ in range(10): - _ = model(x) - - trace_file = f"profile_converse2d_{args.backend.lower()}.json" - - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - profile_memory=True, - with_stack=False, - ) as prof: - with record_function(f"Converse2D::forward({args.backend.upper()})"): - y = model(x) - - print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=25)) - prof.export_chrome_trace(trace_file) - print(f"[INFO] Saved trace to: {trace_file}") - -if __name__ == "__main__": - main() diff --git a/test/results_2080ti.csv b/test/results_2080ti.csv deleted file mode 100644 index f50f58f..0000000 --- a/test/results_2080ti.csv +++ /dev/null @@ -1,361 +0,0 @@ -variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters -pytorch,1,3,128,128,1,3,1.3588724099099636,1.3492174912244081,1.3725919649004936,0.012057055453120557,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,3,0.718311658129096,0.712866079993546,0.7390974787995219,0.022809040915016645,1.8917588132277605,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,3,0.4002129239961505,0.3978795139119029,0.4132815869525075,0.0409382081827962,3.3953736334687536,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,3,0.3714838158339262,0.36968302447348833,0.3826369298622012,0.044104209393941815,3.657958575825157,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,3,0.36731038708239794,0.36497542168945074,0.3836334450170398,0.044605327200628854,3.6995207805139763,10,float32,cuda,50 -pytorch,1,3,128,128,1,5,2.02513235155493,2.008417039178312,2.028485178016126,0.008090335422976228,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,5,0.5582067556679249,0.5469518946483731,0.6023913389071822,0.029351131697422133,3.6279251925780596,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,5,0.5319257266819477,0.5326560931280255,0.5428183358162642,0.030801292695881242,3.8071712834559133,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,5,0.3715480910614133,0.37001201417297125,0.3865184495225549,0.04409657967342882,5.450525518162857,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,5,0.3673481987789273,0.3656859043985605,0.37831307854503393,0.04460073590794984,5.512841381246758,10,float32,cuda,50 -pytorch,1,3,128,128,1,7,1.6548363445326686,1.6421884065493941,1.6600201604887843,0.009900676918373397,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,7,0.53962841629982,0.5401384551078081,0.5467256065458059,0.03036163312588968,3.0666219467827918,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,7,0.40642036590725183,0.4043428925797343,0.41671814396977425,0.04031294042911952,4.071735777410118,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,7,0.3748922608792782,0.35778700839728117,0.4316947190091014,0.0437032227914567,4.414165127472596,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,7,0.37697022780776024,0.3745625726878643,0.387030653655529,0.04346231821881484,4.389832995980173,10,float32,cuda,50 -pytorch,1,3,128,128,2,3,1.8107919162139297,1.7898115329444408,1.8179329112172127,0.03619190002627418,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,3,0.6132322456687689,0.6145505467429757,0.6199152441695333,0.10686978785423916,2.9528648061211222,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,3,0.48999007791280746,0.4895030288025737,0.5026075756177306,0.13374964709318457,3.6955685387085646,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,3,0.4332400672137737,0.42917008977383375,0.45607127249240875,0.1512694807326363,4.179650159920297,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,3,0.42093018535524607,0.4206504672765732,0.4302357789129019,0.15569327712786998,4.301881830322297,10,float32,cuda,50 -pytorch,1,3,128,128,2,5,2.1734800469130278,2.159254509024322,2.198834274895489,0.030152565740403335,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,5,0.6376416468992829,0.6374488584697247,0.6438577082008123,0.10277873209613546,3.408623099639435,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,5,0.478832246735692,0.4765290068462491,0.4935482516884804,0.13686630432009075,4.539126305152032,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,5,0.4307904839515686,0.43061794713139534,0.44139688834547997,0.15212963712394317,5.0453297551423635,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,5,0.4390825191512704,0.4367530345916748,0.46043451875448227,0.14925668215322843,4.950049141364774,10,float32,cuda,50 -pytorch,1,3,128,128,2,7,2.1709761628881097,2.145085483789444,2.219454082660377,0.0301873420447029,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,7,0.8158506592735648,0.8078685496002436,0.8693302515894175,0.08032842684512065,2.660997007492954,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,7,0.48865099903196096,0.48417190555483103,0.5218630190938711,0.134116169065099,4.44279489285586,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,7,0.4235974932089448,0.42357854545116425,0.4345803987234831,0.15471290800975904,5.1250920925947225,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,7,0.42114653158932924,0.41657453402876854,0.45004035346210003,0.1556132962859251,5.154918775408755,10,float32,cuda,50 -pytorch,1,3,128,128,3,3,2.2250490309670568,2.2027044324204326,2.2900789277628064,0.06627089917920248,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,3,0.6517368508502841,0.6480665178969502,0.67496825940907,0.22625082471187952,3.414029800622999,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,3,0.48238967079669237,0.47929398715496063,0.4991346038877964,0.30567818700692434,4.612555296410616,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,3,0.4496808862313628,0.44654798693954945,0.46137943863868713,0.3279125364562043,4.948062279485589,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,3,0.426276377402246,0.423771096393466,0.4365326603874564,0.34591642374978826,5.219733367649035,10,float32,cuda,50 -pytorch,1,3,128,128,3,5,2.170854858122766,2.143112593330443,2.2381292656064034,0.06792531497361907,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,5,0.627955780364573,0.6281071109697223,0.6354789482429624,0.23481908218822556,3.4570186723377723,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,5,0.5390440486371517,0.5394599866122007,0.5951238563284278,0.27355092848684337,4.0272309166779054,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,5,0.44883018359541893,0.4477289039641619,0.45829941518604755,0.32853405450315853,4.836695341504932,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,5,0.4391154181212187,0.4359594313427806,0.4516175948083401,0.3358023743071906,4.943699921562529,10,float32,cuda,50 -pytorch,1,3,128,128,3,7,2.302859895862639,2.2876025177538395,2.3640060564503074,0.06403168523839518,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,7,1.1249064421281219,1.1208199430257082,1.1454877443611622,0.13108290118868873,2.047156820887291,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,7,0.4788805032148957,0.47789840027689934,0.4864738555625081,0.3079181528796333,4.8088403691583155,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,7,0.4445829102769494,0.4424919607117772,0.45168143697082996,0.33167266800278816,5.179821002179527,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,7,0.42853476013988256,0.4264645976945758,0.43528247624635696,0.34409344052246155,5.3737995375472885,10,float32,cuda,50 -pytorch,1,3,128,256,1,3,1.7951560160145164,1.7785559175536036,1.8055073218420148,0.018253566658094317,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,3,0.5566090485081077,0.5568780470639467,0.5647305399179459,0.058870764116804856,3.225164989369316,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,3,0.4132459592074156,0.4112694878131151,0.42841359972953796,0.07929418127365924,4.344037675425872,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,3,0.374758830294013,0.37343648727983236,0.3899722592905164,0.08743756611229739,4.790163355473561,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,3,0.3926960425451398,0.37166697438806295,0.43946001678705215,0.0834436725861157,4.57136263553806,10,float32,cuda,50 -pytorch,1,3,128,256,1,5,1.4699688693508506,1.4445290435105562,1.4817556831985712,0.022291628539365324,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,5,0.5615295935422182,0.5593735259026289,0.5792464362457395,0.05835489416202311,2.617794122083669,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,5,0.4279289720579982,0.42431592009961605,0.447837240062654,0.0765734552685507,3.4350767658508112,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,5,0.37391155026853085,0.37109002005308867,0.3892844310030341,0.08763569880755787,3.931327792081223,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,5,0.37634771782904863,0.37289841566234827,0.39316134061664343,0.08706841691248005,3.9058795887758406,10,float32,cuda,50 -pytorch,1,3,128,256,1,7,1.4711981965228915,1.4571724459528923,1.4904611511155963,0.022273001746090804,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,7,0.5647571152076125,0.5636163987219334,0.5756448954343796,0.058021402683796255,2.6050104671670384,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,7,0.41376128792762756,0.4112150054425001,0.43163460213690996,0.0791954224720307,3.5556690281286634,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,7,0.40381398517638445,0.4036114551126957,0.4396184580400586,0.08114627329136968,3.6432571692143787,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,7,0.364238740876317,0.36335503682494164,0.3761310363188386,0.08996297296977228,4.039104113371784,10,float32,cuda,50 -pytorch,1,3,128,256,2,3,2.44969944935292,2.412254922091961,2.47289901599288,0.05350533920992726,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,3,0.6509987637400627,0.6330730393528938,0.6970161804929376,0.20133986007435145,3.7629863308481806,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,3,0.48265908844769,0.4774669650942087,0.49854174721986055,0.27156227477565753,5.075423850883138,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,3,0.4738498851656914,0.4671409260481596,0.5251517286524177,0.2766108088306658,5.169779556866057,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,3,0.433603972196579,0.42632746044546366,0.46148020774126053,0.3022850536539299,5.649624095792007,10,float32,cuda,50 -pytorch,1,3,128,256,2,5,2.4055594438686967,2.3875508923083544,2.4414390325546265,0.05448711747035686,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,5,0.6060707569122314,0.6057386053726077,0.6157765863463283,0.21626517779504295,3.969106604193179,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,5,0.4737306758761406,0.4720538854598999,0.48464827705174685,0.2766804150007577,5.077905160816824,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,5,0.578413438051939,0.5772160366177559,0.5907158832997084,0.22660607685990572,4.158892732455306,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,5,0.44161519035696983,0.43730391189455986,0.46488025691360235,0.29680138469433276,5.447184554327074,10,float32,cuda,50 -pytorch,1,3,128,256,2,7,2.1301160426810384,2.1134009584784508,2.1731291199103,0.06153279791979229,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,7,0.6413558637723327,0.6351619958877563,0.6695011164993048,0.20436704083293092,3.3212700826528705,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,7,0.4867607494816184,0.4816199652850628,0.5202332977205515,0.26927397112356877,4.376104780324894,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,7,0.43859776109457016,0.43492461554706097,0.4642078885808587,0.2988432947603176,4.856650515873802,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,7,0.44836338609457016,0.447872094810009,0.47701222356408834,0.2923343075394517,4.75086973812743,10,float32,cuda,50 -pytorch,1,3,128,256,3,3,3.6672434210777283,3.676302614621818,3.76545328181237,0.0804178959882983,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,3,1.0867606522515416,1.0846543591469526,1.094102906063199,0.2713679404834945,3.3744720270097783,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,3,0.8823559479787946,0.8789199637249112,0.9115118300542235,0.334232461032934,4.1561950474502405,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,3,0.8446352742612362,0.8448135340586305,0.8729633176699281,0.3491589908531183,4.3418070886102855,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,3,0.7847000611945987,0.7841064361855388,0.7894640555605292,0.375827675546548,4.673433331322609,10,float32,cuda,50 -pytorch,1,3,128,256,3,5,3.0747908260673285,3.057163907214999,3.187051648274064,0.09591286584433902,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,5,1.089889518916607,1.0872494895011187,1.096566766500473,0.2705888944534067,2.821194967654879,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,5,0.871727392077446,0.8700459729880095,0.8790818741545081,0.3383075978571515,3.527238967149673,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,5,0.8495373092591763,0.8468780433759093,0.8814086206257343,0.3471442593347345,3.619371147747054,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,5,0.7938195066526532,0.7923004450276494,0.8032999699935317,0.3715101449743573,3.873413036967289,10,float32,cuda,50 -pytorch,1,3,128,256,3,7,3.572395290248096,3.608741913922131,3.6893750075250864,0.08255301444525165,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,7,1.090473192743957,1.088742632418871,1.1037262855097651,0.2704440622312898,3.2760046867900345,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,7,0.8857840858399868,0.8821134688332677,0.9158464381471276,0.3329389235079062,4.033031691758611,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,7,0.7986934995278716,0.7949564605951309,0.8138878736644983,0.36924302022531813,4.472798754916412,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,7,0.7899459125474095,0.7872970309108496,0.8032441604882479,0.37333188933020844,4.522328976585083,10,float32,cuda,50 -pytorch,1,3,256,128,1,3,1.4993645157665014,1.4857796486467123,1.5059428755193949,0.021854592165833953,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,3,0.7230725651606917,0.7242871215566993,0.7289966102689505,0.04531771993412281,2.073601721333862,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,3,0.405933721922338,0.4037924809381366,0.4139351425692439,0.08072253727732694,3.6936190190509848,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,3,0.38566675037145615,0.38364401552826166,0.39766388945281506,0.08496454508572336,3.88772045897756,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,3,0.36679978482425213,0.3645513206720352,0.3863898105919361,0.0893348397565184,4.087691917499091,10,float32,cuda,50 -pytorch,1,3,256,128,1,5,1.4416432147845626,1.4279410243034363,1.4532520435750484,0.022729618302193316,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,5,0.5475908936932683,0.5470785545185208,0.5588179919868708,0.0598402938715685,2.6327012216388237,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,5,0.437559150159359,0.43854652903974056,0.4863776499405503,0.07488816080766657,3.2947390410176918,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,5,0.37053246051073074,0.3683719551190734,0.38405111990869045,0.08843489705283467,3.89073392597627,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,5,0.3653422696515918,0.3640999784693122,0.37856195122003555,0.08969123674424304,3.946007167906916,10,float32,cuda,50 -pytorch,1,3,256,128,1,7,1.5084912767633796,1.5014259843155742,1.5185259049758315,0.021722366250805942,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,7,0.5360735580325127,0.5358878988772631,0.5449545569717884,0.06112593973160048,2.8139632223231015,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,7,0.40473042987287045,0.4033439327031374,0.4160322714596987,0.08096253106121211,3.7271506302039374,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,7,0.359728392213583,0.35842100623995066,0.3687401069328189,0.09109094725151559,4.193417337677747,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,7,0.3663349011912942,0.3646225668489933,0.3803261090070009,0.08944820680050104,4.117792959005207,10,float32,cuda,50 -pytorch,1,3,256,128,2,3,2.133134347386658,2.1139864111319184,2.160443994216621,0.06144573132985211,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,3,0.6306317308917642,0.6240959046408534,0.6758735282346606,0.20784238023458415,3.382535706489482,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,3,0.5141867883503437,0.4855890292674303,0.6295430706813931,0.2549112559280567,4.148559231228704,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,3,0.423141997307539,0.4228444304317236,0.4300167551264167,0.30975890087491614,5.041178519172842,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,3,0.4352907044813037,0.43021049350500107,0.4722143057733774,0.3011137124009725,4.900482195981006,10,float32,cuda,50 -pytorch,1,3,256,128,2,5,2.415369115769863,2.3783150827512145,2.4357105838134885,0.05426582593287103,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,5,0.6090639485046268,0.6084414198994637,0.6245741387829185,0.21520236146271315,3.9657069207594295,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,5,0.47608069609850645,0.4748969804495573,0.48580863513052464,0.2753146705466918,5.0734447659060224,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,5,0.44825855642557144,0.4457270260900259,0.45704932417720556,0.29240267279039234,5.388339120685384,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,5,0.43435935862362385,0.4245585296303034,0.47102225944399834,0.30175935523833164,5.560762230203958,10,float32,cuda,50 -pytorch,1,3,256,128,2,7,2.6145117403939366,2.644861117005348,2.7359860949218273,0.0501324962420138,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,7,0.6282500876113772,0.6161184282973409,0.6790379295125604,0.20863029322978538,4.161577995690182,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,7,0.630432553589344,0.6254641339182854,0.6581001449376345,0.2079080454423657,4.1471712168229775,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,7,0.4406739352270961,0.43863849714398384,0.47307766508311033,0.2974353360210733,5.932984756737607,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,7,0.4157876316457987,0.4142334219068289,0.42685249354690313,0.3152378522689142,6.288094068707623,10,float32,cuda,50 -pytorch,1,3,256,128,3,3,3.8604717003181577,3.8671459769830108,3.9452574448660016,0.07639273718175296,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,3,1.1453364789485931,1.1441544629633427,1.1537153273820877,0.25748939758797007,3.370600492758277,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,3,0.9098627092316747,0.9093486005440354,0.9173051686957479,0.32412802174191285,4.242916718257525,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,3,0.8363525243476033,0.8355780737474561,0.8449909742921591,0.3526168588180517,4.615842707391262,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,3,0.8323059324175119,0.8312843274325132,0.8499871473759413,0.35433124829880763,4.638284493665748,10,float32,cuda,50 -pytorch,1,3,256,128,3,5,3.06809326633811,3.044669982045889,3.175138426013291,0.09612224088350126,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,5,1.1442034784704447,1.1442825198173523,1.1537955841049552,0.25774436588344785,2.6814227749416517,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,5,0.9060597093775868,0.9040400618687272,0.9134287713095546,0.3254884826548444,3.386193243760636,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,5,0.8386831870302558,0.8376993937417865,0.8442188147455454,0.35163695249963434,3.658226746147264,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,5,0.8441900322213769,0.8293610299006104,0.8354945341125131,0.34934314401222816,3.6343632940853605,10,float32,cuda,50 -pytorch,1,3,256,128,3,7,4.190151221118867,4.013016470707953,5.129964789375663,0.07038218537641504,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,7,1.148314750753343,1.1480144457891583,1.1549783172085881,0.25682157248831405,3.648957063705705,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,7,0.908088181167841,0.9069349616765976,0.9139806730672717,0.32476141206983916,4.614255870757122,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,7,0.8402302768081427,0.8394863689318299,0.8471378590911627,0.3509894943566047,4.9869081569357,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,7,0.8234968921169639,0.8228260558098555,0.8296727202832699,0.35812157012744705,5.0882416937205965,10,float32,cuda,50 -pytorch,1,3,256,256,1,3,1.7476121010258794,1.7350299749523401,1.7694034380838275,0.03750031254734915,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,3,0.549295274540782,0.5471804179251194,0.5695500643923879,0.11930923683038953,3.1815531318504533,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,3,0.4180182237178087,0.41706988122314215,0.4277727100998163,0.15677785388667967,4.1807079257999975,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,3,0.39928949903696775,0.39886811282485723,0.4483650205656886,0.16413153904138217,4.376804562205827,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,3,0.37084896117448807,0.3665839321911335,0.39022082928568125,0.17671884476215285,4.7124632505134905,10,float32,cuda,50 -pytorch,1,3,256,256,1,5,1.6822041990235448,1.661254558712244,1.7334380885586143,0.03895840947135974,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,5,0.5354204308241606,0.5324805388227105,0.5574033362790942,0.12240100718443245,3.1418378944452416,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,5,0.3861092450097203,0.3840719582512975,0.3941373433917761,0.16973434551754937,4.356808910341412,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,5,0.36452691070735455,0.3633134765550494,0.3785670269280672,0.17978370889773043,4.614759979611034,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,5,0.3729759994894266,0.36907452158629894,0.3943886375054717,0.17571103794805398,4.510221036544827,10,float32,cuda,50 -pytorch,1,3,256,256,1,7,1.6930993692949414,1.6854360001161695,1.7194566549733281,0.03870771036155498,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,7,0.5525457859039307,0.5440320819616318,0.5869314773008227,0.11860736552860893,3.064179317782934,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,7,0.40626043919473886,0.40068302769213915,0.4402776714414358,0.16131523938166584,4.167522126079726,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,7,0.3619052888825536,0.3611855208873749,0.3710400080308318,0.18108605210593623,4.678294076670403,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,7,0.38414323702454567,0.3804925363510847,0.39720607455819845,0.17060302950436282,4.407468897302902,10,float32,cuda,50 -pytorch,1,3,256,256,2,3,3.5527321184054017,3.5565514117479324,3.642054391093552,0.07378659332121555,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,3,0.9840514045208693,0.9831475326791406,0.990622048266232,0.26639258761856743,3.6103115163330446,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,3,0.7950076553970575,0.7889129919931293,0.8149547036737204,0.3297377053169069,4.468802399935419,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,3,0.7262257812544703,0.7252609357237816,0.7328283973038197,0.36096763123332914,4.892049015759908,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,3,0.7457686774432659,0.744950957596302,0.7515277713537216,0.35150846090601934,4.763852687652834,10,float32,cuda,50 -pytorch,1,3,256,256,2,5,3.0734648229554296,2.9988345922902226,3.703110688365996,0.08529266319954933,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,5,0.9879587311297655,0.985164544545114,0.9992359671741724,0.265339018463078,3.1109242988731065,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,5,0.7879760023206472,0.7871531415730715,0.7931290892884135,0.33268018217301876,3.900454853832921,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,5,0.7290018862113357,0.7279976271092892,0.736223254352808,0.35959303392529934,4.215990220448401,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,5,0.7129816291853786,0.7123425602912903,0.7170673925429583,0.36767286739142807,4.310720917826502,10,float32,cuda,50 -pytorch,1,3,256,256,2,7,3.3563410444185138,3.334320499561727,3.463956923224032,0.07810410102273038,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,7,0.9848026884719729,0.9839965496212244,0.9889016160741448,0.2661893626699421,3.408135541979721,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,7,0.790087953209877,0.7883070502430201,0.7980464026331902,0.3317909087652735,4.248060017600273,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,7,0.7256730319932103,0.7235906086862087,0.7303511258214712,0.36124258232384293,4.625142311268799,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,7,0.7089367602020502,0.7080744253471494,0.7145792013034225,0.3697706406496508,4.734330666478546,10,float32,cuda,50 -pytorch,1,3,256,256,3,3,6.674171555787325,6.3839604845270514,7.9156504943966866,0.08837411431064435,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,3,2.0255626551806927,2.0256630377843976,2.0410686964169145,0.29119020262909817,3.294971665634282,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,3,1.6004600655287504,1.5984426718205214,1.6165184089913964,0.3685340313724963,4.170158131113601,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,3,1.4722054777666926,1.471800496801734,1.4807991916313767,0.4006397265242836,4.533451108952478,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,3,1.4688942860811949,1.4653319958597422,1.4887887286022305,0.40154285137398704,4.543670445878774,10,float32,cuda,50 -pytorch,1,3,256,256,3,5,5.245809820480645,5.211159586906433,5.4148891242221,0.11243716798447673,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,5,2.024001074023545,2.023787470534444,2.0335125038400292,0.29141486512528336,2.591801895664222,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,5,1.5998238697648048,1.5992920380085707,1.60962687805295,0.3686805848738286,3.278992093830834,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,5,1.5252059418708086,1.5229755081236362,1.5630704816430807,0.3867176122304673,3.439410820840473,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,5,1.4621745282784104,1.4624440809711814,1.4697926817461848,0.4033882334788508,3.587676928456107,10,float32,cuda,50 -pytorch,1,3,256,256,3,7,6.0111372359097,5.690280115231872,7.076818193309009,0.09812186560580803,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,7,2.0201045367866755,2.0184863824397326,2.029761392623186,0.2919769691415162,2.9756565199697294,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,7,1.6010280745103955,1.6005345387384295,1.6115479171276093,0.3684032837340294,3.754548300315061,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,7,1.4723994303494692,1.4696434373036027,1.4831635169684887,0.4005869520473851,4.08254520614897,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,7,1.4713413268327713,1.4690780080854893,1.4835603768005967,0.4008750309961475,4.085481136351517,10,float32,cuda,50 -pytorch,1,8,128,128,1,3,1.8926894385367632,1.8772950861603022,1.9192735198885202,0.008656465063104309,,10,float32,cuda,50 -cuda_v1,1,8,128,128,1,3,0.5577139323577285,0.5577560514211655,0.5638764938339591,0.029377067793047325,3.3936563688402814,10,float32,cuda,50 -cuda_v2,1,8,128,128,1,3,0.40895847138017416,0.4075943725183606,0.4168321844190359,0.04006274755650967,4.628072459653953,10,float32,cuda,50 -cuda_v3,1,8,128,128,1,3,0.46957184094935656,0.46929146628826857,0.48177684657275677,0.034891359683910474,4.030670652461229,10,float32,cuda,50 -cuda_v4,1,8,128,128,1,3,0.37016375456005335,0.37107395473867655,0.381774315610528,0.044261491834803476,5.113113899512556,10,float32,cuda,50 -pytorch,1,8,128,128,1,5,1.8416153965517879,1.8245259998366237,1.8423532135784626,0.008896537263251137,,10,float32,cuda,50 -cuda_v1,1,8,128,128,1,5,0.5451813340187073,0.5464649293571711,0.5536802345886827,0.030052386201905076,3.3779868855315485,10,float32,cuda,50 -cuda_v2,1,8,128,128,1,5,0.4040445853024721,0.400731572881341,0.4195025423541665,0.04054998036351548,4.557950937946947,10,float32,cuda,50 -cuda_v3,1,8,128,128,1,5,0.4850057791918516,0.4806459182873368,0.5031358683481812,0.03378104076883392,3.797099901820568,10,float32,cuda,50 -cuda_v4,1,8,128,128,1,5,0.41654150001704693,0.41544996201992035,0.44771814718842506,0.03933341575648401,4.421205081549905,10,float32,cuda,50 -pytorch,1,8,128,128,1,7,1.6017067013308406,1.5827155439183116,1.6301571391522884,0.01022908875038527,,10,float32,cuda,50 -cuda_v1,1,8,128,128,1,7,0.5676993587985635,0.5643828772008419,0.5944325588643551,0.028860346142849047,2.821399525129944,10,float32,cuda,50 -cuda_v2,1,8,128,128,1,7,0.5288944765925407,0.5266304360702634,0.5552147515118122,0.030977823980230752,3.0284050452748295,10,float32,cuda,50 -cuda_v3,1,8,128,128,1,7,0.36155916284769773,0.3598505863919854,0.3737625200301409,0.04531485212809157,4.429998921104758,10,float32,cuda,50 -cuda_v4,1,8,128,128,1,7,0.3603372583165765,0.3556694136932492,0.39303568191826344,0.045468514903351284,4.445021058365414,10,float32,cuda,50 -pytorch,1,8,128,128,2,3,2.2949183266609907,2.2692334605380893,2.3775779409334064,0.028557007558239384,,10,float32,cuda,50 -cuda_v1,1,8,128,128,2,3,0.6507242191582918,0.6433745147660375,0.6627354305237532,0.10071240330469096,3.526714173370487,10,float32,cuda,50 -cuda_v2,1,8,128,128,2,3,0.5056051397696137,0.5040100077167153,0.5108454264700413,0.1296189354994738,4.538953713379384,10,float32,cuda,50 -cuda_v3,1,8,128,128,2,3,0.45750402845442295,0.4539460642263293,0.463389465585351,0.14324682609112538,5.016170752449698,10,float32,cuda,50 -cuda_v4,1,8,128,128,2,3,0.4524831008166075,0.44547696597874165,0.46550282277166843,0.1448363483226789,5.071832124822551,10,float32,cuda,50 -pytorch,1,8,128,128,2,5,2.7334573213011026,2.684682374820113,3.138685319572687,0.023975497802469953,,10,float32,cuda,50 -cuda_v1,1,8,128,128,2,5,0.6499058287590742,0.6500064628198743,0.6569568300619721,0.1008392248537515,4.205928305829088,10,float32,cuda,50 -cuda_v2,1,8,128,128,2,5,0.5131116230040789,0.5048854509368539,0.5448845447972417,0.1277226963137396,5.32721770225691,10,float32,cuda,50 -cuda_v3,1,8,128,128,2,5,0.5604248121380806,0.5589330103248358,0.5785022862255573,0.11693986165596977,4.877473770072154,10,float32,cuda,50 -cuda_v4,1,8,128,128,2,5,0.44553374871611595,0.4445739323273301,0.4527335288003087,0.14709547860033845,6.135241896215589,10,float32,cuda,50 -pytorch,1,8,128,128,2,7,2.5094290915876627,2.4871930945664644,2.599560678936541,0.026115900313619444,,10,float32,cuda,50 -cuda_v1,1,8,128,128,2,7,0.654459148645401,0.6531750550493598,0.6606933427974582,0.10013764821783966,3.834355584732335,10,float32,cuda,50 -cuda_v2,1,8,128,128,2,7,0.5120710004121065,0.5095340311527252,0.5175400758162141,0.1279822523580864,4.900549122227415,10,float32,cuda,50 -cuda_v3,1,8,128,128,2,7,0.4739638837054372,0.47122593969106674,0.48143870662897825,0.1382721389816483,5.294557618966686,10,float32,cuda,50 -cuda_v4,1,8,128,128,2,7,0.45019406359642744,0.44682237785309553,0.45996704138815403,0.14557277694081097,5.574105245948376,10,float32,cuda,50 -pytorch,1,8,128,128,3,3,5.082319467328489,4.843820002861321,6.053852685727179,0.029013524424805582,,10,float32,cuda,50 -cuda_v1,1,8,128,128,3,3,1.3734258990734816,1.3736775144934654,1.3794763945043087,0.10736363723698117,3.7004686388665315,10,float32,cuda,50 -cuda_v2,1,8,128,128,3,3,1.0804085666313767,1.0780345182865858,1.0903888382017612,0.13648170197294476,4.704071796815472,10,float32,cuda,50 -cuda_v3,1,8,128,128,3,3,0.9943246794864535,0.9936904534697533,0.9973179781809449,0.1482976366192155,5.111327891361797,10,float32,cuda,50 -cuda_v4,1,8,128,128,3,3,0.979733420535922,0.9789994219318032,0.9875095216557384,0.15050624681083186,5.187451362584344,10,float32,cuda,50 -pytorch,1,8,128,128,3,5,3.830469772219658,3.782373503781855,3.919286490418017,0.038495539390342996,,10,float32,cuda,50 -cuda_v1,1,8,128,128,3,5,1.363285849802196,1.3614975614473224,1.373448595404625,0.10816220238873228,2.809733389938313,10,float32,cuda,50 -cuda_v2,1,8,128,128,3,5,1.0831660823896527,1.0821976466104388,1.0879256762564182,0.13613424792132192,3.5363642145893044,10,float32,cuda,50 -cuda_v3,1,8,128,128,3,5,0.9961470495909452,0.9955629939213395,0.9998968336731195,0.14802633814008773,3.8452854664304734,10,float32,cuda,50 -cuda_v4,1,8,128,128,3,5,0.9828188642859459,0.9822685969993472,0.9898525662720203,0.15003375022429208,3.897432081752557,10,float32,cuda,50 -pytorch,1,8,128,128,3,7,4.917632234282792,4.972781520336866,5.124774295836687,0.02998516216239696,,10,float32,cuda,50 -cuda_v1,1,8,128,128,3,7,1.3663292303681374,1.3654798967763782,1.372660114429891,0.1079212804078488,3.5991561367371228,10,float32,cuda,50 -cuda_v2,1,8,128,128,3,7,1.0768034821376204,1.0758364805951715,1.0850996943190694,0.13693863592201352,4.566879951502883,10,float32,cuda,50 -cuda_v3,1,8,128,128,3,7,0.9991599293425679,0.9962470503523946,1.0128907160833478,0.14757997760881364,4.921766866209816,10,float32,cuda,50 -cuda_v4,1,8,128,128,3,7,0.9830530360341072,0.9821520652621984,0.9894790593534708,0.149998010885431,5.002407860029409,10,float32,cuda,50 -pytorch,1,8,128,256,1,3,2.067403239198029,1.9350905204191804,2.2734523052349687,0.015849834893705162,,10,float32,cuda,50 -cuda_v1,1,8,128,256,1,3,0.5463245930150151,0.5399889778345823,0.5833456525579095,0.05997899493991736,3.7842031378975625,10,float32,cuda,50 -cuda_v2,1,8,128,256,1,3,0.4024812765419483,0.4006690578535199,0.41477736085653305,0.08141496737820245,5.136644509182666,10,float32,cuda,50 -cuda_v3,1,8,128,256,1,3,0.4817423736676574,0.48082845751196146,0.49287278670817614,0.06801975867417855,4.291512128065947,10,float32,cuda,50 -cuda_v4,1,8,128,256,1,3,0.3658107668161392,0.3623685333877802,0.37896146532148123,0.08957636836443796,5.6515647617259175,10,float32,cuda,50 -pytorch,1,8,128,256,1,5,1.814923589117825,1.8042140873149037,1.8416738137602806,0.01805475458938051,,10,float32,cuda,50 -cuda_v1,1,8,128,256,1,5,0.5642586294561625,0.5593795794993639,0.5676337750628591,0.058072660814389485,3.216474670253718,10,float32,cuda,50 -cuda_v2,1,8,128,256,1,5,0.4024905152618885,0.3992595011368394,0.4130690125748515,0.08141309858861853,4.5092331876116605,10,float32,cuda,50 -cuda_v3,1,8,128,256,1,5,0.36135319620370865,0.36130042281001806,0.37272502668201923,0.09068136201437506,5.022575165198437,10,float32,cuda,50 -cuda_v4,1,8,128,256,1,5,0.3723372332751751,0.3720894455909729,0.38131470791995525,0.08800624023486492,4.874408001459551,10,float32,cuda,50 -pytorch,1,8,128,256,1,7,2.195071200840175,2.182921045459807,2.221221197396517,0.014927989573849759,,10,float32,cuda,50 -cuda_v1,1,8,128,256,1,7,0.5654219072312117,0.562358065508306,0.5834365030750632,0.0579531843052564,3.8821827961868642,10,float32,cuda,50 -cuda_v2,1,8,128,256,1,7,0.40377453435212374,0.4015684826299548,0.4227354656904936,0.08115420169471034,5.436378508521534,10,float32,cuda,50 -cuda_v3,1,8,128,256,1,7,0.3754196735098958,0.37615804467350245,0.38602480199187994,0.08728365163616354,5.8469796756197825,10,float32,cuda,50 -cuda_v4,1,8,128,256,1,7,0.3583986544981599,0.35712961107492447,0.3721989458426833,0.09142891467012537,6.124663620497619,10,float32,cuda,50 -pytorch,1,8,128,256,2,3,4.291606065817177,4.2037880048155785,4.73125025164336,0.030541479807290328,,10,float32,cuda,50 -cuda_v1,1,8,128,256,2,3,1.2242945516481996,1.2185699306428432,1.2504691258072853,0.10705920386850132,3.5053705499543613,10,float32,cuda,50 -cuda_v2,1,8,128,256,2,3,0.9949567029252648,0.9701449889689684,0.9783105226233602,0.13173638572878216,4.313359619769844,10,float32,cuda,50 -cuda_v3,1,8,128,256,2,3,0.8927233191207051,0.8920715190470219,0.8991760900244117,0.14682264615771481,4.807319327162002,10,float32,cuda,50 -cuda_v4,1,8,128,256,2,3,0.8806978771463037,0.8800718933343887,0.8872309466823936,0.1488274281127011,4.872960611331467,10,float32,cuda,50 -pytorch,1,8,128,256,2,5,3.808657177723944,3.604250494390726,4.542697803117335,0.03441422892210233,,10,float32,cuda,50 -cuda_v1,1,8,128,256,2,5,1.21144263073802,1.2110269162803888,1.2173108290880919,0.10819497075164834,3.143902221274549,10,float32,cuda,50 -cuda_v2,1,8,128,256,2,5,0.9682021802291274,0.9672249434515834,0.9752721525728703,0.13537668337927258,3.9337415836252467,10,float32,cuda,50 -cuda_v3,1,8,128,256,2,5,0.925555769354105,0.9218446211889386,0.9484240086749196,0.1416143730501167,4.114994799699427,10,float32,cuda,50 -cuda_v4,1,8,128,256,2,5,0.8804069506004453,0.8796894690021873,0.8853658568114042,0.14887660747181486,4.326018979207747,10,float32,cuda,50 -pytorch,1,8,128,256,2,7,4.556696848012507,4.437481984496117,5.234006349928677,0.02876469608838903,,10,float32,cuda,50 -cuda_v1,1,8,128,256,2,7,1.2110048020258546,1.2096769642084837,1.2163812527433038,0.10823408774328021,3.7627405278573147,10,float32,cuda,50 -cuda_v2,1,8,128,256,2,7,0.9702552948147058,0.9701745584607124,0.9744886308908463,0.13509021872952667,4.696389571244463,10,float32,cuda,50 -cuda_v3,1,8,128,256,2,7,0.8937292965129018,0.8919144747778773,0.9041925193741918,0.1466573832942578,5.0985201736046335,10,float32,cuda,50 -cuda_v4,1,8,128,256,2,7,0.8822291437536478,0.8815015899017453,0.8899182314053178,0.1485691114695258,5.164981093942314,10,float32,cuda,50 -pytorch,1,8,128,256,3,3,8.531096149235964,8.202063501812518,10.174798616208136,0.03456906297163372,,10,float32,cuda,50 -cuda_v1,1,8,128,256,3,3,2.46146900113672,2.46087193954736,2.4681103182956576,0.11981138087207599,3.465855611139631,10,float32,cuda,50 -cuda_v2,1,8,128,256,3,3,1.9530037604272366,1.9528679549694061,1.961797708645463,0.15100431754185942,4.368192382471252,10,float32,cuda,50 -cuda_v3,1,8,128,256,3,3,1.7915777256712317,1.7895660130307078,1.803082088008523,0.1646102180074316,4.761778418538725,10,float32,cuda,50 -cuda_v4,1,8,128,256,3,3,1.7834124807268381,1.7811920261010528,1.78979211486876,0.16536387582070033,4.7835799297306005,10,float32,cuda,50 -pytorch,1,8,128,256,3,5,6.88839724753052,6.5992234740406275,8.008980075828731,0.04281286188971252,,10,float32,cuda,50 -cuda_v1,1,8,128,256,3,5,2.5040291901677847,2.5033794809132814,2.51198906917125,0.11777498487557134,2.750925298546123,10,float32,cuda,50 -cuda_v2,1,8,128,256,3,5,1.9446935411542654,1.9462434574961662,1.9529090961441398,0.15164960121426438,3.5421505248800993,10,float32,cuda,50 -cuda_v3,1,8,128,256,3,5,1.7982480069622397,1.7984320875257254,1.8055621068924665,0.16399962566798088,3.830615810978716,10,float32,cuda,50 -cuda_v4,1,8,128,256,3,5,1.7795015452429652,1.778624951839447,1.786777633242309,0.16572730762070453,3.870970084822269,10,float32,cuda,50 -pytorch,1,8,128,256,3,7,7.802273272536695,7.923941942863166,8.00987989641726,0.037798214661112155,,10,float32,cuda,50 -cuda_v1,1,8,128,256,3,7,2.446714648976922,2.456792979501188,2.4712163489311934,0.12053387595619929,3.1888774916188796,10,float32,cuda,50 -cuda_v2,1,8,128,256,3,7,1.95212263148278,1.9523354712873697,1.9586187787353992,0.15107247631056495,3.9968151317472804,10,float32,cuda,50 -cuda_v3,1,8,128,256,3,7,1.7935446230694652,1.792844501323998,1.800841884687543,0.1644296975980942,4.3501974649417505,10,float32,cuda,50 -cuda_v4,1,8,128,256,3,7,1.783656389452517,1.7828240524977446,1.7908786423504353,0.16534126289342171,4.37431408800187,10,float32,cuda,50 -pytorch,1,8,256,128,1,3,1.963044130243361,1.8950920784845948,2.2486007306724787,0.016692441853529656,,10,float32,cuda,50 -cuda_v1,1,8,256,128,1,3,0.5548413703218102,0.5544835003092885,0.5644256481900811,0.05905832144599172,3.5380276872735488,10,float32,cuda,50 -cuda_v2,1,8,256,128,1,3,0.4113389831036329,0.4055905155837536,0.42195883579552174,0.07966179075165461,4.772326987906203,10,float32,cuda,50 -cuda_v3,1,8,256,128,1,3,0.3597167832776904,0.3596280002966523,0.37039099261164665,0.09109388697803433,5.457193607583082,10,float32,cuda,50 -cuda_v4,1,8,256,128,1,3,0.3603894030675292,0.3588370746001601,0.37184106186032295,0.09092387212578497,5.447008467880863,10,float32,cuda,50 -pytorch,1,8,256,128,1,5,1.8885056534782052,1.820328994654119,2.098836237564683,0.01735128509657817,,10,float32,cuda,50 -cuda_v1,1,8,256,128,1,5,0.5418416438624263,0.5404235562309623,0.5499669583514333,0.06047523362438307,3.4853460874958095,10,float32,cuda,50 -cuda_v2,1,8,256,128,1,5,0.39879064075648785,0.39732293225824833,0.4064814653247595,0.08216842786942187,4.73558168239805,10,float32,cuda,50 -cuda_v3,1,8,256,128,1,5,0.36798678804188967,0.36877451930195093,0.3769081551581621,0.08904667521995345,5.131993090097647,10,float32,cuda,50 -cuda_v4,1,8,256,128,1,5,0.4367315862327814,0.4385111387819052,0.46610296703875065,0.07503006659686481,4.324179228180708,10,float32,cuda,50 -pytorch,1,8,256,128,1,7,1.8230937235057354,1.8126420909538865,1.8479618011042476,0.017973842802216696,,10,float32,cuda,50 -cuda_v1,1,8,256,128,1,7,0.7252306723967195,0.7257384713739157,0.7311144843697548,0.04518286560013981,2.5138122157476226,10,float32,cuda,50 -cuda_v2,1,8,256,128,1,7,0.40039450861513615,0.3968594828620553,0.41684405878186226,0.08183928424327364,4.553243574222228,10,float32,cuda,50 -cuda_v3,1,8,256,128,1,7,0.3655512351542711,0.3629459533840418,0.38629008922725916,0.08963996520534567,4.987245420566962,10,float32,cuda,50 -cuda_v4,1,8,256,128,1,7,0.3956288564950228,0.3936109133064747,0.42458868119865656,0.0828251010057762,4.608090875010961,10,float32,cuda,50 -pytorch,1,8,256,128,2,3,3.778682304546237,3.7858879659324884,3.9505227701738477,0.03468722412633199,,10,float32,cuda,50 -cuda_v1,1,8,256,128,2,3,1.2544773099943995,1.2530890526250005,1.2617591070011258,0.10448335649895905,3.012156755998326,10,float32,cuda,50 -cuda_v2,1,8,256,128,2,3,1.0024462034925818,0.998719478957355,1.0154257994145155,0.1307521536251396,3.7694614348192297,10,float32,cuda,50 -cuda_v3,1,8,256,128,2,3,0.9228101978078485,0.9217309998348355,0.931913498789072,0.14203570822186815,4.09475568597157,10,float32,cuda,50 -cuda_v4,1,8,256,128,2,3,0.9084250312298536,0.9071275126188993,0.9134697495028377,0.14428488372073006,4.15959729712703,10,float32,cuda,50 -pytorch,1,8,256,128,2,5,3.9364816807210445,3.7192939780652523,5.02905345056206,0.03329673821217722,,10,float32,cuda,50 -cuda_v1,1,8,256,128,2,5,1.2536142067983747,1.2501939199864864,1.2665793765336275,0.10455529244100295,3.1401061501803555,10,float32,cuda,50 -cuda_v2,1,8,256,128,2,5,0.9965211106464267,0.9953664848580956,1.001305691897869,0.13152957684456457,3.9502240731935063,10,float32,cuda,50 -cuda_v3,1,8,256,128,2,5,0.9432033356279135,0.9403220610693097,0.9684782708063722,0.13896473331780806,4.173523917936987,10,float32,cuda,50 -cuda_v4,1,8,256,128,2,5,0.9114051656797528,0.9110620012506843,0.9172238875180483,0.14381309755057473,4.319134704250991,10,float32,cuda,50 -pytorch,1,8,256,128,2,7,4.343443512916565,4.256172454915941,4.450874077156186,0.030176978153443713,,10,float32,cuda,50 -cuda_v1,1,8,256,128,2,7,1.2559509836137295,1.2548479717224836,1.2616410618647933,0.10436076065872286,3.4582906256574106,10,float32,cuda,50 -cuda_v2,1,8,256,128,2,7,1.0031307814642787,1.0019369656220078,1.010197401046753,0.13066292294278226,4.329887581135137,10,float32,cuda,50 -cuda_v3,1,8,256,128,2,7,0.9242268512025476,0.9237965568900108,0.9295360650867224,0.14181799612233414,4.69954265802291,10,float32,cuda,50 -cuda_v4,1,8,256,128,2,7,0.908985547721386,0.9086495265364647,0.9132696315646172,0.14419591194663856,4.778341662091946,10,float32,cuda,50 -pytorch,1,8,256,128,3,3,8.587907194159925,8.186906110495329,10.415960941463709,0.03434038041311746,,10,float32,cuda,50 -cuda_v1,1,8,256,128,3,3,2.5986746698617935,2.597782062366605,2.6080208364874125,0.11348554069512842,3.304725787247797,10,float32,cuda,50 -cuda_v2,1,8,256,128,3,3,2.026054351590574,2.0258149597793818,2.035799017176032,0.14555976732237044,4.2387348530003175,10,float32,cuda,50 -cuda_v3,1,8,256,128,3,3,1.861856454052031,1.8611680716276169,1.8698341678828,0.1583967439370374,4.612550648289629,10,float32,cuda,50 -cuda_v4,1,8,256,128,3,3,1.8496425542980433,1.8488640198484063,1.8567960010841489,0.15944269843635916,4.643009090704618,10,float32,cuda,50 -pytorch,1,8,256,128,3,5,6.725382432341576,6.6917366348207,6.833369680680335,0.04385059183873362,,10,float32,cuda,50 -cuda_v1,1,8,256,128,3,5,2.5601254729554057,2.559990040026605,2.569140726700425,0.11519435399373373,2.6269737571017573,10,float32,cuda,50 -cuda_v2,1,8,256,128,3,5,2.0295281894505024,2.029025577940047,2.0362793002277613,0.14531062023821795,3.3137664543415295,10,float32,cuda,50 -cuda_v3,1,8,256,128,3,5,1.8621120229363441,1.8612504936754704,1.8688065931200981,0.15837500449353012,3.6116959396118355,10,float32,cuda,50 -cuda_v4,1,8,256,128,3,5,1.8565295031294227,1.855071634054184,1.8653924344107509,0.15885123263750311,3.6225561840008824,10,float32,cuda,50 -pytorch,1,8,256,128,3,7,9.338932940736413,9.335665381513536,9.872243599966168,0.03157876835303038,,10,float32,cuda,50 -cuda_v1,1,8,256,128,3,7,2.5553510058671236,2.5530324783176184,2.5751939741894603,0.11540958534576178,3.654657586881053,10,float32,cuda,50 -cuda_v2,1,8,256,128,3,7,2.02504463493824,2.0252445247024298,2.030028752051294,0.14563234553543272,4.611717084952664,10,float32,cuda,50 -cuda_v3,1,8,256,128,3,7,1.867524804547429,1.8653111765161157,1.8787236418575048,0.1579159748143041,5.000700883863004,10,float32,cuda,50 -cuda_v4,1,8,256,128,3,7,1.857128101401031,1.8561474280431867,1.8686209805309772,0.15880003096044706,5.028696153857697,10,float32,cuda,50 -pytorch,1,8,256,256,1,3,2.6666148006916046,2.5547059485688806,2.9847467551007867,0.024576478006123267,,10,float32,cuda,50 -cuda_v1,1,8,256,256,1,3,0.6383521435782313,0.6306543946266174,0.6635474506765604,0.10266433763759804,4.1773413428896955,10,float32,cuda,50 -cuda_v2,1,8,256,256,1,3,0.48515914008021355,0.48316200263798237,0.49340776167809963,0.13508144974691116,5.496371356109505,10,float32,cuda,50 -cuda_v3,1,8,256,256,1,3,0.4773047938942909,0.470960047096014,0.496916426345706,0.13730429871717215,5.586817553066902,10,float32,cuda,50 -cuda_v4,1,8,256,256,1,3,0.46722331549972296,0.4641955019906163,0.476591382175684,0.14026697261438115,5.707366717860607,10,float32,cuda,50 -pytorch,1,8,256,256,1,5,2.2214805241674185,2.1969579393044114,2.284339559264481,0.029501046390925274,,10,float32,cuda,50 -cuda_v1,1,8,256,256,1,5,0.641844249330461,0.6362345302477479,0.6586905103176832,0.10210576797776687,3.461089705929052,10,float32,cuda,50 -cuda_v2,1,8,256,256,1,5,0.4858332918956876,0.4853704012930393,0.4932035459205508,0.13489400807483387,4.5725160486622,10,float32,cuda,50 -cuda_v3,1,8,256,256,1,5,0.46678606420755386,0.46398292761296034,0.47364868223667145,0.14039836452970836,4.7590977848465705,10,float32,cuda,50 -cuda_v4,1,8,256,256,1,5,0.4626097623258829,0.46146148815751076,0.46768535394221544,0.14166583876332794,4.802061489144513,10,float32,cuda,50 -pytorch,1,8,256,256,1,7,2.6318022841587663,2.5641301181167364,2.948883641511202,0.024901566654330964,,10,float32,cuda,50 -cuda_v1,1,8,256,256,1,7,0.6335034826770425,0.6320229731500149,0.6402655737474561,0.10345010215738622,4.154361193149823,10,float32,cuda,50 -cuda_v2,1,8,256,256,1,7,0.48780177254229784,0.4840875044465065,0.5006202263757586,0.1343496553086373,5.395229030108864,10,float32,cuda,50 -cuda_v3,1,8,256,256,1,7,0.4626284958794713,0.4613135242834687,0.46950376126915216,0.14166010218504593,5.688802803112306,10,float32,cuda,50 -cuda_v4,1,8,256,256,1,7,0.46635448932647705,0.4646843299269676,0.4732209723442793,0.14052829231825134,5.643351451295543,10,float32,cuda,50 -pytorch,1,8,256,256,2,3,7.843744731508195,7.943923585116863,8.010599622502923,0.03342077145205552,,10,float32,cuda,50 -cuda_v1,1,8,256,256,2,3,2.281022099778056,2.2740440908819437,2.3007393116131425,0.1149239194243259,3.4386973858216425,10,float32,cuda,50 -cuda_v2,1,8,256,256,2,3,1.8037663120776415,1.8034030217677355,1.8102034227922559,0.14533146463859462,4.348537102055917,10,float32,cuda,50 -cuda_v3,1,8,256,256,2,3,1.6670777183026075,1.6650534234941006,1.679345965385437,0.15724761786565708,4.705086418823097,10,float32,cuda,50 -cuda_v4,1,8,256,256,2,3,1.6558474162593484,1.6553474124521017,1.6613460145890713,0.15831410396025372,4.736997294851993,10,float32,cuda,50 -pytorch,1,8,256,256,2,5,7.005033129826188,6.712923059239984,8.13796438742429,0.03742223557570876,,10,float32,cuda,50 -cuda_v1,1,8,256,256,2,5,2.276116036809981,2.274753525853157,2.2865735460072756,0.11517163262352816,3.0776256643065847,10,float32,cuda,50 -cuda_v2,1,8,256,256,2,5,1.8342435453087091,1.8338535446673632,1.840946963056922,0.14291668119561532,3.819031091995593,10,float32,cuda,50 -cuda_v3,1,8,256,256,2,5,1.671182638965547,1.6702709253877401,1.679073041304946,0.15686137103618172,4.191662219613689,10,float32,cuda,50 -cuda_v4,1,8,256,256,2,5,1.663348381407559,1.6634294297546148,1.6699650092050433,0.15760017740731405,4.211404663103942,10,float32,cuda,50 -pytorch,1,8,256,256,2,7,6.85543866828084,6.597880506888032,7.352691376581788,0.03823883673745107,,10,float32,cuda,50 -cuda_v1,1,8,256,256,2,7,2.254212941043079,2.2627164144068956,2.27433773688972,0.11629069961717975,3.041167293232139,10,float32,cuda,50 -cuda_v2,1,8,256,256,2,7,1.8175022071227431,1.8079809378832579,1.837374921888113,0.14423311233002337,3.771901151709504,10,float32,cuda,50 -cuda_v3,1,8,256,256,2,7,1.6725403722375631,1.6703285509720445,1.689420617185533,0.15673403425789817,4.09881805071735,10,float32,cuda,50 -cuda_v4,1,8,256,256,2,7,1.6611922485753894,1.660865033045411,1.6678510000929236,0.15780473345262133,4.126818358416943,10,float32,cuda,50 -pytorch,1,8,256,256,3,3,14.831915474496782,14.468007488176227,15.825923206284642,0.039767216919095315,,10,float32,cuda,50 -cuda_v1,1,8,256,256,3,3,4.795580897480249,4.861502558924258,4.891411936841905,0.1229932332723054,3.092829793005878,10,float32,cuda,50 -cuda_v2,1,8,256,256,3,3,3.8363351486623287,3.8457076298072934,3.8683589547872543,0.15374673409482031,3.86616781374496,10,float32,cuda,50 -cuda_v3,1,8,256,256,3,3,3.5528057161718607,3.5565325524657965,3.5763337276875973,0.16601639580661728,4.174704911947207,10,float32,cuda,50 -cuda_v4,1,8,256,256,3,3,3.4883907111361623,3.528101951815188,3.555237129330635,0.16908197757695997,4.251793076718191,10,float32,cuda,50 -pytorch,1,8,256,256,3,5,12.49826724641025,12.370950891636312,13.13346887473017,0.04719246183261197,,10,float32,cuda,50 -cuda_v1,1,8,256,256,3,5,4.857695340178907,4.866057075560093,4.883602284826338,0.12142054177861986,2.57288001226317,10,float32,cuda,50 -cuda_v2,1,8,256,256,3,5,3.828923678956926,3.873889450915158,3.933584224432707,0.15404433450621288,3.264172465776342,10,float32,cuda,50 -cuda_v3,1,8,256,256,3,5,3.510385132394731,3.535768366418779,3.565721120685339,0.16802258947514145,3.560369240136373,10,float32,cuda,50 -cuda_v4,1,8,256,256,3,5,3.5362447844818234,3.54151357896626,3.5553407855331898,0.16679388332740339,3.5343331720860096,10,float32,cuda,50 -pytorch,1,8,256,256,3,7,13.576734396629035,13.386719045229256,14.27298269700259,0.043443731222026945,,10,float32,cuda,50 -cuda_v1,1,8,256,256,3,7,4.754021079279482,4.852510988712311,4.898783378303051,0.12406844441030404,2.855842279665853,10,float32,cuda,50 -cuda_v2,1,8,256,256,3,7,3.7703912518918514,3.8328804075717926,3.857595380395651,0.15643575443372004,3.600882107345412,10,float32,cuda,50 -cuda_v3,1,8,256,256,3,7,3.478135773912072,3.5281829768791795,3.553933184593916,0.16958049896269256,3.9034515266660943,10,float32,cuda,50 -cuda_v4,1,8,256,256,3,7,3.560623242519796,3.560943529009819,3.572101565077901,0.1656518985093719,3.8130218986665376,10,float32,cuda,50 diff --git a/test/results_4090.csv b/test/results_4090.csv index ecca0b3..e01b4d1 100644 --- a/test/results_4090.csv +++ b/test/results_4090.csv @@ -1,181 +1,73 @@ -variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters -pytorch,1,3,128,128,1,3,1.52592733502388,0.8647029753774405,2.298735734075308,0.010737077463615656,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,3,0.44544885866343975,0.4432220011949539,0.472044013440609,0.036780877717724675,3.425594892312428,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,3,0.3007301315665245,0.29844650998711586,0.3108557313680649,0.054480739640735645,5.074075308234717,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,3,0.30307079665362835,0.2994614187628031,0.3240731079131365,0.0540599760217902,5.03488739882722,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,3,0.32072700560092926,0.32319221645593643,0.3775870893150568,0.051083942773394356,4.757713907392458,10,float32,cuda,50 -pytorch,1,3,128,128,1,5,4.0221707709133625,0.9404211305081844,7.168814446777105,0.004073422271993561,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,5,0.4805893823504448,0.4761132877320051,0.5044737830758095,0.0340914730988643,8.369246010475537,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,5,0.3053080663084984,0.3032265231013298,0.3142551053315401,0.053663829449709974,13.174138566156198,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,5,0.2769072540104389,0.2749105915427208,0.28512105345726013,0.05916782519313254,14.525335514546446,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,5,0.27410948649048805,0.27215038426220417,0.2808789722621441,0.05977173650488943,14.673592010296721,10,float32,cuda,50 -pytorch,1,3,128,128,1,7,3.5734746791422367,0.8284670766443014,10.36820076406002,0.004584893268065006,,10,float32,cuda,50 -cuda_v1,1,3,128,128,1,7,0.5126208532601595,0.5032145418226719,0.5624458193778992,0.03196124366732498,6.970989682561095,10,float32,cuda,50 -cuda_v2,1,3,128,128,1,7,0.30147168785333633,0.299196457490325,0.31120297499001026,0.05434672859884173,11.853433748912119,10,float32,cuda,50 -cuda_v3,1,3,128,128,1,7,0.2754225581884384,0.27269101701676846,0.28396081179380417,0.05948677591176249,12.974517057158911,10,float32,cuda,50 -cuda_v4,1,3,128,128,1,7,0.27901765890419483,0.2780151553452015,0.2874089404940605,0.058720297720029645,12.807342349500704,10,float32,cuda,50 -pytorch,1,3,128,128,2,3,5.1218782644718885,1.200301107019186,8.238791720941663,0.012795306060784977,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,3,0.48394261859357357,0.46004820615053177,0.5499029066413641,0.13542101373600798,10.583647869983038,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,3,0.3488078713417053,0.3472878597676754,0.3549169283360243,0.18788566825603104,14.683952643529377,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,3,0.34828370437026024,0.34455815330147743,0.36369492299854755,0.18816843618479692,14.706051992104753,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,3,0.3102908004075289,0.30851690098643303,0.3184992354363203,0.2112083243007092,16.506703575307196,10,float32,cuda,50 -pytorch,1,3,128,128,2,5,2.7414161060005426,1.1525587178766727,3.5016948357224464,0.02390589296406032,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,5,0.5049472488462925,0.4589471500366926,0.6366008426994085,0.1297878147662695,5.429113857465611,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,5,0.3760635666549206,0.37418887950479984,0.38898889906704426,0.1742684104789562,7.28976787191962,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,5,0.3826252557337284,0.3617340698838234,0.443447008728981,0.17127985938703158,7.1647547173632296,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,5,0.33498174510896206,0.326463021337986,0.3694041632115841,0.1956405116305147,8.183777611848136,10,float32,cuda,50 -pytorch,1,3,128,128,2,7,6.999819874763489,1.7441920936107635,18.960934737697244,0.009362526632474858,,10,float32,cuda,50 -cuda_v1,1,3,128,128,2,7,0.4698681924492121,0.46762311831116676,0.4844237584620714,0.13947741314088583,14.897411630007488,10,float32,cuda,50 -cuda_v2,1,3,128,128,2,7,0.37914127111434937,0.36833412013947964,0.40491526015102863,0.1728537750780349,18.462299960618996,10,float32,cuda,50 -cuda_v3,1,3,128,128,2,7,0.44248790480196476,0.44069206342101097,0.4555768799036741,0.14810800315396327,15.819234376352625,10,float32,cuda,50 -cuda_v4,1,3,128,128,2,7,0.3105806838721037,0.3085271455347538,0.3185570705682039,0.2110111909824616,22.53784680841902,10,float32,cuda,50 -pytorch,1,3,128,128,3,3,7.3913106974214315,1.91159313544631,23.464363627135754,0.019949912273535222,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,3,0.523065309971571,0.5227448418736458,0.5595123395323753,0.281907435245542,14.130760645975844,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,3,0.3566680010408163,0.3553489223122597,0.3626151941716671,0.4134264906571348,20.72322349033937,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,3,0.33742536790668964,0.3273128531873226,0.36468892358243465,0.43700330213695415,21.90502374873426,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,3,0.31496345065534115,0.31304731965065,0.3236026968806982,0.46816860716121145,23.467201296031043,10,float32,cuda,50 -pytorch,1,3,128,128,3,5,3.313328195363283,1.4438305515795946,5.463926354423165,0.044503891949596766,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,5,0.4630151018500328,0.4609823226928711,0.4760188050568104,0.31846909401188367,7.155982995207888,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,5,0.36386603489518166,0.35751843824982643,0.3770098090171814,0.40524804697002687,9.105901286768216,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,5,0.32810534350574017,0.3216217737644911,0.356891006231308,0.44941663681688954,10.098367066994497,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,5,0.3466991614550352,0.3456580452620983,0.357994856312871,0.4253139793622609,9.556781682014543,10,float32,cuda,50 -pytorch,1,3,128,128,3,7,6.756937270984054,1.7321815248578787,10.7049988117069,0.021822904976965263,,10,float32,cuda,50 -cuda_v1,1,3,128,128,3,7,0.4557925648987293,0.45467307791113853,0.465529877692461,0.32351558879149916,14.824588620670786,10,float32,cuda,50 -cuda_v2,1,3,128,128,3,7,0.38067104294896126,0.3792489878833294,0.39135636761784554,0.3873580686823354,17.75006897987246,10,float32,cuda,50 -cuda_v3,1,3,128,128,3,7,0.32470209524035454,0.3225074615329504,0.3330751322209835,0.454127035708989,20.80965096939815,10,float32,cuda,50 -cuda_v4,1,3,128,128,3,7,0.3528321161866188,0.3392628859728575,0.38030720315873623,0.41792113936138414,19.150573207486012,10,float32,cuda,50 -pytorch,1,3,128,256,1,3,4.7790092043578625,1.1113823857158422,11.791642662137747,0.006856651368262623,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,3,0.43059052899479866,0.4229459445923567,0.4630208481103182,0.07610014107020878,11.098732746199326,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,3,0.3584872093051672,0.3435080870985985,0.4171540029346943,0.09140632956894645,13.331045237627055,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,3,0.31005123630166054,0.30543701723217964,0.3326671663671732,0.10568575823422545,15.413611186856176,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,3,0.274530379101634,0.2726605162024498,0.28227311559021473,0.11936019651897593,17.407943048039222,10,float32,cuda,50 -pytorch,1,3,128,256,1,5,3.1171874701976776,0.8945895824581385,2.6657020207494497,0.010512040200752509,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,5,0.48769986256957054,0.4741228185594082,0.5690208170562983,0.06718886453515381,6.391610310845659,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,5,0.34617194905877113,0.3371278289705515,0.35091196186840534,0.0946581607466896,9.004737323960525,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,5,0.29009729623794556,0.290280906483531,0.31339898705482483,0.11295520649431635,10.745317211232734,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,5,0.27919040992856026,0.2755909226834774,0.2894400618970394,0.11736792824791058,11.165095072553921,10,float32,cuda,50 -pytorch,1,3,128,256,1,7,0.8610220160335302,0.8528372272849083,0.8728299289941788,0.03805709887762491,,10,float32,cuda,50 -cuda_v1,1,3,128,256,1,7,0.4062088765203953,0.4038410261273384,0.4159193020313978,0.08066785807511706,2.119653374906738,10,float32,cuda,50 -cuda_v2,1,3,128,256,1,7,0.3033390734344721,0.29832683503627777,0.3132038749754429,0.10802432943766017,2.8384804050623895,10,float32,cuda,50 -cuda_v3,1,3,128,256,1,7,0.2987572457641363,0.2943666186183691,0.31618126668035984,0.10968102184831952,2.8820121628557662,10,float32,cuda,50 -cuda_v4,1,3,128,256,1,7,0.2784122433513403,0.2752200234681368,0.2866980619728565,0.11769597344412998,3.09261548870525,10,float32,cuda,50 -pytorch,1,3,128,256,2,3,4.812479577958584,1.4287945814430714,9.668499417603016,0.027235855836213175,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,3,0.4676768183708191,0.4681474529206753,0.48343208618462086,0.2802619134653656,10.290181999448148,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,3,0.3709117043763399,0.3670940641313791,0.3977825865149498,0.35337790221634474,12.974730970138582,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,3,0.40193247608840466,0.41314586997032166,0.4363299813121557,0.32610452699813897,11.973353397051905,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,3,0.34459170885384083,0.34502334892749786,0.3664530348032713,0.3803689892480681,13.965743963966949,10,float32,cuda,50 -pytorch,1,3,128,256,2,5,3.534023268148303,1.3921631034463644,7.826935639604926,0.03708860696570254,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,5,0.4526952747255564,0.45057223178446293,0.4625048488378525,0.28953692984637747,7.806627251169746,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,5,0.35566513426601887,0.34616305492818356,0.3815658390522003,0.36852642379605594,9.936378148061708,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,5,0.41979328729212284,0.41792611591517925,0.45677535235881805,0.3122298616194654,8.418484466353727,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,5,0.35099745728075504,0.343183521181345,0.3787568770349026,0.37342720661123885,10.068515297880111,10,float32,cuda,50 -pytorch,1,3,128,256,2,7,5.221625966951251,1.6814591363072395,8.201262401416898,0.025101759649117296,,10,float32,cuda,50 -cuda_v1,1,3,128,256,2,7,0.5324110481888056,0.5313355941325426,0.5536912009119987,0.24618572519464088,9.80750866217851,10,float32,cuda,50 -cuda_v2,1,3,128,256,2,7,0.35434636287391186,0.3511281684041023,0.36454498767852783,0.3698979691422422,14.735937811246218,10,float32,cuda,50 -cuda_v3,1,3,128,256,2,7,0.3654781263321638,0.3502380568534136,0.43992577120661736,0.3586315857405801,14.287109380126324,10,float32,cuda,50 -cuda_v4,1,3,128,256,2,7,0.3121230937540531,0.3096370492130518,0.31899111345410347,0.41993688587260436,16.729380399727134,10,float32,cuda,50 -pytorch,1,3,128,256,3,3,10.960625801235437,2.763780066743493,45.61858847737312,0.026906492872583856,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,3,0.46120816841721535,0.45501673594117165,0.4788396880030632,0.6394336011265492,23.765029658625433,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,3,0.36482485942542553,0.361383892595768,0.3778459504246712,0.8083659662460131,30.04352778617576,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,3,0.4006952326744795,0.3770939074456692,0.4836510866880417,0.7360007705396968,27.354020980179047,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,3,0.32492942176759243,0.320731895044446,0.33880281262099743,0.9076186403671916,33.732327905581556,10,float32,cuda,50 -pytorch,1,3,128,256,3,5,11.318621216341853,2.605273388326168,45.752703258767724,0.026055470393708847,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,5,0.4699833784252405,0.4665427841246128,0.48548299819231033,0.6274945318027054,24.083024498157414,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,5,0.4197291377931833,0.42938650585711,0.4759266972541809,0.7026245581866524,26.96648909306595,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,5,0.3752768971025944,0.3709190059453249,0.4119148012250662,0.7858517331520571,30.160719468023988,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,5,0.36198515444993973,0.35343365743756294,0.3866641316562891,0.8147074441440512,31.26819173991056,10,float32,cuda,50 -pytorch,1,3,128,256,3,7,7.481619408354163,2.37233005464077,25.27893357910216,0.039418203988122395,,10,float32,cuda,50 -cuda_v1,1,3,128,256,3,7,0.4637504182755947,0.46098814345896244,0.47804401256144047,0.6359282673999478,16.132857488676233,10,float32,cuda,50 -cuda_v2,1,3,128,256,3,7,0.3909336030483246,0.3640139475464821,0.45265606604516506,0.7543787428361459,19.137826347021225,10,float32,cuda,50 -cuda_v3,1,3,128,256,3,7,0.3465086594223976,0.34476793371140957,0.35570100881159306,0.8510956132859561,21.59143561036953,10,float32,cuda,50 -cuda_v4,1,3,128,256,3,7,0.32119077630341053,0.3183919470757246,0.33296276815235615,0.9181832784681636,23.293381878708452,10,float32,cuda,50 -pytorch,1,3,256,128,1,3,3.82584142498672,1.0650705080479383,7.791366055607796,0.008564913272670139,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,3,0.48191010020673275,0.4935292527079582,0.54588015191257,0.06799608471775748,7.938911061099336,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,3,0.31497727148234844,0.31047710217535496,0.33008200116455555,0.10403290321802264,12.146404745274209,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,3,0.2761990111321211,0.2733948640525341,0.28623687103390694,0.11863909239097628,13.85175641760937,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,3,0.2851138450205326,0.28066104277968407,0.2989410888403654,0.11492952928203187,13.418644838910577,10,float32,cuda,50 -pytorch,1,3,256,128,1,5,3.586227549239993,0.8654326666146517,11.385623132809997,0.009137178148928202,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,5,0.423511927947402,0.4099758807569742,0.48163579776883125,0.07737208290404897,8.46783127601871,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,5,0.2990085817873478,0.296260928735137,0.30720722861588,0.10958882786616574,11.993727831499116,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,5,0.275130495429039,0.27071102522313595,0.28432300314307213,0.1190998473248177,13.034642138261058,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,5,0.27801617980003357,0.2752654254436493,0.28759418055415154,0.11786364384824212,12.899348346630148,10,float32,cuda,50 -pytorch,1,3,256,128,1,7,3.53361826390028,1.0452242568135262,5.550267640501261,0.009273214465399514,,10,float32,cuda,50 -cuda_v1,1,3,256,128,1,7,0.4154033772647381,0.412175664678216,0.4310780204832554,0.07888236300764795,8.50647456736562,10,float32,cuda,50 -cuda_v2,1,3,256,128,1,7,0.300332996994257,0.29647164046764374,0.3192121163010597,0.10910556058755874,11.765667772988163,10,float32,cuda,50 -cuda_v3,1,3,256,128,1,7,0.35520353354513645,0.35250792279839516,0.39303875528275967,0.09225133453194125,9.948150652198553,10,float32,cuda,50 -cuda_v4,1,3,256,128,1,7,0.28812913224101067,0.28413604013621807,0.30301674269139767,0.11372678543518687,12.264008975477436,10,float32,cuda,50 -pytorch,1,3,256,128,2,3,1.6048630606383085,1.2890337966382504,1.585709908977151,0.08167176578160394,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,3,0.49104482866823673,0.45836716890335083,0.6075259298086166,0.2669247130765648,3.2682618102116248,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,3,0.3567333798855543,0.35033351741731167,0.3854172769933939,0.367422863658147,4.498774578238723,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,3,0.35774463787674904,0.3515880089253187,0.385533319786191,0.36638424765197236,4.486057625247256,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,3,0.3179950825870037,0.307686161249876,0.3502919338643551,0.4121824744385429,5.046817226172725,10,float32,cuda,50 -pytorch,1,3,256,128,2,5,5.680167442187667,1.439184183254838,13.859670702368021,0.02307537609305383,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,5,0.4627335909754038,0.46338303945958614,0.48728715628385544,0.2832558572713755,12.2752433645769,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,5,0.3555159270763397,0.35281339660286903,0.3653420601040125,0.36868109138709554,15.977251677300995,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,5,0.3210102953016758,0.31748763285577297,0.3332026768475771,0.4083108919507472,17.694658163064886,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,5,0.32647970132529736,0.31356699764728546,0.3721813205629587,0.40147059516390177,17.398225430646516,10,float32,cuda,50 -pytorch,1,3,256,128,2,7,4.92123176343739,1.4630758669227362,18.036476150155067,0.026633982364701436,,10,float32,cuda,50 -cuda_v1,1,3,256,128,2,7,0.44690595008432865,0.4451470449566841,0.45659723691642284,0.29328765923852085,11.011783939123616,10,float32,cuda,50 -cuda_v2,1,3,256,128,2,7,0.35398226231336594,0.351473456248641,0.3692640457302332,0.37027844034729446,13.902481246590899,10,float32,cuda,50 -cuda_v3,1,3,256,128,2,7,0.3210613317787647,0.31872186809778214,0.33104592002928257,0.4082459861292746,15.32801143062749,10,float32,cuda,50 -cuda_v4,1,3,256,128,2,7,0.32810162752866745,0.32733287662267685,0.3445217851549387,0.3994859793511623,14.99910805229823,10,float32,cuda,50 -pytorch,1,3,256,128,3,3,9.371620612218976,2.651255577802658,27.0526010543108,0.03146862343269484,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,3,0.5212203320115805,0.5533958319574594,0.5914739333093166,0.5658106215116867,17.980151649208416,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,3,0.3939199075102806,0.38670445792376995,0.4216096829622984,0.7486598021002614,23.790675296029292,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,3,0.39537341333925724,0.39467494934797287,0.4045611247420311,0.745907514390568,23.7032139644086,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,3,0.3381696157157421,0.32397685572505,0.3813320305198431,0.8720830798941337,27.712781328339542,10,float32,cuda,50 -pytorch,1,3,256,128,3,5,6.334149120375514,2.5149499997496605,9.2535394243896,0.04655905543040269,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,5,0.4680374823510647,0.46257232315838337,0.49366913735866547,0.6301033808629732,13.533422768957653,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,5,0.3780175279825926,0.38000987842679024,0.39740419015288353,0.7801543001825579,16.756231263083645,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,5,0.37817317992448807,0.36235409788787365,0.45027188025414944,0.7798331972110945,16.749334581686327,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,5,0.34465146251022816,0.34011295065283775,0.35576531663537025,0.8556818469651727,18.378419387056965,10,float32,cuda,50 -pytorch,1,3,256,128,3,7,9.138631783425808,3.381723305210471,33.623141143471,0.03227091396054104,,10,float32,cuda,50 -cuda_v1,1,3,256,128,3,7,0.4677951242774725,0.4621578846126795,0.49515804275870323,0.6304298285611738,19.53554303829219,10,float32,cuda,50 -cuda_v2,1,3,256,128,3,7,0.3610655106604099,0.35483832471072674,0.38772691041231155,0.8167825264190665,25.310176446126675,10,float32,cuda,50 -cuda_v3,1,3,256,128,3,7,0.3526437934488058,0.32708211801946163,0.39914101362228394,0.8362886444584856,25.91462533354493,10,float32,cuda,50 -cuda_v4,1,3,256,128,3,7,0.32145872712135315,0.3138268366456032,0.3484130371361971,0.9174179299498951,28.428631772612924,10,float32,cuda,50 -pytorch,1,3,256,256,1,3,4.861515955999494,1.0742205195128918,14.73111561499536,0.013480568734763365,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,3,0.4849292803555727,0.48372894525527954,0.49910699017345905,0.13514547925822495,10.025206051560351,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,3,0.3039980586618185,0.2988413907587528,0.32487385906279087,0.21558032406024435,15.99193092679473,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,3,0.2766306512057781,0.273675424978137,0.2857776824384928,0.23690794824919645,17.574032142892012,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,3,0.2745348773896694,0.26938505470752716,0.2950329799205065,0.23871648157468708,17.708190675894173,10,float32,cuda,50 -pytorch,1,3,256,256,1,5,1.1926674656569958,1.1049916502088308,1.324015948921442,0.05494909678273036,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,5,0.4482492245733738,0.419301213696599,0.5197371356189251,0.14620438007979739,2.660723990748964,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,5,0.31799268908798695,0.2957459073513746,0.39233872666954994,0.2060927884473046,3.7506128492375206,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,5,0.27819630689918995,0.2757101319730282,0.28871200047433376,0.23557465852250975,4.287143416642058,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,5,0.27723253704607487,0.2712407149374485,0.302234198898077,0.23639360912787882,4.302047221314356,10,float32,cuda,50 -pytorch,1,3,256,256,1,7,1.1875793617218733,1.0118531063199043,1.2298297137022018,0.05518452249370455,,10,float32,cuda,50 -cuda_v1,1,3,256,256,1,7,0.4127761535346508,0.4072303418070078,0.43675173074007034,0.15876886161859766,2.87705418918339,10,float32,cuda,50 -cuda_v2,1,3,256,256,1,7,0.29986392706632614,0.29469607397913933,0.319720059633255,0.21855246358293792,3.9603942139368957,10,float32,cuda,50 -cuda_v3,1,3,256,256,1,7,0.3060135804116726,0.2915910445153713,0.36104372702538967,0.21416043010848088,3.8808060744371273,10,float32,cuda,50 -cuda_v4,1,3,256,256,1,7,0.2705792896449566,0.2661552280187607,0.29452070593833923,0.242206268210674,4.389025351053911,10,float32,cuda,50 -pytorch,1,3,256,256,2,3,4.6642600279301405,2.254175953567028,5.693626776337624,0.05620269848384325,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,3,0.47030373476445675,0.4551168531179428,0.5582175217568874,0.5573929795205436,9.917548348337382,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,3,0.3628383856266737,0.35751843824982643,0.3879097755998373,0.722481441833228,12.854924431091538,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,3,0.3875340800732374,0.3856392577290535,0.41215093806385994,0.6764411531250598,12.035741545746568,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,3,0.33060619607567787,0.3230019938200712,0.3539799712598324,0.7929191984653353,14.108205119248465,10,float32,cuda,50 -pytorch,1,3,256,256,2,5,8.932606596499681,2.1693871822208166,25.42668771930039,0.02934686501280862,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,5,0.5238902755081654,0.5380609072744846,0.5635851062834263,0.5003795875877337,17.05052949844353,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,5,0.37138083949685097,0.36642886698246,0.3967938479036093,0.7058630174759535,24.052416405223358,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,5,0.3278264496475458,0.3224520478397608,0.3462827764451504,0.7996426166401076,27.24797406098057,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,5,0.32294890843331814,0.3143919166177511,0.3508441150188446,0.811719727655085,27.65950391296668,10,float32,cuda,50 -pytorch,1,3,256,256,2,7,5.310848616063595,2.0857034251093864,5.400367686524987,0.049360096465016795,,10,float32,cuda,50 -cuda_v1,1,3,256,256,2,7,0.5602501425892115,0.5671817343682051,0.6339772138744593,0.4679052802887194,9.479423943596625,10,float32,cuda,50 -cuda_v2,1,3,256,256,2,7,0.42169813998043537,0.4308209754526615,0.4505240358412266,0.6216389761931654,12.593957887293481,10,float32,cuda,50 -cuda_v3,1,3,256,256,2,7,0.38437320850789547,0.3447979688644409,0.5109140183776617,0.6820038290848132,13.816906325703249,10,float32,cuda,50 -cuda_v4,1,3,256,256,2,7,0.3148854896426201,0.31107640825212,0.32421983778476715,0.832505811231635,16.865968076493942,10,float32,cuda,50 -pytorch,1,3,256,256,3,3,10.819459995254874,5.247258115559816,43.626357009634376,0.054515105214001526,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,3,0.5695007462054491,0.5678911693394184,0.5740981083363295,1.0356860880867385,18.998148935439186,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,3,0.47695184126496315,0.47551305033266544,0.48432392068207264,1.2366531565025085,22.68459634532429,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,3,0.4476183373481035,0.44674682430922985,0.4503197968006134,1.3176940057781994,24.171172386176853,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,3,0.43853784911334515,0.4363264888525009,0.4481939598917961,1.3449785490409363,24.671667490343484,10,float32,cuda,50 -pytorch,1,3,256,256,3,5,16.53141546063125,6.35837041772902,54.64553306810558,0.03567897748409002,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,5,0.5723217409104109,0.5713619757443666,0.5795460194349289,1.0305811536387692,28.88482872297353,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,5,0.4708682466298342,0.46868249773979187,0.4833988845348358,1.2526306545866557,35.10836752945706,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,5,0.4524185135960579,0.4476869944483042,0.4720529541373253,1.3037132263040514,36.540094986898204,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,5,0.47651665285229683,0.4608971066772938,0.5228085909038782,1.2377825548582126,34.69220931037514,10,float32,cuda,50 -pytorch,1,3,256,256,3,7,18.527234653010964,7.752042729407549,61.81915830820799,0.03183551193940022,,10,float32,cuda,50 -cuda_v1,1,3,256,256,3,7,0.5787044391036034,0.5703659262508154,0.5973172839730978,1.0192145768116458,32.015020796642105,10,float32,cuda,50 -cuda_v2,1,3,256,256,3,7,0.4701301362365484,0.46895304694771767,0.47616218216717243,1.2545973009125009,39.40873648586716,10,float32,cuda,50 -cuda_v3,1,3,256,256,3,7,0.4477470647543669,0.44500199146568775,0.46136612072587013,1.3173151683832396,41.3787964487829,10,float32,cuda,50 -cuda_v4,1,3,256,256,3,7,0.43689489364624023,0.43615163303911686,0.44089406728744507,1.3500363784924174,42.40661752392263,10,float32,cuda,50 \ No newline at end of file +variant,B,C,H,W,scale,ksize,fwd_mean_ms,fwd_p50_ms,fwd_p90_ms,fwd_tp,fwd_speedup_vs_pytorch,bwd_mean_ms,bwd_p50_ms,bwd_p90_ms,bwd_tp,bwd_speedup_vs_pytorch,grad_ok,eps,warmup,dtype,device,iters +pytorch,8,8,256,256,1,3,9.288488607853651,8.991222945041955,9.583412948995829,0.05644492038852277,,20.064111375249922,19.56672000233084,21.39674376230687,0.02613063644805796,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,1,3,3.3021508576348424,3.324655001051724,3.3440517028793693,0.158771668104685,2.8128601654821153,6.695929053239524,6.69370440300554,6.72081895172596,0.07829951539679872,2.9964641524304687,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,1,3,3.1805392680689692,3.1820274889469147,3.1898362562060356,0.16484248607259483,2.920413120223181,6.543120900169015,6.540277507156134,6.558541930280626,0.08012812356660827,3.066443625507768,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,3,3.024693327024579,3.04701691493392,3.0670031206682324,0.17333591981562882,3.070886071280102,6.397886727936566,6.395927630364895,6.411892082542181,0.08194705881719984,3.1360529231690477,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,3,3.0596842663362622,3.0616489239037037,3.07619022205472,0.17135362814013316,3.0357670266990997,6.399778164923191,6.40075805131346,6.410616356879473,0.08192283958740817,3.1351260712785547,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,3,3.0559819284826517,3.0527031049132347,3.0764779541641474,0.17156122394359777,3.0394448740950337,6.409975425340235,6.408171029761434,6.429016985930502,0.08179251326414738,3.1301385799283215,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,3,3.0050428677350283,3.0304675456136465,3.052515466697514,0.17446939131193434,3.090967089882017,6.387096201069653,6.387264467775822,6.399212614633143,0.08208550231515176,3.141351052750664,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,1,3,1.8821742571890354,1.8811896443367004,1.8879902781918645,0.27855444202228474,4.934978029996914,5.040463833138347,5.040465504862368,5.048878467641771,0.10401582420909114,3.9806081423180046,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,1,5,7.326102494262159,7.420093985274434,7.613263255916536,0.07156438234526817,,16.14620674867183,16.1027709254995,16.263780114240944,0.03247127998302929,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,1,5,3.2715308107435703,3.3216774463653564,3.3417833968997,0.16025769901914425,2.2393499918153132,6.707657393999398,6.701117963530123,6.731993076391518,0.07816260867304027,2.407130507756115,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,1,5,3.159591546282172,3.1798079144209623,3.199100005440414,0.16593537244297263,2.318686572915615,6.542924279347062,6.542460061609745,6.554703111760318,0.08013053148955596,2.467735535261782,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,5,3.05445936974138,3.055477049201727,3.073208383284509,0.1716467422005326,2.398494007429687,6.400615535676479,6.398052908480167,6.4109522849321365,0.08191212190103653,2.5226021870356474,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,5,3.0754889361560345,3.0726579716429114,3.091029403731227,0.17047305676712743,2.3820935943339285,6.441063601523638,6.430621026083827,6.473009032197297,0.08139773683898717,2.5067609555745474,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,5,3.048403072170913,3.052991349250078,3.072553128004074,0.17198775476453956,2.403259123159488,6.407246896997094,6.406409083865583,6.420211470685899,0.08182734463466983,2.5199913485836056,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,5,2.9911579517647624,3.0331225134432316,3.045725799165666,0.1752792759374923,2.449252969051604,6.385393836535513,6.382934865541756,6.403305334970355,0.08210738654836988,2.5286156440793923,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,1,5,1.8822504533454776,1.8811115296557546,1.8892435124143958,0.2785431657451006,3.8922038675782376,5.046942573972046,5.043365992605686,5.0634910352528095,0.10388229949432032,3.1992055609946117,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,1,7,9.156141332350671,9.012500522658229,10.730735887773335,0.057260802446066954,,18.346093399450183,18.29959498718381,18.474590103141963,0.02857763713422022,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,1,7,3.3304048283025622,3.330954583361745,3.3437674399465322,0.15742470571279427,2.7492577642632607,6.747524798847735,6.6985663725063205,6.975994328968227,0.07770078890106961,2.718936787395449,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,1,7,3.1805677665397525,3.180499654263258,3.187753399834037,0.16484100905367305,2.8787757420782603,6.5445497911423445,6.542496965266764,6.557909771800041,0.0801106289556529,2.803262865274632,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,1,7,2.9923936538398266,2.959055360406637,3.0513782519847155,0.17520689476374068,3.0598050896818174,6.439674673601985,6.414378061890602,6.499357661232352,0.08141529294161429,2.8489161843308506,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,1,7,3.0570596596226096,3.056291490793228,3.0761950882151723,0.17150074201192486,2.9950810097965133,6.428387816995382,6.409163004718721,6.460378412157297,0.08155824056132495,2.853918264070309,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,1,7,3.0619724839925766,3.061673021875322,3.0807194765657187,0.17122557525937293,2.9902755103833485,6.415656288154423,6.413351511582732,6.4297555247321725,0.081720088553999,2.8595817131481276,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,1,7,3.042473620735109,3.0388199957087636,3.0632448382675648,0.17232294026375944,3.0094398419593875,6.388430343940854,6.389048998244107,6.399145186878741,0.08206835979627831,2.871768558430108,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,1,7,1.8896355107426643,1.883591990917921,1.9149431493133307,0.2774545657188377,4.845453676276503,5.0843368750065565,5.057928501628339,5.145174008794129,0.10311826554555825,3.6083551996004446,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,2,3,45.595936393365264,45.675908448174596,46.11496950965375,0.045994274180651766,,80.62302918639034,80.68823453504592,81.43389630131423,0.02601182343510869,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,2,3,11.85663956683129,11.822210508398712,11.90752275288105,0.17687574866210323,3.8456036498671153,25.430614417418838,25.363861583173275,25.43278068769723,0.08246564418685631,3.170313853336047,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,2,3,11.422737971879542,11.390020838007331,11.460354528389871,0.1835945116803661,3.9916818984741824,24.925604569725692,24.912420543842018,24.930561799556017,0.08413645470999621,3.2345465868584786,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,3,10.840446311049163,10.825388482771814,10.85302964784205,0.19345624154445284,4.206094019108,24.0234538866207,24.002613499760628,24.026684556156397,0.08729602370656452,3.3560132346869342,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,3,10.878428206779063,10.826261015608907,11.14368592388928,0.1927807915019494,4.191408494560955,23.699625409208238,23.694894975051284,23.71811883058399,0.0884888247721069,3.4018693457941795,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,3,10.918543636798859,10.834143031388521,11.191598395816982,0.1920725025022523,4.176008990768045,23.730977713130414,23.69663491845131,23.714118194766343,0.08837191730367012,3.3973749485164024,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,3,10.793533944524825,10.768141131848097,10.864380979910493,0.19429706811306324,4.224375133085532,23.638727851212025,23.631693911738694,23.66107157431543,0.08871678768840655,3.4106331649423582,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,2,3,7.474436634220183,7.435761974193156,7.621926814317703,0.2805765976259157,6.100250577363058,20.239599947817624,20.2378734247759,20.263943052850664,0.10361627726866854,3.9834299785694967,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,2,5,38.853937541134655,38.41745341196656,40.494335955008864,0.05397527593644133,,66.22083507943898,66.26046146266162,67.06938743591309,0.031669066049744635,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,2,5,11.842856714501977,11.828812537714839,11.868507391773164,0.17708159868488205,3.2807909846242334,25.389751312322915,25.350986630655825,25.415705051273108,0.08259836712075817,2.608171866862623,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,2,5,11.411372288130224,11.404957505874336,11.43825901672244,0.1837773711213853,3.4048435683365956,24.930950277484953,24.917138973250985,24.93941232096404,0.08411841412615266,2.656169714446977,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,5,10.863104425370693,10.828325641341507,11.124962358735502,0.19305273316733643,3.5766882117409824,24.023548257537186,24.010553024709225,24.034774862229824,0.08729568078445848,2.756496849238861,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,5,10.90485557448119,10.82881959155202,11.15186910610646,0.19231359697304146,3.562994234610317,23.72426231391728,23.71151139959693,23.792543495073915,0.08839693189405325,2.7912705652639866,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,5,10.839254357852042,10.825947392731905,10.844685533083975,0.19347751522048257,3.584558149333266,23.702450119890273,23.698252509348094,23.726728814654052,0.08847827922397537,2.7938392336861733,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,5,10.759746134281158,10.75667655095458,10.77785084489733,0.19490720076734483,3.6110459351214446,23.636622526682913,23.63265841268003,23.660433711484075,0.08872468973232393,2.8016200286095696,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,2,5,7.5289920112118125,7.6012545032426715,7.6225581811740994,0.2785435283869371,5.160576274124775,20.23725756444037,20.235952455550432,20.263770665042102,0.10362827044732498,3.2722237619685215,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,2,7,43.63379757385701,43.68262959178537,44.327281741425395,0.048062559680950134,,75.68363416008651,75.79991698730737,76.05113848112524,0.02770945163077252,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,2,7,11.998974299058318,11.853986419737339,12.264510500244796,0.1747776057962375,3.6364606245786693,25.40259242989123,25.356607511639595,25.428785989060998,0.08255661329795148,2.9793665496530024,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,2,7,11.41712686046958,11.413594475015998,11.448350944556296,0.18368474184701714,3.821784421519722,24.922909317538142,24.915332440286875,24.959108070470393,0.0841455535258977,3.036709446550378,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,2,7,10.88796194177121,10.827695950865746,11.144098523072898,0.19261198847089686,4.007526643389317,24.021303891204298,24.00768455117941,24.02521837502718,0.08730383702309759,3.1506880102290795,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,2,7,10.852394071407616,10.832656407728791,10.89512084145099,0.1932432591556259,4.020660997633443,23.70872948784381,23.69790489319712,23.748202505521476,0.08845484533767504,3.192226483451668,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,2,7,10.89978810865432,10.831900988705456,11.149773467332125,0.19240300628733165,4.003178514930233,23.71836454141885,23.704793537035584,23.77322791144252,0.08841891254086219,3.190929713046693,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,2,7,10.774082760326564,10.765377548523247,10.784391174092889,0.19464784582148828,4.049885131245684,23.65874081850052,23.644065018743277,23.685648757964373,0.08864174201359364,3.198971354422374,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,2,7,7.470769472420216,7.435820531100035,7.61996959336102,0.2807143237041433,5.840602863592509,20.250088809989393,20.24794602766633,20.27125945314765,0.10356260753609496,3.7374470240718987,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,3,3,101.28841781057417,101.60925146192312,107.4452179018408,0.04658570152437898,,175.37028575781733,175.02894543576986,177.75019966065884,0.026906450996586024,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,3,3,26.66485572233796,26.59923385363072,27.004664088599384,0.17695921737341677,3.798573630598038,58.96571567747742,58.67917649447918,60.70718690752983,0.08002263596373708,2.974105948565666,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,3,3,25.672803041525185,25.64728946890682,25.70093993563205,0.18379730457822557,3.9453587380677684,56.8397456035018,56.80113600101322,56.884816801175475,0.08301571285901913,3.085346070708189,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,3,24.374681571498513,24.3739370489493,24.404809717088938,0.19358579049161745,4.155476555189605,53.81168487481773,53.78810840193182,53.942976240068674,0.08768712615070265,3.2589629216364755,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,3,24.441681993193924,24.361361982300878,24.429271719418466,0.1930551261289607,4.144085412729743,53.07756224647164,53.07193798944354,53.1614552019164,0.08889993813371996,3.3040380593114964,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,3,24.375403551384807,24.365137447603047,24.41326177213341,0.19358005663590047,4.155353473309771,53.05147184524685,53.03926358465105,53.125715954229236,0.08894365859941288,3.305662965758595,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,3,24.276557215489447,24.22345709055662,24.269587476737797,0.1943682523891544,4.17227273667734,52.906081513501704,52.899461006745696,52.96852015890181,0.08918808320355022,3.314747203741834,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,3,3,16.926095513626933,16.923529445193708,16.972190444357693,0.2787761652532999,5.9841572871326685,46.062268051318824,46.03102651890367,46.21587647125125,0.10243941949933792,3.807243828341975,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,3,5,83.52379720192403,83.7456090375781,85.80605688039213,0.05649398324878007,,145.5188752664253,145.75104543473572,148.14186077564955,0.03242597904471772,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,3,5,26.591138602234423,26.581451063975692,26.635813480243087,0.17744979147314519,3.1410387667606536,57.973297983407974,57.822484406642616,58.44167503528297,0.08139250593179065,2.510101724902264,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,3,5,25.637969239614904,25.634926394559443,25.683102500624955,0.1840470263420472,3.2578164214685903,56.7909628059715,56.76805193070322,56.88152206130326,0.08308702242152946,2.562359714935563,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,5,24.41519634798169,24.391590617597103,24.450342124328017,0.19326455264776388,3.4209758550161573,53.789559081196785,53.765413467772305,53.912423201836646,0.08772319536728602,2.705336830271478,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,5,24.383281343616545,24.362614494748414,24.392798193730414,0.19351751446018198,3.4254535320690236,53.027258049696684,53.017987054772675,53.09310944285244,0.08898427287297746,2.744227791865955,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,5,24.393998621962965,24.367429548874497,24.428526940755546,0.19343249432471676,3.423948590647372,53.11096484772861,53.071493515744805,53.206104156561196,0.08884402709550474,2.73990268645343,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,5,24.267674176953733,24.23390606418252,24.288278119638562,0.194439399737825,3.441771823409596,52.94822241179645,52.909665973857045,53.120820759795606,0.08911709940518672,2.7483240916885783,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,3,5,17.1291301259771,16.946400050073862,18.221904919482768,0.2754717820050909,4.876126025527497,46.0351287573576,46.020424575544894,46.14793793298304,0.1024998110653888,3.1610398231626027,True,1e-05,10,float32,cuda,50 +pytorch,8,8,256,256,3,7,90.59936547186226,89.40491802059114,93.76176362857223,0.05208195416628463,,158.36064806673676,158.1671329913661,160.4503425071016,0.02979649336880384,,True,1e-05,10,float32,cuda,50 +cuda_v1,8,8,256,256,3,7,26.627318738028407,26.58525703009218,26.68023353908211,0.17720867979324692,3.4024967501692434,58.71118636801839,58.641764568164945,60.120047559030354,0.08036955632990492,2.6972823726314648,True,1e-05,10,float32,cuda,50 +cuda_v2,8,8,256,256,3,7,25.703592961654067,25.652661453932524,26.049211691133678,0.18357713674658,3.5247743615852856,56.822463613934815,56.79537542164326,56.92225047387183,0.0830409612659392,2.786937383473484,True,1e-05,10,float32,cuda,50 +cuda_v3,8,8,256,256,3,7,24.43074162583798,24.387317011132836,24.824217706918716,0.1931415784369645,3.7084165048859696,53.77089409157634,53.76591603271663,53.81281706504524,0.08775364590300178,2.9450997745552696,True,1e-05,10,float32,cuda,50 +cuda_v4,8,8,256,256,3,7,24.416185086593032,24.371870909817517,24.417825369164348,0.1932567263585738,3.7106274035254803,53.0829459335655,53.052632487379014,53.19422874599695,0.0888909218773469,2.983267889180994,True,1e-05,10,float32,cuda,50 +cuda_v5,8,8,256,256,3,7,24.362534414976835,24.361073039472103,24.400543165393174,0.19368231234182484,3.718798870784413,53.057269509881735,53.02855698391795,53.15046065952629,0.0889339395635725,2.984711605583899,True,1e-05,10,float32,cuda,50 +cuda_v6,8,8,256,256,3,7,24.26755757071078,24.222337524406612,24.37746999785304,0.1944403340241791,3.733353272486279,52.89180411491543,52.88355250377208,52.94667130801827,0.08921215827216153,2.9940489025988666,True,1e-05,10,float32,cuda,50 +cuda_v7,8,8,256,256,3,7,16.87394628766924,16.87477866653353,16.914236755110323,0.27963772786500724,5.369186551107395,46.165161593817174,46.10549903009087,46.407426544465125,0.10221110112245234,3.4303063738857515,True,1e-05,10,float32,cuda,50 diff --git a/test/test_speed.py b/test/test_speed.py index 26707ee..75fb61f 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -185,7 +185,7 @@ def parent_main(args): print(f"[Cfg] dtype={args.dtype}, warmup={args.warmup}, iters={args.iters}") print(f"[Grid] B={Bs} C={Cs} H={Hs} W={Ws} scale={Ss} ksize={Ks}\n") - variants = ["pytorch", "cuda_v1", "cuda_v2", "cuda_v3","cuda_v4", "cuda_v5", "cuda_v6", "cuda_v7"] + variants = ["pytorch", "cuda"] results = [] cache_root = PROJECT_ROOT / ".torch_ext_cache_grid" diff --git a/test/test_v5.csv b/test/test_v5.csv deleted file mode 100644 index c8bf52e..0000000 --- a/test/test_v5.csv +++ /dev/null @@ -1,37 +0,0 @@ -variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters -pytorch,8,8,256,256,1,3,9.018746092915535,8.91665113158524,9.489735681563616,0.058133136757430405,,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,3,3.061873996630311,3.0620904872193933,3.0763483606278896,0.17123108285219948,2.9454987706355484,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,3,3.043683832511306,3.0511280056089163,3.0741054099053144,0.1722544222234201,2.963102145032679,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,3,3.0455162096768618,3.054927452467382,3.0771585879847407,0.17215078295565156,2.9613193534348157,10,float32,cuda,50 -pytorch,8,8,256,256,1,5,8.164601628668606,8.02653655409813,9.788747737184167,0.06421476807380928,,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,5,3.0641747498884797,3.0675254529342055,3.079495718702674,0.17110251300748477,2.664535248508838,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,5,3.0497499415650964,3.0503239249810576,3.0591003131121397,0.17191179934278203,2.6771380556133817,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,5,3.067394499666989,3.0633179703727365,3.082340001128614,0.17092291195570677,2.6617383677107704,10,float32,cuda,50 -pytorch,8,8,256,256,1,7,9.292588303796947,9.34083794709295,10.928183235228062,0.05642001806813891,,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,7,3.0567877739667892,3.0584499472752213,3.077001217752695,0.17151599612675505,3.0399847784453704,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,7,3.011856614612043,3.047129837796092,3.0612722039222717,0.1740746878375329,3.085335556385351,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,7,3.0088233342394233,3.04994557518512,3.0606051674112678,0.17425017748093563,3.088445970905084,10,float32,cuda,50 -pytorch,8,8,256,256,2,3,48.46728646196425,48.91097836662084,52.44047886226326,0.04326943291215169,,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,3,10.859722616150975,10.825826553627849,11.112747946754098,0.1931128514167608,4.463031715918962,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,3,10.890666269697249,10.818996001034975,11.145172524265945,0.19256415981042627,4.45035090248079,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,3,10.889667607843876,10.824134456925094,11.142246099188924,0.19258181934675508,4.450759032080377,10,float32,cuda,50 -pytorch,8,8,256,256,2,5,39.19886094983667,39.11385603714734,40.184705122374,0.053500330090809387,,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,5,10.921839205548167,10.843515512533486,11.04965121485293,0.19201454631694917,3.589034796440155,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,5,10.950461719185114,10.84580144379288,11.239956971257925,0.191512655245012,3.57965371278917,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,5,10.826412132009864,10.821027914062142,10.844661109149456,0.19370701710121166,3.6206695691862247,10,float32,cuda,50 -pytorch,8,8,256,256,2,7,44.36711141373962,43.92930108588189,47.58514082059264,0.04726816628748459,,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,7,10.912643321789801,10.832935920916498,11.155167641118169,0.19217635344248038,4.065661279806486,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,7,10.899619800038636,10.823742602951825,11.146645108237863,0.19240597731606807,4.070519176603055,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,7,10.856493567116559,10.844919481314719,10.878617619164288,0.19317028901045027,4.086688868689981,10,float32,cuda,50 -pytorch,8,8,256,256,3,3,101.29073102492839,100.59209598693997,105.4747574031353,0.046584637629268566,,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,3,24.37616023235023,24.372862884774804,24.413138814270496,0.1935740475539636,4.15531938005982,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,3,24.461542363278568,24.370684404857457,24.882999388501048,0.19289838432606382,4.14081538770773,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,3,24.447299065068364,24.368930491618812,24.460403178818524,0.1930107693058896,4.143227877866399,10,float32,cuda,50 -pytorch,8,8,256,256,3,5,86.84147734194994,86.47282503079623,90.13091719243675,0.054335694698282394,,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,5,24.3759097578004,24.37639352865517,24.403270613402128,0.19357603662320866,3.562594307445706,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,5,24.360263156704605,24.359582574106753,24.388629896566272,0.19370037054387548,3.5648825623646356,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,5,24.40255253110081,24.365780991502106,24.539125943556428,0.1933646897793254,3.558704657280067,10,float32,cuda,50 -pytorch,8,8,256,256,3,7,92.940255952999,93.59969745855778,96.69805229641497,0.05077016360258625,,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,7,24.371705148369074,24.363814387470484,24.390139686875045,0.19360943238375597,3.813449054434275,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,7,24.42510688211769,24.368971004150808,24.507023952901363,0.19318613518349081,3.8051115355013323,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,7,24.401678266003728,24.37236753758043,24.425271223299205,0.19337161766344219,3.808764911161164,10,float32,cuda,50 diff --git a/test/test_v6.csv b/test/test_v6.csv deleted file mode 100644 index 122bef0..0000000 --- a/test/test_v6.csv +++ /dev/null @@ -1,37 +0,0 @@ -variant,B,C,H,W,scale,ksize,mean_ms,p50_ms,p90_ms,tp,speedup_vs_pytorch,warmup,dtype,device,iters -pytorch,8,8,256,256,1,3,9.205426471307874,9.001746540889144,9.537389921024442,0.0569542325533899,,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,3,3.0636826157569885,3.0644324142485857,3.076496347784996,0.17112999803031376,3.0046932485639855,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,3,2.9942896962165833,2.9994455398991704,3.0572090996429324,0.17509595035592612,3.0743272713189156,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,3,3.0439832201227546,3.0431936029344797,3.0574772041291,0.17223748032975592,3.0241383758142555,10,float32,cuda,50 -pytorch,8,8,256,256,1,5,7.813363987952471,7.445647963322699,8.981779753230512,0.0671014432206674,,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,5,3.069853764027357,3.0687025282531977,3.0871906550601125,0.17078598535983158,2.5451909402036397,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,5,3.0559902312234044,3.0536623671650887,3.0731556937098503,0.17156075783334943,2.5567372264879746,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,5,3.0378469452261925,3.0349043663591146,3.0528080882504582,0.17258539006512144,2.572007125056362,10,float32,cuda,50 -pytorch,8,8,256,256,1,7,8.450384107418358,8.219452458433807,9.284918219782412,0.062043096897777934,,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,7,3.0614592181518674,3.0635460279881954,3.0748489312827587,0.17125428190955963,2.7602471583860133,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,7,3.06079070083797,3.0586729990318418,3.079405124299228,0.17129168611772855,2.760850033001227,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,7,3.066543652676046,3.0657140305265784,3.081072401255369,0.17097033643805315,2.755670573951901,10,float32,cuda,50 -pytorch,8,8,256,256,2,3,46.00589778739959,45.81692651845515,47.76265760883689,0.04558441636529442,,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,3,10.883753076195717,10.826129000633955,11.165698524564505,0.1926864736197262,4.227025132352634,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,3,10.945874289609492,10.827564052306116,11.150588723830879,0.1915929184378399,4.20303546068232,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,3,10.81716647837311,10.752257890999317,11.076831934042275,0.19387258245427405,4.253045183263078,10,float32,cuda,50 -pytorch,8,8,256,256,2,5,38.73277612961829,38.58403104823083,40.112837520428,0.05414411796825335,,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,5,10.870220344513655,10.826208395883441,11.142217181622982,0.19292635600146416,3.5632006437815438,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,5,10.877257627435029,10.877057909965515,10.957135050557554,0.19280153801914962,3.560895352145105,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,5,10.755426031537354,10.749659966677427,10.778588545508683,0.19498548861297296,3.601231231198558,10,float32,cuda,50 -pytorch,8,8,256,256,2,7,44.22167818527669,44.16610090993345,45.8656846312806,0.04742361859750118,,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,7,10.892827231436968,10.832248604856431,11.152946949005127,0.19252595817801713,4.059706194334179,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,7,10.838612425141037,10.817955480888486,10.870237019844353,0.19348897420997238,4.080012869793271,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,7,10.811590934172273,10.752530070021749,11.083307187072933,0.1939725626661953,4.09021007680793,10,float32,cuda,50 -pytorch,8,8,256,256,3,3,100.32575480174273,100.2145285019651,101.78962631616741,0.04703270869304274,,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,3,24.35199290048331,24.352667038328946,24.379127472639084,0.19376615373053724,4.119817019154584,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,3,24.405628656968474,24.35805497225374,24.45684978738427,0.19334031777348676,4.110762980616645,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,3,24.225411019288003,24.226056993938982,24.253131332807243,0.19477861474643748,4.141343761798901,10,float32,cuda,50 -pytorch,8,8,256,256,3,5,83.76444775145501,83.35559000261128,86.63466500584036,0.05633167921074292,,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,5,24.367065099067986,24.362630443647504,24.40121565014124,0.19364630007002692,3.437609224208906,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,5,24.346372256986797,24.343705968931317,24.382443260401487,0.1938108869031148,3.4405309697594357,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,5,24.30050970055163,24.209839990362525,24.665594031102955,0.19417666782079415,3.4470243128091065,10,float32,cuda,50 -pytorch,8,8,256,256,3,7,93.56231243349612,94.02274643070996,95.42696874123067,0.0504326141292624,,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,7,24.370713336393237,24.374024011194706,24.39629177097231,0.19361731168343105,3.8391290046392603,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,7,24.427954778075218,24.36559647321701,24.452975136227906,0.19316361287171985,3.830132865542675,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,7,24.22358512878418,24.221764993853867,24.25856285262853,0.19479329648826565,3.862446947306687,10,float32,cuda,50 diff --git a/test/test_v7.csv b/test/test_v7.csv deleted file mode 100644 index e01b4d1..0000000 --- a/test/test_v7.csv +++ /dev/null @@ -1,73 +0,0 @@ -variant,B,C,H,W,scale,ksize,fwd_mean_ms,fwd_p50_ms,fwd_p90_ms,fwd_tp,fwd_speedup_vs_pytorch,bwd_mean_ms,bwd_p50_ms,bwd_p90_ms,bwd_tp,bwd_speedup_vs_pytorch,grad_ok,eps,warmup,dtype,device,iters -pytorch,8,8,256,256,1,3,9.288488607853651,8.991222945041955,9.583412948995829,0.05644492038852277,,20.064111375249922,19.56672000233084,21.39674376230687,0.02613063644805796,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,1,3,3.3021508576348424,3.324655001051724,3.3440517028793693,0.158771668104685,2.8128601654821153,6.695929053239524,6.69370440300554,6.72081895172596,0.07829951539679872,2.9964641524304687,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,1,3,3.1805392680689692,3.1820274889469147,3.1898362562060356,0.16484248607259483,2.920413120223181,6.543120900169015,6.540277507156134,6.558541930280626,0.08012812356660827,3.066443625507768,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,3,3.024693327024579,3.04701691493392,3.0670031206682324,0.17333591981562882,3.070886071280102,6.397886727936566,6.395927630364895,6.411892082542181,0.08194705881719984,3.1360529231690477,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,3,3.0596842663362622,3.0616489239037037,3.07619022205472,0.17135362814013316,3.0357670266990997,6.399778164923191,6.40075805131346,6.410616356879473,0.08192283958740817,3.1351260712785547,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,3,3.0559819284826517,3.0527031049132347,3.0764779541641474,0.17156122394359777,3.0394448740950337,6.409975425340235,6.408171029761434,6.429016985930502,0.08179251326414738,3.1301385799283215,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,3,3.0050428677350283,3.0304675456136465,3.052515466697514,0.17446939131193434,3.090967089882017,6.387096201069653,6.387264467775822,6.399212614633143,0.08208550231515176,3.141351052750664,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,1,3,1.8821742571890354,1.8811896443367004,1.8879902781918645,0.27855444202228474,4.934978029996914,5.040463833138347,5.040465504862368,5.048878467641771,0.10401582420909114,3.9806081423180046,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,1,5,7.326102494262159,7.420093985274434,7.613263255916536,0.07156438234526817,,16.14620674867183,16.1027709254995,16.263780114240944,0.03247127998302929,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,1,5,3.2715308107435703,3.3216774463653564,3.3417833968997,0.16025769901914425,2.2393499918153132,6.707657393999398,6.701117963530123,6.731993076391518,0.07816260867304027,2.407130507756115,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,1,5,3.159591546282172,3.1798079144209623,3.199100005440414,0.16593537244297263,2.318686572915615,6.542924279347062,6.542460061609745,6.554703111760318,0.08013053148955596,2.467735535261782,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,5,3.05445936974138,3.055477049201727,3.073208383284509,0.1716467422005326,2.398494007429687,6.400615535676479,6.398052908480167,6.4109522849321365,0.08191212190103653,2.5226021870356474,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,5,3.0754889361560345,3.0726579716429114,3.091029403731227,0.17047305676712743,2.3820935943339285,6.441063601523638,6.430621026083827,6.473009032197297,0.08139773683898717,2.5067609555745474,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,5,3.048403072170913,3.052991349250078,3.072553128004074,0.17198775476453956,2.403259123159488,6.407246896997094,6.406409083865583,6.420211470685899,0.08182734463466983,2.5199913485836056,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,5,2.9911579517647624,3.0331225134432316,3.045725799165666,0.1752792759374923,2.449252969051604,6.385393836535513,6.382934865541756,6.403305334970355,0.08210738654836988,2.5286156440793923,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,1,5,1.8822504533454776,1.8811115296557546,1.8892435124143958,0.2785431657451006,3.8922038675782376,5.046942573972046,5.043365992605686,5.0634910352528095,0.10388229949432032,3.1992055609946117,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,1,7,9.156141332350671,9.012500522658229,10.730735887773335,0.057260802446066954,,18.346093399450183,18.29959498718381,18.474590103141963,0.02857763713422022,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,1,7,3.3304048283025622,3.330954583361745,3.3437674399465322,0.15742470571279427,2.7492577642632607,6.747524798847735,6.6985663725063205,6.975994328968227,0.07770078890106961,2.718936787395449,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,1,7,3.1805677665397525,3.180499654263258,3.187753399834037,0.16484100905367305,2.8787757420782603,6.5445497911423445,6.542496965266764,6.557909771800041,0.0801106289556529,2.803262865274632,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,7,2.9923936538398266,2.959055360406637,3.0513782519847155,0.17520689476374068,3.0598050896818174,6.439674673601985,6.414378061890602,6.499357661232352,0.08141529294161429,2.8489161843308506,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,7,3.0570596596226096,3.056291490793228,3.0761950882151723,0.17150074201192486,2.9950810097965133,6.428387816995382,6.409163004718721,6.460378412157297,0.08155824056132495,2.853918264070309,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,7,3.0619724839925766,3.061673021875322,3.0807194765657187,0.17122557525937293,2.9902755103833485,6.415656288154423,6.413351511582732,6.4297555247321725,0.081720088553999,2.8595817131481276,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,7,3.042473620735109,3.0388199957087636,3.0632448382675648,0.17232294026375944,3.0094398419593875,6.388430343940854,6.389048998244107,6.399145186878741,0.08206835979627831,2.871768558430108,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,1,7,1.8896355107426643,1.883591990917921,1.9149431493133307,0.2774545657188377,4.845453676276503,5.0843368750065565,5.057928501628339,5.145174008794129,0.10311826554555825,3.6083551996004446,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,2,3,45.595936393365264,45.675908448174596,46.11496950965375,0.045994274180651766,,80.62302918639034,80.68823453504592,81.43389630131423,0.02601182343510869,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,2,3,11.85663956683129,11.822210508398712,11.90752275288105,0.17687574866210323,3.8456036498671153,25.430614417418838,25.363861583173275,25.43278068769723,0.08246564418685631,3.170313853336047,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,2,3,11.422737971879542,11.390020838007331,11.460354528389871,0.1835945116803661,3.9916818984741824,24.925604569725692,24.912420543842018,24.930561799556017,0.08413645470999621,3.2345465868584786,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,3,10.840446311049163,10.825388482771814,10.85302964784205,0.19345624154445284,4.206094019108,24.0234538866207,24.002613499760628,24.026684556156397,0.08729602370656452,3.3560132346869342,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,3,10.878428206779063,10.826261015608907,11.14368592388928,0.1927807915019494,4.191408494560955,23.699625409208238,23.694894975051284,23.71811883058399,0.0884888247721069,3.4018693457941795,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,3,10.918543636798859,10.834143031388521,11.191598395816982,0.1920725025022523,4.176008990768045,23.730977713130414,23.69663491845131,23.714118194766343,0.08837191730367012,3.3973749485164024,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,3,10.793533944524825,10.768141131848097,10.864380979910493,0.19429706811306324,4.224375133085532,23.638727851212025,23.631693911738694,23.66107157431543,0.08871678768840655,3.4106331649423582,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,2,3,7.474436634220183,7.435761974193156,7.621926814317703,0.2805765976259157,6.100250577363058,20.239599947817624,20.2378734247759,20.263943052850664,0.10361627726866854,3.9834299785694967,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,2,5,38.853937541134655,38.41745341196656,40.494335955008864,0.05397527593644133,,66.22083507943898,66.26046146266162,67.06938743591309,0.031669066049744635,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,2,5,11.842856714501977,11.828812537714839,11.868507391773164,0.17708159868488205,3.2807909846242334,25.389751312322915,25.350986630655825,25.415705051273108,0.08259836712075817,2.608171866862623,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,2,5,11.411372288130224,11.404957505874336,11.43825901672244,0.1837773711213853,3.4048435683365956,24.930950277484953,24.917138973250985,24.93941232096404,0.08411841412615266,2.656169714446977,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,5,10.863104425370693,10.828325641341507,11.124962358735502,0.19305273316733643,3.5766882117409824,24.023548257537186,24.010553024709225,24.034774862229824,0.08729568078445848,2.756496849238861,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,5,10.90485557448119,10.82881959155202,11.15186910610646,0.19231359697304146,3.562994234610317,23.72426231391728,23.71151139959693,23.792543495073915,0.08839693189405325,2.7912705652639866,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,5,10.839254357852042,10.825947392731905,10.844685533083975,0.19347751522048257,3.584558149333266,23.702450119890273,23.698252509348094,23.726728814654052,0.08847827922397537,2.7938392336861733,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,5,10.759746134281158,10.75667655095458,10.77785084489733,0.19490720076734483,3.6110459351214446,23.636622526682913,23.63265841268003,23.660433711484075,0.08872468973232393,2.8016200286095696,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,2,5,7.5289920112118125,7.6012545032426715,7.6225581811740994,0.2785435283869371,5.160576274124775,20.23725756444037,20.235952455550432,20.263770665042102,0.10362827044732498,3.2722237619685215,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,2,7,43.63379757385701,43.68262959178537,44.327281741425395,0.048062559680950134,,75.68363416008651,75.79991698730737,76.05113848112524,0.02770945163077252,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,2,7,11.998974299058318,11.853986419737339,12.264510500244796,0.1747776057962375,3.6364606245786693,25.40259242989123,25.356607511639595,25.428785989060998,0.08255661329795148,2.9793665496530024,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,2,7,11.41712686046958,11.413594475015998,11.448350944556296,0.18368474184701714,3.821784421519722,24.922909317538142,24.915332440286875,24.959108070470393,0.0841455535258977,3.036709446550378,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,7,10.88796194177121,10.827695950865746,11.144098523072898,0.19261198847089686,4.007526643389317,24.021303891204298,24.00768455117941,24.02521837502718,0.08730383702309759,3.1506880102290795,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,7,10.852394071407616,10.832656407728791,10.89512084145099,0.1932432591556259,4.020660997633443,23.70872948784381,23.69790489319712,23.748202505521476,0.08845484533767504,3.192226483451668,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,7,10.89978810865432,10.831900988705456,11.149773467332125,0.19240300628733165,4.003178514930233,23.71836454141885,23.704793537035584,23.77322791144252,0.08841891254086219,3.190929713046693,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,7,10.774082760326564,10.765377548523247,10.784391174092889,0.19464784582148828,4.049885131245684,23.65874081850052,23.644065018743277,23.685648757964373,0.08864174201359364,3.198971354422374,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,2,7,7.470769472420216,7.435820531100035,7.61996959336102,0.2807143237041433,5.840602863592509,20.250088809989393,20.24794602766633,20.27125945314765,0.10356260753609496,3.7374470240718987,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,3,3,101.28841781057417,101.60925146192312,107.4452179018408,0.04658570152437898,,175.37028575781733,175.02894543576986,177.75019966065884,0.026906450996586024,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,3,3,26.66485572233796,26.59923385363072,27.004664088599384,0.17695921737341677,3.798573630598038,58.96571567747742,58.67917649447918,60.70718690752983,0.08002263596373708,2.974105948565666,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,3,3,25.672803041525185,25.64728946890682,25.70093993563205,0.18379730457822557,3.9453587380677684,56.8397456035018,56.80113600101322,56.884816801175475,0.08301571285901913,3.085346070708189,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,3,24.374681571498513,24.3739370489493,24.404809717088938,0.19358579049161745,4.155476555189605,53.81168487481773,53.78810840193182,53.942976240068674,0.08768712615070265,3.2589629216364755,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,3,24.441681993193924,24.361361982300878,24.429271719418466,0.1930551261289607,4.144085412729743,53.07756224647164,53.07193798944354,53.1614552019164,0.08889993813371996,3.3040380593114964,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,3,24.375403551384807,24.365137447603047,24.41326177213341,0.19358005663590047,4.155353473309771,53.05147184524685,53.03926358465105,53.125715954229236,0.08894365859941288,3.305662965758595,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,3,24.276557215489447,24.22345709055662,24.269587476737797,0.1943682523891544,4.17227273667734,52.906081513501704,52.899461006745696,52.96852015890181,0.08918808320355022,3.314747203741834,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,3,3,16.926095513626933,16.923529445193708,16.972190444357693,0.2787761652532999,5.9841572871326685,46.062268051318824,46.03102651890367,46.21587647125125,0.10243941949933792,3.807243828341975,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,3,5,83.52379720192403,83.7456090375781,85.80605688039213,0.05649398324878007,,145.5188752664253,145.75104543473572,148.14186077564955,0.03242597904471772,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,3,5,26.591138602234423,26.581451063975692,26.635813480243087,0.17744979147314519,3.1410387667606536,57.973297983407974,57.822484406642616,58.44167503528297,0.08139250593179065,2.510101724902264,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,3,5,25.637969239614904,25.634926394559443,25.683102500624955,0.1840470263420472,3.2578164214685903,56.7909628059715,56.76805193070322,56.88152206130326,0.08308702242152946,2.562359714935563,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,5,24.41519634798169,24.391590617597103,24.450342124328017,0.19326455264776388,3.4209758550161573,53.789559081196785,53.765413467772305,53.912423201836646,0.08772319536728602,2.705336830271478,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,5,24.383281343616545,24.362614494748414,24.392798193730414,0.19351751446018198,3.4254535320690236,53.027258049696684,53.017987054772675,53.09310944285244,0.08898427287297746,2.744227791865955,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,5,24.393998621962965,24.367429548874497,24.428526940755546,0.19343249432471676,3.423948590647372,53.11096484772861,53.071493515744805,53.206104156561196,0.08884402709550474,2.73990268645343,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,5,24.267674176953733,24.23390606418252,24.288278119638562,0.194439399737825,3.441771823409596,52.94822241179645,52.909665973857045,53.120820759795606,0.08911709940518672,2.7483240916885783,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,3,5,17.1291301259771,16.946400050073862,18.221904919482768,0.2754717820050909,4.876126025527497,46.0351287573576,46.020424575544894,46.14793793298304,0.1024998110653888,3.1610398231626027,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,3,7,90.59936547186226,89.40491802059114,93.76176362857223,0.05208195416628463,,158.36064806673676,158.1671329913661,160.4503425071016,0.02979649336880384,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,3,7,26.627318738028407,26.58525703009218,26.68023353908211,0.17720867979324692,3.4024967501692434,58.71118636801839,58.641764568164945,60.120047559030354,0.08036955632990492,2.6972823726314648,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,3,7,25.703592961654067,25.652661453932524,26.049211691133678,0.18357713674658,3.5247743615852856,56.822463613934815,56.79537542164326,56.92225047387183,0.0830409612659392,2.786937383473484,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,7,24.43074162583798,24.387317011132836,24.824217706918716,0.1931415784369645,3.7084165048859696,53.77089409157634,53.76591603271663,53.81281706504524,0.08775364590300178,2.9450997745552696,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,7,24.416185086593032,24.371870909817517,24.417825369164348,0.1932567263585738,3.7106274035254803,53.0829459335655,53.052632487379014,53.19422874599695,0.0888909218773469,2.983267889180994,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,7,24.362534414976835,24.361073039472103,24.400543165393174,0.19368231234182484,3.718798870784413,53.057269509881735,53.02855698391795,53.15046065952629,0.0889339395635725,2.984711605583899,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,7,24.26755757071078,24.222337524406612,24.37746999785304,0.1944403340241791,3.733353272486279,52.89180411491543,52.88355250377208,52.94667130801827,0.08921215827216153,2.9940489025988666,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,3,7,16.87394628766924,16.87477866653353,16.914236755110323,0.27963772786500724,5.369186551107395,46.165161593817174,46.10549903009087,46.407426544465125,0.10221110112245234,3.4303063738857515,True,1e-05,10,float32,cuda,50 From ddf147debcddc6c4d3863cb9e79023ed0aad278f Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Mon, 1 Sep 2025 10:40:41 +0800 Subject: [PATCH 18/22] Update log --- Converse2D/README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/Converse2D/README.md b/Converse2D/README.md index fb42493..7508258 100644 --- a/Converse2D/README.md +++ b/Converse2D/README.md @@ -29,16 +29,15 @@ - NVIDIA RTX 4090 - NVIDIA RTX 5060ti 16g -**v7** fastest +Only **v7** left in main branch. -We highly recommend you to run `test/test_speed.py` first to choose the most suitable backend for GPU. +If you want test more, switch to branch `dev`. We recommend you to run `test/test_speed.py` first to choose the most suitable backend for GPU. **Installation** ```python cd ./Converse2D -# Remember to choose the wanted kernel version -pip install . --no-build-isolation --config-settings=--variant=v7 +pip install . --no-build-isolation ``` **Usage** From 9a71245763726c1998a76ca52ae7a2f318da269a Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Mon, 1 Sep 2025 20:15:20 +0800 Subject: [PATCH 19/22] maintain the most stable original CUDA backend --- Converse2D/torch_converse2d/converse2d.cpp | 266 +++++++++------------ Converse2D/torch_converse2d/converse2d.cu | 251 ------------------- 2 files changed, 116 insertions(+), 401 deletions(-) delete mode 100644 Converse2D/torch_converse2d/converse2d.cu diff --git a/Converse2D/torch_converse2d/converse2d.cpp b/Converse2D/torch_converse2d/converse2d.cpp index 867b542..b5c713d 100644 --- a/Converse2D/torch_converse2d/converse2d.cpp +++ b/Converse2D/torch_converse2d/converse2d.cpp @@ -5,190 +5,156 @@ #include #include #include -#include -#include #include #include #include #include #include -#include -#include - -#include -#include -#include -#include using at::Tensor; -using at::indexing::Slice; - -#ifdef CONVERSE2D_USE_CUDA_KERNELS -Tensor block_mean_cuda(const Tensor &input, int64_t s); -Tensor block_mean_cuda_backward(const Tensor &grad_out, int64_t s); - -Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t s); -Tensor sfold_upsample_cuda_backward(const Tensor &grad_out, int64_t s); -#endif - -static inline Tensor sfold_upsample_zero_insertion_autograd(const Tensor &x, int64_t s) { - TORCH_CHECK(x.dim() == 4, "sfold_upsample expects (B,C,H,W)"); - if (s == 1) return x; - const auto B = x.size(0), C = x.size(1), H = x.size(2), W = x.size(3); - auto y = at::zeros({B, C, H * s, W * s}, x.options()); - y.index_put_({Slice(), Slice(), Slice(0, c10::nullopt, s), Slice(0, c10::nullopt, s)}, x); - return y; -} -static inline Tensor block_mean_autograd(const Tensor &input, int64_t s) { - if (s == 1) return input; - return at::avg_pool2d(input, /*kernel_size=*/{s, s}, /*stride=*/{s, s}, - /*padding=*/{0, 0}, /*ceil_mode=*/false, /*count_include_pad=*/true); -} - -static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) { -#ifdef CONVERSE2D_USE_CUDA_KERNELS - if (s == 1) return x; - return sfold_upsample_cuda_launcher(x, s); -#else - return sfold_upsample_zero_insertion_autograd(x, s); -#endif -} +namespace +{ -static inline Tensor block_mean(const Tensor &input, int64_t s) { -#ifdef CONVERSE2D_USE_CUDA_KERNELS - if (s == 1) return input; - return block_mean_cuda(input, s); -#else - return block_mean_autograd(input, s); -#endif -} - -struct FBKey { - int64_t device_id; - at::ScalarType dtype; - int64_t channels; - int64_t H, W; - void *ptr; - bool operator==(const FBKey &other) const { - return device_id == other.device_id && dtype == other.dtype && - channels == other.channels && H == other.H && W == other.W && - ptr == other.ptr; - } -}; -namespace std { -template <> struct hash { - size_t operator()(const FBKey &k) const { - return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ - ((hash()(k.H) ^ hash()(k.W)) << 1) ^ - ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) + { + TORCH_CHECK(s >= 1, "scale must be >= 1"); + if (s == 1) + return x; + auto sizes = x.sizes().vec(); + sizes[sizes.size() - 2] *= s; + sizes[sizes.size() - 1] *= s; + Tensor z = at::zeros(sizes, x.options()); + z.index_put_( + {at::indexing::Slice(), at::indexing::Slice(), + at::indexing::Slice(0, z.size(-2), s), + at::indexing::Slice(0, z.size(-1), s)}, + x); + return z; } -}; -} - -constexpr size_t FB_CACHE_MAX_SIZE = 64; -static std::unordered_map> fb_cache; -static std::list fb_cache_lru; -static std::mutex fb_cache_mutex; - -static inline std::pair p2o_cached_rfft(const Tensor &psf, int64_t H, int64_t W) { - auto C = psf.size(1); - FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; + static inline Tensor p2o(const Tensor &psf, int64_t H, int64_t W) { - std::lock_guard lock(fb_cache_mutex); - auto it = fb_cache.find(key); - if (it != fb_cache.end()) { - fb_cache_lru.remove(key); - fb_cache_lru.push_front(key); - return it->second; - } + TORCH_CHECK(psf.dim() == 4 && psf.size(0) == 1, "psf must be (1,C,kh,kw)"); + auto C = psf.size(1); + auto kh = psf.size(2); + auto kw = psf.size(3); + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + const int64_t sh = -static_cast(kh / 2); + const int64_t sw = -static_cast(kw / 2); + otf = at::roll(otf, {sh, sw}, {-2, -1}); + return at::fft_fftn(otf, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); } - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - int64_t kh = psf.size(2), kw = psf.size(3); - otf.index_put_({0, Slice(), Slice(0, kh), Slice(0, kw)}, psf); - otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); - - Tensor FB = at::fft_rfft2(otf, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor F2B = at::abs(FB).pow(2); - + static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) { - std::lock_guard lock(fb_cache_mutex); - fb_cache[key] = {FB, F2B}; - fb_cache_lru.push_front(key); - if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) { - fb_cache.erase(fb_cache_lru.back()); - fb_cache_lru.pop_back(); - } + TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); + TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); + + const auto &sizes = a.sizes(); + const int64_t L = a.dim(); + const int64_t W = sizes[L - 2]; + const int64_t H = sizes[L - 1]; + const int64_t W_s = W / s; + const int64_t H_s = H / s; + + std::vector view_shape; + view_shape.reserve(L + 2); + for (int64_t i = 0; i < L - 2; ++i) + view_shape.push_back(sizes[i]); + view_shape.push_back(s); + view_shape.push_back(W_s); + view_shape.push_back(s); + view_shape.push_back(H_s); + Tensor v = a.view(view_shape); + + std::vector perm; + perm.reserve(view_shape.size()); + for (int64_t i = 0; i < L - 2; ++i) + perm.push_back(i); + perm.push_back(L - 2 + 1); // W_s + perm.push_back(L - 2 + 3); // H_s + perm.push_back(L - 2 + 0); // s + perm.push_back(L - 2 + 2); // s + Tensor p = v.permute(perm).contiguous(); + + std::vector merge_shape; + merge_shape.reserve(L + 1); + for (int64_t i = 0; i < L - 2; ++i) + merge_shape.push_back(p.size(i)); + merge_shape.push_back(W_s); + merge_shape.push_back(H_s); + merge_shape.push_back(s * s); + Tensor r = p.view(merge_shape); + + return r.mean(-1, /*keepdim=*/false); } - return {FB, F2B}; -} -Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) { - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) + { + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + TORCH_CHECK(scale >= 1, "scale must be >= 1"); - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - Tensor STy = sfold_upsample_zero_insertion(x, scale); + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - auto [FB, F2B] = p2o_cached_rfft(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); + Tensor STy = sfold_upsample_zero_insertion(x, scale); - Tensor F_STy = at::fft_rfft2(STy, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor F_lambda_x0 = at::fft_rfft2(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor FB = p2o(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + Tensor F2B = at::abs(FB).pow(2.0); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + F_lambda_x0; + Tensor F_STy = at::fft_fftn(STy, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); - Tensor x1 = FB * FR; - Tensor x1_real = at::fft_irfft2(x1, {Hs, Ws}, {-2, -1}, c10::nullopt); - Tensor F2B_real = at::fft_irfft2(F2B, {Hs, Ws}, {-2, -1}, c10::nullopt); + Tensor x1 = FB * FR; - Tensor FBR = block_mean(x1_real, scale); - Tensor invW = block_mean(F2B_real, scale); + Tensor FBR = splits_mean_then_mean(x1, scale); + Tensor invW = splits_mean_then_mean(F2B, scale); - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; - Tensor invWBR_exp = invWBR.view({B, C, H, 1, W, 1}) - .expand({B, C, H, scale, W, scale}) - .reshape({B, C, Hs, Ws}); - Tensor FCBinvWBR = FBC * at::fft_rfft2(invWBR_exp, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor invWBR_rep = invWBR.repeat({1, 1, scale, scale}); - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out = at::fft_irfft2(FX, {Hs, Ws}, {-2, -1}, c10::nullopt); - return out; -} + Tensor FCBinvWBR = FBC * invWBR_rep; + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = at::fft_ifftn(FX, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); + Tensor out = at::real(out_c); + (void)B; + (void)C; + (void)H; + (void)W; + return out; + } -void clear_fb_cache() { - std::lock_guard lock(fb_cache_mutex); - fb_cache.clear(); - fb_cache_lru.clear(); } -TORCH_LIBRARY(converse2d, m) { +TORCH_LIBRARY(converse2d, m) +{ m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); - m.def("clear_cache() -> ()"); } -TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) { +TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) +{ m.impl("forward", TORCH_FN(converse2d_forward)); - m.impl("clear_cache", TORCH_FN(clear_fb_cache)); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d.cu b/Converse2D/torch_converse2d/converse2d.cu deleted file mode 100644 index e7ebeec..0000000 --- a/Converse2D/torch_converse2d/converse2d.cu +++ /dev/null @@ -1,251 +0,0 @@ -// converse2d_v7.cu (PyTorch 2.4 / CUDA 12.1 兼容版) -#include -#include -#include // getCurrentCUDAStream -#include // CUDAGuard (moved from at::cuda) -#include - -using at::Tensor; - -// ---- 简易累加类型映射:half/bfloat16/float 用 float;double 用 double ---- -template struct acc_type_map { using type = float; }; -template <> struct acc_type_map { using type = double; }; -template <> struct acc_type_map { using type = float; }; -template <> struct acc_type_map { using type = float; }; -template <> struct acc_type_map { using type = float; }; -template using acc_t = typename acc_type_map::type; - -// ======================= block_mean forward ======================= -template -__global__ void block_mean_forward_kernel( - const scalar_t* __restrict__ x, - scalar_t* __restrict__ y, - int B, int C, int Hin, int Win, int s) -{ - const int Hout = Hin / s; - const int Wout = Win / s; - const int Nout = B * C * Hout * Wout; - - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= Nout) return; - - int wout = idx % Wout; - int t = idx / Wout; - int hout = t % Hout; - t /= Hout; - int c = t % C; - int b = t / C; - - const int h0 = hout * s; - const int w0 = wout * s; - - const int64_t base_in = (((int64_t)b * C + c) * Hin + h0) * Win + w0; - const int64_t base_out = (((int64_t)b * C + c) * Hout + hout) * Wout + wout; - - acc_t sum = static_cast>(0); - for (int dh = 0; dh < s; ++dh) { - int64_t row = base_in + (int64_t)dh * Win; - for (int dw = 0; dw < s; ++dw) { - sum += static_cast>(x[row + dw]); - } - } - const acc_t denom = static_cast>(s) * static_cast>(s); - y[base_out] = static_cast(sum / denom); -} - -Tensor block_mean_cuda(const Tensor &input, int64_t s) -{ - TORCH_CHECK(input.is_cuda(), "block_mean_cuda: input must be CUDA tensor"); - TORCH_CHECK(input.dim() == 4, "block_mean_cuda: expect (B,C,H,W)"); - TORCH_CHECK(s >= 1, "block_mean_cuda: s must be >= 1"); - if (s == 1) return input; - - auto x = input.contiguous(); - const int B = x.size(0); - const int C = x.size(1); - const int Hin = x.size(2); - const int Win = x.size(3); - TORCH_CHECK(Hin % s == 0 && Win % s == 0, "block_mean_cuda: H/W must be divisible by s"); - - const int Hout = Hin / s; - const int Wout = Win / s; - - auto y = at::empty({B, C, Hout, Wout}, x.options()); - - const int threads = 256; - const int blocks = (B * C * Hout * Wout + threads - 1) / threads; - - c10::cuda::CUDAGuard guard(x.device()); - AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, x.scalar_type(), "block_mean_forward", [&] { - block_mean_forward_kernel<<>>( - x.data_ptr(), y.data_ptr(), B, C, Hin, Win, (int)s); - }); - return y; -} - -// ======================= block_mean backward ======================= -// grad_x[b,c,i*s+p,j*s+q] = grad_y[b,c,i,j] / (s*s) -template -__global__ void block_mean_backward_kernel( - const scalar_t* __restrict__ gy, - scalar_t* __restrict__ gx, - int B, int C, int Hin, int Win, int s) -{ - const int Hout = Hin / s; - const int Wout = Win / s; - const int Nin = B * C * Hin * Win; - - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= Nin) return; - - int win = idx % Win; - int t = idx / Win; - int hin = t % Hin; - t /= Hin; - int c = t % C; - int b = t / C; - - const int hout = hin / s; - const int wout = win / s; - - const int64_t index_out = (((int64_t)b * C + c) * Hout + hout) * Wout + wout; - const acc_t denom = static_cast>(s) * static_cast>(s); - gx[idx] = static_cast( static_cast>(gy[index_out]) / denom ); -} - -Tensor block_mean_cuda_backward(const Tensor &grad_out, int64_t s) -{ - TORCH_CHECK(grad_out.is_cuda(), "block_mean_cuda_backward: grad_out must be CUDA tensor"); - TORCH_CHECK(grad_out.dim() == 4, "block_mean_cuda_backward: expect (B,C,Hout,Wout)"); - TORCH_CHECK(s >= 1, "block_mean_cuda_backward: s must be >= 1"); - - auto gy = grad_out.contiguous(); - const int B = gy.size(0); - const int C = gy.size(1); - const int Hout = gy.size(2); - const int Wout = gy.size(3); - - const int Hin = Hout * s; - const int Win = Wout * s; - - auto gx = at::empty({B, C, Hin, Win}, gy.options()); - - const int threads = 256; - const int blocks = (B * C * Hin * Win + threads - 1) / threads; - - c10::cuda::CUDAGuard guard(gy.device()); - AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, gy.scalar_type(), "block_mean_backward", [&] { - block_mean_backward_kernel<<>>( - gy.data_ptr(), gx.data_ptr(), B, C, Hin, Win, (int)s); - }); - return gx; -} - -// ======================= s-fold zero insertion forward ======================= -template -__global__ void sfold_upsample_forward_kernel( - const scalar_t* __restrict__ x, - scalar_t* __restrict__ y, - int B, int C, int H, int W, int s, int Hs, int Ws) -{ - const int Nin = B * C * H * W; - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= Nin) return; - - int w = idx % W; - int t = idx / W; - int h = t % H; - t /= H; - int c = t % C; - int b = t / C; - - const int hs = h * s; - const int ws = w * s; - - const int64_t out_index = (((int64_t)b * C + c) * Hs + hs) * Ws + ws; - const scalar_t v = x[idx]; - y[out_index] = v; // 其它位置保持 0 -} - -Tensor sfold_upsample_cuda_launcher(const Tensor &x, int64_t s) -{ - TORCH_CHECK(x.is_cuda(), "sfold_upsample_cuda: x must be CUDA tensor"); - TORCH_CHECK(x.dim() == 4, "sfold_upsample_cuda: expect (B,C,H,W)"); - TORCH_CHECK(s >= 1, "sfold_upsample_cuda: s must be >= 1"); - if (s == 1) return x; - - auto xx = x.contiguous(); - const int B = xx.size(0); - const int C = xx.size(1); - const int H = xx.size(2); - const int W = xx.size(3); - const int Hs = H * s; - const int Ws = W * s; - - auto y = at::zeros({B, C, Hs, Ws}, xx.options()); - - const int threads = 256; - const int blocks = (B * C * H * W + threads - 1) / threads; - - c10::cuda::CUDAGuard guard(xx.device()); - AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, xx.scalar_type(), "sfold_upsample_forward", [&] { - sfold_upsample_forward_kernel<<>>( - xx.data_ptr(), y.data_ptr(), B, C, H, W, (int)s, Hs, Ws); - }); - return y; -} - -// ======================= s-fold zero insertion backward ======================= -// grad_x[b,c,h,w] = grad_y[b,c,h*s, w*s] -template -__global__ void sfold_upsample_backward_kernel( - const scalar_t* __restrict__ gy, - scalar_t* __restrict__ gx, - int B, int C, int H, int W, int s, int Hs, int Ws) -{ - const int Nin = B * C * H * W; - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= Nin) return; - - int w = idx % W; - int t = idx / W; - int h = t % H; - t /= H; - int c = t % C; - int b = t / C; - - const int hs = h * s; - const int ws = w * s; - - const int64_t in_index = (((int64_t)b * C + c) * Hs + hs) * Ws + ws; - gx[idx] = gy[in_index]; -} - -Tensor sfold_upsample_cuda_backward(const Tensor &grad_out, int64_t s) -{ - TORCH_CHECK(grad_out.is_cuda(), "sfold_upsample_cuda_backward: grad_out must be CUDA tensor"); - TORCH_CHECK(grad_out.dim() == 4, "sfold_upsample_cuda_backward: expect (B,C,Hs,Ws)"); - TORCH_CHECK(s >= 1, "sfold_upsample_cuda_backward: s must be >= 1"); - if (s == 1) return grad_out; - - auto gy = grad_out.contiguous(); - const int B = gy.size(0); - const int C = gy.size(1); - const int Hs = gy.size(2); - const int Ws = gy.size(3); - TORCH_CHECK(Hs % s == 0 && Ws % s == 0, "sfold_upsample_cuda_backward: Hs/Ws must be divisible by s"); - const int H = Hs / s; - const int W = Ws / s; - - auto gx = at::empty({B, C, H, W}, gy.options()); - - const int threads = 256; - const int blocks = (B * C * H * W + threads - 1) / threads; - - c10::cuda::CUDAGuard guard(gy.device()); - AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, gy.scalar_type(), "sfold_upsample_backward", [&] { - sfold_upsample_backward_kernel<<>>( - gy.data_ptr(), gx.data_ptr(), B, C, H, W, (int)s, Hs, Ws); - }); - return gx; -} From eac211bf8145c705bab52d410e59c1c5ff45a1dd Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Thu, 4 Sep 2025 09:12:02 +0800 Subject: [PATCH 20/22] Fix v7 precision error --- Converse2D/setup.py | 2 +- Converse2D/torch_converse2d/converse2d.cpp | 332 ++++++++++++------ .../torch_converse2d/converse2d_kernel.cu | 203 +++++++++++ 3 files changed, 427 insertions(+), 110 deletions(-) create mode 100644 Converse2D/torch_converse2d/converse2d_kernel.cu diff --git a/Converse2D/setup.py b/Converse2D/setup.py index 6b310cd..d80a440 100644 --- a/Converse2D/setup.py +++ b/Converse2D/setup.py @@ -6,7 +6,7 @@ CPP = str(PKG_DIR / f"converse2d.cpp") -CU = str(PKG_DIR / f"converse2d.cu") +CU = str(PKG_DIR / f"converse2d_kernel.cu") has_cu = os.path.exists(CU) extra_cflags = ["-O3"] diff --git a/Converse2D/torch_converse2d/converse2d.cpp b/Converse2D/torch_converse2d/converse2d.cpp index b5c713d..bc7ebc8 100644 --- a/Converse2D/torch_converse2d/converse2d.cpp +++ b/Converse2D/torch_converse2d/converse2d.cpp @@ -3,158 +3,272 @@ #include #include #include -#include -#include +#include +#include #include #include #include #include +#include #include +#include +#include + +#include +#include +#include +#include using at::Tensor; +using at::indexing::Slice; -namespace -{ +void weighted_block_mean_kernel_launcher( + const at::Tensor &in, at::Tensor &out, + int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, + int Ws_full, long long total_out); + +void weighted_block_mean_grad_kernel_launcher( + const at::Tensor &grad_out, at::Tensor &grad_in, + int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, + int Ws_full, long long total_in); - static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) +struct WeightedBlockMeanFunction : public torch::autograd::Function +{ + static at::Tensor forward( + torch::autograd::AutogradContext *ctx, + const at::Tensor &input, int64_t Ws_full, int64_t s) { - TORCH_CHECK(s >= 1, "scale must be >= 1"); + TORCH_CHECK(input.is_cuda() && input.dim() == 4, + "weighted_block_mean: input must be (B,C,Hs,Ws_r) CUDA Tensor"); + TORCH_CHECK(s >= 1, "weighted_block_mean: s must be >= 1"); + if (s == 1) - return x; - auto sizes = x.sizes().vec(); - sizes[sizes.size() - 2] *= s; - sizes[sizes.size() - 1] *= s; - Tensor z = at::zeros(sizes, x.options()); - z.index_put_( - {at::indexing::Slice(), at::indexing::Slice(), - at::indexing::Slice(0, z.size(-2), s), - at::indexing::Slice(0, z.size(-1), s)}, - x); - return z; + { + ctx->saved_data["s"] = s; + return input; + } + + auto x = input.contiguous(); + const int B = (int)x.size(0); + const int C = (int)x.size(1); + const int Hs = (int)x.size(2); + const int Ws_r = (int)x.size(3); + + TORCH_CHECK(Hs % s == 0, "weighted_block_mean: H must be divisible by s"); + const int Ho = Hs / (int)s; + const int Wo_r = (Ws_r + (int)s - 1) / (int)s; // ceil_div for safety + + auto out = at::empty({B, C, Ho, Wo_r}, x.options()); + const long long total_out = out.numel(); + + weighted_block_mean_kernel_launcher( + x, out, B, C, Ho, Wo_r, (int)s, Hs, Ws_r, (int)Ws_full, total_out); + + ctx->save_for_backward({x}); + ctx->saved_data["s"] = s; + ctx->saved_data["Ws_full"] = Ws_full; + return out; } - static inline Tensor p2o(const Tensor &psf, int64_t H, int64_t W) + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext *ctx, + torch::autograd::tensor_list grad_outputs) { - TORCH_CHECK(psf.dim() == 4 && psf.size(0) == 1, "psf must be (1,C,kh,kw)"); - auto C = psf.size(1); - auto kh = psf.size(2); - auto kw = psf.size(3); - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); - const int64_t sh = -static_cast(kh / 2); - const int64_t sw = -static_cast(kw / 2); - otf = at::roll(otf, {sh, sw}, {-2, -1}); - return at::fft_fftn(otf, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); + auto go = grad_outputs[0].contiguous(); + auto s = ctx->saved_data["s"].toInt(); + + if (s == 1) + { + return {go, torch::Tensor(), torch::Tensor()}; + } + + auto Ws_full = ctx->saved_data["Ws_full"].toInt(); + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + + const int B = (int)input.size(0); + const int C = (int)input.size(1); + const int Hs = (int)input.size(2); + const int Ws_r = (int)input.size(3); + const int Ho = (int)go.size(2); + const int Wo_r = (int)go.size(3); + + auto grad_in = at::empty_like(input); + const long long total_in = input.numel(); + + weighted_block_mean_grad_kernel_launcher( + go, grad_in, B, C, Ho, Wo_r, (int)s, Hs, Ws_r, (int)Ws_full, total_in); + + return {grad_in, torch::Tensor(), torch::Tensor()}; } +}; + +static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) +{ + TORCH_CHECK(x.dim() == 4, "sfold_upsample expects (B,C,H,W)"); + if (s == 1) + return x; + const auto B = x.size(0), C = x.size(1), H = x.size(2), W = x.size(3); + auto y = at::zeros({B, C, H * s, W * s}, x.options()); + y.index_put_({Slice(), Slice(), Slice(0, c10::nullopt, s), Slice(0, c10::nullopt, s)}, x); + return y; +} + +at::Tensor weighted_block_mean_cuda(const at::Tensor &input, int64_t Ws_full, int64_t s) +{ + return WeightedBlockMeanFunction::apply(input, Ws_full, s); +} - static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) +struct FBKey +{ + int64_t device_id; + at::ScalarType dtype; + int64_t channels; + int64_t H, W; + void *ptr; + bool operator==(const FBKey &other) const { - TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); - TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); - - const auto &sizes = a.sizes(); - const int64_t L = a.dim(); - const int64_t W = sizes[L - 2]; - const int64_t H = sizes[L - 1]; - const int64_t W_s = W / s; - const int64_t H_s = H / s; - - std::vector view_shape; - view_shape.reserve(L + 2); - for (int64_t i = 0; i < L - 2; ++i) - view_shape.push_back(sizes[i]); - view_shape.push_back(s); - view_shape.push_back(W_s); - view_shape.push_back(s); - view_shape.push_back(H_s); - Tensor v = a.view(view_shape); - - std::vector perm; - perm.reserve(view_shape.size()); - for (int64_t i = 0; i < L - 2; ++i) - perm.push_back(i); - perm.push_back(L - 2 + 1); // W_s - perm.push_back(L - 2 + 3); // H_s - perm.push_back(L - 2 + 0); // s - perm.push_back(L - 2 + 2); // s - Tensor p = v.permute(perm).contiguous(); - - std::vector merge_shape; - merge_shape.reserve(L + 1); - for (int64_t i = 0; i < L - 2; ++i) - merge_shape.push_back(p.size(i)); - merge_shape.push_back(W_s); - merge_shape.push_back(H_s); - merge_shape.push_back(s * s); - Tensor r = p.view(merge_shape); - - return r.mean(-1, /*keepdim=*/false); + return device_id == other.device_id && dtype == other.dtype && + channels == other.channels && H == other.H && W == other.W && + ptr == other.ptr; } +}; +namespace std +{ + template <> + struct hash + { + size_t operator()(const FBKey &k) const + { + return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ + ((hash()(k.H) ^ hash()(k.W)) << 1) ^ + ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + } + }; +} - Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +constexpr size_t FB_CACHE_MAX_SIZE = 64; +static std::unordered_map> fb_cache; +static std::list fb_cache_lru; +static std::mutex fb_cache_mutex; + +static inline std::pair p2o_cached_rfft(const Tensor &psf, int64_t H, int64_t W) +{ + const bool training_with_grad = at::GradMode::is_enabled() && psf.requires_grad(); + + auto C = psf.size(1); + FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; + + if (!training_with_grad) { - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); - TORCH_CHECK(scale >= 1, "scale must be >= 1"); + std::lock_guard lock(fb_cache_mutex); + auto it = fb_cache.find(key); + if (it != fb_cache.end()) + { + fb_cache_lru.remove(key); + fb_cache_lru.push_front(key); + return it->second; + } + } - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + int64_t kh = psf.size(2), kw = psf.size(3); + otf.index_put_({0, Slice(), Slice(0, kh), Slice(0, kw)}, psf); + otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; + Tensor FB = at::fft_rfftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor F2B = at::abs(FB).pow(2); - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + if (!training_with_grad) + { + std::lock_guard lock(fb_cache_mutex); + fb_cache[key] = {FB, F2B}; + fb_cache_lru.push_front(key); + if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) + { + fb_cache.erase(fb_cache_lru.back()); + fb_cache_lru.pop_back(); + } + } + return {FB, F2B}; +} - Tensor STy = sfold_upsample_zero_insertion(x, scale); +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +{ + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); - Tensor FB = p2o(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); - Tensor F2B = at::abs(FB).pow(2.0); + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); - Tensor F_STy = at::fft_fftn(STy, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; - Tensor x1 = FB * FR; + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + Tensor STy = sfold_upsample_zero_insertion(x, scale); - Tensor FBR = splits_mean_then_mean(x1, scale); - Tensor invW = splits_mean_then_mean(F2B, scale); + auto [FB, F2B] = p2o_cached_rfft(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + Tensor F_STy = at::fft_rfftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor F_lx0 = at::fft_rfftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + F_lx0; - Tensor invWBR_rep = invWBR.repeat({1, 1, scale, scale}); + Tensor x1 = FB * FR; + Tensor FBR = weighted_block_mean_cuda(x1, Ws, scale); + Tensor invW = weighted_block_mean_cuda(F2B, Ws, scale); - Tensor FCBinvWBR = FBC * invWBR_rep; + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out_c = at::fft_ifftn(FX, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); - Tensor out = at::real(out_c); - (void)B; - (void)C; - (void)H; - (void)W; - return out; + const int64_t Ws_rfft = Ws / 2 + 1; + const int64_t Ho = invWBR.size(2); + const int64_t Wo_r = invWBR.size(3); + + Tensor invWBR_exp = invWBR + .view({B, C, Ho, 1, Wo_r, 1}) + .expand({B, C, Ho, (int64_t)scale, Wo_r, (int64_t)scale}) + .reshape({B, C, Ho * (int64_t)scale, Wo_r * (int64_t)scale}); + + if (invWBR_exp.size(-2) != Hs || invWBR_exp.size(-1) != Ws_rfft) + { + using at::indexing::Slice; + invWBR_exp = invWBR_exp.index({Slice(), Slice(), + Slice(0, Hs), Slice(0, Ws_rfft)}); } + Tensor FCBinvWBR = FBC * invWBR_exp; + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out = at::fft_irfftn(FX, {Hs, Ws}, {-2, -1}, c10::nullopt); + return out; +} + +void clear_fb_cache() +{ + std::lock_guard lock(fb_cache_mutex); + fb_cache.clear(); + fb_cache_lru.clear(); } TORCH_LIBRARY(converse2d, m) { m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); + m.def("clear_cache() -> ()"); } TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) { m.impl("forward", TORCH_FN(converse2d_forward)); + m.impl("clear_cache", TORCH_FN(clear_fb_cache)); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_kernel.cu b/Converse2D/torch_converse2d/converse2d_kernel.cu new file mode 100644 index 0000000..cf9f625 --- /dev/null +++ b/Converse2D/torch_converse2d/converse2d_kernel.cu @@ -0,0 +1,203 @@ +#include +#include +#include +#include +#include + +// -------- accumulator mapping (keep numerics stable) ---------- +template +struct acc_type_map +{ + using type = float; +}; +template <> +struct acc_type_map +{ + using type = double; +}; +template <> +struct acc_type_map +{ + using type = float; +}; +template <> +struct acc_type_map +{ + using type = float; +}; +template <> +struct acc_type_map +{ + using type = float; +}; +template <> +struct acc_type_map> +{ + using type = c10::complex; +}; +template <> +struct acc_type_map> +{ + using type = c10::complex; +}; +template <> +struct acc_type_map> +{ + using type = c10::complex; +}; + +template +using acc_scalar_t = typename acc_type_map::type; + +// ----------------------- forward kernel ----------------------- +template +__global__ void weighted_block_mean_kernel( + const scalar_t *__restrict__ in, // (B,C,Hs,Ws_r) + scalar_t *__restrict__ out, // (B,C,Ho,Wo_r) + int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, + int Ws_full, long long total_out) +{ + using acc_t = acc_scalar_t; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_out) + return; + + const int wo = static_cast(idx % Wo_r); + const int ho = static_cast((idx / Wo_r) % Ho); + const int c = static_cast((idx / (1LL * Wo_r * Ho)) % C); + const int b = static_cast(idx / (1LL * Wo_r * Ho * C)); + + const int hi0 = ho * s; + const int wi0 = wo * s; + const long long base_in = ((long long)b * C + c) * Hs * Ws_r; + + acc_t acc_val = acc_t(0); + float acc_w = 0.0f; + + // average over an s x s block with DC/Nyquist half-weights along last rFFT dim + for (int di = 0; di < s; ++di) + { + const int hi = hi0 + di; + if (hi >= Hs) + continue; + const long long row_off = base_in + (long long)hi * Ws_r; + for (int dj = 0; dj < s; ++dj) + { + const int wi = wi0 + dj; + if (wi >= Ws_r) + continue; + + float w = 2.0f; + if (wi == 0) + w = 1.0f; // DC + if ((Ws_full % 2 == 0) && wi == (Ws_r - 1)) + w = 1.0f; // Nyquist if even W + + acc_val += static_cast(in[row_off + wi]) * static_cast(w); + acc_w += w; + } + } + + const long long out_off = ((long long)b * C + c) * Ho * Wo_r + (long long)ho * Wo_r + wo; + out[out_off] = (acc_w > 1e-8f) + ? static_cast(acc_val / static_cast(acc_w)) + : static_cast(0); +} + +// ----------------------- backward kernel ---------------------- +template +__global__ void weighted_block_mean_grad_kernel( + const scalar_t *__restrict__ grad_out, // (B,C,Ho,Wo_r) + scalar_t *__restrict__ grad_in, // (B,C,Hs,Ws_r) + int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, + int Ws_full, long long total_in) +{ + using acc_t = acc_scalar_t; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_in) + return; + + const int wi = static_cast(idx % Ws_r); + const int hi = static_cast((idx / Ws_r) % Hs); + const int c = static_cast((idx / (1LL * Ws_r * Hs)) % C); + const int b = static_cast(idx / (1LL * Ws_r * Hs * C)); + + const int ho = hi / s; + const int wo = wi / s; + + // same weights as forward + float w = 2.0f; + if (wi == 0) + w = 1.0f; + if ((Ws_full % 2 == 0) && wi == (Ws_r - 1)) + w = 1.0f; + + // denominator over that output block + float denom = 0.0f; + for (int di = 0; di < s; ++di) + { + const int hi2 = ho * s + di; + if (hi2 >= Hs) + continue; + for (int dj = 0; dj < s; ++dj) + { + const int wi2 = wo * s + dj; + if (wi2 >= Ws_r) + continue; + float w2 = 2.0f; + if (wi2 == 0) + w2 = 1.0f; + if ((Ws_full % 2 == 0) && wi2 == (Ws_r - 1)) + w2 = 1.0f; + denom += w2; + } + } + if (denom < 1e-8f) + denom = 1.0f; + + const long long go_off = ((long long)b * C + c) * Ho * Wo_r + (long long)ho * Wo_r + wo; + const long long gi_off = ((long long)b * C + c) * Hs * Ws_r + (long long)hi * Ws_r + wi; + + const acc_t go_val = static_cast(grad_out[go_off]); + const acc_t scale = static_cast(w / denom); + grad_in[gi_off] = static_cast(go_val * scale); +} + +// ----------------------- launchers ---------------------------- +void weighted_block_mean_kernel_launcher( + const at::Tensor &in, at::Tensor &out, + int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, + int Ws_full, long long total_out) +{ + const int threads = 256; + const int blocks = (int)((total_out + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::kHalf, at::kBFloat16, in.scalar_type(), + "weighted_block_mean_fwd", [&] + { weighted_block_mean_kernel<<>>( + in.data_ptr(), out.data_ptr(), + B, C, Ho, Wo_r, s, Hs, Ws_r, Ws_full, total_out); }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void weighted_block_mean_grad_kernel_launcher( + const at::Tensor &grad_out, at::Tensor &grad_in, + int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, + int Ws_full, long long total_in) +{ + const int threads = 256; + const int blocks = (int)((total_in + threads - 1) / threads); + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::kHalf, at::kBFloat16, grad_out.scalar_type(), + "weighted_block_mean_bwd", [&] + { weighted_block_mean_grad_kernel<<>>( + grad_out.data_ptr(), grad_in.data_ptr(), + B, C, Ho, Wo_r, s, Hs, Ws_r, Ws_full, total_in); }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} From 8de81bf2bddfb5978f388c4f1b3e2a99dec4daa9 Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Thu, 4 Sep 2025 09:26:39 +0800 Subject: [PATCH 21/22] Revert "Fix v7 precision error" This reverts commit eac211bf8145c705bab52d410e59c1c5ff45a1dd. --- Converse2D/setup.py | 2 +- Converse2D/torch_converse2d/converse2d.cpp | 332 ++++++------------ .../torch_converse2d/converse2d_kernel.cu | 203 ----------- 3 files changed, 110 insertions(+), 427 deletions(-) delete mode 100644 Converse2D/torch_converse2d/converse2d_kernel.cu diff --git a/Converse2D/setup.py b/Converse2D/setup.py index d80a440..6b310cd 100644 --- a/Converse2D/setup.py +++ b/Converse2D/setup.py @@ -6,7 +6,7 @@ CPP = str(PKG_DIR / f"converse2d.cpp") -CU = str(PKG_DIR / f"converse2d_kernel.cu") +CU = str(PKG_DIR / f"converse2d.cu") has_cu = os.path.exists(CU) extra_cflags = ["-O3"] diff --git a/Converse2D/torch_converse2d/converse2d.cpp b/Converse2D/torch_converse2d/converse2d.cpp index bc7ebc8..b5c713d 100644 --- a/Converse2D/torch_converse2d/converse2d.cpp +++ b/Converse2D/torch_converse2d/converse2d.cpp @@ -3,272 +3,158 @@ #include #include #include -#include -#include +#include +#include #include #include #include #include -#include #include -#include -#include - -#include -#include -#include -#include using at::Tensor; -using at::indexing::Slice; - -void weighted_block_mean_kernel_launcher( - const at::Tensor &in, at::Tensor &out, - int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, - int Ws_full, long long total_out); - -void weighted_block_mean_grad_kernel_launcher( - const at::Tensor &grad_out, at::Tensor &grad_in, - int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, - int Ws_full, long long total_in); -struct WeightedBlockMeanFunction : public torch::autograd::Function +namespace { - static at::Tensor forward( - torch::autograd::AutogradContext *ctx, - const at::Tensor &input, int64_t Ws_full, int64_t s) - { - TORCH_CHECK(input.is_cuda() && input.dim() == 4, - "weighted_block_mean: input must be (B,C,Hs,Ws_r) CUDA Tensor"); - TORCH_CHECK(s >= 1, "weighted_block_mean: s must be >= 1"); - - if (s == 1) - { - ctx->saved_data["s"] = s; - return input; - } - - auto x = input.contiguous(); - const int B = (int)x.size(0); - const int C = (int)x.size(1); - const int Hs = (int)x.size(2); - const int Ws_r = (int)x.size(3); - - TORCH_CHECK(Hs % s == 0, "weighted_block_mean: H must be divisible by s"); - const int Ho = Hs / (int)s; - const int Wo_r = (Ws_r + (int)s - 1) / (int)s; // ceil_div for safety - - auto out = at::empty({B, C, Ho, Wo_r}, x.options()); - const long long total_out = out.numel(); - - weighted_block_mean_kernel_launcher( - x, out, B, C, Ho, Wo_r, (int)s, Hs, Ws_r, (int)Ws_full, total_out); - - ctx->save_for_backward({x}); - ctx->saved_data["s"] = s; - ctx->saved_data["Ws_full"] = Ws_full; - return out; - } - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext *ctx, - torch::autograd::tensor_list grad_outputs) + static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) { - auto go = grad_outputs[0].contiguous(); - auto s = ctx->saved_data["s"].toInt(); - + TORCH_CHECK(s >= 1, "scale must be >= 1"); if (s == 1) - { - return {go, torch::Tensor(), torch::Tensor()}; - } - - auto Ws_full = ctx->saved_data["Ws_full"].toInt(); - auto saved = ctx->get_saved_variables(); - auto input = saved[0]; - - const int B = (int)input.size(0); - const int C = (int)input.size(1); - const int Hs = (int)input.size(2); - const int Ws_r = (int)input.size(3); - const int Ho = (int)go.size(2); - const int Wo_r = (int)go.size(3); - - auto grad_in = at::empty_like(input); - const long long total_in = input.numel(); - - weighted_block_mean_grad_kernel_launcher( - go, grad_in, B, C, Ho, Wo_r, (int)s, Hs, Ws_r, (int)Ws_full, total_in); - - return {grad_in, torch::Tensor(), torch::Tensor()}; + return x; + auto sizes = x.sizes().vec(); + sizes[sizes.size() - 2] *= s; + sizes[sizes.size() - 1] *= s; + Tensor z = at::zeros(sizes, x.options()); + z.index_put_( + {at::indexing::Slice(), at::indexing::Slice(), + at::indexing::Slice(0, z.size(-2), s), + at::indexing::Slice(0, z.size(-1), s)}, + x); + return z; } -}; - -static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) -{ - TORCH_CHECK(x.dim() == 4, "sfold_upsample expects (B,C,H,W)"); - if (s == 1) - return x; - const auto B = x.size(0), C = x.size(1), H = x.size(2), W = x.size(3); - auto y = at::zeros({B, C, H * s, W * s}, x.options()); - y.index_put_({Slice(), Slice(), Slice(0, c10::nullopt, s), Slice(0, c10::nullopt, s)}, x); - return y; -} -at::Tensor weighted_block_mean_cuda(const at::Tensor &input, int64_t Ws_full, int64_t s) -{ - return WeightedBlockMeanFunction::apply(input, Ws_full, s); -} - -struct FBKey -{ - int64_t device_id; - at::ScalarType dtype; - int64_t channels; - int64_t H, W; - void *ptr; - bool operator==(const FBKey &other) const + static inline Tensor p2o(const Tensor &psf, int64_t H, int64_t W) { - return device_id == other.device_id && dtype == other.dtype && - channels == other.channels && H == other.H && W == other.W && - ptr == other.ptr; + TORCH_CHECK(psf.dim() == 4 && psf.size(0) == 1, "psf must be (1,C,kh,kw)"); + auto C = psf.size(1); + auto kh = psf.size(2); + auto kw = psf.size(3); + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + const int64_t sh = -static_cast(kh / 2); + const int64_t sw = -static_cast(kw / 2); + otf = at::roll(otf, {sh, sw}, {-2, -1}); + return at::fft_fftn(otf, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); } -}; -namespace std -{ - template <> - struct hash - { - size_t operator()(const FBKey &k) const - { - return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ - ((hash()(k.H) ^ hash()(k.W)) << 1) ^ - ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); - } - }; -} - -constexpr size_t FB_CACHE_MAX_SIZE = 64; -static std::unordered_map> fb_cache; -static std::list fb_cache_lru; -static std::mutex fb_cache_mutex; - -static inline std::pair p2o_cached_rfft(const Tensor &psf, int64_t H, int64_t W) -{ - const bool training_with_grad = at::GradMode::is_enabled() && psf.requires_grad(); - - auto C = psf.size(1); - FBKey key{psf.device().index(), psf.scalar_type(), C, H, W, psf.data_ptr()}; - if (!training_with_grad) + static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) { - std::lock_guard lock(fb_cache_mutex); - auto it = fb_cache.find(key); - if (it != fb_cache.end()) - { - fb_cache_lru.remove(key); - fb_cache_lru.push_front(key); - return it->second; - } + TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); + TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); + + const auto &sizes = a.sizes(); + const int64_t L = a.dim(); + const int64_t W = sizes[L - 2]; + const int64_t H = sizes[L - 1]; + const int64_t W_s = W / s; + const int64_t H_s = H / s; + + std::vector view_shape; + view_shape.reserve(L + 2); + for (int64_t i = 0; i < L - 2; ++i) + view_shape.push_back(sizes[i]); + view_shape.push_back(s); + view_shape.push_back(W_s); + view_shape.push_back(s); + view_shape.push_back(H_s); + Tensor v = a.view(view_shape); + + std::vector perm; + perm.reserve(view_shape.size()); + for (int64_t i = 0; i < L - 2; ++i) + perm.push_back(i); + perm.push_back(L - 2 + 1); // W_s + perm.push_back(L - 2 + 3); // H_s + perm.push_back(L - 2 + 0); // s + perm.push_back(L - 2 + 2); // s + Tensor p = v.permute(perm).contiguous(); + + std::vector merge_shape; + merge_shape.reserve(L + 1); + for (int64_t i = 0; i < L - 2; ++i) + merge_shape.push_back(p.size(i)); + merge_shape.push_back(W_s); + merge_shape.push_back(H_s); + merge_shape.push_back(s * s); + Tensor r = p.view(merge_shape); + + return r.mean(-1, /*keepdim=*/false); } - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - int64_t kh = psf.size(2), kw = psf.size(3); - otf.index_put_({0, Slice(), Slice(0, kh), Slice(0, kw)}, psf); - otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); - - Tensor FB = at::fft_rfftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor F2B = at::abs(FB).pow(2); - - if (!training_with_grad) + Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) { - std::lock_guard lock(fb_cache_mutex); - fb_cache[key] = {FB, F2B}; - fb_cache_lru.push_front(key); - if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) - { - fb_cache.erase(fb_cache_lru.back()); - fb_cache_lru.pop_back(); - } - } - return {FB, F2B}; -} + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + TORCH_CHECK(scale >= 1, "scale must be >= 1"); -Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) -{ - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; - Tensor STy = sfold_upsample_zero_insertion(x, scale); + Tensor STy = sfold_upsample_zero_insertion(x, scale); - auto [FB, F2B] = p2o_cached_rfft(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); - Tensor F_STy = at::fft_rfftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); - Tensor F_lx0 = at::fft_rfftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor FB = p2o(weight, Hs, Ws); + Tensor FBC = at::conj_physical(FB); + Tensor F2B = at::abs(FB).pow(2.0); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + F_lx0; + Tensor F_STy = at::fft_fftn(STy, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); - Tensor x1 = FB * FR; - Tensor FBR = weighted_block_mean_cuda(x1, Ws, scale); - Tensor invW = weighted_block_mean_cuda(F2B, Ws, scale); + Tensor x1 = FB * FR; - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; + Tensor FBR = splits_mean_then_mean(x1, scale); + Tensor invW = splits_mean_then_mean(F2B, scale); - const int64_t Ws_rfft = Ws / 2 + 1; - const int64_t Ho = invWBR.size(2); - const int64_t Wo_r = invWBR.size(3); + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; - Tensor invWBR_exp = invWBR - .view({B, C, Ho, 1, Wo_r, 1}) - .expand({B, C, Ho, (int64_t)scale, Wo_r, (int64_t)scale}) - .reshape({B, C, Ho * (int64_t)scale, Wo_r * (int64_t)scale}); + Tensor invWBR_rep = invWBR.repeat({1, 1, scale, scale}); - if (invWBR_exp.size(-2) != Hs || invWBR_exp.size(-1) != Ws_rfft) - { - using at::indexing::Slice; - invWBR_exp = invWBR_exp.index({Slice(), Slice(), - Slice(0, Hs), Slice(0, Ws_rfft)}); - } + Tensor FCBinvWBR = FBC * invWBR_rep; - Tensor FCBinvWBR = FBC * invWBR_exp; - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out = at::fft_irfftn(FX, {Hs, Ws}, {-2, -1}, c10::nullopt); - return out; -} + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = at::fft_ifftn(FX, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); + Tensor out = at::real(out_c); + (void)B; + (void)C; + (void)H; + (void)W; + return out; + } -void clear_fb_cache() -{ - std::lock_guard lock(fb_cache_mutex); - fb_cache.clear(); - fb_cache_lru.clear(); } TORCH_LIBRARY(converse2d, m) { m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); - m.def("clear_cache() -> ()"); } TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) { m.impl("forward", TORCH_FN(converse2d_forward)); - m.impl("clear_cache", TORCH_FN(clear_fb_cache)); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} diff --git a/Converse2D/torch_converse2d/converse2d_kernel.cu b/Converse2D/torch_converse2d/converse2d_kernel.cu deleted file mode 100644 index cf9f625..0000000 --- a/Converse2D/torch_converse2d/converse2d_kernel.cu +++ /dev/null @@ -1,203 +0,0 @@ -#include -#include -#include -#include -#include - -// -------- accumulator mapping (keep numerics stable) ---------- -template -struct acc_type_map -{ - using type = float; -}; -template <> -struct acc_type_map -{ - using type = double; -}; -template <> -struct acc_type_map -{ - using type = float; -}; -template <> -struct acc_type_map -{ - using type = float; -}; -template <> -struct acc_type_map -{ - using type = float; -}; -template <> -struct acc_type_map> -{ - using type = c10::complex; -}; -template <> -struct acc_type_map> -{ - using type = c10::complex; -}; -template <> -struct acc_type_map> -{ - using type = c10::complex; -}; - -template -using acc_scalar_t = typename acc_type_map::type; - -// ----------------------- forward kernel ----------------------- -template -__global__ void weighted_block_mean_kernel( - const scalar_t *__restrict__ in, // (B,C,Hs,Ws_r) - scalar_t *__restrict__ out, // (B,C,Ho,Wo_r) - int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, - int Ws_full, long long total_out) -{ - using acc_t = acc_scalar_t; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_out) - return; - - const int wo = static_cast(idx % Wo_r); - const int ho = static_cast((idx / Wo_r) % Ho); - const int c = static_cast((idx / (1LL * Wo_r * Ho)) % C); - const int b = static_cast(idx / (1LL * Wo_r * Ho * C)); - - const int hi0 = ho * s; - const int wi0 = wo * s; - const long long base_in = ((long long)b * C + c) * Hs * Ws_r; - - acc_t acc_val = acc_t(0); - float acc_w = 0.0f; - - // average over an s x s block with DC/Nyquist half-weights along last rFFT dim - for (int di = 0; di < s; ++di) - { - const int hi = hi0 + di; - if (hi >= Hs) - continue; - const long long row_off = base_in + (long long)hi * Ws_r; - for (int dj = 0; dj < s; ++dj) - { - const int wi = wi0 + dj; - if (wi >= Ws_r) - continue; - - float w = 2.0f; - if (wi == 0) - w = 1.0f; // DC - if ((Ws_full % 2 == 0) && wi == (Ws_r - 1)) - w = 1.0f; // Nyquist if even W - - acc_val += static_cast(in[row_off + wi]) * static_cast(w); - acc_w += w; - } - } - - const long long out_off = ((long long)b * C + c) * Ho * Wo_r + (long long)ho * Wo_r + wo; - out[out_off] = (acc_w > 1e-8f) - ? static_cast(acc_val / static_cast(acc_w)) - : static_cast(0); -} - -// ----------------------- backward kernel ---------------------- -template -__global__ void weighted_block_mean_grad_kernel( - const scalar_t *__restrict__ grad_out, // (B,C,Ho,Wo_r) - scalar_t *__restrict__ grad_in, // (B,C,Hs,Ws_r) - int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, - int Ws_full, long long total_in) -{ - using acc_t = acc_scalar_t; - - long long idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_in) - return; - - const int wi = static_cast(idx % Ws_r); - const int hi = static_cast((idx / Ws_r) % Hs); - const int c = static_cast((idx / (1LL * Ws_r * Hs)) % C); - const int b = static_cast(idx / (1LL * Ws_r * Hs * C)); - - const int ho = hi / s; - const int wo = wi / s; - - // same weights as forward - float w = 2.0f; - if (wi == 0) - w = 1.0f; - if ((Ws_full % 2 == 0) && wi == (Ws_r - 1)) - w = 1.0f; - - // denominator over that output block - float denom = 0.0f; - for (int di = 0; di < s; ++di) - { - const int hi2 = ho * s + di; - if (hi2 >= Hs) - continue; - for (int dj = 0; dj < s; ++dj) - { - const int wi2 = wo * s + dj; - if (wi2 >= Ws_r) - continue; - float w2 = 2.0f; - if (wi2 == 0) - w2 = 1.0f; - if ((Ws_full % 2 == 0) && wi2 == (Ws_r - 1)) - w2 = 1.0f; - denom += w2; - } - } - if (denom < 1e-8f) - denom = 1.0f; - - const long long go_off = ((long long)b * C + c) * Ho * Wo_r + (long long)ho * Wo_r + wo; - const long long gi_off = ((long long)b * C + c) * Hs * Ws_r + (long long)hi * Ws_r + wi; - - const acc_t go_val = static_cast(grad_out[go_off]); - const acc_t scale = static_cast(w / denom); - grad_in[gi_off] = static_cast(go_val * scale); -} - -// ----------------------- launchers ---------------------------- -void weighted_block_mean_kernel_launcher( - const at::Tensor &in, at::Tensor &out, - int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, - int Ws_full, long long total_out) -{ - const int threads = 256; - const int blocks = (int)((total_out + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - at::kHalf, at::kBFloat16, in.scalar_type(), - "weighted_block_mean_fwd", [&] - { weighted_block_mean_kernel<<>>( - in.data_ptr(), out.data_ptr(), - B, C, Ho, Wo_r, s, Hs, Ws_r, Ws_full, total_out); }); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -void weighted_block_mean_grad_kernel_launcher( - const at::Tensor &grad_out, at::Tensor &grad_in, - int B, int C, int Ho, int Wo_r, int s, int Hs, int Ws_r, - int Ws_full, long long total_in) -{ - const int threads = 256; - const int blocks = (int)((total_in + threads - 1) / threads); - auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - at::kHalf, at::kBFloat16, grad_out.scalar_type(), - "weighted_block_mean_bwd", [&] - { weighted_block_mean_grad_kernel<<>>( - grad_out.data_ptr(), grad_in.data_ptr(), - B, C, Ho, Wo_r, s, Hs, Ws_r, Ws_full, total_in); }); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} From c7eb880f1361f46aae57c2a8369e5f116f97042d Mon Sep 17 00:00:00 2001 From: BoyceYi <1473416941@qq.com> Date: Sun, 14 Sep 2025 10:31:36 +0800 Subject: [PATCH 22/22] Stable C++ only version --- Converse2D/README.md | 28 -- Converse2D/torch_converse2d/converse2d.cpp | 305 +++++++++++++-------- test/results_4090.csv | 73 ----- test/test_speed.py | 7 +- 4 files changed, 190 insertions(+), 223 deletions(-) delete mode 100644 test/results_4090.csv diff --git a/Converse2D/README.md b/Converse2D/README.md index 7508258..09e9d73 100644 --- a/Converse2D/README.md +++ b/Converse2D/README.md @@ -1,12 +1,3 @@ -### Pytorch Analysis - -![hotmap](../figs/pytorch_hotmap.png) - -**Bottleneck** - -- p2o -- upsample - ### Kernel Registry --- @@ -16,11 +7,6 @@ ``` ├---v1 Translation from python to CPP ├---v2 Add FB/F2B cache & broadcast replace repeat -├---v3 splits→permute→view→mean to block mean CUDA kernel -├---v4 STy s-fold upsampler CUDA kernel - ├---v5 Larger batched FFT CUDA kernel - ├---v6 Eliminate redundant calculations of conj/abs/pow(2) - ├---v7 R2C/C2R (Real FFT) replaces C2C ``` **Tested Device** @@ -29,10 +15,6 @@ - NVIDIA RTX 4090 - NVIDIA RTX 5060ti 16g -Only **v7** left in main branch. - -If you want test more, switch to branch `dev`. We recommend you to run `test/test_speed.py` first to choose the most suitable backend for GPU. - **Installation** ```python @@ -49,13 +31,3 @@ import torch_converse2d out = torch.ops.converse2d.forward(x, x0, weight, bias, scale, eps) print(torch.ops.converse2d) ``` - -**TODO** - -- [ ] Temporary Tensor Reuse and In-Place Writing -- [X] Larger batched FFT(v5) **Note: not very useful** -- [X] Eliminate redundant calculations of `conj/abs/pow(2)` (v6) *Note: not very useful** -- [ ] The minimal necessary policy for `contiguous()` -- [X] R2C/C2R (Real FFT) replaces C2C (v7) **(Optional)** -- [ ] Mixed precision **(Optional)** -- [ ] Adaptive padding **(Optional)** diff --git a/Converse2D/torch_converse2d/converse2d.cpp b/Converse2D/torch_converse2d/converse2d.cpp index b5c713d..251e740 100644 --- a/Converse2D/torch_converse2d/converse2d.cpp +++ b/Converse2D/torch_converse2d/converse2d.cpp @@ -1,3 +1,4 @@ + #include #include #include @@ -11,150 +12,218 @@ #include #include +#include +#include +#include +#include + using at::Tensor; -namespace +struct FBKey { + int64_t device_id; + at::ScalarType dtype; + int64_t channels; + int64_t H, W; + void *ptr; - static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) - { - TORCH_CHECK(s >= 1, "scale must be >= 1"); - if (s == 1) - return x; - auto sizes = x.sizes().vec(); - sizes[sizes.size() - 2] *= s; - sizes[sizes.size() - 1] *= s; - Tensor z = at::zeros(sizes, x.options()); - z.index_put_( - {at::indexing::Slice(), at::indexing::Slice(), - at::indexing::Slice(0, z.size(-2), s), - at::indexing::Slice(0, z.size(-1), s)}, - x); - return z; - } - - static inline Tensor p2o(const Tensor &psf, int64_t H, int64_t W) - { - TORCH_CHECK(psf.dim() == 4 && psf.size(0) == 1, "psf must be (1,C,kh,kw)"); - auto C = psf.size(1); - auto kh = psf.size(2); - auto kw = psf.size(3); - Tensor otf = at::zeros({1, C, H, W}, psf.options()); - otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); - const int64_t sh = -static_cast(kh / 2); - const int64_t sw = -static_cast(kw / 2); - otf = at::roll(otf, {sh, sw}, {-2, -1}); - return at::fft_fftn(otf, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); - } - - static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) + bool operator==(const FBKey &other) const { - TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); - TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); - - const auto &sizes = a.sizes(); - const int64_t L = a.dim(); - const int64_t W = sizes[L - 2]; - const int64_t H = sizes[L - 1]; - const int64_t W_s = W / s; - const int64_t H_s = H / s; - - std::vector view_shape; - view_shape.reserve(L + 2); - for (int64_t i = 0; i < L - 2; ++i) - view_shape.push_back(sizes[i]); - view_shape.push_back(s); - view_shape.push_back(W_s); - view_shape.push_back(s); - view_shape.push_back(H_s); - Tensor v = a.view(view_shape); - - std::vector perm; - perm.reserve(view_shape.size()); - for (int64_t i = 0; i < L - 2; ++i) - perm.push_back(i); - perm.push_back(L - 2 + 1); // W_s - perm.push_back(L - 2 + 3); // H_s - perm.push_back(L - 2 + 0); // s - perm.push_back(L - 2 + 2); // s - Tensor p = v.permute(perm).contiguous(); - - std::vector merge_shape; - merge_shape.reserve(L + 1); - for (int64_t i = 0; i < L - 2; ++i) - merge_shape.push_back(p.size(i)); - merge_shape.push_back(W_s); - merge_shape.push_back(H_s); - merge_shape.push_back(s * s); - Tensor r = p.view(merge_shape); - - return r.mean(-1, /*keepdim=*/false); + return device_id == other.device_id && dtype == other.dtype && + channels == other.channels && H == other.H && W == other.W && + ptr == other.ptr; } +}; - Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +namespace std +{ + template <> + struct hash { - TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); - TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); - TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); - TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); - TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); - TORCH_CHECK(scale >= 1, "scale must be >= 1"); - - x = x.contiguous(); - x0 = x0.contiguous(); - weight = weight.contiguous(); - bias = bias.contiguous(); - - const int64_t B = x.size(0); - const int64_t C = x.size(1); - const int64_t H = x.size(2); - const int64_t W = x.size(3); - const int64_t Hs = H * scale; - const int64_t Ws = W * scale; - - Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + size_t operator()(const FBKey &k) const + { + return ((hash()(k.device_id) ^ hash()(k.channels)) << 1) ^ + ((hash()(k.H) ^ hash()(k.W)) << 1) ^ + ((hash()(k.ptr)) ^ hash()(static_cast(k.dtype))); + } + }; +} - Tensor STy = sfold_upsample_zero_insertion(x, scale); +constexpr size_t FB_CACHE_MAX_SIZE = 64; - Tensor FB = p2o(weight, Hs, Ws); - Tensor FBC = at::conj_physical(FB); - Tensor F2B = at::abs(FB).pow(2.0); +static std::unordered_map> fb_cache; +static std::list fb_cache_lru; +static std::mutex fb_cache_mutex; - Tensor F_STy = at::fft_fftn(STy, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); - Tensor FBFy = FBC * F_STy; - Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); +static inline std::tuple p2o_cached(const Tensor &psf, int64_t H, int64_t W) +{ + const bool training_with_grad = at::GradMode::is_enabled() && psf.requires_grad(); + auto C = psf.size(1); + FBKey key{ + psf.device().index(), + psf.scalar_type(), + C, H, W, + psf.data_ptr()}; + + if (!training_with_grad) + { + std::lock_guard lock(fb_cache_mutex); + auto it = fb_cache.find(key); + if (it != fb_cache.end()) + { + fb_cache_lru.remove(key); + fb_cache_lru.push_front(key); + return it->second; + } + } - Tensor x1 = FB * FR; + Tensor otf = at::zeros({1, C, H, W}, psf.options()); + int64_t kh = psf.size(2), kw = psf.size(3); + otf.index_put_({0, at::indexing::Slice(), at::indexing::Slice(0, kh), at::indexing::Slice(0, kw)}, psf); + otf = at::roll(otf, {-kh / 2, -kw / 2}, {-2, -1}); + Tensor FB = at::fft_fftn(otf, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor FBC = at::conj_physical(FB); + Tensor F2B = at::abs(FB).pow(2); - Tensor FBR = splits_mean_then_mean(x1, scale); - Tensor invW = splits_mean_then_mean(F2B, scale); + if (!training_with_grad) + { + std::lock_guard lock(fb_cache_mutex); + fb_cache[key] = std::make_tuple(FB, FBC, F2B); + fb_cache_lru.push_front(key); + + if (fb_cache_lru.size() > FB_CACHE_MAX_SIZE) + { + fb_cache.erase(fb_cache_lru.back()); + fb_cache_lru.pop_back(); + } + } - Tensor invW_plus = invW + lambda_; - Tensor invWBR = FBR / invW_plus; + return std::make_tuple(FB, FBC, F2B); +} - Tensor invWBR_rep = invWBR.repeat({1, 1, scale, scale}); +static inline Tensor sfold_upsample_zero_insertion(const Tensor &x, int64_t s) +{ + TORCH_CHECK(s >= 1, "scale must be >= 1"); + if (s == 1) + return x; + auto sizes = x.sizes().vec(); + sizes[sizes.size() - 2] *= s; + sizes[sizes.size() - 1] *= s; + Tensor z = at::zeros(sizes, x.options()); + z.index_put_( + {at::indexing::Slice(), at::indexing::Slice(), + at::indexing::Slice(0, z.size(-2), s), + at::indexing::Slice(0, z.size(-1), s)}, + x); + return z; +} - Tensor FCBinvWBR = FBC * invWBR_rep; +static inline Tensor splits_mean_then_mean(const Tensor &a, int64_t s) +{ + TORCH_CHECK(a.dim() >= 2, "tensor must have spatial dims"); + TORCH_CHECK(a.size(-2) % s == 0 && a.size(-1) % s == 0, "spatial not divisible by scale"); + + const auto &sizes = a.sizes(); + const int64_t L = a.dim(); + const int64_t W = sizes[L - 2]; + const int64_t H = sizes[L - 1]; + const int64_t W_s = W / s; + const int64_t H_s = H / s; + + std::vector view_shape; + view_shape.reserve(L + 2); + for (int64_t i = 0; i < L - 2; ++i) + view_shape.push_back(sizes[i]); + view_shape.push_back(s); + view_shape.push_back(W_s); + view_shape.push_back(s); + view_shape.push_back(H_s); + Tensor v = a.view(view_shape); + + std::vector perm; + perm.reserve(view_shape.size()); + for (int64_t i = 0; i < L - 2; ++i) + perm.push_back(i); + perm.push_back(L - 2 + 1); + perm.push_back(L - 2 + 3); + perm.push_back(L - 2 + 0); + perm.push_back(L - 2 + 2); + Tensor p = v.permute(perm).contiguous(); + + std::vector merge_shape; + merge_shape.reserve(L + 1); + for (int64_t i = 0; i < L - 2; ++i) + merge_shape.push_back(p.size(i)); + merge_shape.push_back(W_s); + merge_shape.push_back(H_s); + merge_shape.push_back(s * s); + Tensor r = p.view(merge_shape); + + return r.mean(-1, /*keepdim=*/false); +} - Tensor FX = (FR - FCBinvWBR) / lambda_; - Tensor out_c = at::fft_ifftn(FX, c10::optional({}), c10::optional({-2, -1}), c10::nullopt); - Tensor out = at::real(out_c); - (void)B; - (void)C; - (void)H; - (void)W; - return out; - } +Tensor converse2d_forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int64_t scale, double eps) +{ + TORCH_CHECK(x.dim() == 4, "x must be (B,C,H,W)"); + TORCH_CHECK(x0.dim() == 4, "x0 must be (B,C,Hs,Ws)"); + TORCH_CHECK(weight.dim() == 4 && weight.size(0) == 1, "weight must be (1,C,kh,kw)"); + TORCH_CHECK(bias.dim() == 4 && bias.size(0) == 1 && bias.size(2) == 1 && bias.size(3) == 1, "bias must be (1,C,1,1)"); + TORCH_CHECK(x.device() == x0.device() && x.device() == weight.device() && x.device() == bias.device(), "tensors on same device"); + TORCH_CHECK(scale >= 1, "scale must be >= 1"); + + x = x.contiguous(); + x0 = x0.contiguous(); + weight = weight.contiguous(); + bias = bias.contiguous(); + + const int64_t B = x.size(0); + const int64_t C = x.size(1); + const int64_t H = x.size(2); + const int64_t W = x.size(3); + const int64_t Hs = H * scale; + const int64_t Ws = W * scale; + + Tensor lambda_ = at::sigmoid(bias - 9.0) + eps; + Tensor STy = sfold_upsample_zero_insertion(x, scale); + + auto [FB, FBC, F2B] = p2o_cached(weight, Hs, Ws); + + Tensor F_STy = at::fft_fftn(STy, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor FBFy = FBC * F_STy; + Tensor FR = FBFy + at::fft_fftn(lambda_ * x0, c10::nullopt, {-2, -1}, c10::nullopt); + + Tensor x1 = FB * FR; + Tensor FBR = splits_mean_then_mean(x1, scale); + Tensor invW = splits_mean_then_mean(F2B, scale); + + Tensor invW_plus = invW + lambda_; + Tensor invWBR = FBR / invW_plus; + + Tensor invWBR_rep = invWBR.repeat({1, 1, scale, scale}); + Tensor FCBinvWBR = FBC * invWBR_rep; + + Tensor FX = (FR - FCBinvWBR) / lambda_; + Tensor out_c = at::fft_ifftn(FX, c10::nullopt, {-2, -1}, c10::nullopt); + Tensor out = at::real(out_c); + return out; +} +void clear_fb_cache() +{ + std::lock_guard lock(fb_cache_mutex); + fb_cache.clear(); + fb_cache_lru.clear(); } TORCH_LIBRARY(converse2d, m) { m.def("forward(Tensor x, Tensor x0, Tensor weight, Tensor bias, int scale, float eps=1e-5) -> Tensor"); + m.def("clear_cache() -> ()"); } TORCH_LIBRARY_IMPL(converse2d, CompositeImplicitAutograd, m) { m.impl("forward", TORCH_FN(converse2d_forward)); + m.impl("clear_cache", TORCH_FN(clear_fb_cache)); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} \ No newline at end of file diff --git a/test/results_4090.csv b/test/results_4090.csv deleted file mode 100644 index e01b4d1..0000000 --- a/test/results_4090.csv +++ /dev/null @@ -1,73 +0,0 @@ -variant,B,C,H,W,scale,ksize,fwd_mean_ms,fwd_p50_ms,fwd_p90_ms,fwd_tp,fwd_speedup_vs_pytorch,bwd_mean_ms,bwd_p50_ms,bwd_p90_ms,bwd_tp,bwd_speedup_vs_pytorch,grad_ok,eps,warmup,dtype,device,iters -pytorch,8,8,256,256,1,3,9.288488607853651,8.991222945041955,9.583412948995829,0.05644492038852277,,20.064111375249922,19.56672000233084,21.39674376230687,0.02613063644805796,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,1,3,3.3021508576348424,3.324655001051724,3.3440517028793693,0.158771668104685,2.8128601654821153,6.695929053239524,6.69370440300554,6.72081895172596,0.07829951539679872,2.9964641524304687,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,1,3,3.1805392680689692,3.1820274889469147,3.1898362562060356,0.16484248607259483,2.920413120223181,6.543120900169015,6.540277507156134,6.558541930280626,0.08012812356660827,3.066443625507768,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,3,3.024693327024579,3.04701691493392,3.0670031206682324,0.17333591981562882,3.070886071280102,6.397886727936566,6.395927630364895,6.411892082542181,0.08194705881719984,3.1360529231690477,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,3,3.0596842663362622,3.0616489239037037,3.07619022205472,0.17135362814013316,3.0357670266990997,6.399778164923191,6.40075805131346,6.410616356879473,0.08192283958740817,3.1351260712785547,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,3,3.0559819284826517,3.0527031049132347,3.0764779541641474,0.17156122394359777,3.0394448740950337,6.409975425340235,6.408171029761434,6.429016985930502,0.08179251326414738,3.1301385799283215,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,3,3.0050428677350283,3.0304675456136465,3.052515466697514,0.17446939131193434,3.090967089882017,6.387096201069653,6.387264467775822,6.399212614633143,0.08208550231515176,3.141351052750664,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,1,3,1.8821742571890354,1.8811896443367004,1.8879902781918645,0.27855444202228474,4.934978029996914,5.040463833138347,5.040465504862368,5.048878467641771,0.10401582420909114,3.9806081423180046,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,1,5,7.326102494262159,7.420093985274434,7.613263255916536,0.07156438234526817,,16.14620674867183,16.1027709254995,16.263780114240944,0.03247127998302929,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,1,5,3.2715308107435703,3.3216774463653564,3.3417833968997,0.16025769901914425,2.2393499918153132,6.707657393999398,6.701117963530123,6.731993076391518,0.07816260867304027,2.407130507756115,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,1,5,3.159591546282172,3.1798079144209623,3.199100005440414,0.16593537244297263,2.318686572915615,6.542924279347062,6.542460061609745,6.554703111760318,0.08013053148955596,2.467735535261782,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,5,3.05445936974138,3.055477049201727,3.073208383284509,0.1716467422005326,2.398494007429687,6.400615535676479,6.398052908480167,6.4109522849321365,0.08191212190103653,2.5226021870356474,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,5,3.0754889361560345,3.0726579716429114,3.091029403731227,0.17047305676712743,2.3820935943339285,6.441063601523638,6.430621026083827,6.473009032197297,0.08139773683898717,2.5067609555745474,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,5,3.048403072170913,3.052991349250078,3.072553128004074,0.17198775476453956,2.403259123159488,6.407246896997094,6.406409083865583,6.420211470685899,0.08182734463466983,2.5199913485836056,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,5,2.9911579517647624,3.0331225134432316,3.045725799165666,0.1752792759374923,2.449252969051604,6.385393836535513,6.382934865541756,6.403305334970355,0.08210738654836988,2.5286156440793923,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,1,5,1.8822504533454776,1.8811115296557546,1.8892435124143958,0.2785431657451006,3.8922038675782376,5.046942573972046,5.043365992605686,5.0634910352528095,0.10388229949432032,3.1992055609946117,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,1,7,9.156141332350671,9.012500522658229,10.730735887773335,0.057260802446066954,,18.346093399450183,18.29959498718381,18.474590103141963,0.02857763713422022,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,1,7,3.3304048283025622,3.330954583361745,3.3437674399465322,0.15742470571279427,2.7492577642632607,6.747524798847735,6.6985663725063205,6.975994328968227,0.07770078890106961,2.718936787395449,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,1,7,3.1805677665397525,3.180499654263258,3.187753399834037,0.16484100905367305,2.8787757420782603,6.5445497911423445,6.542496965266764,6.557909771800041,0.0801106289556529,2.803262865274632,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,1,7,2.9923936538398266,2.959055360406637,3.0513782519847155,0.17520689476374068,3.0598050896818174,6.439674673601985,6.414378061890602,6.499357661232352,0.08141529294161429,2.8489161843308506,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,1,7,3.0570596596226096,3.056291490793228,3.0761950882151723,0.17150074201192486,2.9950810097965133,6.428387816995382,6.409163004718721,6.460378412157297,0.08155824056132495,2.853918264070309,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,1,7,3.0619724839925766,3.061673021875322,3.0807194765657187,0.17122557525937293,2.9902755103833485,6.415656288154423,6.413351511582732,6.4297555247321725,0.081720088553999,2.8595817131481276,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,1,7,3.042473620735109,3.0388199957087636,3.0632448382675648,0.17232294026375944,3.0094398419593875,6.388430343940854,6.389048998244107,6.399145186878741,0.08206835979627831,2.871768558430108,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,1,7,1.8896355107426643,1.883591990917921,1.9149431493133307,0.2774545657188377,4.845453676276503,5.0843368750065565,5.057928501628339,5.145174008794129,0.10311826554555825,3.6083551996004446,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,2,3,45.595936393365264,45.675908448174596,46.11496950965375,0.045994274180651766,,80.62302918639034,80.68823453504592,81.43389630131423,0.02601182343510869,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,2,3,11.85663956683129,11.822210508398712,11.90752275288105,0.17687574866210323,3.8456036498671153,25.430614417418838,25.363861583173275,25.43278068769723,0.08246564418685631,3.170313853336047,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,2,3,11.422737971879542,11.390020838007331,11.460354528389871,0.1835945116803661,3.9916818984741824,24.925604569725692,24.912420543842018,24.930561799556017,0.08413645470999621,3.2345465868584786,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,3,10.840446311049163,10.825388482771814,10.85302964784205,0.19345624154445284,4.206094019108,24.0234538866207,24.002613499760628,24.026684556156397,0.08729602370656452,3.3560132346869342,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,3,10.878428206779063,10.826261015608907,11.14368592388928,0.1927807915019494,4.191408494560955,23.699625409208238,23.694894975051284,23.71811883058399,0.0884888247721069,3.4018693457941795,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,3,10.918543636798859,10.834143031388521,11.191598395816982,0.1920725025022523,4.176008990768045,23.730977713130414,23.69663491845131,23.714118194766343,0.08837191730367012,3.3973749485164024,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,3,10.793533944524825,10.768141131848097,10.864380979910493,0.19429706811306324,4.224375133085532,23.638727851212025,23.631693911738694,23.66107157431543,0.08871678768840655,3.4106331649423582,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,2,3,7.474436634220183,7.435761974193156,7.621926814317703,0.2805765976259157,6.100250577363058,20.239599947817624,20.2378734247759,20.263943052850664,0.10361627726866854,3.9834299785694967,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,2,5,38.853937541134655,38.41745341196656,40.494335955008864,0.05397527593644133,,66.22083507943898,66.26046146266162,67.06938743591309,0.031669066049744635,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,2,5,11.842856714501977,11.828812537714839,11.868507391773164,0.17708159868488205,3.2807909846242334,25.389751312322915,25.350986630655825,25.415705051273108,0.08259836712075817,2.608171866862623,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,2,5,11.411372288130224,11.404957505874336,11.43825901672244,0.1837773711213853,3.4048435683365956,24.930950277484953,24.917138973250985,24.93941232096404,0.08411841412615266,2.656169714446977,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,5,10.863104425370693,10.828325641341507,11.124962358735502,0.19305273316733643,3.5766882117409824,24.023548257537186,24.010553024709225,24.034774862229824,0.08729568078445848,2.756496849238861,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,5,10.90485557448119,10.82881959155202,11.15186910610646,0.19231359697304146,3.562994234610317,23.72426231391728,23.71151139959693,23.792543495073915,0.08839693189405325,2.7912705652639866,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,5,10.839254357852042,10.825947392731905,10.844685533083975,0.19347751522048257,3.584558149333266,23.702450119890273,23.698252509348094,23.726728814654052,0.08847827922397537,2.7938392336861733,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,5,10.759746134281158,10.75667655095458,10.77785084489733,0.19490720076734483,3.6110459351214446,23.636622526682913,23.63265841268003,23.660433711484075,0.08872468973232393,2.8016200286095696,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,2,5,7.5289920112118125,7.6012545032426715,7.6225581811740994,0.2785435283869371,5.160576274124775,20.23725756444037,20.235952455550432,20.263770665042102,0.10362827044732498,3.2722237619685215,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,2,7,43.63379757385701,43.68262959178537,44.327281741425395,0.048062559680950134,,75.68363416008651,75.79991698730737,76.05113848112524,0.02770945163077252,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,2,7,11.998974299058318,11.853986419737339,12.264510500244796,0.1747776057962375,3.6364606245786693,25.40259242989123,25.356607511639595,25.428785989060998,0.08255661329795148,2.9793665496530024,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,2,7,11.41712686046958,11.413594475015998,11.448350944556296,0.18368474184701714,3.821784421519722,24.922909317538142,24.915332440286875,24.959108070470393,0.0841455535258977,3.036709446550378,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,2,7,10.88796194177121,10.827695950865746,11.144098523072898,0.19261198847089686,4.007526643389317,24.021303891204298,24.00768455117941,24.02521837502718,0.08730383702309759,3.1506880102290795,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,2,7,10.852394071407616,10.832656407728791,10.89512084145099,0.1932432591556259,4.020660997633443,23.70872948784381,23.69790489319712,23.748202505521476,0.08845484533767504,3.192226483451668,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,2,7,10.89978810865432,10.831900988705456,11.149773467332125,0.19240300628733165,4.003178514930233,23.71836454141885,23.704793537035584,23.77322791144252,0.08841891254086219,3.190929713046693,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,2,7,10.774082760326564,10.765377548523247,10.784391174092889,0.19464784582148828,4.049885131245684,23.65874081850052,23.644065018743277,23.685648757964373,0.08864174201359364,3.198971354422374,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,2,7,7.470769472420216,7.435820531100035,7.61996959336102,0.2807143237041433,5.840602863592509,20.250088809989393,20.24794602766633,20.27125945314765,0.10356260753609496,3.7374470240718987,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,3,3,101.28841781057417,101.60925146192312,107.4452179018408,0.04658570152437898,,175.37028575781733,175.02894543576986,177.75019966065884,0.026906450996586024,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,3,3,26.66485572233796,26.59923385363072,27.004664088599384,0.17695921737341677,3.798573630598038,58.96571567747742,58.67917649447918,60.70718690752983,0.08002263596373708,2.974105948565666,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,3,3,25.672803041525185,25.64728946890682,25.70093993563205,0.18379730457822557,3.9453587380677684,56.8397456035018,56.80113600101322,56.884816801175475,0.08301571285901913,3.085346070708189,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,3,24.374681571498513,24.3739370489493,24.404809717088938,0.19358579049161745,4.155476555189605,53.81168487481773,53.78810840193182,53.942976240068674,0.08768712615070265,3.2589629216364755,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,3,24.441681993193924,24.361361982300878,24.429271719418466,0.1930551261289607,4.144085412729743,53.07756224647164,53.07193798944354,53.1614552019164,0.08889993813371996,3.3040380593114964,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,3,24.375403551384807,24.365137447603047,24.41326177213341,0.19358005663590047,4.155353473309771,53.05147184524685,53.03926358465105,53.125715954229236,0.08894365859941288,3.305662965758595,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,3,24.276557215489447,24.22345709055662,24.269587476737797,0.1943682523891544,4.17227273667734,52.906081513501704,52.899461006745696,52.96852015890181,0.08918808320355022,3.314747203741834,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,3,3,16.926095513626933,16.923529445193708,16.972190444357693,0.2787761652532999,5.9841572871326685,46.062268051318824,46.03102651890367,46.21587647125125,0.10243941949933792,3.807243828341975,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,3,5,83.52379720192403,83.7456090375781,85.80605688039213,0.05649398324878007,,145.5188752664253,145.75104543473572,148.14186077564955,0.03242597904471772,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,3,5,26.591138602234423,26.581451063975692,26.635813480243087,0.17744979147314519,3.1410387667606536,57.973297983407974,57.822484406642616,58.44167503528297,0.08139250593179065,2.510101724902264,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,3,5,25.637969239614904,25.634926394559443,25.683102500624955,0.1840470263420472,3.2578164214685903,56.7909628059715,56.76805193070322,56.88152206130326,0.08308702242152946,2.562359714935563,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,5,24.41519634798169,24.391590617597103,24.450342124328017,0.19326455264776388,3.4209758550161573,53.789559081196785,53.765413467772305,53.912423201836646,0.08772319536728602,2.705336830271478,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,5,24.383281343616545,24.362614494748414,24.392798193730414,0.19351751446018198,3.4254535320690236,53.027258049696684,53.017987054772675,53.09310944285244,0.08898427287297746,2.744227791865955,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,5,24.393998621962965,24.367429548874497,24.428526940755546,0.19343249432471676,3.423948590647372,53.11096484772861,53.071493515744805,53.206104156561196,0.08884402709550474,2.73990268645343,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,5,24.267674176953733,24.23390606418252,24.288278119638562,0.194439399737825,3.441771823409596,52.94822241179645,52.909665973857045,53.120820759795606,0.08911709940518672,2.7483240916885783,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,3,5,17.1291301259771,16.946400050073862,18.221904919482768,0.2754717820050909,4.876126025527497,46.0351287573576,46.020424575544894,46.14793793298304,0.1024998110653888,3.1610398231626027,True,1e-05,10,float32,cuda,50 -pytorch,8,8,256,256,3,7,90.59936547186226,89.40491802059114,93.76176362857223,0.05208195416628463,,158.36064806673676,158.1671329913661,160.4503425071016,0.02979649336880384,,True,1e-05,10,float32,cuda,50 -cuda_v1,8,8,256,256,3,7,26.627318738028407,26.58525703009218,26.68023353908211,0.17720867979324692,3.4024967501692434,58.71118636801839,58.641764568164945,60.120047559030354,0.08036955632990492,2.6972823726314648,True,1e-05,10,float32,cuda,50 -cuda_v2,8,8,256,256,3,7,25.703592961654067,25.652661453932524,26.049211691133678,0.18357713674658,3.5247743615852856,56.822463613934815,56.79537542164326,56.92225047387183,0.0830409612659392,2.786937383473484,True,1e-05,10,float32,cuda,50 -cuda_v3,8,8,256,256,3,7,24.43074162583798,24.387317011132836,24.824217706918716,0.1931415784369645,3.7084165048859696,53.77089409157634,53.76591603271663,53.81281706504524,0.08775364590300178,2.9450997745552696,True,1e-05,10,float32,cuda,50 -cuda_v4,8,8,256,256,3,7,24.416185086593032,24.371870909817517,24.417825369164348,0.1932567263585738,3.7106274035254803,53.0829459335655,53.052632487379014,53.19422874599695,0.0888909218773469,2.983267889180994,True,1e-05,10,float32,cuda,50 -cuda_v5,8,8,256,256,3,7,24.362534414976835,24.361073039472103,24.400543165393174,0.19368231234182484,3.718798870784413,53.057269509881735,53.02855698391795,53.15046065952629,0.0889339395635725,2.984711605583899,True,1e-05,10,float32,cuda,50 -cuda_v6,8,8,256,256,3,7,24.26755757071078,24.222337524406612,24.37746999785304,0.1944403340241791,3.733353272486279,52.89180411491543,52.88355250377208,52.94667130801827,0.08921215827216153,2.9940489025988666,True,1e-05,10,float32,cuda,50 -cuda_v7,8,8,256,256,3,7,16.87394628766924,16.87477866653353,16.914236755110323,0.27963772786500724,5.369186551107395,46.165161593817174,46.10549903009087,46.407426544465125,0.10221110112245234,3.4303063738857515,True,1e-05,10,float32,cuda,50 diff --git a/test/test_speed.py b/test/test_speed.py index 75fb61f..7991072 100644 --- a/test/test_speed.py +++ b/test/test_speed.py @@ -110,9 +110,8 @@ def call_bwd(): return from torch.utils.cpp_extension import load - vnum = int(args.variant.split("_v")[1]) - cpp = PKG / f"converse2d_v{vnum}.cpp" - cu = PKG / f"converse2d_v{vnum}.cu" + cpp = PKG / f"converse2d.cpp" + cu = PKG / f"converse2d.cu" sources = [str(cpp)] if cu.exists(): sources.append(str(cu)) @@ -122,7 +121,7 @@ def call_bwd(): os.environ.setdefault("TORCH_CUDA_ARCH_LIST", f"{arch_str}+PTX") extra_cuda = ["-O3", f"-gencode=arch=compute_{arch_num},code=sm_{arch_num}"] if (cu.exists() and device=="cuda") else [] - ext_name = f"converse2d_v{vnum}_sm{arch_num}_ext" + ext_name = f"converse2d_sm{arch_num}_ext" print(f"[build] compiling {ext_name} (variant={args.variant}) ...", flush=True) load(name=ext_name, sources=sources, verbose=False, extra_cflags=["-O3"], extra_cuda_cflags=extra_cuda)