Skip to content

Commit

Permalink
[vulkan] Pad channels when using texture storage instead of "tight pa…
Browse files Browse the repository at this point in the history
…cking" (#95251)

Currently, in Vulkan 4D tensors are represented in GPU textures by simply combining the batch and channel dimensions into the depth axis. However, if the number of channels is not a multiple of 4, then data belonging to the same batch can cross texel boundaries.

For instance, consider a tensor with `N=2`, `C=3`. The depth axis of the texture would contain the data

```
|tex1|tex2|
-----------
|AAAB|BB00|
```
Where A represents data from `n=1`and B represents data form `n=2`.

This packing structure ("tight packing") makes some ops that care about batch boundaries more complex and inefficient to implement. Therefore this diff introduces channel padding when storing tensors as image textures.

The same tensor with `N=2`, `C=3` would now have the depth axis contain

```
|tex1|tex2|
-----------
|AAA0|BBB0|
```

Differential Revision: [D43068669](https://our.internmc.facebook.com/intern/diff/D43068669/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D43068669/)!
Pull Request resolved: pytorch/pytorch#95251
Approved by: https://github.com/salilsdesai
  • Loading branch information
SS-JIA authored and cyyever committed Feb 25, 2023
1 parent 9d09418 commit 51046c6
Show file tree
Hide file tree
Showing 16 changed files with 418 additions and 247 deletions.
7 changes: 3 additions & 4 deletions aten/src/ATen/native/vulkan/api/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ c10::SmallVector<int64_t, 6u> calc_gpu_sizes(

c10::SmallVector<int64_t, 6u> gpu_sizes(3);

// Channel dim will be always be aligned. For 4 dimensional tensors, batch
// and channel are combined, then aligned.
// Channel dim will be be aligned to the next multiple of 4
switch (ndim) {
case 1:
gpu_sizes[0] = 4;
Expand All @@ -146,8 +145,8 @@ c10::SmallVector<int64_t, 6u> calc_gpu_sizes(
break;

case 4:
int64_t combined_depth = sizes[0] * sizes[1];
gpu_sizes[0] = api::utils::align_up(combined_depth, INT64_C(4));
int64_t padded_c = api::utils::align_up(sizes[1], INT64_C(4));
gpu_sizes[0] = sizes[0] * padded_c;
gpu_sizes[1] = sizes[2];
gpu_sizes[2] = sizes[3];
break;
Expand Down
91 changes: 59 additions & 32 deletions aten/src/ATen/native/vulkan/glsl/cat_feature.glsl
Original file line number Diff line number Diff line change
@@ -1,47 +1,74 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */
/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput;

layout(set = 0, binding = 0, FORMAT) uniform PRECISION image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
ivec4 size; // output texture size (x=width,y=height,z=depth,w=unused)
ivec4 isize; // input texture size (x=width,y=height,z=depth,w=unused)
uint batch_size; // input tensor's batch size
uint ch_size; // input tensor's channel size
uint ch_interval; // channel interval (total # of channels for all tensors)
uint ch_size_allprior; // # of channels for tensor 0 to i-1 at ith tensor
} uBlock;
/*
* Input Textures
*/
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// output texture size (x=width,y=height,z=depth,w=unused)
ivec4 out_extents;
// input texture size (x=width,y=height,z=depth,w=unused)
ivec4 in_extents;
// input tensor's batch size
uint batch_size;
// input tensor's channel size
uint ch_size;
// channel interval (total # of channels for all tensors)
uint ch_interval;
// # of channels for tensor 0 to i-1 at ith tensor
uint ch_size_allprior;
}
uBlock;

/*
* Local Work Group
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 posIn = ivec3(gl_GlobalInvocationID);
const ivec3 in_pos = ivec3(gl_GlobalInvocationID);
const uint max_src_index = uBlock.ch_size * uBlock.batch_size;

if (all(lessThan(posIn, uBlock.isize.xyz))) {
ivec3 posOut = posIn; // x and y don't change. only z and index matter
const vec4 inval = texelFetch(uInput, posIn, 0);

for (uint i = 0; i < 4; ++i)
{
uint src_index = posIn.z * 4 + i;
if (src_index >= max_src_index) {
// out of range
break;
}

uint dst_index = uint(src_index / uBlock.ch_size) * uBlock.ch_interval + (src_index % uBlock.ch_size) + uBlock.ch_size_allprior;
posOut.z = int(dst_index / 4);
uint j = (dst_index % 4);

vec4 outval = imageLoad(uOutput, posOut);
outval[j] = inval[i];
imageStore(uOutput, posOut, outval);
if (any(greaterThanEqual(in_pos, uBlock.in_extents.xyz))) {
return;
}

// x and y don't change. only z and index matter
ivec3 out_pos = in_pos;
const vec4 in_tex = texelFetch(uInput, in_pos, 0);

for (uint i = 0; i < 4; ++i) {
uint src_index = in_pos.z * 4 + i;

if (src_index >= max_src_index) {
// out of range
break;
}

uint src_n_idx = src_index / uBlock.ch_size;
uint src_c_idx = src_index % uBlock.ch_size;

uint dst_nc_idx =
src_n_idx * uBlock.ch_interval + src_c_idx + uBlock.ch_size_allprior;

out_pos.z = int(dst_nc_idx / 4);
uint j = (dst_nc_idx % 4);

vec4 out_tex = imageLoad(uOutput, out_pos);
out_tex[j] = in_tex[i];
imageStore(uOutput, out_pos, out_tex);
}
}
27 changes: 20 additions & 7 deletions aten/src/ATen/native/vulkan/glsl/image_to_nchw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ uBuffer;
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// xyz contain the extents of the input texture, w contains HxW to help
// calculate buffer offsets
// Extents of the output texture
ivec4 in_extents;
// Number of texels spanned by one channel
ivec2 c_info;
}
uBlock;

Expand All @@ -40,13 +41,25 @@ void main() {

const vec4 intex = texelFetch(uImage, pos, 0);

const int n_index = int(pos.z / uBlock.c_info.x);
const int c_index = (pos.z % uBlock.c_info.x) * 4;
int d_offset = (n_index * uBlock.c_info.y) + c_index;

const int base_index =
pos.x + uBlock.in_extents.x * pos.y + (4 * uBlock.in_extents.w) * pos.z;
pos.x + uBlock.in_extents.x * pos.y + uBlock.in_extents.w * d_offset;
const ivec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * uBlock.in_extents.w;

uBuffer.data[buf_indices.x] = intex.x;
uBuffer.data[buf_indices.y] = intex.y;
uBuffer.data[buf_indices.z] = intex.z;
uBuffer.data[buf_indices.w] = intex.w;
if (c_index < uBlock.c_info.y) {
uBuffer.data[buf_indices.x] = intex.x;
}
if (c_index + 1 < uBlock.c_info.y) {
uBuffer.data[buf_indices.y] = intex.y;
}
if (c_index + 2 < uBlock.c_info.y) {
uBuffer.data[buf_indices.z] = intex.z;
}
if (c_index + 3 < uBlock.c_info.y) {
uBuffer.data[buf_indices.w] = intex.w;
}
}
59 changes: 41 additions & 18 deletions aten/src/ATen/native/vulkan/glsl/mean.glsl
Original file line number Diff line number Diff line change
@@ -1,54 +1,77 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */
/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;

layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
ivec4 size;
ivec3 isize;
} uBlock;
/*
* Input Textures
*/
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// extents of the output texture
// w contains pre-computed H*W of the input texture for convenience
ivec4 out_extents;
// extents of the input texture
// w contains size of input channels aligned to 4
ivec4 in_extents;
}
uBlock;

