Skip to content

Commit

Permalink
[Pytorch] Add Vulkan support for aten::unsqueeze for 2d to 3d (pytorc…
Browse files Browse the repository at this point in the history
…h#101719)

Summary:
Pull Request resolved: pytorch#101719

Unsqueeze operator: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html#torch.unsqueeze

Test Plan:
Unsqueeze tests:
https://www.internalfb.com/phabricator/paste/view/P738187802
```
lfq@lfq-mbp fbsource % buck run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*unsqueeze*"
Downloaded 0/2 artifacts, 0.00 bytes, 100.0% cache miss (for updated rules)
Building: finished in 15.0 sec (100%) 455/455 jobs, 2/455 updated
  Total time: 15.0 sec
BUILD SUCCEEDED
Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc
Note: Google Test filter = *unsqueeze*
[==========] Running 3 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 3 tests from VulkanAPITest
[ RUN      ] VulkanAPITest.unsqueeze_dim0
[       OK ] VulkanAPITest.unsqueeze_dim0 (96 ms)
[ RUN      ] VulkanAPITest.unsqueeze_dim1
[       OK ] VulkanAPITest.unsqueeze_dim1 (2 ms)
[ RUN      ] VulkanAPITest.unsqueeze_dim2
[       OK ] VulkanAPITest.unsqueeze_dim2 (3 ms)
[----------] 3 tests from VulkanAPITest (101 ms total)
[----------] Global test environment tear-down
[==========] 3 tests from 1 test suite ran. (101 ms total)
[  PASSED  ] 3 tests.
```
All tests:
buck run //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64

https://www.internalfb.com/phabricator/paste/view/P738255852

Reviewed By: SS-JIA

Differential Revision: D45893511

fbshipit-source-id: 6875e1dcebc928282d8d6a2c795d3167e1dfb2cd
  • Loading branch information
lucylq authored and facebook-github-bot committed May 17, 2023
1 parent 66e3989 commit 2607021
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 0 deletions.
60 changes: 60 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/unsqueeze_2dto3d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;

/*
* Input Sampler
*/
layout(set = 0, binding = 1) uniform PRECISION sampler3D uImage;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// dim: dimension to insert at
ivec2 dim;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Returns a new tensor with dimension of size one inserted at the specified
* position (dim)
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const int dim = uBlock.dim.x;
vec4 out_texel = vec4(0, 0, 0, 0);
if (dim == 0 || dim == -3) {
imageStore(uOutput, pos, texelFetch(uImage, pos, 0));
} else if (dim == 1 || dim == -2) {
int src_x = pos.x;
int src_z = 0;
for (int i = 0; i < 4; i++) {
int src_y = pos.z * 4 + i;
const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0);
out_texel[i] = v[0];
}
imageStore(uOutput, pos, out_texel);
} else if (dim == 2 || dim == -1) {
int src_x = pos.y;
int src_z = 0;
for (int i = 0; i < 4; i++) {
int src_y = pos.z * 4 + i;
const vec4 v = texelFetch(uImage, ivec3(src_x, src_y, src_z), 0);
out_texel[i] = v[0];
}
imageStore(uOutput, pos, out_texel);
}
}
106 changes: 106 additions & 0 deletions aten/src/ATen/native/vulkan/ops/Unsqueeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include <ATen/native/vulkan/ops/Common.h>
#include <ATen/native/vulkan/ops/Utils.h>
#include <torch/library.h>

namespace at {
namespace native {
namespace vulkan {
namespace ops {
namespace {

using namespace api::utils;

struct Block final {
ivec2 dim;
};

Tensor unsqueeze_2dto3d(const at::Tensor& input_arg, int64_t dim) {
// Get the global Vulkan context
api::Context* const context = api::context();

// Cast the input Tensor to a vTensor
const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
const vTensor& v_input = convert(input);

// Create the output texture. For unsqueeze, add a dimension.
std::vector<int64_t> output_size = input_arg.sizes().vec();
if (dim < 0)
dim += 3;
output_size.insert(output_size.begin() + dim, 1);
const IntArrayRef v_input_sizes = v_input.sizes();
// Create the output texture
vTensor v_output{
context,
output_size,
input_arg.scalar_type(),
};

// Required to determine how to insert memory barriers in the command buffer
api::PipelineBarrier pipeline_barrier{};

// Total number of work items is equal to the size of the output texture
uvec3 global_size = v_output.extents();
// Adaptively determine local work group size, will usually be {4, 4, 4}
uvec3 local_size = adaptive_work_group_size(global_size);

// Create the params buffer
struct Block block {
{
static_cast<int32_t>(dim)
}
};
api::UniformParamsBuffer params(context, block);

context->submit_compute_job(
// shader descriptor
VK_KERNEL(unsqueeze_2dto3d),
// pipeline barrier
pipeline_barrier,
// global work group size
global_size,
// local work group size
local_size,
// fence handle
VK_NULL_HANDLE,
// shader arguments
v_output.image(
pipeline_barrier,
api::PipelineStage::COMPUTE,
api::MemoryAccessType::WRITE),
v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
// params buffer
params.buffer());

return convert(v_output);
}

Tensor unsqueeze(const at::Tensor& self, int64_t dim) {
TORCH_CHECK(
self.dim() >= 1 || self.dim() <= 3,
"Vulkan unsqueeze supports 1d, 2d, 3d tensors as input!");
TORCH_CHECK(
dim >= -self.dim() - 1 && dim <= self.dim(),
"Vulkan unsqueeze dimension out of range (expected to be in range of [",
-self.dim() - 1,
",",
self.dim(),
"], but got ",
dim);
// Remove this when 1d->2d and 3d->4d are supported.
TORCH_CHECK(self.dim() == 2, "Vulkan unsqueeze expects input dimension = 2!");
return unsqueeze_2dto3d(self, dim);
}

#ifdef USE_VULKAN_API

TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl(TORCH_SELECTIVE_NAME("aten::unsqueeze"), TORCH_FN(unsqueeze));
}

#endif /* USE_VULKAN_API */

} // namespace
} // namespace ops
} // namespace vulkan
} // namespace native
} // namespace at
41 changes: 41 additions & 0 deletions aten/src/ATen/test/vulkan_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3343,6 +3343,47 @@ TEST_F(VulkanAPITest, sub_to_scalar_wrapped) {
ASSERT_TRUE(check);
}

void test_unsqueeze(const at::IntArrayRef input_shape, int64_t dim) {
at::TensorOptions options(at::kCPU);
options = options.dtype(at::kFloat);

const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat));
const auto out_cpu = at::unsqueeze(in_cpu, dim);

const auto in_vulkan = in_cpu.vulkan();
const auto out_vulkan = at::unsqueeze(in_vulkan, dim);

const auto check = almostEqual(out_cpu, out_vulkan.cpu());
if (!check) {
showRtol(out_cpu, out_vulkan.cpu());
}
ASSERT_TRUE(check);
}

TEST_F(VulkanAPITest, unsqueeze_dim0) {
c10::InferenceMode mode;
test_unsqueeze({5, 7}, 0);
test_unsqueeze({5, 7}, -3);
test_unsqueeze({111, 222}, 0);
test_unsqueeze({111, 222}, -3);
}

TEST_F(VulkanAPITest, unsqueeze_dim1) {
c10::InferenceMode mode;
test_unsqueeze({5, 7}, 1);
test_unsqueeze({5, 7}, -2);
test_unsqueeze({111, 222}, 1);
test_unsqueeze({111, 222}, -2);
}

TEST_F(VulkanAPITest, unsqueeze_dim2) {
c10::InferenceMode mode;
test_unsqueeze({5, 7}, 2);
test_unsqueeze({5, 7}, -1);
test_unsqueeze({111, 222}, 2);
test_unsqueeze({111, 222}, -1);
}

TEST_F(VulkanAPITest, upsample_nearest2d) {
const auto in_cpu = at::rand({1, 2, 2, 3}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
const auto out_cpu = at::upsample_nearest2d(in_cpu, {4, 6});
Expand Down

0 comments on commit 2607021

Please sign in to comment.