/*
* Shared memory buffer
*/
shared vec4 sh_mem[64];

/*
* Local Work Group
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Computes the mean of an input tensor along the width, height, and channel
* axes.
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec3 tid = ivec3(gl_LocalInvocationID);
const ivec3 group_size = ivec3(gl_WorkGroupSize);

if (pos.z < uBlock.isize.z) {
if (pos.z < uBlock.in_extents.z) {
vec4 sum = vec4(0);

for (int y = tid.y; y < uBlock.isize.y; y+=group_size.y) {
for (int x = tid.x; x < uBlock.isize.x; x+=group_size.x) {
for (int y = tid.y; y < uBlock.in_extents.y; y += group_size.y) {
for (int x = tid.x; x < uBlock.in_extents.x; x += group_size.x) {
sum += texelFetch(uInput, ivec3(x, y, pos.z), 0);
}
}

sh_mem[tid.z * group_size.y * group_size.x + tid.y * group_size.x + tid.x] = sum;
sh_mem[tid.z * group_size.y * group_size.x + tid.y * group_size.x + tid.x] =
sum;
}
memoryBarrierShared();
barrier();

if (tid.y > 0 || tid.x > 0 || pos.z >= uBlock.size.z) {
if (tid.y > 0 || tid.x > 0 || pos.z >= uBlock.out_extents.z) {
return;
}

vec4 total = vec4(0);
for (int y = 0; y < group_size.y; ++y) {
for (int x = 0; x < group_size.x; ++x) {
total += sh_mem[tid.z * group_size.y * group_size.x + y * group_size.x + x];
total +=
sh_mem[tid.z * group_size.y * group_size.x + y * group_size.x + x];
}
}

imageStore(
uOutput,
pos,
total / uBlock.size.w);
imageStore(uOutput, pos, total / uBlock.out_extents.w);
}
93 changes: 55 additions & 38 deletions aten/src/ATen/native/vulkan/glsl/mean2d.glsl
Original file line number Diff line number Diff line change
@@ -1,73 +1,90 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format
#define FORMAT $format

layout(std430) buffer;

/* Qualifiers: layout - storage - precision - memory */

layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
ivec4 size;
ivec3 isize;
} uBlock;
/*
* Output Image
*/
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;

/*
* Input Textures
*/
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// extents of the output texture
// w contains pre-computed H*W of the input texture for convenience
ivec4 out_extents;
// extents of the input texture
// w contains size of input channels aligned to 4
ivec4 in_extents;
}
uBlock;

/*
* Shared memory buffer
*/
shared vec4 sh_mem[64];

/*
* Local Work Group
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Computes the mean of an input tensor along the width and height axes.
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec3 tid = ivec3(gl_LocalInvocationID);
const ivec3 group_size = ivec3(gl_WorkGroupSize);

if (pos.z < uBlock.isize.z) {
if (pos.z < uBlock.in_extents.z) {
vec4 sum = vec4(0);

for (int y = tid.y; y < uBlock.isize.y; y+=group_size.y) {
for (int x = tid.x; x < uBlock.isize.x; x+=group_size.x) {
for (int y = tid.y; y < uBlock.in_extents.y; y += group_size.y) {
for (int x = tid.x; x < uBlock.in_extents.x; x += group_size.x) {
sum += texelFetch(uInput, ivec3(x, y, pos.z), 0);
}
}

sh_mem[tid.z * group_size.y * group_size.x + tid.y * group_size.x + tid.x] = sum;
sh_mem[tid.z * group_size.y * group_size.x + tid.y * group_size.x + tid.x] =
sum;
}
memoryBarrierShared();
barrier();

if (tid.y > 0 || tid.x > 0 || pos.z >= uBlock.isize.z) {
if (tid.y > 0 || tid.x > 0 || pos.z >= uBlock.in_extents.z) {
return;
}

vec4 total = vec4(0);
for (int y = 0; y < group_size.y; ++y) {
for (int x = 0; x < group_size.x; ++x) {
total += sh_mem[tid.z * group_size.y * group_size.x + y * group_size.x + x];
total +=
sh_mem[tid.z * group_size.y * group_size.x + y * group_size.x + x];
}
}

const vec4 outtex = total / uBlock.size.w;
const int zoutx = 4*pos.z;
const int width = uBlock.size.x;
const int maxlen = uBlock.size.x * uBlock.size.y;

const int zouty = min(zoutx + 1, maxlen);
ivec3 posy = ivec3((zouty)%width, (zouty)/width, 0);
vec4 outy = vec4(outtex.y, 0, 0, 0);
imageStore(uOutput, posy, outy);

const int zoutz = min(zoutx + 2, maxlen);
ivec3 posz = ivec3((zoutz)%width, (zoutz)/width, 0);
vec4 outz = vec4(outtex.z, 0, 0, 0);
imageStore(uOutput, posz, outz);

const int zoutw = min(zoutx + 3, maxlen);
ivec3 posw = ivec3((zoutw)%width, (zoutw)/width, 0);
vec4 outw = vec4(outtex.w, 0, 0, 0);
imageStore(uOutput, posw, outw);

ivec3 posx = ivec3(zoutx%width, zoutx/width, 0);
vec4 outx = vec4(outtex.x, 0, 0, 0);
imageStore(uOutput, posx, outx);
const vec4 outtex = total / uBlock.out_extents.w;

const int nc_idx = pos.z * 4;
const int out_width = uBlock.out_extents.x;
const int out_height = uBlock.out_extents.y;

for (int i = 0; i < 4; ++i) {
const int n_idx = (nc_idx + i) / uBlock.in_extents.w;
const int c_idx = (nc_idx + i) % uBlock.in_extents.w;

ivec3 pos = ivec3(c_idx, n_idx, 0);
if (c_idx < out_width && n_idx < out_height) {
imageStore(uOutput, pos, vec4(outtex[i], 0, 0, 0));
}
}
}

0 comments on commit 51046c6

Please sign in to comment.